01_dataset_loading.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. """
  2. 示例1: 数据集加载和格式化
  3. 演示如何使用RLTrainingTool加载和查看GSM8K数据集
  4. """
  5. import sys
  6. from pathlib import Path
  7. import json
  8. # 添加项目路径
  9. project_root = Path(__file__).parent.parent / "HelloAgents"
  10. sys.path.insert(0, str(project_root))
  11. from hello_agents.tools import RLTrainingTool
  12. # ============================================================================
  13. # 示例1: 加载SFT格式数据集
  14. # ============================================================================
  15. def load_sft_dataset():
  16. """
  17. 使用RLTrainingTool加载SFT格式的GSM8K数据集
  18. SFT数据格式:
  19. {
  20. "prompt": "Question: ...\n\nLet's solve this step by step:\n",
  21. "completion": "Step 1: ...\nFinal Answer: 42",
  22. "text": "Question: ...\n\nLet's solve this step by step:\nStep 1: ...\nFinal Answer: 42"
  23. }
  24. """
  25. tool = RLTrainingTool()
  26. config = {
  27. "action": "load_dataset",
  28. "format": "sft",
  29. "split": "train",
  30. "max_samples": 5
  31. }
  32. print("加载SFT格式数据集...")
  33. result = tool.run(config)
  34. result_dict = json.loads(result)
  35. print(f"✅ 数据集大小: {result_dict['dataset_size']}")
  36. print(f"📋 数据集列: {result_dict['sample_keys']}")
  37. print(f"\n💡 提示: 数据集已加载,可以用于训练")
  38. print(f" 使用 action='train' 开始训练")
  39. return result_dict
  40. # ============================================================================
  41. # 示例2: 加载RL格式数据集
  42. # ============================================================================
  43. def load_rl_dataset():
  44. """
  45. 使用RLTrainingTool加载RL格式的GSM8K数据集
  46. RL数据格式:
  47. {
  48. "prompt": "<|im_start|>user\nQuestion: ...\n<|im_end|>\n<|im_start|>assistant\n",
  49. "ground_truth": "42",
  50. "question": "...",
  51. "full_answer": "..."
  52. }
  53. """
  54. tool = RLTrainingTool()
  55. config = {
  56. "action": "load_dataset",
  57. "format": "rl",
  58. "split": "train",
  59. "max_samples": 5,
  60. "model_name": "Qwen/Qwen3-0.6B"
  61. }
  62. print("加载RL格式数据集...")
  63. result = tool.run(config)
  64. result_dict = json.loads(result)
  65. print(f"✅ 数据集大小: {result_dict['dataset_size']}")
  66. print(f"📋 数据集列: {result_dict['sample_keys']}")
  67. print(f"\n💡 提示: RL数据集已加载,包含prompt和ground_truth")
  68. print(f" 可用于GRPO训练")
  69. return result_dict
  70. # ============================================================================
  71. # 示例3: 加载不同split的数据集
  72. # ============================================================================
  73. def load_different_splits():
  74. """
  75. 加载训练集和测试集
  76. """
  77. tool = RLTrainingTool()
  78. # 加载训练集
  79. train_config = {
  80. "action": "load_dataset",
  81. "format": "sft",
  82. "split": "train",
  83. "max_samples": 100
  84. }
  85. print("加载训练集...")
  86. train_result = tool.run(train_config)
  87. train_data = json.loads(train_result)
  88. print(f"✅ 训练集: {train_data['dataset_size']} 样本")
  89. # 加载测试集
  90. test_config = {
  91. "action": "load_dataset",
  92. "format": "sft",
  93. "split": "test",
  94. "max_samples": 50
  95. }
  96. print("\n加载测试集...")
  97. test_result = tool.run(test_config)
  98. test_data = json.loads(test_result)
  99. print(f"✅ 测试集: {test_data['dataset_size']} 样本")
  100. return train_data, test_data
  101. # ============================================================================
  102. # 示例4: 加载完整数据集
  103. # ============================================================================
  104. def load_full_dataset():
  105. """
  106. 加载完整数据集 (max_samples=None)
  107. GSM8K数据集:
  108. - 训练集: ~7500 样本
  109. - 测试集: ~1300 样本
  110. """
  111. tool = RLTrainingTool()
  112. config = {
  113. "action": "load_dataset",
  114. "format": "sft",
  115. "split": "train",
  116. "max_samples": None # None = 使用全部数据
  117. }
  118. print("加载完整训练集...")
  119. print("⚠️ 这可能需要一些时间...")
  120. # 实际加载时取消注释
  121. # result = tool.run(config)
  122. # result_dict = json.loads(result)
  123. # print(f"✅ 完整训练集: {result_dict['dataset_size']} 样本")
  124. print("💡 提示: 设置 max_samples=None 可以加载全部数据")
  125. print(" GSM8K训练集约有 7500 个样本")
  126. return config
  127. # ============================================================================
  128. # 示例5: 对比SFT和RL格式
  129. # ============================================================================
  130. def compare_sft_rl_formats():
  131. """
  132. 对比SFT和RL数据格式的区别
  133. """
  134. tool = RLTrainingTool()
  135. print("="*80)
  136. print("SFT vs RL 数据格式对比")
  137. print("="*80)
  138. # SFT格式
  139. sft_config = {
  140. "action": "load_dataset",
  141. "format": "sft",
  142. "split": "train",
  143. "max_samples": 1
  144. }
  145. print("\n1. SFT格式:")
  146. sft_result = tool.run(sft_config)
  147. sft_data = json.loads(sft_result)
  148. print(f" 列: {sft_data['sample_keys']}")
  149. print(f" 用途: 监督微调 (Supervised Fine-Tuning)")
  150. print(f" 特点: 包含完整的prompt和completion")
  151. # RL格式
  152. rl_config = {
  153. "action": "load_dataset",
  154. "format": "rl",
  155. "split": "train",
  156. "max_samples": 1,
  157. "model_name": "Qwen/Qwen3-0.6B"
  158. }
  159. print("\n2. RL格式:")
  160. rl_result = tool.run(rl_config)
  161. rl_data = json.loads(rl_result)
  162. print(f" 列: {rl_data['sample_keys']}")
  163. print(f" 用途: 强化学习训练 (Reinforcement Learning)")
  164. print(f" 特点: 包含prompt和ground_truth,用于奖励计算")
  165. print("\n主要区别:")
  166. print(" - SFT: 直接学习正确答案")
  167. print(" - RL: 通过奖励信号学习,更灵活")
  168. return sft_data, rl_data
  169. # ============================================================================
  170. # 示例6: 数据集统计信息
  171. # ============================================================================
  172. def dataset_statistics():
  173. """
  174. 查看数据集的统计信息
  175. """
  176. tool = RLTrainingTool()
  177. config = {
  178. "action": "load_dataset",
  179. "format": "sft",
  180. "split": "train",
  181. "max_samples": 100
  182. }
  183. print("加载数据集...")
  184. result = tool.run(config)
  185. result_dict = json.loads(result)
  186. print("\n数据集统计:")
  187. print(f" 总样本数: {result_dict['dataset_size']}")
  188. print(f" 数据列: {', '.join(result_dict['sample_keys'])}")
  189. print(f" 数据集: GSM8K (Grade School Math 8K)")
  190. print(f" 任务类型: 数学推理")
  191. print(f"\n💡 提示: 数据集包含以下字段:")
  192. for key in result_dict['sample_keys']:
  193. print(f" - {key}")
  194. return result_dict
  195. # ============================================================================
  196. # 主函数
  197. # ============================================================================
  198. if __name__ == "__main__":
  199. print("="*80)
  200. print("示例1: 加载SFT格式数据集")
  201. print("="*80)
  202. load_sft_dataset()
  203. print("\n" + "="*80)
  204. print("示例2: 加载RL格式数据集")
  205. print("="*80)
  206. load_rl_dataset()
  207. print("\n" + "="*80)
  208. print("示例3: 加载不同split的数据集")
  209. print("="*80)
  210. load_different_splits()
  211. print("\n" + "="*80)
  212. print("示例4: 加载完整数据集")
  213. print("="*80)
  214. load_full_dataset()
  215. print("\n" + "="*80)
  216. print("示例5: 对比SFT和RL格式")
  217. print("="*80)
  218. compare_sft_rl_formats()
  219. print("\n" + "="*80)
  220. print("示例6: 数据集统计信息")
  221. print("="*80)
  222. dataset_statistics()