analysis_service.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """
  2. 分析服务
  3. """
  4. from typing import Optional, List, Dict, Any
  5. from sqlalchemy.orm import Session
  6. from ..core.database import get_db
  7. from ..models.analysis import AnalysisDB, Analysis, AnalysisCreate, AnalysisUpdate
  8. from ..core.exceptions import AnalysisNotFoundError
  9. from ..services.paper_service import PaperService
  10. import json
  11. class AnalysisService:
  12. """分析服务类"""
  13. def __init__(self, db: Session):
  14. self.db = db
  15. self.paper_service = PaperService(db)
  16. def get_analysis_by_id(self, analysis_id: int) -> Optional[Analysis]:
  17. """根据ID获取分析"""
  18. analysis_db = self.db.query(AnalysisDB).filter(AnalysisDB.id == analysis_id).first()
  19. if not analysis_db:
  20. raise AnalysisNotFoundError(f"Analysis with id {analysis_id} not found")
  21. return Analysis.from_orm(analysis_db)
  22. def get_analyses_by_user(self, user_id: int, skip: int = 0, limit: int = 20) -> List[Analysis]:
  23. """获取用户的分析列表"""
  24. analyses_db = self.db.query(AnalysisDB).filter(
  25. AnalysisDB.user_id == user_id
  26. ).order_by(AnalysisDB.created_at.desc()).offset(skip).limit(limit).all()
  27. return [Analysis.from_orm(analysis) for analysis in analyses_db]
  28. def create_analysis(self, analysis_create: AnalysisCreate, user_id: int, task_id: Optional[int] = None) -> Analysis:
  29. """创建分析"""
  30. analysis_db = AnalysisDB(
  31. title=analysis_create.title,
  32. analysis_type=analysis_create.analysis_type,
  33. paper_ids=json.dumps(analysis_create.paper_ids),
  34. methodology=analysis_create.methodology,
  35. user_id=user_id,
  36. task_id=task_id
  37. )
  38. self.db.add(analysis_db)
  39. self.db.commit()
  40. self.db.refresh(analysis_db)
  41. return Analysis.from_orm(analysis_db)
  42. def update_analysis(self, analysis_id: int, analysis_update: AnalysisUpdate) -> Analysis:
  43. """更新分析"""
  44. analysis_db = self.db.query(AnalysisDB).filter(AnalysisDB.id == analysis_id).first()
  45. if not analysis_db:
  46. raise AnalysisNotFoundError(f"Analysis with id {analysis_id} not found")
  47. # 更新字段
  48. update_data = analysis_update.dict(exclude_unset=True)
  49. for field, value in update_data.items():
  50. if field == 'findings':
  51. setattr(analysis_db, field, json.dumps(value))
  52. else:
  53. setattr(analysis_db, field, value)
  54. self.db.commit()
  55. self.db.refresh(analysis_db)
  56. return Analysis.from_orm(analysis_db)
  57. def delete_analysis(self, analysis_id: int) -> bool:
  58. """删除分析"""
  59. analysis_db = self.db.query(AnalysisDB).filter(AnalysisDB.id == analysis_id).first()
  60. if not analysis_db:
  61. raise AnalysisNotFoundError(f"Analysis with id {analysis_id} not found")
  62. self.db.delete(analysis_db)
  63. self.db.commit()
  64. return True
  65. def get_analysis_statistics(self, user_id: int) -> Dict[str, Any]:
  66. """获取分析统计信息"""
  67. total_analyses = self.db.query(AnalysisDB).filter(AnalysisDB.user_id == user_id).count()
  68. # 按类型统计
  69. type_stats = self.db.query(
  70. AnalysisDB.analysis_type,
  71. self.db.func.count(AnalysisDB.id)
  72. ).filter(AnalysisDB.user_id == user_id).group_by(AnalysisDB.analysis_type).all()
  73. # 平均分数
  74. avg_scores = self.db.query(
  75. self.db.func.avg(AnalysisDB.confidence_score),
  76. self.db.func.avg(AnalysisDB.novelty_score),
  77. self.db.func.avg(AnalysisDB.impact_score)
  78. ).filter(AnalysisDB.user_id == user_id).first()
  79. return {
  80. 'total_analyses': total_analyses,
  81. 'type_distribution': dict(type_stats),
  82. 'average_confidence': float(avg_scores[0] or 0),
  83. 'average_novelty': float(avg_scores[1] or 0),
  84. 'average_impact': float(avg_scores[2] or 0)
  85. }
  86. def get_related_analyses(self, analysis_id: int, limit: int = 5) -> List[Analysis]:
  87. """获取相关分析"""
  88. analysis_db = self.db.query(AnalysisDB).filter(AnalysisDB.id == analysis_id).first()
  89. if not analysis_db:
  90. raise AnalysisNotFoundError(f"Analysis with id {analysis_id} not found")
  91. # 获取相同类型的分析
  92. related_analyses = self.db.query(AnalysisDB).filter(
  93. and_(
  94. AnalysisDB.user_id == analysis_db.user_id,
  95. AnalysisDB.analysis_type == analysis_db.analysis_type,
  96. AnalysisDB.id != analysis_id
  97. )
  98. ).order_by(AnalysisDB.created_at.desc()).limit(limit).all()
  99. return [Analysis.from_orm(analysis) for analysis in related_analyses]
  100. def export_analysis(self, analysis_id: int, format: str = "json") -> Dict[str, Any]:
  101. """导出分析结果"""
  102. analysis = self.get_analysis_by_id(analysis_id)
  103. if format == "json":
  104. return analysis.dict()
  105. elif format == "markdown":
  106. return self._export_to_markdown(analysis)
  107. elif format == "pdf":
  108. return self._export_to_pdf(analysis)
  109. else:
  110. raise ValueError(f"Unsupported export format: {format}")
  111. def _export_to_markdown(self, analysis: Analysis) -> str:
  112. """导出为Markdown格式"""
  113. markdown = f"# {analysis.title}\n\n"
  114. markdown += f"**分析类型**: {analysis.analysis_type}\n\n"
  115. markdown += f"**创建时间**: {analysis.created_at.strftime('%Y-%m-%d %H:%M:%S')}\n\n"
  116. if analysis.methodology:
  117. markdown += f"## 方法论\n\n{analysis.methodology}\n\n"
  118. if analysis.findings:
  119. markdown += "## 主要发现\n\n"
  120. for key, value in analysis.findings.items():
  121. markdown += f"### {key}\n\n{value}\n\n"
  122. if analysis.insights:
  123. markdown += f"## 洞察\n\n{analysis.insights}\n\n"
  124. if analysis.limitations:
  125. markdown += f"## 局限性\n\n{analysis.limitations}\n\n"
  126. if analysis.recommendations:
  127. markdown += f"## 建议\n\n{analysis.recommendations}\n\n"
  128. # 添加评分
  129. markdown += "## 评分\n\n"
  130. markdown += f"- **置信度**: {analysis.confidence_score:.2f}\n"
  131. markdown += f"- **新颖性**: {analysis.novelty_score:.2f}\n"
  132. markdown += f"- **影响力**: {analysis.impact_score:.2f}\n"
  133. return markdown
  134. def _export_to_pdf(self, analysis: Analysis) -> bytes:
  135. """导出为PDF格式"""
  136. # 这里可以使用reportlab或其他PDF生成库
  137. # 暂时返回Markdown内容的字节
  138. markdown_content = self._export_to_markdown(analysis)
  139. return markdown_content.encode('utf-8')