context_fetch_tool.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. """上下文获取工具 - 让模型按需获取扩展上下文
  2. 设计理念(借鉴 Claude Code):
  3. - 保底上下文由 ContextBuilder 自动注入(系统提示、对话历史、上次工具摘要)
  4. - 扩展上下文通过此工具按需获取(notes、memory、files、tests)
  5. - 模型自行决定何时需要更多证据,避免盲目全局扫描
  6. """
  7. from typing import Dict, Any, List, Optional
  8. import subprocess
  9. import os
  10. from ..base import Tool, ToolParameter
  11. class ContextFetchTool(Tool):
  12. """上下文获取工具
  13. 让模型按需获取扩展上下文,支持多种数据源:
  14. - notes: 检索笔记(blocker、insight、decision 等)
  15. - memory: 检索情景记忆(之前的对话/经验)
  16. - files: 搜索代码文件(rg + 上下文行)
  17. - tests: 获取最近测试失败信息
  18. 使用场景:
  19. - 模型发现证据不足时主动调用
  20. - 提到类名/函数名/错误栈时获取相关代码
  21. - 询问"之前做了什么"时检索记忆
  22. """
  23. def __init__(
  24. self,
  25. workspace: str,
  26. note_tool: Optional[Any] = None,
  27. memory_tool: Optional[Any] = None,
  28. max_tokens_per_source: int = 800,
  29. context_lines: int = 5, # 命中行前后各取 k 行
  30. ):
  31. super().__init__(
  32. name="context_fetch",
  33. description=(
  34. "获取扩展上下文。当保底上下文不足以回答问题时调用。"
  35. "可指定数据源:notes(笔记)、memory(记忆)、files(代码文件)、tests(测试结果)。"
  36. "返回结构化的证据块。"
  37. ),
  38. )
  39. self.workspace = workspace
  40. self.note_tool = note_tool
  41. self.memory_tool = memory_tool
  42. self.max_tokens_per_source = max_tokens_per_source
  43. self.context_lines = context_lines
  44. # 缓存最近的查询结果,避免重复查询
  45. self._cache: Dict[str, str] = {}
  46. self._cache_max_size = 20
  47. def get_parameters(self) -> List[ToolParameter]:
  48. """遵循基类接口返回参数定义"""
  49. return [
  50. ToolParameter(
  51. name="sources",
  52. type="array",
  53. description="要查询的数据源列表,可选: notes, memory, files, tests",
  54. required=True,
  55. ),
  56. ToolParameter(
  57. name="query",
  58. type="string",
  59. description="搜索关键词/符号名/错误栈片段",
  60. required=True,
  61. ),
  62. ToolParameter(
  63. name="paths",
  64. type="string",
  65. description="限定文件搜索范围的 glob 模式,如 'src/**/*.py'",
  66. required=False,
  67. ),
  68. ToolParameter(
  69. name="budget_tokens",
  70. type="integer",
  71. description="单个数据源的 token 上限,默认 800",
  72. required=False,
  73. ),
  74. ]
  75. def run(self, parameters: Dict[str, Any]) -> str:
  76. """执行上下文获取"""
  77. sources = parameters.get("sources", [])
  78. query = parameters.get("query", "")
  79. paths = parameters.get("paths", "")
  80. budget = parameters.get("budget_tokens", self.max_tokens_per_source)
  81. if not sources or not query:
  82. return "错误:必须指定 sources 和 query 参数"
  83. # 检查缓存
  84. cache_key = f"{','.join(sorted(sources))}|{query}|{paths}"
  85. if cache_key in self._cache:
  86. return f"[缓存命中]\n{self._cache[cache_key]}"
  87. results: List[str] = []
  88. for source in sources:
  89. if source == "notes":
  90. result = self._fetch_notes(query, budget)
  91. elif source == "memory":
  92. result = self._fetch_memory(query, budget)
  93. elif source == "files":
  94. result = self._fetch_files(query, paths, budget)
  95. elif source == "tests":
  96. result = self._fetch_tests(query, budget)
  97. else:
  98. result = f"[{source}] 未知数据源"
  99. if result:
  100. results.append(result)
  101. output = "\n\n".join(results) if results else "未找到相关上下文"
  102. # 更新缓存
  103. if len(self._cache) >= self._cache_max_size:
  104. # 简单 LRU:删除最早的
  105. oldest_key = next(iter(self._cache))
  106. del self._cache[oldest_key]
  107. self._cache[cache_key] = output
  108. return output
  109. def _fetch_notes(self, query: str, budget: int) -> str:
  110. """从笔记中检索"""
  111. if not self.note_tool:
  112. return "[notes] 笔记工具未配置"
  113. try:
  114. # 搜索相关笔记
  115. result = self.note_tool.run({
  116. "action": "search",
  117. "query": query,
  118. "limit": 5,
  119. })
  120. if result and "未找到" not in result:
  121. return f"[notes] 相关笔记:\n{self._truncate(result, budget)}"
  122. return "[notes] 未找到相关笔记"
  123. except Exception as e:
  124. return f"[notes] 检索失败: {e}"
  125. def _fetch_memory(self, query: str, budget: int) -> str:
  126. """从记忆中检索"""
  127. if not self.memory_tool:
  128. return "[memory] 记忆工具未配置"
  129. try:
  130. result = self.memory_tool.run({
  131. "action": "search",
  132. "query": query,
  133. "memory_types": getattr(self.memory_tool, "memory_types", ["episodic"]),
  134. "limit": 5,
  135. "min_importance": 0.0,
  136. })
  137. if result and "未找到" not in result:
  138. return f"[memory] 相关记忆:\n{self._truncate(result, budget)}"
  139. return "[memory] 未找到相关记忆"
  140. except Exception as e:
  141. return f"[memory] 检索失败: {e}"
  142. def _fetch_files(self, query: str, paths: str, budget: int) -> str:
  143. """从代码文件中检索"""
  144. try:
  145. # 使用 ripgrep 搜索
  146. cmd = ["rg", "--color=never", "-n", "-C", str(self.context_lines)]
  147. if paths:
  148. cmd.extend(["-g", paths])
  149. cmd.append(query)
  150. cmd.append(self.workspace)
  151. result = subprocess.run(
  152. cmd,
  153. capture_output=True,
  154. text=True,
  155. timeout=10,
  156. cwd=self.workspace,
  157. )
  158. output = result.stdout.strip()
  159. if output:
  160. # 结构化输出
  161. lines = output.split("\n")
  162. # 按文件分组
  163. grouped = self._group_by_file(lines)
  164. formatted = self._format_file_results(grouped, budget)
  165. return f"[files] 代码搜索结果:\n{formatted}"
  166. return f"[files] 未找到匹配 '{query}' 的内容"
  167. except subprocess.TimeoutExpired:
  168. return "[files] 搜索超时"
  169. except FileNotFoundError:
  170. # ripgrep 未安装,降级到 grep
  171. return self._fetch_files_fallback(query, paths, budget)
  172. except Exception as e:
  173. return f"[files] 搜索失败: {e}"
  174. def _fetch_files_fallback(self, query: str, paths: str, budget: int) -> str:
  175. """ripgrep 不可用时的降级方案"""
  176. try:
  177. cmd = f"grep -rn '{query}' {self.workspace}"
  178. if paths:
  179. cmd = f"find {self.workspace} -path '{paths}' -type f | xargs grep -n '{query}'"
  180. result = subprocess.run(
  181. cmd,
  182. shell=True,
  183. capture_output=True,
  184. text=True,
  185. timeout=10,
  186. )
  187. output = result.stdout.strip()
  188. if output:
  189. return f"[files] grep 结果:\n{self._truncate(output, budget)}"
  190. return f"[files] 未找到匹配 '{query}' 的内容"
  191. except Exception as e:
  192. return f"[files] grep 搜索失败: {e}"
  193. def _fetch_tests(self, query: str, budget: int) -> str:
  194. """获取测试相关信息"""
  195. # 查找最近的测试输出/日志
  196. test_patterns = [
  197. ".pytest_cache/v/cache/lastfailed",
  198. "test-results.xml",
  199. ".coverage",
  200. ]
  201. results = []
  202. for pattern in test_patterns:
  203. path = os.path.join(self.workspace, pattern)
  204. if os.path.exists(path):
  205. try:
  206. with open(path, "r", encoding="utf-8", errors="ignore") as f:
  207. content = f.read()
  208. if query.lower() in content.lower():
  209. results.append(f"[tests] {pattern}:\n{self._truncate(content, budget // 2)}")
  210. except Exception:
  211. pass
  212. if results:
  213. return "\n".join(results)
  214. return "[tests] 未找到相关测试信息"
  215. def _group_by_file(self, lines: List[str]) -> Dict[str, List[str]]:
  216. """按文件分组 ripgrep 输出"""
  217. grouped: Dict[str, List[str]] = {}
  218. current_file = None
  219. for line in lines:
  220. if ":" in line:
  221. # 格式: file:line:content 或 file-line-content
  222. parts = line.split(":", 2) if ":" in line else line.split("-", 2)
  223. if len(parts) >= 2:
  224. file_path = parts[0]
  225. if file_path != current_file:
  226. current_file = file_path
  227. grouped[current_file] = []
  228. grouped[current_file].append(line)
  229. elif current_file:
  230. grouped[current_file].append(line)
  231. return grouped
  232. def _format_file_results(self, grouped: Dict[str, List[str]], budget: int) -> str:
  233. """格式化文件搜索结果"""
  234. output_parts = []
  235. tokens_used = 0
  236. tokens_per_file = budget // max(len(grouped), 1)
  237. for file_path, lines in grouped.items():
  238. content = "\n".join(lines)
  239. truncated = self._truncate(content, tokens_per_file)
  240. # 相对路径
  241. rel_path = file_path.replace(self.workspace, "").lstrip("/")
  242. output_parts.append(f"--- {rel_path} ---\n{truncated}")
  243. tokens_used += len(truncated) // 4 # 粗略估算
  244. if tokens_used >= budget:
  245. output_parts.append("...(更多结果已截断)...")
  246. break
  247. return "\n\n".join(output_parts)
  248. def _truncate(self, text: str, max_tokens: int) -> str:
  249. """截断文本到指定 token 上限"""
  250. # 粗略估算:1 token ≈ 4 字符(英文),2 字符(中文)
  251. max_chars = max_tokens * 3
  252. if len(text) <= max_chars:
  253. return text
  254. return text[:max_chars] + "\n...(已截断)..."
  255. def clear_cache(self):
  256. """清空缓存"""
  257. self._cache.clear()