embedding.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. """
  2. InnoCore AI 向量生成工具
  3. """
  4. import asyncio
  5. from typing import List, Dict, Optional, Any
  6. import numpy as np
  7. from openai import AsyncOpenAI
  8. import hashlib
  9. import json
  10. from ..core.config import get_config
  11. from ..core.exceptions import AgentException
  12. class EmbeddingGenerator:
  13. """向量生成器"""
  14. def __init__(self):
  15. self.config = get_config()
  16. self.client = None
  17. self.embedding_model = self.config.vector_db.embedding_model
  18. self.cache = {} # 简单的内存缓存
  19. async def initialize(self):
  20. """初始化向量生成器"""
  21. try:
  22. self.client = AsyncOpenAI(
  23. api_key=self.config.llm.api_key,
  24. base_url=self.config.llm.base_url
  25. )
  26. except Exception as e:
  27. raise AgentException(f"向量生成器初始化失败: {str(e)}")
  28. async def generate_embedding(self, text: str, use_cache: bool = True) -> List[float]:
  29. """生成文本向量"""
  30. if not text:
  31. return [0.0] * 1536 # 返回零向量
  32. # 检查缓存
  33. if use_cache:
  34. cache_key = self._get_cache_key(text)
  35. if cache_key in self.cache:
  36. return self.cache[cache_key]
  37. try:
  38. # 清理文本
  39. cleaned_text = self._clean_text(text)
  40. # 调用OpenAI API
  41. response = await self.client.embeddings.create(
  42. model=self.embedding_model,
  43. input=cleaned_text
  44. )
  45. embedding = response.data[0].embedding
  46. # 缓存结果
  47. if use_cache:
  48. cache_key = self._get_cache_key(text)
  49. self.cache[cache_key] = embedding
  50. return embedding
  51. except Exception as e:
  52. raise AgentException(f"向量生成失败: {str(e)}")
  53. async def generate_batch_embeddings(self, texts: List[str],
  54. batch_size: int = 10) -> List[List[float]]:
  55. """批量生成向量"""
  56. embeddings = []
  57. for i in range(0, len(texts), batch_size):
  58. batch = texts[i:i + batch_size]
  59. try:
  60. # 批量调用API
  61. cleaned_texts = [self._clean_text(text) for text in batch]
  62. response = await self.client.embeddings.create(
  63. model=self.embedding_model,
  64. input=cleaned_texts
  65. )
  66. batch_embeddings = [item.embedding for item in response.data]
  67. embeddings.extend(batch_embeddings)
  68. except Exception as e:
  69. # 如果批量失败,逐个生成
  70. for text in batch:
  71. try:
  72. embedding = await self.generate_embedding(text)
  73. embeddings.append(embedding)
  74. except Exception as single_error:
  75. print(f"单个向量生成失败: {str(single_error)}")
  76. embeddings.append([0.0] * 1536) # 零向量
  77. return embeddings
  78. async def generate_paper_embedding(self, paper_info: Dict[str, Any]) -> List[float]:
  79. """为论文生成综合向量"""
  80. # 组合论文的关键信息
  81. title = paper_info.get("title", "")
  82. abstract = paper_info.get("abstract", "")
  83. authors = " ".join(paper_info.get("authors", []))
  84. # 构建综合文本
  85. combined_text = f"{title} {abstract} {authors}"
  86. # 如果有结构化内容,也包含进来
  87. sections = paper_info.get("sections", {})
  88. if sections:
  89. section_text = " ".join(sections.values())
  90. combined_text += " " + section_text
  91. return await self.generate_embedding(combined_text)
  92. async def generate_section_embeddings(self, sections: Dict[str, str]) -> Dict[str, List[float]]:
  93. """为各个章节生成向量"""
  94. section_embeddings = {}
  95. for section_name, section_content in sections.items():
  96. if section_content.strip():
  97. try:
  98. embedding = await self.generate_embedding(section_content)
  99. section_embeddings[section_name] = embedding
  100. except Exception as e:
  101. print(f"章节 {section_name} 向量生成失败: {str(e)}")
  102. section_embeddings[section_name] = [0.0] * 1536
  103. return section_embeddings
  104. def _clean_text(self, text: str) -> str:
  105. """清理文本"""
  106. if not text:
  107. return ""
  108. # 移除多余的空白字符
  109. text = ' '.join(text.split())
  110. # 截断过长的文本(OpenAI有token限制)
  111. max_length = 8000 # 保守估计
  112. if len(text) > max_length:
  113. text = text[:max_length]
  114. return text
  115. def _get_cache_key(self, text: str) -> str:
  116. """生成缓存键"""
  117. return hashlib.md5(text.encode()).hexdigest()
  118. def clear_cache(self):
  119. """清空缓存"""
  120. self.cache.clear()
  121. def get_cache_size(self) -> int:
  122. """获取缓存大小"""
  123. return len(self.cache)
  124. async def calculate_similarity(self, text1: str, text2: str) -> float:
  125. """计算两个文本的相似度"""
  126. try:
  127. embedding1 = await self.generate_embedding(text1)
  128. embedding2 = await self.generate_embedding(text2)
  129. return self._cosine_similarity(embedding1, embedding2)
  130. except Exception as e:
  131. print(f"相似度计算失败: {str(e)}")
  132. return 0.0
  133. def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
  134. """计算余弦相似度"""
  135. if len(vec1) != len(vec2):
  136. return 0.0
  137. try:
  138. vec1_np = np.array(vec1)
  139. vec2_np = np.array(vec2)
  140. dot_product = np.dot(vec1_np, vec2_np)
  141. norm1 = np.linalg.norm(vec1_np)
  142. norm2 = np.linalg.norm(vec2_np)
  143. if norm1 == 0 or norm2 == 0:
  144. return 0.0
  145. return dot_product / (norm1 * norm2)
  146. except Exception:
  147. return 0.0
  148. async def find_most_similar(self, query_text: str,
  149. candidate_texts: List[str],
  150. top_k: int = 5) -> List[Dict[str, Any]]:
  151. """找到最相似的文本"""
  152. if not candidate_texts:
  153. return []
  154. try:
  155. # 生成查询向量
  156. query_embedding = await self.generate_embedding(query_text)
  157. # 生成候选文本向量
  158. candidate_embeddings = await self.generate_batch_embeddings(candidate_texts)
  159. # 计算相似度
  160. similarities = []
  161. for i, candidate_embedding in enumerate(candidate_embeddings):
  162. similarity = self._cosine_similarity(query_embedding, candidate_embedding)
  163. similarities.append({
  164. "text": candidate_texts[i],
  165. "similarity": similarity,
  166. "index": i
  167. })
  168. # 按相似度排序
  169. similarities.sort(key=lambda x: x["similarity"], reverse=True)
  170. return similarities[:top_k]
  171. except Exception as e:
  172. print(f"相似文本查找失败: {str(e)}")
  173. return []
  174. async def cluster_texts(self, texts: List[str],
  175. num_clusters: int = 3) -> Dict[str, Any]:
  176. """文本聚类(简化实现)"""
  177. try:
  178. # 生成所有文本的向量
  179. embeddings = await self.generate_batch_embeddings(texts)
  180. # 简单的聚类逻辑(基于相似度阈值)
  181. clusters = {}
  182. cluster_id = 0
  183. used_indices = set()
  184. for i, embedding in enumerate(embeddings):
  185. if i in used_indices:
  186. continue
  187. # 创建新聚类
  188. clusters[f"cluster_{cluster_id}"] = {
  189. "texts": [texts[i]],
  190. "indices": [i],
  191. "center": embedding
  192. }
  193. used_indices.add(i)
  194. # 查找相似文本加入同一聚类
  195. for j, other_embedding in enumerate(embeddings):
  196. if j in used_indices:
  197. continue
  198. similarity = self._cosine_similarity(embedding, other_embedding)
  199. if similarity > 0.8: # 相似度阈值
  200. clusters[f"cluster_{cluster_id}"]["texts"].append(texts[j])
  201. clusters[f"cluster_{cluster_id}"]["indices"].append(j)
  202. used_indices.add(j)
  203. cluster_id += 1
  204. return {
  205. "clusters": clusters,
  206. "num_clusters": len(clusters),
  207. "total_texts": len(texts)
  208. }
  209. except Exception as e:
  210. print(f"文本聚类失败: {str(e)}")
  211. return {"clusters": {}, "num_clusters": 0, "total_texts": len(texts)}
  212. async def extract_keywords(self, text: str, max_keywords: int = 10) -> List[str]:
  213. """提取关键词(基于TF-IDF的简化实现)"""
  214. try:
  215. # 分词
  216. words = text.lower().split()
  217. # 过滤停用词(简化版)
  218. stop_words = {
  219. 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
  220. 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have',
  221. 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should'
  222. }
  223. filtered_words = [word for word in words if word not in stop_words and len(word) > 2]
  224. # 计算词频
  225. word_freq = {}
  226. for word in filtered_words:
  227. word_freq[word] = word_freq.get(word, 0) + 1
  228. # 按频率排序
  229. sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
  230. # 返回前N个关键词
  231. return [word for word, freq in sorted_words[:max_keywords]]
  232. except Exception as e:
  233. print(f"关键词提取失败: {str(e)}")
  234. return []
  235. def get_embedding_info(self) -> Dict[str, Any]:
  236. """获取向量生成器信息"""
  237. return {
  238. "model": self.embedding_model,
  239. "cache_size": len(self.cache),
  240. "vector_dimension": 1536, # OpenAI embedding维度
  241. "provider": "openai"
  242. }