memory_service.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """
  2. 记忆系统 — 仪表盘快照的每日缓存与过期管理
  3. 记录每天第一次打开后端的时间日期;跨日或自选股数量变化时触发刷新。
  4. 每天首次启动时三线程并行获取指数、自选、热点资讯,写入 data/memory/dashboard_state.json。
  5. 与 HelloAgents 的 ConversationManager / MemoryManager 无关,本应用对话历史由前端与 SQLite 历史表管理。
  6. """
  7. from __future__ import annotations
  8. import json
  9. import logging
  10. import threading
  11. from concurrent.futures import ThreadPoolExecutor, as_completed
  12. from datetime import date
  13. from pathlib import Path
  14. from typing import Any, Optional
  15. from app.config import settings
  16. from app.models.memory_models import MemorySnapshot
  17. logger = logging.getLogger(__name__)
  18. _memory_lock = threading.Lock()
  19. class MemoryService:
  20. """
  21. 记忆系统核心服务
  22. 职责:
  23. - 记录每日首次启动日期
  24. - 日切时清空前一日数据并重新获取
  25. - 三线程并行获取仪表盘数据(指数、自选、热点资讯)
  26. - 检测自选股数量变化触发刷新
  27. """
  28. def __init__(self, storage_dir: Optional[Path] = None):
  29. self._today: Optional[str] = None
  30. self._snapshot: Optional[MemorySnapshot] = None
  31. self._lock = threading.Lock()
  32. self._watchlist_count: int = 0
  33. self._storage_dir = storage_dir or (settings.DATA_DIR / "memory")
  34. self._storage_dir.mkdir(parents=True, exist_ok=True)
  35. self._state_file = self._storage_dir / "dashboard_state.json"
  36. self._load_state()
  37. # ---- 持久化 ----
  38. def _load_state(self) -> None:
  39. """从磁盘恢复上次的快照状态"""
  40. try:
  41. if self._state_file.exists():
  42. data = json.loads(self._state_file.read_text(encoding="utf-8"))
  43. self._today = data.get("today")
  44. self._watchlist_count = data.get("watchlist_count", 0)
  45. snap = data.get("snapshot")
  46. if snap:
  47. self._snapshot = MemorySnapshot.from_dict(snap)
  48. logger.info("记忆系统状态已加载: date=%s, watchlist_count=%d", self._today, self._watchlist_count)
  49. except Exception as exc:
  50. logger.warning("加载记忆状态失败: %s", exc)
  51. def _save_state(self) -> None:
  52. """将当前快照状态持久化到磁盘"""
  53. try:
  54. data = {
  55. "today": self._today,
  56. "watchlist_count": self._watchlist_count,
  57. "snapshot": self._snapshot.to_dict() if self._snapshot else None,
  58. }
  59. self._state_file.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  60. except Exception as exc:
  61. logger.warning("保存记忆状态失败: %s", exc)
  62. # ---- 日期检测 ----
  63. def _get_today(self) -> str:
  64. return date.today().isoformat()
  65. def is_new_day(self) -> bool:
  66. """检查是否为新的一天"""
  67. return self._today != self._get_today()
  68. def should_refresh(self) -> bool:
  69. """
  70. 判断是否需要刷新数据:
  71. 1. 新的一天
  72. 2. 自选股数量发生变化
  73. """
  74. if self.is_new_day():
  75. logger.info("检测到新的一天,需要刷新仪表盘数据")
  76. return True
  77. try:
  78. from app.services import watchlist_service
  79. wl = watchlist_service.get_watchlist()
  80. current_count = wl.get("total", 0) if wl.get("success") else 0
  81. if current_count != self._watchlist_count and self._watchlist_count > 0:
  82. logger.info("自选股数量变化: %d -> %d,需要刷新", self._watchlist_count, current_count)
  83. self._watchlist_count = current_count
  84. self._save_state()
  85. return True
  86. except Exception as exc:
  87. logger.debug("检查自选股数量时出错: %s", exc)
  88. return False
  89. # ---- 数据获取 ----
  90. def _fetch_indices(self) -> list:
  91. """获取四大指数数据"""
  92. from app.services import market_service
  93. index_names = ("上证指数", "深证成指", "创业板指", "沪深300")
  94. results = []
  95. for name in index_names:
  96. try:
  97. data = market_service.get_index_quote(name)
  98. results.append({"name": name, "data": data})
  99. except Exception as exc:
  100. logger.debug("记忆系统获取指数失败 %s: %s", name, exc)
  101. return results
  102. def _fetch_watchlist(self) -> dict:
  103. """获取自选股列表(含行情数据)"""
  104. from app.services import watchlist_service
  105. try:
  106. wl = watchlist_service.get_watchlist()
  107. if wl.get("success"):
  108. self._watchlist_count = wl.get("total", 0)
  109. return wl
  110. except Exception as exc:
  111. logger.debug("记忆系统获取自选股失败: %s", exc)
  112. return {"success": False, "stocks": [], "total": 0}
  113. def _fetch_hot_news(self) -> dict:
  114. """获取热点资讯"""
  115. from app.services import news_service
  116. try:
  117. return news_service.search_market_news() or {}
  118. except Exception as exc:
  119. logger.debug("记忆系统获取热点资讯失败: %s", exc)
  120. return {}
  121. def parallel_fetch(self) -> MemorySnapshot:
  122. """
  123. 三线程并行获取仪表盘数据:指数、自选、热点资讯
  124. """
  125. logger.info("记忆系统: 开始三线程并行获取仪表盘数据...")
  126. with ThreadPoolExecutor(max_workers=3) as executor:
  127. future_indices = executor.submit(self._fetch_indices)
  128. future_watchlist = executor.submit(self._fetch_watchlist)
  129. future_news = executor.submit(self._fetch_hot_news)
  130. results: dict[str, Any] = {}
  131. for future in as_completed([future_indices, future_watchlist, future_news]):
  132. try:
  133. value = future.result()
  134. except Exception as exc:
  135. logger.warning("并行获取任务失败: %s", exc)
  136. value = None
  137. if future == future_indices:
  138. results["indices"] = value or []
  139. elif future == future_watchlist:
  140. results["watchlist"] = value or {}
  141. elif future == future_news:
  142. results["hot_news"] = value or {}
  143. today = self._get_today()
  144. snapshot = MemorySnapshot(
  145. date_str=today,
  146. indices=results.get("indices", []),
  147. watchlist=results.get("watchlist", {}),
  148. hot_news=results.get("hot_news", {}),
  149. watchlist_count=self._watchlist_count,
  150. )
  151. with self._lock:
  152. self._today = today
  153. self._snapshot = snapshot
  154. self._save_state()
  155. logger.info("记忆系统: 仪表盘数据获取完成 (date=%s, indices=%d, watchlist=%d)",
  156. today, len(snapshot.indices), snapshot.watchlist_count)
  157. return snapshot
  158. # ---- 公共接口 ----
  159. def get_snapshot(self) -> Optional[MemorySnapshot]:
  160. """获取当前缓存的仪表盘快照"""
  161. with self._lock:
  162. return self._snapshot
  163. def get_indices(self) -> list:
  164. """获取缓存的指数数据"""
  165. snap = self.get_snapshot()
  166. return snap.indices if snap else []
  167. def get_watchlist(self) -> dict:
  168. """获取缓存的自选股数据"""
  169. snap = self.get_snapshot()
  170. return snap.watchlist if snap else {}
  171. def get_hot_news(self) -> dict:
  172. """获取缓存的热点资讯数据"""
  173. snap = self.get_snapshot()
  174. return snap.hot_news if snap else {}
  175. def clear(self) -> None:
  176. """清空所有记忆数据"""
  177. with self._lock:
  178. self._today = None
  179. self._snapshot = None
  180. self._watchlist_count = 0
  181. try:
  182. if self._state_file.exists():
  183. self._state_file.unlink()
  184. except Exception:
  185. pass
  186. def get_stats(self) -> dict:
  187. """获取记忆系统状态"""
  188. with self._lock:
  189. return {
  190. "today": self._today,
  191. "has_snapshot": self._snapshot is not None,
  192. "watchlist_count": self._watchlist_count,
  193. "indices_count": len(self._snapshot.indices) if self._snapshot else 0,
  194. "storage_dir": str(self._storage_dir),
  195. }
  196. _memory_svc: Optional[MemoryService] = None
  197. def get_memory_service() -> MemoryService:
  198. """获取 MemoryService 全局单例"""
  199. global _memory_svc
  200. if _memory_svc is None:
  201. with _memory_lock:
  202. if _memory_svc is None:
  203. _memory_svc = MemoryService()
  204. return _memory_svc