base.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. HealthRecordAgent 基础智能体类
  3. """
  4. import asyncio
  5. import json
  6. import logging
  7. from abc import ABC, abstractmethod
  8. from typing import Any, Dict, List, Callable, Optional, ClassVar
  9. from datetime import datetime
  10. from core.config import get_config
  11. from core.llm_adapter import get_llm_adapter
  12. from core.exceptions import AgentException, TimeoutException
  13. from enum import Enum
  14. # 全局任务状态管理
  15. TASKS = {}
  16. def create_task(task_id: str, user_id: str | None = None):
  17. TASKS[task_id] = {
  18. "task_id": task_id,
  19. "user_id": user_id,
  20. "state": "running",
  21. "agents": {
  22. "PlannerAgent": "pending",
  23. "HealthIndicatorAgent": "pending",
  24. "RiskAssessmentAgent": "pending",
  25. "AdviceAgent": "pending",
  26. "ReportAgent": "pending"},
  27. "report": None, # 最终报告
  28. }
  29. def update_agent_state(task_id: str, agent_name: str, state: str, partial_report=None):
  30. task = TASKS.get(task_id)
  31. if not task:
  32. return
  33. task["agents"][agent_name] = state
  34. if partial_report:
  35. task["report"] = partial_report
  36. def complete_task(task_id: str, report: dict):
  37. task = TASKS.get(task_id)
  38. if not task:
  39. return
  40. task["state"] = "completed"
  41. task["report"] = report
  42. for agent in task["agents"]:
  43. task["agents"][agent] = "completed"
  44. def get_task_status(task_id: str):
  45. return TASKS.get(task_id)
  46. class TraceLevel(str, Enum):
  47. INFO = "INFO"
  48. DEBUG = "DEBUG"
  49. TRACE = "TRACE"
  50. ERROR = "ERROR"
  51. logger = logging.getLogger(__name__)
  52. class BaseAgent(ABC):
  53. """
  54. 基础智能体抽象类
  55. """
  56. def __init__(
  57. self, name: str, llm = None,
  58. max_steps: int = None, timeout: int = None, debug: bool = True, task_id = None):
  59. self.name = name
  60. self.config = get_config()
  61. self.llm = llm or get_llm_adapter()
  62. self.max_steps = max_steps or self.config.agent.max_steps
  63. self.timeout = timeout or self.config.agent.timeout
  64. self.history = []
  65. self.tools = {}
  66. self.state = "idle"
  67. self.created_at = datetime.now()
  68. self.debug = debug
  69. self.traces: List[Dict[str, Any]] = []
  70. self.task_id = task_id
  71. # ========== 核心接口 ==========
  72. @abstractmethod
  73. async def run(self, **kwargs) -> Any:
  74. """Agent 执行入口"""
  75. pass
  76. # ========== LLM 思考 ==========
  77. async def think(self, prompt: str, context: Dict = None) -> str:
  78. """调用LLM进行思考"""
  79. try:
  80. # 构建完整的提示词
  81. full_prompt = prompt
  82. # 添加上下文信息
  83. if context:
  84. context_str = json.dumps(context, ensure_ascii=False, indent=2)
  85. full_prompt = f"上下文信息:\n{context_str}\n\n任务:\n{prompt}"
  86. # 添加历史记录
  87. if self.history:
  88. history_str = "\n".join(self.history[-10:]) # 只保留最近10条
  89. full_prompt += f"\n\n历史记录:\n{history_str}"
  90. self.trace("LLM CALL",
  91. {
  92. "prompt_length": len(full_prompt),
  93. "history_length": len(self.history)
  94. },
  95. TraceLevel.INFO
  96. )
  97. start = datetime.now()
  98. # 调用 HelloAgent LLM
  99. response = await asyncio.wait_for(
  100. self.llm.ainvoke(full_prompt),
  101. timeout=self.timeout
  102. )
  103. duration = (datetime.now() - start).total_seconds()
  104. self.trace("LLM TTHINKING TIME",
  105. {
  106. "duration_sec": duration,
  107. "prompt_tokens": len(full_prompt),
  108. }
  109. )
  110. response_text = response.content if hasattr(response, 'content') else str(response)
  111. self.trace("LLM RESPONSE", response_text)
  112. self._add_to_history(f"LLM prompt: {prompt}")
  113. self._add_to_history(f"LLM response: {response_text}")
  114. return response_text
  115. except asyncio.TimeoutError:
  116. raise TimeoutException(f"LLM思考超时")
  117. except Exception as e:
  118. raise AgentException(f"LLM思考失败: {str(e)}")
  119. # ========== Tool 机制 ==========
  120. def add_tool(self, tool_name: str, tool_func: Callable, description: str = ""):
  121. """添加工具"""
  122. self.tools[tool_name] = {
  123. "function": tool_func,
  124. "description": description
  125. }
  126. def get_tools_description(self) -> str:
  127. """获取工具描述"""
  128. if not self.tools:
  129. return "暂无可用工具"
  130. descriptions = []
  131. for name, tool_info in self.tools.items():
  132. descriptions.append(f"- {name}: {tool_info['description']}")
  133. return "\n".join(descriptions)
  134. async def call_tool(self, tool_name: str, tool_input: Any) -> Any:
  135. """调用工具"""
  136. if tool_name not in self.tools:
  137. raise AgentException(f"工具 '{tool_name}' 不存在")
  138. try:
  139. tool_func = self.tools[tool_name]["function"]
  140. if asyncio.iscoroutinefunction(tool_func):
  141. result = await asyncio.wait_for(
  142. tool_func(tool_input),
  143. timeout=self.timeout
  144. )
  145. else:
  146. result = await asyncio.wait_for(
  147. asyncio.to_thread(tool_func, tool_input),
  148. timeout=self.timeout
  149. )
  150. self._add_to_history(f"Tool {tool_name} called with input: {tool_input}")
  151. self._add_to_history(f"Tool {tool_name} result: {result}")
  152. return result
  153. except asyncio.TimeoutError:
  154. raise TimeoutException(f"工具 '{tool_name}' 执行超时")
  155. except Exception as e:
  156. raise AgentException(f"工具 '{tool_name}' 执行失败: {str(e)}")
  157. # ========== 状态 & 历史 ==========
  158. def _add_to_history(self, message: str):
  159. """添加到历史记录"""
  160. timestamp = datetime.now().isoformat()
  161. self.history.append(f"[{timestamp}] {message}")
  162. # 限制历史记录长度
  163. if len(self.history) > 100:
  164. self.history = self.history[-50:]
  165. def get_history(self, limit: int = 10) -> List[str]:
  166. """获取历史记录"""
  167. return self.history[-limit:]
  168. def clear_history(self):
  169. """清空历史记录"""
  170. self.history = []
  171. def set_state(self, state: str):
  172. """设置智能体状态"""
  173. self.state = state
  174. # 更新全局任务状态
  175. if self.task_id:
  176. update_agent_state(self.task_id, self.name, state)
  177. self.trace("STATE CHANGE",
  178. {
  179. "state": state
  180. }
  181. )
  182. logger.info(f"Agent {self.name} state changed to: {state}")
  183. def get_status(self) -> Dict[str, Any]:
  184. """获取智能体状态"""
  185. return {
  186. "name": self.name,
  187. "state": self.state,
  188. "created_at": self.created_at.isoformat(),
  189. "history_count": len(self.history),
  190. "tools_count": len(self.tools),
  191. "max_steps": self.max_steps,
  192. "timeout": self.timeout
  193. }
  194. async def validate_input(self, input_data: Dict[str, Any]) -> bool:
  195. """验证输入数据"""
  196. required_fields = self.get_required_fields()
  197. for field in required_fields:
  198. if field not in input_data:
  199. raise AgentException(f"缺少必需字段: {field}")
  200. return True
  201. @abstractmethod
  202. def get_required_fields(self) -> List[str]:
  203. """获取必需的输入字段"""
  204. pass
  205. def trace(self, title: str, data: Any, level: TraceLevel = TraceLevel.DEBUG):
  206. """统一Agent调试输出"""
  207. event = {
  208. "agent": self.name,
  209. "title": title,
  210. "timestamp": datetime.now().isoformat(),
  211. "data": data
  212. }
  213. self.traces.append({
  214. **event,
  215. "level": level
  216. })
  217. if not self.debug:
  218. return
  219. if level in [TraceLevel.INFO, TraceLevel.ERROR]:
  220. logger.info(f"[{self.name}] {title}")
  221. return
  222. if level == TraceLevel.DEBUG:
  223. preview = self._preview(data)
  224. logger.debug(f"[{self.name}] {title}: {preview}")
  225. try:
  226. print(json.dumps(data, indent=2, ensure_ascii=False))
  227. except Exception:
  228. print(data)
  229. def trace_step(self, step: str, status: str):
  230. self.trace(
  231. "STEP",
  232. {
  233. "step": step,
  234. "status": status
  235. }
  236. )
  237. def get_traces(self) -> List[Dict[str, Any]]:
  238. return self.traces
  239. def _preview(self, data, max_len: int = 300):
  240. """日志摘要"""
  241. if data is None:
  242. return ""
  243. text = str(data)
  244. if len(text) > max_len:
  245. return text[:max_len] + "...(truncated)"
  246. return text