document.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. """文档处理模块"""
  2. from typing import List, Dict, Any, Optional
  3. from dataclasses import dataclass
  4. from datetime import datetime
  5. import hashlib
  6. @dataclass
  7. class Document:
  8. """文档类"""
  9. content: str
  10. metadata: Dict[str, Any]
  11. doc_id: Optional[str] = None
  12. def __post_init__(self):
  13. if self.doc_id is None:
  14. # 基于内容生成ID
  15. self.doc_id = hashlib.md5(self.content.encode()).hexdigest()
  16. @dataclass
  17. class DocumentChunk:
  18. """文档块类"""
  19. content: str
  20. metadata: Dict[str, Any]
  21. chunk_id: Optional[str] = None
  22. doc_id: Optional[str] = None
  23. chunk_index: int = 0
  24. def __post_init__(self):
  25. if self.chunk_id is None:
  26. # 基于文档ID和块索引生成ID
  27. chunk_content = f"{self.doc_id}_{self.chunk_index}_{self.content[:50]}"
  28. self.chunk_id = hashlib.md5(chunk_content.encode()).hexdigest()
  29. class DocumentProcessor:
  30. """文档处理器"""
  31. def __init__(
  32. self,
  33. chunk_size: int = 1000,
  34. chunk_overlap: int = 200,
  35. separators: Optional[List[str]] = None
  36. ):
  37. self.chunk_size = chunk_size
  38. self.chunk_overlap = chunk_overlap
  39. self.separators = separators or ["\n\n", "\n", "。", ".", " "]
  40. def process_document(self, document: Document) -> List[DocumentChunk]:
  41. """
  42. 处理文档,分割成块
  43. Args:
  44. document: 输入文档
  45. Returns:
  46. 文档块列表
  47. """
  48. chunks = self._split_text(document.content)
  49. document_chunks = []
  50. for i, chunk_content in enumerate(chunks):
  51. # 创建块的元数据
  52. chunk_metadata = document.metadata.copy()
  53. chunk_metadata.update({
  54. "doc_id": document.doc_id,
  55. "chunk_index": i,
  56. "total_chunks": len(chunks),
  57. "processed_at": datetime.now().isoformat()
  58. })
  59. chunk = DocumentChunk(
  60. content=chunk_content,
  61. metadata=chunk_metadata,
  62. doc_id=document.doc_id,
  63. chunk_index=i
  64. )
  65. document_chunks.append(chunk)
  66. return document_chunks
  67. def process_documents(self, documents: List[Document]) -> List[DocumentChunk]:
  68. """
  69. 批量处理文档
  70. Args:
  71. documents: 文档列表
  72. Returns:
  73. 所有文档块列表
  74. """
  75. all_chunks = []
  76. for document in documents:
  77. chunks = self.process_document(document)
  78. all_chunks.extend(chunks)
  79. return all_chunks
  80. def _split_text(self, text: str) -> List[str]:
  81. """
  82. 分割文本为块
  83. Args:
  84. text: 输入文本
  85. Returns:
  86. 文本块列表
  87. """
  88. if len(text) <= self.chunk_size:
  89. return [text]
  90. chunks = []
  91. start = 0
  92. while start < len(text):
  93. # 确定块的结束位置
  94. end = start + self.chunk_size
  95. if end >= len(text):
  96. # 最后一块
  97. chunks.append(text[start:])
  98. break
  99. # 寻找合适的分割点
  100. split_point = self._find_split_point(text, start, end)
  101. if split_point == -1:
  102. # 没找到合适的分割点,强制分割
  103. split_point = end
  104. chunks.append(text[start:split_point])
  105. # 计算下一块的开始位置(考虑重叠)
  106. start = max(start + 1, split_point - self.chunk_overlap)
  107. return chunks
  108. def _find_split_point(self, text: str, start: int, end: int) -> int:
  109. """
  110. 在指定范围内寻找最佳分割点
  111. Args:
  112. text: 文本
  113. start: 开始位置
  114. end: 结束位置
  115. Returns:
  116. 分割点位置,-1表示未找到
  117. """
  118. # 从后往前寻找分隔符
  119. for separator in self.separators:
  120. # 在end附近寻找分隔符
  121. search_start = max(start, end - 100) # 在最后100个字符中寻找
  122. for i in range(end - len(separator), search_start - 1, -1):
  123. if text[i:i + len(separator)] == separator:
  124. return i + len(separator)
  125. return -1
  126. def merge_chunks(self, chunks: List[DocumentChunk], max_length: int = 2000) -> List[DocumentChunk]:
  127. """
  128. 合并小的文档块
  129. Args:
  130. chunks: 文档块列表
  131. max_length: 合并后的最大长度
  132. Returns:
  133. 合并后的文档块列表
  134. """
  135. if not chunks:
  136. return []
  137. merged_chunks = []
  138. current_chunk = chunks[0]
  139. for next_chunk in chunks[1:]:
  140. # 检查是否可以合并
  141. combined_length = len(current_chunk.content) + len(next_chunk.content)
  142. if (combined_length <= max_length and
  143. current_chunk.doc_id == next_chunk.doc_id):
  144. # 合并块
  145. current_chunk.content += "\n" + next_chunk.content
  146. current_chunk.metadata["total_chunks"] = current_chunk.metadata.get("total_chunks", 1) + 1
  147. else:
  148. # 不能合并,保存当前块
  149. merged_chunks.append(current_chunk)
  150. current_chunk = next_chunk
  151. # 添加最后一个块
  152. merged_chunks.append(current_chunk)
  153. return merged_chunks
  154. def filter_chunks(self, chunks: List[DocumentChunk], min_length: int = 50) -> List[DocumentChunk]:
  155. """
  156. 过滤太短的文档块
  157. Args:
  158. chunks: 文档块列表
  159. min_length: 最小长度
  160. Returns:
  161. 过滤后的文档块列表
  162. """
  163. return [chunk for chunk in chunks if len(chunk.content.strip()) >= min_length]
  164. def add_chunk_metadata(self, chunks: List[DocumentChunk], metadata: Dict[str, Any]) -> List[DocumentChunk]:
  165. """
  166. 为文档块添加额外元数据
  167. Args:
  168. chunks: 文档块列表
  169. metadata: 要添加的元数据
  170. Returns:
  171. 更新后的文档块列表
  172. """
  173. for chunk in chunks:
  174. chunk.metadata.update(metadata)
  175. return chunks
  176. def load_text_file(file_path: str, encoding: str = "utf-8") -> Document:
  177. """
  178. 加载文本文件为文档
  179. Args:
  180. file_path: 文件路径
  181. encoding: 文件编码
  182. Returns:
  183. 文档对象
  184. """
  185. with open(file_path, 'r', encoding=encoding) as f:
  186. content = f.read()
  187. metadata = {
  188. "source": file_path,
  189. "type": "text_file",
  190. "loaded_at": datetime.now().isoformat()
  191. }
  192. return Document(content=content, metadata=metadata)
  193. def create_document(content: str, **metadata) -> Document:
  194. """
  195. 创建文档的便捷函数
  196. Args:
  197. content: 文档内容
  198. **metadata: 元数据
  199. Returns:
  200. 文档对象
  201. """
  202. return Document(content=content, metadata=metadata)