| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- """上下文获取工具 - 让模型按需获取扩展上下文
- 设计理念(借鉴 Claude Code):
- - 保底上下文由 ContextBuilder 自动注入(系统提示、对话历史、上次工具摘要)
- - 扩展上下文通过此工具按需获取(notes、memory、files、tests)
- - 模型自行决定何时需要更多证据,避免盲目全局扫描
- """
- from typing import Dict, Any, List, Optional
- import subprocess
- import os
- from ..base import Tool, ToolParameter
- class ContextFetchTool(Tool):
- """上下文获取工具
-
- 让模型按需获取扩展上下文,支持多种数据源:
- - notes: 检索笔记(blocker、insight、decision 等)
- - memory: 检索情景记忆(之前的对话/经验)
- - files: 搜索代码文件(rg + 上下文行)
- - tests: 获取最近测试失败信息
-
- 使用场景:
- - 模型发现证据不足时主动调用
- - 提到类名/函数名/错误栈时获取相关代码
- - 询问"之前做了什么"时检索记忆
- """
-
- def __init__(
- self,
- workspace: str,
- note_tool: Optional[Any] = None,
- memory_tool: Optional[Any] = None,
- max_tokens_per_source: int = 800,
- context_lines: int = 5, # 命中行前后各取 k 行
- ):
- super().__init__(
- name="context_fetch",
- description=(
- "获取扩展上下文。当保底上下文不足以回答问题时调用。"
- "可指定数据源:notes(笔记)、memory(记忆)、files(代码文件)、tests(测试结果)。"
- "返回结构化的证据块。"
- ),
- )
- self.workspace = workspace
- self.note_tool = note_tool
- self.memory_tool = memory_tool
- self.max_tokens_per_source = max_tokens_per_source
- self.context_lines = context_lines
-
- # 缓存最近的查询结果,避免重复查询
- self._cache: Dict[str, str] = {}
- self._cache_max_size = 20
-
- def get_parameters(self) -> List[ToolParameter]:
- """遵循基类接口返回参数定义"""
- return [
- ToolParameter(
- name="sources",
- type="array",
- description="要查询的数据源列表,可选: notes, memory, files, tests",
- required=True,
- ),
- ToolParameter(
- name="query",
- type="string",
- description="搜索关键词/符号名/错误栈片段",
- required=True,
- ),
- ToolParameter(
- name="paths",
- type="string",
- description="限定文件搜索范围的 glob 模式,如 'src/**/*.py'",
- required=False,
- ),
- ToolParameter(
- name="budget_tokens",
- type="integer",
- description="单个数据源的 token 上限,默认 800",
- required=False,
- ),
- ]
-
- def run(self, parameters: Dict[str, Any]) -> str:
- """执行上下文获取"""
- sources = parameters.get("sources", [])
- query = parameters.get("query", "")
- paths = parameters.get("paths", "")
- budget = parameters.get("budget_tokens", self.max_tokens_per_source)
-
- if not sources or not query:
- return "错误:必须指定 sources 和 query 参数"
-
- # 检查缓存
- cache_key = f"{','.join(sorted(sources))}|{query}|{paths}"
- if cache_key in self._cache:
- return f"[缓存命中]\n{self._cache[cache_key]}"
-
- results: List[str] = []
-
- for source in sources:
- if source == "notes":
- result = self._fetch_notes(query, budget)
- elif source == "memory":
- result = self._fetch_memory(query, budget)
- elif source == "files":
- result = self._fetch_files(query, paths, budget)
- elif source == "tests":
- result = self._fetch_tests(query, budget)
- else:
- result = f"[{source}] 未知数据源"
-
- if result:
- results.append(result)
-
- output = "\n\n".join(results) if results else "未找到相关上下文"
-
- # 更新缓存
- if len(self._cache) >= self._cache_max_size:
- # 简单 LRU:删除最早的
- oldest_key = next(iter(self._cache))
- del self._cache[oldest_key]
- self._cache[cache_key] = output
-
- return output
-
- def _fetch_notes(self, query: str, budget: int) -> str:
- """从笔记中检索"""
- if not self.note_tool:
- return "[notes] 笔记工具未配置"
-
- try:
- # 搜索相关笔记
- result = self.note_tool.run({
- "action": "search",
- "query": query,
- "limit": 5,
- })
- if result and "未找到" not in result:
- return f"[notes] 相关笔记:\n{self._truncate(result, budget)}"
- return "[notes] 未找到相关笔记"
- except Exception as e:
- return f"[notes] 检索失败: {e}"
-
- def _fetch_memory(self, query: str, budget: int) -> str:
- """从记忆中检索"""
- if not self.memory_tool:
- return "[memory] 记忆工具未配置"
-
- try:
- result = self.memory_tool.run({
- "action": "search",
- "query": query,
- "memory_types": getattr(self.memory_tool, "memory_types", ["episodic"]),
- "limit": 5,
- "min_importance": 0.0,
- })
- if result and "未找到" not in result:
- return f"[memory] 相关记忆:\n{self._truncate(result, budget)}"
- return "[memory] 未找到相关记忆"
- except Exception as e:
- return f"[memory] 检索失败: {e}"
-
- def _fetch_files(self, query: str, paths: str, budget: int) -> str:
- """从代码文件中检索"""
- try:
- # 使用 ripgrep 搜索
- cmd = ["rg", "--color=never", "-n", "-C", str(self.context_lines)]
-
- if paths:
- cmd.extend(["-g", paths])
-
- cmd.append(query)
- cmd.append(self.workspace)
-
- result = subprocess.run(
- cmd,
- capture_output=True,
- text=True,
- timeout=10,
- cwd=self.workspace,
- )
-
- output = result.stdout.strip()
- if output:
- # 结构化输出
- lines = output.split("\n")
- # 按文件分组
- grouped = self._group_by_file(lines)
- formatted = self._format_file_results(grouped, budget)
- return f"[files] 代码搜索结果:\n{formatted}"
- return f"[files] 未找到匹配 '{query}' 的内容"
- except subprocess.TimeoutExpired:
- return "[files] 搜索超时"
- except FileNotFoundError:
- # ripgrep 未安装,降级到 grep
- return self._fetch_files_fallback(query, paths, budget)
- except Exception as e:
- return f"[files] 搜索失败: {e}"
-
- def _fetch_files_fallback(self, query: str, paths: str, budget: int) -> str:
- """ripgrep 不可用时的降级方案"""
- try:
- cmd = f"grep -rn '{query}' {self.workspace}"
- if paths:
- cmd = f"find {self.workspace} -path '{paths}' -type f | xargs grep -n '{query}'"
-
- result = subprocess.run(
- cmd,
- shell=True,
- capture_output=True,
- text=True,
- timeout=10,
- )
- output = result.stdout.strip()
- if output:
- return f"[files] grep 结果:\n{self._truncate(output, budget)}"
- return f"[files] 未找到匹配 '{query}' 的内容"
- except Exception as e:
- return f"[files] grep 搜索失败: {e}"
-
- def _fetch_tests(self, query: str, budget: int) -> str:
- """获取测试相关信息"""
- # 查找最近的测试输出/日志
- test_patterns = [
- ".pytest_cache/v/cache/lastfailed",
- "test-results.xml",
- ".coverage",
- ]
-
- results = []
- for pattern in test_patterns:
- path = os.path.join(self.workspace, pattern)
- if os.path.exists(path):
- try:
- with open(path, "r", encoding="utf-8", errors="ignore") as f:
- content = f.read()
- if query.lower() in content.lower():
- results.append(f"[tests] {pattern}:\n{self._truncate(content, budget // 2)}")
- except Exception:
- pass
-
- if results:
- return "\n".join(results)
- return "[tests] 未找到相关测试信息"
-
- def _group_by_file(self, lines: List[str]) -> Dict[str, List[str]]:
- """按文件分组 ripgrep 输出"""
- grouped: Dict[str, List[str]] = {}
- current_file = None
-
- for line in lines:
- if ":" in line:
- # 格式: file:line:content 或 file-line-content
- parts = line.split(":", 2) if ":" in line else line.split("-", 2)
- if len(parts) >= 2:
- file_path = parts[0]
- if file_path != current_file:
- current_file = file_path
- grouped[current_file] = []
- grouped[current_file].append(line)
- elif current_file:
- grouped[current_file].append(line)
-
- return grouped
-
- def _format_file_results(self, grouped: Dict[str, List[str]], budget: int) -> str:
- """格式化文件搜索结果"""
- output_parts = []
- tokens_used = 0
- tokens_per_file = budget // max(len(grouped), 1)
-
- for file_path, lines in grouped.items():
- content = "\n".join(lines)
- truncated = self._truncate(content, tokens_per_file)
-
- # 相对路径
- rel_path = file_path.replace(self.workspace, "").lstrip("/")
- output_parts.append(f"--- {rel_path} ---\n{truncated}")
-
- tokens_used += len(truncated) // 4 # 粗略估算
- if tokens_used >= budget:
- output_parts.append("...(更多结果已截断)...")
- break
-
- return "\n\n".join(output_parts)
-
- def _truncate(self, text: str, max_tokens: int) -> str:
- """截断文本到指定 token 上限"""
- # 粗略估算:1 token ≈ 4 字符(英文),2 字符(中文)
- max_chars = max_tokens * 3
- if len(text) <= max_chars:
- return text
- return text[:max_chars] + "\n...(已截断)..."
-
- def clear_cache(self):
- """清空缓存"""
- self._cache.clear()
|