memory.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """Step 6: 股票分析记忆系统
  2. 持久化存储: 关注列表、分析历史、用户偏好
  3. """
  4. import json
  5. import os
  6. from datetime import datetime
  7. from typing import Optional
  8. class StockMemory:
  9. """股票分析记忆 — JSON 文件持久化"""
  10. def __init__(self, path: str = "memory/stock_memory.json"):
  11. self.path = path
  12. self.data = self._load()
  13. def _load(self) -> dict:
  14. if os.path.exists(self.path):
  15. try:
  16. with open(self.path, "r", encoding="utf-8") as f:
  17. return json.load(f)
  18. except (json.JSONDecodeError, IOError):
  19. pass
  20. return {"watchlist": {}, "history": [], "preferences": {}}
  21. def _save(self):
  22. os.makedirs(os.path.dirname(self.path), exist_ok=True)
  23. with open(self.path, "w", encoding="utf-8") as f:
  24. json.dump(self.data, f, ensure_ascii=False, indent=2)
  25. # ===== 关注列表 =====
  26. def add_watchlist(self, code: str, name: str = "", notes: str = "") -> str:
  27. self.data["watchlist"][code] = {
  28. "name": name or code,
  29. "notes": notes,
  30. "added": datetime.now().strftime("%Y-%m-%d %H:%M"),
  31. }
  32. self._save()
  33. return f"已添加 {name or code}({code}) 到关注列表"
  34. def remove_watchlist(self, code: str) -> str:
  35. if code in self.data["watchlist"]:
  36. name = self.data["watchlist"][code]["name"]
  37. del self.data["watchlist"][code]
  38. self._save()
  39. return f"已从关注列表移除 {name}({code})"
  40. return f"关注列表中未找到 {code}"
  41. def get_watchlist(self, query: str = "") -> str:
  42. wl = self.data["watchlist"]
  43. if not wl:
  44. return "关注列表为空。说'关注 600519'来添加。"
  45. lines = [f"关注列表 ({len(wl)} 只):"]
  46. for code, info in wl.items():
  47. lines.append(f" {info['name']}({code}) [{info['added']}]")
  48. if info.get("notes"):
  49. lines.append(f" 备注: {info['notes']}")
  50. return "\n".join(lines)
  51. # ===== 分析历史 =====
  52. def save_analysis(self, code: str, question: str, summary: str) -> str:
  53. record = {
  54. "code": code,
  55. "question": question,
  56. "summary": summary[:500], # 截取前500字作为摘要
  57. "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  58. }
  59. self.data["history"].append(record)
  60. # 只保留最近100条
  61. if len(self.data["history"]) > 100:
  62. self.data["history"] = self.data["history"][-100:]
  63. self._save()
  64. return f"分析记录已保存 ({len(self.data['history'])} 条历史)"
  65. def get_history(self, query: str = "") -> str:
  66. code = query.strip() if query else ""
  67. records = self.data["history"]
  68. if code:
  69. records = [r for r in records if r["code"] == code]
  70. if not records:
  71. return f"暂无{' ' + code + ' 的' if code else ''}分析历史"
  72. lines = [f"分析历史 (最近{len(records)}条):"]
  73. for r in records[-10:]: # 最近10条
  74. lines.append(f" [{r['timestamp']}] {r['code']}: {r['question'][:60]}")
  75. return "\n".join(lines)
  76. def get_last_analysis(self, code: str = "") -> Optional[str]:
  77. records = self.data["history"]
  78. if code:
  79. records = [r for r in records if r["code"] == code]
  80. if records:
  81. return records[-1].get("summary", "")
  82. return None
  83. # ===== 用户偏好 =====
  84. def set_preference(self, key: str, value: str) -> str:
  85. self.data["preferences"][key] = value
  86. self._save()
  87. return f"偏好已设置: {key} = {value}"
  88. def get_preferences(self, query: str = "") -> str:
  89. prefs = self.data["preferences"]
  90. if not prefs:
  91. return "暂无保存的偏好。可以设置如: '偏好 分析风格=深度价值投资'"
  92. lines = ["用户偏好:"]
  93. for k, v in prefs.items():
  94. lines.append(f" {k}: {v}")
  95. return "\n".join(lines)
  96. def clear(self) -> str:
  97. self.data = {"watchlist": {}, "history": [], "preferences": {}}
  98. self._save()
  99. return "记忆已清空"
  100. # 全局单例
  101. _memory_instance = None
  102. def get_memory() -> StockMemory:
  103. global _memory_instance
  104. if _memory_instance is None:
  105. _memory_instance = StockMemory()
  106. return _memory_instance
  107. # ===== 工具函数(可直接注册到 ToolRegistry)=====
  108. def memory_add_watchlist(query: str) -> str:
  109. """添加股票到关注列表。输入: '代码|名称' 如 '600519|贵州茅台'"""
  110. parts = query.strip().split("|")
  111. code = parts[0].strip()
  112. name = parts[1].strip() if len(parts) > 1 else ""
  113. return get_memory().add_watchlist(code, name)
  114. def memory_remove_watchlist(code: str) -> str:
  115. """从关注列表移除股票。输入: 股票代码"""
  116. return get_memory().remove_watchlist(code.strip())
  117. def memory_get_watchlist(query: str = "") -> str:
  118. """查看关注列表"""
  119. return get_memory().get_watchlist(query)
  120. def memory_save_analysis(query: str) -> str:
  121. """保存分析结果。输入: '代码|问题|摘要' """
  122. parts = query.strip().split("|")
  123. code = parts[0].strip() if len(parts) > 0 else ""
  124. question = parts[1].strip() if len(parts) > 1 else ""
  125. summary = parts[2].strip() if len(parts) > 2 else ""
  126. return get_memory().save_analysis(code, question, summary)
  127. def memory_get_history(query: str = "") -> str:
  128. """查看分析历史。输入: 股票代码(可选,留空看全部)"""
  129. return get_memory().get_history(query)
  130. def memory_set_preference(query: str) -> str:
  131. """设置用户偏好。输入: 'key=value' 如 '风格=技术分析为主'"""
  132. if "=" in query:
  133. k, v = query.split("=", 1)
  134. return get_memory().set_preference(k.strip(), v.strip())
  135. return "格式: key=value"
  136. def memory_get_preferences(query: str = "") -> str:
  137. """查看用户偏好"""
  138. return get_memory().get_preferences()