| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- """
- InnoCore AI 数据库管理模块
- """
- import asyncio
- import asyncpg
- from typing import Dict, List, Optional, Any, Union
- from datetime import datetime
- import json
- import uuid
- from contextlib import asynccontextmanager
- from .config import get_config
- from .exceptions import DatabaseException
- class DatabaseManager:
- """数据库管理器"""
-
- def __init__(self):
- self.config = get_config().database
- self.pool = None
-
- async def initialize(self):
- """初始化数据库连接池"""
- try:
- self.pool = await asyncpg.create_pool(
- host=self.config.host,
- port=self.config.port,
- database=self.config.database,
- user=self.config.username,
- password=self.config.password,
- min_size=1,
- max_size=self.config.pool_size
- )
- await self._create_tables()
- except Exception as e:
- raise DatabaseException(f"数据库初始化失败: {str(e)}")
-
- async def _create_tables(self):
- """创建数据库表"""
- create_tables_sql = """
- -- 用户表
- CREATE TABLE IF NOT EXISTS users (
- id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
- email VARCHAR(255) UNIQUE NOT NULL,
- profile JSONB DEFAULT '{}',
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- );
-
- -- 论文表
- CREATE TABLE IF NOT EXISTS papers (
- id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
- title TEXT NOT NULL,
- authors TEXT[] DEFAULT '{}',
- abstract TEXT,
- doi VARCHAR(255) UNIQUE,
- file_path TEXT,
- content_hash VARCHAR(64) UNIQUE,
- is_preset BOOLEAN DEFAULT FALSE,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- );
-
- -- 用户论文关系表
- CREATE TABLE IF NOT EXISTS user_paper_relations (
- id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
- user_id UUID REFERENCES users(id) ON DELETE CASCADE,
- paper_id UUID REFERENCES papers(id) ON DELETE CASCADE,
- tags TEXT[] DEFAULT '{}',
- rating INTEGER DEFAULT 0,
- is_read BOOLEAN DEFAULT FALSE,
- added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- UNIQUE(user_id, paper_id)
- );
-
- -- 分析报告表
- CREATE TABLE IF NOT EXISTS analysis_reports (
- id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
- paper_id UUID REFERENCES papers(id) ON DELETE CASCADE,
- generated_for_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
- summary TEXT,
- innovation_point TEXT,
- limitation TEXT,
- future_idea TEXT,
- vector_ids JSONB DEFAULT '{}',
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- );
-
- -- 引用缓存表
- CREATE TABLE IF NOT EXISTS reference_cache (
- doi VARCHAR(255) PRIMARY KEY,
- bibtex_std TEXT,
- is_verified BOOLEAN DEFAULT FALSE,
- last_check TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- );
-
- -- 创建索引
- CREATE INDEX IF NOT EXISTS idx_papers_content_hash ON papers(content_hash);
- CREATE INDEX IF NOT EXISTS idx_papers_doi ON papers(doi);
- CREATE INDEX IF NOT EXISTS idx_user_paper_relations_user_id ON user_paper_relations(user_id);
- CREATE INDEX IF NOT EXISTS idx_user_paper_relations_paper_id ON user_paper_relations(paper_id);
- CREATE INDEX IF NOT EXISTS idx_analysis_reports_paper_id ON analysis_reports(paper_id);
- CREATE INDEX IF NOT EXISTS idx_analysis_reports_user_id ON analysis_reports(generated_for_user_id);
- """
-
- async with self.pool.acquire() as conn:
- await conn.execute(create_tables_sql)
-
- @asynccontextmanager
- async def get_connection(self):
- """获取数据库连接"""
- if not self.pool:
- await self.initialize()
-
- async with self.pool.acquire() as conn:
- try:
- yield conn
- except Exception as e:
- raise DatabaseException(f"数据库操作失败: {str(e)}")
-
- # 用户相关操作
- async def create_user(self, email: str, profile: Dict = None) -> str:
- """创建用户"""
- async with self.get_connection() as conn:
- user_id = await conn.fetchval(
- "INSERT INTO users (email, profile) VALUES ($1, $2) RETURNING id",
- email, json.dumps(profile or {})
- )
- return str(user_id)
-
- async def get_user(self, user_id: str) -> Optional[Dict]:
- """获取用户信息"""
- async with self.get_connection() as conn:
- row = await conn.fetchrow(
- "SELECT * FROM users WHERE id = $1", user_id
- )
- return dict(row) if row else None
-
- async def update_user_profile(self, user_id: str, profile: Dict) -> bool:
- """更新用户配置"""
- async with self.get_connection() as conn:
- result = await conn.execute(
- "UPDATE users SET profile = $1 WHERE id = $2",
- json.dumps(profile), user_id
- )
- return result == "UPDATE 1"
-
- # 论文相关操作
- async def create_paper(self, title: str, authors: List[str],
- abstract: str = None, doi: str = None,
- file_path: str = None, content_hash: str = None,
- is_preset: bool = False) -> str:
- """创建论文记录"""
- async with self.get_connection() as conn:
- paper_id = await conn.fetchval(
- """
- INSERT INTO papers (title, authors, abstract, doi, file_path, content_hash, is_preset)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- RETURNING id
- """,
- title, authors, abstract, doi, file_path, content_hash, is_preset
- )
- return str(paper_id)
-
- async def get_paper(self, paper_id: str) -> Optional[Dict]:
- """获取论文信息"""
- async with self.get_connection() as conn:
- row = await conn.fetchrow(
- "SELECT * FROM papers WHERE id = $1", paper_id
- )
- return dict(row) if row else None
-
- async def get_paper_by_hash(self, content_hash: str) -> Optional[Dict]:
- """根据内容哈希获取论文"""
- async with self.get_connection() as conn:
- row = await conn.fetchrow(
- "SELECT * FROM papers WHERE content_hash = $1", content_hash
- )
- return dict(row) if row else None
-
- async def search_papers(self, query: str, limit: int = 10, offset: int = 0) -> List[Dict]:
- """搜索论文"""
- async with self.get_connection() as conn:
- rows = await conn.fetch(
- """
- SELECT * FROM papers
- WHERE title ILIKE $1 OR abstract ILIKE $1
- ORDER BY created_at DESC
- LIMIT $2 OFFSET $3
- """,
- f"%{query}%", limit, offset
- )
- return [dict(row) for row in rows]
-
- # 用户论文关系操作
- async def add_paper_to_user(self, user_id: str, paper_id: str,
- tags: List[str] = None, rating: int = 0) -> bool:
- """将论文添加到用户库"""
- async with self.get_connection() as conn:
- try:
- await conn.execute(
- """
- INSERT INTO user_paper_relations (user_id, paper_id, tags, rating)
- VALUES ($1, $2, $3, $4)
- ON CONFLICT (user_id, paper_id) DO UPDATE SET
- tags = EXCLUDED.tags,
- rating = EXCLUDED.rating,
- added_at = CURRENT_TIMESTAMP
- """,
- user_id, paper_id, tags or [], rating
- )
- return True
- except Exception:
- return False
-
- async def get_user_papers(self, user_id: str, limit: int = 50, offset: int = 0) -> List[Dict]:
- """获取用户的论文列表"""
- async with self.get_connection() as conn:
- rows = await conn.fetch(
- """
- SELECT p.*, upr.tags, upr.rating, upr.is_read, upr.added_at
- FROM papers p
- JOIN user_paper_relations upr ON p.id = upr.paper_id
- WHERE upr.user_id = $1
- ORDER BY upr.added_at DESC
- LIMIT $2 OFFSET $3
- """,
- user_id, limit, offset
- )
- return [dict(row) for row in rows]
-
- # 分析报告操作
- async def create_analysis_report(self, paper_id: str, summary: str,
- innovation_point: str, limitation: str,
- future_idea: str, vector_ids: Dict = None,
- user_id: str = None) -> str:
- """创建分析报告"""
- async with self.get_connection() as conn:
- report_id = await conn.fetchval(
- """
- INSERT INTO analysis_reports
- (paper_id, generated_for_user_id, summary, innovation_point, limitation, future_idea, vector_ids)
- VALUES ($1, $2, $3, $4, $5, $6, $7)
- RETURNING id
- """,
- paper_id, user_id, summary, innovation_point,
- limitation, future_idea, json.dumps(vector_ids or {})
- )
- return str(report_id)
-
- async def get_analysis_report(self, paper_id: str, user_id: str = None) -> Optional[Dict]:
- """获取分析报告"""
- async with self.get_connection() as conn:
- if user_id:
- row = await conn.fetchrow(
- """
- SELECT * FROM analysis_reports
- WHERE paper_id = $1 AND (generated_for_user_id = $2 OR generated_for_user_id IS NULL)
- ORDER BY created_at DESC LIMIT 1
- """,
- paper_id, user_id
- )
- else:
- row = await conn.fetchrow(
- """
- SELECT * FROM analysis_reports
- WHERE paper_id = $1 AND generated_for_user_id IS NULL
- ORDER BY created_at DESC LIMIT 1
- """,
- paper_id
- )
- return dict(row) if row else None
-
- # 引用缓存操作
- async def cache_reference(self, doi: str, bibtex: str, is_verified: bool = False):
- """缓存引用信息"""
- async with self.get_connection() as conn:
- await conn.execute(
- """
- INSERT INTO reference_cache (doi, bibtex_std, is_verified, last_check)
- VALUES ($1, $2, $3, CURRENT_TIMESTAMP)
- ON CONFLICT (doi) DO UPDATE SET
- bibtex_std = EXCLUDED.bibtex_std,
- is_verified = EXCLUDED.is_verified,
- last_check = CURRENT_TIMESTAMP
- """,
- doi, bibtex, is_verified
- )
-
- async def get_cached_reference(self, doi: str) -> Optional[Dict]:
- """获取缓存的引用信息"""
- async with self.get_connection() as conn:
- row = await conn.fetchrow(
- "SELECT * FROM reference_cache WHERE doi = $1", doi
- )
- return dict(row) if row else None
-
- async def close(self):
- """关闭数据库连接池"""
- if self.pool:
- await self.pool.close()
- # 全局数据库管理器实例
- db_manager = DatabaseManager()
|