| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- """统一嵌入模块(实现 + 提供器)
- 说明(中文):
- - 提供统一的文本嵌入接口与多实现:本地Transformer、DashScope(通义千问)、TF-IDF兜底。
- - 暴露 get_text_embedder()/get_dimension()/refresh_embedder() 供各记忆类型统一使用。
- - 通过环境变量优先级:dashscope > local > tfidf。
- 环境变量:
- - EMBED_MODEL_TYPE: "dashscope" | "local" | "tfidf"(默认 dashscope)
- - EMBED_MODEL_NAME: 模型名称(dashscope默认 text-embedding-v3;local默认 sentence-transformers/all-MiniLM-L6-v2)
- - EMBED_API_KEY: Embedding API Key(统一命名)
- - EMBED_BASE_URL: Embedding Base URL(统一命名,可选)
- """
- from typing import List, Union, Optional
- import threading
- import os
- import numpy as np
- # ==============
- # 抽象与实现
- # ==============
- class EmbeddingModel:
- """嵌入模型基类(最小接口)"""
- def encode(self, texts: Union[str, List[str]]):
- raise NotImplementedError
- @property
- def dimension(self) -> int:
- raise NotImplementedError
- class LocalTransformerEmbedding(EmbeddingModel):
- """本地Transformer嵌入(优先 sentence-transformers,缺失回退 transformers+torch)"""
- def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
- self.model_name = model_name
- self._backend = None # "st" 或 "hf"
- self._st_model = None
- self._hf_tokenizer = None
- self._hf_model = None
- self._dimension = None
- self._load_backend()
- def _load_backend(self):
- # 优先 sentence-transformers
- try:
- from sentence_transformers import SentenceTransformer
- self._st_model = SentenceTransformer(self.model_name)
- test_vec = self._st_model.encode("test_text")
- self._dimension = len(test_vec)
- self._backend = "st"
- return
- except Exception:
- self._st_model = None
- # 回退 transformers
- try:
- from transformers import AutoTokenizer, AutoModel
- import torch
- self._hf_tokenizer = AutoTokenizer.from_pretrained(self.model_name)
- self._hf_model = AutoModel.from_pretrained(self.model_name)
- with torch.no_grad():
- inputs = self._hf_tokenizer("test_text", return_tensors="pt", padding=True, truncation=True)
- outputs = self._hf_model(**inputs)
- test_embedding = outputs.last_hidden_state.mean(dim=1)
- self._dimension = int(test_embedding.shape[1])
- self._backend = "hf"
- return
- except Exception:
- self._hf_tokenizer = None
- self._hf_model = None
- raise ImportError("未找到可用的本地嵌入后端,请安装 sentence-transformers 或 transformers+torch")
- def encode(self, texts: Union[str, List[str]]):
- if isinstance(texts, str):
- inputs = [texts]
- single = True
- else:
- inputs = list(texts)
- single = False
- if self._backend == "st":
- vecs = self._st_model.encode(inputs)
- if hasattr(vecs, "tolist"):
- vecs = [v for v in vecs]
- else:
- import torch
- tokenized = self._hf_tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=512)
- with torch.no_grad():
- outputs = self._hf_model(**tokenized)
- embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
- vecs = [v for v in embeddings]
- if single:
- return vecs[0]
- return vecs
- @property
- def dimension(self) -> int:
- return int(self._dimension or 0)
- class TFIDFEmbedding(EmbeddingModel):
- """TF-IDF 简易兜底(在无深度模型时保证可用)"""
- def __init__(self, max_features: int = 1000):
- self.max_features = max_features
- self._vectorizer = None
- self._is_fitted = False
- self._dimension = max_features
- self._init_vectorizer()
- def _init_vectorizer(self):
- try:
- from sklearn.feature_extraction.text import TfidfVectorizer
- self._vectorizer = TfidfVectorizer(max_features=self.max_features, stop_words='english')
- except ImportError:
- raise ImportError("请安装 scikit-learn: pip install scikit-learn")
- def fit(self, texts: List[str]):
- self._vectorizer.fit(texts)
- self._is_fitted = True
- self._dimension = len(self._vectorizer.get_feature_names_out())
- def encode(self, texts: Union[str, List[str]]):
- if not self._is_fitted:
- raise ValueError("TF-IDF模型未训练,请先调用fit()方法")
- if isinstance(texts, str):
- texts = [texts]
- single = True
- else:
- single = False
- tfidf_matrix = self._vectorizer.transform(texts)
- embeddings = tfidf_matrix.toarray()
- if single:
- return embeddings[0]
- return [e for e in embeddings]
- @property
- def dimension(self) -> int:
- return self._dimension
- class DashScopeEmbedding(EmbeddingModel):
- """阿里云 DashScope(通义千问)Embedding / OpenAI兼容REST 模式
- 行为:
- - 如提供 base_url,则优先使用 OpenAI 兼容的 REST 接口(POST {base_url}/embeddings)。
- - 否则使用官方 dashscope SDK 的 TextEmbedding.call。
- """
- def __init__(self, model_name: str = "text-embedding-v3", api_key: Optional[str] = None, base_url: Optional[str] = None):
- self.model_name = model_name
- self.api_key = api_key
- self.base_url = base_url
- self._dimension = None
- # 仅在非REST情况下初始化SDK
- if not self.base_url:
- self._init_client()
- # 探测维度
- test = self.encode("health_check")
- self._dimension = len(test)
- def _init_client(self):
- try:
- if self.api_key:
- # 将统一命名的 API Key 注入到 SDK 期望的位置
- os.environ["DASHSCOPE_API_KEY"] = self.api_key
- import dashscope # noqa: F401
- except ImportError:
- raise ImportError("请安装 dashscope: pip install dashscope")
- def encode(self, texts: Union[str, List[str]]):
- if isinstance(texts, str):
- inputs = [texts]
- single = True
- else:
- inputs = list(texts)
- single = False
- # REST 模式(OpenAI兼容)
- if self.base_url:
- import requests
- url = self.base_url.rstrip("/") + "/embeddings"
- headers = {
- "Authorization": f"Bearer {self.api_key}" if self.api_key else "",
- "Content-Type": "application/json",
- }
- payload = {"model": self.model_name, "input": inputs}
- resp = requests.post(url, headers=headers, json=payload, timeout=30)
- if resp.status_code >= 400:
- raise RuntimeError(f"Embedding REST 调用失败: {resp.status_code} {resp.text}")
- data = resp.json()
- # 期望结构:{"data": [{"embedding": [...]}]}
- items = data.get("data") or []
- vecs = [np.array(item.get("embedding")) for item in items]
- if single:
- return vecs[0]
- return vecs
- # SDK 模式
- from dashscope import TextEmbedding
- rsp = TextEmbedding.call(model=self.model_name, input=inputs)
- embeddings_obj = None
- if isinstance(rsp, dict):
- embeddings_obj = (rsp.get("output") or {}).get("embeddings")
- else:
- embeddings_obj = getattr(getattr(rsp, "output", None), "embeddings", None)
- if not embeddings_obj:
- raise RuntimeError("DashScope 返回为空或格式不匹配")
- vecs = [np.array(item.get("embedding") or item.get("vector")) for item in embeddings_obj]
- if single:
- return vecs[0]
- return vecs
- @property
- def dimension(self) -> int:
- return int(self._dimension or 0)
- # ==============
- # 工厂与回退
- # ==============
- def create_embedding_model(model_type: str = "local", **kwargs) -> EmbeddingModel:
- """创建嵌入模型实例
- model_type: "dashscope" | "local" | "tfidf"
- kwargs: model_name, api_key
- """
- if model_type in ("local", "sentence_transformer", "huggingface"):
- return LocalTransformerEmbedding(**kwargs)
- elif model_type == "dashscope":
- return DashScopeEmbedding(**kwargs)
- elif model_type == "tfidf":
- return TFIDFEmbedding(**kwargs)
- else:
- raise ValueError(f"不支持的模型类型: {model_type}")
- def create_embedding_model_with_fallback(preferred_type: str = "dashscope", **kwargs) -> EmbeddingModel:
- """带回退的创建:dashscope -> local -> tfidf"""
- if preferred_type in ("sentence_transformer", "huggingface"):
- preferred_type = "local"
- fallback = ["dashscope", "local", "tfidf"]
- # 将首选放最前
- if preferred_type in fallback:
- fallback.remove(preferred_type)
- fallback.insert(0, preferred_type)
- for t in fallback:
- try:
- return create_embedding_model(t, **kwargs)
- except Exception:
- continue
- raise RuntimeError("所有嵌入模型都不可用,请安装依赖或检查配置")
- # ==================
- # Provider(单例)
- # ==================
- _lock = threading.RLock()
- _embedder: Optional[EmbeddingModel] = None
- def _build_embedder() -> EmbeddingModel:
- preferred = os.getenv("EMBED_MODEL_TYPE", "dashscope").strip()
- # 根据提供商选择默认模型
- default_model = "text-embedding-v3" if preferred == "dashscope" else "sentence-transformers/all-MiniLM-L6-v2"
- model_name = os.getenv("EMBED_MODEL_NAME", default_model).strip()
- kwargs = {}
- if model_name:
- kwargs["model_name"] = model_name
- # 仅使用统一命名
- api_key = os.getenv("EMBED_API_KEY")
- if api_key:
- kwargs["api_key"] = api_key
- base_url = os.getenv("EMBED_BASE_URL")
- if base_url:
- kwargs["base_url"] = base_url
- return create_embedding_model_with_fallback(preferred_type=preferred, **kwargs)
- def get_text_embedder() -> EmbeddingModel:
- """获取全局共享的文本嵌入实例(线程安全单例)"""
- global _embedder
- if _embedder is not None:
- return _embedder
- with _lock:
- if _embedder is None:
- _embedder = _build_embedder()
- return _embedder
- def get_dimension(default: int = 384) -> int:
- """获取统一向量维度(失败回退默认值)"""
- try:
- return int(getattr(get_text_embedder(), "dimension", default))
- except Exception:
- return int(default)
- def refresh_embedder() -> EmbeddingModel:
- """强制重建嵌入实例(可用于动态切换环境变量)"""
- global _embedder
- with _lock:
- _embedder = _build_embedder()
- return _embedder
|