diet.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import asyncio
  2. from typing import List, Literal, Optional
  3. from fastapi import APIRouter, Body, HTTPException
  4. from pydantic import BaseModel, Field, field_validator
  5. from memory.store import (
  6. get_diet_run,
  7. insert_diet_reflect,
  8. list_diet_runs_for_user,
  9. list_recent_diet_reflect,
  10. )
  11. from service.diet_recommend_service import DietRecommendService, replay_diet_run
  12. from service.observability_views import build_diet_observability
  13. from rag.indexers import index_reflect_event
  14. router = APIRouter()
  15. class DietContext(BaseModel):
  16. today_food_log_text: str = Field(
  17. ..., min_length=4, max_length=8000, description="今天吃了什么(自由文本)"
  18. )
  19. goal: Literal["muscle_gain", "fat_loss", "maintain"] = Field(
  20. default="muscle_gain", description="健康目标"
  21. )
  22. channels: List[str] = Field(
  23. default_factory=lambda: ["convenience_store", "delivery"],
  24. description="可购买渠道标签",
  25. )
  26. activity_context: str = Field(default="", max_length=2000, description="运动/睡眠等上下文")
  27. free_notes: str = Field(
  28. default="", max_length=2000, description="额外说明(如只有便利店)"
  29. )
  30. class DietRecommendRequest(BaseModel):
  31. user_id: str = Field(..., min_length=1, max_length=256)
  32. context: DietContext
  33. @field_validator("user_id")
  34. @classmethod
  35. def strip_uid(cls, v: str) -> str:
  36. v = v.strip()
  37. if not v:
  38. raise ValueError("user_id 不能为空")
  39. return v
  40. class DietReplayRequest(BaseModel):
  41. """可选:传入 user_id 时必须与 run 一致,防止误重放。"""
  42. user_id: Optional[str] = Field(default=None, max_length=256)
  43. class DietReflectRequest(BaseModel):
  44. user_id: str = Field(..., min_length=1, max_length=256)
  45. diet_run_id: str = Field(..., min_length=8, max_length=64)
  46. followed: bool = Field(..., description="是否按上次推荐执行")
  47. reason_code: Optional[
  48. Literal["cant_buy", "too_late", "dont_want", "executed_ok", "other"]
  49. ] = Field(default=None, description="未执行或总结原因类型")
  50. reason_detail: Optional[str] = Field(default=None, max_length=2000)
  51. @field_validator("user_id", "diet_run_id")
  52. @classmethod
  53. def strip_ids(cls, v: str) -> str:
  54. return v.strip()
  55. @router.post("/diet/recommend")
  56. async def diet_recommend(body: DietRecommendRequest):
  57. """
  58. 饮食推荐:阶段 2 为 **Nutritionist → Coach → Habit** 三 Agent,固定 JSON schema + Pydantic 校验;
  59. 每阶段最多 2 次尝试,失败则降级并写入 `errors` / `degraded`。
  60. 仍落库 `diet_runs`,并读取 Reflect 记忆。
  61. """
  62. svc = DietRecommendService()
  63. ctx = body.context.model_dump()
  64. result = await svc.run(body.user_id, ctx)
  65. return result
  66. @router.post("/diet/reflect")
  67. async def diet_reflect(body: DietReflectRequest):
  68. """
  69. Reflect:用户反馈是否执行及原因,写入 diet_reflect;下次 recommend 自动读取。
  70. """
  71. row = get_diet_run(body.diet_run_id)
  72. if not row:
  73. raise HTTPException(status_code=404, detail="diet_run_id 不存在")
  74. if row.get("user_id") != body.user_id:
  75. raise HTTPException(status_code=403, detail="该 run 不属于此 user_id")
  76. rc = body.reason_code
  77. if body.followed and rc is None:
  78. rc = "executed_ok"
  79. rid = insert_diet_reflect(
  80. user_id=body.user_id,
  81. diet_run_id=body.diet_run_id,
  82. followed=body.followed,
  83. reason_code=rc,
  84. reason_detail=body.reason_detail,
  85. )
  86. asyncio.create_task(asyncio.to_thread(index_reflect_event, rid))
  87. return {
  88. "ok": True,
  89. "reflect_id": rid,
  90. "user_id": body.user_id,
  91. "diet_run_id": body.diet_run_id,
  92. }
  93. @router.get("/diet/users/{user_id}/runs")
  94. async def diet_runs(user_id: str, limit: int = 20):
  95. uid = user_id.strip()
  96. if not uid:
  97. raise HTTPException(status_code=400, detail="user_id 无效")
  98. return {"user_id": uid, "items": list_diet_runs_for_user(uid, limit=limit)}
  99. @router.get("/diet/users/{user_id}/reflect_history")
  100. async def diet_reflect_history(user_id: str, limit: int = 20):
  101. uid = user_id.strip()
  102. if not uid:
  103. raise HTTPException(status_code=400, detail="user_id 无效")
  104. return {"user_id": uid, "items": list_recent_diet_reflect(uid, limit=limit)}
  105. @router.get("/diet/runs/{run_id}")
  106. async def diet_run_detail(run_id: str):
  107. row = get_diet_run(run_id.strip())
  108. if not row:
  109. raise HTTPException(status_code=404, detail="未找到该饮食推荐 run")
  110. return row
  111. @router.get("/diet/runs/{run_id}/observability")
  112. async def diet_run_observability(run_id: str):
  113. """
  114. 阶段 3:可观测性视图 — timeline / errors / replay 说明(trace 已持久化在 diet_runs)。
  115. """
  116. row = get_diet_run(run_id.strip())
  117. if not row:
  118. raise HTTPException(status_code=404, detail="未找到该饮食推荐 run")
  119. return build_diet_observability(row)
  120. @router.post("/diet/runs/{run_id}/replay")
  121. async def diet_run_replay(
  122. run_id: str,
  123. body: DietReplayRequest | None = Body(default=None),
  124. ):
  125. """
  126. 阶段 3:用该 run 落库的 input 重跑流水线(新 run_id;列 replayed_from_run_id 与 output.replayed_from 溯源)。
  127. Mock 工具确定性较高,LLM 输出仍可能不同。
  128. """
  129. rid = run_id.strip()
  130. row = get_diet_run(rid)
  131. if not row:
  132. raise HTTPException(status_code=404, detail="run 不存在")
  133. if body and body.user_id and body.user_id.strip() != row["user_id"]:
  134. raise HTTPException(status_code=403, detail="user_id 与 run 不匹配")
  135. try:
  136. return await replay_diet_run(rid)
  137. except ValueError as e:
  138. raise HTTPException(status_code=400, detail=str(e))