06_complete_pipeline.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. """
  2. 完整的Agentic RL训练流程(更新版)
  3. 从数据准备到模型部署的端到端示例
  4. 更新内容:
  5. 1. 修复了JSON解析问题
  6. 2. 添加了训练监控配置(wandb/tensorboard)
  7. 3. 支持详细日志输出
  8. """
  9. import sys
  10. import os
  11. # 添加HelloAgents到路径
  12. sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "HelloAgents"))
  13. from hello_agents.tools import RLTrainingTool
  14. import json
  15. from datetime import datetime
  16. class AgenticRLPipeline:
  17. """Agentic RL训练流水线"""
  18. def __init__(self, config_path="config.json"):
  19. """
  20. 初始化训练流水线
  21. Args:
  22. config_path: 配置文件路径
  23. """
  24. self.rl_tool = RLTrainingTool()
  25. self.config = self.load_config(config_path)
  26. self.results = {}
  27. def load_config(self, config_path):
  28. """加载配置文件"""
  29. with open(config_path, 'r') as f:
  30. return json.load(f)
  31. def log(self, message):
  32. """记录日志"""
  33. timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  34. print(f"[{timestamp}] {message}")
  35. def stage1_prepare_data(self):
  36. """阶段1: 数据准备"""
  37. self.log("=" * 50)
  38. self.log("阶段1: 数据准备")
  39. self.log("=" * 50)
  40. # 加载并检查数据集
  41. result = self.rl_tool.run({
  42. "action": "load_dataset",
  43. "format": "sft",
  44. "max_samples": self.config["data"]["max_samples"],
  45. })
  46. # 解析JSON结果
  47. dataset_info = json.loads(result)
  48. self.log(f"✓ 数据集加载完成")
  49. self.log(f" - 样本数: {dataset_info['dataset_size']}")
  50. self.log(f" - 格式: {dataset_info['format']}")
  51. self.log(f" - 数据列: {', '.join(dataset_info['sample_keys'])}")
  52. self.results["data"] = dataset_info
  53. return dataset_info
  54. def stage2_sft_training(self):
  55. """阶段2: SFT训练"""
  56. self.log("\n" + "=" * 50)
  57. self.log("阶段2: SFT训练")
  58. self.log("=" * 50)
  59. sft_config = self.config["sft"]
  60. result = self.rl_tool.run({
  61. "action": "train",
  62. "algorithm": "sft",
  63. "model_name": self.config["model"]["base_model"],
  64. "output_dir": sft_config["output_dir"],
  65. "max_samples": self.config["data"]["max_samples"],
  66. "num_epochs": sft_config["num_epochs"],
  67. "batch_size": sft_config["batch_size"],
  68. "use_lora": True,
  69. # 训练监控配置
  70. "use_wandb": self.config.get("monitoring", {}).get("use_wandb", False),
  71. "use_tensorboard": self.config.get("monitoring", {}).get("use_tensorboard", True),
  72. "wandb_project": self.config.get("monitoring", {}).get("wandb_project", None),
  73. })
  74. # 解析JSON结果
  75. result_data = json.loads(result)
  76. self.log(f"✓ SFT训练完成")
  77. self.log(f" - 模型路径: {result_data['output_dir']}")
  78. self.log(f" - 状态: {result_data['status']}")
  79. self.results["sft_training"] = result_data
  80. return result_data["output_dir"]
  81. def stage3_sft_evaluation(self, model_path):
  82. """阶段3: SFT评估"""
  83. self.log("\n" + "=" * 50)
  84. self.log("阶段3: SFT评估")
  85. self.log("=" * 50)
  86. result = self.rl_tool.run({
  87. "action": "evaluate",
  88. "model_path": model_path,
  89. "max_samples": self.config["eval"]["max_samples"],
  90. "use_lora": True,
  91. })
  92. eval_data = json.loads(result)
  93. self.log(f"✓ SFT评估完成")
  94. self.log(f" - 准确率: {eval_data['accuracy']}")
  95. self.log(f" - 平均奖励: {eval_data['average_reward']}")
  96. self.results["sft_evaluation"] = eval_data
  97. return eval_data
  98. def stage4_grpo_training(self, sft_model_path):
  99. """阶段4: GRPO训练"""
  100. self.log("\n" + "=" * 50)
  101. self.log("阶段4: GRPO训练")
  102. self.log("=" * 50)
  103. grpo_config = self.config["grpo"]
  104. result = self.rl_tool.run({
  105. "action": "train",
  106. "algorithm": "grpo",
  107. "model_name": sft_model_path,
  108. "output_dir": grpo_config["output_dir"],
  109. "max_samples": self.config["data"]["max_samples"],
  110. "num_epochs": grpo_config["num_epochs"],
  111. "batch_size": grpo_config["batch_size"],
  112. "use_lora": True,
  113. # 训练监控配置
  114. "use_wandb": self.config.get("monitoring", {}).get("use_wandb", False),
  115. "use_tensorboard": self.config.get("monitoring", {}).get("use_tensorboard", True),
  116. "wandb_project": self.config.get("monitoring", {}).get("wandb_project", None),
  117. })
  118. # 解析JSON结果
  119. result_data = json.loads(result)
  120. self.log(f"✓ GRPO训练完成")
  121. self.log(f" - 模型路径: {result_data['output_dir']}")
  122. self.log(f" - 状态: {result_data['status']}")
  123. self.results["grpo_training"] = result_data
  124. return result_data["output_dir"]
  125. def stage5_grpo_evaluation(self, model_path):
  126. """阶段5: GRPO评估"""
  127. self.log("\n" + "=" * 50)
  128. self.log("阶段5: GRPO评估")
  129. self.log("=" * 50)
  130. result = self.rl_tool.run({
  131. "action": "evaluate",
  132. "model_path": model_path,
  133. "max_samples": self.config["eval"]["max_samples"],
  134. "use_lora": True,
  135. })
  136. eval_data = json.loads(result)
  137. self.log(f"✓ GRPO评估完成")
  138. self.log(f" - 准确率: {eval_data['accuracy']}")
  139. self.log(f" - 平均奖励: {eval_data['average_reward']}")
  140. self.results["grpo_evaluation"] = eval_data
  141. return eval_data
  142. def stage6_save_results(self):
  143. """阶段6: 保存结果"""
  144. self.log("\n" + "=" * 50)
  145. self.log("阶段6: 保存结果")
  146. self.log("=" * 50)
  147. # 保存训练结果
  148. results_path = "training_results.json"
  149. with open(results_path, 'w') as f:
  150. json.dump(self.results, f, indent=2)
  151. self.log(f"✓ 结果已保存到: {results_path}")
  152. def run(self):
  153. """运行完整流程"""
  154. try:
  155. # 阶段1: 数据准备
  156. self.stage1_prepare_data()
  157. # 阶段2: SFT训练
  158. sft_model_path = self.stage2_sft_training()
  159. # 阶段3: SFT评估
  160. self.stage3_sft_evaluation(sft_model_path)
  161. # 阶段4: GRPO训练
  162. grpo_model_path = self.stage4_grpo_training(sft_model_path)
  163. # 阶段5: GRPO评估
  164. self.stage5_grpo_evaluation(grpo_model_path)
  165. # 阶段6: 保存结果
  166. self.stage6_save_results()
  167. self.log("\n" + "=" * 50)
  168. self.log("✓ 训练流程完成!")
  169. self.log("=" * 50)
  170. except Exception as e:
  171. self.log(f"\n✗ 训练失败: {str(e)}")
  172. raise
  173. # 使用示例
  174. if __name__ == "__main__":
  175. # 创建配置文件
  176. config = {
  177. "model": {
  178. "base_model": "Qwen/Qwen3-0.6B"
  179. },
  180. "data": {
  181. "max_samples": 100 # 使用100个样本快速测试
  182. },
  183. "sft": {
  184. "output_dir": "./models/sft_model",
  185. "num_epochs": 2,
  186. "batch_size": 4,
  187. },
  188. "grpo": {
  189. "output_dir": "./models/grpo_model",
  190. "num_epochs": 2,
  191. "batch_size": 2,
  192. },
  193. "eval": {
  194. "max_samples": 20,
  195. "sft_accuracy_threshold": 0.40
  196. },
  197. "monitoring": {
  198. "use_wandb": False, # 是否使用Wandb
  199. "use_tensorboard": True, # 是否使用TensorBoard
  200. "wandb_project": "agentic-rl-pipeline" # Wandb项目名
  201. }
  202. }
  203. # 保存配置
  204. with open("config.json", 'w') as f:
  205. json.dump(config, f, indent=2)
  206. # 运行训练流程
  207. pipeline = AgenticRLPipeline("config.json")
  208. pipeline.run()