| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- """
- 示例1: 数据集加载和格式化
- 演示如何使用RLTrainingTool加载和查看GSM8K数据集
- """
- import sys
- from pathlib import Path
- import json
- # 添加项目路径
- project_root = Path(__file__).parent.parent / "HelloAgents"
- sys.path.insert(0, str(project_root))
- from hello_agents.tools import RLTrainingTool
- # ============================================================================
- # 示例1: 加载SFT格式数据集
- # ============================================================================
- def load_sft_dataset():
- """
- 使用RLTrainingTool加载SFT格式的GSM8K数据集
- SFT数据格式:
- {
- "prompt": "Question: ...\n\nLet's solve this step by step:\n",
- "completion": "Step 1: ...\nFinal Answer: 42",
- "text": "Question: ...\n\nLet's solve this step by step:\nStep 1: ...\nFinal Answer: 42"
- }
- """
- tool = RLTrainingTool()
- config = {
- "action": "load_dataset",
- "format": "sft",
- "split": "train",
- "max_samples": 5
- }
- print("加载SFT格式数据集...")
- result = tool.run(config)
- result_dict = json.loads(result)
- print(f"✅ 数据集大小: {result_dict['dataset_size']}")
- print(f"📋 数据集列: {result_dict['sample_keys']}")
- print(f"\n💡 提示: 数据集已加载,可以用于训练")
- print(f" 使用 action='train' 开始训练")
- return result_dict
- # ============================================================================
- # 示例2: 加载RL格式数据集
- # ============================================================================
- def load_rl_dataset():
- """
- 使用RLTrainingTool加载RL格式的GSM8K数据集
- RL数据格式:
- {
- "prompt": "<|im_start|>user\nQuestion: ...\n<|im_end|>\n<|im_start|>assistant\n",
- "ground_truth": "42",
- "question": "...",
- "full_answer": "..."
- }
- """
- tool = RLTrainingTool()
- config = {
- "action": "load_dataset",
- "format": "rl",
- "split": "train",
- "max_samples": 5,
- "model_name": "Qwen/Qwen3-0.6B"
- }
- print("加载RL格式数据集...")
- result = tool.run(config)
- result_dict = json.loads(result)
- print(f"✅ 数据集大小: {result_dict['dataset_size']}")
- print(f"📋 数据集列: {result_dict['sample_keys']}")
- print(f"\n💡 提示: RL数据集已加载,包含prompt和ground_truth")
- print(f" 可用于GRPO训练")
- return result_dict
- # ============================================================================
- # 示例3: 加载不同split的数据集
- # ============================================================================
- def load_different_splits():
- """
- 加载训练集和测试集
- """
- tool = RLTrainingTool()
-
- # 加载训练集
- train_config = {
- "action": "load_dataset",
- "format": "sft",
- "split": "train",
- "max_samples": 100
- }
-
- print("加载训练集...")
- train_result = tool.run(train_config)
- train_data = json.loads(train_result)
- print(f"✅ 训练集: {train_data['dataset_size']} 样本")
-
- # 加载测试集
- test_config = {
- "action": "load_dataset",
- "format": "sft",
- "split": "test",
- "max_samples": 50
- }
-
- print("\n加载测试集...")
- test_result = tool.run(test_config)
- test_data = json.loads(test_result)
- print(f"✅ 测试集: {test_data['dataset_size']} 样本")
-
- return train_data, test_data
- # ============================================================================
- # 示例4: 加载完整数据集
- # ============================================================================
- def load_full_dataset():
- """
- 加载完整数据集 (max_samples=None)
-
- GSM8K数据集:
- - 训练集: ~7500 样本
- - 测试集: ~1300 样本
- """
- tool = RLTrainingTool()
-
- config = {
- "action": "load_dataset",
- "format": "sft",
- "split": "train",
- "max_samples": None # None = 使用全部数据
- }
-
- print("加载完整训练集...")
- print("⚠️ 这可能需要一些时间...")
-
- # 实际加载时取消注释
- # result = tool.run(config)
- # result_dict = json.loads(result)
- # print(f"✅ 完整训练集: {result_dict['dataset_size']} 样本")
-
- print("💡 提示: 设置 max_samples=None 可以加载全部数据")
- print(" GSM8K训练集约有 7500 个样本")
-
- return config
- # ============================================================================
- # 示例5: 对比SFT和RL格式
- # ============================================================================
- def compare_sft_rl_formats():
- """
- 对比SFT和RL数据格式的区别
- """
- tool = RLTrainingTool()
- print("="*80)
- print("SFT vs RL 数据格式对比")
- print("="*80)
- # SFT格式
- sft_config = {
- "action": "load_dataset",
- "format": "sft",
- "split": "train",
- "max_samples": 1
- }
- print("\n1. SFT格式:")
- sft_result = tool.run(sft_config)
- sft_data = json.loads(sft_result)
- print(f" 列: {sft_data['sample_keys']}")
- print(f" 用途: 监督微调 (Supervised Fine-Tuning)")
- print(f" 特点: 包含完整的prompt和completion")
- # RL格式
- rl_config = {
- "action": "load_dataset",
- "format": "rl",
- "split": "train",
- "max_samples": 1,
- "model_name": "Qwen/Qwen3-0.6B"
- }
- print("\n2. RL格式:")
- rl_result = tool.run(rl_config)
- rl_data = json.loads(rl_result)
- print(f" 列: {rl_data['sample_keys']}")
- print(f" 用途: 强化学习训练 (Reinforcement Learning)")
- print(f" 特点: 包含prompt和ground_truth,用于奖励计算")
- print("\n主要区别:")
- print(" - SFT: 直接学习正确答案")
- print(" - RL: 通过奖励信号学习,更灵活")
- return sft_data, rl_data
- # ============================================================================
- # 示例6: 数据集统计信息
- # ============================================================================
- def dataset_statistics():
- """
- 查看数据集的统计信息
- """
- tool = RLTrainingTool()
- config = {
- "action": "load_dataset",
- "format": "sft",
- "split": "train",
- "max_samples": 100
- }
- print("加载数据集...")
- result = tool.run(config)
- result_dict = json.loads(result)
- print("\n数据集统计:")
- print(f" 总样本数: {result_dict['dataset_size']}")
- print(f" 数据列: {', '.join(result_dict['sample_keys'])}")
- print(f" 数据集: GSM8K (Grade School Math 8K)")
- print(f" 任务类型: 数学推理")
- print(f"\n💡 提示: 数据集包含以下字段:")
- for key in result_dict['sample_keys']:
- print(f" - {key}")
- return result_dict
- # ============================================================================
- # 主函数
- # ============================================================================
- if __name__ == "__main__":
- print("="*80)
- print("示例1: 加载SFT格式数据集")
- print("="*80)
- load_sft_dataset()
-
- print("\n" + "="*80)
- print("示例2: 加载RL格式数据集")
- print("="*80)
- load_rl_dataset()
-
- print("\n" + "="*80)
- print("示例3: 加载不同split的数据集")
- print("="*80)
- load_different_splits()
-
- print("\n" + "="*80)
- print("示例4: 加载完整数据集")
- print("="*80)
- load_full_dataset()
-
- print("\n" + "="*80)
- print("示例5: 对比SFT和RL格式")
- print("="*80)
- compare_sft_rl_formats()
-
- print("\n" + "="*80)
- print("示例6: 数据集统计信息")
- print("="*80)
- dataset_statistics()
|