07_model_evaluation.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """
  2. 示例7: 模型评估
  3. 演示如何使用RLTrainingTool评估训练后的模型
  4. """
  5. import sys
  6. from pathlib import Path
  7. import json
  8. # 添加项目路径
  9. project_root = Path(__file__).parent.parent / "HelloAgents"
  10. sys.path.insert(0, str(project_root))
  11. from hello_agents.tools import RLTrainingTool
  12. # ============================================================================
  13. # 示例1: 评估SFT模型
  14. # ============================================================================
  15. def evaluate_sft_model():
  16. """
  17. 评估SFT训练后的模型
  18. 使用测试集评估模型的准确率
  19. """
  20. tool = RLTrainingTool()
  21. config = {
  22. "action": "evaluate",
  23. "model_path": "./output/quick_test/sft",
  24. "max_samples": 50 # 使用50个测试样本
  25. }
  26. print("评估SFT模型:")
  27. print(f" 模型路径: {config['model_path']}")
  28. print(f" 测试样本数: {config['max_samples']}")
  29. # 实际评估时取消注释
  30. # result = tool.run(config)
  31. # result_dict = json.loads(result)
  32. # print(f"\n✅ 评估完成!")
  33. # print(f" 准确率: {result_dict['accuracy']}")
  34. # print(f" 平均奖励: {result_dict['average_reward']}")
  35. print("\n💡 提示: 取消注释以运行评估")
  36. return config
  37. # ============================================================================
  38. # 示例2: 评估GRPO模型
  39. # ============================================================================
  40. def evaluate_grpo_model():
  41. """
  42. 评估GRPO训练后的模型
  43. 对比GRPO模型和SFT模型的性能
  44. """
  45. tool = RLTrainingTool()
  46. config = {
  47. "action": "evaluate",
  48. "model_path": "./output/quick_test/grpo",
  49. "max_samples": 50
  50. }
  51. print("评估GRPO模型:")
  52. print(f" 模型路径: {config['model_path']}")
  53. print(f" 测试样本数: {config['max_samples']}")
  54. # 实际评估时取消注释
  55. # result = tool.run(config)
  56. # result_dict = json.loads(result)
  57. # print(f"\n✅ 评估完成!")
  58. # print(f" 准确率: {result_dict['accuracy']}")
  59. # print(f" 平均奖励: {result_dict['average_reward']}")
  60. print("\n💡 提示: 取消注释以运行评估")
  61. return config
  62. # ============================================================================
  63. # 示例3: 对比SFT和GRPO模型
  64. # ============================================================================
  65. def compare_sft_grpo():
  66. """
  67. 对比SFT和GRPO模型的性能
  68. 在相同的测试集上评估两个模型
  69. """
  70. tool = RLTrainingTool()
  71. print("="*80)
  72. print("SFT vs GRPO 模型对比")
  73. print("="*80)
  74. # 评估SFT模型
  75. print("\n1. 评估SFT模型...")
  76. sft_config = {
  77. "action": "evaluate",
  78. "model_path": "./output/quick_test/sft",
  79. "max_samples": 100
  80. }
  81. # 实际评估时取消注释
  82. # sft_result = tool.run(sft_config)
  83. # sft_data = json.loads(sft_result)
  84. # print(f" SFT准确率: {sft_data['accuracy']}")
  85. # 评估GRPO模型
  86. print("\n2. 评估GRPO模型...")
  87. grpo_config = {
  88. "action": "evaluate",
  89. "model_path": "./output/quick_test/grpo",
  90. "max_samples": 100
  91. }
  92. # 实际评估时取消注释
  93. # grpo_result = tool.run(grpo_config)
  94. # grpo_data = json.loads(grpo_result)
  95. # print(f" GRPO准确率: {grpo_data['accuracy']}")
  96. # 对比结果
  97. print("\n对比结果:")
  98. print(" SFT模型: 学习基本格式和推理步骤")
  99. print(" GRPO模型: 通过强化学习优化推理能力")
  100. print(" 预期: GRPO模型准确率 > SFT模型准确率")
  101. print("\n💡 提示: 取消注释以运行实际评估")
  102. return sft_config, grpo_config
  103. # ============================================================================
  104. # 示例4: 评估基线模型
  105. # ============================================================================
  106. def evaluate_baseline():
  107. """
  108. 评估基线模型(未训练的原始模型)
  109. 用于对比训练效果
  110. """
  111. tool = RLTrainingTool()
  112. config = {
  113. "action": "evaluate",
  114. "model_path": "Qwen/Qwen3-0.6B", # 原始模型
  115. "max_samples": 50
  116. }
  117. print("评估基线模型:")
  118. print(f" 模型: {config['model_path']}")
  119. print(f" 测试样本数: {config['max_samples']}")
  120. # 实际评估时取消注释
  121. # result = tool.run(config)
  122. # result_dict = json.loads(result)
  123. # print(f"\n✅ 评估完成!")
  124. # print(f" 基线准确率: {result_dict['accuracy']}")
  125. print("\n💡 提示: 基线模型通常准确率较低")
  126. print(" 训练后的模型应该显著优于基线")
  127. return config
  128. # ============================================================================
  129. # 示例5: 完整评估流程
  130. # ============================================================================
  131. def complete_evaluation():
  132. """
  133. 完整的评估流程
  134. 评估基线、SFT和GRPO三个模型
  135. """
  136. tool = RLTrainingTool()
  137. models = {
  138. "基线模型": "Qwen/Qwen3-0.6B",
  139. "SFT模型": "./output/quick_test/sft",
  140. "GRPO模型": "./output/quick_test/grpo"
  141. }
  142. print("="*80)
  143. print("完整评估流程")
  144. print("="*80)
  145. results = {}
  146. for name, model_path in models.items():
  147. print(f"\n评估 {name}...")
  148. print(f" 路径: {model_path}")
  149. config = {
  150. "action": "evaluate",
  151. "model_path": model_path,
  152. "max_samples": 100
  153. }
  154. # 实际评估时取消注释
  155. # result = tool.run(config)
  156. # result_dict = json.loads(result)
  157. # results[name] = result_dict
  158. # print(f" 准确率: {result_dict['accuracy']}")
  159. print("\n" + "="*80)
  160. print("评估总结")
  161. print("="*80)
  162. # 实际评估时取消注释
  163. # for name, result in results.items():
  164. # print(f"{name}: {result['accuracy']}")
  165. print("\n预期结果:")
  166. print(" 基线模型 < SFT模型 < GRPO模型")
  167. print(" 说明强化学习训练有效提升了模型性能")
  168. print("\n💡 提示: 取消注释以运行完整评估")
  169. return models
  170. # ============================================================================
  171. # 示例6: 实际评估示例
  172. # ============================================================================
  173. def practical_evaluation():
  174. """
  175. 实际评估示例 - 可以直接运行
  176. 评估quick_test训练的模型
  177. """
  178. tool = RLTrainingTool()
  179. print("="*80)
  180. print("实际评估示例")
  181. print("="*80)
  182. # 检查模型是否存在
  183. import os
  184. sft_path = "./output/quick_test/sft"
  185. grpo_path = "./output/quick_test/grpo"
  186. if not os.path.exists(sft_path):
  187. print(f"\n❌ SFT模型不存在: {sft_path}")
  188. print(" 请先运行 00_quick_test.py 训练模型")
  189. return None
  190. if not os.path.exists(grpo_path):
  191. print(f"\n❌ GRPO模型不存在: {grpo_path}")
  192. print(" 请先运行 00_quick_test.py 训练模型")
  193. return None
  194. print("\n✅ 模型文件存在,开始评估...")
  195. # 评估SFT模型
  196. print("\n1. 评估SFT模型...")
  197. sft_config = {
  198. "action": "evaluate",
  199. "model_path": sft_path,
  200. "max_samples": 20 # 使用较少样本快速测试
  201. }
  202. print("💡 提示: 取消下面的注释以开始评估")
  203. print("# sft_result = tool.run(sft_config)")
  204. print("# sft_data = json.loads(sft_result)")
  205. print("# print(f'SFT准确率: {sft_data[\"accuracy\"]}')")
  206. # 评估GRPO模型
  207. print("\n2. 评估GRPO模型...")
  208. grpo_config = {
  209. "action": "evaluate",
  210. "model_path": grpo_path,
  211. "max_samples": 20
  212. }
  213. print("💡 提示: 取消下面的注释以开始评估")
  214. print("# grpo_result = tool.run(grpo_config)")
  215. print("# grpo_data = json.loads(grpo_result)")
  216. print("# print(f'GRPO准确率: {grpo_data[\"accuracy\"]}')")
  217. # 实际评估时取消注释
  218. # sft_result = tool.run(sft_config)
  219. # sft_data = json.loads(sft_result)
  220. # print(f"\n✅ SFT评估完成: {sft_data['accuracy']}")
  221. # grpo_result = tool.run(grpo_config)
  222. # grpo_data = json.loads(grpo_result)
  223. # print(f"✅ GRPO评估完成: {grpo_data['accuracy']}")
  224. return sft_config, grpo_config
  225. # ============================================================================
  226. # 主函数
  227. # ============================================================================
  228. if __name__ == "__main__":
  229. print("="*80)
  230. print("示例1: 评估SFT模型")
  231. print("="*80)
  232. evaluate_sft_model()
  233. print("\n" + "="*80)
  234. print("示例2: 评估GRPO模型")
  235. print("="*80)
  236. evaluate_grpo_model()
  237. print("\n" + "="*80)
  238. print("示例3: 对比SFT和GRPO模型")
  239. print("="*80)
  240. compare_sft_grpo()
  241. print("\n" + "="*80)
  242. print("示例4: 评估基线模型")
  243. print("="*80)
  244. evaluate_baseline()
  245. print("\n" + "="*80)
  246. print("示例5: 完整评估流程")
  247. print("="*80)
  248. complete_evaluation()
  249. print("\n" + "="*80)
  250. print("示例6: 实际评估示例")
  251. print("="*80)
  252. practical_evaluation()