controller.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. """
  2. InnoCore AI 智能体控制器
  3. 负责四大智能体的协同调度和任务编排
  4. """
  5. import asyncio
  6. from typing import Dict, List, Optional, Any, Callable
  7. from datetime import datetime
  8. import json
  9. import logging
  10. from enum import Enum
  11. from agents.base import BaseAgent
  12. from agents.hunter import HunterAgent
  13. from agents.miner import MinerAgent
  14. from agents.coach import CoachAgent
  15. from agents.validator import ValidatorAgent
  16. from core.config import get_config
  17. from core.exceptions import AgentException, TimeoutException
  18. logger = logging.getLogger(__name__)
  19. class TaskType(Enum):
  20. """任务类型枚举"""
  21. PAPER_HUNTING = "paper_hunting"
  22. PAPER_ANALYSIS = "paper_analysis"
  23. WRITING_ASSISTANCE = "writing_assistance"
  24. CITATION_VALIDATION = "citation_validation"
  25. FULL_WORKFLOW = "full_workflow"
  26. class TaskStatus(Enum):
  27. """任务状态枚举"""
  28. PENDING = "pending"
  29. RUNNING = "running"
  30. COMPLETED = "completed"
  31. FAILED = "failed"
  32. CANCELLED = "cancelled"
  33. class AgentController:
  34. """智能体控制器"""
  35. def __init__(self):
  36. self.config = get_config()
  37. # 初始化智能体
  38. self.agents = {
  39. "hunter": HunterAgent(),
  40. "miner": MinerAgent(),
  41. "coach": CoachAgent(),
  42. "validator": ValidatorAgent()
  43. }
  44. # 任务管理
  45. self.active_tasks = {}
  46. self.task_history = []
  47. self.task_queue = asyncio.Queue()
  48. # 并发控制
  49. self.semaphore = asyncio.Semaphore(self.config.concurrent_agents)
  50. # 事件回调
  51. self.event_callbacks = {
  52. "task_started": [],
  53. "task_completed": [],
  54. "task_failed": [],
  55. "agent_status_changed": []
  56. }
  57. async def initialize(self):
  58. """初始化控制器"""
  59. logger.info("初始化Agent Controller...")
  60. # 这里可以添加智能体的初始化逻辑
  61. # 例如加载模型、建立连接等
  62. logger.info("Agent Controller初始化完成")
  63. async def submit_task(self, task_type: TaskType, input_data: Dict[str, Any],
  64. priority: int = 0, callback: Callable = None) -> str:
  65. """提交任务"""
  66. task_id = f"task_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{len(self.active_tasks)}"
  67. task = {
  68. "id": task_id,
  69. "type": task_type,
  70. "input_data": input_data,
  71. "status": TaskStatus.PENDING,
  72. "priority": priority,
  73. "callback": callback,
  74. "created_at": datetime.now(),
  75. "started_at": None,
  76. "completed_at": None,
  77. "result": None,
  78. "error": None,
  79. "agent_results": {}
  80. }
  81. self.active_tasks[task_id] = task
  82. await self.task_queue.put((priority, task))
  83. logger.info(f"任务已提交: {task_id}, 类型: {task_type.value}")
  84. return task_id
  85. async def execute_task(self, task_id: str) -> Dict[str, Any]:
  86. """执行单个任务"""
  87. if task_id not in self.active_tasks:
  88. raise AgentException(f"任务不存在: {task_id}")
  89. task = self.active_tasks[task_id]
  90. async with self.semaphore: # 并发控制
  91. try:
  92. task["status"] = TaskStatus.RUNNING
  93. task["started_at"] = datetime.now()
  94. await self._trigger_event("task_started", task)
  95. # 根据任务类型执行相应的逻辑
  96. if task["type"] == TaskType.PAPER_HUNTING:
  97. result = await self._execute_paper_hunting(task)
  98. elif task["type"] == TaskType.PAPER_ANALYSIS:
  99. result = await self._execute_paper_analysis(task)
  100. elif task["type"] == TaskType.WRITING_ASSISTANCE:
  101. result = await self._execute_writing_assistance(task)
  102. elif task["type"] == TaskType.CITATION_VALIDATION:
  103. result = await self._execute_citation_validation(task)
  104. elif task["type"] == TaskType.FULL_WORKFLOW:
  105. result = await self._execute_full_workflow(task)
  106. else:
  107. raise AgentException(f"不支持的任务类型: {task['type']}")
  108. task["status"] = TaskStatus.COMPLETED
  109. task["completed_at"] = datetime.now()
  110. task["result"] = result
  111. await self._trigger_event("task_completed", task)
  112. # 执行回调
  113. if task["callback"]:
  114. await task["callback"](task)
  115. return result
  116. except Exception as e:
  117. task["status"] = TaskStatus.FAILED
  118. task["completed_at"] = datetime.now()
  119. task["error"] = str(e)
  120. await self._trigger_event("task_failed", task)
  121. logger.error(f"任务执行失败 {task_id}: {str(e)}")
  122. raise AgentException(f"任务执行失败: {str(e)}")
  123. finally:
  124. # 移动到历史记录
  125. self.task_history.append(task.copy())
  126. del self.active_tasks[task_id]
  127. async def _execute_paper_hunting(self, task: Dict) -> Dict[str, Any]:
  128. """执行论文抓取任务"""
  129. input_data = task["input_data"]
  130. # 调用Hunter Agent
  131. hunter_result = await self.agents["hunter"].run(input_data)
  132. task["agent_results"]["hunter"] = hunter_result
  133. return {
  134. "task_type": "paper_hunting",
  135. "papers_found": hunter_result.get("downloaded_papers", []),
  136. "statistics": {
  137. "total_found": hunter_result.get("total_found", 0),
  138. "downloaded": hunter_result.get("downloaded_papers", 0)
  139. }
  140. }
  141. async def _execute_paper_analysis(self, task: Dict) -> Dict[str, Any]:
  142. """执行论文分析任务"""
  143. input_data = task["input_data"]
  144. # 调用Miner Agent
  145. miner_result = await self.agents["miner"].run(input_data)
  146. task["agent_results"]["miner"] = miner_result
  147. return {
  148. "task_type": "paper_analysis",
  149. "analysis_report": miner_result,
  150. "paper_id": input_data.get("paper_id")
  151. }
  152. async def _execute_writing_assistance(self, task: Dict) -> Dict[str, Any]:
  153. """执行写作辅助任务"""
  154. input_data = task["input_data"]
  155. # 调用Coach Agent
  156. coach_result = await self.agents["coach"].run(input_data)
  157. task["agent_results"]["coach"] = coach_result
  158. return {
  159. "task_type": "writing_assistance",
  160. "assistance_result": coach_result,
  161. "user_id": input_data.get("user_id")
  162. }
  163. async def _execute_citation_validation(self, task: Dict) -> Dict[str, Any]:
  164. """执行引用校验任务"""
  165. input_data = task["input_data"]
  166. # 调用Validator Agent
  167. validator_result = await self.agents["validator"].run(input_data)
  168. task["agent_results"]["validator"] = validator_result
  169. return {
  170. "task_type": "citation_validation",
  171. "validation_result": validator_result,
  172. "paper_info": input_data.get("paper_info")
  173. }
  174. async def _execute_full_workflow(self, task: Dict) -> Dict[str, Any]:
  175. """执行完整工作流"""
  176. input_data = task["input_data"]
  177. user_id = input_data.get("user_id")
  178. keywords = input_data.get("keywords", [])
  179. workflow_result = {
  180. "task_type": "full_workflow",
  181. "stages": {},
  182. "final_papers": [],
  183. "analysis_reports": []
  184. }
  185. try:
  186. # Stage 1: 论文抓取
  187. self._add_to_history("开始论文抓取阶段")
  188. hunting_input = {
  189. "keywords": keywords,
  190. "max_papers": input_data.get("max_papers", 10),
  191. "sources": input_data.get("sources", ["arxiv"])
  192. }
  193. hunting_result = await self.agents["hunter"].run(hunting_input)
  194. workflow_result["stages"]["hunting"] = hunting_result
  195. task["agent_results"]["hunter"] = hunting_result
  196. downloaded_papers = hunting_result.get("papers", [])
  197. workflow_result["final_papers"] = downloaded_papers
  198. # Stage 2: 论文分析
  199. self._add_to_history("开始论文分析阶段")
  200. for paper in downloaded_papers:
  201. if paper.get("db_id"):
  202. analysis_input = {
  203. "paper_id": paper["db_id"],
  204. "user_id": user_id,
  205. "analysis_type": "full"
  206. }
  207. try:
  208. analysis_result = await self.agents["miner"].run(analysis_input)
  209. workflow_result["analysis_reports"].append(analysis_result)
  210. except Exception as e:
  211. self._add_to_history(f"论文分析失败 {paper.get('title', 'Unknown')}: {str(e)}")
  212. # Stage 3: 引用校验(可选)
  213. if input_data.get("validate_citations", False):
  214. self._add_to_history("开始引用校验阶段")
  215. for paper in downloaded_papers:
  216. paper_info = {
  217. "title": paper.get("title", ""),
  218. "authors": paper.get("authors", []),
  219. "doi": paper.get("doi", ""),
  220. "year": datetime.now().year
  221. }
  222. validation_input = {
  223. "paper_info": paper_info,
  224. "formats": ["bibtex", "apa"],
  225. "verify_external": True
  226. }
  227. try:
  228. validation_result = await self.agents["validator"].run(validation_input)
  229. paper["citations"] = validation_result.get("citations", {})
  230. except Exception as e:
  231. self._add_to_history(f"引用校验失败 {paper.get('title', 'Unknown')}: {str(e)}")
  232. self._add_to_history("完整工作流执行完成")
  233. except Exception as e:
  234. self._add_to_history(f"工作流执行失败: {str(e)}")
  235. raise
  236. return workflow_result
  237. async def start_task_processor(self):
  238. """启动任务处理器"""
  239. logger.info("启动任务处理器...")
  240. while True:
  241. try:
  242. # 获取任务(按优先级排序)
  243. priority, task = await self.task_queue.get()
  244. # 异步执行任务
  245. asyncio.create_task(self.execute_task(task["id"]))
  246. except Exception as e:
  247. logger.error(f"任务处理器异常: {str(e)}")
  248. await asyncio.sleep(1)
  249. async def get_task_status(self, task_id: str) -> Optional[Dict]:
  250. """获取任务状态"""
  251. if task_id in self.active_tasks:
  252. task = self.active_tasks[task_id]
  253. return {
  254. "id": task["id"],
  255. "type": task["type"].value,
  256. "status": task["status"].value,
  257. "created_at": task["created_at"].isoformat(),
  258. "started_at": task["started_at"].isoformat() if task["started_at"] else None,
  259. "completed_at": task["completed_at"].isoformat() if task["completed_at"] else None,
  260. "priority": task["priority"]
  261. }
  262. else:
  263. # 在历史记录中查找
  264. for task in self.task_history:
  265. if task["id"] == task_id:
  266. return {
  267. "id": task["id"],
  268. "type": task["type"].value,
  269. "status": task["status"].value,
  270. "created_at": task["created_at"].isoformat(),
  271. "started_at": task["started_at"].isoformat() if task["started_at"] else None,
  272. "completed_at": task["completed_at"].isoformat() if task["completed_at"] else None,
  273. "priority": task["priority"]
  274. }
  275. return None
  276. async def cancel_task(self, task_id: str) -> bool:
  277. """取消任务"""
  278. if task_id in self.active_tasks:
  279. task = self.active_tasks[task_id]
  280. if task["status"] == TaskStatus.PENDING:
  281. task["status"] = TaskStatus.CANCELLED
  282. task["completed_at"] = datetime.now()
  283. # 移动到历史记录
  284. self.task_history.append(task.copy())
  285. del self.active_tasks[task_id]
  286. logger.info(f"任务已取消: {task_id}")
  287. return True
  288. return False
  289. async def get_agent_status(self) -> Dict[str, Any]:
  290. """获取所有智能体状态"""
  291. agent_status = {}
  292. for name, agent in self.agents.items():
  293. agent_status[name] = agent.get_status()
  294. return {
  295. "agents": agent_status,
  296. "active_tasks": len(self.active_tasks),
  297. "queued_tasks": self.task_queue.qsize(),
  298. "completed_tasks": len(self.task_history),
  299. "max_concurrent": self.config.concurrent_agents
  300. }
  301. def add_event_callback(self, event_type: str, callback: Callable):
  302. """添加事件回调"""
  303. if event_type in self.event_callbacks:
  304. self.event_callbacks[event_type].append(callback)
  305. async def _trigger_event(self, event_type: str, data: Any):
  306. """触发事件"""
  307. if event_type in self.event_callbacks:
  308. for callback in self.event_callbacks[event_type]:
  309. try:
  310. if asyncio.iscoroutinefunction(callback):
  311. await callback(data)
  312. else:
  313. callback(data)
  314. except Exception as e:
  315. logger.error(f"事件回调执行失败 {event_type}: {str(e)}")
  316. def _add_to_history(self, message: str):
  317. """添加到控制器历史记录"""
  318. timestamp = datetime.now().isoformat()
  319. logger.info(f"[{timestamp}] Controller: {message}")
  320. async def shutdown(self):
  321. """关闭控制器"""
  322. logger.info("关闭Agent Controller...")
  323. # 取消所有待处理任务
  324. for task_id in list(self.active_tasks.keys()):
  325. await self.cancel_task(task_id)
  326. # 清理智能体资源
  327. for agent in self.agents.values():
  328. if hasattr(agent, 'close'):
  329. await agent.close()
  330. logger.info("Agent Controller已关闭")
  331. # 全局控制器实例
  332. agent_controller = AgentController()