qdrant_store.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. """
  2. Qdrant向量数据库存储实现
  3. 使用专业的Qdrant向量数据库替代ChromaDB
  4. """
  5. import logging
  6. import os
  7. import uuid
  8. import threading
  9. from typing import Dict, List, Optional, Any, Union
  10. import numpy as np
  11. from datetime import datetime
  12. try:
  13. from qdrant_client import QdrantClient
  14. from qdrant_client.http import models
  15. from qdrant_client.http.models import (
  16. Distance, VectorParams, PointStruct,
  17. Filter, FieldCondition, MatchValue, SearchRequest
  18. )
  19. QDRANT_AVAILABLE = True
  20. except ImportError:
  21. QDRANT_AVAILABLE = False
  22. QdrantClient = None
  23. models = None
  24. logger = logging.getLogger(__name__)
  25. class QdrantConnectionManager:
  26. """Qdrant连接管理器 - 防止重复连接和初始化"""
  27. _instances = {} # key: (url, collection_name) -> QdrantVectorStore instance
  28. _lock = threading.Lock()
  29. @classmethod
  30. def get_instance(
  31. cls,
  32. url: Optional[str] = None,
  33. api_key: Optional[str] = None,
  34. collection_name: str = "hello_agents_vectors",
  35. vector_size: int = 384,
  36. distance: str = "cosine",
  37. timeout: int = 30,
  38. **kwargs
  39. ) -> 'QdrantVectorStore':
  40. """获取或创建Qdrant实例(单例模式)"""
  41. # 创建唯一键
  42. key = (url or "local", collection_name)
  43. if key not in cls._instances:
  44. with cls._lock:
  45. # 双重检查锁定
  46. if key not in cls._instances:
  47. logger.debug(f"🔄 创建新的Qdrant连接: {collection_name}")
  48. cls._instances[key] = QdrantVectorStore(
  49. url=url,
  50. api_key=api_key,
  51. collection_name=collection_name,
  52. vector_size=vector_size,
  53. distance=distance,
  54. timeout=timeout,
  55. **kwargs
  56. )
  57. else:
  58. logger.debug(f"♻️ 复用现有Qdrant连接: {collection_name}")
  59. else:
  60. logger.debug(f"♻️ 复用现有Qdrant连接: {collection_name}")
  61. return cls._instances[key]
  62. class QdrantVectorStore:
  63. """Qdrant向量数据库存储实现"""
  64. def __init__(
  65. self,
  66. url: Optional[str] = None,
  67. api_key: Optional[str] = None,
  68. collection_name: str = "hello_agents_vectors",
  69. vector_size: int = 384,
  70. distance: str = "cosine",
  71. timeout: int = 30,
  72. **kwargs
  73. ):
  74. """
  75. 初始化Qdrant向量存储 (支持云API)
  76. Args:
  77. url: Qdrant云服务URL (如果为None则使用本地)
  78. api_key: Qdrant云服务API密钥
  79. collection_name: 集合名称
  80. vector_size: 向量维度
  81. distance: 距离度量方式 (cosine, dot, euclidean)
  82. timeout: 连接超时时间
  83. """
  84. if not QDRANT_AVAILABLE:
  85. raise ImportError(
  86. "qdrant-client未安装。请运行: pip install qdrant-client>=1.6.0"
  87. )
  88. self.url = url
  89. self.api_key = api_key
  90. self.collection_name = collection_name
  91. self.vector_size = vector_size
  92. self.timeout = timeout
  93. # HNSW/Query params via env
  94. try:
  95. self.hnsw_m = int(os.getenv("QDRANT_HNSW_M", "32"))
  96. except Exception:
  97. self.hnsw_m = 32
  98. try:
  99. self.hnsw_ef_construct = int(os.getenv("QDRANT_HNSW_EF_CONSTRUCT", "256"))
  100. except Exception:
  101. self.hnsw_ef_construct = 256
  102. try:
  103. self.search_ef = int(os.getenv("QDRANT_SEARCH_EF", "128"))
  104. except Exception:
  105. self.search_ef = 128
  106. self.search_exact = os.getenv("QDRANT_SEARCH_EXACT", "0") == "1"
  107. # 距离度量映射
  108. distance_map = {
  109. "cosine": Distance.COSINE,
  110. "dot": Distance.DOT,
  111. "euclidean": Distance.EUCLID,
  112. }
  113. self.distance = distance_map.get(distance.lower(), Distance.COSINE)
  114. # 初始化客户端
  115. self.client = None
  116. self._initialize_client()
  117. def _initialize_client(self):
  118. """初始化Qdrant客户端和集合"""
  119. try:
  120. # 根据配置创建客户端连接
  121. if self.url and self.api_key:
  122. # 使用云服务API
  123. self.client = QdrantClient(
  124. url=self.url,
  125. api_key=self.api_key,
  126. timeout=self.timeout
  127. )
  128. logger.info(f"✅ 成功连接到Qdrant云服务: {self.url}")
  129. elif self.url:
  130. # 使用自定义URL(无API密钥)
  131. self.client = QdrantClient(
  132. url=self.url,
  133. timeout=self.timeout
  134. )
  135. logger.info(f"✅ 成功连接到Qdrant服务: {self.url}")
  136. else:
  137. # 使用本地服务(默认)
  138. self.client = QdrantClient(
  139. host="localhost",
  140. port=6333,
  141. timeout=self.timeout
  142. )
  143. logger.info("✅ 成功连接到本地Qdrant服务: localhost:6333")
  144. # 检查连接
  145. collections = self.client.get_collections()
  146. # 创建或获取集合
  147. self._ensure_collection()
  148. except Exception as e:
  149. logger.error(f"❌ Qdrant连接失败: {e}")
  150. if not self.url:
  151. logger.info("💡 本地连接失败,可以考虑使用Qdrant云服务")
  152. logger.info("💡 或启动本地服务: docker run -p 6333:6333 qdrant/qdrant")
  153. else:
  154. logger.info("💡 请检查URL和API密钥是否正确")
  155. raise
  156. def _ensure_collection(self):
  157. """确保集合存在,不存在则创建"""
  158. try:
  159. # 检查集合是否存在
  160. collections = self.client.get_collections().collections
  161. collection_names = [c.name for c in collections]
  162. if self.collection_name not in collection_names:
  163. # 创建新集合
  164. hnsw_cfg = None
  165. try:
  166. hnsw_cfg = models.HnswConfigDiff(m=self.hnsw_m, ef_construct=self.hnsw_ef_construct)
  167. except Exception:
  168. hnsw_cfg = None
  169. self.client.create_collection(
  170. collection_name=self.collection_name,
  171. vectors_config=VectorParams(
  172. size=self.vector_size,
  173. distance=self.distance
  174. ),
  175. hnsw_config=hnsw_cfg
  176. )
  177. logger.info(f"✅ 创建Qdrant集合: {self.collection_name}")
  178. else:
  179. logger.info(f"✅ 使用现有Qdrant集合: {self.collection_name}")
  180. # 尝试更新 HNSW 配置
  181. try:
  182. self.client.update_collection(
  183. collection_name=self.collection_name,
  184. hnsw_config=models.HnswConfigDiff(m=self.hnsw_m, ef_construct=self.hnsw_ef_construct)
  185. )
  186. except Exception as ie:
  187. logger.debug(f"跳过更新HNSW配置: {ie}")
  188. # 确保必要的payload索引
  189. self._ensure_payload_indexes()
  190. except Exception as e:
  191. logger.error(f"❌ 集合初始化失败: {e}")
  192. raise
  193. def _ensure_payload_indexes(self):
  194. """为常用过滤字段创建payload索引"""
  195. try:
  196. index_fields = [
  197. ("memory_type", models.PayloadSchemaType.KEYWORD),
  198. ("user_id", models.PayloadSchemaType.KEYWORD),
  199. ("memory_id", models.PayloadSchemaType.KEYWORD),
  200. ("timestamp", models.PayloadSchemaType.INTEGER),
  201. ("modality", models.PayloadSchemaType.KEYWORD), # 感知记忆模态筛选
  202. ("source", models.PayloadSchemaType.KEYWORD),
  203. ("external", models.PayloadSchemaType.BOOL),
  204. ("namespace", models.PayloadSchemaType.KEYWORD),
  205. # RAG相关字段索引
  206. ("is_rag_data", models.PayloadSchemaType.BOOL),
  207. ("rag_namespace", models.PayloadSchemaType.KEYWORD),
  208. ("data_source", models.PayloadSchemaType.KEYWORD),
  209. ]
  210. for field_name, schema_type in index_fields:
  211. try:
  212. self.client.create_payload_index(
  213. collection_name=self.collection_name,
  214. field_name=field_name,
  215. field_schema=schema_type,
  216. )
  217. except Exception as ie:
  218. # 索引已存在会报错,忽略
  219. logger.debug(f"索引 {field_name} 已存在或创建失败: {ie}")
  220. except Exception as e:
  221. logger.debug(f"创建payload索引时出错: {e}")
  222. def add_vectors(
  223. self,
  224. vectors: List[List[float]],
  225. metadata: List[Dict[str, Any]],
  226. ids: Optional[List[str]] = None
  227. ) -> bool:
  228. """
  229. 添加向量到Qdrant
  230. Args:
  231. vectors: 向量列表
  232. metadata: 元数据列表
  233. ids: 可选的ID列表
  234. Returns:
  235. bool: 是否成功
  236. """
  237. try:
  238. if not vectors:
  239. logger.warning("⚠️ 向量列表为空")
  240. return False
  241. # 生成ID(如果未提供)
  242. if ids is None:
  243. ids = [f"vec_{i}_{int(datetime.now().timestamp() * 1000000)}"
  244. for i in range(len(vectors))]
  245. # 构建点数据
  246. logger.info(f"[Qdrant] add_vectors start: n_vectors={len(vectors)} n_meta={len(metadata)} collection={self.collection_name}")
  247. points = []
  248. for i, (vector, meta, point_id) in enumerate(zip(vectors, metadata, ids)):
  249. # 确保向量是正确的维度
  250. try:
  251. vlen = len(vector)
  252. except Exception:
  253. logger.error(f"[Qdrant] 非法向量类型: index={i} type={type(vector)} value={vector}")
  254. continue
  255. if vlen != self.vector_size:
  256. logger.warning(f"⚠️ 向量维度不匹配: 期望{self.vector_size}, 实际{len(vector)}")
  257. continue
  258. # 添加时间戳到元数据
  259. meta_with_timestamp = meta.copy()
  260. meta_with_timestamp["timestamp"] = int(datetime.now().timestamp())
  261. meta_with_timestamp["added_at"] = int(datetime.now().timestamp())
  262. if "external" in meta_with_timestamp and not isinstance(meta_with_timestamp.get("external"), bool):
  263. # normalize to bool
  264. val = meta_with_timestamp.get("external")
  265. meta_with_timestamp["external"] = True if str(val).lower() in ("1", "true", "yes") else False
  266. # 确保点ID是Qdrant接受的类型(无符号整数或UUID字符串)
  267. safe_id: Any
  268. if isinstance(point_id, int):
  269. safe_id = point_id
  270. elif isinstance(point_id, str):
  271. try:
  272. uuid.UUID(point_id)
  273. safe_id = point_id
  274. except Exception:
  275. safe_id = str(uuid.uuid4())
  276. else:
  277. safe_id = str(uuid.uuid4())
  278. point = PointStruct(
  279. id=safe_id,
  280. vector=vector,
  281. payload=meta_with_timestamp
  282. )
  283. points.append(point)
  284. if not points:
  285. logger.warning("⚠️ 没有有效的向量点")
  286. return False
  287. # 批量插入
  288. logger.info(f"[Qdrant] upsert begin: points={len(points)}")
  289. operation_info = self.client.upsert(
  290. collection_name=self.collection_name,
  291. points=points,
  292. wait=True
  293. )
  294. logger.info("[Qdrant] upsert done")
  295. logger.info(f"✅ 成功添加 {len(points)} 个向量到Qdrant")
  296. return True
  297. except Exception as e:
  298. logger.error(f"❌ 添加向量失败: {e}")
  299. return False
  300. def search_similar(
  301. self,
  302. query_vector: List[float],
  303. limit: int = 10,
  304. score_threshold: Optional[float] = None,
  305. where: Optional[Dict[str, Any]] = None
  306. ) -> List[Dict[str, Any]]:
  307. """
  308. 搜索相似向量
  309. Args:
  310. query_vector: 查询向量
  311. limit: 返回结果数量限制
  312. score_threshold: 相似度阈值
  313. where: 过滤条件
  314. Returns:
  315. List[Dict]: 搜索结果
  316. """
  317. try:
  318. if len(query_vector) != self.vector_size:
  319. logger.error(f"❌ 查询向量维度错误: 期望{self.vector_size}, 实际{len(query_vector)}")
  320. return []
  321. # 构建过滤器
  322. query_filter = None
  323. if where:
  324. conditions = []
  325. for key, value in where.items():
  326. if isinstance(value, (str, int, float, bool)):
  327. conditions.append(
  328. FieldCondition(
  329. key=key,
  330. match=MatchValue(value=value)
  331. )
  332. )
  333. if conditions:
  334. query_filter = Filter(must=conditions)
  335. # 执行搜索
  336. # 搜索参数
  337. search_params = None
  338. try:
  339. search_params = models.SearchParams(hnsw_ef=self.search_ef, exact=self.search_exact)
  340. except Exception:
  341. search_params = None
  342. response = self.client.query_points(
  343. collection_name=self.collection_name,
  344. query=query_vector,
  345. query_filter=query_filter,
  346. limit=limit,
  347. score_threshold=score_threshold,
  348. with_payload=True,
  349. with_vectors=False,
  350. search_params=search_params
  351. )
  352. search_result = response.points
  353. # 转换结果格式
  354. results = []
  355. for hit in search_result:
  356. result = {
  357. "id": hit.id,
  358. "score": hit.score,
  359. "metadata": hit.payload or {}
  360. }
  361. results.append(result)
  362. logger.debug(f"🔍 Qdrant搜索返回 {len(results)} 个结果")
  363. return results
  364. except Exception as e:
  365. logger.error(f"❌ 向量搜索失败: {e}")
  366. return []
  367. def delete_vectors(self, ids: List[str]) -> bool:
  368. """
  369. 删除向量
  370. Args:
  371. ids: 要删除的向量ID列表
  372. Returns:
  373. bool: 是否成功
  374. """
  375. try:
  376. if not ids:
  377. return True
  378. operation_info = self.client.delete(
  379. collection_name=self.collection_name,
  380. points_selector=models.PointIdsList(
  381. points=ids
  382. ),
  383. wait=True
  384. )
  385. logger.info(f"✅ 成功删除 {len(ids)} 个向量")
  386. return True
  387. except Exception as e:
  388. logger.error(f"❌ 删除向量失败: {e}")
  389. return False
  390. def clear_collection(self) -> bool:
  391. """
  392. 清空集合
  393. Returns:
  394. bool: 是否成功
  395. """
  396. try:
  397. # 删除并重新创建集合
  398. self.client.delete_collection(collection_name=self.collection_name)
  399. self._ensure_collection()
  400. logger.info(f"✅ 成功清空Qdrant集合: {self.collection_name}")
  401. return True
  402. except Exception as e:
  403. logger.error(f"❌ 清空集合失败: {e}")
  404. return False
  405. def delete_memories(self, memory_ids: List[str]):
  406. """
  407. 删除指定记忆(通过payload中的 memory_id 过滤删除)
  408. 注意:由于写入时可能将非UUID的点ID转换为UUID,这里不再依赖点ID,
  409. 而是通过payload中的memory_id来匹配删除,确保一致性。
  410. """
  411. try:
  412. if not memory_ids:
  413. return
  414. # 构建 should 过滤条件:memory_id 等于任一给定值
  415. conditions = [
  416. FieldCondition(key="memory_id", match=MatchValue(value=mid))
  417. for mid in memory_ids
  418. ]
  419. query_filter = Filter(should=conditions)
  420. self.client.delete(
  421. collection_name=self.collection_name,
  422. points_selector=models.FilterSelector(filter=query_filter),
  423. wait=True,
  424. )
  425. logger.info(f"✅ 成功按memory_id删除 {len(memory_ids)} 个Qdrant向量")
  426. except Exception as e:
  427. logger.error(f"❌ 删除记忆失败: {e}")
  428. raise
  429. def get_collection_info(self) -> Dict[str, Any]:
  430. """
  431. 获取集合信息
  432. Returns:
  433. Dict: 集合信息
  434. """
  435. try:
  436. collection_info = self.client.get_collection(self.collection_name)
  437. info = {
  438. "name": self.collection_name,
  439. "vectors_count": collection_info.vectors_count,
  440. "indexed_vectors_count": collection_info.indexed_vectors_count,
  441. "points_count": collection_info.points_count,
  442. "segments_count": collection_info.segments_count,
  443. "config": {
  444. "vector_size": self.vector_size,
  445. "distance": self.distance.value,
  446. }
  447. }
  448. return info
  449. except Exception as e:
  450. logger.error(f"❌ 获取集合信息失败: {e}")
  451. return {}
  452. def get_collection_stats(self) -> Dict[str, Any]:
  453. """
  454. 获取集合统计信息(兼容抽象接口)
  455. """
  456. info = self.get_collection_info()
  457. if not info:
  458. return {"store_type": "qdrant", "name": self.collection_name}
  459. info["store_type"] = "qdrant"
  460. return info
  461. def health_check(self) -> bool:
  462. """
  463. 健康检查
  464. Returns:
  465. bool: 服务是否健康
  466. """
  467. try:
  468. # 尝试获取集合列表
  469. collections = self.client.get_collections()
  470. return True
  471. except Exception as e:
  472. logger.error(f"❌ Qdrant健康检查失败: {e}")
  473. return False
  474. def __del__(self):
  475. """析构函数,清理资源"""
  476. if hasattr(self, 'client') and self.client:
  477. try:
  478. self.client.close()
  479. except:
  480. pass