database.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. InnoCore AI 数据库管理模块
  3. """
  4. import asyncio
  5. import asyncpg
  6. from typing import Dict, List, Optional, Any, Union
  7. from datetime import datetime
  8. import json
  9. import uuid
  10. from contextlib import asynccontextmanager
  11. from .config import get_config
  12. from .exceptions import DatabaseException
  13. class DatabaseManager:
  14. """数据库管理器"""
  15. def __init__(self):
  16. self.config = get_config().database
  17. self.pool = None
  18. async def initialize(self):
  19. """初始化数据库连接池"""
  20. try:
  21. self.pool = await asyncpg.create_pool(
  22. host=self.config.host,
  23. port=self.config.port,
  24. database=self.config.database,
  25. user=self.config.username,
  26. password=self.config.password,
  27. min_size=1,
  28. max_size=self.config.pool_size
  29. )
  30. await self._create_tables()
  31. except Exception as e:
  32. raise DatabaseException(f"数据库初始化失败: {str(e)}")
  33. async def _create_tables(self):
  34. """创建数据库表"""
  35. create_tables_sql = """
  36. -- 用户表
  37. CREATE TABLE IF NOT EXISTS users (
  38. id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
  39. email VARCHAR(255) UNIQUE NOT NULL,
  40. profile JSONB DEFAULT '{}',
  41. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  42. );
  43. -- 论文表
  44. CREATE TABLE IF NOT EXISTS papers (
  45. id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
  46. title TEXT NOT NULL,
  47. authors TEXT[] DEFAULT '{}',
  48. abstract TEXT,
  49. doi VARCHAR(255) UNIQUE,
  50. file_path TEXT,
  51. content_hash VARCHAR(64) UNIQUE,
  52. is_preset BOOLEAN DEFAULT FALSE,
  53. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  54. );
  55. -- 用户论文关系表
  56. CREATE TABLE IF NOT EXISTS user_paper_relations (
  57. id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
  58. user_id UUID REFERENCES users(id) ON DELETE CASCADE,
  59. paper_id UUID REFERENCES papers(id) ON DELETE CASCADE,
  60. tags TEXT[] DEFAULT '{}',
  61. rating INTEGER DEFAULT 0,
  62. is_read BOOLEAN DEFAULT FALSE,
  63. added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  64. UNIQUE(user_id, paper_id)
  65. );
  66. -- 分析报告表
  67. CREATE TABLE IF NOT EXISTS analysis_reports (
  68. id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
  69. paper_id UUID REFERENCES papers(id) ON DELETE CASCADE,
  70. generated_for_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
  71. summary TEXT,
  72. innovation_point TEXT,
  73. limitation TEXT,
  74. future_idea TEXT,
  75. vector_ids JSONB DEFAULT '{}',
  76. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  77. );
  78. -- 引用缓存表
  79. CREATE TABLE IF NOT EXISTS reference_cache (
  80. doi VARCHAR(255) PRIMARY KEY,
  81. bibtex_std TEXT,
  82. is_verified BOOLEAN DEFAULT FALSE,
  83. last_check TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  84. );
  85. -- 创建索引
  86. CREATE INDEX IF NOT EXISTS idx_papers_content_hash ON papers(content_hash);
  87. CREATE INDEX IF NOT EXISTS idx_papers_doi ON papers(doi);
  88. CREATE INDEX IF NOT EXISTS idx_user_paper_relations_user_id ON user_paper_relations(user_id);
  89. CREATE INDEX IF NOT EXISTS idx_user_paper_relations_paper_id ON user_paper_relations(paper_id);
  90. CREATE INDEX IF NOT EXISTS idx_analysis_reports_paper_id ON analysis_reports(paper_id);
  91. CREATE INDEX IF NOT EXISTS idx_analysis_reports_user_id ON analysis_reports(generated_for_user_id);
  92. """
  93. async with self.pool.acquire() as conn:
  94. await conn.execute(create_tables_sql)
  95. @asynccontextmanager
  96. async def get_connection(self):
  97. """获取数据库连接"""
  98. if not self.pool:
  99. await self.initialize()
  100. async with self.pool.acquire() as conn:
  101. try:
  102. yield conn
  103. except Exception as e:
  104. raise DatabaseException(f"数据库操作失败: {str(e)}")
  105. # 用户相关操作
  106. async def create_user(self, email: str, profile: Dict = None) -> str:
  107. """创建用户"""
  108. async with self.get_connection() as conn:
  109. user_id = await conn.fetchval(
  110. "INSERT INTO users (email, profile) VALUES ($1, $2) RETURNING id",
  111. email, json.dumps(profile or {})
  112. )
  113. return str(user_id)
  114. async def get_user(self, user_id: str) -> Optional[Dict]:
  115. """获取用户信息"""
  116. async with self.get_connection() as conn:
  117. row = await conn.fetchrow(
  118. "SELECT * FROM users WHERE id = $1", user_id
  119. )
  120. return dict(row) if row else None
  121. async def update_user_profile(self, user_id: str, profile: Dict) -> bool:
  122. """更新用户配置"""
  123. async with self.get_connection() as conn:
  124. result = await conn.execute(
  125. "UPDATE users SET profile = $1 WHERE id = $2",
  126. json.dumps(profile), user_id
  127. )
  128. return result == "UPDATE 1"
  129. # 论文相关操作
  130. async def create_paper(self, title: str, authors: List[str],
  131. abstract: str = None, doi: str = None,
  132. file_path: str = None, content_hash: str = None,
  133. is_preset: bool = False) -> str:
  134. """创建论文记录"""
  135. async with self.get_connection() as conn:
  136. paper_id = await conn.fetchval(
  137. """
  138. INSERT INTO papers (title, authors, abstract, doi, file_path, content_hash, is_preset)
  139. VALUES ($1, $2, $3, $4, $5, $6, $7)
  140. RETURNING id
  141. """,
  142. title, authors, abstract, doi, file_path, content_hash, is_preset
  143. )
  144. return str(paper_id)
  145. async def get_paper(self, paper_id: str) -> Optional[Dict]:
  146. """获取论文信息"""
  147. async with self.get_connection() as conn:
  148. row = await conn.fetchrow(
  149. "SELECT * FROM papers WHERE id = $1", paper_id
  150. )
  151. return dict(row) if row else None
  152. async def get_paper_by_hash(self, content_hash: str) -> Optional[Dict]:
  153. """根据内容哈希获取论文"""
  154. async with self.get_connection() as conn:
  155. row = await conn.fetchrow(
  156. "SELECT * FROM papers WHERE content_hash = $1", content_hash
  157. )
  158. return dict(row) if row else None
  159. async def search_papers(self, query: str, limit: int = 10, offset: int = 0) -> List[Dict]:
  160. """搜索论文"""
  161. async with self.get_connection() as conn:
  162. rows = await conn.fetch(
  163. """
  164. SELECT * FROM papers
  165. WHERE title ILIKE $1 OR abstract ILIKE $1
  166. ORDER BY created_at DESC
  167. LIMIT $2 OFFSET $3
  168. """,
  169. f"%{query}%", limit, offset
  170. )
  171. return [dict(row) for row in rows]
  172. # 用户论文关系操作
  173. async def add_paper_to_user(self, user_id: str, paper_id: str,
  174. tags: List[str] = None, rating: int = 0) -> bool:
  175. """将论文添加到用户库"""
  176. async with self.get_connection() as conn:
  177. try:
  178. await conn.execute(
  179. """
  180. INSERT INTO user_paper_relations (user_id, paper_id, tags, rating)
  181. VALUES ($1, $2, $3, $4)
  182. ON CONFLICT (user_id, paper_id) DO UPDATE SET
  183. tags = EXCLUDED.tags,
  184. rating = EXCLUDED.rating,
  185. added_at = CURRENT_TIMESTAMP
  186. """,
  187. user_id, paper_id, tags or [], rating
  188. )
  189. return True
  190. except Exception:
  191. return False
  192. async def get_user_papers(self, user_id: str, limit: int = 50, offset: int = 0) -> List[Dict]:
  193. """获取用户的论文列表"""
  194. async with self.get_connection() as conn:
  195. rows = await conn.fetch(
  196. """
  197. SELECT p.*, upr.tags, upr.rating, upr.is_read, upr.added_at
  198. FROM papers p
  199. JOIN user_paper_relations upr ON p.id = upr.paper_id
  200. WHERE upr.user_id = $1
  201. ORDER BY upr.added_at DESC
  202. LIMIT $2 OFFSET $3
  203. """,
  204. user_id, limit, offset
  205. )
  206. return [dict(row) for row in rows]
  207. # 分析报告操作
  208. async def create_analysis_report(self, paper_id: str, summary: str,
  209. innovation_point: str, limitation: str,
  210. future_idea: str, vector_ids: Dict = None,
  211. user_id: str = None) -> str:
  212. """创建分析报告"""
  213. async with self.get_connection() as conn:
  214. report_id = await conn.fetchval(
  215. """
  216. INSERT INTO analysis_reports
  217. (paper_id, generated_for_user_id, summary, innovation_point, limitation, future_idea, vector_ids)
  218. VALUES ($1, $2, $3, $4, $5, $6, $7)
  219. RETURNING id
  220. """,
  221. paper_id, user_id, summary, innovation_point,
  222. limitation, future_idea, json.dumps(vector_ids or {})
  223. )
  224. return str(report_id)
  225. async def get_analysis_report(self, paper_id: str, user_id: str = None) -> Optional[Dict]:
  226. """获取分析报告"""
  227. async with self.get_connection() as conn:
  228. if user_id:
  229. row = await conn.fetchrow(
  230. """
  231. SELECT * FROM analysis_reports
  232. WHERE paper_id = $1 AND (generated_for_user_id = $2 OR generated_for_user_id IS NULL)
  233. ORDER BY created_at DESC LIMIT 1
  234. """,
  235. paper_id, user_id
  236. )
  237. else:
  238. row = await conn.fetchrow(
  239. """
  240. SELECT * FROM analysis_reports
  241. WHERE paper_id = $1 AND generated_for_user_id IS NULL
  242. ORDER BY created_at DESC LIMIT 1
  243. """,
  244. paper_id
  245. )
  246. return dict(row) if row else None
  247. # 引用缓存操作
  248. async def cache_reference(self, doi: str, bibtex: str, is_verified: bool = False):
  249. """缓存引用信息"""
  250. async with self.get_connection() as conn:
  251. await conn.execute(
  252. """
  253. INSERT INTO reference_cache (doi, bibtex_std, is_verified, last_check)
  254. VALUES ($1, $2, $3, CURRENT_TIMESTAMP)
  255. ON CONFLICT (doi) DO UPDATE SET
  256. bibtex_std = EXCLUDED.bibtex_std,
  257. is_verified = EXCLUDED.is_verified,
  258. last_check = CURRENT_TIMESTAMP
  259. """,
  260. doi, bibtex, is_verified
  261. )
  262. async def get_cached_reference(self, doi: str) -> Optional[Dict]:
  263. """获取缓存的引用信息"""
  264. async with self.get_connection() as conn:
  265. row = await conn.fetchrow(
  266. "SELECT * FROM reference_cache WHERE doi = $1", doi
  267. )
  268. return dict(row) if row else None
  269. async def close(self):
  270. """关闭数据库连接池"""
  271. if self.pool:
  272. await self.pool.close()
  273. # 全局数据库管理器实例
  274. db_manager = DatabaseManager()