| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462 |
- """文档存储实现
- 支持多种文档数据库后端:
- - SQLite: 轻量级关系型数据库
- - PostgreSQL: 企业级关系型数据库(可扩展)
- """
- from abc import ABC, abstractmethod
- from typing import List, Dict, Any, Optional
- import sqlite3
- import json
- import os
- import threading
- class DocumentStore(ABC):
- """文档存储基类"""
-
- @abstractmethod
- def add_memory(
- self,
- memory_id: str,
- user_id: str,
- content: str,
- memory_type: str,
- timestamp: int,
- importance: float,
- properties: Dict[str, Any] = None
- ) -> str:
- """添加记忆"""
- pass
-
- @abstractmethod
- def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
- """获取单个记忆"""
- pass
-
- @abstractmethod
- def search_memories(
- self,
- user_id: Optional[str] = None,
- memory_type: Optional[str] = None,
- text_query: Optional[str] = None,
- start_time: Optional[int] = None,
- end_time: Optional[int] = None,
- importance_threshold: Optional[float] = None,
- limit: int = 10
- ) -> List[Dict[str, Any]]:
- """搜索记忆"""
- pass
-
- @abstractmethod
- def update_memory(
- self,
- memory_id: str,
- content: str = None,
- importance: float = None,
- properties: Dict[str, Any] = None
- ) -> bool:
- """更新记忆"""
- pass
-
- @abstractmethod
- def delete_memory(self, memory_id: str) -> bool:
- """删除记忆"""
- pass
-
- @abstractmethod
- def get_database_stats(self) -> Dict[str, Any]:
- """获取数据库统计信息"""
- pass
-
- @abstractmethod
- def add_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
- """添加文档"""
- pass
-
- @abstractmethod
- def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
- """获取文档"""
- pass
- class SQLiteDocumentStore(DocumentStore):
- """SQLite文档存储实现"""
-
- _instances = {} # 存储已创建的实例
- _initialized_dbs = set() # 存储已初始化的数据库路径
-
- def __new__(cls, db_path: str = "./memory.db"):
- """单例模式,同一路径只创建一个实例"""
- abs_path = os.path.abspath(db_path)
- if abs_path not in cls._instances:
- instance = super(SQLiteDocumentStore, cls).__new__(cls)
- cls._instances[abs_path] = instance
- return cls._instances[abs_path]
-
- def __init__(self, db_path: str = "./memory.db"):
- # 避免重复初始化
- if hasattr(self, '_initialized'):
- return
-
- self.db_path = db_path
- self.local = threading.local()
-
- # 确保目录存在
- os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True)
-
- # 初始化数据库(只初始化一次)
- abs_path = os.path.abspath(db_path)
- if abs_path not in self._initialized_dbs:
- self._init_database()
- self._initialized_dbs.add(abs_path)
- print(f"[OK] SQLite 文档存储初始化完成: {db_path}")
-
- self._initialized = True
-
- def _get_connection(self):
- """获取线程本地连接"""
- if not hasattr(self.local, 'connection'):
- self.local.connection = sqlite3.connect(self.db_path)
- self.local.connection.row_factory = sqlite3.Row # 使结果可以按列名访问
- return self.local.connection
-
- def _init_database(self):
- """初始化数据库表"""
- conn = self._get_connection()
- cursor = conn.cursor()
-
- # 创建用户表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS users (
- id TEXT PRIMARY KEY,
- name TEXT,
- properties TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """)
-
- # 创建记忆表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS memories (
- id TEXT PRIMARY KEY,
- user_id TEXT NOT NULL,
- content TEXT NOT NULL,
- memory_type TEXT NOT NULL,
- timestamp INTEGER NOT NULL,
- importance REAL NOT NULL,
- properties TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- FOREIGN KEY (user_id) REFERENCES users (id)
- )
- """)
-
- # 创建概念表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS concepts (
- id TEXT PRIMARY KEY,
- name TEXT NOT NULL,
- description TEXT,
- properties TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """)
-
- # 创建记忆-概念关联表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS memory_concepts (
- memory_id TEXT NOT NULL,
- concept_id TEXT NOT NULL,
- relevance_score REAL DEFAULT 1.0,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- PRIMARY KEY (memory_id, concept_id),
- FOREIGN KEY (memory_id) REFERENCES memories (id) ON DELETE CASCADE,
- FOREIGN KEY (concept_id) REFERENCES concepts (id) ON DELETE CASCADE
- )
- """)
-
- # 创建概念关系表
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS concept_relationships (
- from_concept_id TEXT NOT NULL,
- to_concept_id TEXT NOT NULL,
- relationship_type TEXT NOT NULL,
- strength REAL DEFAULT 1.0,
- properties TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- PRIMARY KEY (from_concept_id, to_concept_id, relationship_type),
- FOREIGN KEY (from_concept_id) REFERENCES concepts (id) ON DELETE CASCADE,
- FOREIGN KEY (to_concept_id) REFERENCES concepts (id) ON DELETE CASCADE
- )
- """)
-
- # 创建索引
- indexes = [
- "CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id)",
- "CREATE INDEX IF NOT EXISTS idx_memories_type ON memories (memory_type)",
- "CREATE INDEX IF NOT EXISTS idx_memories_timestamp ON memories (timestamp)",
- "CREATE INDEX IF NOT EXISTS idx_memories_importance ON memories (importance)",
- "CREATE INDEX IF NOT EXISTS idx_memory_concepts_memory ON memory_concepts (memory_id)",
- "CREATE INDEX IF NOT EXISTS idx_memory_concepts_concept ON memory_concepts (concept_id)"
- ]
-
- for index_sql in indexes:
- cursor.execute(index_sql)
-
- conn.commit()
- print("[OK] SQLite 数据库表和索引创建完成")
-
- def add_memory(
- self,
- memory_id: str,
- user_id: str,
- content: str,
- memory_type: str,
- timestamp: int,
- importance: float,
- properties: Dict[str, Any] = None
- ) -> str:
- """添加记忆"""
- conn = self._get_connection()
- cursor = conn.cursor()
-
- # 确保用户存在
- cursor.execute("INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user_id, user_id))
-
- # 插入记忆
- cursor.execute("""
- INSERT OR REPLACE INTO memories
- (id, user_id, content, memory_type, timestamp, importance, properties, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
- """, (
- memory_id,
- user_id,
- content,
- memory_type,
- timestamp,
- importance,
- json.dumps(properties) if properties else None
- ))
-
- conn.commit()
- return memory_id
-
- def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
- """获取单个记忆"""
- conn = self._get_connection()
- cursor = conn.cursor()
-
- cursor.execute("""
- SELECT id, user_id, content, memory_type, timestamp, importance, properties, created_at
- FROM memories
- WHERE id = ?
- """, (memory_id,))
-
- row = cursor.fetchone()
- if not row:
- return None
-
- return {
- "memory_id": row["id"],
- "user_id": row["user_id"],
- "content": row["content"],
- "memory_type": row["memory_type"],
- "timestamp": row["timestamp"],
- "importance": row["importance"],
- "properties": json.loads(row["properties"]) if row["properties"] else {},
- "created_at": row["created_at"]
- }
-
- def search_memories(
- self,
- user_id: Optional[str] = None,
- memory_type: Optional[str] = None,
- text_query: Optional[str] = None,
- start_time: Optional[int] = None,
- end_time: Optional[int] = None,
- importance_threshold: Optional[float] = None,
- limit: int = 10
- ) -> List[Dict[str, Any]]:
- """搜索记忆"""
- conn = self._get_connection()
- cursor = conn.cursor()
-
- # 构建查询条件
- where_conditions = []
- params = []
-
- if user_id:
- where_conditions.append("user_id = ?")
- params.append(user_id)
-
- if memory_type:
- where_conditions.append("memory_type = ?")
- params.append(memory_type)
- if text_query:
- where_conditions.append("content LIKE ?")
- params.append(f"%{text_query}%")
-
- if start_time:
- where_conditions.append("timestamp >= ?")
- params.append(start_time)
-
- if end_time:
- where_conditions.append("timestamp <= ?")
- params.append(end_time)
-
- if importance_threshold:
- where_conditions.append("importance >= ?")
- params.append(importance_threshold)
-
- where_clause = ""
- if where_conditions:
- where_clause = "WHERE " + " AND ".join(where_conditions)
-
- cursor.execute(f"""
- SELECT id, user_id, content, memory_type, timestamp, importance, properties, created_at
- FROM memories
- {where_clause}
- ORDER BY importance DESC, timestamp DESC
- LIMIT ?
- """, params + [limit])
-
- memories = []
- for row in cursor.fetchall():
- memories.append({
- "memory_id": row["id"],
- "user_id": row["user_id"],
- "content": row["content"],
- "memory_type": row["memory_type"],
- "timestamp": row["timestamp"],
- "importance": row["importance"],
- "properties": json.loads(row["properties"]) if row["properties"] else {},
- "created_at": row["created_at"]
- })
-
- return memories
-
- def update_memory(
- self,
- memory_id: str,
- content: str = None,
- importance: float = None,
- properties: Dict[str, Any] = None
- ) -> bool:
- """更新记忆"""
- conn = self._get_connection()
- cursor = conn.cursor()
-
- # 构建更新字段
- update_fields = []
- params = []
-
- if content is not None:
- update_fields.append("content = ?")
- params.append(content)
-
- if importance is not None:
- update_fields.append("importance = ?")
- params.append(importance)
-
- if properties is not None:
- update_fields.append("properties = ?")
- params.append(json.dumps(properties))
-
- if not update_fields:
- return False
-
- update_fields.append("updated_at = CURRENT_TIMESTAMP")
- params.append(memory_id)
-
- cursor.execute(f"""
- UPDATE memories
- SET {', '.join(update_fields)}
- WHERE id = ?
- """, params)
-
- conn.commit()
- return cursor.rowcount > 0
-
- def delete_memory(self, memory_id: str) -> bool:
- """删除记忆"""
- conn = self._get_connection()
- cursor = conn.cursor()
-
- cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
- deleted_count = cursor.rowcount
-
- conn.commit()
- return deleted_count > 0
-
- def get_database_stats(self) -> Dict[str, Any]:
- """获取数据库统计信息"""
- conn = self._get_connection()
- cursor = conn.cursor()
-
- stats = {}
-
- # 统计各表的记录数
- tables = ["users", "memories", "concepts", "memory_concepts", "concept_relationships"]
- for table in tables:
- cursor.execute(f"SELECT COUNT(*) as count FROM {table}")
- stats[f"{table}_count"] = cursor.fetchone()["count"]
-
- # 统计记忆类型分布
- cursor.execute("""
- SELECT memory_type, COUNT(*) as count
- FROM memories
- GROUP BY memory_type
- """)
- memory_types = {}
- for row in cursor.fetchall():
- memory_types[row["memory_type"]] = row["count"]
- stats["memory_types"] = memory_types
-
- # 统计用户分布
- cursor.execute("""
- SELECT user_id, COUNT(*) as count
- FROM memories
- GROUP BY user_id
- ORDER BY count DESC
- LIMIT 10
- """)
- top_users = {}
- for row in cursor.fetchall():
- top_users[row["user_id"]] = row["count"]
- stats["top_users"] = top_users
-
- stats["store_type"] = "sqlite"
- stats["db_path"] = self.db_path
-
- return stats
-
- def add_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
- """添加文档"""
- import uuid
- import time
-
- doc_id = str(uuid.uuid4())
- user_id = metadata.get("user_id", "system") if metadata else "system"
-
- return self.add_memory(
- memory_id=doc_id,
- user_id=user_id,
- content=content,
- memory_type="document",
- timestamp=int(time.time()),
- importance=0.5,
- properties=metadata or {}
- )
-
- def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
- """获取文档"""
- return self.get_memory(document_id)
- def close(self):
- """关闭数据库连接"""
- if hasattr(self.local, 'connection'):
- self.local.connection.close()
- delattr(self.local, 'connection')
- print("[OK] SQLite 连接已关闭")
|