paper_analyzer.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # specialist/paper_analyzer.py
  2. """PDF 论文分析专家"""
  3. import os
  4. from pathlib import Path
  5. from typing import Dict, List
  6. import PyPDF2
  7. from hello_agents import HelloAgentsLLM
  8. class PaperAnalyzerAgent:
  9. """
  10. PDF 论文分析专家
  11. 功能:
  12. - 读取 PDF 论文
  13. - 提取标题和摘要
  14. - 识别核心概念
  15. - 推断前置知识
  16. - 确定研究领域
  17. """
  18. def __init__(self, llm: HelloAgentsLLM):
  19. """
  20. 初始化 PaperAnalyzerAgent
  21. Args:
  22. llm: HelloAgentsLLM 实例
  23. """
  24. self.llm = llm
  25. def _extract_title_from_path(self, file_path: str) -> str:
  26. """
  27. 从文件路径提取论文标题
  28. Args:
  29. file_path: PDF 文件路径
  30. Returns:
  31. 论文标题
  32. """
  33. # 处理 ~ 路径
  34. if file_path.startswith("~"):
  35. file_path = os.path.expanduser(file_path)
  36. # 获取文件名(去掉扩展名)
  37. filename = Path(file_path).stem
  38. # 将连字符和下划线替换为空格
  39. title = filename.replace("-", " ").replace("_", " ")
  40. return title
  41. def _extract_text_from_pdf(self, file_path: str) -> str:
  42. """
  43. 从 PDF 提取文本
  44. Args:
  45. file_path: PDF 文件路径
  46. Returns:
  47. 提取的文本内容
  48. """
  49. # 处理 ~ 路径
  50. if file_path.startswith("~"):
  51. file_path = os.path.expanduser(file_path)
  52. try:
  53. with open(file_path, "rb") as file:
  54. reader = PyPDF2.PdfReader(file)
  55. text = ""
  56. # 提取前3页的内容(通常包含摘要和引言)
  57. max_pages = min(3, len(reader.pages))
  58. for i in range(max_pages):
  59. page = reader.pages[i]
  60. text += page.extract_text() + "\n"
  61. return text
  62. except Exception as e:
  63. raise IOError(f"无法读取 PDF 文件:{e}")
  64. def _extract_keywords_from_text(self, text: str) -> List[str]:
  65. """
  66. 从文本中提取关键词
  67. Args:
  68. text: 论文文本
  69. Returns:
  70. 关键词列表
  71. """
  72. # 学术领域常见关键词
  73. academic_keywords = [
  74. # 深度学习/机器学习
  75. "Neural Network",
  76. "Deep Learning",
  77. "Transformer",
  78. "Attention",
  79. "CNN",
  80. "RNN",
  81. "LSTM",
  82. "Backpropagation",
  83. "Gradient Descent",
  84. "Optimization",
  85. # 自然语言处理
  86. "NLP",
  87. "Language Model",
  88. "Tokenization",
  89. "Embedding",
  90. "BERT",
  91. "GPT",
  92. # 计算机视觉
  93. "Computer Vision",
  94. "Image Processing",
  95. "Convolution",
  96. "Feature Extraction",
  97. # 其他
  98. "Algorithm",
  99. "Data Structure",
  100. "Complexity",
  101. "Statistics",
  102. "Probability",
  103. ]
  104. found_keywords = []
  105. text_lower = text.lower()
  106. for keyword in academic_keywords:
  107. if keyword.lower() in text_lower:
  108. found_keywords.append(keyword)
  109. return found_keywords
  110. def _identify_prerequisites(self, keywords: List[str]) -> List[str]:
  111. """
  112. 根据关键词推断前置知识
  113. Args:
  114. keywords: 关键词列表
  115. Returns:
  116. 前置知识列表
  117. """
  118. # 前置知识映射
  119. prereq_map = {
  120. "Deep Learning": ["Machine Learning", "Python", "Linear Algebra"],
  121. "Transformer": ["Attention Mechanism", "Sequence Models"],
  122. "Neural Network": ["Calculus", "Linear Algebra", "Probability"],
  123. "CNN": ["Image Processing", "Linear Algebra"],
  124. "RNN": ["Sequence Models", "Calculus"],
  125. "NLP": ["Machine Learning", "Statistics", "Python"],
  126. "Computer Vision": ["Linear Algebra", "Probability", "Python"],
  127. }
  128. prerequisites = []
  129. for keyword in keywords:
  130. if keyword in prereq_map:
  131. prerequisites.extend(prereq_map[keyword])
  132. # 去重
  133. return list(set(prerequisites))
  134. def _analyze_with_llm(self, title: str, text: str) -> Dict[str, any]:
  135. """
  136. 使用 LLM 深度分析论文
  137. Args:
  138. title: 论文标题
  139. text: 论文文本
  140. Returns:
  141. 分析结果字典
  142. """
  143. user_prompt = f"""请分析以下学术论文并提取学习相关信息:
  144. 【论文标题】
  145. {title}
  146. 【论文内容(前1000字)】
  147. {text[:1000]}
  148. """
  149. messages = [
  150. {
  151. "role": "system",
  152. "content": "你是一个学术教育专家,擅长分析学术论文并提取学习相关信息。",
  153. },
  154. {"role": "user", "content": user_prompt},
  155. ]
  156. try:
  157. response = self.llm.invoke(messages)
  158. # 简化实现:返回基于规则的分析结果
  159. keywords = self._extract_keywords_from_text(text)
  160. prerequisites = self._identify_prerequisites(keywords)
  161. return {
  162. "domain": self._infer_domain_from_keywords(keywords),
  163. "core_concepts": keywords[:5], # 前5个关键词
  164. "prerequisites": prerequisites,
  165. "title": title,
  166. "learning_difficulty": "高级",
  167. "estimated_weeks": 8,
  168. }
  169. except Exception:
  170. # 降级:使用基于规则的分析
  171. keywords = self._extract_keywords_from_text(text)
  172. prerequisites = self._identify_prerequisites(keywords)
  173. return {
  174. "domain": self._infer_domain_from_keywords(keywords),
  175. "core_concepts": keywords[:5],
  176. "prerequisites": prerequisites,
  177. "title": title,
  178. "learning_difficulty": "高级",
  179. "estimated_weeks": 8,
  180. }
  181. def _infer_domain_from_keywords(self, keywords: List[str]) -> str:
  182. """
  183. 根据关键词推断研究领域
  184. Args:
  185. keywords: 关键词列表
  186. Returns:
  187. 研究领域
  188. """
  189. if not keywords:
  190. return "general"
  191. keyword_lower = " ".join(keywords).lower()
  192. # 领域映射
  193. if any(
  194. kw in keyword_lower
  195. for kw in ["transformer", "attention", "nlp", "language", "bert", "gpt"]
  196. ):
  197. return "natural-language-processing"
  198. elif any(
  199. kw in keyword_lower
  200. for kw in ["cnn", "image", "vision", "computer", "processing"]
  201. ):
  202. return "computer-vision"
  203. elif any(
  204. kw in keyword_lower
  205. for kw in ["neural", "deep", "learning", "network", "backpropagation"]
  206. ):
  207. return "deep-learning"
  208. elif any(
  209. kw in keyword_lower for kw in ["machine", "learning", "algorithm", "model"]
  210. ):
  211. return "machine-learning"
  212. else:
  213. return "general"
  214. def analyze(self, pdf_path: str) -> Dict[str, any]:
  215. """
  216. 分析 PDF 论文
  217. Args:
  218. pdf_path: PDF 文件路径
  219. Returns:
  220. 分析结果字典,包含:
  221. - domain: 研究领域
  222. - title: 论文标题
  223. - core_concepts: 核心概念列表
  224. - prerequisites: 前置知识列表
  225. - learning_difficulty: 学习难度
  226. - estimated_weeks: 估计学习周数
  227. """
  228. # 提取标题
  229. title = self._extract_title_from_path(pdf_path)
  230. # 提取文本
  231. try:
  232. text = self._extract_text_from_pdf(pdf_path)
  233. except IOError:
  234. # 如果无法读取 PDF,使用基于路径的分析
  235. return {
  236. "domain": "general",
  237. "title": title,
  238. "core_concepts": [],
  239. "prerequisites": [],
  240. "learning_difficulty": "高级",
  241. "estimated_weeks": 8,
  242. }
  243. # 使用 LLM 深度分析
  244. result = self._analyze_with_llm(title, text)
  245. return result