store.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. """
  2. SQLite 持久化:用户表 + 体检分析履历(report_runs)。
  3. 同步 API,在 async 路由中通过 asyncio.to_thread 调用。
  4. """
  5. from __future__ import annotations
  6. import json
  7. import logging
  8. import os
  9. import sqlite3
  10. from datetime import datetime, timezone
  11. from pathlib import Path
  12. from typing import Any, Dict, List, Optional
  13. logger = logging.getLogger(__name__)
  14. # backend/memory/store.py -> 项目根目录(HealthRecordAgent)
  15. _PROJECT_ROOT = Path(__file__).resolve().parents[2]
  16. _DEFAULT_DATA_DIR = _PROJECT_ROOT / "data"
  17. _DEFAULT_DB_PATH = _DEFAULT_DATA_DIR / "health_memory.db"
  18. def get_db_path() -> Path:
  19. override = os.getenv("HEALTH_MEMORY_DB_PATH")
  20. if override:
  21. return Path(override).expanduser().resolve()
  22. return _DEFAULT_DB_PATH
  23. def _connect() -> sqlite3.Connection:
  24. path = get_db_path()
  25. path.parent.mkdir(parents=True, exist_ok=True)
  26. conn = sqlite3.connect(str(path), check_same_thread=False)
  27. conn.row_factory = sqlite3.Row
  28. conn.execute("PRAGMA foreign_keys = ON")
  29. return conn
  30. def _ensure_legacy_columns(conn: sqlite3.Connection) -> None:
  31. """旧库补列(阶段 3:体检 trace、饮食 replay 溯源)。"""
  32. cols = {r[1] for r in conn.execute("PRAGMA table_info(report_runs)").fetchall()}
  33. if "agent_trace_json" not in cols:
  34. conn.execute("ALTER TABLE report_runs ADD COLUMN agent_trace_json TEXT")
  35. cols_d = {r[1] for r in conn.execute("PRAGMA table_info(diet_runs)").fetchall()}
  36. if "replayed_from_run_id" not in cols_d:
  37. conn.execute("ALTER TABLE diet_runs ADD COLUMN replayed_from_run_id TEXT")
  38. def init_db() -> None:
  39. """创建表与索引(幂等)。"""
  40. with _connect() as conn:
  41. conn.executescript(
  42. """
  43. CREATE TABLE IF NOT EXISTS users (
  44. user_id TEXT PRIMARY KEY,
  45. created_at TEXT NOT NULL
  46. );
  47. CREATE TABLE IF NOT EXISTS report_runs (
  48. id INTEGER PRIMARY KEY AUTOINCREMENT,
  49. task_id TEXT NOT NULL UNIQUE,
  50. user_id TEXT NOT NULL,
  51. created_at TEXT NOT NULL,
  52. summary_text TEXT,
  53. report_json TEXT NOT NULL,
  54. agent_trace_json TEXT,
  55. FOREIGN KEY (user_id) REFERENCES users (user_id)
  56. );
  57. CREATE INDEX IF NOT EXISTS idx_report_runs_user_created
  58. ON report_runs (user_id, created_at DESC);
  59. CREATE TABLE IF NOT EXISTS user_profiles (
  60. user_id TEXT PRIMARY KEY,
  61. profile_json TEXT NOT NULL DEFAULT '{}',
  62. updated_at TEXT NOT NULL,
  63. FOREIGN KEY (user_id) REFERENCES users (user_id)
  64. );
  65. CREATE TABLE IF NOT EXISTS diet_runs (
  66. run_id TEXT PRIMARY KEY,
  67. user_id TEXT NOT NULL,
  68. created_at TEXT NOT NULL,
  69. input_json TEXT NOT NULL,
  70. steps_trace_json TEXT NOT NULL,
  71. output_json TEXT NOT NULL,
  72. replayed_from_run_id TEXT,
  73. FOREIGN KEY (user_id) REFERENCES users (user_id)
  74. );
  75. CREATE TABLE IF NOT EXISTS diet_reflect (
  76. id INTEGER PRIMARY KEY AUTOINCREMENT,
  77. user_id TEXT NOT NULL,
  78. diet_run_id TEXT NOT NULL,
  79. followed INTEGER NOT NULL DEFAULT 0,
  80. reason_code TEXT,
  81. reason_detail TEXT,
  82. created_at TEXT NOT NULL,
  83. FOREIGN KEY (user_id) REFERENCES users (user_id)
  84. );
  85. CREATE INDEX IF NOT EXISTS idx_diet_runs_user_created
  86. ON diet_runs (user_id, created_at DESC);
  87. CREATE INDEX IF NOT EXISTS idx_diet_reflect_user_created
  88. ON diet_reflect (user_id, created_at DESC);
  89. """
  90. )
  91. _ensure_legacy_columns(conn)
  92. conn.commit()
  93. logger.info("SQLite 记忆库已就绪: %s", get_db_path())
  94. def ensure_user(user_id: str) -> None:
  95. now = datetime.now(timezone.utc).isoformat()
  96. with _connect() as conn:
  97. conn.execute(
  98. "INSERT OR IGNORE INTO users (user_id, created_at) VALUES (?, ?)",
  99. (user_id, now),
  100. )
  101. conn.commit()
  102. def save_completed_report_run(
  103. user_id: str,
  104. task_id: str,
  105. final_report: Dict[str, Any],
  106. agent_trace: Optional[Dict[str, Any]] = None,
  107. ) -> None:
  108. """
  109. 分析成功完成后写入一条履历;失败时由调用方捕获日志,不影响主流程。
  110. agent_trace: 各 Agent 的 trace 列表(阶段 3 可观测性)。
  111. """
  112. ensure_user(user_id)
  113. summary = ""
  114. report_inner = final_report.get("report") if isinstance(final_report, dict) else None
  115. if isinstance(report_inner, dict):
  116. s = report_inner.get("summary")
  117. if isinstance(s, str):
  118. summary = s[:8000]
  119. elif s is not None:
  120. summary = str(s)[:8000]
  121. payload = json.dumps(final_report, ensure_ascii=False)
  122. trace_payload = json.dumps(agent_trace, ensure_ascii=False) if agent_trace else None
  123. now = datetime.now(timezone.utc).isoformat()
  124. with _connect() as conn:
  125. conn.execute(
  126. """
  127. INSERT INTO report_runs (task_id, user_id, created_at, summary_text, report_json, agent_trace_json)
  128. VALUES (?, ?, ?, ?, ?, ?)
  129. """,
  130. (task_id, user_id, now, summary or None, payload, trace_payload),
  131. )
  132. conn.commit()
  133. def list_report_runs_for_user(user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
  134. limit = max(1, min(limit, 200))
  135. with _connect() as conn:
  136. cur = conn.execute(
  137. """
  138. SELECT task_id, user_id, created_at, summary_text
  139. FROM report_runs
  140. WHERE user_id = ?
  141. ORDER BY created_at DESC
  142. LIMIT ?
  143. """,
  144. (user_id, limit),
  145. )
  146. rows = cur.fetchall()
  147. return [dict(r) for r in rows]
  148. def get_report_run(task_id: str) -> Optional[Dict[str, Any]]:
  149. with _connect() as conn:
  150. cur = conn.execute(
  151. """
  152. SELECT task_id, user_id, created_at, summary_text, report_json, agent_trace_json
  153. FROM report_runs
  154. WHERE task_id = ?
  155. """,
  156. (task_id,),
  157. )
  158. row = cur.fetchone()
  159. if not row:
  160. return None
  161. d = dict(row)
  162. if d.get("report_json"):
  163. try:
  164. d["report"] = json.loads(d["report_json"])
  165. except json.JSONDecodeError:
  166. d["report"] = None
  167. del d["report_json"]
  168. raw_trace = d.pop("agent_trace_json", None)
  169. if raw_trace:
  170. try:
  171. d["agent_trace"] = json.loads(raw_trace)
  172. except json.JSONDecodeError:
  173. d["agent_trace"] = None
  174. else:
  175. d["agent_trace"] = None
  176. return d
  177. def save_diet_run(
  178. user_id: str,
  179. run_id: str,
  180. input_payload: Dict[str, Any],
  181. steps_trace: List[Dict[str, Any]],
  182. output_payload: Dict[str, Any],
  183. replayed_from_run_id: Optional[str] = None,
  184. ) -> None:
  185. ensure_user(user_id)
  186. now = datetime.now(timezone.utc).isoformat()
  187. with _connect() as conn:
  188. conn.execute(
  189. """
  190. INSERT INTO diet_runs (run_id, user_id, created_at, input_json, steps_trace_json, output_json, replayed_from_run_id)
  191. VALUES (?, ?, ?, ?, ?, ?, ?)
  192. """,
  193. (
  194. run_id,
  195. user_id,
  196. now,
  197. json.dumps(input_payload, ensure_ascii=False),
  198. json.dumps(steps_trace, ensure_ascii=False),
  199. json.dumps(output_payload, ensure_ascii=False),
  200. replayed_from_run_id,
  201. ),
  202. )
  203. conn.commit()
  204. def insert_diet_reflect(
  205. user_id: str,
  206. diet_run_id: str,
  207. followed: bool,
  208. reason_code: str | None,
  209. reason_detail: str | None,
  210. ) -> int:
  211. ensure_user(user_id)
  212. now = datetime.now(timezone.utc).isoformat()
  213. with _connect() as conn:
  214. cur = conn.execute(
  215. """
  216. INSERT INTO diet_reflect (user_id, diet_run_id, followed, reason_code, reason_detail, created_at)
  217. VALUES (?, ?, ?, ?, ?, ?)
  218. """,
  219. (
  220. user_id,
  221. diet_run_id,
  222. 1 if followed else 0,
  223. reason_code,
  224. (reason_detail or "")[:2000] or None,
  225. now,
  226. ),
  227. )
  228. conn.commit()
  229. return int(cur.lastrowid)
  230. def list_recent_diet_reflect(user_id: str, limit: int = 8) -> List[Dict[str, Any]]:
  231. limit = max(1, min(limit, 50))
  232. with _connect() as conn:
  233. cur = conn.execute(
  234. """
  235. SELECT id, diet_run_id, followed, reason_code, reason_detail, created_at
  236. FROM diet_reflect
  237. WHERE user_id = ?
  238. ORDER BY created_at DESC
  239. LIMIT ?
  240. """,
  241. (user_id, limit),
  242. )
  243. rows = cur.fetchall()
  244. out = []
  245. for r in rows:
  246. d = dict(r)
  247. d["followed"] = bool(d["followed"])
  248. out.append(d)
  249. return out
  250. def format_reflect_memory_for_prompt(user_id: str, limit: int = 5) -> str:
  251. rows = list_recent_diet_reflect(user_id, limit=limit)
  252. if not rows:
  253. return "(暂无历史执行反馈)"
  254. lines = []
  255. for r in rows:
  256. fl = "已执行" if r["followed"] else "未执行"
  257. rc = r.get("reason_code") or "-"
  258. rd = (r.get("reason_detail") or "").strip()
  259. lines.append(
  260. f"- {r['created_at'][:19]} | run={r['diet_run_id'][:8]}… | {fl} | 原因码={rc}"
  261. + (f" | 说明={rd}" if rd else "")
  262. )
  263. return "\n".join(lines)
  264. def get_diet_run(run_id: str) -> Optional[Dict[str, Any]]:
  265. with _connect() as conn:
  266. cur = conn.execute(
  267. """
  268. SELECT run_id, user_id, created_at, input_json, steps_trace_json, output_json, replayed_from_run_id
  269. FROM diet_runs
  270. WHERE run_id = ?
  271. """,
  272. (run_id,),
  273. )
  274. row = cur.fetchone()
  275. if not row:
  276. return None
  277. d = dict(row)
  278. mapping = {
  279. "input_json": "input",
  280. "steps_trace_json": "steps_trace",
  281. "output_json": "output",
  282. }
  283. for raw_key, out_key in mapping.items():
  284. raw = d.pop(raw_key, None)
  285. if raw:
  286. try:
  287. d[out_key] = json.loads(raw)
  288. except json.JSONDecodeError:
  289. d[out_key] = None
  290. else:
  291. d[out_key] = None
  292. return d
  293. def list_diet_runs_for_user(user_id: str, limit: int = 30) -> List[Dict[str, Any]]:
  294. limit = max(1, min(limit, 100))
  295. with _connect() as conn:
  296. cur = conn.execute(
  297. """
  298. SELECT run_id, user_id, created_at,
  299. json_extract(output_json, '$.meal_plan.total_est_protein_g') AS total_protein
  300. FROM diet_runs
  301. WHERE user_id = ?
  302. ORDER BY created_at DESC
  303. LIMIT ?
  304. """,
  305. (user_id, limit),
  306. )
  307. rows = cur.fetchall()
  308. return [dict(r) for r in rows]
  309. def get_diet_reflect(reflect_id: int) -> Optional[Dict[str, Any]]:
  310. with _connect() as conn:
  311. cur = conn.execute(
  312. """
  313. SELECT id, user_id, diet_run_id, followed, reason_code, reason_detail, created_at
  314. FROM diet_reflect
  315. WHERE id = ?
  316. """,
  317. (reflect_id,),
  318. )
  319. row = cur.fetchone()
  320. if not row:
  321. return None
  322. d = dict(row)
  323. d["followed"] = bool(d["followed"])
  324. return d
  325. def list_all_user_ids(limit: int = 5000) -> List[str]:
  326. limit = max(1, min(limit, 20000))
  327. with _connect() as conn:
  328. cur = conn.execute(
  329. """
  330. SELECT user_id
  331. FROM users
  332. ORDER BY created_at DESC
  333. LIMIT ?
  334. """,
  335. (limit,),
  336. )
  337. rows = cur.fetchall()
  338. return [r["user_id"] for r in rows]
  339. def list_user_memory_chunks_sql(user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
  340. """
  341. SQL 回退检索:按时间抓取用户近期文本记忆。
  342. """
  343. limit = max(1, min(limit, 500))
  344. out: List[Dict[str, Any]] = []
  345. with _connect() as conn:
  346. r1 = conn.execute(
  347. """
  348. SELECT task_id, created_at, summary_text
  349. FROM report_runs
  350. WHERE user_id = ?
  351. ORDER BY created_at DESC
  352. LIMIT ?
  353. """,
  354. (user_id, limit),
  355. ).fetchall()
  356. for r in r1:
  357. txt = (r["summary_text"] or "").strip()
  358. if not txt:
  359. continue
  360. out.append(
  361. {
  362. "chunk_id": f"report:{r['task_id']}",
  363. "user_id": user_id,
  364. "source_type": "report_summary",
  365. "source_id": r["task_id"],
  366. "created_at": r["created_at"],
  367. "text": txt[:8000],
  368. }
  369. )
  370. r2 = conn.execute(
  371. """
  372. SELECT run_id, created_at, output_json
  373. FROM diet_runs
  374. WHERE user_id = ?
  375. ORDER BY created_at DESC
  376. LIMIT ?
  377. """,
  378. (user_id, limit),
  379. ).fetchall()
  380. for r in r2:
  381. txt = ""
  382. try:
  383. obj = json.loads(r["output_json"] or "{}")
  384. mp = obj.get("meal_plan") or {}
  385. items = mp.get("items") or []
  386. hints = (obj.get("habit_extras") or {}).get("execution_hints", [])
  387. txt = ";".join(
  388. [f"{it.get('name')} {it.get('portion')} {it.get('why','')}" for it in items if isinstance(it, dict)]
  389. )
  390. if hints:
  391. txt += "\n执行提示:" + ";".join([str(h) for h in hints])
  392. except Exception:
  393. txt = str(r["output_json"] or "")
  394. txt = txt.strip()
  395. if not txt:
  396. continue
  397. out.append(
  398. {
  399. "chunk_id": f"diet:{r['run_id']}",
  400. "user_id": user_id,
  401. "source_type": "diet_plan",
  402. "source_id": r["run_id"],
  403. "created_at": r["created_at"],
  404. "text": txt[:8000],
  405. }
  406. )
  407. r3 = conn.execute(
  408. """
  409. SELECT id, created_at, followed, reason_code, reason_detail
  410. FROM diet_reflect
  411. WHERE user_id = ?
  412. ORDER BY created_at DESC
  413. LIMIT ?
  414. """,
  415. (user_id, limit),
  416. ).fetchall()
  417. for r in r3:
  418. txt = f"执行={bool(r['followed'])} 原因={r['reason_code'] or '-'} 说明={r['reason_detail'] or ''}".strip()
  419. out.append(
  420. {
  421. "chunk_id": f"reflect:{r['id']}",
  422. "user_id": user_id,
  423. "source_type": "diet_reflect",
  424. "source_id": str(r["id"]),
  425. "created_at": r["created_at"],
  426. "text": txt[:8000],
  427. }
  428. )
  429. out.sort(key=lambda x: x.get("created_at", ""), reverse=True)
  430. return out[:limit]