registry.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """工具注册表 - HelloAgents原生工具系统"""
  2. from typing import Optional, Any, Callable
  3. import json
  4. from .base import Tool
  5. class ToolRegistry:
  6. """
  7. HelloAgents工具注册表
  8. 提供工具的注册、管理和执行功能。
  9. 支持两种工具注册方式:
  10. 1. Tool对象注册(推荐)
  11. 2. 函数直接注册(简便)
  12. """
  13. def __init__(self):
  14. self._tools: dict[str, Tool] = {}
  15. self._functions: dict[str, dict[str, Any]] = {}
  16. def register_tool(self, tool: Tool):
  17. """
  18. 注册Tool对象
  19. Args:
  20. tool: Tool实例
  21. """
  22. if tool.name in self._tools:
  23. print(f"⚠️ 警告:工具 '{tool.name}' 已存在,将被覆盖。")
  24. self._tools[tool.name] = tool
  25. print(f"✅ 工具 '{tool.name}' 已注册。")
  26. def register_function(self, name: str, description: str, func: Callable[[str], str]):
  27. """
  28. 直接注册函数作为工具(简便方式)
  29. Args:
  30. name: 工具名称
  31. description: 工具描述
  32. func: 工具函数,接受字符串参数,返回字符串结果
  33. """
  34. if name in self._functions:
  35. print(f"⚠️ 警告:工具 '{name}' 已存在,将被覆盖。")
  36. self._functions[name] = {
  37. "description": description,
  38. "func": func
  39. }
  40. print(f"✅ 工具 '{name}' 已注册。")
  41. def unregister(self, name: str):
  42. """注销工具"""
  43. if name in self._tools:
  44. del self._tools[name]
  45. print(f"🗑️ 工具 '{name}' 已注销。")
  46. elif name in self._functions:
  47. del self._functions[name]
  48. print(f"🗑️ 工具 '{name}' 已注销。")
  49. else:
  50. print(f"⚠️ 工具 '{name}' 不存在。")
  51. def get_tool(self, name: str) -> Optional[Tool]:
  52. """获取Tool对象"""
  53. return self._tools.get(name)
  54. def get_function(self, name: str) -> Optional[Callable]:
  55. """获取工具函数"""
  56. func_info = self._functions.get(name)
  57. return func_info["func"] if func_info else None
  58. def execute_tool(self, name: str, input_text: str) -> str:
  59. """
  60. 执行工具
  61. Args:
  62. name: 工具名称
  63. input_text: 输入参数
  64. Returns:
  65. 工具执行结果
  66. """
  67. # 优先查找Tool对象
  68. if name in self._tools:
  69. tool = self._tools[name]
  70. try:
  71. raw = (input_text or "").strip()
  72. # 预处理:如果输入包含换行和另一个 Action,只取第一行
  73. if '\n' in raw and 'Action:' in raw:
  74. lines = raw.split('\n')
  75. raw = lines[0].strip()
  76. # 1) JSON 直通:允许 ReAct 里用 tool[{"k":"v"}] 精确传参
  77. def _try_json(txt: str):
  78. try:
  79. return json.loads(txt)
  80. except Exception:
  81. return None
  82. obj = None
  83. # 1a 单个对象
  84. if raw.startswith("{") and raw.endswith("}"):
  85. obj = _try_json(raw)
  86. # 1b 常见模型输出尾部多了一个 ']' 的容错
  87. if obj is None and raw.startswith("{") and raw.endswith("}]"):
  88. obj = _try_json(raw[:-1].strip())
  89. # 1c 模型输出为数组包裹一个对象
  90. if obj is None and raw.startswith("[") and raw.endswith("]"):
  91. arr = _try_json(raw)
  92. if isinstance(arr, list) and len(arr) == 1 and isinstance(arr[0], dict):
  93. obj = arr[0]
  94. # 1d 错位尾括号(常见:{"a":1,"b":2}])
  95. if obj is None and raw.endswith("}]") and raw.count("{") == 1 and raw.count("}") == 2:
  96. obj = _try_json(raw[:-1])
  97. # 1e 正则兜底:提取首个完整 JSON 对象
  98. if obj is None and "{" in raw and "}" in raw:
  99. try:
  100. import re
  101. # 使用括号匹配而非简单正则
  102. def extract_first_json_object(text: str):
  103. """从文本中提取第一个完整的 JSON 对象"""
  104. start = text.find('{')
  105. if start == -1:
  106. return None
  107. depth = 0
  108. in_string = False
  109. escape = False
  110. for i, c in enumerate(text[start:], start):
  111. if escape:
  112. escape = False
  113. continue
  114. if c == '\\' and in_string:
  115. escape = True
  116. continue
  117. if c == '"' and not escape:
  118. in_string = not in_string
  119. continue
  120. if in_string:
  121. continue
  122. if c == '{':
  123. depth += 1
  124. elif c == '}':
  125. depth -= 1
  126. if depth == 0:
  127. return text[start:i+1]
  128. return None
  129. json_str = extract_first_json_object(raw)
  130. if json_str:
  131. obj = json.loads(json_str)
  132. except Exception:
  133. pass
  134. if isinstance(obj, dict):
  135. return tool.run(obj)
  136. # 2) 单参数兜底:如果工具只有一个必填参数,把 input_text 映射到该参数名
  137. params = tool.get_parameters()
  138. required = [p for p in params if p.required]
  139. if len(required) == 1:
  140. return tool.run({required[0].name: input_text})
  141. # 3) 兼容旧行为:若存在 input 参数,使用 input
  142. if any(p.name == "input" for p in params):
  143. return tool.run({"input": input_text})
  144. return (
  145. f"错误:工具 '{name}' 需要结构化参数。"
  146. "请使用 JSON 形式传参,例如:tool[{\"param\":\"value\"}]"
  147. )
  148. except Exception as e:
  149. return f"错误:执行工具 '{name}' 时发生异常: {str(e)}"
  150. # 查找函数工具
  151. elif name in self._functions:
  152. func = self._functions[name]["func"]
  153. try:
  154. return func(input_text)
  155. except Exception as e:
  156. return f"错误:执行工具 '{name}' 时发生异常: {str(e)}"
  157. else:
  158. return f"错误:未找到名为 '{name}' 的工具。"
  159. def get_tools_description(self) -> str:
  160. """
  161. 获取所有可用工具的格式化描述字符串
  162. Returns:
  163. 工具描述字符串,用于构建提示词
  164. """
  165. descriptions = []
  166. # Tool对象描述
  167. for tool in self._tools.values():
  168. descriptions.append(f"- {tool.name}: {tool.description}")
  169. # 函数工具描述
  170. for name, info in self._functions.items():
  171. descriptions.append(f"- {name}: {info['description']}")
  172. return "\n".join(descriptions) if descriptions else "暂无可用工具"
  173. def list_tools(self) -> list[str]:
  174. """列出所有工具名称"""
  175. return list(self._tools.keys()) + list(self._functions.keys())
  176. def get_all_tools(self) -> list[Tool]:
  177. """获取所有Tool对象"""
  178. return list(self._tools.values())
  179. def clear(self):
  180. """清空所有工具"""
  181. self._tools.clear()
  182. self._functions.clear()
  183. print("🧹 所有工具已清空。")
  184. # 全局工具注册表
  185. global_registry = ToolRegistry()