paper_service.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. """
  2. 论文服务
  3. """
  4. from typing import Optional, List, Dict, Any
  5. from sqlalchemy.orm import Session
  6. from sqlalchemy import and_, or_, desc
  7. from ..core.database import get_db
  8. from ..core.vector_store import VectorStore
  9. from ..models.paper import PaperDB, Paper, PaperCreate, PaperUpdate, PaperSearch
  10. from ..core.exceptions import PaperNotFoundError, PaperAlreadyExistsError
  11. from ..utils.pdf_parser import PDFParser
  12. from ..utils.embedding import EmbeddingService
  13. import json
  14. class PaperService:
  15. """论文服务类"""
  16. def __init__(self, db: Session):
  17. self.db = db
  18. self.vector_store = VectorStore()
  19. self.pdf_parser = PDFParser()
  20. self.embedding_service = EmbeddingService()
  21. def get_paper_by_id(self, paper_id: int) -> Optional[Paper]:
  22. """根据ID获取论文"""
  23. paper_db = self.db.query(PaperDB).filter(PaperDB.id == paper_id).first()
  24. if not paper_db:
  25. raise PaperNotFoundError(f"Paper with id {paper_id} not found")
  26. return Paper.from_orm(paper_db)
  27. def get_papers_by_user(self, user_id: int, skip: int = 0, limit: int = 20) -> List[Paper]:
  28. """获取用户的论文列表"""
  29. papers_db = self.db.query(PaperDB).filter(
  30. PaperDB.user_id == user_id
  31. ).offset(skip).limit(limit).all()
  32. return [Paper.from_orm(paper) for paper in papers_db]
  33. def create_paper(self, paper_create: PaperCreate, user_id: int) -> Paper:
  34. """创建论文记录"""
  35. # 检查DOI是否已存在
  36. if paper_create.doi:
  37. existing = self.db.query(PaperDB).filter(PaperDB.doi == paper_create.doi).first()
  38. if existing:
  39. raise PaperAlreadyExistsError(f"Paper with DOI {paper_create.doi} already exists")
  40. # 检查arXiv ID是否已存在
  41. if paper_create.arxiv_id:
  42. existing = self.db.query(PaperDB).filter(PaperDB.arxiv_id == paper_create.arxiv_id).first()
  43. if existing:
  44. raise PaperAlreadyExistsError(f"Paper with arXiv ID {paper_create.arxiv_id} already exists")
  45. # 创建论文记录
  46. paper_db = PaperDB(
  47. title=paper_create.title,
  48. authors=json.dumps(paper_create.authors),
  49. abstract=paper_create.abstract,
  50. keywords=json.dumps(paper_create.keywords),
  51. publication_year=paper_create.publication_year,
  52. journal=paper_create.journal,
  53. doi=paper_create.doi,
  54. arxiv_id=paper_create.arxiv_id,
  55. pdf_url=paper_create.pdf_url,
  56. user_id=user_id
  57. )
  58. self.db.add(paper_db)
  59. self.db.commit()
  60. self.db.refresh(paper_db)
  61. # 异步处理PDF和嵌入
  62. self._process_paper_async(paper_db.id)
  63. return Paper.from_orm(paper_db)
  64. def update_paper(self, paper_id: int, paper_update: PaperUpdate) -> Paper:
  65. """更新论文信息"""
  66. paper_db = self.db.query(PaperDB).filter(PaperDB.id == paper_id).first()
  67. if not paper_db:
  68. raise PaperNotFoundError(f"Paper with id {paper_id} not found")
  69. # 更新字段
  70. update_data = paper_update.dict(exclude_unset=True)
  71. for field, value in update_data.items():
  72. if field in ['authors', 'keywords']:
  73. setattr(paper_db, field, json.dumps(value))
  74. else:
  75. setattr(paper_db, field, value)
  76. self.db.commit()
  77. self.db.refresh(paper_db)
  78. return Paper.from_orm(paper_db)
  79. def delete_paper(self, paper_id: int) -> bool:
  80. """删除论文"""
  81. paper_db = self.db.query(PaperDB).filter(PaperDB.id == paper_id).first()
  82. if not paper_db:
  83. raise PaperNotFoundError(f"Paper with id {paper_id} not found")
  84. # 从向量存储中删除
  85. if paper_db.embeddings:
  86. self.vector_store.delete_document(paper_id)
  87. self.db.delete(paper_db)
  88. self.db.commit()
  89. return True
  90. def search_papers(self, search: PaperSearch, user_id: int) -> List[Paper]:
  91. """搜索论文"""
  92. query = self.db.query(PaperDB).filter(PaperDB.user_id == user_id)
  93. # 文本搜索
  94. if search.query:
  95. search_filter = or_(
  96. PaperDB.title.contains(search.query),
  97. PaperDB.abstract.contains(search.query),
  98. PaperDB.keywords.contains(search.query)
  99. )
  100. query = query.filter(search_filter)
  101. # 应用过滤器
  102. filters = search.filters
  103. if 'year_range' in filters:
  104. start_year, end_year = filters['year_range']
  105. query = query.filter(
  106. and_(
  107. PaperDB.publication_year >= start_year,
  108. PaperDB.publication_year <= end_year
  109. )
  110. )
  111. if 'venues' in filters:
  112. query = query.filter(PaperDB.journal.in_(filters['venues']))
  113. if 'authors' in filters:
  114. author_filter = or_(*[
  115. PaperDB.authors.contains(author) for author in filters['authors']
  116. ])
  117. query = query.filter(author_filter)
  118. # 排序
  119. if search.sort_by == "relevance":
  120. query = query.order_by(desc(PaperDB.relevance_score))
  121. elif search.sort_by == "quality":
  122. query = query.order_by(desc(PaperDB.quality_score))
  123. elif search.sort_by == "year":
  124. query = query.order_by(desc(PaperDB.publication_year))
  125. else:
  126. query = query.order_by(desc(PaperDB.created_at))
  127. # 分页
  128. papers_db = query.offset(search.offset).limit(search.limit).all()
  129. return [Paper.from_orm(paper) for paper in papers_db]
  130. def semantic_search(self, query: str, user_id: int, limit: int = 10) -> List[Paper]:
  131. """语义搜索论文"""
  132. # 生成查询向量
  133. query_embedding = self.embedding_service.get_embedding(query)
  134. # 在向量存储中搜索
  135. results = self.vector_store.search(query_embedding, user_id, limit)
  136. # 获取对应的论文
  137. paper_ids = [result['id'] for result in results]
  138. papers_db = self.db.query(PaperDB).filter(
  139. and_(
  140. PaperDB.id.in_(paper_ids),
  141. PaperDB.user_id == user_id
  142. )
  143. ).all()
  144. # 按相似度排序
  145. paper_dict = {paper.id: paper for paper in papers_db}
  146. sorted_papers = []
  147. for result in results:
  148. if result['id'] in paper_dict:
  149. paper = Paper.from_orm(paper_dict[result['id']])
  150. paper.relevance_score = result['score']
  151. sorted_papers.append(paper)
  152. return sorted_papers
  153. def _process_paper_async(self, paper_id: int):
  154. """异步处理论文(PDF解析和嵌入生成)"""
  155. try:
  156. paper_db = self.db.query(PaperDB).filter(PaperDB.id == paper_id).first()
  157. if not paper_db:
  158. return
  159. # 如果有PDF URL,下载并解析
  160. if paper_db.pdf_url and not paper_db.full_text:
  161. full_text = self.pdf_parser.parse_pdf_from_url(paper_db.pdf_url)
  162. if full_text:
  163. paper_db.full_text = full_text
  164. # 生成嵌入
  165. text_to_embed = paper_db.title + " " + (paper_db.abstract or "")
  166. if paper_db.full_text:
  167. text_to_embed += " " + paper_db.full_text
  168. embedding = self.embedding_service.get_embedding(text_to_embed)
  169. paper_db.embeddings = embedding.tolist()
  170. # 添加到向量存储
  171. self.vector_store.add_document(
  172. doc_id=paper_id,
  173. embedding=embedding,
  174. metadata={
  175. 'title': paper_db.title,
  176. 'user_id': paper_db.user_id
  177. }
  178. )
  179. paper_db.is_processed = True
  180. self.db.commit()
  181. except Exception as e:
  182. print(f"Error processing paper {paper_id}: {e}")
  183. # 可以在这里添加错误日志记录
  184. def get_paper_statistics(self, user_id: int) -> Dict[str, Any]:
  185. """获取论文统计信息"""
  186. total_papers = self.db.query(PaperDB).filter(PaperDB.user_id == user_id).count()
  187. processed_papers = self.db.query(PaperDB).filter(
  188. and_(PaperDB.user_id == user_id, PaperDB.is_processed == True)
  189. ).count()
  190. # 按年份统计
  191. year_stats = self.db.query(
  192. PaperDB.publication_year,
  193. self.db.func.count(PaperDB.id)
  194. ).filter(PaperDB.user_id == user_id).group_by(PaperDB.publication_year).all()
  195. # 按期刊统计
  196. journal_stats = self.db.query(
  197. PaperDB.journal,
  198. self.db.func.count(PaperDB.id)
  199. ).filter(PaperDB.user_id == user_id).group_by(PaperDB.journal).all()
  200. return {
  201. 'total_papers': total_papers,
  202. 'processed_papers': processed_papers,
  203. 'processing_rate': processed_papers / total_papers if total_papers > 0 else 0,
  204. 'year_distribution': dict(year_stats),
  205. 'journal_distribution': dict(journal_stats)
  206. }