memory.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """Step 6: 股票分析记忆系统
  2. 持久化存储: 关注列表、分析历史、用户偏好
  3. """
  4. import json
  5. import os
  6. from datetime import datetime
  7. from typing import Optional
  8. class StockMemory:
  9. """股票分析记忆 — JSON 文件持久化"""
  10. def __init__(self, path: str = "memory/stock_memory.json"):
  11. import threading
  12. self.path = path
  13. self.data = self._load()
  14. self._lock = threading.Lock()
  15. def _load(self) -> dict:
  16. if os.path.exists(self.path):
  17. try:
  18. with open(self.path, "r", encoding="utf-8") as f:
  19. return json.load(f)
  20. except (json.JSONDecodeError, IOError):
  21. pass
  22. return {"watchlist": {}, "history": [], "preferences": {}}
  23. def _save(self):
  24. import tempfile
  25. with self._lock:
  26. os.makedirs(os.path.dirname(self.path), exist_ok=True)
  27. fd, temp_path = tempfile.mkstemp(dir=os.path.dirname(self.path))
  28. try:
  29. with os.fdopen(fd, "w", encoding="utf-8") as f:
  30. json.dump(self.data, f, ensure_ascii=False, indent=2)
  31. os.replace(temp_path, self.path)
  32. except Exception:
  33. os.remove(temp_path)
  34. raise
  35. # ===== 关注列表 =====
  36. def add_watchlist(self, code: str, name: str = "", notes: str = "") -> str:
  37. self.data["watchlist"][code] = {
  38. "name": name or code,
  39. "notes": notes,
  40. "added": datetime.now().strftime("%Y-%m-%d %H:%M"),
  41. }
  42. self._save()
  43. return f"已添加 {name or code}({code}) 到关注列表"
  44. def remove_watchlist(self, code: str) -> str:
  45. if code in self.data["watchlist"]:
  46. name = self.data["watchlist"][code]["name"]
  47. del self.data["watchlist"][code]
  48. self._save()
  49. return f"已从关注列表移除 {name}({code})"
  50. return f"关注列表中未找到 {code}"
  51. def get_watchlist(self, query: str = "") -> str:
  52. wl = self.data["watchlist"]
  53. if not wl:
  54. return "关注列表为空。说'关注 600519'来添加。"
  55. lines = [f"关注列表 ({len(wl)} 只):"]
  56. for code, info in wl.items():
  57. lines.append(f" {info['name']}({code}) [{info['added']}]")
  58. if info.get("notes"):
  59. lines.append(f" 备注: {info['notes']}")
  60. return "\n".join(lines)
  61. # ===== 分析历史 =====
  62. def save_analysis(self, code: str, question: str, summary: str) -> str:
  63. record = {
  64. "code": code,
  65. "question": question,
  66. "summary": summary[:500], # 截取前500字作为摘要
  67. "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  68. }
  69. self.data["history"].append(record)
  70. # 只保留最近100条
  71. if len(self.data["history"]) > 100:
  72. self.data["history"] = self.data["history"][-100:]
  73. self._save()
  74. return f"分析记录已保存 ({len(self.data['history'])} 条历史)"
  75. def get_history(self, query: str = "") -> str:
  76. code = query.strip() if query else ""
  77. records = self.data["history"]
  78. if code:
  79. records = [r for r in records if r["code"] == code]
  80. if not records:
  81. return f"暂无{' ' + code + ' 的' if code else ''}分析历史"
  82. lines = [f"分析历史 (最近{len(records)}条):"]
  83. for r in records[-10:]: # 最近10条
  84. lines.append(f" [{r['timestamp']}] {r['code']}: {r['question'][:60]}")
  85. return "\n".join(lines)
  86. def get_last_analysis(self, code: str = "") -> Optional[str]:
  87. records = self.data["history"]
  88. if code:
  89. records = [r for r in records if r["code"] == code]
  90. if records:
  91. return records[-1].get("summary", "")
  92. return None
  93. # ===== 用户偏好 =====
  94. def set_preference(self, key: str, value: str) -> str:
  95. self.data["preferences"][key] = value
  96. self._save()
  97. return f"偏好已设置: {key} = {value}"
  98. def get_preferences(self, query: str = "") -> str:
  99. prefs = self.data["preferences"]
  100. if not prefs:
  101. return "暂无保存的偏好。可以设置如: '偏好 分析风格=深度价值投资'"
  102. lines = ["用户偏好:"]
  103. for k, v in prefs.items():
  104. lines.append(f" {k}: {v}")
  105. return "\n".join(lines)
  106. def clear(self) -> str:
  107. self.data = {"watchlist": {}, "history": [], "preferences": {}}
  108. self._save()
  109. return "记忆已清空"
  110. # 全局单例
  111. _memory_instance = None
  112. def get_memory() -> StockMemory:
  113. global _memory_instance
  114. if _memory_instance is None:
  115. _memory_instance = StockMemory()
  116. return _memory_instance
  117. # ===== 工具函数(可直接注册到 ToolRegistry)=====
  118. def memory_add_watchlist(query: str) -> str:
  119. """添加股票到关注列表。输入: '代码|名称' 如 '600519|贵州茅台'"""
  120. parts = query.strip().split("|")
  121. code = parts[0].strip()
  122. name = parts[1].strip() if len(parts) > 1 else ""
  123. return get_memory().add_watchlist(code, name)
  124. def memory_remove_watchlist(code: str) -> str:
  125. """从关注列表移除股票。输入: 股票代码"""
  126. return get_memory().remove_watchlist(code.strip())
  127. def memory_get_watchlist(query: str = "") -> str:
  128. """查看关注列表"""
  129. return get_memory().get_watchlist(query)
  130. def memory_save_analysis(query: str) -> str:
  131. """保存分析结果。输入: '代码|问题|摘要' """
  132. parts = query.strip().split("|")
  133. code = parts[0].strip() if len(parts) > 0 else ""
  134. question = parts[1].strip() if len(parts) > 1 else ""
  135. summary = parts[2].strip() if len(parts) > 2 else ""
  136. return get_memory().save_analysis(code, question, summary)
  137. def memory_get_history(query: str = "") -> str:
  138. """查看分析历史。输入: 股票代码(可选,留空看全部)"""
  139. return get_memory().get_history(query)
  140. def memory_set_preference(query: str) -> str:
  141. """设置用户偏好。输入: 'key=value' 如 '风格=技术分析为主'"""
  142. if "=" in query:
  143. k, v = query.split("=", 1)
  144. return get_memory().set_preference(k.strip(), v.strip())
  145. return "格式: key=value"
  146. def memory_get_preferences(query: str = "") -> str:
  147. """查看用户偏好"""
  148. return get_memory().get_preferences()