task_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. """
  2. 任务服务
  3. """
  4. from typing import Optional, List, Dict, Any
  5. from sqlalchemy.orm import Session
  6. from datetime import datetime
  7. from ..core.database import get_db
  8. from ..models.task import TaskDB, Task, TaskCreate, TaskUpdate
  9. from ..core.exceptions import TaskNotFoundError
  10. from ..agents.controller import AgentController
  11. import json
  12. import asyncio
  13. class TaskService:
  14. """任务服务类"""
  15. def __init__(self, db: Session):
  16. self.db = db
  17. self.agent_controller = AgentController()
  18. def get_task_by_id(self, task_id: int) -> Optional[Task]:
  19. """根据ID获取任务"""
  20. task_db = self.db.query(TaskDB).filter(TaskDB.id == task_id).first()
  21. if not task_db:
  22. raise TaskNotFoundError(f"Task with id {task_id} not found")
  23. return Task.from_orm(task_db)
  24. def get_tasks_by_user(self, user_id: int, skip: int = 0, limit: int = 20) -> List[Task]:
  25. """获取用户的任务列表"""
  26. tasks_db = self.db.query(TaskDB).filter(
  27. TaskDB.user_id == user_id
  28. ).order_by(TaskDB.created_at.desc()).offset(skip).limit(limit).all()
  29. return [Task.from_orm(task) for task in tasks_db]
  30. def create_task(self, task_create: TaskCreate, user_id: int) -> Task:
  31. """创建任务"""
  32. task_db = TaskDB(
  33. title=task_create.title,
  34. description=task_create.description,
  35. task_type=task_create.task_type,
  36. priority=task_create.priority,
  37. parameters=task_create.parameters,
  38. user_id=user_id
  39. )
  40. self.db.add(task_db)
  41. self.db.commit()
  42. self.db.refresh(task_db)
  43. # 异步执行任务
  44. self._execute_task_async(task_db.id)
  45. return Task.from_orm(task_db)
  46. def update_task(self, task_id: int, task_update: TaskUpdate) -> Task:
  47. """更新任务"""
  48. task_db = self.db.query(TaskDB).filter(TaskDB.id == task_id).first()
  49. if not task_db:
  50. raise TaskNotFoundError(f"Task with id {task_id} not found")
  51. # 更新字段
  52. update_data = task_update.dict(exclude_unset=True)
  53. for field, value in update_data.items():
  54. setattr(task_db, field, value)
  55. # 如果任务完成,设置完成时间
  56. if task_update.status == "completed":
  57. task_db.completed_at = datetime.utcnow()
  58. self.db.commit()
  59. self.db.refresh(task_db)
  60. return Task.from_orm(task_db)
  61. def delete_task(self, task_id: int) -> bool:
  62. """删除任务"""
  63. task_db = self.db.query(TaskDB).filter(TaskDB.id == task_id).first()
  64. if not task_db:
  65. raise TaskNotFoundError(f"Task with id {task_id} not found")
  66. self.db.delete(task_db)
  67. self.db.commit()
  68. return True
  69. def cancel_task(self, task_id: int) -> Task:
  70. """取消任务"""
  71. return self.update_task(task_id, TaskUpdate(status="failed", error_message="Task cancelled by user"))
  72. def retry_task(self, task_id: int) -> Task:
  73. """重试任务"""
  74. # 重置任务状态
  75. task = self.update_task(task_id, TaskUpdate(
  76. status="pending",
  77. progress=0,
  78. error_message=None
  79. ))
  80. # 重新执行任务
  81. self._execute_task_async(task_id)
  82. return task
  83. def get_task_statistics(self, user_id: int) -> Dict[str, Any]:
  84. """获取任务统计信息"""
  85. total_tasks = self.db.query(TaskDB).filter(TaskDB.user_id == user_id).count()
  86. # 按状态统计
  87. status_stats = self.db.query(
  88. TaskDB.status,
  89. self.db.func.count(TaskDB.id)
  90. ).filter(TaskDB.user_id == user_id).group_by(TaskDB.status).all()
  91. # 按类型统计
  92. type_stats = self.db.query(
  93. TaskDB.task_type,
  94. self.db.func.count(TaskDB.id)
  95. ).filter(TaskDB.user_id == user_id).group_by(TaskDB.task_type).all()
  96. # 成功率
  97. completed_tasks = self.db.query(TaskDB).filter(
  98. and_(TaskDB.user_id == user_id, TaskDB.status == "completed")
  99. ).count()
  100. success_rate = completed_tasks / total_tasks if total_tasks > 0 else 0
  101. return {
  102. 'total_tasks': total_tasks,
  103. 'success_rate': success_rate,
  104. 'status_distribution': dict(status_stats),
  105. 'type_distribution': dict(type_stats)
  106. }
  107. def _execute_task_async(self, task_id: int):
  108. """异步执行任务"""
  109. try:
  110. # 获取任务信息
  111. task_db = self.db.query(TaskDB).filter(TaskDB.id == task_id).first()
  112. if not task_db:
  113. return
  114. # 更新任务状态为运行中
  115. task_db.status = "running"
  116. task_db.progress = 0
  117. self.db.commit()
  118. # 根据任务类型执行相应的智能体
  119. if task_db.task_type == "literature_search":
  120. result = asyncio.run(self._execute_literature_search(task_db))
  121. elif task_db.task_type == "analysis":
  122. result = asyncio.run(self._execute_analysis(task_db))
  123. elif task_db.task_type == "writing":
  124. result = asyncio.run(self._execute_writing(task_db))
  125. else:
  126. raise ValueError(f"Unknown task type: {task_db.task_type}")
  127. # 更新任务结果
  128. task_db.status = "completed"
  129. task_db.progress = 100
  130. task_db.results = result
  131. task_db.completed_at = datetime.utcnow()
  132. self.db.commit()
  133. except Exception as e:
  134. # 更新任务状态为失败
  135. task_db.status = "failed"
  136. task_db.error_message = str(e)
  137. self.db.commit()
  138. async def _execute_literature_search(self, task_db: TaskDB) -> Dict[str, Any]:
  139. """执行文献搜索任务"""
  140. parameters = task_db.parameters or {}
  141. query = parameters.get('query', '')
  142. max_papers = parameters.get('max_papers', 20)
  143. # 使用猎手智能体进行文献搜索
  144. hunter_agent = self.agent_controller.get_agent('hunter')
  145. # 更新进度
  146. await self._update_task_progress(task_db.id, 20)
  147. # 执行搜索
  148. search_results = await hunter_agent.search_papers(query, max_papers)
  149. # 更新进度
  150. await self._update_task_progress(task_db.id, 60)
  151. # 使用矿工智能体进行深度挖掘
  152. miner_agent = self.agent_controller.get_agent('miner')
  153. enriched_results = await miner_agent.enrich_papers(search_results)
  154. # 更新进度
  155. await self._update_task_progress(task_db.id, 90)
  156. # 保存论文到数据库
  157. paper_service = PaperService(self.db)
  158. saved_papers = []
  159. for paper_data in enriched_results:
  160. try:
  161. paper = paper_service.create_paper(
  162. PaperCreate(**paper_data),
  163. task_db.user_id
  164. )
  165. saved_papers.append(paper.dict())
  166. except Exception as e:
  167. print(f"Error saving paper: {e}")
  168. return {
  169. 'query': query,
  170. 'total_found': len(enriched_results),
  171. 'papers_saved': len(saved_papers),
  172. 'papers': saved_papers
  173. }
  174. async def _execute_analysis(self, task_db: TaskDB) -> Dict[str, Any]:
  175. """执行分析任务"""
  176. parameters = task_db.parameters or {}
  177. paper_ids = parameters.get('paper_ids', [])
  178. analysis_type = parameters.get('analysis_type', 'comprehensive')
  179. # 使用教练智能体进行分析
  180. coach_agent = self.agent_controller.get_agent('coach')
  181. # 更新进度
  182. await self._update_task_progress(task_db.id, 30)
  183. # 执行分析
  184. analysis_result = await coach_agent.analyze_papers(paper_ids, analysis_type)
  185. # 更新进度
  186. await self._update_task_progress(task_db.id, 80)
  187. # 保存分析结果
  188. analysis_service = AnalysisService(self.db)
  189. analysis = analysis_service.create_analysis(
  190. {
  191. 'title': f"Analysis of {len(paper_ids)} papers",
  192. 'analysis_type': analysis_type,
  193. 'paper_ids': paper_ids,
  194. 'methodology': analysis_result.get('methodology', ''),
  195. 'findings': analysis_result.get('findings', {}),
  196. 'insights': analysis_result.get('insights', ''),
  197. 'limitations': analysis_result.get('limitations', ''),
  198. 'recommendations': analysis_result.get('recommendations', ''),
  199. 'confidence_score': analysis_result.get('confidence_score', 0.0),
  200. 'novelty_score': analysis_result.get('novelty_score', 0.0),
  201. 'impact_score': analysis_result.get('impact_score', 0.0)
  202. },
  203. task_db.user_id,
  204. task_db.id
  205. )
  206. return {
  207. 'analysis_id': analysis.id,
  208. 'analysis_type': analysis_type,
  209. 'papers_analyzed': len(paper_ids),
  210. 'result': analysis.dict()
  211. }
  212. async def _execute_writing(self, task_db: TaskDB) -> Dict[str, Any]:
  213. """执行写作任务"""
  214. parameters = task_db.parameters or {}
  215. paper_ids = parameters.get('paper_ids', [])
  216. writing_type = parameters.get('writing_type', 'review')
  217. outline = parameters.get('outline')
  218. # 使用教练智能体进行写作
  219. coach_agent = self.agent_controller.get_agent('coach')
  220. # 更新进度
  221. await self._update_task_progress(task_db.id, 25)
  222. # 生成内容
  223. writing_result = await coach_agent.generate_writing(paper_ids, writing_type, outline)
  224. # 更新进度
  225. await self._update_task_progress(task_db.id, 75)
  226. # 保存写作结果
  227. writing_service = WritingService(self.db)
  228. writing = writing_service.create_writing(
  229. {
  230. 'title': writing_result.get('title', 'Generated Writing'),
  231. 'writing_type': writing_type,
  232. 'content': writing_result.get('content', ''),
  233. 'outline': writing_result.get('outline', []),
  234. 'sections': writing_result.get('sections', {}),
  235. 'citations': writing_result.get('citations', []),
  236. 'paper_ids': paper_ids
  237. },
  238. task_db.user_id,
  239. task_db.id
  240. )
  241. return {
  242. 'writing_id': writing.id,
  243. 'writing_type': writing_type,
  244. 'papers_referenced': len(paper_ids),
  245. 'word_count': writing.word_count,
  246. 'result': writing.dict()
  247. }
  248. async def _update_task_progress(self, task_id: int, progress: int):
  249. """更新任务进度"""
  250. task_db = self.db.query(TaskDB).filter(TaskDB.id == task_id).first()
  251. if task_db:
  252. task_db.progress = progress
  253. self.db.commit()