health_analysis.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. 健康分析工作流服务
  3. 负责串联多个 Agent,完成一次完整的健康报告分析
  4. """
  5. import asyncio
  6. import logging
  7. from typing import Dict, Any
  8. from uuid import uuid4
  9. from agents.planner import PlannerAgent
  10. from agents.health_indicator import HealthIndicatorAgent
  11. from agents.risk_assess import RiskAssessmentAgent
  12. from agents.advice import AdviceAgent
  13. from agents.report import ReportAgent
  14. from agents.base import create_task, update_agent_state, complete_task
  15. from memory.store import save_completed_report_run
  16. from rag.indexers import index_report_run
  17. from rag.retriever import retrieve
  18. logger = logging.getLogger(__name__)
  19. class HealthAnalysisService:
  20. def __init__(self, task_id: str | None = None, user_id: str | None = None):
  21. self.task_id = task_id or str(uuid4())
  22. self.user_id = user_id
  23. # 任务初始化
  24. create_task(self.task_id, user_id=user_id)
  25. self.planner = PlannerAgent(task_id=self.task_id)
  26. self.indicator_agent = HealthIndicatorAgent(task_id=self.task_id)
  27. self.risk_agent = RiskAssessmentAgent(task_id=self.task_id)
  28. self.advice_agent = AdviceAgent(task_id=self.task_id)
  29. self.report_agent = ReportAgent(task_id=self.task_id)
  30. def _bundle_agent_traces(self, limit_per_agent: int = 80) -> Dict[str, Any]:
  31. """阶段 3:各 Agent 的 trace 切片落库。"""
  32. pairs = [
  33. ("PlannerAgent", self.planner),
  34. ("HealthIndicatorAgent", self.indicator_agent),
  35. ("RiskAssessmentAgent", self.risk_agent),
  36. ("AdviceAgent", self.advice_agent),
  37. ("ReportAgent", self.report_agent),
  38. ]
  39. out: Dict[str, Any] = {}
  40. for name, ag in pairs:
  41. try:
  42. t = ag.get_traces()
  43. out[name] = t[-limit_per_agent:] if len(t) > limit_per_agent else list(t)
  44. except Exception:
  45. out[name] = []
  46. return out
  47. async def run(self, report_text: str, user_id: str) -> Dict[str, Any]:
  48. """
  49. 执行完整的健康分析流程
  50. """
  51. # 1.任务规划
  52. update_agent_state(self.task_id, "PlannerAgent", "running")
  53. plan_result = await self.planner.run({"goal": f"分析以下体检报告并制定执行计划:\n{report_text}"})
  54. update_agent_state(self.task_id, "PlannerAgent", "completed")
  55. # 2.健康指标分析
  56. update_agent_state(self.task_id, "HealthIndicatorAgent", "running")
  57. indicator_result = await self.indicator_agent.run({
  58. "report_text": report_text,
  59. "plan": plan_result
  60. })
  61. update_agent_state(self.task_id, "HealthIndicatorAgent", "completed", partial_report={"indicator_results": indicator_result})
  62. # 3. 风险评估
  63. update_agent_state(self.task_id, "RiskAssessmentAgent", "running")
  64. risk_result = await self.risk_agent.run({
  65. "indicator_results": indicator_result
  66. })
  67. update_agent_state(self.task_id, "RiskAssessmentAgent", "completed", partial_report={"risk_assessment": risk_result})
  68. rag_result = await asyncio.to_thread(
  69. retrieve,
  70. user_id,
  71. {
  72. "scenario": "health_report_analysis",
  73. "risk_focus": str(risk_result.get("overall_risk_level", "")),
  74. "query": "历史体检变化与执行反馈",
  75. },
  76. )
  77. retrieved_memory = rag_result.get("summary", "(暂无召回记忆)")
  78. # 4. 健康建议生成
  79. update_agent_state(self.task_id, "AdviceAgent", "running")
  80. advice_result = await self.advice_agent.run({
  81. "risk_assessment": risk_result,
  82. "retrieved_memory": retrieved_memory,
  83. })
  84. update_agent_state(self.task_id, "AdviceAgent", "completed", partial_report={"advice": advice_result})
  85. # 5. 报告汇总
  86. update_agent_state(self.task_id, "ReportAgent", "running")
  87. final_report = await self.report_agent.run({
  88. "indicators": indicator_result,
  89. "risk_assessment": risk_result,
  90. "advice": advice_result,
  91. "retrieved_memory": retrieved_memory,
  92. })
  93. update_agent_state(self.task_id, "ReportAgent", "completed")
  94. complete_task(self.task_id, final_report)
  95. try:
  96. traces = self._bundle_agent_traces()
  97. await asyncio.to_thread(
  98. save_completed_report_run,
  99. user_id,
  100. self.task_id,
  101. final_report,
  102. traces,
  103. )
  104. except Exception as e:
  105. logger.exception("写入 SQLite 履历失败(分析结果仍有效): %s", e)
  106. try:
  107. await asyncio.to_thread(index_report_run, self.task_id)
  108. except Exception as e:
  109. logger.warning("report run 向量索引失败(不影响返回): %s", e)
  110. return self.task_id
  111. # ---------- 临时本地验证入口 ----------
  112. async def _demo():
  113. demo_text = """
  114. 男性,28岁,BMI 27.3,血压 145/95 mmHg,
  115. 总胆固醇 6.2 mmol/L,空腹血糖 6.1 mmol/L。
  116. """
  117. workflow = HealthAnalysisService(user_id="local-demo-user")
  118. result = await workflow.run(demo_text, user_id="local-demo-user")
  119. print(result)
  120. if __name__ == "__main__":
  121. asyncio.run(_demo())