1
0

test.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. """
  2. 数据库Agent助手 - 测试脚本
  3. 用于测试各个组件的功能
  4. """
  5. import os
  6. from dotenv import load_dotenv
  7. from hello_agents import HelloAgentsLLM
  8. from react_agent import DatabaseAgent, DatabaseConfig
  9. from tools import OracleQueryTool, SQLGeneratorTool
  10. load_dotenv()
  11. def test_database_connection():
  12. """测试数据库连接"""
  13. print("=" * 60)
  14. print("测试1: 数据库连接")
  15. print("=" * 60)
  16. db_config = DatabaseConfig()
  17. if not db_config.validate():
  18. print("❌ 数据库配置不完整")
  19. return False
  20. print(f"配置信息: {db_config.get_connection_string()}")
  21. oracle_tool = OracleQueryTool(db_config)
  22. if oracle_tool.connect():
  23. print("✅ 数据库连接成功")
  24. schema_info = oracle_tool.get_schema_info()
  25. print("\n数据库表结构:")
  26. print(schema_info)
  27. oracle_tool.disconnect()
  28. return True
  29. else:
  30. print("❌ 数据库连接失败")
  31. return False
  32. def test_sql_generation():
  33. """测试SQL生成功能"""
  34. print("\n" + "=" * 60)
  35. print("测试2: SQL生成")
  36. print("=" * 60)
  37. try:
  38. llm = HelloAgentsLLM()
  39. sql_generator = SQLGeneratorTool(llm)
  40. test_queries = [
  41. "查询所有员工信息",
  42. "查询工资大于5000的员工",
  43. "统计各部门的员工数量"
  44. ]
  45. for query in test_queries:
  46. print(f"\n自然语言: {query}")
  47. sql = sql_generator.generate_sql(query, "表 EMPLOYEES: ID (NUMBER), NAME (VARCHAR2), SALARY (NUMBER), DEPARTMENT (VARCHAR2)")
  48. print(f"生成的SQL: {sql}")
  49. is_valid, msg = sql_generator.validate_sql(sql)
  50. print(f"验证结果: {msg}")
  51. return True
  52. except Exception as e:
  53. print(f"❌ SQL生成测试失败: {e}")
  54. return False
  55. def test_agent_query():
  56. """测试Agent查询功能"""
  57. print("\n" + "=" * 60)
  58. print("测试3: Agent查询")
  59. print("=" * 60)
  60. try:
  61. llm = HelloAgentsLLM()
  62. db_config = DatabaseConfig()
  63. if not db_config.validate():
  64. print("❌ 数据库配置不完整")
  65. return False
  66. agent = DatabaseAgent(
  67. name="TestAgent",
  68. llm=llm,
  69. db_config=db_config,
  70. max_steps=5
  71. )
  72. test_query = "查询所有员工的信息"
  73. print(f"\n测试查询: {test_query}")
  74. result = agent.run(test_query)
  75. print(f"\n查询结果:\n{result}")
  76. return True
  77. except Exception as e:
  78. print(f"❌ Agent查询测试失败: {e}")
  79. return False
  80. def main():
  81. """运行所有测试"""
  82. print("🧪 数据库Agent助手 - 测试套件")
  83. print("=" * 60)
  84. results = []
  85. results.append(("数据库连接", test_database_connection()))
  86. results.append(("SQL生成", test_sql_generation()))
  87. results.append(("Agent查询", test_agent_query()))
  88. print("\n" + "=" * 60)
  89. print("测试结果汇总")
  90. print("=" * 60)
  91. for test_name, result in results:
  92. status = "✅ 通过" if result else "❌ 失败"
  93. print(f"{test_name}: {status}")
  94. passed = sum(1 for _, result in results if result)
  95. total = len(results)
  96. print(f"\n总计: {passed}/{total} 测试通过")
  97. if __name__ == "__main__":
  98. main()