reflection_agent.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. """Step 4: Reflection 反思模式 — 分析报告自审与优化
  2. 来自 hello-agents 教程第4章 Reflection 范式:
  3. 初始分析 -> 反思评审 -> 优化改进 -> 循环直到无需改进
  4. """
  5. from typing import List, Dict, Any
  6. from llm_client import HelloAgentsLLM
  7. from tools import (
  8. get_realtime_quote, get_historical_data, get_financial_data,
  9. calc_indicators, get_news
  10. )
  11. class Memory:
  12. """短期记忆:存储分析轨迹(初始报告 + 反思 + 改进报告)"""
  13. def __init__(self):
  14. self.records: List[Dict[str, Any]] = []
  15. def add_record(self, record_type: str, content: str):
  16. self.records.append({"type": record_type, "content": content})
  17. print(f" [记忆] 新增 '{record_type}' 记录")
  18. def get_trajectory(self) -> str:
  19. parts = []
  20. for r in self.records:
  21. label = "分析报告" if r['type'] == 'execution' else "评审意见"
  22. parts.append(f"--- {label} ---\n{r['content']}")
  23. return "\n\n".join(parts)
  24. def get_last_execution(self) -> str:
  25. for r in reversed(self.records):
  26. if r['type'] == 'execution':
  27. return r['content']
  28. return None
  29. # ===== 提示词模板 =====
  30. INITIAL_ANALYSIS_PROMPT = """你是资深股票分析师。请对以下股票进行全面分析。
  31. ## 可用数据
  32. {data_section}
  33. ## 分析要求
  34. {task}
  35. 请输出一份结构化的分析报告,包含:
  36. 1. 基本概况与走势判断
  37. 2. 技术面分析(趋势、均线、指标、关键价位)
  38. 3. 基本面与估值分析
  39. 4. 消息面与市场情绪
  40. 5. 风险提示
  41. 6. 短期/中长期操作建议
  42. """
  43. REFLECTION_PROMPT = """你是极其严格的股票投资评审专家。你的任务是审查以下分析报告,找出缺陷和遗漏。
  44. ## 原始分析需求
  45. {task}
  46. ## 待审查报告
  47. {report}
  48. ## 评审维度
  49. 1. **数据完整性**: 是否遗漏了关键数据维度?是否有数据解读错误?
  50. 2. **风险覆盖**: 是否遗漏了重要风险因素?(如:行业政策风险、汇率风险、大股东减持、解禁压力)
  51. 3. **逻辑一致性**: 技术面/基本面/消息面的结论是否一致?是否有自相矛盾?
  52. 4. **盲区检查**: 有没有未考虑的视角?(如:产业链上下游、跨市场联动、资金流向)
  53. 5. **反向思考**: 如果最终判断是看多,请从看空角度挑战;反之亦然。有没有可能判断错了?
  54. 请直接输出你的评审意见,指出至少3个具体的缺陷或遗漏。
  55. 如果分析报告已经全面、严谨、无明显疏漏,请回答"无需改进"。
  56. """
  57. REFINE_PROMPT = """你是资深股票分析师。评审专家指出了你上一轮分析报告的缺陷。
  58. ## 原始需求
  59. {task}
  60. ## 你的上一轮报告
  61. {previous_report}
  62. ## 评审意见
  63. {reflection}
  64. 请基于评审意见,生成一份改进后的完整分析报告。要特别针对评审指出的问题补充分析和修正。
  65. 输出完整的改进版报告(6个章节结构不变,但内容要体现反思后的改进)。
  66. """
  67. class ReflectionStockAgent:
  68. """反思式股票分析 Agent — 分析→评审→改进 循环"""
  69. TOOLS = {
  70. "GetRealtimeQuote": get_realtime_quote,
  71. "GetHistoricalData": get_historical_data,
  72. "CalcIndicators": calc_indicators,
  73. "GetFinancialData": get_financial_data,
  74. "GetNews": get_news,
  75. }
  76. def __init__(self, llm_client: HelloAgentsLLM, max_iterations: int = 2):
  77. self.llm_client = llm_client
  78. self.memory = Memory()
  79. self.max_iterations = max_iterations
  80. def run(self, task: str):
  81. print(f"\n{'='*60}")
  82. print(f" Reflection 反思模式 (最多{self.max_iterations}轮)")
  83. print(f" 问题: {task}")
  84. print(f"{'='*60}")
  85. # --- 阶段1: 自动采集数据 ---
  86. print("\n [阶段1] 自动采集数据...")
  87. data_text = self._collect_data(task)
  88. # --- 阶段2: 初始分析 ---
  89. print(f"\n [阶段2] 生成初始分析报告...")
  90. initial_prompt = INITIAL_ANALYSIS_PROMPT.format(
  91. data_section=data_text, task=task
  92. )
  93. messages = [{"role": "user", "content": initial_prompt}]
  94. initial_report = self.llm_client.think(messages=messages) or ""
  95. self.memory.add_record("execution", initial_report)
  96. print(f" [初始报告] 已生成 ({len(initial_report)} 字)")
  97. # --- 阶段3: 反思-改进循环 ---
  98. for iteration in range(self.max_iterations):
  99. print(f"\n [阶段3] 第 {iteration+1}/{self.max_iterations} 轮反思...")
  100. # 评审
  101. reflect_prompt = REFLECTION_PROMPT.format(
  102. task=task, report=self.memory.get_last_execution()
  103. )
  104. messages = [{"role": "user", "content": reflect_prompt}]
  105. feedback = self.llm_client.think(messages=messages) or ""
  106. self.memory.add_record("reflection", feedback)
  107. # 检查收敛
  108. if "无需改进" in feedback:
  109. print("\n [评审] 报告已无明显缺陷,反思结束。")
  110. break
  111. # 改进
  112. print(f"\n [阶段3] 基于评审意见改进报告...")
  113. refine_prompt = REFINE_PROMPT.format(
  114. task=task,
  115. previous_report=self.memory.get_last_execution(),
  116. reflection=feedback,
  117. )
  118. messages = [{"role": "user", "content": refine_prompt}]
  119. refined_report = self.llm_client.think(messages=messages) or ""
  120. self.memory.add_record("execution", refined_report)
  121. print(f" [改进报告] 已生成 ({len(refined_report)} 字)")
  122. # --- 输出最终报告 ---
  123. final_report = self.memory.get_last_execution()
  124. print(f"\n{'='*60}")
  125. print(f" 最终分析报告 (经 {sum(1 for r in self.memory.records if r['type']=='reflection')} 轮反思)")
  126. print(f"{'='*60}")
  127. print(final_report)
  128. return final_report
  129. def _collect_data(self, task: str) -> str:
  130. """自动从任务中提取股票代码,采集关键数据"""
  131. import re
  132. # 提取股票代码
  133. codes = re.findall(r"\b(\d{6})\b", task)
  134. if not codes:
  135. return "(未能自动识别股票代码,请在问题中包含6位代码)"
  136. code = codes[0]
  137. parts = []
  138. # 实时行情
  139. print(f" [采集] 实时行情 {code}...")
  140. r = self.TOOLS["GetRealtimeQuote"](code)
  141. parts.append(f"### 实时行情\n{r}")
  142. # 历史K线 (60天)
  143. print(f" [采集] 60天K线 {code}...")
  144. r = self.TOOLS["GetHistoricalData"](f"{code}|daily|60")
  145. parts.append(f"### 60天历史K线\n{r}")
  146. # 技术指标 (120天)
  147. print(f" [采集] 技术指标 {code}...")
  148. r = self.TOOLS["CalcIndicators"](f"{code}|daily|120")
  149. parts.append(f"### 技术指标\n{r}")
  150. # 财务数据
  151. print(f" [采集] 财务数据 {code}...")
  152. r = self.TOOLS["GetFinancialData"](code)
  153. parts.append(f"### 财务数据\n{r}")
  154. # 新闻
  155. print(f" [采集] 新闻舆情 {code}...")
  156. r = self.TOOLS["GetNews"](code)
  157. parts.append(f"### 新闻舆情\n{r}")
  158. return "\n\n".join(parts)