agent.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """Step 2: StockInsightAgent — ReAct 范式智能股票分析助手"""
  2. import re
  3. from llm_client import HelloAgentsLLM
  4. from tools import (
  5. ToolExecutor, get_realtime_quote, get_historical_data,
  6. get_financial_data, calc_indicators, get_news
  7. )
  8. STOCK_AGENT_PROMPT = """
  9. 你是一个专业的股票分析助手 StockInsightAgent。你可以获取A股实时行情、历史K线、
  10. 财务报表、技术指标和新闻舆情,然后综合这些信息给出分析结论。
  11. 可用工具如下:
  12. {tools}
  13. 请严格按照以下格式进行回应:
  14. Thought: 你的思考过程,分析用户需求并规划下一步行动。
  15. Action: 你决定采取的行动,必须是以下格式之一:
  16. - `{{tool_name}}[{{tool_input}}]`:调用一个可用工具。
  17. 工具输入格式说明:
  18. - 实时行情: 股票代码 或 股票简称,如 "600519" 或 "贵州茅台"
  19. - 历史K线: "代码|周期|天数",如 "600519|daily|60"
  20. - 财务数据: 股票代码,如 "600519"
  21. - 技术指标: "代码|周期|天数",如 "600519|daily|120"
  22. - 新闻舆情: 股票代码,如 "600519"
  23. - `Finish[最终分析报告]`:当你收集到足够的信息,能够输出完整分析报告时。
  24. 分析报告的格式应该包含:
  25. 1. 股票基本概况(最新价、涨跌幅、市值等)
  26. 2. 技术面分析(趋势、均线、MACD、RSI、支撑压力位)
  27. 3. 基本面分析(财务指标解读)
  28. 4. 消息面(近期新闻舆情)
  29. 5. 风险提示
  30. 6. 综合小结
  31. 重要:
  32. - 每次只调用一个工具
  33. - 如果用户只给名称没给代码,用该名称搜索实时行情就能找到代码
  34. - 收集到足够信息后输出完整的 Markdown 分析报告
  35. - 数据异常时如实说明,不要编造
  36. 现在,请开始分析:
  37. Question: {question}
  38. History: {history}
  39. """
  40. class StockInsightAgent:
  41. """智能股票分析 Agent — ReAct 范式"""
  42. def __init__(self, llm_client: HelloAgentsLLM, max_steps: int = 8):
  43. self.llm_client = llm_client
  44. self.tool_executor = ToolExecutor()
  45. self.max_steps = max_steps
  46. self.history = []
  47. # 注册 5 个分析工具
  48. print("注册工具:")
  49. self.tool_executor.registerTool(
  50. "GetRealtimeQuote",
  51. "获取实时行情(最新价/涨跌幅/成交量/PE/市值)。输入: 股票代码或简称",
  52. get_realtime_quote
  53. )
  54. self.tool_executor.registerTool(
  55. "GetHistoricalData",
  56. "获取历史K线(OHLCV)。输入格式: '代码|周期|天数',周期=daily/weekly/monthly",
  57. get_historical_data
  58. )
  59. self.tool_executor.registerTool(
  60. "GetFinancialData",
  61. "获取财务指标(ROE/ROA/毛利率/营收增长等)。输入: 股票代码",
  62. get_financial_data
  63. )
  64. self.tool_executor.registerTool(
  65. "CalcIndicators",
  66. "计算技术指标(MA/MACD/RSI/布林带/支撑压力位)。输入格式: '代码|周期|天数'",
  67. calc_indicators
  68. )
  69. self.tool_executor.registerTool(
  70. "GetNews",
  71. "获取近期新闻舆情。输入: 股票代码",
  72. get_news
  73. )
  74. print()
  75. def run(self, question: str):
  76. self.history = []
  77. current_step = 0
  78. print(f"\n{'='*60}")
  79. print(f" [用户]: {question}")
  80. print(f"{'='*60}")
  81. while current_step < self.max_steps:
  82. current_step += 1
  83. print(f"\n--- 第 {current_step}/{self.max_steps} 步 ---")
  84. tools_desc = self.tool_executor.getAvailableTools()
  85. history_str = "\n".join(self.history) if self.history else "(首次执行,无历史)"
  86. prompt = STOCK_AGENT_PROMPT.format(
  87. tools=tools_desc, question=question, history=history_str
  88. )
  89. messages = [{"role": "user", "content": prompt}]
  90. response_text = self.llm_client.think(messages=messages)
  91. if not response_text:
  92. print(" LLM 未返回有效响应。")
  93. break
  94. thought, action = self._parse_output(response_text)
  95. if thought:
  96. print(f" [思考] {thought}")
  97. if not action:
  98. print(" 未能解析出 Action,流程终止。")
  99. break
  100. if action.startswith("Finish"):
  101. final_answer = self._parse_action_input(action)
  102. print(f"\n{'='*60}")
  103. print(f" [分析报告]")
  104. print(f"{'='*60}")
  105. print(final_answer)
  106. return final_answer
  107. tool_name, tool_input = self._parse_action(action)
  108. if not tool_name:
  109. self.history.append("Observation: Action 格式无效。")
  110. continue
  111. print(f" [行动] {tool_name}[{tool_input[:60]}{'...' if len(tool_input)>60 else ''}]")
  112. tool_func = self.tool_executor.getTool(tool_name)
  113. observation = (
  114. tool_func(tool_input) if tool_func
  115. else f"错误:未找到工具 '{tool_name}'"
  116. )
  117. print(f" [观察]\n{observation[:300]}{'...' if len(str(observation))>300 else ''}")
  118. self.history.append(f"Action: {action}")
  119. self.history.append(f"Observation: {observation}")
  120. print(f"\n 已达到最大步数 ({self.max_steps}),流程终止。")
  121. return None
  122. def _parse_output(self, text: str):
  123. # 支持 Thought: / **Thought:** / Thought: 等多种格式
  124. thought_match = re.search(
  125. r"(?:\*\*)?Thought(?:\*\*)?\s*[::]\s*(.*?)(?=\n(?:\*\*)?Action(?:\*\*)?\s*[::]|$)",
  126. text, re.DOTALL | re.IGNORECASE
  127. )
  128. action_match = re.search(
  129. r"(?:\*\*)?Action(?:\*\*)?\s*[::]\s*(.*?)$",
  130. text, re.DOTALL | re.IGNORECASE
  131. )
  132. thought = thought_match.group(1).strip() if thought_match else None
  133. action = action_match.group(1).strip() if action_match else None
  134. # 清理 markdown 反引号
  135. if action:
  136. action = action.strip("`\"' \n\r")
  137. return thought, action
  138. def _parse_action(self, action_text: str):
  139. # 清理反引号、markdown bold 等
  140. clean = action_text.strip("`\"' \n\r*_")
  141. match = re.match(r"(\w+)\[(.*)\]", clean, re.DOTALL)
  142. return (match.group(1), match.group(2)) if match else (None, None)
  143. def _parse_action_input(self, action_text: str):
  144. clean = action_text.strip("`\"' \n\r*_")
  145. match = re.match(r"\w+\[(.*)\]", clean, re.DOTALL)
  146. return match.group(1) if match else ""