pipeline.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207
  1. from typing import List, Dict, Optional, Any
  2. import os
  3. import hashlib
  4. import sqlite3
  5. import time
  6. import json
  7. from ..embedding import get_text_embedder, get_dimension
  8. from ..storage.qdrant_store import QdrantVectorStore
  9. def _get_markitdown_instance():
  10. """
  11. Get a configured MarkItDown instance for document conversion.
  12. """
  13. try:
  14. from markitdown import MarkItDown
  15. return MarkItDown()
  16. except ImportError:
  17. print("[WARNING] MarkItDown not available. Install with: pip install markitdown")
  18. return None
  19. def _is_markitdown_supported_format(path: str) -> bool:
  20. """
  21. Check if the file format is supported by MarkItDown.
  22. Supports: PDF, Office docs (docx, xlsx, pptx), images (jpg, png, gif, bmp, tiff),
  23. audio (mp3, wav, m4a), HTML, text formats (txt, md, csv, json, xml), ZIP files, etc.
  24. """
  25. ext = (os.path.splitext(path)[1] or '').lower()
  26. supported_formats = {
  27. # Documents
  28. '.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx',
  29. # Text formats
  30. '.txt', '.md', '.csv', '.json', '.xml', '.html', '.htm',
  31. # Images (OCR + metadata)
  32. '.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif', '.webp',
  33. # Audio (transcription + metadata)
  34. '.mp3', '.wav', '.m4a', '.aac', '.flac', '.ogg',
  35. # Archives
  36. '.zip', '.tar', '.gz', '.rar',
  37. # Code files
  38. '.py', '.js', '.ts', '.java', '.cpp', '.c', '.h', '.css', '.scss',
  39. # Other text
  40. '.log', '.conf', '.ini', '.cfg', '.yaml', '.yml', '.toml'
  41. }
  42. return ext in supported_formats
  43. def _convert_to_markdown(path: str) -> str:
  44. """
  45. Universal document reader using MarkItDown with enhanced PDF processing.
  46. Converts any supported file format to markdown text.
  47. """
  48. if not os.path.exists(path):
  49. return ""
  50. # 对PDF文件使用增强处理
  51. ext = (os.path.splitext(path)[1] or '').lower()
  52. if ext == '.pdf':
  53. return _enhanced_pdf_processing(path)
  54. # 其他格式使用原有MarkItDown
  55. md_instance = _get_markitdown_instance()
  56. if md_instance is None:
  57. return _fallback_text_reader(path)
  58. try:
  59. result = md_instance.convert(path)
  60. text = getattr(result, "text_content", None)
  61. if isinstance(text, str) and text.strip():
  62. return text
  63. return ""
  64. except Exception as e:
  65. print(f"[WARNING] MarkItDown failed for {path}: {e}")
  66. return _fallback_text_reader(path)
  67. def _enhanced_pdf_processing(path: str) -> str:
  68. """
  69. Enhanced PDF processing with post-processing cleanup.
  70. """
  71. print(f"[RAG] Using enhanced PDF processing for: {path}")
  72. # 使用原有MarkItDown提取
  73. md_instance = _get_markitdown_instance()
  74. if md_instance is None:
  75. return _fallback_text_reader(path)
  76. try:
  77. result = md_instance.convert(path)
  78. raw_text = getattr(result, "text_content", None)
  79. if not raw_text or not raw_text.strip():
  80. return ""
  81. # 后处理:清理和重组文本
  82. cleaned_text = _post_process_pdf_text(raw_text)
  83. print(f"[RAG] PDF post-processing completed: {len(raw_text)} -> {len(cleaned_text)} chars")
  84. return cleaned_text
  85. except Exception as e:
  86. print(f"[WARNING] Enhanced PDF processing failed for {path}: {e}")
  87. return _fallback_text_reader(path)
  88. def _post_process_pdf_text(text: str) -> str:
  89. """
  90. Post-process PDF text to improve quality.
  91. """
  92. import re
  93. # 1. 按行分割并清理
  94. lines = text.splitlines()
  95. cleaned_lines = []
  96. for line in lines:
  97. line = line.strip()
  98. if not line:
  99. continue
  100. # 移除单个字符的行(通常是噪音)
  101. if len(line) <= 2 and not line.isdigit():
  102. continue
  103. # 移除明显的页眉页脚噪音
  104. if re.match(r'^\d+$', line): # 纯数字行(页码)
  105. continue
  106. if line.lower() in ['github', 'project', 'forks', 'stars', 'language']:
  107. continue
  108. cleaned_lines.append(line)
  109. # 2. 智能合并短行
  110. merged_lines = []
  111. i = 0
  112. while i < len(cleaned_lines):
  113. current_line = cleaned_lines[i]
  114. # 如果当前行很短,尝试与下一行合并
  115. if len(current_line) < 60 and i + 1 < len(cleaned_lines):
  116. next_line = cleaned_lines[i + 1]
  117. # 合并条件:都是内容,不是标题
  118. if (not current_line.endswith(':') and
  119. not current_line.endswith(':') and
  120. not current_line.startswith('#') and
  121. not next_line.startswith('#') and
  122. len(next_line) < 120):
  123. merged_line = current_line + " " + next_line
  124. merged_lines.append(merged_line)
  125. i += 2 # 跳过下一行
  126. continue
  127. merged_lines.append(current_line)
  128. i += 1
  129. # 3. 重新组织段落
  130. paragraphs = []
  131. current_paragraph = []
  132. for line in merged_lines:
  133. # 检查是否是新段落的开始
  134. if (line.startswith('#') or # 标题
  135. line.endswith(':') or # 中文冒号结尾
  136. line.endswith(':') or # 英文冒号结尾
  137. len(line) > 150 or # 长句通常是段落开始
  138. not current_paragraph): # 第一行
  139. # 保存当前段落
  140. if current_paragraph:
  141. paragraphs.append(' '.join(current_paragraph))
  142. current_paragraph = []
  143. paragraphs.append(line)
  144. else:
  145. current_paragraph.append(line)
  146. # 添加最后一个段落
  147. if current_paragraph:
  148. paragraphs.append(' '.join(current_paragraph))
  149. return '\n\n'.join(paragraphs)
  150. def _fallback_text_reader(path: str) -> str:
  151. """
  152. Simple fallback reader for basic text files when MarkItDown is unavailable.
  153. """
  154. try:
  155. with open(path, 'r', encoding='utf-8', errors='ignore') as f:
  156. return f.read()
  157. except Exception:
  158. try:
  159. with open(path, 'r', encoding='latin-1', errors='ignore') as f:
  160. return f.read()
  161. except Exception:
  162. return ""
  163. def _detect_lang(sample: str) -> str:
  164. try:
  165. from langdetect import detect
  166. return detect(sample[:1000]) if sample else "unknown"
  167. except Exception:
  168. return "unknown"
  169. def _is_cjk(ch: str) -> bool:
  170. code = ord(ch)
  171. return (
  172. 0x4E00 <= code <= 0x9FFF or
  173. 0x3400 <= code <= 0x4DBF or
  174. 0x20000 <= code <= 0x2A6DF or
  175. 0x2A700 <= code <= 0x2B73F or
  176. 0x2B740 <= code <= 0x2B81F or
  177. 0x2B820 <= code <= 0x2CEAF or
  178. 0xF900 <= code <= 0xFAFF
  179. )
  180. def _approx_token_len(text: str) -> int:
  181. # 近似估计:CJK字符按1 token,其他按空白分词
  182. cjk = sum(1 for ch in text if _is_cjk(ch))
  183. non_cjk_tokens = len([t for t in text.split() if t])
  184. return cjk + non_cjk_tokens
  185. def _split_paragraphs_with_headings(text: str) -> List[Dict]:
  186. lines = text.splitlines()
  187. heading_stack: List[str] = []
  188. paragraphs: List[Dict] = []
  189. buf: List[str] = []
  190. char_pos = 0
  191. def flush_buf(end_pos: int):
  192. if not buf:
  193. return
  194. content = "\n".join(buf).strip()
  195. if not content:
  196. return
  197. paragraphs.append({
  198. "content": content,
  199. "heading_path": " > ".join(heading_stack) if heading_stack else None,
  200. "start": max(0, end_pos - len(content)),
  201. "end": end_pos,
  202. })
  203. for ln in lines:
  204. raw = ln
  205. if raw.strip().startswith("#"):
  206. # heading line
  207. flush_buf(char_pos)
  208. level = len(raw) - len(raw.lstrip('#'))
  209. title = raw.lstrip('#').strip()
  210. if level <= 0:
  211. level = 1
  212. if level <= len(heading_stack):
  213. heading_stack = heading_stack[:level-1]
  214. heading_stack.append(title)
  215. char_pos += len(raw) + 1
  216. continue
  217. # paragraph accumulation
  218. if raw.strip() == "":
  219. flush_buf(char_pos)
  220. buf = []
  221. else:
  222. buf.append(raw)
  223. char_pos += len(raw) + 1
  224. flush_buf(char_pos)
  225. if not paragraphs:
  226. paragraphs = [{"content": text, "heading_path": None, "start": 0, "end": len(text)}]
  227. return paragraphs
  228. def _chunk_paragraphs(paragraphs: List[Dict], chunk_tokens: int, overlap_tokens: int) -> List[Dict]:
  229. chunks: List[Dict] = []
  230. cur: List[Dict] = []
  231. cur_tokens = 0
  232. i = 0
  233. while i < len(paragraphs):
  234. p = paragraphs[i]
  235. p_tokens = _approx_token_len(p["content"]) or 1
  236. if cur_tokens + p_tokens <= chunk_tokens or not cur:
  237. cur.append(p)
  238. cur_tokens += p_tokens
  239. i += 1
  240. else:
  241. # emit current chunk
  242. content = "\n\n".join(x["content"] for x in cur)
  243. start = cur[0]["start"]
  244. end = cur[-1]["end"]
  245. heading_path = next((x["heading_path"] for x in reversed(cur) if x.get("heading_path")), None)
  246. chunks.append({
  247. "content": content,
  248. "start": start,
  249. "end": end,
  250. "heading_path": heading_path,
  251. })
  252. # build overlap by keeping tail tokens
  253. if overlap_tokens > 0 and cur:
  254. kept: List[Dict] = []
  255. kept_tokens = 0
  256. for x in reversed(cur):
  257. t = _approx_token_len(x["content"]) or 1
  258. if kept_tokens + t > overlap_tokens:
  259. break
  260. kept.append(x)
  261. kept_tokens += t
  262. cur = list(reversed(kept))
  263. cur_tokens = kept_tokens
  264. else:
  265. cur = []
  266. cur_tokens = 0
  267. if cur:
  268. content = "\n\n".join(x["content"] for x in cur)
  269. start = cur[0]["start"]
  270. end = cur[-1]["end"]
  271. heading_path = next((x["heading_path"] for x in reversed(cur) if x.get("heading_path")), None)
  272. chunks.append({
  273. "content": content,
  274. "start": start,
  275. "end": end,
  276. "heading_path": heading_path,
  277. })
  278. return chunks
  279. def load_and_chunk_texts(paths: List[str], chunk_size: int = 800, chunk_overlap: int = 100, namespace: Optional[str] = None, source_label: str = "rag") -> List[Dict]:
  280. """
  281. Universal document loader and chunker using MarkItDown.
  282. Converts all supported formats to markdown, then chunks intelligently.
  283. """
  284. print(f"[RAG] Universal loader start: files={len(paths)} chunk_size={chunk_size} overlap={chunk_overlap} ns={namespace or 'default'}")
  285. chunks: List[Dict] = []
  286. seen_hashes = set()
  287. for path in paths:
  288. if not os.path.exists(path):
  289. print(f"[WARNING] File not found: {path}")
  290. continue
  291. print(f"[RAG] Processing: {path}")
  292. ext = (os.path.splitext(path)[1] or '').lower()
  293. # Convert to markdown using MarkItDown
  294. markdown_text = _convert_to_markdown(path)
  295. if not markdown_text.strip():
  296. print(f"[WARNING] No content extracted from: {path}")
  297. continue
  298. lang = _detect_lang(markdown_text)
  299. doc_id = hashlib.md5(f"{path}|{len(markdown_text)}".encode('utf-8')).hexdigest()
  300. # Always use markdown-aware chunking for better structure preservation
  301. para = _split_paragraphs_with_headings(markdown_text)
  302. token_chunks = _chunk_paragraphs(para, chunk_tokens=max(1, chunk_size), overlap_tokens=max(0, chunk_overlap))
  303. for ch in token_chunks:
  304. content = ch["content"]
  305. start = ch.get("start", 0)
  306. end = ch.get("end", start + len(content))
  307. norm = content.strip()
  308. if not norm:
  309. continue
  310. content_hash = hashlib.md5(norm.encode('utf-8')).hexdigest()
  311. if content_hash in seen_hashes:
  312. continue
  313. seen_hashes.add(content_hash)
  314. chunk_id = hashlib.md5(f"{doc_id}|{start}|{end}|{content_hash}".encode('utf-8')).hexdigest()
  315. chunks.append({
  316. "id": chunk_id,
  317. "content": content,
  318. "metadata": {
  319. "source_path": path,
  320. "file_ext": ext,
  321. "doc_id": doc_id,
  322. "lang": lang,
  323. "start": start,
  324. "end": end,
  325. "content_hash": content_hash,
  326. "namespace": namespace or "default",
  327. "source": source_label,
  328. "external": True,
  329. "heading_path": ch.get("heading_path"),
  330. "format": "markdown", # Mark all content as markdown-processed
  331. },
  332. })
  333. print(f"[RAG] Universal loader done: total_chunks={len(chunks)}")
  334. return chunks
  335. def build_graph_from_chunks(neo4j, chunks: List[Dict]) -> None:
  336. created_docs = set()
  337. for ch in chunks:
  338. mem_id = ch["id"]
  339. meta = ch.get("metadata", {})
  340. source_path = meta.get("source_path")
  341. doc_id = meta.get("doc_id")
  342. if doc_id and doc_id not in created_docs:
  343. created_docs.add(doc_id)
  344. try:
  345. neo4j.add_entity(
  346. entity_id=doc_id,
  347. name=os.path.basename(source_path or doc_id),
  348. entity_type="Document",
  349. properties={"source_path": source_path, "lang": meta.get("lang")}
  350. )
  351. except Exception:
  352. pass
  353. try:
  354. neo4j.add_entity(entity_id=mem_id, name=mem_id, entity_type="Memory", properties={
  355. "source_path": source_path,
  356. "doc_id": doc_id,
  357. "start": meta.get("start"),
  358. "end": meta.get("end"),
  359. })
  360. except Exception:
  361. pass
  362. if doc_id:
  363. try:
  364. neo4j.add_relationship(from_id=doc_id, to_id=mem_id, rel_type="HAS_CHUNK", properties={})
  365. except Exception:
  366. pass
  367. def _preprocess_markdown_for_embedding(text: str) -> str:
  368. """
  369. Preprocess markdown text for better embedding quality.
  370. Removes excessive markup while preserving semantic content.
  371. """
  372. import re
  373. # Remove markdown headers symbols but keep the text
  374. text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE)
  375. # Remove markdown links but keep the text
  376. text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
  377. # Remove markdown emphasis markers
  378. text = re.sub(r'\*\*([^*]+)\*\*', r'\1', text) # bold
  379. text = re.sub(r'\*([^*]+)\*', r'\1', text) # italic
  380. text = re.sub(r'`([^`]+)`', r'\1', text) # inline code
  381. # Remove markdown code blocks but keep content
  382. text = re.sub(r'```[^\n]*\n([\s\S]*?)```', r'\1', text)
  383. # Remove excessive whitespace
  384. text = re.sub(r'\n\s*\n', '\n\n', text)
  385. text = re.sub(r'[ \t]+', ' ', text)
  386. return text.strip()
  387. def _create_default_vector_store(dimension: int = None) -> QdrantVectorStore:
  388. """
  389. Create default Qdrant vector store with RAG-optimized settings.
  390. 使用连接管理器避免重复连接。
  391. """
  392. if dimension is None:
  393. dimension = get_dimension(384)
  394. # Check for Qdrant configuration
  395. qdrant_url = os.getenv("QDRANT_URL")
  396. qdrant_api_key = os.getenv("QDRANT_API_KEY")
  397. # 使用连接管理器
  398. from ..storage.qdrant_store import QdrantConnectionManager
  399. return QdrantConnectionManager.get_instance(
  400. url=qdrant_url,
  401. api_key=qdrant_api_key,
  402. collection_name="hello_agents_rag_vectors",
  403. vector_size=dimension,
  404. distance="cosine"
  405. )
  406. # Cache functions removed - using unified embedder with internal caching
  407. def index_chunks(
  408. store = None,
  409. chunks: List[Dict] = None,
  410. cache_db: Optional[str] = None,
  411. batch_size: int = 64,
  412. rag_namespace: str = "default"
  413. ) -> None:
  414. """
  415. Index markdown chunks with unified embedding and Qdrant storage.
  416. Uses百炼 API with fallback to sentence-transformers.
  417. """
  418. if not chunks:
  419. print("[RAG] No chunks to index")
  420. return
  421. # Use unified embedding from embedding module
  422. embedder = get_text_embedder()
  423. dimension = get_dimension(384)
  424. # Create default Qdrant store if not provided
  425. if store is None:
  426. store = _create_default_vector_store(dimension)
  427. print(f"[RAG] Created default Qdrant store with dimension {dimension}")
  428. # Preprocess markdown texts for better embeddings
  429. processed_texts = []
  430. for c in chunks:
  431. raw_content = c["content"]
  432. processed_content = _preprocess_markdown_for_embedding(raw_content)
  433. processed_texts.append(processed_content)
  434. print(f"[RAG] Embedding start: total_texts={len(processed_texts)} batch_size={batch_size}")
  435. # Batch encoding with unified embedder
  436. vecs: List[List[float]] = []
  437. for i in range(0, len(processed_texts), batch_size):
  438. part = processed_texts[i:i+batch_size]
  439. try:
  440. # Use unified embedder directly (handles caching internally)
  441. part_vecs = embedder.encode(part)
  442. # Normalize to List[List[float]]
  443. if not isinstance(part_vecs, list):
  444. # 单个numpy数组转为列表中的列表
  445. if hasattr(part_vecs, "tolist"):
  446. part_vecs = [part_vecs.tolist()]
  447. else:
  448. part_vecs = [list(part_vecs)]
  449. else:
  450. # 检查是否是嵌套列表
  451. if part_vecs and not isinstance(part_vecs[0], (list, tuple)) and hasattr(part_vecs[0], "__len__"):
  452. # numpy数组列表 -> 转换每个数组
  453. normalized_vecs = []
  454. for v in part_vecs:
  455. if hasattr(v, "tolist"):
  456. normalized_vecs.append(v.tolist())
  457. else:
  458. normalized_vecs.append(list(v))
  459. part_vecs = normalized_vecs
  460. elif part_vecs and not isinstance(part_vecs[0], (list, tuple)):
  461. # 单个向量被误判为列表,实际应该包装成[[...]]
  462. if hasattr(part_vecs, "tolist"):
  463. part_vecs = [part_vecs.tolist()]
  464. else:
  465. part_vecs = [list(part_vecs)]
  466. for v in part_vecs:
  467. try:
  468. # 确保向量是float列表
  469. if hasattr(v, "tolist"):
  470. v = v.tolist()
  471. v_norm = [float(x) for x in v]
  472. if len(v_norm) != dimension:
  473. print(f"[WARNING] 向量维度异常: 期望{dimension}, 实际{len(v_norm)}")
  474. # 用零向量填充或截断
  475. if len(v_norm) < dimension:
  476. v_norm.extend([0.0] * (dimension - len(v_norm)))
  477. else:
  478. v_norm = v_norm[:dimension]
  479. vecs.append(v_norm)
  480. except Exception as e:
  481. print(f"[WARNING] 向量转换失败: {e}, 使用零向量")
  482. vecs.append([0.0] * dimension)
  483. except Exception as e:
  484. print(f"[WARNING] Batch {i} encoding failed: {e}")
  485. print(f"[RAG] Retrying batch {i} with smaller chunks...")
  486. # 尝试重试:将批次分解为更小的块
  487. success = False
  488. for j in range(0, len(part), 8): # 更小的批次
  489. small_part = part[j:j+8]
  490. try:
  491. import time
  492. time.sleep(2) # 等待2秒避免频率限制
  493. small_vecs = embedder.encode(small_part)
  494. # Normalize to List[List[float]]
  495. if isinstance(small_vecs, list) and small_vecs and not isinstance(small_vecs[0], list):
  496. small_vecs = [small_vecs]
  497. for v in small_vecs:
  498. if hasattr(v, "tolist"):
  499. v = v.tolist()
  500. try:
  501. v_norm = [float(x) for x in v]
  502. if len(v_norm) != dimension:
  503. print(f"[WARNING] 向量维度异常: 期望{dimension}, 实际{len(v_norm)}")
  504. if len(v_norm) < dimension:
  505. v_norm.extend([0.0] * (dimension - len(v_norm)))
  506. else:
  507. v_norm = v_norm[:dimension]
  508. vecs.append(v_norm)
  509. success = True
  510. except Exception as e2:
  511. print(f"[WARNING] 小批次向量转换失败: {e2}")
  512. vecs.append([0.0] * dimension)
  513. except Exception as e2:
  514. print(f"[WARNING] 小批次 {j//8} 仍然失败: {e2}")
  515. # 为这个小批次创建零向量
  516. for _ in range(len(small_part)):
  517. vecs.append([0.0] * dimension)
  518. if not success:
  519. print(f"[ERROR] 批次 {i} 完全失败,使用零向量")
  520. print(f"[RAG] Embedding progress: {min(i+batch_size, len(processed_texts))}/{len(processed_texts)}")
  521. # Prepare metadata with RAG tags
  522. metas: List[Dict] = []
  523. ids: List[str] = []
  524. for ch in chunks:
  525. meta = {
  526. "memory_id": ch["id"],
  527. "user_id": "rag_user",
  528. "memory_type": "rag_chunk",
  529. "content": ch["content"], # Keep original markdown content
  530. "data_source": "rag_pipeline", # RAG identification tag
  531. "rag_namespace": rag_namespace,
  532. "is_rag_data": True, # Clear RAG data marker
  533. }
  534. # Merge chunk metadata
  535. meta.update(ch.get("metadata", {}))
  536. metas.append(meta)
  537. ids.append(ch["id"])
  538. print(f"[RAG] Qdrant upsert start: n={len(vecs)}")
  539. success = store.add_vectors(vectors=vecs, metadata=metas, ids=ids)
  540. if success:
  541. print(f"[RAG] Qdrant upsert done: {len(vecs)} vectors indexed")
  542. else:
  543. print(f"[RAG] Qdrant upsert failed")
  544. raise RuntimeError("Failed to index vectors to Qdrant")
  545. def embed_query(query: str) -> List[float]:
  546. """
  547. Embed query using unified embedding (百炼 with fallback).
  548. """
  549. embedder = get_text_embedder()
  550. dimension = get_dimension(384)
  551. try:
  552. vec = embedder.encode(query)
  553. # Normalize to List[float]
  554. if hasattr(vec, "tolist"):
  555. vec = vec.tolist()
  556. # 处理嵌套列表情况
  557. if isinstance(vec, list) and vec and isinstance(vec[0], (list, tuple)):
  558. vec = vec[0] # Extract first vector if nested
  559. # 转换为float列表
  560. result = [float(x) for x in vec]
  561. # 检查维度
  562. if len(result) != dimension:
  563. print(f"[WARNING] Query向量维度异常: 期望{dimension}, 实际{len(result)}")
  564. # 用零向量填充或截断
  565. if len(result) < dimension:
  566. result.extend([0.0] * (dimension - len(result)))
  567. else:
  568. result = result[:dimension]
  569. return result
  570. except Exception as e:
  571. print(f"[WARNING] Query embedding failed: {e}")
  572. # Return zero vector as fallback
  573. return [0.0] * dimension
  574. def search_vectors(
  575. store = None,
  576. query: str = "",
  577. top_k: int = 8,
  578. rag_namespace: Optional[str] = None,
  579. only_rag_data: bool = True,
  580. score_threshold: Optional[float] = None
  581. ) -> List[Dict]:
  582. """
  583. Search RAG vectors using unified embedding and Qdrant.
  584. """
  585. if not query:
  586. return []
  587. # Create default store if not provided
  588. if store is None:
  589. store = _create_default_vector_store()
  590. # Embed query with unified embedder
  591. qv = embed_query(query)
  592. # Build filter for RAG data
  593. where = {"memory_type": "rag_chunk"}
  594. if only_rag_data:
  595. where["is_rag_data"] = True
  596. where["data_source"] = "rag_pipeline"
  597. if rag_namespace:
  598. where["rag_namespace"] = rag_namespace
  599. try:
  600. return store.search_similar(
  601. query_vector=qv,
  602. limit=top_k,
  603. score_threshold=score_threshold,
  604. where=where
  605. )
  606. except Exception as e:
  607. print(f"[WARNING] RAG search failed: {e}")
  608. return []
  609. def _prompt_mqe(query: str, n: int) -> List[str]:
  610. try:
  611. from core.llm import HelloAgentsLLM
  612. llm = HelloAgentsLLM()
  613. prompt = [
  614. {"role": "system", "content": "你是检索查询扩展助手。生成语义等价或互补的多样化查询。使用中文,简短,避免标点。"},
  615. {"role": "user", "content": f"原始查询:{query}\n请给出{n}个不同表述的查询,每行一个。"}
  616. ]
  617. text = llm.invoke(prompt)
  618. lines = [ln.strip("- \t") for ln in (text or "").splitlines()]
  619. outs = [ln for ln in lines if ln]
  620. return outs[:n] or [query]
  621. except Exception:
  622. return [query]
  623. def _prompt_hyde(query: str) -> Optional[str]:
  624. try:
  625. from core.llm import HelloAgentsLLM
  626. llm = HelloAgentsLLM()
  627. prompt = [
  628. {"role": "system", "content": "根据用户问题,先写一段可能的答案性段落,用于向量检索的查询文档(不要分析过程)。"},
  629. {"role": "user", "content": f"问题:{query}\n请直接写一段中等长度、客观、包含关键术语的段落。"}
  630. ]
  631. return llm.invoke(prompt)
  632. except Exception:
  633. return None
  634. def search_vectors_expanded(
  635. store = None,
  636. query: str = "",
  637. top_k: int = 8,
  638. rag_namespace: Optional[str] = None,
  639. only_rag_data: bool = True,
  640. score_threshold: Optional[float] = None,
  641. enable_mqe: bool = False,
  642. mqe_expansions: int = 2,
  643. enable_hyde: bool = False,
  644. candidate_pool_multiplier: int = 4,
  645. ) -> List[Dict]:
  646. """
  647. Search with query expansion using unified embedding and Qdrant.
  648. """
  649. if not query:
  650. return []
  651. # Create default store if not provided
  652. if store is None:
  653. store = _create_default_vector_store()
  654. # expansions
  655. expansions: List[str] = [query]
  656. if enable_mqe and mqe_expansions > 0:
  657. expansions.extend(_prompt_mqe(query, mqe_expansions))
  658. if enable_hyde:
  659. hyde_text = _prompt_hyde(query)
  660. if hyde_text:
  661. expansions.append(hyde_text)
  662. # unique and trim
  663. uniq: List[str] = []
  664. for e in expansions:
  665. if e and e not in uniq:
  666. uniq.append(e)
  667. expansions = uniq[: max(1, len(uniq))]
  668. # distribute pool per expansion
  669. pool = max(top_k * candidate_pool_multiplier, 20)
  670. per = max(1, pool // max(1, len(expansions)))
  671. # Build filter for RAG data
  672. where = {"memory_type": "rag_chunk"}
  673. if only_rag_data:
  674. where["is_rag_data"] = True
  675. where["data_source"] = "rag_pipeline"
  676. if rag_namespace:
  677. where["rag_namespace"] = rag_namespace
  678. # collect hits across expansions
  679. agg: Dict[str, Dict] = {}
  680. for q in expansions:
  681. qv = embed_query(q)
  682. hits = store.search_similar(query_vector=qv, limit=per, score_threshold=score_threshold, where=where)
  683. for h in hits:
  684. mid = h.get("metadata", {}).get("memory_id", h.get("id"))
  685. s = float(h.get("score", 0.0))
  686. if mid not in agg or s > float(agg[mid].get("score", 0.0)):
  687. agg[mid] = h
  688. # return top by score
  689. merged = list(agg.values())
  690. merged.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)
  691. return merged[:top_k]
  692. def _try_load_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
  693. try:
  694. from sentence_transformers import CrossEncoder
  695. return CrossEncoder(model_name)
  696. except Exception:
  697. return None
  698. def rerank_with_cross_encoder(query: str, items: List[Dict], model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", top_k: int = 10) -> List[Dict]:
  699. ce = _try_load_cross_encoder(model_name)
  700. if ce is None or not items:
  701. return items[:top_k]
  702. pairs = [[query, it.get("content", "")] for it in items]
  703. try:
  704. scores = ce.predict(pairs)
  705. for it, s in zip(items, scores):
  706. it["rerank_score"] = float(s)
  707. items.sort(key=lambda x: x.get("rerank_score", x.get("score", 0.0)), reverse=True)
  708. return items[:top_k]
  709. except Exception:
  710. return items[:top_k]
  711. def compute_graph_signals_from_pool(vector_hits: List[Dict], same_doc_weight: float = 1.0, proximity_weight: float = 1.0, proximity_window_chars: int = 1600) -> Dict[str, float]:
  712. """
  713. Compute graph signals with direct parameters instead of environment variables.
  714. """
  715. # group by doc
  716. by_doc: Dict[str, List[Dict]] = {}
  717. for h in vector_hits:
  718. meta = h.get("metadata", {})
  719. did = meta.get("doc_id")
  720. if not did:
  721. # fall back to memory_id grouping if doc missing
  722. did = meta.get("memory_id") or h.get("id")
  723. by_doc.setdefault(did, []).append(h)
  724. # same-doc density score
  725. doc_counts = {d: len(arr) for d, arr in by_doc.items()}
  726. max_count = max(doc_counts.values()) if doc_counts else 1
  727. # proximity score per hit within same doc
  728. graph_signal: Dict[str, float] = {}
  729. for did, arr in by_doc.items():
  730. arr.sort(key=lambda x: x.get("metadata", {}).get("start", 0))
  731. # precompute density
  732. density = doc_counts.get(did, 1) / max_count
  733. # proximity accumulation
  734. for i, h in enumerate(arr):
  735. mid = h.get("metadata", {}).get("memory_id", h.get("id"))
  736. pos_i = h.get("metadata", {}).get("start", 0)
  737. prox_acc = 0.0
  738. # look around neighbors within window
  739. # two-pointer expansion
  740. # left
  741. j = i - 1
  742. while j >= 0:
  743. pos_j = arr[j].get("metadata", {}).get("start", 0)
  744. dist = abs(pos_i - pos_j)
  745. if dist > proximity_window_chars:
  746. break
  747. prox_acc += max(0.0, 1.0 - (dist / max(1.0, float(proximity_window_chars))))
  748. j -= 1
  749. # right
  750. j = i + 1
  751. while j < len(arr):
  752. pos_j = arr[j].get("metadata", {}).get("start", 0)
  753. dist = abs(pos_i - pos_j)
  754. if dist > proximity_window_chars:
  755. break
  756. prox_acc += max(0.0, 1.0 - (dist / max(1.0, float(proximity_window_chars))))
  757. j += 1
  758. # combine
  759. score = same_doc_weight * density + proximity_weight * prox_acc
  760. graph_signal[mid] = graph_signal.get(mid, 0.0) + score
  761. # normalize to [0,1]
  762. if graph_signal:
  763. max_v = max(graph_signal.values())
  764. if max_v > 0:
  765. for k in list(graph_signal.keys()):
  766. graph_signal[k] = graph_signal[k] / max_v
  767. return graph_signal
  768. def rank(vector_hits: List[Dict], graph_signals: Optional[Dict[str, float]] = None, w_vector: float = 0.7, w_graph: float = 0.3) -> List[Dict]:
  769. """
  770. Rank results with direct weight parameters instead of environment variables.
  771. """
  772. items: List[Dict] = []
  773. graph_signals = graph_signals or {}
  774. for h in vector_hits:
  775. mid = h.get("metadata", {}).get("memory_id", h.get("id"))
  776. g = float(graph_signals.get(mid, 0.0))
  777. v = float(h.get("score", 0.0))
  778. score = w_vector * v + w_graph * g
  779. items.append({
  780. "memory_id": mid,
  781. "score": score,
  782. "vector_score": v,
  783. "graph_score": g,
  784. "content": h.get("metadata", {}).get("content", ""),
  785. "metadata": h.get("metadata", {}),
  786. })
  787. items.sort(key=lambda x: x["score"], reverse=True)
  788. return items
  789. def merge_snippets(ranked_items: List[Dict], max_chars: int = 1200) -> str:
  790. out: List[str] = []
  791. total = 0
  792. for it in ranked_items:
  793. text = it.get("content", "").strip()
  794. if not text:
  795. continue
  796. if total + len(text) > max_chars:
  797. remain = max_chars - total
  798. if remain <= 0:
  799. break
  800. out.append(text[:remain])
  801. total += remain
  802. break
  803. out.append(text)
  804. total += len(text)
  805. return "\n\n".join(out)
  806. def expand_neighbors_from_pool(selected: List[Dict], pool: List[Dict], neighbors: int = 1, max_additions: int = 5) -> List[Dict]:
  807. if not selected or not pool or neighbors <= 0:
  808. return selected
  809. # index pool by doc_id and sort by start
  810. by_doc: Dict[str, List[Dict]] = {}
  811. for it in pool:
  812. meta = it.get("metadata", {})
  813. did = meta.get("doc_id")
  814. if not did:
  815. continue
  816. by_doc.setdefault(did, []).append(it)
  817. for did, arr in by_doc.items():
  818. arr.sort(key=lambda x: (x.get("metadata", {}).get("start", 0)))
  819. selected_ids = set(it.get("memory_id") for it in selected)
  820. additions: List[Dict] = []
  821. for it in selected:
  822. meta = it.get("metadata", {})
  823. did = meta.get("doc_id")
  824. if not did or did not in by_doc:
  825. continue
  826. arr = by_doc[did]
  827. # find index
  828. try:
  829. idx = next(i for i, x in enumerate(arr) if x.get("memory_id") == it.get("memory_id"))
  830. except StopIteration:
  831. continue
  832. for offset in range(1, neighbors + 1):
  833. for j in (idx - offset, idx + offset):
  834. if 0 <= j < len(arr):
  835. cand = arr[j]
  836. mid = cand.get("memory_id")
  837. if mid not in selected_ids:
  838. additions.append(cand)
  839. selected_ids.add(mid)
  840. if len(additions) >= max_additions:
  841. break
  842. if len(additions) >= max_additions:
  843. break
  844. if len(additions) >= max_additions:
  845. break
  846. # keep relative order by score
  847. extended = list(selected) + additions
  848. extended.sort(key=lambda x: (x.get("rerank_score", x.get("score", 0.0))), reverse=True)
  849. return extended
  850. def merge_snippets_grouped(ranked_items: List[Dict], max_chars: int = 1200, include_citations: bool = True) -> str:
  851. # Group by doc_id and aggregate doc score
  852. by_doc: Dict[str, List[Dict]] = {}
  853. doc_score: Dict[str, float] = {}
  854. for it in ranked_items:
  855. meta = it.get("metadata", {})
  856. did = meta.get("doc_id") or meta.get("source_path") or "unknown"
  857. by_doc.setdefault(did, []).append(it)
  858. doc_score[did] = doc_score.get(did, 0.0) + float(it.get("score", 0.0))
  859. # Sort docs by aggregate score
  860. ordered_docs = sorted(by_doc.keys(), key=lambda d: doc_score.get(d, 0.0), reverse=True)
  861. # Within doc, order by start offset to preserve context
  862. for d in ordered_docs:
  863. by_doc[d].sort(key=lambda x: (x.get("metadata", {}).get("start", 0)))
  864. out: List[str] = []
  865. citations: List[Dict] = []
  866. total = 0
  867. cite_index = 1
  868. for did in ordered_docs:
  869. parts = by_doc[did]
  870. for it in parts:
  871. text = (it.get("content", "") or "").strip()
  872. if not text:
  873. continue
  874. # add citation marker if enabled
  875. suffix = ""
  876. if include_citations:
  877. suffix = f" [{cite_index}]"
  878. need = len(text) + (len(suffix) if suffix else 0)
  879. if total + need > max_chars:
  880. remain = max_chars - total
  881. if remain <= 0:
  882. break
  883. clipped = text[: max(0, remain - len(suffix))]
  884. if clipped:
  885. out.append(clipped + suffix)
  886. total += len(clipped) + len(suffix)
  887. if include_citations:
  888. m = it.get("metadata", {})
  889. citations.append({
  890. "index": cite_index,
  891. "source_path": m.get("source_path"),
  892. "doc_id": m.get("doc_id"),
  893. "start": m.get("start"),
  894. "end": m.get("end"),
  895. "heading_path": m.get("heading_path"),
  896. })
  897. cite_index += 1
  898. break
  899. out.append(text + suffix)
  900. total += need
  901. if include_citations:
  902. m = it.get("metadata", {})
  903. citations.append({
  904. "index": cite_index,
  905. "source_path": m.get("source_path"),
  906. "doc_id": m.get("doc_id"),
  907. "start": m.get("start"),
  908. "end": m.get("end"),
  909. "heading_path": m.get("heading_path"),
  910. })
  911. cite_index += 1
  912. if total >= max_chars:
  913. break
  914. merged = "\n\n".join(out)
  915. if include_citations and citations:
  916. lines: List[str] = [merged, "", "References:"]
  917. for c in citations:
  918. loc = ""
  919. if c.get("start") is not None and c.get("end") is not None:
  920. loc = f" ({c['start']}-{c['end']})"
  921. hp = f" – {c['heading_path']}" if c.get("heading_path") else ""
  922. sp = c.get("source_path") or c.get("doc_id") or "source"
  923. lines.append(f"[{c['index']}] {sp}{loc}{hp}")
  924. return "\n".join(lines)
  925. return merged
  926. def compress_ranked_items(ranked_items: List[Dict], enable_compression: bool = True, max_per_doc: int = 2, join_gap: int = 200) -> List[Dict]:
  927. """
  928. Compress ranked items with direct parameters instead of environment variables.
  929. """
  930. if not enable_compression:
  931. return ranked_items
  932. by_doc_count: Dict[str, int] = {}
  933. last_by_doc: Dict[str, Dict] = {}
  934. new_items: List[Dict] = []
  935. for it in ranked_items:
  936. meta = it.get("metadata", {})
  937. did = meta.get("doc_id") or meta.get("source_path") or "unknown"
  938. start = int(meta.get("start") or 0)
  939. end = int(meta.get("end") or (start + len(it.get("content", "") or "")))
  940. if did not in last_by_doc:
  941. last_by_doc[did] = it
  942. by_doc_count[did] = 1
  943. new_items.append(it)
  944. continue
  945. last = last_by_doc[did]
  946. lmeta = last.get("metadata", {})
  947. lstart = int(lmeta.get("start") or 0)
  948. lend = int(lmeta.get("end") or (lstart + len(last.get("content", "") or "")))
  949. if start - lend <= join_gap and start >= lstart:
  950. # merge into last
  951. merged_text = (last.get("content", "") or "").strip()
  952. add_text = (it.get("content", "") or "").strip()
  953. if add_text:
  954. if merged_text:
  955. merged_text = merged_text + "\n\n" + add_text
  956. else:
  957. merged_text = add_text
  958. last["content"] = merged_text
  959. lmeta["end"] = max(lend, end)
  960. # keep the higher score
  961. try:
  962. last["score"] = max(float(last.get("score", 0.0)), float(it.get("score", 0.0)))
  963. except Exception:
  964. pass
  965. last_by_doc[did] = last
  966. else:
  967. cnt = by_doc_count.get(did, 0)
  968. if cnt >= max_per_doc:
  969. continue
  970. new_items.append(it)
  971. last_by_doc[did] = it
  972. by_doc_count[did] = cnt + 1
  973. return new_items
  974. def tldr_summarize(text: str, bullets: int = 3) -> Optional[str]:
  975. try:
  976. if not text or len(text.strip()) == 0:
  977. return None
  978. from core.llm import HelloAgentsLLM
  979. llm = HelloAgentsLLM()
  980. prompt = [
  981. {"role": "system", "content": "请将以下内容概括为简洁的要点列表(最多3-5条),用中文,避免重复,突出关键信息。"},
  982. {"role": "user", "content": f"请用 {max(1, min(5, int(bullets)))} 条要点总结:\n\n{text}"},
  983. ]
  984. out = llm.invoke(prompt)
  985. return out
  986. except Exception:
  987. return None
  988. # ==================
  989. # High-level RAG Pipeline API
  990. # ==================
  991. def create_rag_pipeline(
  992. qdrant_url: Optional[str] = None,
  993. qdrant_api_key: Optional[str] = None,
  994. collection_name: str = "hello_agents_rag_vectors",
  995. rag_namespace: str = "default"
  996. ) -> Dict[str, Any]:
  997. """
  998. Create a complete RAG pipeline with Qdrant and unified embedding.
  999. Returns:
  1000. Dict containing store, namespace, and helper functions
  1001. """
  1002. dimension = get_dimension(384)
  1003. store = QdrantVectorStore(
  1004. url=qdrant_url,
  1005. api_key=qdrant_api_key,
  1006. collection_name=collection_name,
  1007. vector_size=dimension,
  1008. distance="cosine"
  1009. )
  1010. def add_documents(file_paths: List[str], chunk_size: int = 800, chunk_overlap: int = 100):
  1011. """Add documents to RAG pipeline"""
  1012. chunks = load_and_chunk_texts(
  1013. paths=file_paths,
  1014. chunk_size=chunk_size,
  1015. chunk_overlap=chunk_overlap,
  1016. namespace=rag_namespace,
  1017. source_label="rag"
  1018. )
  1019. index_chunks(
  1020. store=store,
  1021. chunks=chunks,
  1022. rag_namespace=rag_namespace
  1023. )
  1024. return len(chunks)
  1025. def search(query: str, top_k: int = 8, score_threshold: Optional[float] = None):
  1026. """Search RAG knowledge base"""
  1027. return search_vectors(
  1028. store=store,
  1029. query=query,
  1030. top_k=top_k,
  1031. rag_namespace=rag_namespace,
  1032. score_threshold=score_threshold
  1033. )
  1034. def search_advanced(
  1035. query: str,
  1036. top_k: int = 8,
  1037. enable_mqe: bool = False,
  1038. enable_hyde: bool = False,
  1039. score_threshold: Optional[float] = None
  1040. ):
  1041. """Advanced search with query expansion"""
  1042. return search_vectors_expanded(
  1043. store=store,
  1044. query=query,
  1045. top_k=top_k,
  1046. rag_namespace=rag_namespace,
  1047. enable_mqe=enable_mqe,
  1048. enable_hyde=enable_hyde,
  1049. score_threshold=score_threshold
  1050. )
  1051. def get_stats():
  1052. """Get pipeline statistics"""
  1053. return store.get_collection_stats()
  1054. return {
  1055. "store": store,
  1056. "namespace": rag_namespace,
  1057. "add_documents": add_documents,
  1058. "search": search,
  1059. "search_advanced": search_advanced,
  1060. "get_stats": get_stats
  1061. }