00_quick_test.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. """
  2. 快速实验测试
  3. 使用少量数据快速测试SFT和GRPO训练流程
  4. """
  5. import sys
  6. from pathlib import Path
  7. import json
  8. # 添加项目路径
  9. project_root = Path(__file__).parent.parent / "HelloAgents"
  10. sys.path.insert(0, str(project_root))
  11. from hello_agents.tools import RLTrainingTool
  12. def quick_test():
  13. """
  14. 快速实验测试
  15. 配置:
  16. - 模型: Qwen/Qwen3-0.6B
  17. - 样本数: 10个
  18. - 训练轮数: 1轮
  19. - 预计时间: ~2-3分钟
  20. """
  21. tool = RLTrainingTool()
  22. print("="*80)
  23. print("快速实验测试")
  24. print("="*80)
  25. # ========================================================================
  26. # 测试1: 数据加载
  27. # ========================================================================
  28. print("\n测试1: 数据加载")
  29. print("-"*80)
  30. data_config = {
  31. "action": "load_dataset",
  32. "format_type": "sft",
  33. "split": "train",
  34. "max_samples": 5
  35. }
  36. print("加载数据集...")
  37. result = tool.run(data_config)
  38. data = json.loads(result)
  39. print(f"✅ 数据集加载成功: {data['dataset_size']} 样本")
  40. print(json.dumps(data, indent=2, ensure_ascii=False))
  41. # ========================================================================
  42. # 测试2: SFT训练
  43. # ========================================================================
  44. print("\n测试2: SFT训练")
  45. print("-"*80)
  46. sft_config = {
  47. "action": "train",
  48. "algorithm": "sft",
  49. "model_name": "Qwen/Qwen3-0.6B",
  50. "output_dir": "./output/quick_test/sft",
  51. "max_samples": 10,
  52. "num_epochs": 1,
  53. "batch_size": 2,
  54. "use_lora": True,
  55. "lora_r": 8,
  56. "lora_alpha": 16,
  57. }
  58. print("SFT配置:")
  59. print(json.dumps(sft_config, indent=2, ensure_ascii=False))
  60. print("\n⏳ 开始SFT训练...")
  61. sft_result = tool.run(sft_config)
  62. sft_data = json.loads(sft_result)
  63. print("\n✅ SFT训练结果:")
  64. print(json.dumps(sft_data, indent=2, ensure_ascii=False))
  65. # ========================================================================
  66. # 测试3: GRPO训练
  67. # ========================================================================
  68. print("\n测试3: GRPO训练")
  69. print("-"*80)
  70. grpo_config = {
  71. "action": "train",
  72. "algorithm": "grpo",
  73. "model_name": "Qwen/Qwen3-0.6B",
  74. "output_dir": "./output/quick_test/grpo",
  75. "max_samples": 10,
  76. "num_epochs": 1,
  77. "batch_size": 2,
  78. "use_lora": True,
  79. "lora_r": 8,
  80. "lora_alpha": 16,
  81. }
  82. print("GRPO配置:")
  83. print(json.dumps(grpo_config, indent=2, ensure_ascii=False))
  84. print("\n⏳ 开始GRPO训练...")
  85. grpo_result = tool.run(grpo_config)
  86. grpo_data = json.loads(grpo_result)
  87. print("\n✅ GRPO训练结果:")
  88. print(json.dumps(grpo_data, indent=2, ensure_ascii=False))
  89. # ========================================================================
  90. # 测试4: 奖励函数
  91. # ========================================================================
  92. print("\n测试4: 奖励函数")
  93. print("-"*80)
  94. reward_config = {
  95. "action": "create_reward",
  96. "reward_type": "accuracy"
  97. }
  98. print("创建奖励函数...")
  99. reward_result = tool.run(reward_config)
  100. reward_data = json.loads(reward_result)
  101. print("✅ 奖励函数创建成功:")
  102. print(json.dumps(reward_data, indent=2, ensure_ascii=False))
  103. # ========================================================================
  104. # 总结
  105. # ========================================================================
  106. print("\n" + "="*80)
  107. print("测试总结")
  108. print("="*80)
  109. print("\n✅ 所有测试通过!")
  110. print("\n测试项目:")
  111. print(" 1. ✅ 数据加载")
  112. print(" 2. ✅ SFT训练")
  113. print(" 3. ✅ GRPO训练")
  114. print(" 4. ✅ 奖励函数创建")
  115. print("\n模型路径:")
  116. print(f" SFT模型: {sft_config['output_dir']}")
  117. print(f" GRPO模型: {grpo_config['output_dir']}")
  118. if __name__ == "__main__":
  119. quick_test()