embedding.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. """统一嵌入模块(实现 + 提供器)
  2. 说明(中文):
  3. - 提供统一的文本嵌入接口与多实现:本地Transformer、DashScope(通义千问)、TF-IDF兜底。
  4. - 暴露 get_text_embedder()/get_dimension()/refresh_embedder() 供各记忆类型统一使用。
  5. - 通过环境变量优先级:dashscope > local > tfidf。
  6. 环境变量:
  7. - EMBED_MODEL_TYPE: "dashscope" | "local" | "tfidf"(默认 dashscope)
  8. - EMBED_MODEL_NAME: 模型名称(dashscope默认 text-embedding-v3;local默认 sentence-transformers/all-MiniLM-L6-v2)
  9. - EMBED_API_KEY: Embedding API Key(统一命名)
  10. - EMBED_BASE_URL: Embedding Base URL(统一命名,可选)
  11. """
  12. from typing import List, Union, Optional
  13. import threading
  14. import os
  15. import numpy as np
  16. # ==============
  17. # 抽象与实现
  18. # ==============
  19. class EmbeddingModel:
  20. """嵌入模型基类(最小接口)"""
  21. def encode(self, texts: Union[str, List[str]]):
  22. raise NotImplementedError
  23. @property
  24. def dimension(self) -> int:
  25. raise NotImplementedError
  26. class LocalTransformerEmbedding(EmbeddingModel):
  27. """本地Transformer嵌入(优先 sentence-transformers,缺失回退 transformers+torch)"""
  28. def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
  29. self.model_name = model_name
  30. self._backend = None # "st" 或 "hf"
  31. self._st_model = None
  32. self._hf_tokenizer = None
  33. self._hf_model = None
  34. self._dimension = None
  35. self._load_backend()
  36. def _load_backend(self):
  37. # 优先 sentence-transformers
  38. try:
  39. from sentence_transformers import SentenceTransformer
  40. self._st_model = SentenceTransformer(self.model_name)
  41. test_vec = self._st_model.encode("test_text")
  42. self._dimension = len(test_vec)
  43. self._backend = "st"
  44. return
  45. except Exception:
  46. self._st_model = None
  47. # 回退 transformers
  48. try:
  49. from transformers import AutoTokenizer, AutoModel
  50. import torch
  51. self._hf_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
  52. self._hf_model = AutoModel.from_pretrained(self.model_name)
  53. with torch.no_grad():
  54. inputs = self._hf_tokenizer("test_text", return_tensors="pt", padding=True, truncation=True)
  55. outputs = self._hf_model(**inputs)
  56. test_embedding = outputs.last_hidden_state.mean(dim=1)
  57. self._dimension = int(test_embedding.shape[1])
  58. self._backend = "hf"
  59. return
  60. except Exception:
  61. self._hf_tokenizer = None
  62. self._hf_model = None
  63. raise ImportError("未找到可用的本地嵌入后端,请安装 sentence-transformers 或 transformers+torch")
  64. def encode(self, texts: Union[str, List[str]]):
  65. if isinstance(texts, str):
  66. inputs = [texts]
  67. single = True
  68. else:
  69. inputs = list(texts)
  70. single = False
  71. if self._backend == "st":
  72. vecs = self._st_model.encode(inputs)
  73. if hasattr(vecs, "tolist"):
  74. vecs = [v for v in vecs]
  75. else:
  76. import torch
  77. tokenized = self._hf_tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=512)
  78. with torch.no_grad():
  79. outputs = self._hf_model(**tokenized)
  80. embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
  81. vecs = [v for v in embeddings]
  82. if single:
  83. return vecs[0]
  84. return vecs
  85. @property
  86. def dimension(self) -> int:
  87. return int(self._dimension or 0)
  88. class TFIDFEmbedding(EmbeddingModel):
  89. """TF-IDF 简易兜底(在无深度模型时保证可用)"""
  90. def __init__(self, max_features: int = 1000):
  91. self.max_features = max_features
  92. self._vectorizer = None
  93. self._is_fitted = False
  94. self._dimension = max_features
  95. self._init_vectorizer()
  96. def _init_vectorizer(self):
  97. try:
  98. from sklearn.feature_extraction.text import TfidfVectorizer
  99. self._vectorizer = TfidfVectorizer(max_features=self.max_features, stop_words='english')
  100. except ImportError:
  101. raise ImportError("请安装 scikit-learn: pip install scikit-learn")
  102. def fit(self, texts: List[str]):
  103. self._vectorizer.fit(texts)
  104. self._is_fitted = True
  105. self._dimension = len(self._vectorizer.get_feature_names_out())
  106. def encode(self, texts: Union[str, List[str]]):
  107. if not self._is_fitted:
  108. raise ValueError("TF-IDF模型未训练,请先调用fit()方法")
  109. if isinstance(texts, str):
  110. texts = [texts]
  111. single = True
  112. else:
  113. single = False
  114. tfidf_matrix = self._vectorizer.transform(texts)
  115. embeddings = tfidf_matrix.toarray()
  116. if single:
  117. return embeddings[0]
  118. return [e for e in embeddings]
  119. @property
  120. def dimension(self) -> int:
  121. return self._dimension
  122. class DashScopeEmbedding(EmbeddingModel):
  123. """阿里云 DashScope(通义千问)Embedding / OpenAI兼容REST 模式
  124. 行为:
  125. - 如提供 base_url,则优先使用 OpenAI 兼容的 REST 接口(POST {base_url}/embeddings)。
  126. - 否则使用官方 dashscope SDK 的 TextEmbedding.call。
  127. """
  128. def __init__(self, model_name: str = "text-embedding-v3", api_key: Optional[str] = None, base_url: Optional[str] = None):
  129. self.model_name = model_name
  130. self.api_key = api_key
  131. self.base_url = base_url
  132. self._dimension = None
  133. # 仅在非REST情况下初始化SDK
  134. if not self.base_url:
  135. self._init_client()
  136. # 探测维度
  137. test = self.encode("health_check")
  138. self._dimension = len(test)
  139. def _init_client(self):
  140. try:
  141. if self.api_key:
  142. # 将统一命名的 API Key 注入到 SDK 期望的位置
  143. os.environ["DASHSCOPE_API_KEY"] = self.api_key
  144. import dashscope # noqa: F401
  145. except ImportError:
  146. raise ImportError("请安装 dashscope: pip install dashscope")
  147. def encode(self, texts: Union[str, List[str]]):
  148. if isinstance(texts, str):
  149. inputs = [texts]
  150. single = True
  151. else:
  152. inputs = list(texts)
  153. single = False
  154. # REST 模式(OpenAI兼容)
  155. if self.base_url:
  156. import requests
  157. url = self.base_url.rstrip("/") + "/embeddings"
  158. headers = {
  159. "Authorization": f"Bearer {self.api_key}" if self.api_key else "",
  160. "Content-Type": "application/json",
  161. }
  162. payload = {"model": self.model_name, "input": inputs}
  163. resp = requests.post(url, headers=headers, json=payload, timeout=30)
  164. if resp.status_code >= 400:
  165. raise RuntimeError(f"Embedding REST 调用失败: {resp.status_code} {resp.text}")
  166. data = resp.json()
  167. # 期望结构:{"data": [{"embedding": [...]}]}
  168. items = data.get("data") or []
  169. vecs = [np.array(item.get("embedding")) for item in items]
  170. if single:
  171. return vecs[0]
  172. return vecs
  173. # SDK 模式
  174. from dashscope import TextEmbedding
  175. rsp = TextEmbedding.call(model=self.model_name, input=inputs)
  176. embeddings_obj = None
  177. if isinstance(rsp, dict):
  178. embeddings_obj = (rsp.get("output") or {}).get("embeddings")
  179. else:
  180. embeddings_obj = getattr(getattr(rsp, "output", None), "embeddings", None)
  181. if not embeddings_obj:
  182. raise RuntimeError("DashScope 返回为空或格式不匹配")
  183. vecs = [np.array(item.get("embedding") or item.get("vector")) for item in embeddings_obj]
  184. if single:
  185. return vecs[0]
  186. return vecs
  187. @property
  188. def dimension(self) -> int:
  189. return int(self._dimension or 0)
  190. # ==============
  191. # 工厂与回退
  192. # ==============
  193. def create_embedding_model(model_type: str = "local", **kwargs) -> EmbeddingModel:
  194. """创建嵌入模型实例
  195. model_type: "dashscope" | "local" | "tfidf"
  196. kwargs: model_name, api_key
  197. """
  198. if model_type in ("local", "sentence_transformer", "huggingface"):
  199. return LocalTransformerEmbedding(**kwargs)
  200. elif model_type == "dashscope":
  201. return DashScopeEmbedding(**kwargs)
  202. elif model_type == "tfidf":
  203. return TFIDFEmbedding(**kwargs)
  204. else:
  205. raise ValueError(f"不支持的模型类型: {model_type}")
  206. def create_embedding_model_with_fallback(preferred_type: str = "dashscope", **kwargs) -> EmbeddingModel:
  207. """带回退的创建:dashscope -> local -> tfidf"""
  208. if preferred_type in ("sentence_transformer", "huggingface"):
  209. preferred_type = "local"
  210. fallback = ["dashscope", "local", "tfidf"]
  211. # 将首选放最前
  212. if preferred_type in fallback:
  213. fallback.remove(preferred_type)
  214. fallback.insert(0, preferred_type)
  215. for t in fallback:
  216. try:
  217. return create_embedding_model(t, **kwargs)
  218. except Exception:
  219. continue
  220. raise RuntimeError("所有嵌入模型都不可用,请安装依赖或检查配置")
  221. # ==================
  222. # Provider(单例)
  223. # ==================
  224. _lock = threading.RLock()
  225. _embedder: Optional[EmbeddingModel] = None
  226. def _build_embedder() -> EmbeddingModel:
  227. preferred = os.getenv("EMBED_MODEL_TYPE", "dashscope").strip()
  228. # 根据提供商选择默认模型
  229. default_model = "text-embedding-v3" if preferred == "dashscope" else "sentence-transformers/all-MiniLM-L6-v2"
  230. model_name = os.getenv("EMBED_MODEL_NAME", default_model).strip()
  231. kwargs = {}
  232. if model_name:
  233. kwargs["model_name"] = model_name
  234. # 仅使用统一命名
  235. api_key = os.getenv("EMBED_API_KEY")
  236. if api_key:
  237. kwargs["api_key"] = api_key
  238. base_url = os.getenv("EMBED_BASE_URL")
  239. if base_url:
  240. kwargs["base_url"] = base_url
  241. return create_embedding_model_with_fallback(preferred_type=preferred, **kwargs)
  242. def get_text_embedder() -> EmbeddingModel:
  243. """获取全局共享的文本嵌入实例(线程安全单例)"""
  244. global _embedder
  245. if _embedder is not None:
  246. return _embedder
  247. with _lock:
  248. if _embedder is None:
  249. _embedder = _build_embedder()
  250. return _embedder
  251. def get_dimension(default: int = 384) -> int:
  252. """获取统一向量维度(失败回退默认值)"""
  253. try:
  254. return int(getattr(get_text_embedder(), "dimension", default))
  255. except Exception:
  256. return int(default)
  257. def refresh_embedder() -> EmbeddingModel:
  258. """强制重建嵌入实例(可用于动态切换环境变量)"""
  259. global _embedder
  260. with _lock:
  261. _embedder = _build_embedder()
  262. return _embedder