04_sft_training.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. """
  2. 示例4: SFT训练完整流程
  3. 演示如何使用RLTrainingTool进行SFT监督微调
  4. """
  5. import sys
  6. from pathlib import Path
  7. import json
  8. # 添加项目路径
  9. project_root = Path(__file__).parent.parent / "HelloAgents"
  10. sys.path.insert(0, str(project_root))
  11. from hello_agents.tools import RLTrainingTool
  12. # ============================================================================
  13. # 示例1: 最简单的SFT训练
  14. # ============================================================================
  15. def minimal_sft_training():
  16. """
  17. 最简单的SFT训练示例
  18. 只需要调用RLTrainingTool即可
  19. """
  20. tool = RLTrainingTool()
  21. config = {
  22. "action": "train",
  23. "algorithm": "sft",
  24. "model_name": "Qwen/Qwen3-0.6B",
  25. "output_dir": "./output/sft_minimal",
  26. "max_samples": 10,
  27. "num_epochs": 1,
  28. }
  29. print("最简单的SFT训练:")
  30. print(f" 模型: {config['model_name']}")
  31. print(f" 样本数: {config['max_samples']}")
  32. print(f" 训练轮数: {config['num_epochs']}")
  33. # 实际训练时取消注释
  34. # result = tool.run(config)
  35. # result_dict = json.loads(result)
  36. # print(f"\n✅ 训练完成! 模型保存在: {result_dict['output_dir']}")
  37. return config
  38. # ============================================================================
  39. # 示例2: 标准SFT训练配置
  40. # ============================================================================
  41. def standard_sft_training():
  42. """
  43. 标准的SFT训练配置
  44. 包含:
  45. - LoRA参数高效微调
  46. - 合理的训练参数
  47. - 使用部分数据集
  48. """
  49. tool = RLTrainingTool()
  50. config = {
  51. "action": "train",
  52. "algorithm": "sft",
  53. # 模型配置
  54. "model_name": "Qwen/Qwen3-0.6B",
  55. "output_dir": "./output/sft_standard",
  56. # 数据配置
  57. "max_samples": 1000, # 使用1000个样本
  58. # 训练配置
  59. "num_epochs": 3,
  60. "batch_size": 4,
  61. "learning_rate": 5e-5,
  62. # LoRA配置
  63. "use_lora": True,
  64. "lora_r": 16,
  65. "lora_alpha": 32,
  66. }
  67. print("标准SFT训练配置:")
  68. print(f" 模型: {config['model_name']}")
  69. print(f" 样本数: {config['max_samples']}")
  70. print(f" 训练轮数: {config['num_epochs']}")
  71. print(f" batch_size: {config['batch_size']}")
  72. print(f" learning_rate: {config['learning_rate']}")
  73. print(f" LoRA秩: {config['lora_r']}")
  74. # 实际训练时取消注释
  75. # result = tool.run(config)
  76. # result_dict = json.loads(result)
  77. # print(f"\n✅ 训练完成!")
  78. # print(f"📁 模型保存在: {result_dict['output_dir']}")
  79. return config
  80. # ============================================================================
  81. # 示例3: 完整数据集训练
  82. # ============================================================================
  83. def full_dataset_training():
  84. """
  85. 使用完整数据集进行训练
  86. max_samples=None 表示使用全部数据
  87. """
  88. tool = RLTrainingTool()
  89. config = {
  90. "action": "train",
  91. "algorithm": "sft",
  92. "model_name": "Qwen/Qwen3-0.6B",
  93. "output_dir": "./output/sft_full",
  94. # 使用全部数据
  95. "max_samples": None, # None = 使用全部数据
  96. "num_epochs": 3,
  97. "batch_size": 4,
  98. "learning_rate": 5e-5,
  99. "use_lora": True,
  100. "lora_r": 16,
  101. "lora_alpha": 32,
  102. }
  103. print("完整数据集训练:")
  104. print(f" 模型: {config['model_name']}")
  105. print(f" 样本数: 全部 (max_samples=None)")
  106. print(f" 训练轮数: {config['num_epochs']}")
  107. print(f" 预计样本数: ~7500 (GSM8K训练集)")
  108. # 实际训练时取消注释
  109. # result = tool.run(config)
  110. # result_dict = json.loads(result)
  111. # print(f"\n✅ 训练完成!")
  112. return config
  113. # ============================================================================
  114. # 示例4: 不同学习率的对比
  115. # ============================================================================
  116. def compare_learning_rates():
  117. """
  118. 对比不同学习率的训练效果
  119. 常用学习率:
  120. - 1e-5: 保守,适合微调已经很好的模型
  121. - 5e-5: 推荐,平衡学习速度和稳定性
  122. - 1e-4: 激进,适合快速实验
  123. """
  124. learning_rates = {
  125. "保守 (1e-5)": 1e-5,
  126. "推荐 (5e-5)": 5e-5,
  127. "激进 (1e-4)": 1e-4,
  128. }
  129. print("不同学习率的对比:")
  130. for name, lr in learning_rates.items():
  131. print(f"\n{name}:")
  132. print(f" learning_rate: {lr}")
  133. print(f" 适用场景: ", end="")
  134. if lr == 1e-5:
  135. print("模型已经很好,只需微调")
  136. elif lr == 5e-5:
  137. print("标准训练,推荐使用")
  138. else:
  139. print("快速实验(可能不稳定)")
  140. # 训练示例
  141. print("\n训练示例 (推荐学习率):")
  142. tool = RLTrainingTool()
  143. config = {
  144. "action": "train",
  145. "algorithm": "sft",
  146. "model_name": "Qwen/Qwen3-0.6B",
  147. "max_samples": 1000,
  148. "num_epochs": 3,
  149. "learning_rate": 5e-5,
  150. "use_lora": True,
  151. }
  152. print(f" learning_rate: {config['learning_rate']}")
  153. # result = tool.run(config)
  154. return learning_rates
  155. # ============================================================================
  156. # 示例5: 显存优化配置
  157. # ============================================================================
  158. def memory_optimized_training():
  159. """
  160. 显存优化配置
  161. 适用于显存受限的情况:
  162. - 使用LoRA
  163. - 减小batch size
  164. - 使用较小的LoRA秩
  165. """
  166. tool = RLTrainingTool()
  167. config = {
  168. "action": "train",
  169. "algorithm": "sft",
  170. "model_name": "Qwen/Qwen3-0.6B",
  171. "output_dir": "./output/sft_memory_opt",
  172. # 显存优化
  173. "max_samples": 1000,
  174. "num_epochs": 3,
  175. "batch_size": 1, # 最小batch size
  176. "learning_rate": 5e-5,
  177. # LoRA配置
  178. "use_lora": True,
  179. "lora_r": 8, # 使用较小的秩
  180. "lora_alpha": 16,
  181. }
  182. print("显存优化配置:")
  183. print(f" batch_size: {config['batch_size']} (最小)")
  184. print(f" lora_r: {config['lora_r']} (较小)")
  185. print(f" use_lora: {config['use_lora']}")
  186. print(f" 预计显存占用: ~3-4GB")
  187. # 实际训练时取消注释
  188. # result = tool.run(config)
  189. return config
  190. # ============================================================================
  191. # 示例6: 实际训练示例
  192. # ============================================================================
  193. def practical_training_example():
  194. """
  195. 实际训练示例 - 可以直接运行
  196. """
  197. tool = RLTrainingTool()
  198. config = {
  199. "action": "train",
  200. "algorithm": "sft",
  201. "model_name": "Qwen/Qwen3-0.6B",
  202. "output_dir": "./output/sft_practical",
  203. # 使用较少样本进行快速测试
  204. "max_samples": 100,
  205. "num_epochs": 1,
  206. "batch_size": 4,
  207. "learning_rate": 5e-5,
  208. # 使用LoRA
  209. "use_lora": True,
  210. "lora_r": 16,
  211. "lora_alpha": 32,
  212. }
  213. print("实际训练示例:")
  214. print(f" 模型: {config['model_name']}")
  215. print(f" 样本数: {config['max_samples']}")
  216. print(f" 训练轮数: {config['num_epochs']}")
  217. print(f" 输出目录: {config['output_dir']}")
  218. print("\n💡 提示: 取消下面的注释以开始训练")
  219. print("# result = tool.run(config)")
  220. print("# result_dict = json.loads(result)")
  221. print("# print(f'✅ 训练完成! 模型保存在: {result_dict[\"output_dir\"]}')")
  222. # 实际训练时取消注释
  223. # result = tool.run(config)
  224. # result_dict = json.loads(result)
  225. # print(f"\n✅ 训练完成!")
  226. # print(f"📁 模型保存在: {result_dict['output_dir']}")
  227. return config
  228. # ============================================================================
  229. # 主函数
  230. # ============================================================================
  231. if __name__ == "__main__":
  232. print("="*80)
  233. print("示例1: 最简单的SFT训练")
  234. print("="*80)
  235. minimal_sft_training()
  236. print("\n" + "="*80)
  237. print("示例2: 标准SFT训练配置")
  238. print("="*80)
  239. standard_sft_training()
  240. print("\n" + "="*80)
  241. print("示例3: 完整数据集训练")
  242. print("="*80)
  243. full_dataset_training()
  244. print("\n" + "="*80)
  245. print("示例4: 不同学习率的对比")
  246. print("="*80)
  247. compare_learning_rates()
  248. print("\n" + "="*80)
  249. print("示例5: 显存优化配置")
  250. print("="*80)
  251. memory_optimized_training()
  252. print("\n" + "="*80)
  253. print("示例6: 实际训练示例")
  254. print("="*80)
  255. practical_training_example()