Bläddra i källkod

Merge pull request #540 from 939147533/feature/database_agent

[毕业设计] DatabaseAgent- 数据库Agent助手
jjyaoao 1 månad sedan
förälder
incheckning
786c5bdbc2

+ 13 - 0
Co-creation-projects/939147533-DatabaseAgent/.env.example

@@ -0,0 +1,13 @@
+# 数据库Agent助手配置文件
+
+# LLM配置 (使用ModelScope或其他兼容OpenAI接口的服务)
+LLM_MODEL_ID=Qwen/Qwen2.5-7B-Instruct
+LLM_API_KEY=your_api_key_here
+LLM_BASE_URL=https://api.modelscope.cn/v1
+
+# Oracle数据库配置
+DB_HOST=localhost
+DB_PORT=1521
+DB_SERVICE_NAME=ORCL
+DB_USERNAME=system
+DB_PASSWORD=your_password_here

+ 117 - 0
Co-creation-projects/939147533-DatabaseAgent/README.md

@@ -0,0 +1,117 @@
+# 数据库Agent助手
+
+基于hello-agents库实现的智能数据库查询助手,支持将自然语言转换为SQL查询并从Oracle数据库获取数据。
+
+## 📝 项目简介
+
+- 输入自然语言,自动生成SQL语句并执行查询
+- 支持Oracle数据库,返回格式化的查询结果
+- 适用于非技术人员查询数据库或者辅助技术人员快速生成SQL
+
+
+## ✨ 核心功能
+
+- **自然语言转sql**: 用中文描述查询需求,自动转换为SQL语句
+- **Oracle数据库查询**: 查询oracle数据库并返回结果
+
+
+##  🛠️ 技术栈
+
+基于HelloAgentsLearn项目中的ReAct框架实现:
+
+- **ReAct Agent**: 推理-行动循环框架
+- **Tool Registry**: 工具注册和管理
+- **LLM Integration**: 大语言模型集成
+- **Oracle DB**: Oracle数据库连接和查询
+
+## 工具实现
+
+1. **GetSchema**: 获取数据库表结构信息
+2. **GenerateSQL**: 将自然语言转换为SQL语句
+3. **ExecuteQuery**: 执行SQL查询并返回结果
+
+## 🚀 快速开始
+### 环境要求
+
+- Python 3.10+
+### 安装依赖
+
+```bash
+pip install -r requirements.txt
+```
+
+### 配置API密钥
+
+1. 复制示例配置文件:
+```bash
+cp .env.example .env
+```
+
+2. 编辑 `.env` 文件,配置以下参数:
+
+### LLM配置
+- `LLM_MODEL_ID`: 模型ID,如 `qwen3.6:35b-a3b-q4_K_M`
+- `LLM_API_KEY`: API密钥
+- `LLM_BASE_URL`: API服务地址
+
+### Oracle数据库配置
+- `DB_HOST`: 数据库主机地址
+- `DB_PORT`: 数据库端口 (默认: 1521)
+- `DB_SERVICE_NAME`: 服务名称
+- `DB_USERNAME`: 用户名
+- `DB_PASSWORD`: 密码
+
+
+### 创建测试数据
+使用提供的SQL脚本创建测试表和数据:
+
+```bash
+# 在Oracle SQL*Plus或其他Oracle客户端中执行
+sqlplus 用户名/用户密码@数据库地址:1521/服务名称 @setup_database.sql
+# 例如:
+sqlplus system/password@localhost:1521/ORCL @setup_database.sql
+```
+### 运行项目
+
+#### 运行测试程序
+python test.py
+
+#### 运行主程序:
+
+```bash
+python main.py
+```
+
+#### 查询示例
+
+- "查询所有员工信息"
+![img.png](img.png)
+![img_1.png](img_1.png)
+- "查询IT部门的员工平均工资"
+![img_2.png](img_2.png)
+![img_3.png](img_3.png)
+
+
+## 🔮 未来计划
+
+- 增加plan_solve智能体实现查询
+- 增加更多数据库支持
+- 优化提示词设计
+- 优化sql工具
+- 增加查询结果导出功能
+
+## 🤝 贡献指南
+
+欢迎提出Issue和Pull Request!
+
+## 📄 许可证
+
+MIT License
+
+## 👤 作者
+
+- GitHub: [@939147533](https://github.com/939147533)
+
+## 🙏 致谢
+
+感谢Datawhale社区和Hello-Agents项目!

BIN
Co-creation-projects/939147533-DatabaseAgent/img.png


BIN
Co-creation-projects/939147533-DatabaseAgent/img_1.png


BIN
Co-creation-projects/939147533-DatabaseAgent/img_2.png


BIN
Co-creation-projects/939147533-DatabaseAgent/img_3.png


+ 64 - 0
Co-creation-projects/939147533-DatabaseAgent/main.py

@@ -0,0 +1,64 @@
+"""
+数据库Agent助手 - 主程序
+演示如何使用DatabaseAgent进行自然语言查询
+"""
+import os
+from dotenv import load_dotenv
+from hello_agents import HelloAgentsLLM
+from react_agent import DatabaseAgent, DatabaseConfig
+
+load_dotenv()
+
+
+def main():
+    print("=" * 60)
+    print("🤖 数据库Agent助手")
+    print("=" * 60)
+    
+    llm = HelloAgentsLLM()
+    
+    db_config = DatabaseConfig()
+    
+    if not db_config.validate():
+        print("❌ 数据库配置不完整,请检查.env文件")
+        print("需要配置: DB_HOST, DB_PORT, DB_SERVICE_NAME, DB_USERNAME, DB_PASSWORD")
+        return
+    
+    agent = DatabaseAgent(
+        name="DatabaseAssistant",
+        llm=llm,
+        db_config=db_config,
+        max_steps=5
+    )
+    
+    print("\n📝 示例查询:")
+    print("1. 查询所有员工信息")
+    print("2. 查询工资大于5000的员工")
+    print("3. 统计各部门的员工数量")
+    print("4. 查询最近入职的5名员工")
+    print("5. 退出")
+    
+    while True:
+        print("\n" + "=" * 60)
+        user_input = input("请输入您的查询 (或输入 '5' 退出): ").strip()
+        
+        if user_input.lower() in ['5', 'exit', 'quit', '退出']:
+            print("👋 感谢使用数据库Agent助手!")
+            break
+
+        if not user_input:
+            print("⚠️ 请输入有效的查询")
+            continue
+        
+        try:
+            result = agent.run(user_input)
+            print("\n" + "=" * 60)
+            print("📊 查询结果:")
+            print("=" * 60)
+            print(result)
+        except Exception as e:
+            print(f"❌ 执行查询时出错: {e}")
+
+
+if __name__ == "__main__":
+    main()

+ 4 - 0
Co-creation-projects/939147533-DatabaseAgent/requirements.txt

@@ -0,0 +1,4 @@
+oracledb>=2.0.0
+python-dotenv>=1.0.0
+openai>=1.0.0
+hello-agents>=0.1.0

+ 77 - 0
Co-creation-projects/939147533-DatabaseAgent/setup_database.sql

@@ -0,0 +1,77 @@
+"""
+示例Oracle数据库创建脚本
+用于创建测试数据库表和示例数据
+"""
+
+-- 创建员工表
+CREATE TABLE EMPLOYEES (
+    ID NUMBER PRIMARY KEY,
+    NAME VARCHAR2(100) NOT NULL,
+    DEPARTMENT VARCHAR2(50),
+    SALARY NUMBER(10,2),
+    HIRE_DATE DATE,
+    EMAIL VARCHAR2(100),
+    PHONE VARCHAR2(20)
+);
+
+-- 创建部门表
+CREATE TABLE DEPARTMENTS (
+    ID NUMBER PRIMARY KEY,
+    NAME VARCHAR2(50) NOT NULL,
+    LOCATION VARCHAR2(100),
+    BUDGET NUMBER(15,2)
+);
+
+-- 创建项目表
+CREATE TABLE PROJECTS (
+    ID NUMBER PRIMARY KEY,
+    NAME VARCHAR2(100) NOT NULL,
+    DEPARTMENT_ID NUMBER,
+    START_DATE DATE,
+    END_DATE DATE,
+    STATUS VARCHAR2(20),
+    BUDGET NUMBER(15,2),
+    FOREIGN KEY (DEPARTMENT_ID) REFERENCES DEPARTMENTS(ID)
+);
+
+-- 插入部门数据
+INSERT INTO DEPARTMENTS VALUES (1, 'IT', '北京', 1000000);
+INSERT INTO DEPARTMENTS VALUES (2, 'HR', '上海', 500000);
+INSERT INTO DEPARTMENTS VALUES (3, 'Finance', '深圳', 800000);
+INSERT INTO DEPARTMENTS VALUES (4, 'Marketing', '广州', 600000);
+INSERT INTO DEPARTMENTS VALUES (5, 'Operations', '成都', 700000);
+
+-- 插入员工数据
+INSERT INTO EMPLOYEES VALUES (1, '张三', 'IT', 12000, TO_DATE('2020-01-15', 'YYYY-MM-DD'), 'zhangsan@company.com', '13800138001');
+INSERT INTO EMPLOYEES VALUES (2, '李四', 'IT', 15000, TO_DATE('2019-05-20', 'YYYY-MM-DD'), 'lisi@company.com', '13800138002');
+INSERT INTO EMPLOYEES VALUES (3, '王五', 'HR', 8000, TO_DATE('2021-03-10', 'YYYY-MM-DD'), 'wangwu@company.com', '13800138003');
+INSERT INTO EMPLOYEES VALUES (4, '赵六', 'Finance', 10000, TO_DATE('2020-08-25', 'YYYY-MM-DD'), 'zhaoliu@company.com', '13800138004');
+INSERT INTO EMPLOYEES VALUES (5, '钱七', 'IT', 18000, TO_DATE('2018-11-30', 'YYYY-MM-DD'), 'qianqi@company.com', '13800138005');
+INSERT INTO EMPLOYEES VALUES (6, '孙八', 'Marketing', 9000, TO_DATE('2021-07-15', 'YYYY-MM-DD'), 'sunba@company.com', '13800138006');
+INSERT INTO EMPLOYEES VALUES (7, '周九', 'Operations', 8500, TO_DATE('2020-12-01', 'YYYY-MM-DD'), 'zhoujiu@company.com', '13800138007');
+INSERT INTO EMPLOYEES VALUES (8, '吴十', 'IT', 14000, TO_DATE('2019-09-20', 'YYYY-MM-DD'), 'wushi@company.com', '13800138008');
+INSERT INTO EMPLOYEES VALUES (9, '郑十一', 'Finance', 11000, TO_DATE('2020-04-10', 'YYYY-MM-DD'), 'zhengshiyi@company.com', '13800138009');
+INSERT INTO EMPLOYEES VALUES (10, '王十二', 'HR', 7500, TO_DATE('2022-02-28', 'YYYY-MM-DD'), 'wangshier@company.com', '13800138010');
+
+-- 插入项目数据
+INSERT INTO PROJECTS VALUES (1, '数字化转型项目', 1, TO_DATE('2023-01-01', 'YYYY-MM-DD'), TO_DATE('2023-12-31', 'YYYY-MM-DD'), '进行中', 500000);
+INSERT INTO PROJECTS VALUES (2, '人力资源系统升级', 2, TO_DATE('2023-03-01', 'YYYY-MM-DD'), TO_DATE('2023-09-30', 'YYYY-MM-DD'), '已完成', 200000);
+INSERT INTO PROJECTS VALUES (3, '财务审计系统', 3, TO_DATE('2023-06-01', 'YYYY-MM-DD'), TO_DATE('2024-03-31', 'YYYY-MM-DD'), '进行中', 300000);
+INSERT INTO PROJECTS VALUES (4, '市场推广活动', 4, TO_DATE('2023-04-01', 'YYYY-MM-DD'), TO_DATE('2023-10-31', 'YYYY-MM-DD'), '已完成', 250000);
+INSERT INTO PROJECTS VALUES (5, '运营优化项目', 5, TO_DATE('2023-02-01', 'YYYY-MM-DD'), TO_DATE('2023-08-31', 'YYYY-MM-DD'), '已完成', 150000);
+
+COMMIT;
+
+-- 创建一些有用的视图
+CREATE OR REPLACE VIEW V_EMPLOYEE_DEPT AS
+SELECT e.ID, e.NAME, e.DEPARTMENT, e.SALARY, e.HIRE_DATE, d.LOCATION, d.BUDGET AS DEPT_BUDGET
+FROM EMPLOYEES e
+LEFT JOIN DEPARTMENTS d ON e.DEPARTMENT = d.NAME;
+
+CREATE OR REPLACE VIEW V_DEPARTMENT_STATS AS
+SELECT d.ID, d.NAME, d.LOCATION, d.BUDGET, COUNT(e.ID) AS EMP_COUNT, AVG(e.SALARY) AS AVG_SALARY
+FROM DEPARTMENTS d
+LEFT JOIN EMPLOYEES e ON d.NAME = e.DEPARTMENT
+GROUP BY d.ID, d.NAME, d.LOCATION, d.BUDGET;
+
+COMMIT;

+ 34 - 0
Co-creation-projects/939147533-DatabaseAgent/src/config.py

@@ -0,0 +1,34 @@
+"""
+数据库配置管理
+"""
+import os
+from typing import Optional
+from dotenv import load_dotenv
+
+load_dotenv()
+
+
+class DatabaseConfig:
+    """Oracle数据库配置类"""
+    
+    def __init__(
+        self,
+        host: Optional[str] = None,
+        port: Optional[int] = None,
+        service_name: Optional[str] = None,
+        username: Optional[str] = None,
+        password: Optional[str] = None
+    ):
+        self.host = host or os.getenv("DB_HOST", "localhost")
+        self.port = port or int(os.getenv("DB_PORT", "1521"))
+        self.service_name = service_name or os.getenv("DB_SERVICE_NAME", "ORCL")
+        self.username = username or os.getenv("DB_USERNAME", "system")
+        self.password = password or os.getenv("DB_PASSWORD", "")
+        
+    def get_connection_string(self) -> str:
+        """获取Oracle连接字符串"""
+        return f"{self.username}/{self.password}@{self.host}:{self.port}/{self.service_name}"
+    
+    def validate(self) -> bool:
+        """验证配置是否完整"""
+        return all([self.host, self.port, self.service_name, self.username, self.password])

+ 174 - 0
Co-creation-projects/939147533-DatabaseAgent/src/react_agent.py

@@ -0,0 +1,174 @@
+"""
+数据库Agent - 基于ReAct框架的智能数据库查询助手
+"""
+import re
+from typing import Optional, List
+from hello_agents import ReActAgent, HelloAgentsLLM, Config, Message, ToolRegistry
+from tools import OracleQueryTool, SQLGeneratorTool, format_query_result
+from config import DatabaseConfig
+
+
+DATABASE_AGENT_PROMPT = """你是一个专业的数据库查询助手。你可以理解用户的自然语言查询,将其转换为SQL语句,从Oracle数据库中获取数据并格式化输出。
+
+## 可用工具
+{tools}
+
+## 工作流程
+请严格按照以下格式进行回应:
+
+Thought: 你的思考过程,分析用户需求并规划下一步行动。
+Action: 你决定采取的行动,必须是以下格式之一:
+- `{{tool_name}}[{{tool_input}}]` - 调用指定工具
+- `Finish[最终答案]` - 当你有足够信息给出最终答案时
+
+## 使用指南
+1. 当用户提出查询需求时,首先使用 GetSchema 工具获取数据库表结构
+2. 使用 GenerateSQL 工具将自然语言转换为SQL语句
+3. 使用 ExecuteQuery 工具执行SQL并获取结果
+
+## 当前任务
+**Question:** {question}
+
+## 执行历史
+{history}
+
+现在开始你的推理和行动:
+"""
+
+
+class DatabaseAgent(ReActAgent):
+    """数据库查询Agent"""
+    
+    def __init__(
+        self,
+        name: str,
+        llm: HelloAgentsLLM,
+        db_config: DatabaseConfig,
+        system_prompt: Optional[str] = None,
+        config: Optional[Config] = None,
+        max_steps: int = 5
+    ):
+        super().__init__(name, llm, system_prompt, config)
+        
+        self.db_config = db_config
+        self.max_steps = max_steps
+        self.current_history: List[str] = []
+        self.prompt_template = DATABASE_AGENT_PROMPT
+        
+        self.oracle_tool = OracleQueryTool(db_config)
+        self.sql_generator = SQLGeneratorTool(llm)
+        
+        self.tool_registry = ToolRegistry()
+        self.tool_registry.register_function(
+            "GetSchema",
+            "获取数据库表结构信息,包括所有表名和字段定义。",
+            self._get_schema
+        )
+        self.tool_registry.register_function(
+            "GenerateSQL",
+            "将自然语言查询转换为Oracle SQL语句。",
+            self._generate_sql
+        )
+        self.tool_registry.register_function(
+            "ExecuteQuery",
+            "执行SQL查询并返回结果。",
+            self._execute_query
+        )
+        
+        self.schema_cache = None
+        print(f"✅ {name} 初始化完成,最大步数: {max_steps}")
+    
+    def _get_schema(self, input_text: str) -> str:
+        """获取数据库表结构信息,包括所有表名和字段定义"""
+        schema_info = self.oracle_tool.get_schema_info()
+        self.schema_cache = schema_info
+        return schema_info
+    
+    def _generate_sql(self, input_text: str) -> str:
+        """将自然语言查询转换为Oracle SQL语句"""
+        if not self.schema_cache:
+            self.schema_cache = self.oracle_tool.get_schema_info()
+        
+        sql = self.sql_generator.generate_sql(input_text, self.schema_cache)
+        
+        is_valid, msg = self.sql_generator.validate_sql(sql)
+        if not is_valid:
+            return f"SQL生成失败: {msg}"
+        
+        return f"生成的SQL: {sql}"
+    
+    def _execute_query(self, input_text: str) -> str:
+        """执行SQL查询并返回结果"""
+        sql = input_text.strip()
+        
+        if sql.startswith("生成的SQL: "):
+            sql = sql.replace("生成的SQL: ", "")
+        
+        result = self.oracle_tool.execute_query(sql)
+        
+        if not result["success"]:
+            return f"查询执行失败: {result['error']}"
+        
+        formatted_result = format_query_result(result)
+        return formatted_result
+    
+    def run(self, input_text: str, **kwargs) -> str:
+        """运行数据库Agent"""
+        self.current_history = []
+        current_step = 0
+        
+        print(f"\n🤖 {self.name} 开始处理问题: {input_text}")
+        
+        while current_step < self.max_steps:
+            current_step += 1
+            print(f"\n--- 第 {current_step} 步 ---")
+            # 1. 构建提示词
+            tools_desc = self.tool_registry.get_tools_description()
+            history_str = "\n".join(self.current_history)
+            prompt = self.prompt_template.format(
+                tools=tools_desc,
+                question=input_text,
+                history=history_str
+            )
+            # 2. 调用LLM
+            messages = [{"role": "user", "content": prompt}]
+            response_text = self.llm.invoke(messages, **kwargs)
+            # 3. 解析输出
+            thought, action = self._parse_output(response_text)
+            
+            if thought:
+                print(f"🤔 思考: {thought}")
+            
+            if action and action.startswith("Finish"):
+                final_answer = self._parse_action_input(action)
+                self.add_message(Message(input_text, "user"))
+                self.add_message(Message(final_answer, "assistant"))
+                return final_answer
+            
+            if action:
+                tool_name, tool_input = self._parse_action(action)
+                observation = self.tool_registry.execute_tool(tool_name, tool_input)
+                print(f"🎬 行动: {tool_name}[{tool_input}]")
+                print(f"👀 观察: {observation}")
+                self.current_history.append(f"Action: {action}")
+                self.current_history.append(f"Observation: {observation}")
+        
+        final_answer = "抱歉,我无法在限定步数内完成这个任务。"
+        self.add_message(Message(input_text, "user"))
+        self.add_message(Message(final_answer, "assistant"))
+        return final_answer
+    
+    def _parse_output(self, text: str):
+        thought_match = re.search(r"Thought:\s*(.*?)(?=\nAction:|$)", text, re.DOTALL)
+        action_match = re.search(r"Action:\s*(.*?)$", text, re.DOTALL)
+        thought = thought_match.group(1).strip() if thought_match else None
+        action = action_match.group(1).strip() if action_match else None
+        return thought, action
+    
+    def _parse_action(self, action_text: str):
+        match = re.match(r"(\w+)\[(.*)\]", action_text, re.DOTALL)
+        return (match.group(1), match.group(2)) if match else (None, None)
+    
+    def _parse_action_input(self, action_text: str):
+        match = re.match(r"\w+\[(.*)\]", action_text, re.DOTALL)
+        return match.group(1) if match else ""

+ 197 - 0
Co-creation-projects/939147533-DatabaseAgent/src/tools.py

@@ -0,0 +1,197 @@
+"""
+数据库查询工具集
+"""
+import oracledb
+from typing import Dict, Any
+from config import DatabaseConfig
+from hello_agents import HelloAgentsLLM
+
+
+class OracleQueryTool:
+    """Oracle数据库查询工具"""
+    
+    def __init__(self, config: DatabaseConfig):
+        self.config = config
+        self.connection = None
+        
+    def connect(self) -> bool:
+        """连接到Oracle数据库"""
+        try:
+            self.connection = oracledb.connect(
+                user=self.config.username,
+                password=self.config.password,
+                host=self.config.host,
+                port=self.config.port,
+                service_name=self.config.service_name
+            )
+            return True
+        except Exception as e:
+            print(f"数据库连接失败: {e}")
+            return False
+    
+    def disconnect(self):
+        """断开数据库连接"""
+        if self.connection:
+            self.connection.close()
+            self.connection = None
+    
+    def execute_query(self, sql: str) -> Dict[str, Any]:
+        """执行SQL查询并返回结果"""
+        if not self.connection:
+            if not self.connect():
+                return {"success": False, "error": "无法连接到数据库"}
+        
+        try:
+            cursor = self.connection.cursor()
+            cursor.execute(sql)
+            
+            columns = [col[0] for col in cursor.description]
+            rows = cursor.fetchall()
+            
+            cursor.close()
+            
+            return {
+                "success": True,
+                "columns": columns,
+                "rows": rows,
+                "row_count": len(rows),
+                "sql": sql
+            }
+        except Exception as e:
+            return {"success": False, "error": str(e), "sql": sql}
+    
+    def get_schema_info(self) -> str:
+        """获取数据库表结构信息"""
+        if not self.connection:
+            if not self.connect():
+                return "无法连接到数据库"
+        
+        try:
+            cursor = self.connection.cursor()
+            
+            cursor.execute("""
+                SELECT table_name 
+                FROM user_tables 
+                ORDER BY table_name
+            """)
+            tables = [row[0] for row in cursor.fetchall()]
+            
+            schema_info = []
+            for table in tables:
+                cursor.execute(f"""
+                    SELECT column_name, data_type, nullable
+                    FROM user_tab_columns
+                    WHERE table_name = UPPER('{table}')
+                    ORDER BY column_id
+                """)
+                columns = cursor.fetchall()
+                
+                col_desc = ", ".join([
+                    f"{col[0]} ({col[1]})" 
+                    for col in columns
+                ])
+                schema_info.append(f"表 {table}: {col_desc}")
+            
+            cursor.close()
+            return "\n".join(schema_info)
+        except Exception as e:
+            return f"获取表结构失败: {e}"
+
+
+class SQLGeneratorTool:
+    """SQL生成工具 - 使用LLM将自然语言转换为SQL"""
+    
+    def __init__(self, llm: HelloAgentsLLM):
+        self.llm = llm
+        self.system_prompt = """你是一个专业的SQL查询生成助手。你的任务是将用户的自然语言查询转换为准确的Oracle SQL语句。
+
+# 规则:
+1. 只返回SQL语句,不要包含任何解释或额外文字
+2. 使用Oracle SQL语法
+3. 表名和字段名使用大写
+4. 日期格式使用 'YYYY-MM-DD'
+5. 字符串使用单引号
+6. 确保SQL语句安全,避免SQL注入
+
+# 数据库表结构:
+{schema_info}
+
+# 示例:
+用户输入: 查询所有员工信息
+输出: SELECT * FROM EMPLOYEES
+
+用户输入: 查询工资大于5000的员工
+输出: SELECT * FROM EMPLOYEES WHERE SALARY > 5000
+
+现在,请根据用户的自然语言输入生成对应的SQL语句。
+"""
+    
+    def generate_sql(self, natural_query: str, schema_info: str) -> str:
+        """生成SQL语句"""
+        prompt = self.system_prompt.format(schema_info=schema_info)
+        
+        messages = [
+            {"role": "system", "content": prompt},
+            {"role": "user", "content": natural_query}
+        ]
+        
+        response = self.llm.invoke(messages)
+        
+        sql = response.strip()
+        
+        if sql.startswith("```sql"):
+            sql = sql[6:]
+        if sql.startswith("```"):
+            sql = sql[3:]
+        if sql.endswith("```"):
+            sql = sql[:-3]
+        
+        return sql.strip()
+    
+    def validate_sql(self, sql: str) -> tuple[bool, str]:
+        """验证SQL语句的基本语法"""
+        sql_upper = sql.upper().strip()
+        
+        if not sql_upper.startswith(("SELECT", "WITH")):
+            return False, "只允许SELECT查询语句"
+        
+        dangerous_keywords = ["DROP", "DELETE", "UPDATE", "INSERT", "TRUNCATE", "ALTER", "CREATE"]
+        for keyword in dangerous_keywords:
+            if keyword in sql_upper:
+                return False, f"不允许使用 {keyword} 语句"
+        
+        return True, "SQL语句验证通过"
+
+
+def format_query_result(result: Dict[str, Any]) -> str:
+    """格式化查询结果为表格"""
+    if not result["success"]:
+        return f"查询失败: {result['error']}"
+    
+    if result["row_count"] == 0:
+        return "查询成功,但没有找到匹配的数据。"
+    
+    columns = result["columns"]
+    rows = result["rows"]
+    
+    col_widths = []
+    for i, col in enumerate(columns):
+        max_width = max(len(str(col)), max(len(str(row[i])) for row in rows))
+        col_widths.append(max_width + 2)
+    
+    separator = "+" + "+".join("-" * width for width in col_widths) + "+"
+    
+    header = "|" + "|".join(
+        str(col).center(width) for col, width in zip(columns, col_widths)
+    ) + "|"
+    
+    data_rows = []
+    for row in rows:
+        data_row = "|" + "|".join(
+            str(cell).center(width) for cell, width in zip(row, col_widths)
+        ) + "|"
+        data_rows.append(data_row)
+    
+    table = [separator, header, separator] + data_rows + [separator]
+    
+    return "\n".join(table)

+ 131 - 0
Co-creation-projects/939147533-DatabaseAgent/test.py

@@ -0,0 +1,131 @@
+"""
+数据库Agent助手 - 测试脚本
+用于测试各个组件的功能
+"""
+import os
+from dotenv import load_dotenv
+from hello_agents import HelloAgentsLLM
+from react_agent import DatabaseAgent, DatabaseConfig
+from tools import OracleQueryTool, SQLGeneratorTool
+
+load_dotenv()
+
+
+def test_database_connection():
+    """测试数据库连接"""
+    print("=" * 60)
+    print("测试1: 数据库连接")
+    print("=" * 60)
+    
+    db_config = DatabaseConfig()
+    
+    if not db_config.validate():
+        print("❌ 数据库配置不完整")
+        return False
+    
+    print(f"配置信息: {db_config.get_connection_string()}")
+    
+    oracle_tool = OracleQueryTool(db_config)
+    
+    if oracle_tool.connect():
+        print("✅ 数据库连接成功")
+        schema_info = oracle_tool.get_schema_info()
+        print("\n数据库表结构:")
+        print(schema_info)
+        oracle_tool.disconnect()
+        return True
+    else:
+        print("❌ 数据库连接失败")
+        return False
+
+
+def test_sql_generation():
+    """测试SQL生成功能"""
+    print("\n" + "=" * 60)
+    print("测试2: SQL生成")
+    print("=" * 60)
+    
+    try:
+        llm = HelloAgentsLLM()
+        
+        sql_generator = SQLGeneratorTool(llm)
+        
+        test_queries = [
+            "查询所有员工信息",
+            "查询工资大于5000的员工",
+            "统计各部门的员工数量"
+        ]
+        
+        for query in test_queries:
+            print(f"\n自然语言: {query}")
+            sql = sql_generator.generate_sql(query, "表 EMPLOYEES: ID (NUMBER), NAME (VARCHAR2), SALARY (NUMBER), DEPARTMENT (VARCHAR2)")
+            print(f"生成的SQL: {sql}")
+            
+            is_valid, msg = sql_generator.validate_sql(sql)
+            print(f"验证结果: {msg}")
+        
+        return True
+    except Exception as e:
+        print(f"❌ SQL生成测试失败: {e}")
+        return False
+
+
+def test_agent_query():
+    """测试Agent查询功能"""
+    print("\n" + "=" * 60)
+    print("测试3: Agent查询")
+    print("=" * 60)
+    
+    try:
+        llm = HelloAgentsLLM()
+        
+        db_config = DatabaseConfig()
+        
+        if not db_config.validate():
+            print("❌ 数据库配置不完整")
+            return False
+        
+        agent = DatabaseAgent(
+            name="TestAgent",
+            llm=llm,
+            db_config=db_config,
+            max_steps=5
+        )
+        
+        test_query = "查询所有员工的信息"
+        print(f"\n测试查询: {test_query}")
+        result = agent.run(test_query)
+        print(f"\n查询结果:\n{result}")
+        
+        return True
+    except Exception as e:
+        print(f"❌ Agent查询测试失败: {e}")
+        return False
+
+
+def main():
+    """运行所有测试"""
+    print("🧪 数据库Agent助手 - 测试套件")
+    print("=" * 60)
+    
+    results = []
+    
+    results.append(("数据库连接", test_database_connection()))
+    results.append(("SQL生成", test_sql_generation()))
+    results.append(("Agent查询", test_agent_query()))
+    
+    print("\n" + "=" * 60)
+    print("测试结果汇总")
+    print("=" * 60)
+    
+    for test_name, result in results:
+        status = "✅ 通过" if result else "❌ 失败"
+        print(f"{test_name}: {status}")
+    
+    passed = sum(1 for _, result in results if result)
+    total = len(results)
+    print(f"\n总计: {passed}/{total} 测试通过")
+
+
+if __name__ == "__main__":
+    main()