document_store.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. """文档存储实现
  2. 支持多种文档数据库后端:
  3. - SQLite: 轻量级关系型数据库
  4. - PostgreSQL: 企业级关系型数据库(可扩展)
  5. """
  6. from abc import ABC, abstractmethod
  7. from typing import List, Dict, Any, Optional
  8. import sqlite3
  9. import json
  10. import os
  11. import threading
  12. class DocumentStore(ABC):
  13. """文档存储基类"""
  14. @abstractmethod
  15. def add_memory(
  16. self,
  17. memory_id: str,
  18. user_id: str,
  19. content: str,
  20. memory_type: str,
  21. timestamp: int,
  22. importance: float,
  23. properties: Dict[str, Any] = None
  24. ) -> str:
  25. """添加记忆"""
  26. pass
  27. @abstractmethod
  28. def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
  29. """获取单个记忆"""
  30. pass
  31. @abstractmethod
  32. def search_memories(
  33. self,
  34. user_id: Optional[str] = None,
  35. memory_type: Optional[str] = None,
  36. text_query: Optional[str] = None,
  37. start_time: Optional[int] = None,
  38. end_time: Optional[int] = None,
  39. importance_threshold: Optional[float] = None,
  40. limit: int = 10
  41. ) -> List[Dict[str, Any]]:
  42. """搜索记忆"""
  43. pass
  44. @abstractmethod
  45. def update_memory(
  46. self,
  47. memory_id: str,
  48. content: str = None,
  49. importance: float = None,
  50. properties: Dict[str, Any] = None
  51. ) -> bool:
  52. """更新记忆"""
  53. pass
  54. @abstractmethod
  55. def delete_memory(self, memory_id: str) -> bool:
  56. """删除记忆"""
  57. pass
  58. @abstractmethod
  59. def get_database_stats(self) -> Dict[str, Any]:
  60. """获取数据库统计信息"""
  61. pass
  62. @abstractmethod
  63. def add_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
  64. """添加文档"""
  65. pass
  66. @abstractmethod
  67. def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
  68. """获取文档"""
  69. pass
  70. class SQLiteDocumentStore(DocumentStore):
  71. """SQLite文档存储实现"""
  72. _instances = {} # 存储已创建的实例
  73. _initialized_dbs = set() # 存储已初始化的数据库路径
  74. def __new__(cls, db_path: str = "./memory.db"):
  75. """单例模式,同一路径只创建一个实例"""
  76. abs_path = os.path.abspath(db_path)
  77. if abs_path not in cls._instances:
  78. instance = super(SQLiteDocumentStore, cls).__new__(cls)
  79. cls._instances[abs_path] = instance
  80. return cls._instances[abs_path]
  81. def __init__(self, db_path: str = "./memory.db"):
  82. # 避免重复初始化
  83. if hasattr(self, '_initialized'):
  84. return
  85. self.db_path = db_path
  86. self.local = threading.local()
  87. # 确保目录存在
  88. os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True)
  89. # 初始化数据库(只初始化一次)
  90. abs_path = os.path.abspath(db_path)
  91. if abs_path not in self._initialized_dbs:
  92. self._init_database()
  93. self._initialized_dbs.add(abs_path)
  94. print(f"[OK] SQLite 文档存储初始化完成: {db_path}")
  95. self._initialized = True
  96. def _get_connection(self):
  97. """获取线程本地连接"""
  98. if not hasattr(self.local, 'connection'):
  99. self.local.connection = sqlite3.connect(self.db_path)
  100. self.local.connection.row_factory = sqlite3.Row # 使结果可以按列名访问
  101. return self.local.connection
  102. def _init_database(self):
  103. """初始化数据库表"""
  104. conn = self._get_connection()
  105. cursor = conn.cursor()
  106. # 创建用户表
  107. cursor.execute("""
  108. CREATE TABLE IF NOT EXISTS users (
  109. id TEXT PRIMARY KEY,
  110. name TEXT,
  111. properties TEXT,
  112. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  113. )
  114. """)
  115. # 创建记忆表
  116. cursor.execute("""
  117. CREATE TABLE IF NOT EXISTS memories (
  118. id TEXT PRIMARY KEY,
  119. user_id TEXT NOT NULL,
  120. content TEXT NOT NULL,
  121. memory_type TEXT NOT NULL,
  122. timestamp INTEGER NOT NULL,
  123. importance REAL NOT NULL,
  124. properties TEXT,
  125. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  126. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  127. FOREIGN KEY (user_id) REFERENCES users (id)
  128. )
  129. """)
  130. # 创建概念表
  131. cursor.execute("""
  132. CREATE TABLE IF NOT EXISTS concepts (
  133. id TEXT PRIMARY KEY,
  134. name TEXT NOT NULL,
  135. description TEXT,
  136. properties TEXT,
  137. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  138. )
  139. """)
  140. # 创建记忆-概念关联表
  141. cursor.execute("""
  142. CREATE TABLE IF NOT EXISTS memory_concepts (
  143. memory_id TEXT NOT NULL,
  144. concept_id TEXT NOT NULL,
  145. relevance_score REAL DEFAULT 1.0,
  146. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  147. PRIMARY KEY (memory_id, concept_id),
  148. FOREIGN KEY (memory_id) REFERENCES memories (id) ON DELETE CASCADE,
  149. FOREIGN KEY (concept_id) REFERENCES concepts (id) ON DELETE CASCADE
  150. )
  151. """)
  152. # 创建概念关系表
  153. cursor.execute("""
  154. CREATE TABLE IF NOT EXISTS concept_relationships (
  155. from_concept_id TEXT NOT NULL,
  156. to_concept_id TEXT NOT NULL,
  157. relationship_type TEXT NOT NULL,
  158. strength REAL DEFAULT 1.0,
  159. properties TEXT,
  160. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  161. PRIMARY KEY (from_concept_id, to_concept_id, relationship_type),
  162. FOREIGN KEY (from_concept_id) REFERENCES concepts (id) ON DELETE CASCADE,
  163. FOREIGN KEY (to_concept_id) REFERENCES concepts (id) ON DELETE CASCADE
  164. )
  165. """)
  166. # 创建索引
  167. indexes = [
  168. "CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories (user_id)",
  169. "CREATE INDEX IF NOT EXISTS idx_memories_type ON memories (memory_type)",
  170. "CREATE INDEX IF NOT EXISTS idx_memories_timestamp ON memories (timestamp)",
  171. "CREATE INDEX IF NOT EXISTS idx_memories_importance ON memories (importance)",
  172. "CREATE INDEX IF NOT EXISTS idx_memory_concepts_memory ON memory_concepts (memory_id)",
  173. "CREATE INDEX IF NOT EXISTS idx_memory_concepts_concept ON memory_concepts (concept_id)"
  174. ]
  175. for index_sql in indexes:
  176. cursor.execute(index_sql)
  177. conn.commit()
  178. print("[OK] SQLite 数据库表和索引创建完成")
  179. def add_memory(
  180. self,
  181. memory_id: str,
  182. user_id: str,
  183. content: str,
  184. memory_type: str,
  185. timestamp: int,
  186. importance: float,
  187. properties: Dict[str, Any] = None
  188. ) -> str:
  189. """添加记忆"""
  190. conn = self._get_connection()
  191. cursor = conn.cursor()
  192. # 确保用户存在
  193. cursor.execute("INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user_id, user_id))
  194. # 插入记忆
  195. cursor.execute("""
  196. INSERT OR REPLACE INTO memories
  197. (id, user_id, content, memory_type, timestamp, importance, properties, updated_at)
  198. VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
  199. """, (
  200. memory_id,
  201. user_id,
  202. content,
  203. memory_type,
  204. timestamp,
  205. importance,
  206. json.dumps(properties) if properties else None
  207. ))
  208. conn.commit()
  209. return memory_id
  210. def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
  211. """获取单个记忆"""
  212. conn = self._get_connection()
  213. cursor = conn.cursor()
  214. cursor.execute("""
  215. SELECT id, user_id, content, memory_type, timestamp, importance, properties, created_at
  216. FROM memories
  217. WHERE id = ?
  218. """, (memory_id,))
  219. row = cursor.fetchone()
  220. if not row:
  221. return None
  222. return {
  223. "memory_id": row["id"],
  224. "user_id": row["user_id"],
  225. "content": row["content"],
  226. "memory_type": row["memory_type"],
  227. "timestamp": row["timestamp"],
  228. "importance": row["importance"],
  229. "properties": json.loads(row["properties"]) if row["properties"] else {},
  230. "created_at": row["created_at"]
  231. }
  232. def search_memories(
  233. self,
  234. user_id: Optional[str] = None,
  235. memory_type: Optional[str] = None,
  236. text_query: Optional[str] = None,
  237. start_time: Optional[int] = None,
  238. end_time: Optional[int] = None,
  239. importance_threshold: Optional[float] = None,
  240. limit: int = 10
  241. ) -> List[Dict[str, Any]]:
  242. """搜索记忆"""
  243. conn = self._get_connection()
  244. cursor = conn.cursor()
  245. # 构建查询条件
  246. where_conditions = []
  247. params = []
  248. if user_id:
  249. where_conditions.append("user_id = ?")
  250. params.append(user_id)
  251. if memory_type:
  252. where_conditions.append("memory_type = ?")
  253. params.append(memory_type)
  254. if text_query:
  255. where_conditions.append("content LIKE ?")
  256. params.append(f"%{text_query}%")
  257. if start_time:
  258. where_conditions.append("timestamp >= ?")
  259. params.append(start_time)
  260. if end_time:
  261. where_conditions.append("timestamp <= ?")
  262. params.append(end_time)
  263. if importance_threshold:
  264. where_conditions.append("importance >= ?")
  265. params.append(importance_threshold)
  266. where_clause = ""
  267. if where_conditions:
  268. where_clause = "WHERE " + " AND ".join(where_conditions)
  269. cursor.execute(f"""
  270. SELECT id, user_id, content, memory_type, timestamp, importance, properties, created_at
  271. FROM memories
  272. {where_clause}
  273. ORDER BY importance DESC, timestamp DESC
  274. LIMIT ?
  275. """, params + [limit])
  276. memories = []
  277. for row in cursor.fetchall():
  278. memories.append({
  279. "memory_id": row["id"],
  280. "user_id": row["user_id"],
  281. "content": row["content"],
  282. "memory_type": row["memory_type"],
  283. "timestamp": row["timestamp"],
  284. "importance": row["importance"],
  285. "properties": json.loads(row["properties"]) if row["properties"] else {},
  286. "created_at": row["created_at"]
  287. })
  288. return memories
  289. def update_memory(
  290. self,
  291. memory_id: str,
  292. content: str = None,
  293. importance: float = None,
  294. properties: Dict[str, Any] = None
  295. ) -> bool:
  296. """更新记忆"""
  297. conn = self._get_connection()
  298. cursor = conn.cursor()
  299. # 构建更新字段
  300. update_fields = []
  301. params = []
  302. if content is not None:
  303. update_fields.append("content = ?")
  304. params.append(content)
  305. if importance is not None:
  306. update_fields.append("importance = ?")
  307. params.append(importance)
  308. if properties is not None:
  309. update_fields.append("properties = ?")
  310. params.append(json.dumps(properties))
  311. if not update_fields:
  312. return False
  313. update_fields.append("updated_at = CURRENT_TIMESTAMP")
  314. params.append(memory_id)
  315. cursor.execute(f"""
  316. UPDATE memories
  317. SET {', '.join(update_fields)}
  318. WHERE id = ?
  319. """, params)
  320. conn.commit()
  321. return cursor.rowcount > 0
  322. def delete_memory(self, memory_id: str) -> bool:
  323. """删除记忆"""
  324. conn = self._get_connection()
  325. cursor = conn.cursor()
  326. cursor.execute("DELETE FROM memories WHERE id = ?", (memory_id,))
  327. deleted_count = cursor.rowcount
  328. conn.commit()
  329. return deleted_count > 0
  330. def get_database_stats(self) -> Dict[str, Any]:
  331. """获取数据库统计信息"""
  332. conn = self._get_connection()
  333. cursor = conn.cursor()
  334. stats = {}
  335. # 统计各表的记录数
  336. tables = ["users", "memories", "concepts", "memory_concepts", "concept_relationships"]
  337. for table in tables:
  338. cursor.execute(f"SELECT COUNT(*) as count FROM {table}")
  339. stats[f"{table}_count"] = cursor.fetchone()["count"]
  340. # 统计记忆类型分布
  341. cursor.execute("""
  342. SELECT memory_type, COUNT(*) as count
  343. FROM memories
  344. GROUP BY memory_type
  345. """)
  346. memory_types = {}
  347. for row in cursor.fetchall():
  348. memory_types[row["memory_type"]] = row["count"]
  349. stats["memory_types"] = memory_types
  350. # 统计用户分布
  351. cursor.execute("""
  352. SELECT user_id, COUNT(*) as count
  353. FROM memories
  354. GROUP BY user_id
  355. ORDER BY count DESC
  356. LIMIT 10
  357. """)
  358. top_users = {}
  359. for row in cursor.fetchall():
  360. top_users[row["user_id"]] = row["count"]
  361. stats["top_users"] = top_users
  362. stats["store_type"] = "sqlite"
  363. stats["db_path"] = self.db_path
  364. return stats
  365. def add_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
  366. """添加文档"""
  367. import uuid
  368. import time
  369. doc_id = str(uuid.uuid4())
  370. user_id = metadata.get("user_id", "system") if metadata else "system"
  371. return self.add_memory(
  372. memory_id=doc_id,
  373. user_id=user_id,
  374. content=content,
  375. memory_type="document",
  376. timestamp=int(time.time()),
  377. importance=0.5,
  378. properties=metadata or {}
  379. )
  380. def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
  381. """获取文档"""
  382. return self.get_memory(document_id)
  383. def close(self):
  384. """关闭数据库连接"""
  385. if hasattr(self.local, 'connection'):
  386. self.local.connection.close()
  387. delattr(self.local, 'connection')
  388. print("[OK] SQLite 连接已关闭")