context_manager.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """Step 8: 上下文工程 — 对话压缩、Token 管理、多轮连贯性"""
  2. import json
  3. import os
  4. from datetime import datetime
  5. from typing import List, Dict, Optional
  6. class ContextManager:
  7. """对话上下文管理器:压缩历史、控制 Token 用量、保持连贯性"""
  8. def __init__(self, max_tokens: int = 4000, summary_trigger: int = 3000):
  9. self.max_tokens = max_tokens # 上下文最大 token 数
  10. self.summary_trigger = summary_trigger # 触发压缩的阈值
  11. self.turns: List[Dict] = [] # 对话轮次
  12. self.summary: str = "" # 压缩后的摘要
  13. self.total_turns = 0
  14. @staticmethod
  15. def _estimate_tokens(text: str) -> int:
  16. """简单 Token 估算:中文 ~1.5 字/token,英文 ~4 字/token"""
  17. chinese = sum(1 for c in text if '一' <= c <= '鿿')
  18. other = len(text) - chinese
  19. return int(chinese / 1.5 + other / 4)
  20. def add_turn(self, role: str, content: str):
  21. """添加一轮对话"""
  22. self.total_turns += 1
  23. turn = {
  24. "id": self.total_turns,
  25. "role": role,
  26. "content": content,
  27. "tokens": self._estimate_tokens(content),
  28. "time": datetime.now().strftime("%H:%M:%S"),
  29. }
  30. self.turns.append(turn)
  31. # 检查是否需要压缩
  32. total = sum(t["tokens"] for t in self.turns)
  33. if total > self.summary_trigger:
  34. self._compress()
  35. def _compress(self):
  36. """压缩早期对话为摘要"""
  37. if len(self.turns) <= 4:
  38. return # 保留最近 4 轮
  39. # 取最早的 60% 轮次进行压缩
  40. split = max(1, int(len(self.turns) * 0.6))
  41. old_turns = self.turns[:split]
  42. recent = self.turns[split:]
  43. # 生成摘要
  44. lines = []
  45. for t in old_turns:
  46. role_label = "用户" if t["role"] == "user" else "助手"
  47. snippet = t["content"][:200].replace("\n", " ")
  48. lines.append(f"[{role_label}]: {snippet}")
  49. new_summary = "对话历史摘要:\n" + "\n".join(lines)
  50. if self.summary:
  51. self.summary = self.summary[:500] + "\n...\n" + new_summary
  52. else:
  53. self.summary = new_summary
  54. # 限制摘要长度
  55. if self._estimate_tokens(self.summary) > 1500:
  56. self.summary = self.summary[-1500:]
  57. self.turns = recent
  58. def get_context(self, system_prompt: str = "",
  59. current_query: str = "") -> str:
  60. """构建当前上下文字符串"""
  61. parts = []
  62. # 压缩摘要
  63. if self.summary:
  64. parts.append(f"## 历史对话摘要\n{self.summary[:2000]}")
  65. # 最近对话
  66. if self.turns:
  67. parts.append("## 最近对话")
  68. for t in self.turns[-8:]: # 最近 8 轮
  69. role_label = "用户" if t["role"] == "user" else "助手"
  70. content = t["content"]
  71. if self._estimate_tokens(content) > 500:
  72. content = content[:500] + "..."
  73. parts.append(f"### {role_label}\n{content}")
  74. return "\n\n".join(parts)
  75. def get_stats(self) -> str:
  76. """获取上下文使用统计"""
  77. total = sum(t["tokens"] for t in self.turns)
  78. summary_tokens = self._estimate_tokens(self.summary) if self.summary else 0
  79. return (f"上下文: {len(self.turns)} 活跃轮次, "
  80. f"约 {total} tokens 活跃 + {summary_tokens} tokens 摘要, "
  81. f"总计 {self.total_turns} 轮对话")
  82. def clear(self):
  83. self.turns = []
  84. self.summary = ""
  85. self.total_turns = 0
  86. # ===== 上下文感知的 System Prompt 构建器 =====
  87. def build_context_aware_prompt(
  88. ctx: ContextManager,
  89. base_prompt: str,
  90. user_query: str,
  91. memory_context: str = "",
  92. kb_context: str = "",
  93. ) -> str:
  94. """构建完整上下文感知的系统消息"""
  95. parts = [base_prompt]
  96. # 对话上下文
  97. context_str = ctx.get_context()
  98. if context_str:
  99. parts.append(f"\n## 当前对话上下文\n{context_str}")
  100. # 记忆上下文
  101. if memory_context:
  102. parts.append(f"\n## 用户记忆\n{memory_context}")
  103. # 知识库上下文
  104. if kb_context:
  105. parts.append(f"\n## 相关知识\n{kb_context}")
  106. return "\n".join(parts)
  107. # 全局单例
  108. _ctx_instance: Optional[ContextManager] = None
  109. def get_context() -> ContextManager:
  110. global _ctx_instance
  111. if _ctx_instance is None:
  112. _ctx_instance = ContextManager()
  113. return _ctx_instance
  114. # ===== 工具函数 =====
  115. def context_stats(query: str = "") -> str:
  116. """查看当前上下文使用统计"""
  117. return get_context().get_stats()
  118. def context_clear(query: str = "") -> str:
  119. """清空上下文(开始新会话)"""
  120. get_context().clear()
  121. return "上下文已清空,开始新会话。"
  122. def context_summarize(query: str = "") -> str:
  123. """手动触发上下文压缩"""
  124. ctx = get_context()
  125. ctx._compress()
  126. return ctx.get_stats()