retriever.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. """
  2. 统一检索接口:retrieve(user_id, query_context)。
  3. 优先 Milvus 语义检索,不可用时回退 SQL 文本记忆。
  4. """
  5. from __future__ import annotations
  6. import time
  7. from collections import Counter
  8. from typing import Any, Dict, List
  9. from core.config import get_config
  10. from memory.store import list_user_memory_chunks_sql
  11. from rag.embedding import embed_texts
  12. from rag.milvus_store import search
  13. def _build_query_text(query_context: Dict[str, Any]) -> str:
  14. if not query_context:
  15. return "健康记忆检索"
  16. keys = [
  17. "goal",
  18. "query",
  19. "scenario",
  20. "risk_focus",
  21. "free_notes",
  22. "today_food_log_text",
  23. ]
  24. pieces: List[str] = []
  25. for k in keys:
  26. if k in query_context and query_context[k]:
  27. pieces.append(f"{k}:{query_context[k]}")
  28. if not pieces:
  29. pieces.append(str(query_context))
  30. return " | ".join(pieces)[:4000]
  31. def retrieve(user_id: str, query_context: Dict[str, Any], top_k: int | None = None) -> Dict[str, Any]:
  32. cfg = get_config().rag
  33. k = top_k or cfg.top_k
  34. t0 = time.perf_counter()
  35. query_text = _build_query_text(query_context)
  36. chunks: List[Dict[str, Any]] = []
  37. mode = "sql_fallback"
  38. if cfg.enabled:
  39. vec = embed_texts([query_text])[0]
  40. chunks = search(user_id=user_id, query_vector=vec, top_k=k)
  41. if chunks:
  42. mode = "milvus"
  43. if not chunks:
  44. rows = list_user_memory_chunks_sql(user_id=user_id, limit=max(8, k * 3))
  45. chunks = [
  46. {
  47. "chunk_id": r["chunk_id"],
  48. "source_type": r["source_type"],
  49. "source_id": r["source_id"],
  50. "text": r["text"],
  51. "created_at": r.get("created_at"),
  52. "score": 0.0,
  53. }
  54. for r in rows[:k]
  55. ]
  56. summary = "\n".join([f"- [{c['source_type']}] {c['text']}" for c in chunks[:k]]) or "(暂无检索结果)"
  57. source_breakdown = dict(Counter([c.get("source_type", "unknown") for c in chunks]))
  58. ms = int((time.perf_counter() - t0) * 1000)
  59. return {
  60. "chunks": chunks[:k],
  61. "summary": summary[:12000],
  62. "debug": {
  63. "rag_enabled": cfg.enabled,
  64. "mode": mode,
  65. "retrieved_count": len(chunks[:k]),
  66. "top_k": k,
  67. "retrieval_ms": ms,
  68. "source_breakdown": source_breakdown,
  69. "query_text_preview": query_text[:240],
  70. },
  71. }