milvus_store.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """
  2. Milvus 存储层(可选启用)。
  3. """
  4. from __future__ import annotations
  5. import logging
  6. from typing import Any, Dict, List, Optional
  7. from core.config import get_config
  8. logger = logging.getLogger(__name__)
  9. def _import_milvus():
  10. try:
  11. from pymilvus import ( # type: ignore
  12. Collection,
  13. CollectionSchema,
  14. DataType,
  15. FieldSchema,
  16. connections,
  17. utility,
  18. )
  19. return Collection, CollectionSchema, DataType, FieldSchema, connections, utility
  20. except Exception:
  21. return None
  22. def _connect() -> bool:
  23. cfg = get_config().rag
  24. pkg = _import_milvus()
  25. if pkg is None:
  26. logger.warning("pymilvus 不可用,RAG 将回退 SQL 检索")
  27. return False
  28. _, _, _, _, connections, _ = pkg
  29. try:
  30. connections.connect(
  31. alias="default",
  32. uri=cfg.milvus_uri,
  33. token=cfg.milvus_token or None,
  34. )
  35. return True
  36. except Exception as e:
  37. logger.warning("Milvus 连接失败,RAG 回退 SQL 检索: %s", e)
  38. return False
  39. def init_collection(dim: int) -> bool:
  40. cfg = get_config().rag
  41. pkg = _import_milvus()
  42. if pkg is None or not _connect():
  43. return False
  44. Collection, CollectionSchema, DataType, FieldSchema, _, utility = pkg
  45. name = cfg.milvus_collection
  46. try:
  47. if utility.has_collection(name):
  48. return True
  49. fields = [
  50. FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, is_primary=True, max_length=128),
  51. FieldSchema(name="user_id", dtype=DataType.VARCHAR, max_length=256),
  52. FieldSchema(name="source_type", dtype=DataType.VARCHAR, max_length=64),
  53. FieldSchema(name="source_id", dtype=DataType.VARCHAR, max_length=128),
  54. FieldSchema(name="created_at", dtype=DataType.VARCHAR, max_length=64),
  55. FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=8192),
  56. FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dim),
  57. ]
  58. schema = CollectionSchema(fields=fields, description="Health memory chunks")
  59. col = Collection(name=name, schema=schema)
  60. index_params = {"metric_type": "IP", "index_type": "AUTOINDEX", "params": {}}
  61. col.create_index(field_name="vector", index_params=index_params)
  62. col.load()
  63. return True
  64. except Exception as e:
  65. logger.warning("Milvus 集合初始化失败: %s", e)
  66. return False
  67. def upsert_chunks(chunks: List[Dict[str, Any]]) -> int:
  68. cfg = get_config().rag
  69. pkg = _import_milvus()
  70. if pkg is None or not chunks:
  71. return 0
  72. Collection, _, _, _, _, _ = pkg
  73. if not _connect():
  74. return 0
  75. dim = len(chunks[0].get("vector") or [])
  76. if dim <= 0 or not init_collection(dim):
  77. return 0
  78. try:
  79. col = Collection(cfg.milvus_collection)
  80. col.load()
  81. data = [
  82. [c["chunk_id"] for c in chunks],
  83. [c["user_id"] for c in chunks],
  84. [c["source_type"] for c in chunks],
  85. [c["source_id"] for c in chunks],
  86. [c.get("created_at", "") for c in chunks],
  87. [c["text"] for c in chunks],
  88. [c["vector"] for c in chunks],
  89. ]
  90. col.upsert(data)
  91. col.flush()
  92. return len(chunks)
  93. except Exception as e:
  94. logger.warning("Milvus upsert 失败: %s", e)
  95. return 0
  96. def search(
  97. user_id: str,
  98. query_vector: List[float],
  99. top_k: int = 5,
  100. source_types: Optional[List[str]] = None,
  101. ) -> List[Dict[str, Any]]:
  102. cfg = get_config().rag
  103. pkg = _import_milvus()
  104. if pkg is None or not query_vector:
  105. return []
  106. Collection, _, _, _, _, _ = pkg
  107. if not _connect():
  108. return []
  109. try:
  110. col = Collection(cfg.milvus_collection)
  111. col.load()
  112. expr = f'user_id == "{user_id}"'
  113. if source_types:
  114. src_expr = " or ".join([f'source_type == "{s}"' for s in source_types])
  115. expr = f"{expr} and ({src_expr})"
  116. res = col.search(
  117. data=[query_vector],
  118. anns_field="vector",
  119. param={"metric_type": "IP", "params": {}},
  120. limit=max(1, min(top_k, 20)),
  121. expr=expr,
  122. output_fields=["chunk_id", "source_type", "source_id", "text", "created_at"],
  123. )
  124. rows: List[Dict[str, Any]] = []
  125. for hits in res:
  126. for h in hits:
  127. entity = h.entity
  128. rows.append(
  129. {
  130. "chunk_id": entity.get("chunk_id"),
  131. "source_type": entity.get("source_type"),
  132. "source_id": entity.get("source_id"),
  133. "text": entity.get("text"),
  134. "created_at": entity.get("created_at"),
  135. "score": float(h.distance),
  136. }
  137. )
  138. return rows
  139. except Exception as e:
  140. logger.warning("Milvus 检索失败: %s", e)
  141. return []