coach.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. """
  2. InnoCore AI 写作助教 (Coach Agent)
  3. 负责风格迁移、实时润色、解释复杂概念
  4. """
  5. import asyncio
  6. import json
  7. from typing import Dict, List, Optional, Any
  8. from datetime import datetime
  9. from agents.base import BaseAgent
  10. from core.database import db_manager
  11. from core.vector_store import vector_store_manager
  12. from core.exceptions import AgentException
  13. class CoachAgent(BaseAgent):
  14. """写作助教智能体"""
  15. def __init__(self, llm=None):
  16. super().__init__("Coach", llm)
  17. # 添加工具
  18. self.add_tool("explain_concept", self._explain_concept, "解释复杂概念")
  19. self.add_tool("polish_text", self._polish_text, "润色文本")
  20. self.add_tool("mimic_style", self._mimic_style, "模仿写作风格")
  21. self.add_tool("get_user_style", self._get_user_style, "获取用户写作风格")
  22. self.add_tool("suggest_improvements", self._suggest_improvements, "建议改进")
  23. async def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
  24. """执行写作助教任务"""
  25. await self.validate_input(input_data)
  26. self.set_state("running")
  27. try:
  28. user_id = input_data["user_id"]
  29. task_type = input_data["task_type"] # explain, polish, mimic, suggest
  30. content = input_data["content"]
  31. context = input_data.get("context", {})
  32. result = None
  33. if task_type == "explain":
  34. result = await self._handle_explain_task(user_id, content, context)
  35. elif task_type == "polish":
  36. result = await self._handle_polish_task(user_id, content, context)
  37. elif task_type == "mimic":
  38. result = await self._handle_mimic_task(user_id, content, context)
  39. elif task_type == "suggest":
  40. result = await self._handle_suggest_task(user_id, content, context)
  41. else:
  42. raise AgentException(f"不支持的任务类型: {task_type}")
  43. self.set_state("completed")
  44. return {
  45. "status": "success",
  46. "task_type": task_type,
  47. "user_id": user_id,
  48. "result": result,
  49. "timestamp": datetime.now().isoformat()
  50. }
  51. except Exception as e:
  52. self.set_state("error")
  53. raise AgentException(f"Coach Agent执行失败: {str(e)}")
  54. def get_required_fields(self) -> List[str]:
  55. """获取必需的输入字段"""
  56. return ["user_id", "task_type", "content"]
  57. async def _handle_explain_task(self, user_id: str, content: str, context: Dict) -> Dict[str, Any]:
  58. """处理解释任务"""
  59. try:
  60. # 获取用户的历史论文作为上下文
  61. user_context = await self._get_user_context(user_id)
  62. explain_prompt = f"""
  63. 请用通俗易懂的语言解释以下内容:
  64. 需要解释的内容:
  65. {content}
  66. 上下文信息:
  67. {json.dumps(context, ensure_ascii=False, indent=2)}
  68. 用户研究领域背景:
  69. {json.dumps(user_context, ensure_ascii=False, indent=2)}
  70. 请提供:
  71. 1. 简单易懂的解释
  72. 2. 相关的例子或类比
  73. 3. 在该领域的重要性
  74. 4. 可能的应用场景
  75. 请以JSON格式返回结果。
  76. """
  77. response = await self.think(explain_prompt)
  78. try:
  79. result = json.loads(response)
  80. except json.JSONDecodeError:
  81. result = {
  82. "explanation": response,
  83. "examples": ["需要补充具体例子"],
  84. "importance": "在相关领域具有重要意义",
  85. "applications": ["潜在应用场景"]
  86. }
  87. self._add_to_history(f"完成解释任务: {content[:50]}...")
  88. return result
  89. except Exception as e:
  90. self._add_to_history(f"解释任务失败: {str(e)}")
  91. return {
  92. "explanation": f"解释过程中出现错误: {str(e)}",
  93. "examples": [],
  94. "importance": "",
  95. "applications": []
  96. }
  97. async def _handle_polish_task(self, user_id: str, content: str, context: Dict) -> Dict[str, Any]:
  98. """处理润色任务"""
  99. try:
  100. # 获取用户的写作风格偏好
  101. user_style = await self._get_user_writing_style(user_id)
  102. # 获取相关的风格参考
  103. style_references = await self._get_style_references(user_id, content)
  104. polish_prompt = f"""
  105. 请将以下文本润色为地道的学术英语:
  106. 原文:
  107. {content}
  108. 用户写作风格偏好:
  109. {json.dumps(user_style, ensure_ascii=False, indent=2)}
  110. 风格参考:
  111. {json.dumps(style_references, ensure_ascii=False, indent=2)}
  112. 上下文信息:
  113. {json.dumps(context, ensure_ascii=False, indent=2)}
  114. 请提供:
  115. 1. 润色后的英文文本
  116. 2. 主要修改说明
  117. 3. 风格改进建议
  118. 4. 参考的论文句式来源
  119. 要求:
  120. - 保持原意不变
  121. - 使用地道的学术表达
  122. - 符合目标期刊/会议的写作风格
  123. - 在注释中说明参考了哪些历史论文的句式
  124. 请以JSON格式返回结果。
  125. """
  126. response = await self.think(polish_prompt)
  127. try:
  128. result = json.loads(response)
  129. except json.JSONDecodeError:
  130. result = {
  131. "polished_text": response,
  132. "modifications": ["语法修正", "词汇优化"],
  133. "style_suggestions": ["建议使用更正式的表达"],
  134. "references": ["基于学术写作规范"]
  135. }
  136. self._add_to_history(f"完成润色任务: {content[:50]}...")
  137. return result
  138. except Exception as e:
  139. self._add_to_history(f"润色任务失败: {str(e)}")
  140. return {
  141. "polished_text": content,
  142. "modifications": [f"润色过程中出现错误: {str(e)}"],
  143. "style_suggestions": [],
  144. "references": []
  145. }
  146. async def _handle_mimic_task(self, user_id: str, content: str, context: Dict) -> Dict[str, Any]:
  147. """处理模仿任务"""
  148. try:
  149. # 获取目标风格参考
  150. target_style = context.get("target_style", "formal_academic")
  151. reference_papers = context.get("reference_papers", [])
  152. # 如果没有指定参考论文,从用户库中获取
  153. if not reference_papers:
  154. reference_papers = await self._get_user_top_papers(user_id, limit=3)
  155. mimic_prompt = f"""
  156. 请基于以下参考论文的写作风格,重写给定内容:
  157. 原文:
  158. {content}
  159. 目标风格:
  160. {target_style}
  161. 参考论文:
  162. {json.dumps(reference_papers, ensure_ascii=False, indent=2)}
  163. 上下文信息:
  164. {json.dumps(context, ensure_ascii=False, indent=2)}
  165. 请提供:
  166. 1. 重写后的文本
  167. 2. 风格分析(说明如何体现目标风格)
  168. 3. 具体的模仿技巧
  169. 4. 参考的句式结构
  170. 请以JSON格式返回结果。
  171. """
  172. response = await self.think(mimic_prompt)
  173. try:
  174. result = json.loads(response)
  175. except json.JSONDecodeError:
  176. result = {
  177. "rewritten_text": response,
  178. "style_analysis": "基于学术写作风格进行重写",
  179. "mimic_techniques": ["句式结构模仿", "词汇选择"],
  180. "reference_structures": ["学术表达方式"]
  181. }
  182. self._add_to_history(f"完成模仿任务: {content[:50]}...")
  183. return result
  184. except Exception as e:
  185. self._add_to_history(f"模仿任务失败: {str(e)}")
  186. return {
  187. "rewritten_text": content,
  188. "style_analysis": f"模仿过程中出现错误: {str(e)}",
  189. "mimic_techniques": [],
  190. "reference_structures": []
  191. }
  192. async def _handle_suggest_task(self, user_id: str, content: str, context: Dict) -> Dict[str, Any]:
  193. """处理建议任务"""
  194. try:
  195. # 获取用户的历史写作数据
  196. user_writing_history = await self._get_user_writing_history(user_id)
  197. suggest_prompt = f"""
  198. 请对以下文本提供改进建议:
  199. 文本内容:
  200. {content}
  201. 用户写作历史:
  202. {json.dumps(user_writing_history, ensure_ascii=False, indent=2)}
  203. 上下文信息:
  204. {json.dumps(context, ensure_ascii=False, indent=2)}
  205. 请提供:
  206. 1. 整体评价
  207. 2. 具体改进建议(按重要性排序)
  208. 3. 语法和表达问题
  209. 4. 结构优化建议
  210. 5. 学术表达改进
  211. 请以JSON格式返回结果。
  212. """
  213. response = await self.think(suggest_prompt)
  214. try:
  215. result = json.loads(response)
  216. except json.JSONDecodeError:
  217. result = {
  218. "overall_evaluation": "文本整体质量良好",
  219. "improvement_suggestions": ["建议加强逻辑表达", "可以增加更多细节"],
  220. "grammar_issues": ["检查时态一致性"],
  221. "structure_suggestions": ["建议优化段落结构"],
  222. "academic_improvements": ["使用更正式的学术词汇"]
  223. }
  224. self._add_to_history(f"完成建议任务: {content[:50]}...")
  225. return result
  226. except Exception as e:
  227. self._add_to_history(f"建议任务失败: {str(e)}")
  228. return {
  229. "overall_evaluation": f"分析过程中出现错误: {str(e)}",
  230. "improvement_suggestions": [],
  231. "grammar_issues": [],
  232. "structure_suggestions": [],
  233. "academic_improvements": []
  234. }
  235. async def _get_user_context(self, user_id: str) -> Dict[str, Any]:
  236. """获取用户的研究背景"""
  237. try:
  238. user = await db_manager.get_user(user_id)
  239. if user:
  240. return user.get("profile", {})
  241. return {}
  242. except Exception:
  243. return {}
  244. async def _get_user_writing_style(self, user_id: str) -> Dict[str, Any]:
  245. """获取用户写作风格偏好"""
  246. user_context = await self._get_user_context(user_id)
  247. return user_context.get("writing_style", {
  248. "tone": "formal",
  249. "complexity": "medium",
  250. "preferred_journals": ["Nature", "Science"],
  251. "language": "english"
  252. })
  253. async def _get_style_references(self, user_id: str, content: str) -> List[Dict[str, Any]]:
  254. """获取风格参考"""
  255. try:
  256. # 搜索用户库中的相关论文
  257. search_results = await vector_store_manager.hybrid_search(
  258. query=content,
  259. user_id=user_id,
  260. top_k=3,
  261. include_l2=True,
  262. include_l1=False
  263. )
  264. references = []
  265. for result in search_results:
  266. payload = result["payload"]
  267. references.append({
  268. "title": payload.get("title", ""),
  269. "abstract": payload.get("abstract", "")[:200],
  270. "similarity": result["score"]
  271. })
  272. return references
  273. except Exception:
  274. return []
  275. async def _get_user_top_papers(self, user_id: str, limit: int = 3) -> List[Dict[str, Any]]:
  276. """获取用户评分最高的论文"""
  277. try:
  278. user_papers = await db_manager.get_user_papers(user_id, limit=limit)
  279. top_papers = []
  280. for paper in user_papers:
  281. top_papers.append({
  282. "title": paper.get("title", ""),
  283. "abstract": paper.get("abstract", "")[:300],
  284. "rating": paper.get("rating", 0),
  285. "authors": paper.get("authors", [])
  286. })
  287. return top_papers
  288. except Exception:
  289. return []
  290. async def _get_user_writing_history(self, user_id: str) -> List[Dict[str, Any]]:
  291. """获取用户写作历史"""
  292. try:
  293. # 这里应该从用户的写作历史记录中获取数据
  294. # 暂时返回模拟数据
  295. return [
  296. {
  297. "date": "2024-01-01",
  298. "content_type": "abstract",
  299. "word_count": 200,
  300. "feedback_score": 4.5
  301. }
  302. ]
  303. except Exception:
  304. return []
  305. # 工具方法
  306. async def _explain_concept(self, concept: str, context: Dict = None) -> Dict:
  307. """解释概念工具"""
  308. return await self._handle_explain_task(
  309. context.get("user_id", ""),
  310. concept,
  311. context or {}
  312. )
  313. async def _polish_text(self, text: str, context: Dict = None) -> Dict:
  314. """润色文本工具"""
  315. return await self._handle_polish_task(
  316. context.get("user_id", ""),
  317. text,
  318. context or {}
  319. )
  320. async def _mimic_style(self, text: str, target_style: str, context: Dict = None) -> Dict:
  321. """模仿风格工具"""
  322. ctx = context or {}
  323. ctx["target_style"] = target_style
  324. return await self._handle_mimic_task(
  325. ctx.get("user_id", ""),
  326. text,
  327. ctx
  328. )
  329. async def _get_user_style(self, user_id: str) -> Dict:
  330. """获取用户风格工具"""
  331. return await self._get_user_writing_style(user_id)
  332. async def _suggest_improvements(self, text: str, context: Dict = None) -> Dict:
  333. """建议改进工具"""
  334. return await self._handle_suggest_task(
  335. context.get("user_id", ""),
  336. text,
  337. context or {}
  338. )