mx_data_tool.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. 智能股票分析助手 — HelloAgents 金融数据工具封装
  3. 将东方财富 mx-data Skill 封装为符合 HelloAgents 标准 Tool 接口的工具类。
  4. Agent可通过此工具调用自然语言查询获取行情、财务、关联关系等数据。
  5. """
  6. import sys
  7. from pathlib import Path
  8. # 将HelloAgents框架和skills路径加入sys.path
  9. _PROJECT_ROOT = Path(__file__).parent.parent.parent
  10. _HELLO_PATH = _PROJECT_ROOT / "HelloAgents Optimized"
  11. _SKILLS_PATH = _PROJECT_ROOT / "skills" / "金融数据" / "mx-data"
  12. for p in [_HELLO_PATH, _SKILLS_PATH]:
  13. if str(p) not in sys.path:
  14. sys.path.insert(0, str(p))
  15. from hello_agents.tools import Tool, ToolParameter
  16. class MXDataTool(Tool):
  17. """金融数据查询工具 — 封装东方财富妙想mx-data Skill
  18. 支持通过自然语言查询A股行情、财务指标、公司概况、股东信息等金融数据。
  19. 使用示例:
  20. tool = MXDataTool(api_key="your_mx_apikey")
  21. result = tool.run({"query": "贵州茅台最新价 涨跌幅"})
  22. """
  23. def __init__(self, api_key: str = None):
  24. super().__init__(
  25. name="mx_data",
  26. description=(
  27. "东方财富金融数据查询工具。支持查询A股股票的实时行情、历史行情、"
  28. "财务指标(净利润、ROE、毛利率等)、公司概况(主营业务、高管信息)、"
  29. "股东信息(十大股东)、指数行情、板块行情等。"
  30. "支持自然语言查询,如'贵州茅台近三年净利润 营业收入'、"
  31. "'沪深300指数最新点位'、'比亚迪公司简介 主营业务'。"
  32. ),
  33. )
  34. # 获取API密钥:优先参数 > 环境变量
  35. import os
  36. self.api_key = api_key or os.getenv("MX_APIKEY", "")
  37. # 延迟导入mx_data模块
  38. self._mx_module = None
  39. def _get_mx_module(self):
  40. """延迟导入mx_data模块(避免初始化时的导入错误)"""
  41. if self._mx_module is None:
  42. import mx_data as _mx_data
  43. self._mx_module = _mx_data
  44. return self._mx_module
  45. def get_parameters(self) -> list:
  46. return [
  47. ToolParameter(
  48. name="query",
  49. type="string",
  50. description=(
  51. "自然语言查询语句。支持中文查询,例如:\n"
  52. "- 行情: '贵州茅台最新价 涨跌幅', '比亚迪近一年每个交易日收盘价'\n"
  53. "- 财务: '贵州茅台近三年净利润 营业收入 净资产收益率'\n"
  54. "- 公司: '比亚迪公司简介 主营业务 董事长是谁'\n"
  55. "- 股东: '贵州茅台十大股东'\n"
  56. "- 指数: '沪深300指数最新点位'"
  57. ),
  58. required=True,
  59. ),
  60. ]
  61. def run(self, parameters: dict) -> str:
  62. """执行金融数据查询
  63. Args:
  64. parameters: {"query": "自然语言查询"}
  65. Returns:
  66. 格式化的查询结果文本
  67. """
  68. query = parameters.get("query", "")
  69. if not query:
  70. return "错误:查询内容不能为空"
  71. if not self.api_key:
  72. return "错误:MX_APIKEY 未配置,无法查询金融数据。请设置环境变量 MX_APIKEY"
  73. try:
  74. mx = self._get_mx_module()
  75. # 创建MXData实例并查询
  76. data_querier = mx.MXData(api_key=self.api_key)
  77. result = data_querier.query(query)
  78. # 解析结果
  79. tables, condition_parts, total_rows, error = mx.MXData.parse_result(result)
  80. if error:
  81. return f"查询出错: {error}"
  82. if not tables:
  83. return "查询未返回任何数据"
  84. # 格式化输出
  85. return self._format_result(tables, condition_parts, total_rows)
  86. except Exception as e:
  87. return f"金融数据查询异常: {str(e)}"
  88. def _format_result(self, tables: list, condition_parts: list, total_rows: int) -> str:
  89. """将查询结果格式化为可读文本"""
  90. lines = []
  91. # 查询条件
  92. if condition_parts:
  93. lines.append("## 查询条件")
  94. for part in condition_parts:
  95. lines.append(part)
  96. lines.append("")
  97. # 数据表格
  98. lines.append(f"## 查询结果({len(tables)}个表,共{total_rows}行数据)\n")
  99. for idx, table in enumerate(tables):
  100. sheet_name = table.get("sheet_name", f"表{idx+1}")
  101. rows = table.get("rows", [])
  102. fieldnames = table.get("fieldnames", [])
  103. lines.append(f"### {sheet_name}")
  104. if not rows:
  105. lines.append("(无数据)")
  106. continue
  107. # 限制输出行数(避免上下文过长)
  108. max_rows = 30
  109. display_rows = rows[:max_rows]
  110. # 表头
  111. header = " | ".join(fieldnames[:10]) # 最多显示10列
  112. lines.append(f"| {header} |")
  113. lines.append(f"|{'|'.join(['---'] * min(len(fieldnames), 10))}|")
  114. # 数据行
  115. for row in display_rows:
  116. values = [str(row.get(col, "")) for col in fieldnames[:10]]
  117. lines.append(f"| {' | '.join(values)} |")
  118. if len(rows) > max_rows:
  119. lines.append(f"\n*(仅显示前{max_rows}行,共{len(rows)}行)*")
  120. lines.append("")
  121. return "\n".join(lines)