1
0

base.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """
  2. InnoCore AI 基础智能体类
  3. """
  4. import asyncio
  5. from abc import ABC, abstractmethod
  6. from typing import Dict, List, Optional, Any, Callable
  7. from datetime import datetime
  8. import json
  9. import logging
  10. from core.config import get_config
  11. from core.llm_adapter import get_llm_adapter
  12. from core.exceptions import AgentException, TimeoutException
  13. logger = logging.getLogger(__name__)
  14. class BaseAgent(ABC):
  15. """基础智能体抽象类"""
  16. def __init__(self, name: str, llm = None,
  17. max_steps: int = None, timeout: int = None):
  18. self.name = name
  19. self.config = get_config()
  20. self.llm = llm or get_llm_adapter()
  21. self.max_steps = max_steps or self.config.agent_max_steps
  22. self.timeout = timeout or self.config.agent_timeout
  23. self.history = []
  24. self.tools = {}
  25. self.state = "idle"
  26. self.created_at = datetime.now()
  27. @abstractmethod
  28. async def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
  29. """执行智能体任务"""
  30. pass
  31. def add_tool(self, tool_name: str, tool_func: Callable, description: str = ""):
  32. """添加工具"""
  33. self.tools[tool_name] = {
  34. "function": tool_func,
  35. "description": description
  36. }
  37. def get_tools_description(self) -> str:
  38. """获取工具描述"""
  39. if not self.tools:
  40. return "暂无可用工具"
  41. descriptions = []
  42. for name, tool_info in self.tools.items():
  43. descriptions.append(f"- {name}: {tool_info['description']}")
  44. return "\n".join(descriptions)
  45. async def call_tool(self, tool_name: str, tool_input: Any) -> Any:
  46. """调用工具"""
  47. if tool_name not in self.tools:
  48. raise AgentException(f"工具 '{tool_name}' 不存在")
  49. try:
  50. tool_func = self.tools[tool_name]["function"]
  51. if asyncio.iscoroutinefunction(tool_func):
  52. result = await asyncio.wait_for(
  53. tool_func(tool_input),
  54. timeout=self.timeout
  55. )
  56. else:
  57. result = await asyncio.wait_for(
  58. asyncio.to_thread(tool_func, tool_input),
  59. timeout=self.timeout
  60. )
  61. self._add_to_history(f"Tool {tool_name} called with input: {tool_input}")
  62. self._add_to_history(f"Tool {tool_name} result: {result}")
  63. return result
  64. except asyncio.TimeoutError:
  65. raise TimeoutException(f"工具 '{tool_name}' 执行超时")
  66. except Exception as e:
  67. raise AgentException(f"工具 '{tool_name}' 执行失败: {str(e)}")
  68. async def think(self, prompt: str, context: Dict = None) -> str:
  69. """调用LLM进行思考"""
  70. try:
  71. # 构建完整的提示词
  72. full_prompt = prompt
  73. # 添加上下文信息
  74. if context:
  75. context_str = json.dumps(context, ensure_ascii=False, indent=2)
  76. full_prompt = f"上下文信息:\n{context_str}\n\n任务:\n{prompt}"
  77. # 添加历史记录
  78. if self.history:
  79. history_str = "\n".join(self.history[-10:]) # 只保留最近10条
  80. full_prompt += f"\n\n历史记录:\n{history_str}"
  81. # 调用 HelloAgent LLM
  82. response = await asyncio.wait_for(
  83. self.llm.ainvoke(full_prompt),
  84. timeout=self.timeout
  85. )
  86. response_text = response.content if hasattr(response, 'content') else str(response)
  87. self._add_to_history(f"LLM prompt: {prompt}")
  88. self._add_to_history(f"LLM response: {response_text}")
  89. return response_text
  90. except asyncio.TimeoutError:
  91. raise TimeoutException("LLM思考超时")
  92. except Exception as e:
  93. raise AgentException(f"LLM思考失败: {str(e)}")
  94. def _add_to_history(self, message: str):
  95. """添加到历史记录"""
  96. timestamp = datetime.now().isoformat()
  97. self.history.append(f"[{timestamp}] {message}")
  98. # 限制历史记录长度
  99. if len(self.history) > 100:
  100. self.history = self.history[-50:]
  101. def get_history(self, limit: int = 10) -> List[str]:
  102. """获取历史记录"""
  103. return self.history[-limit:]
  104. def clear_history(self):
  105. """清空历史记录"""
  106. self.history = []
  107. def set_state(self, state: str):
  108. """设置智能体状态"""
  109. self.state = state
  110. logger.info(f"Agent {self.name} state changed to: {state}")
  111. def get_status(self) -> Dict[str, Any]:
  112. """获取智能体状态"""
  113. return {
  114. "name": self.name,
  115. "state": self.state,
  116. "created_at": self.created_at.isoformat(),
  117. "history_count": len(self.history),
  118. "tools_count": len(self.tools),
  119. "max_steps": self.max_steps,
  120. "timeout": self.timeout
  121. }
  122. async def validate_input(self, input_data: Dict[str, Any]) -> bool:
  123. """验证输入数据"""
  124. required_fields = self.get_required_fields()
  125. for field in required_fields:
  126. if field not in input_data:
  127. raise AgentException(f"缺少必需字段: {field}")
  128. return True
  129. @abstractmethod
  130. def get_required_fields(self) -> List[str]:
  131. """获取必需的输入字段"""
  132. pass
  133. def __str__(self) -> str:
  134. return f"{self.__class__.__name__}(name='{self.name}', state='{self.state}')"
  135. def __repr__(self) -> str:
  136. return self.__str__()