| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- """
- 示例7: 模型评估
- 演示如何使用RLTrainingTool评估训练后的模型
- """
- import sys
- from pathlib import Path
- import json
- # 添加项目路径
- project_root = Path(__file__).parent.parent / "HelloAgents"
- sys.path.insert(0, str(project_root))
- from hello_agents.tools import RLTrainingTool
- # ============================================================================
- # 示例1: 评估SFT模型
- # ============================================================================
- def evaluate_sft_model():
- """
- 评估SFT训练后的模型
-
- 使用测试集评估模型的准确率
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "evaluate",
- "model_path": "./output/quick_test/sft",
- "max_samples": 50 # 使用50个测试样本
- }
-
- print("评估SFT模型:")
- print(f" 模型路径: {config['model_path']}")
- print(f" 测试样本数: {config['max_samples']}")
-
- # 实际评估时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"\n✅ 评估完成!")
- # print(f" 准确率: {result_dict['accuracy']}")
- # print(f" 平均奖励: {result_dict['average_reward']}")
-
- print("\n💡 提示: 取消注释以运行评估")
-
- return config
- # ============================================================================
- # 示例2: 评估GRPO模型
- # ============================================================================
- def evaluate_grpo_model():
- """
- 评估GRPO训练后的模型
-
- 对比GRPO模型和SFT模型的性能
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "evaluate",
- "model_path": "./output/quick_test/grpo",
- "max_samples": 50
- }
-
- print("评估GRPO模型:")
- print(f" 模型路径: {config['model_path']}")
- print(f" 测试样本数: {config['max_samples']}")
-
- # 实际评估时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"\n✅ 评估完成!")
- # print(f" 准确率: {result_dict['accuracy']}")
- # print(f" 平均奖励: {result_dict['average_reward']}")
-
- print("\n💡 提示: 取消注释以运行评估")
-
- return config
- # ============================================================================
- # 示例3: 对比SFT和GRPO模型
- # ============================================================================
- def compare_sft_grpo():
- """
- 对比SFT和GRPO模型的性能
-
- 在相同的测试集上评估两个模型
- """
- tool = RLTrainingTool()
-
- print("="*80)
- print("SFT vs GRPO 模型对比")
- print("="*80)
-
- # 评估SFT模型
- print("\n1. 评估SFT模型...")
- sft_config = {
- "action": "evaluate",
- "model_path": "./output/quick_test/sft",
- "max_samples": 100
- }
-
- # 实际评估时取消注释
- # sft_result = tool.run(sft_config)
- # sft_data = json.loads(sft_result)
- # print(f" SFT准确率: {sft_data['accuracy']}")
-
- # 评估GRPO模型
- print("\n2. 评估GRPO模型...")
- grpo_config = {
- "action": "evaluate",
- "model_path": "./output/quick_test/grpo",
- "max_samples": 100
- }
-
- # 实际评估时取消注释
- # grpo_result = tool.run(grpo_config)
- # grpo_data = json.loads(grpo_result)
- # print(f" GRPO准确率: {grpo_data['accuracy']}")
-
- # 对比结果
- print("\n对比结果:")
- print(" SFT模型: 学习基本格式和推理步骤")
- print(" GRPO模型: 通过强化学习优化推理能力")
- print(" 预期: GRPO模型准确率 > SFT模型准确率")
-
- print("\n💡 提示: 取消注释以运行实际评估")
-
- return sft_config, grpo_config
- # ============================================================================
- # 示例4: 评估基线模型
- # ============================================================================
- def evaluate_baseline():
- """
- 评估基线模型(未训练的原始模型)
-
- 用于对比训练效果
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "evaluate",
- "model_path": "Qwen/Qwen3-0.6B", # 原始模型
- "max_samples": 50
- }
-
- print("评估基线模型:")
- print(f" 模型: {config['model_path']}")
- print(f" 测试样本数: {config['max_samples']}")
-
- # 实际评估时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"\n✅ 评估完成!")
- # print(f" 基线准确率: {result_dict['accuracy']}")
-
- print("\n💡 提示: 基线模型通常准确率较低")
- print(" 训练后的模型应该显著优于基线")
-
- return config
- # ============================================================================
- # 示例5: 完整评估流程
- # ============================================================================
- def complete_evaluation():
- """
- 完整的评估流程
-
- 评估基线、SFT和GRPO三个模型
- """
- tool = RLTrainingTool()
-
- models = {
- "基线模型": "Qwen/Qwen3-0.6B",
- "SFT模型": "./output/quick_test/sft",
- "GRPO模型": "./output/quick_test/grpo"
- }
-
- print("="*80)
- print("完整评估流程")
- print("="*80)
-
- results = {}
-
- for name, model_path in models.items():
- print(f"\n评估 {name}...")
- print(f" 路径: {model_path}")
-
- config = {
- "action": "evaluate",
- "model_path": model_path,
- "max_samples": 100
- }
-
- # 实际评估时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # results[name] = result_dict
- # print(f" 准确率: {result_dict['accuracy']}")
-
- print("\n" + "="*80)
- print("评估总结")
- print("="*80)
-
- # 实际评估时取消注释
- # for name, result in results.items():
- # print(f"{name}: {result['accuracy']}")
-
- print("\n预期结果:")
- print(" 基线模型 < SFT模型 < GRPO模型")
- print(" 说明强化学习训练有效提升了模型性能")
-
- print("\n💡 提示: 取消注释以运行完整评估")
-
- return models
- # ============================================================================
- # 示例6: 实际评估示例
- # ============================================================================
- def practical_evaluation():
- """
- 实际评估示例 - 可以直接运行
-
- 评估quick_test训练的模型
- """
- tool = RLTrainingTool()
-
- print("="*80)
- print("实际评估示例")
- print("="*80)
-
- # 检查模型是否存在
- import os
- sft_path = "./output/quick_test/sft"
- grpo_path = "./output/quick_test/grpo"
-
- if not os.path.exists(sft_path):
- print(f"\n❌ SFT模型不存在: {sft_path}")
- print(" 请先运行 00_quick_test.py 训练模型")
- return None
-
- if not os.path.exists(grpo_path):
- print(f"\n❌ GRPO模型不存在: {grpo_path}")
- print(" 请先运行 00_quick_test.py 训练模型")
- return None
-
- print("\n✅ 模型文件存在,开始评估...")
-
- # 评估SFT模型
- print("\n1. 评估SFT模型...")
- sft_config = {
- "action": "evaluate",
- "model_path": sft_path,
- "max_samples": 20 # 使用较少样本快速测试
- }
-
- print("💡 提示: 取消下面的注释以开始评估")
- print("# sft_result = tool.run(sft_config)")
- print("# sft_data = json.loads(sft_result)")
- print("# print(f'SFT准确率: {sft_data[\"accuracy\"]}')")
-
- # 评估GRPO模型
- print("\n2. 评估GRPO模型...")
- grpo_config = {
- "action": "evaluate",
- "model_path": grpo_path,
- "max_samples": 20
- }
-
- print("💡 提示: 取消下面的注释以开始评估")
- print("# grpo_result = tool.run(grpo_config)")
- print("# grpo_data = json.loads(grpo_result)")
- print("# print(f'GRPO准确率: {grpo_data[\"accuracy\"]}')")
-
- # 实际评估时取消注释
- # sft_result = tool.run(sft_config)
- # sft_data = json.loads(sft_result)
- # print(f"\n✅ SFT评估完成: {sft_data['accuracy']}")
-
- # grpo_result = tool.run(grpo_config)
- # grpo_data = json.loads(grpo_result)
- # print(f"✅ GRPO评估完成: {grpo_data['accuracy']}")
-
- return sft_config, grpo_config
- # ============================================================================
- # 主函数
- # ============================================================================
- if __name__ == "__main__":
- print("="*80)
- print("示例1: 评估SFT模型")
- print("="*80)
- evaluate_sft_model()
-
- print("\n" + "="*80)
- print("示例2: 评估GRPO模型")
- print("="*80)
- evaluate_grpo_model()
-
- print("\n" + "="*80)
- print("示例3: 对比SFT和GRPO模型")
- print("="*80)
- compare_sft_grpo()
-
- print("\n" + "="*80)
- print("示例4: 评估基线模型")
- print("="*80)
- evaluate_baseline()
-
- print("\n" + "="*80)
- print("示例5: 完整评估流程")
- print("="*80)
- complete_evaluation()
-
- print("\n" + "="*80)
- print("示例6: 实际评估示例")
- print("="*80)
- practical_evaluation()
|