05_grpo_training.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """
  2. 示例5: GRPO训练完整流程
  3. 演示如何使用RLTrainingTool进行GRPO强化学习训练
  4. """
  5. import sys
  6. from pathlib import Path
  7. import json
  8. # 添加项目路径
  9. project_root = Path(__file__).parent.parent / "HelloAgents"
  10. sys.path.insert(0, str(project_root))
  11. from hello_agents.tools import RLTrainingTool
  12. # ============================================================================
  13. # 示例1: 最简单的GRPO训练
  14. # ============================================================================
  15. def minimal_grpo_training():
  16. """
  17. 最简单的GRPO训练示例
  18. 只需要调用RLTrainingTool即可
  19. """
  20. tool = RLTrainingTool()
  21. config = {
  22. "action": "train",
  23. "algorithm": "grpo",
  24. "model_name": "Qwen/Qwen3-0.6B",
  25. "output_dir": "./output/grpo_minimal",
  26. "max_samples": 10,
  27. "num_epochs": 1,
  28. }
  29. print("最简单的GRPO训练:")
  30. print(f" 模型: {config['model_name']}")
  31. print(f" 样本数: {config['max_samples']}")
  32. print(f" 训练轮数: {config['num_epochs']}")
  33. # 实际训练时取消注释
  34. # result = tool.run(config)
  35. # result_dict = json.loads(result)
  36. # print(f"\n✅ 训练完成! 模型保存在: {result_dict['output_dir']}")
  37. return config
  38. # ============================================================================
  39. # 示例2: 标准GRPO训练配置
  40. # ============================================================================
  41. def standard_grpo_training():
  42. """
  43. 标准的GRPO训练配置
  44. 通常在SFT模型基础上进行GRPO训练
  45. """
  46. tool = RLTrainingTool()
  47. config = {
  48. "action": "train",
  49. "algorithm": "grpo",
  50. # 模型配置 - 可以使用SFT训练后的模型
  51. "model_name": "Qwen/Qwen3-0.6B", # 或 "./output/sft_standard"
  52. "output_dir": "./output/grpo_standard",
  53. # 数据配置
  54. "max_samples": 500, # GRPO通常使用较少样本
  55. # 训练配置
  56. "num_epochs": 3,
  57. "batch_size": 2, # GRPO需要更多显存
  58. "learning_rate": 1e-5, # 比SFT小10倍
  59. # LoRA配置
  60. "use_lora": True,
  61. "lora_r": 16,
  62. "lora_alpha": 32,
  63. }
  64. print("标准GRPO训练配置:")
  65. print(f" 模型: {config['model_name']}")
  66. print(f" 样本数: {config['max_samples']}")
  67. print(f" 训练轮数: {config['num_epochs']}")
  68. print(f" batch_size: {config['batch_size']}")
  69. print(f" learning_rate: {config['learning_rate']} (比SFT小)")
  70. # 实际训练时取消注释
  71. # result = tool.run(config)
  72. # result_dict = json.loads(result)
  73. # print(f"\n✅ GRPO训练完成!")
  74. return config
  75. # ============================================================================
  76. # 示例3: 完整数据集训练
  77. # ============================================================================
  78. def full_dataset_training():
  79. """
  80. 使用完整数据集进行GRPO训练
  81. """
  82. tool = RLTrainingTool()
  83. config = {
  84. "action": "train",
  85. "algorithm": "grpo",
  86. "model_name": "Qwen/Qwen3-0.6B",
  87. "output_dir": "./output/grpo_full",
  88. # 使用全部数据
  89. "max_samples": None, # None = 使用全部数据
  90. "num_epochs": 3,
  91. "batch_size": 2,
  92. "learning_rate": 1e-5,
  93. "use_lora": True,
  94. "lora_r": 16,
  95. "lora_alpha": 32,
  96. }
  97. print("完整数据集GRPO训练:")
  98. print(f" 模型: {config['model_name']}")
  99. print(f" 样本数: 全部 (max_samples=None)")
  100. print(f" 训练轮数: {config['num_epochs']}")
  101. print(f" 预计样本数: ~7500 (GSM8K训练集)")
  102. # 实际训练时取消注释
  103. # result = tool.run(config)
  104. return config
  105. # ============================================================================
  106. # 示例4: SFT + GRPO完整流程
  107. # ============================================================================
  108. def complete_sft_grpo_pipeline():
  109. """
  110. 完整的SFT + GRPO训练流程
  111. 步骤:
  112. 1. SFT训练 - 学习基本格式
  113. 2. GRPO训练 - 优化推理能力
  114. """
  115. tool = RLTrainingTool()
  116. # 步骤1: SFT训练
  117. print("步骤1: SFT训练")
  118. sft_config = {
  119. "action": "train",
  120. "algorithm": "sft",
  121. "model_name": "Qwen/Qwen3-0.6B",
  122. "output_dir": "./output/pipeline_sft",
  123. "max_samples": 1000,
  124. "num_epochs": 3,
  125. "batch_size": 4,
  126. "use_lora": True,
  127. }
  128. print(f" 模型: {sft_config['model_name']}")
  129. print(f" 样本数: {sft_config['max_samples']}")
  130. # 实际训练时取消注释
  131. # sft_result = tool.run(sft_config)
  132. # print(f"✅ SFT训练完成: {sft_config['output_dir']}")
  133. # 步骤2: GRPO训练
  134. print("\n步骤2: GRPO训练")
  135. grpo_config = {
  136. "action": "train",
  137. "algorithm": "grpo",
  138. "model_name": "./output/pipeline_sft", # 使用SFT模型
  139. "output_dir": "./output/pipeline_grpo",
  140. "max_samples": 500,
  141. "num_epochs": 3,
  142. "batch_size": 2,
  143. "learning_rate": 1e-5,
  144. "use_lora": True,
  145. }
  146. print(f" 基础模型: {grpo_config['model_name']}")
  147. print(f" 样本数: {grpo_config['max_samples']}")
  148. # 实际训练时取消注释
  149. # grpo_result = tool.run(grpo_config)
  150. # print(f"✅ GRPO训练完成: {grpo_config['output_dir']}")
  151. print("\n💡 推荐使用GRPO模型进行推理")
  152. return sft_config, grpo_config
  153. # ============================================================================
  154. # 示例5: 不同奖励函数的使用
  155. # ============================================================================
  156. def using_different_rewards():
  157. """
  158. GRPO默认使用准确性奖励函数
  159. 可以通过创建自定义奖励函数来改变行为
  160. """
  161. print("GRPO奖励函数:")
  162. print("\n默认奖励函数: 准确性奖励")
  163. print(" - 答案正确: 1.0")
  164. print(" - 答案错误: 0.0")
  165. print("\n其他可用奖励函数:")
  166. print(" 1. 长度惩罚奖励: 鼓励简洁答案")
  167. print(" 2. 步骤奖励: 鼓励详细推理")
  168. print(" 3. 自定义奖励: 根据需求定制")
  169. print("\n创建奖励函数示例:")
  170. tool = RLTrainingTool()
  171. # 创建准确性奖励函数
  172. accuracy_config = {
  173. "action": "create_reward",
  174. "reward_type": "accuracy"
  175. }
  176. print("\n1. 准确性奖励:")
  177. print(f" 配置: {accuracy_config}")
  178. # 创建长度惩罚奖励函数
  179. length_config = {
  180. "action": "create_reward",
  181. "reward_type": "length_penalty",
  182. "penalty_weight": 0.001
  183. }
  184. print("\n2. 长度惩罚奖励:")
  185. print(f" 配置: {length_config}")
  186. # 创建步骤奖励函数
  187. step_config = {
  188. "action": "create_reward",
  189. "reward_type": "step",
  190. "step_bonus": 0.1
  191. }
  192. print("\n3. 步骤奖励:")
  193. print(f" 配置: {step_config}")
  194. return accuracy_config, length_config, step_config
  195. # ============================================================================
  196. # 示例6: 实际训练示例
  197. # ============================================================================
  198. def practical_training_example():
  199. """
  200. 实际训练示例 - 可以直接运行
  201. """
  202. tool = RLTrainingTool()
  203. config = {
  204. "action": "train",
  205. "algorithm": "grpo",
  206. "model_name": "Qwen/Qwen3-0.6B",
  207. "output_dir": "./output/grpo_practical",
  208. # 使用较少样本进行快速测试
  209. "max_samples": 50,
  210. "num_epochs": 1,
  211. "batch_size": 2,
  212. "learning_rate": 1e-5,
  213. # 使用LoRA
  214. "use_lora": True,
  215. "lora_r": 16,
  216. "lora_alpha": 32,
  217. }
  218. print("实际训练示例:")
  219. print(f" 模型: {config['model_name']}")
  220. print(f" 样本数: {config['max_samples']}")
  221. print(f" 训练轮数: {config['num_epochs']}")
  222. print(f" 输出目录: {config['output_dir']}")
  223. print("\n💡 提示: 取消下面的注释以开始训练")
  224. print("# result = tool.run(config)")
  225. print("# result_dict = json.loads(result)")
  226. print("# print(f'✅ 训练完成! 模型保存在: {result_dict[\"output_dir\"]}')")
  227. # 实际训练时取消注释
  228. # result = tool.run(config)
  229. # result_dict = json.loads(result)
  230. # print(f"\n✅ 训练完成!")
  231. # print(f"📁 模型保存在: {result_dict['output_dir']}")
  232. return config
  233. # ============================================================================
  234. # 主函数
  235. # ============================================================================
  236. if __name__ == "__main__":
  237. print("="*80)
  238. print("示例1: 最简单的GRPO训练")
  239. print("="*80)
  240. minimal_grpo_training()
  241. print("\n" + "="*80)
  242. print("示例2: 标准GRPO训练配置")
  243. print("="*80)
  244. standard_grpo_training()
  245. print("\n" + "="*80)
  246. print("示例3: 完整数据集训练")
  247. print("="*80)
  248. full_dataset_training()
  249. print("\n" + "="*80)
  250. print("示例4: SFT + GRPO完整流程")
  251. print("="*80)
  252. complete_sft_grpo_pipeline()
  253. print("\n" + "="*80)
  254. print("示例5: 不同奖励函数的使用")
  255. print("="*80)
  256. using_different_rewards()
  257. print("\n" + "="*80)
  258. print("示例6: 实际训练示例")
  259. print("="*80)
  260. practical_training_example()