memory_tool.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. """记忆工具
  2. 为HelloAgents框架提供记忆能力的工具实现。
  3. 可以作为工具添加到任何Agent中,让Agent具备记忆功能。
  4. """
  5. from typing import Dict, Any, List
  6. from datetime import datetime
  7. from ..base import Tool, ToolParameter
  8. from memory import MemoryManager, MemoryConfig
  9. class MemoryTool(Tool):
  10. """记忆工具
  11. 为Agent提供记忆功能:
  12. - 添加记忆
  13. - 检索相关记忆
  14. - 获取记忆摘要
  15. - 管理记忆生命周期
  16. """
  17. def __init__(
  18. self,
  19. user_id: str = "default_user",
  20. memory_config: MemoryConfig = None,
  21. memory_types: List[str] = None
  22. ):
  23. super().__init__(
  24. name="memory",
  25. description="记忆工具 - 可以存储和检索对话历史、知识和经验"
  26. )
  27. # 初始化记忆管理器
  28. self.memory_config = memory_config or MemoryConfig()
  29. self.memory_types = memory_types or ["working", "episodic", "semantic"]
  30. self.memory_manager = MemoryManager(
  31. config=self.memory_config,
  32. user_id=user_id,
  33. enable_working="working" in self.memory_types,
  34. enable_episodic="episodic" in self.memory_types,
  35. enable_semantic="semantic" in self.memory_types,
  36. enable_perceptual="perceptual" in self.memory_types
  37. )
  38. # 会话状态
  39. self.current_session_id = None
  40. self.conversation_count = 0
  41. def run(self, parameters: Dict[str, Any]) -> str:
  42. """执行工具 - Tool基类要求的接口
  43. Args:
  44. parameters: 工具参数字典,必须包含action参数
  45. Returns:
  46. 执行结果字符串
  47. """
  48. if not self.validate_parameters(parameters):
  49. return "❌ 参数验证失败:缺少必需的参数"
  50. action = parameters.get("action")
  51. # 移除action参数,传递其余参数给execute方法
  52. kwargs = {k: v for k, v in parameters.items() if k != "action"}
  53. return self.execute(action, **kwargs)
  54. def get_parameters(self) -> List[ToolParameter]:
  55. """获取工具参数定义 - Tool基类要求的接口"""
  56. return [
  57. ToolParameter(
  58. name="action",
  59. type="string",
  60. description=(
  61. "要执行的操作:"
  62. "add(添加记忆), search(搜索记忆), summary(获取摘要), stats(获取统计), "
  63. "update(更新记忆), remove(删除记忆), forget(遗忘记忆), consolidate(整合记忆), clear_all(清空所有记忆)"
  64. ),
  65. required=True
  66. ),
  67. ToolParameter(name="content", type="string", description="记忆内容(add/update时可用;感知记忆可作描述)", required=False),
  68. ToolParameter(name="query", type="string", description="搜索查询(search时可用)", required=False),
  69. ToolParameter(name="memory_type", type="string", description="记忆类型:working, episodic, semantic, perceptual(默认:working)", required=False, default="working"),
  70. ToolParameter(name="importance", type="number", description="重要性分数,0.0-1.0(add/update时可用)", required=False),
  71. ToolParameter(name="limit", type="integer", description="搜索结果数量限制(默认:5)", required=False, default=5),
  72. ToolParameter(name="memory_id", type="string", description="目标记忆ID(update/remove时必需)", required=False),
  73. ToolParameter(name="file_path", type="string", description="感知记忆:本地文件路径(image/audio)", required=False),
  74. ToolParameter(name="modality", type="string", description="感知记忆模态:text/image/audio(不传则按扩展名推断)", required=False),
  75. ToolParameter(name="strategy", type="string", description="遗忘策略:importance_based/time_based/capacity_based(forget时可用)", required=False, default="importance_based"),
  76. ToolParameter(name="threshold", type="number", description="遗忘阈值(forget时可用,默认0.1)", required=False, default=0.1),
  77. ToolParameter(name="max_age_days", type="integer", description="最大保留天数(forget策略为time_based时可用)", required=False, default=30),
  78. ToolParameter(name="from_type", type="string", description="整合来源类型(consolidate时可用,默认working)", required=False, default="working"),
  79. ToolParameter(name="to_type", type="string", description="整合目标类型(consolidate时可用,默认episodic)", required=False, default="episodic"),
  80. ToolParameter(name="importance_threshold", type="number", description="整合重要性阈值(默认0.7)", required=False, default=0.7),
  81. ]
  82. def execute(self, action: str, **kwargs) -> str:
  83. """执行记忆操作
  84. 支持的操作:
  85. - add: 添加记忆
  86. - search: 搜索记忆
  87. - summary: 获取记忆摘要
  88. - stats: 获取统计信息
  89. """
  90. if action == "add":
  91. return self._add_memory(**kwargs)
  92. elif action == "search":
  93. return self._search_memory(**kwargs)
  94. elif action == "summary":
  95. return self._get_summary(**kwargs)
  96. elif action == "stats":
  97. return self._get_stats()
  98. elif action == "update":
  99. return self._update_memory(**kwargs)
  100. elif action == "remove":
  101. return self._remove_memory(**kwargs)
  102. elif action == "forget":
  103. return self._forget(**kwargs)
  104. elif action == "consolidate":
  105. return self._consolidate(**kwargs)
  106. elif action == "clear_all":
  107. return self._clear_all()
  108. else:
  109. return f"不支持的操作: {action}。支持的操作: add, search, summary, stats, update, remove, forget, consolidate, clear_all"
  110. def _add_memory(
  111. self,
  112. content: str = "",
  113. memory_type: str = "working",
  114. importance: float = 0.5,
  115. file_path: str = None,
  116. modality: str = None,
  117. **metadata
  118. ) -> str:
  119. """添加记忆"""
  120. try:
  121. # 确保会话ID存在
  122. if self.current_session_id is None:
  123. self.current_session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  124. # 感知记忆文件支持:注入 raw_data 与模态
  125. if memory_type == "perceptual" and file_path:
  126. inferred = modality or self._infer_modality(file_path)
  127. metadata.setdefault("modality", inferred)
  128. metadata.setdefault("raw_data", file_path)
  129. # 添加会话信息到元数据
  130. metadata.update({
  131. "session_id": self.current_session_id,
  132. "timestamp": datetime.now().isoformat()
  133. })
  134. memory_id = self.memory_manager.add_memory(
  135. content=content,
  136. memory_type=memory_type,
  137. importance=importance,
  138. metadata=metadata,
  139. auto_classify=False # 禁用自动分类,使用明确指定的类型
  140. )
  141. return f"✅ 记忆已添加 (ID: {memory_id[:8]}...)"
  142. except Exception as e:
  143. return f"❌ 添加记忆失败: {str(e)}"
  144. def _infer_modality(self, path: str) -> str:
  145. """根据扩展名推断模态(默认image/audio/text)"""
  146. try:
  147. ext = (path.rsplit('.', 1)[-1] or '').lower()
  148. if ext in {"png", "jpg", "jpeg", "bmp", "gif", "webp"}:
  149. return "image"
  150. if ext in {"mp3", "wav", "flac", "m4a", "ogg"}:
  151. return "audio"
  152. return "text"
  153. except Exception:
  154. return "text"
  155. def _search_memory(
  156. self,
  157. query: str,
  158. limit: int = 5,
  159. memory_types: List[str] = None,
  160. memory_type: str = None, # 添加单数形式的参数支持
  161. min_importance: float = 0.1
  162. ) -> str:
  163. """搜索记忆"""
  164. try:
  165. # 处理单数形式的memory_type参数
  166. if memory_type and not memory_types:
  167. memory_types = [memory_type]
  168. results = self.memory_manager.retrieve_memories(
  169. query=query,
  170. limit=limit,
  171. memory_types=memory_types,
  172. min_importance=min_importance
  173. )
  174. if not results:
  175. return f"🔍 未找到与 '{query}' 相关的记忆"
  176. # 格式化结果
  177. formatted_results = []
  178. formatted_results.append(f"🔍 找到 {len(results)} 条相关记忆:")
  179. for i, memory in enumerate(results, 1):
  180. memory_type_label = {
  181. "working": "工作记忆",
  182. "episodic": "情景记忆",
  183. "semantic": "语义记忆",
  184. "perceptual": "感知记忆"
  185. }.get(memory.memory_type, memory.memory_type)
  186. content_preview = memory.content[:80] + "..." if len(memory.content) > 80 else memory.content
  187. formatted_results.append(
  188. f"{i}. [{memory_type_label}] {content_preview} (重要性: {memory.importance:.2f})"
  189. )
  190. return "\n".join(formatted_results)
  191. except Exception as e:
  192. return f"❌ 搜索记忆失败: {str(e)}"
  193. def _get_summary(self, limit: int = 10) -> str:
  194. """获取记忆摘要"""
  195. try:
  196. stats = self.memory_manager.get_memory_stats()
  197. summary_parts = [
  198. f"📊 记忆系统摘要",
  199. f"总记忆数: {stats['total_memories']}",
  200. f"当前会话: {self.current_session_id or '未开始'}",
  201. f"对话轮次: {self.conversation_count}"
  202. ]
  203. # 各类型记忆统计
  204. if stats['memories_by_type']:
  205. summary_parts.append("\n📋 记忆类型分布:")
  206. for memory_type, type_stats in stats['memories_by_type'].items():
  207. count = type_stats.get('count', 0)
  208. avg_importance = type_stats.get('avg_importance', 0)
  209. type_label = {
  210. "working": "工作记忆",
  211. "episodic": "情景记忆",
  212. "semantic": "语义记忆",
  213. "perceptual": "感知记忆"
  214. }.get(memory_type, memory_type)
  215. summary_parts.append(f" • {type_label}: {count} 条 (平均重要性: {avg_importance:.2f})")
  216. # 获取重要记忆 - 修复重复问题
  217. important_memories = self.memory_manager.retrieve_memories(
  218. query="",
  219. memory_types=None, # 从所有类型中检索
  220. limit=limit * 3, # 获取更多候选,然后去重
  221. min_importance=0.5 # 降低阈值以获取更多记忆
  222. )
  223. if important_memories:
  224. # 去重:使用记忆ID和内容双重去重
  225. seen_ids = set()
  226. seen_contents = set()
  227. unique_memories = []
  228. for memory in important_memories:
  229. # 使用ID去重
  230. if memory.id in seen_ids:
  231. continue
  232. # 使用内容去重(防止相同内容的不同记忆)
  233. content_key = memory.content.strip().lower()
  234. if content_key in seen_contents:
  235. continue
  236. seen_ids.add(memory.id)
  237. seen_contents.add(content_key)
  238. unique_memories.append(memory)
  239. # 按重要性排序
  240. unique_memories.sort(key=lambda x: x.importance, reverse=True)
  241. summary_parts.append(f"\n⭐ 重要记忆 (前{min(limit, len(unique_memories))}条):")
  242. for i, memory in enumerate(unique_memories[:limit], 1):
  243. content_preview = memory.content[:60] + "..." if len(memory.content) > 60 else memory.content
  244. summary_parts.append(f" {i}. {content_preview} (重要性: {memory.importance:.2f})")
  245. return "\n".join(summary_parts)
  246. except Exception as e:
  247. return f"❌ 获取摘要失败: {str(e)}"
  248. def _get_stats(self) -> str:
  249. """获取统计信息"""
  250. try:
  251. stats = self.memory_manager.get_memory_stats()
  252. stats_info = [
  253. f"📈 记忆系统统计",
  254. f"总记忆数: {stats['total_memories']}",
  255. f"启用的记忆类型: {', '.join(stats['enabled_types'])}",
  256. f"会话ID: {self.current_session_id or '未开始'}",
  257. f"对话轮次: {self.conversation_count}"
  258. ]
  259. return "\n".join(stats_info)
  260. except Exception as e:
  261. return f"❌ 获取统计信息失败: {str(e)}"
  262. def auto_record_conversation(self, user_input: str, agent_response: str):
  263. """自动记录对话
  264. 这个方法可以被Agent调用来自动记录对话历史
  265. """
  266. self.conversation_count += 1
  267. # 记录用户输入
  268. self._add_memory(
  269. content=f"用户: {user_input}",
  270. memory_type="working",
  271. importance=0.6,
  272. type="user_input",
  273. conversation_id=self.conversation_count
  274. )
  275. # 记录Agent响应
  276. self._add_memory(
  277. content=f"助手: {agent_response}",
  278. memory_type="working",
  279. importance=0.7,
  280. type="agent_response",
  281. conversation_id=self.conversation_count
  282. )
  283. # 如果是重要对话,记录为情景记忆
  284. if len(agent_response) > 100 or "重要" in user_input or "记住" in user_input:
  285. interaction_content = f"对话 - 用户: {user_input}\n助手: {agent_response}"
  286. self._add_memory(
  287. content=interaction_content,
  288. memory_type="episodic",
  289. importance=0.8,
  290. type="interaction",
  291. conversation_id=self.conversation_count
  292. )
  293. def _update_memory(self, memory_id: str, content: str = None, importance: float = None, **metadata) -> str:
  294. """更新记忆"""
  295. try:
  296. success = self.memory_manager.update_memory(
  297. memory_id=memory_id,
  298. content=content,
  299. importance=importance,
  300. metadata=metadata or None
  301. )
  302. return "✅ 记忆已更新" if success else "⚠️ 未找到要更新的记忆"
  303. except Exception as e:
  304. return f"❌ 更新记忆失败: {str(e)}"
  305. def _remove_memory(self, memory_id: str) -> str:
  306. """删除记忆"""
  307. try:
  308. success = self.memory_manager.remove_memory(memory_id)
  309. return "✅ 记忆已删除" if success else "⚠️ 未找到要删除的记忆"
  310. except Exception as e:
  311. return f"❌ 删除记忆失败: {str(e)}"
  312. def _forget(self, strategy: str = "importance_based", threshold: float = 0.1, max_age_days: int = 30) -> str:
  313. """遗忘记忆(支持多种策略)"""
  314. try:
  315. count = self.memory_manager.forget_memories(
  316. strategy=strategy,
  317. threshold=threshold,
  318. max_age_days=max_age_days
  319. )
  320. return f"🧹 已遗忘 {count} 条记忆(策略: {strategy})"
  321. except Exception as e:
  322. return f"❌ 遗忘记忆失败: {str(e)}"
  323. def _consolidate(self, from_type: str = "working", to_type: str = "episodic", importance_threshold: float = 0.7) -> str:
  324. """整合记忆(将重要的短期记忆提升为长期记忆)"""
  325. try:
  326. count = self.memory_manager.consolidate_memories(
  327. from_type=from_type,
  328. to_type=to_type,
  329. importance_threshold=importance_threshold,
  330. )
  331. return f"🔄 已整合 {count} 条记忆为长期记忆({from_type} → {to_type},阈值={importance_threshold})"
  332. except Exception as e:
  333. return f"❌ 整合记忆失败: {str(e)}"
  334. def _clear_all(self) -> str:
  335. """清空所有记忆"""
  336. try:
  337. self.memory_manager.clear_all_memories()
  338. return "🧽 已清空所有记忆"
  339. except Exception as e:
  340. return f"❌ 清空记忆失败: {str(e)}"
  341. def add_knowledge(self, content: str, importance: float = 0.9):
  342. """添加知识到语义记忆
  343. 便捷方法,用于添加重要知识
  344. """
  345. return self._add_memory(
  346. content=content,
  347. memory_type="semantic",
  348. importance=importance,
  349. knowledge_type="factual",
  350. source="manual"
  351. )
  352. def get_context_for_query(self, query: str, limit: int = 3) -> str:
  353. """为查询获取相关上下文
  354. 这个方法可以被Agent调用来获取相关的记忆上下文
  355. """
  356. results = self.memory_manager.retrieve_memories(
  357. query=query,
  358. limit=limit,
  359. min_importance=0.3
  360. )
  361. if not results:
  362. return ""
  363. context_parts = ["相关记忆:"]
  364. for memory in results:
  365. context_parts.append(f"- {memory.content}")
  366. return "\n".join(context_parts)
  367. def clear_session(self):
  368. """清除当前会话"""
  369. self.current_session_id = None
  370. self.conversation_count = 0
  371. # 清理工作记忆
  372. wm = self.memory_manager.memory_types.get('working') if hasattr(self.memory_manager, 'memory_types') else None
  373. if wm:
  374. wm.clear()
  375. def consolidate_memories(self):
  376. """整合记忆"""
  377. return self.memory_manager.consolidate_memories()
  378. def forget_old_memories(self, max_age_days: int = 30):
  379. """遗忘旧记忆"""
  380. return self.memory_manager.forget_memories(
  381. strategy="time_based",
  382. max_age_days=max_age_days
  383. )