simple_agent.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. """简单Agent实现 - 基于OpenAI原生API"""
  2. from typing import Optional, Iterator, TYPE_CHECKING, Callable
  3. import re
  4. from core.agent import Agent
  5. from core.llm import HelloAgentsLLM
  6. from core.config import Config
  7. from core.message import Message
  8. if TYPE_CHECKING:
  9. from tools.registry import ToolRegistry
  10. class SimpleAgent(Agent):
  11. """简单的对话Agent,支持可选的工具调用"""
  12. def __init__(
  13. self,
  14. name: str,
  15. llm: HelloAgentsLLM,
  16. system_prompt: Optional[str] = None,
  17. config: Optional[Config] = None,
  18. tool_registry: Optional['ToolRegistry'] = None,
  19. enable_tool_calling: bool = True,
  20. tool_confirm_callback: Optional[Callable[[str, dict], bool]] = None,
  21. ):
  22. """
  23. 初始化SimpleAgent
  24. Args:
  25. name: Agent名称
  26. llm: LLM实例
  27. system_prompt: 系统提示词
  28. config: 配置对象
  29. tool_registry: 工具注册表(可选,如果提供则启用工具调用)
  30. enable_tool_calling: 是否启用工具调用(只有在提供tool_registry时生效)
  31. """
  32. super().__init__(name, llm, system_prompt, config)
  33. self.tool_registry = tool_registry
  34. self.enable_tool_calling = enable_tool_calling and tool_registry is not None
  35. self.tool_confirm_callback = tool_confirm_callback
  36. def _get_enhanced_system_prompt(self) -> str:
  37. """构建增强的系统提示词,包含工具信息"""
  38. base_prompt = self.system_prompt or "你是一个有用的AI助手。"
  39. if not self.enable_tool_calling or not self.tool_registry:
  40. return base_prompt
  41. # 获取工具描述
  42. tools_description = self.tool_registry.get_tools_description()
  43. if not tools_description or tools_description == "暂无可用工具":
  44. return base_prompt
  45. tools_section = "\n\n## 可用工具\n"
  46. tools_section += "你可以使用以下工具来帮助回答问题:\n"
  47. tools_section += tools_description + "\n"
  48. tools_section += "\n## 工具调用格式\n"
  49. tools_section += "当需要使用工具时,请使用以下格式:\n"
  50. tools_section += "`[TOOL_CALL:{tool_name}:{parameters}]`\n\n"
  51. tools_section += "### 参数格式说明\n"
  52. tools_section += "1. **多个参数**:使用 `key=value` 格式,用逗号分隔\n"
  53. tools_section += " 示例:`[TOOL_CALL:calculator_multiply:a=12,b=8]`\n"
  54. tools_section += " 示例:`[TOOL_CALL:filesystem_read_file:path=README.md]`\n\n"
  55. tools_section += "2. **单个参数**:直接使用 `key=value`\n"
  56. tools_section += " 示例:`[TOOL_CALL:search:query=Python编程]`\n\n"
  57. tools_section += "3. **简单查询**:可以直接传入文本\n"
  58. tools_section += " 示例:`[TOOL_CALL:search:Python编程]`\n\n"
  59. tools_section += "### 重要提示\n"
  60. tools_section += "- 参数名必须与工具定义的参数名完全匹配\n"
  61. tools_section += "- 数字参数直接写数字,不需要引号:`a=12` 而不是 `a=\"12\"`\n"
  62. tools_section += "- 文件路径等字符串参数直接写:`path=README.md`\n"
  63. tools_section += "- 工具调用结果会自动插入到对话中,然后你可以基于结果继续回答\n"
  64. return base_prompt + tools_section
  65. def _parse_tool_calls(self, text: str) -> list:
  66. """解析文本中的工具调用"""
  67. pattern = r'\[TOOL_CALL:([^:]+):([^\]]+)\]'
  68. matches = re.findall(pattern, text)
  69. tool_calls = []
  70. for tool_name, parameters in matches:
  71. tool_calls.append({
  72. 'tool_name': tool_name.strip(),
  73. 'parameters': parameters.strip(),
  74. 'original': f'[TOOL_CALL:{tool_name}:{parameters}]'
  75. })
  76. return tool_calls
  77. def _execute_tool_call(self, tool_name: str, parameters: str) -> str:
  78. """执行工具调用"""
  79. if not self.tool_registry:
  80. return f"❌ 错误:未配置工具注册表"
  81. try:
  82. # 获取Tool对象
  83. tool = self.tool_registry.get_tool(tool_name)
  84. if not tool:
  85. return f"❌ 错误:未找到工具 '{tool_name}'"
  86. # 智能参数解析
  87. param_dict = self._parse_tool_parameters(tool_name, parameters)
  88. # 交互式确认门(由上层执行器裁决是否允许执行)
  89. if self.tool_confirm_callback is not None:
  90. try:
  91. allowed = bool(self.tool_confirm_callback(tool_name, param_dict))
  92. except Exception as e:
  93. return f"❌ 工具调用确认失败:{str(e)}"
  94. if not allowed:
  95. return "⛔️ 已取消本次工具调用(需要用户确认)。"
  96. # 调用工具
  97. result = tool.run(param_dict)
  98. return f"🔧 工具 {tool_name} 执行结果:\n{result}"
  99. except Exception as e:
  100. return f"❌ 工具调用失败:{str(e)}"
  101. def _parse_tool_parameters(self, tool_name: str, parameters: str) -> dict:
  102. """智能解析工具参数"""
  103. import json
  104. param_dict = {}
  105. # 尝试解析JSON格式
  106. if parameters.strip().startswith('{'):
  107. try:
  108. param_dict = json.loads(parameters)
  109. # JSON解析成功,进行类型转换
  110. param_dict = self._convert_parameter_types(tool_name, param_dict)
  111. return param_dict
  112. except json.JSONDecodeError:
  113. # JSON解析失败,继续使用其他方式
  114. pass
  115. if '=' in parameters:
  116. # 格式: key=value 或 action=search,query=Python
  117. if ',' in parameters:
  118. # 多个参数:action=search,query=Python,limit=3
  119. pairs = parameters.split(',')
  120. for pair in pairs:
  121. if '=' in pair:
  122. key, value = pair.split('=', 1)
  123. param_dict[key.strip()] = value.strip()
  124. else:
  125. # 单个参数:key=value
  126. key, value = parameters.split('=', 1)
  127. param_dict[key.strip()] = value.strip()
  128. # 类型转换
  129. param_dict = self._convert_parameter_types(tool_name, param_dict)
  130. # 智能推断action(如果没有指定)
  131. if 'action' not in param_dict:
  132. param_dict = self._infer_action(tool_name, param_dict)
  133. else:
  134. # 直接传入参数,根据工具类型智能推断
  135. param_dict = self._infer_simple_parameters(tool_name, parameters)
  136. return param_dict
  137. def _convert_parameter_types(self, tool_name: str, param_dict: dict) -> dict:
  138. """
  139. 根据工具的参数定义转换参数类型
  140. Args:
  141. tool_name: 工具名称
  142. param_dict: 参数字典
  143. Returns:
  144. 类型转换后的参数字典
  145. """
  146. if not self.tool_registry:
  147. return param_dict
  148. tool = self.tool_registry.get_tool(tool_name)
  149. if not tool:
  150. return param_dict
  151. # 获取工具的参数定义
  152. try:
  153. tool_params = tool.get_parameters()
  154. except:
  155. return param_dict
  156. # 创建参数类型映射
  157. param_types = {}
  158. for param in tool_params:
  159. param_types[param.name] = param.type
  160. # 转换参数类型
  161. converted_dict = {}
  162. for key, value in param_dict.items():
  163. if key in param_types:
  164. param_type = param_types[key]
  165. try:
  166. if param_type == 'number' or param_type == 'integer':
  167. # 转换为数字
  168. if isinstance(value, str):
  169. converted_dict[key] = float(value) if param_type == 'number' else int(value)
  170. else:
  171. converted_dict[key] = value
  172. elif param_type == 'boolean':
  173. # 转换为布尔值
  174. if isinstance(value, str):
  175. converted_dict[key] = value.lower() in ('true', '1', 'yes')
  176. else:
  177. converted_dict[key] = bool(value)
  178. else:
  179. converted_dict[key] = value
  180. except (ValueError, TypeError):
  181. # 转换失败,保持原值
  182. converted_dict[key] = value
  183. else:
  184. converted_dict[key] = value
  185. return converted_dict
  186. def _infer_action(self, tool_name: str, param_dict: dict) -> dict:
  187. """根据工具类型和参数推断action"""
  188. if tool_name == 'memory':
  189. if 'recall' in param_dict:
  190. param_dict['action'] = 'search'
  191. param_dict['query'] = param_dict.pop('recall')
  192. elif 'store' in param_dict:
  193. param_dict['action'] = 'add'
  194. param_dict['content'] = param_dict.pop('store')
  195. elif 'query' in param_dict:
  196. param_dict['action'] = 'search'
  197. elif 'content' in param_dict:
  198. param_dict['action'] = 'add'
  199. elif tool_name == 'rag':
  200. if 'search' in param_dict:
  201. param_dict['action'] = 'search'
  202. param_dict['query'] = param_dict.pop('search')
  203. elif 'query' in param_dict:
  204. param_dict['action'] = 'search'
  205. elif 'text' in param_dict:
  206. param_dict['action'] = 'add_text'
  207. return param_dict
  208. def _infer_simple_parameters(self, tool_name: str, parameters: str) -> dict:
  209. """为简单参数推断完整的参数字典"""
  210. if tool_name == 'rag':
  211. return {'action': 'search', 'query': parameters}
  212. elif tool_name == 'memory':
  213. return {'action': 'search', 'query': parameters}
  214. else:
  215. return {'input': parameters}
  216. def run(self, input_text: str, max_tool_iterations: int = 3, **kwargs) -> str:
  217. """
  218. 运行SimpleAgent,支持可选的工具调用
  219. Args:
  220. input_text: 用户输入
  221. max_tool_iterations: 最大工具调用迭代次数(仅在启用工具时有效)
  222. **kwargs: 其他参数
  223. Returns:
  224. Agent响应
  225. """
  226. # 构建消息列表
  227. messages = []
  228. # 添加系统消息(可能包含工具信息)
  229. enhanced_system_prompt = self._get_enhanced_system_prompt()
  230. messages.append({"role": "system", "content": enhanced_system_prompt})
  231. # 添加历史消息
  232. for msg in self._history:
  233. messages.append({"role": msg.role, "content": msg.content})
  234. # 添加当前用户消息
  235. messages.append({"role": "user", "content": input_text})
  236. # 如果没有启用工具调用,使用原有逻辑
  237. if not self.enable_tool_calling:
  238. response = self.llm.invoke(messages, **kwargs)
  239. self.add_message(Message(input_text, "user"))
  240. self.add_message(Message(response, "assistant"))
  241. return response
  242. # 迭代处理,支持多轮工具调用
  243. current_iteration = 0
  244. final_response = ""
  245. while current_iteration < max_tool_iterations:
  246. # 调用LLM
  247. response = self.llm.invoke(messages, **kwargs)
  248. # 检查是否有工具调用
  249. tool_calls = self._parse_tool_calls(response)
  250. if tool_calls:
  251. # 执行所有工具调用并收集结果
  252. tool_results = []
  253. clean_response = response
  254. for call in tool_calls:
  255. result = self._execute_tool_call(call['tool_name'], call['parameters'])
  256. tool_results.append(result)
  257. # 从响应中移除工具调用标记
  258. clean_response = clean_response.replace(call['original'], "")
  259. # 构建包含工具结果的消息
  260. messages.append({"role": "assistant", "content": clean_response})
  261. # 添加工具结果
  262. tool_results_text = "\n\n".join(tool_results)
  263. messages.append({"role": "user", "content": f"工具执行结果:\n{tool_results_text}\n\n请基于这些结果给出完整的回答。"})
  264. current_iteration += 1
  265. continue
  266. # 没有工具调用,这是最终回答
  267. final_response = response
  268. break
  269. # 如果超过最大迭代次数,获取最后一次回答
  270. if current_iteration >= max_tool_iterations and not final_response:
  271. final_response = self.llm.invoke(messages, **kwargs)
  272. # 保存到历史记录
  273. self.add_message(Message(input_text, "user"))
  274. self.add_message(Message(final_response, "assistant"))
  275. return final_response
  276. def add_tool(self, tool) -> None:
  277. """
  278. 添加工具到Agent(便利方法)
  279. 如果是MCP工具且启用了auto_expand,会自动展开为多个独立工具
  280. """
  281. if not self.tool_registry:
  282. from tools.registry import ToolRegistry
  283. self.tool_registry = ToolRegistry()
  284. self.enable_tool_calling = True
  285. # 检查是否是MCP工具且需要展开
  286. if hasattr(tool, 'auto_expand') and tool.auto_expand:
  287. # 获取展开的工具列表
  288. expanded_tools = tool.get_expanded_tools()
  289. if expanded_tools:
  290. # 注册所有展开的工具
  291. for expanded_tool in expanded_tools:
  292. self.tool_registry.register_tool(expanded_tool)
  293. print(f"✅ MCP工具 '{tool.name}' 已展开为 {len(expanded_tools)} 个独立工具")
  294. return
  295. # 普通工具或不展开的MCP工具
  296. self.tool_registry.register_tool(tool)
  297. def remove_tool(self, tool_name: str) -> bool:
  298. """移除工具(便利方法)"""
  299. if self.tool_registry:
  300. return self.tool_registry.unregister_tool(tool_name)
  301. return False
  302. def list_tools(self) -> list:
  303. """列出所有可用工具"""
  304. if self.tool_registry:
  305. return self.tool_registry.list_tools()
  306. return []
  307. def has_tools(self) -> bool:
  308. """检查是否有可用工具"""
  309. return self.enable_tool_calling and self.tool_registry is not None
  310. def stream_run(self, input_text: str, **kwargs) -> Iterator[str]:
  311. """
  312. 流式运行Agent
  313. Args:
  314. input_text: 用户输入
  315. **kwargs: 其他参数
  316. Yields:
  317. Agent响应片段
  318. """
  319. # 构建消息列表
  320. messages = []
  321. if self.system_prompt:
  322. messages.append({"role": "system", "content": self.system_prompt})
  323. for msg in self._history:
  324. messages.append({"role": msg.role, "content": msg.content})
  325. messages.append({"role": "user", "content": input_text})
  326. # 流式调用LLM
  327. full_response = ""
  328. for chunk in self.llm.stream_invoke(messages, **kwargs):
  329. full_response += chunk
  330. yield chunk
  331. # 保存完整对话到历史记录
  332. self.add_message(Message(input_text, "user"))
  333. self.add_message(Message(full_response, "assistant"))