| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- """
- 示例4: SFT训练完整流程
- 演示如何使用RLTrainingTool进行SFT监督微调
- """
- import sys
- from pathlib import Path
- import json
- # 添加项目路径
- project_root = Path(__file__).parent.parent / "HelloAgents"
- sys.path.insert(0, str(project_root))
- from hello_agents.tools import RLTrainingTool
- # ============================================================================
- # 示例1: 最简单的SFT训练
- # ============================================================================
- def minimal_sft_training():
- """
- 最简单的SFT训练示例
-
- 只需要调用RLTrainingTool即可
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "train",
- "algorithm": "sft",
- "model_name": "Qwen/Qwen3-0.6B",
- "output_dir": "./output/sft_minimal",
- "max_samples": 10,
- "num_epochs": 1,
- }
-
- print("最简单的SFT训练:")
- print(f" 模型: {config['model_name']}")
- print(f" 样本数: {config['max_samples']}")
- print(f" 训练轮数: {config['num_epochs']}")
-
- # 实际训练时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"\n✅ 训练完成! 模型保存在: {result_dict['output_dir']}")
-
- return config
- # ============================================================================
- # 示例2: 标准SFT训练配置
- # ============================================================================
- def standard_sft_training():
- """
- 标准的SFT训练配置
-
- 包含:
- - LoRA参数高效微调
- - 合理的训练参数
- - 使用部分数据集
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "train",
- "algorithm": "sft",
-
- # 模型配置
- "model_name": "Qwen/Qwen3-0.6B",
- "output_dir": "./output/sft_standard",
-
- # 数据配置
- "max_samples": 1000, # 使用1000个样本
-
- # 训练配置
- "num_epochs": 3,
- "batch_size": 4,
- "learning_rate": 5e-5,
-
- # LoRA配置
- "use_lora": True,
- "lora_r": 16,
- "lora_alpha": 32,
- }
-
- print("标准SFT训练配置:")
- print(f" 模型: {config['model_name']}")
- print(f" 样本数: {config['max_samples']}")
- print(f" 训练轮数: {config['num_epochs']}")
- print(f" batch_size: {config['batch_size']}")
- print(f" learning_rate: {config['learning_rate']}")
- print(f" LoRA秩: {config['lora_r']}")
-
- # 实际训练时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"\n✅ 训练完成!")
- # print(f"📁 模型保存在: {result_dict['output_dir']}")
-
- return config
- # ============================================================================
- # 示例3: 完整数据集训练
- # ============================================================================
- def full_dataset_training():
- """
- 使用完整数据集进行训练
-
- max_samples=None 表示使用全部数据
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "train",
- "algorithm": "sft",
- "model_name": "Qwen/Qwen3-0.6B",
- "output_dir": "./output/sft_full",
-
- # 使用全部数据
- "max_samples": None, # None = 使用全部数据
-
- "num_epochs": 3,
- "batch_size": 4,
- "learning_rate": 5e-5,
- "use_lora": True,
- "lora_r": 16,
- "lora_alpha": 32,
- }
-
- print("完整数据集训练:")
- print(f" 模型: {config['model_name']}")
- print(f" 样本数: 全部 (max_samples=None)")
- print(f" 训练轮数: {config['num_epochs']}")
- print(f" 预计样本数: ~7500 (GSM8K训练集)")
-
- # 实际训练时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"\n✅ 训练完成!")
-
- return config
- # ============================================================================
- # 示例4: 不同学习率的对比
- # ============================================================================
- def compare_learning_rates():
- """
- 对比不同学习率的训练效果
-
- 常用学习率:
- - 1e-5: 保守,适合微调已经很好的模型
- - 5e-5: 推荐,平衡学习速度和稳定性
- - 1e-4: 激进,适合快速实验
- """
- learning_rates = {
- "保守 (1e-5)": 1e-5,
- "推荐 (5e-5)": 5e-5,
- "激进 (1e-4)": 1e-4,
- }
-
- print("不同学习率的对比:")
- for name, lr in learning_rates.items():
- print(f"\n{name}:")
- print(f" learning_rate: {lr}")
- print(f" 适用场景: ", end="")
- if lr == 1e-5:
- print("模型已经很好,只需微调")
- elif lr == 5e-5:
- print("标准训练,推荐使用")
- else:
- print("快速实验(可能不稳定)")
-
- # 训练示例
- print("\n训练示例 (推荐学习率):")
- tool = RLTrainingTool()
- config = {
- "action": "train",
- "algorithm": "sft",
- "model_name": "Qwen/Qwen3-0.6B",
- "max_samples": 1000,
- "num_epochs": 3,
- "learning_rate": 5e-5,
- "use_lora": True,
- }
- print(f" learning_rate: {config['learning_rate']}")
-
- # result = tool.run(config)
-
- return learning_rates
- # ============================================================================
- # 示例5: 显存优化配置
- # ============================================================================
- def memory_optimized_training():
- """
- 显存优化配置
-
- 适用于显存受限的情况:
- - 使用LoRA
- - 减小batch size
- - 使用较小的LoRA秩
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "train",
- "algorithm": "sft",
- "model_name": "Qwen/Qwen3-0.6B",
- "output_dir": "./output/sft_memory_opt",
-
- # 显存优化
- "max_samples": 1000,
- "num_epochs": 3,
- "batch_size": 1, # 最小batch size
- "learning_rate": 5e-5,
-
- # LoRA配置
- "use_lora": True,
- "lora_r": 8, # 使用较小的秩
- "lora_alpha": 16,
- }
-
- print("显存优化配置:")
- print(f" batch_size: {config['batch_size']} (最小)")
- print(f" lora_r: {config['lora_r']} (较小)")
- print(f" use_lora: {config['use_lora']}")
- print(f" 预计显存占用: ~3-4GB")
-
- # 实际训练时取消注释
- # result = tool.run(config)
-
- return config
- # ============================================================================
- # 示例6: 实际训练示例
- # ============================================================================
- def practical_training_example():
- """
- 实际训练示例 - 可以直接运行
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "train",
- "algorithm": "sft",
- "model_name": "Qwen/Qwen3-0.6B",
- "output_dir": "./output/sft_practical",
-
- # 使用较少样本进行快速测试
- "max_samples": 100,
- "num_epochs": 1,
- "batch_size": 4,
- "learning_rate": 5e-5,
-
- # 使用LoRA
- "use_lora": True,
- "lora_r": 16,
- "lora_alpha": 32,
- }
-
- print("实际训练示例:")
- print(f" 模型: {config['model_name']}")
- print(f" 样本数: {config['max_samples']}")
- print(f" 训练轮数: {config['num_epochs']}")
- print(f" 输出目录: {config['output_dir']}")
-
- print("\n💡 提示: 取消下面的注释以开始训练")
- print("# result = tool.run(config)")
- print("# result_dict = json.loads(result)")
- print("# print(f'✅ 训练完成! 模型保存在: {result_dict[\"output_dir\"]}')")
-
- # 实际训练时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"\n✅ 训练完成!")
- # print(f"📁 模型保存在: {result_dict['output_dir']}")
-
- return config
- # ============================================================================
- # 主函数
- # ============================================================================
- if __name__ == "__main__":
- print("="*80)
- print("示例1: 最简单的SFT训练")
- print("="*80)
- minimal_sft_training()
-
- print("\n" + "="*80)
- print("示例2: 标准SFT训练配置")
- print("="*80)
- standard_sft_training()
-
- print("\n" + "="*80)
- print("示例3: 完整数据集训练")
- print("="*80)
- full_dataset_training()
-
- print("\n" + "="*80)
- print("示例4: 不同学习率的对比")
- print("="*80)
- compare_learning_rates()
-
- print("\n" + "="*80)
- print("示例5: 显存优化配置")
- print("="*80)
- memory_optimized_training()
-
- print("\n" + "="*80)
- print("示例6: 实际训练示例")
- print("="*80)
- practical_training_example()
|