rag.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. """Step 7: RAG 投资知识库 — 文档分块 + TF-IDF 检索"""
  2. import os
  3. import re
  4. import json
  5. import math
  6. from collections import Counter
  7. from typing import List, Dict, Tuple
  8. class InvestmentKnowledgeBase:
  9. """轻量投资知识库 — 无需外部嵌入 API,TF-IDF 检索"""
  10. def __init__(self, path: str = "memory/knowledge_base.json"):
  11. self.path = path
  12. self.chunks: List[Dict] = []
  13. self._load()
  14. def _load(self):
  15. if os.path.exists(self.path):
  16. try:
  17. with open(self.path, "r", encoding="utf-8") as f:
  18. self.chunks = json.load(f)
  19. except (json.JSONDecodeError, IOError):
  20. self.chunks = []
  21. def _save(self):
  22. os.makedirs(os.path.dirname(self.path), exist_ok=True)
  23. with open(self.path, "w", encoding="utf-8") as f:
  24. json.dump(self.chunks, f, ensure_ascii=False, indent=2)
  25. # ===== 文档导入 =====
  26. def add_text(self, text: str, title: str = "", source: str = "") -> str:
  27. """导入文本,自动分块"""
  28. chunks = self._chunk_text(text, title, source)
  29. self.chunks.extend(chunks)
  30. self._save()
  31. return f"已导入 '{title}',共 {len(chunks)} 个知识块 (总计 {len(self.chunks)} 块)"
  32. def add_file(self, filepath: str) -> str:
  33. """导入文件 (支持 .txt .md)"""
  34. if not os.path.exists(filepath):
  35. return f"文件不存在: {filepath}"
  36. try:
  37. try:
  38. with open(filepath, "r", encoding="utf-8") as f:
  39. text = f.read()
  40. except UnicodeDecodeError:
  41. with open(filepath, "r", encoding="gbk") as f:
  42. text = f.read()
  43. except Exception as e:
  44. return f"读取文件失败: {e}"
  45. title = os.path.basename(filepath)
  46. return self.add_text(text, title, filepath)
  47. def _chunk_text(self, text: str, title: str, source: str,
  48. chunk_size: int = 300, overlap: int = 50) -> List[Dict]:
  49. """按段落+句子边界智能分块"""
  50. # 先按段落分割
  51. paragraphs = re.split(r"\n\s*\n", text)
  52. chunks = []
  53. current = ""
  54. for para in paragraphs:
  55. para = para.strip()
  56. if not para:
  57. continue
  58. if len(current) + len(para) < chunk_size:
  59. current += ("\n" if current else "") + para
  60. else:
  61. if current:
  62. chunks.append(current)
  63. # 如果段落太长,按句子分
  64. if len(para) > chunk_size:
  65. sentences = re.split(r"(?<=[。!?\.!?])\s*", para)
  66. sub = ""
  67. for s in sentences:
  68. if len(sub) + len(s) < chunk_size:
  69. sub += s
  70. else:
  71. if sub:
  72. chunks.append(sub)
  73. sub = s
  74. if sub:
  75. current = sub
  76. else:
  77. current = ""
  78. else:
  79. current = para
  80. if current:
  81. chunks.append(current)
  82. return [{
  83. "id": f"{title}_{i}",
  84. "title": title,
  85. "source": source,
  86. "content": c,
  87. } for i, c in enumerate(chunks)]
  88. # ===== 检索 =====
  89. def search(self, query: str, top_k: int = 5) -> str:
  90. """TF-IDF 检索最相关的知识块"""
  91. if not self.chunks:
  92. return "知识库为空。可以用 '导入知识 文件路径' 来添加文档。"
  93. # 构建词汇表
  94. all_docs = [c["content"] for c in self.chunks]
  95. tokenized_docs = [self._tokenize(d) for d in all_docs]
  96. tokenized_query = self._tokenize(query)
  97. # TF-IDF 计算
  98. df = Counter()
  99. for tokens in tokenized_docs:
  100. df.update(set(tokens))
  101. N = len(tokenized_docs)
  102. scores = []
  103. for i, doc_tokens in enumerate(tokenized_docs):
  104. tf = Counter(doc_tokens)
  105. score = 0
  106. for term in set(tokenized_query):
  107. if term in tf:
  108. tf_val = tf[term] / max(len(doc_tokens), 1)
  109. idf_val = math.log((N + 1) / (df[term] + 1)) + 1
  110. score += tf_val * idf_val
  111. if score > 0:
  112. scores.append((score, i))
  113. scores.sort(key=lambda x: x[0], reverse=True)
  114. if not scores:
  115. # 回退到关键词匹配
  116. for i, (doc, doc_tokens) in enumerate(zip(all_docs, tokenized_docs)):
  117. if any(kw in doc_tokens for kw in tokenized_query):
  118. scores.append((0.5, i))
  119. scores.sort(key=lambda x: x[0], reverse=True)
  120. if not scores:
  121. return f"未找到与 '{query}' 相关的知识"
  122. lines = [f"知识库检索结果 (查询: '{query}'):"]
  123. for score, idx in scores[:top_k]:
  124. chunk = self.chunks[idx]
  125. lines.append(f"\n--- [{score:.2f}] {chunk['title']} ---")
  126. lines.append(chunk["content"][:400])
  127. return "\n".join(lines)
  128. def _tokenize(self, text: str) -> List[str]:
  129. """简单中文分词 (2-gram)"""
  130. # 提取中文字符和英文单词
  131. words = re.findall(r"[一-鿿]{1,2}|[a-zA-Z]+", text.lower())
  132. return [w for w in words if len(w) >= 2]
  133. def stats(self) -> str:
  134. titles = set(c["title"] for c in self.chunks)
  135. return (f"知识库: {len(self.chunks)} 个知识块, "
  136. f"{len(titles)} 篇文档")
  137. def clear(self) -> str:
  138. self.chunks = []
  139. self._save()
  140. return "知识库已清空"
  141. # ===== 预置投资知识 =====
  142. INVESTMENT_KNOWLEDGE = """
  143. # 股票估值方法
  144. ## 市盈率 (PE)
  145. PE = 股价 / 每股收益。反映市场愿意为每元利润支付的价格。
  146. - PE < 10: 可能低估(需排除盈利质量差的情况)
  147. - PE 10-20: 合理区间
  148. - PE 20-30: 中等偏高,通常对应成长股
  149. - PE > 30: 高估值,需有高增长支撑
  150. 行业差异大:银行 PE 通常 5-10 倍,科技股 PE 可达 30-50 倍。
  151. ## 市净率 (PB)
  152. PB = 股价 / 每股净资产。适用于重资产行业(银行、地产、制造业)。
  153. - PB < 1: 破净,可能严重低估
  154. - PB 1-2: 合理偏低
  155. - PB 2-5: 正常水平
  156. - PB > 5: 偏高,需有高 ROE 支撑
  157. ## PEG 指标
  158. PEG = PE / 净利润增长率(%)。用于成长股估值。
  159. - PEG < 0.5: 显著低估
  160. - PEG 0.5-1.0: 合理偏低
  161. - PEG 1.0-1.5: 合理
  162. - PEG > 2.0: 高估
  163. ## 股息率
  164. 股息率 = 每股分红 / 股价。衡量现金回报。
  165. - 股息率 > 4%: 高股息,防御性强
  166. - 股息率 2-4%: 正常水平
  167. - 股息率 < 2%: 偏低
  168. # 技术指标解读
  169. ## MACD 金叉死叉
  170. - 金叉: DIF 上穿 DEA,买入信号。零轴上方金叉更强。
  171. - 死叉: DIF 下穿 DEA,卖出信号。零轴下方死叉更弱。
  172. - 顶背离: 股价新高 MACD 未新高,见顶信号。
  173. - 底背离: 股价新低 MACD 未新低,见底信号。
  174. ## RSI 相对强弱指标
  175. - RSI > 80: 严重超买,回调风险大
  176. - RSI 70-80: 超买区域,短期可能回调
  177. - RSI 30-70: 正常区间
  178. - RSI 20-30: 超卖区域,短期可能反弹
  179. - RSI < 20: 严重超卖,反弹概率高
  180. ## 均线系统
  181. - 多头排列: MA5 > MA10 > MA20 > MA60,上升趋势
  182. - 空头排列: MA5 < MA10 < MA20 < MA60,下降趋势
  183. - 金叉: 短期均线上穿长期均线
  184. - 死叉: 短期均线下穿长期均线
  185. ## 布林带
  186. - 价格触及上轨: 短期超买,可能回调
  187. - 价格触及下轨: 短期超卖,可能反弹
  188. - 带宽收窄: 变盘信号,可能突破
  189. - 带宽扩大: 趋势加速
  190. # 风险控制原则
  191. ## 仓位管理
  192. - 单只股票不超过总仓位 20%
  193. - 单一行业不超过总仓位 30%
  194. - 永远保留 10-20% 现金应对极端情况
  195. - 分批建仓: 至少分 3 次买入,降低成本集中风险
  196. ## 止损原则
  197. - 技术止损: 跌破关键支撑位(MA60/前低)止损
  198. - 比例止损: 亏损超过 8-10% 无条件止损
  199. - 时间止损: 买入后 20 个交易日未达预期,重新评估
  200. - 基本面止损: 公司基本面出现重大恶化,立即止损
  201. ## 风险收益比
  202. - 每笔交易的风险收益比应 >= 1:2
  203. - 预期收益应至少是潜在亏损的 2 倍
  204. # A股交易规则
  205. ## 交易时间
  206. - 早盘集合竞价: 9:15-9:25
  207. - 连续竞价: 9:30-11:30, 13:00-15:00
  208. - 深交所尾盘集合竞价: 14:57-15:00
  209. ## 涨跌幅限制
  210. - 主板: ±10%
  211. - 创业板(300开头)/科创板(688开头): ±20%
  212. - ST 股票: ±5%
  213. - 新股上市前 5 日无涨跌幅限制
  214. ## T+1 制度
  215. A股实行 T+1 交易,当日买入次日才能卖出。
  216. """
  217. # 全局单例
  218. _kb_instance = None
  219. def get_kb() -> InvestmentKnowledgeBase:
  220. global _kb_instance
  221. if _kb_instance is None:
  222. _kb_instance = InvestmentKnowledgeBase()
  223. # 首次初始化时导入预置知识
  224. if not _kb_instance.chunks:
  225. _kb_instance.add_text(INVESTMENT_KNOWLEDGE, "投资基础知识", "built-in")
  226. return _kb_instance
  227. # ===== 工具函数 =====
  228. def rag_search(query: str) -> str:
  229. """搜索投资知识库。输入: 查询关键词或问题"""
  230. return get_kb().search(query.strip())
  231. def rag_import(query: str) -> str:
  232. """导入文档到知识库。输入: 文件路径"""
  233. return get_kb().add_file(query.strip())
  234. def rag_stats(query: str = "") -> str:
  235. """查看知识库统计"""
  236. return get_kb().stats()