03_lora_configuration.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """
  2. 示例3: LoRA配置和使用
  3. 演示如何通过RLTrainingTool配置和使用LoRA进行参数高效微调
  4. """
  5. import sys
  6. from pathlib import Path
  7. import json
  8. # 添加项目路径
  9. project_root = Path(__file__).parent.parent / "HelloAgents"
  10. sys.path.insert(0, str(project_root))
  11. from hello_agents.tools import RLTrainingTool
  12. # ============================================================================
  13. # 示例1: 基础LoRA配置
  14. # ============================================================================
  15. def basic_lora_config():
  16. """
  17. 最基础的LoRA配置
  18. LoRA (Low-Rank Adaptation):
  19. - 只训练少量额外参数
  20. - 减少60-80%显存占用
  21. - 提升2-3倍训练速度
  22. - 模型文件只有~10MB
  23. """
  24. tool = RLTrainingTool()
  25. # 使用RLTrainingTool进行SFT训练,启用LoRA
  26. config = {
  27. "action": "train",
  28. "algorithm": "sft",
  29. "model_name": "Qwen/Qwen3-0.6B",
  30. "output_dir": "./output/lora_basic",
  31. "max_samples": 100,
  32. "num_epochs": 1,
  33. # LoRA配置
  34. "use_lora": True, # 启用LoRA
  35. "lora_r": 16, # LoRA秩(rank)
  36. "lora_alpha": 32, # 缩放因子(通常是r的2倍)
  37. }
  38. print("基础LoRA配置:")
  39. print(f" 模型: {config['model_name']}")
  40. print(f" use_lora: {config['use_lora']}")
  41. print(f" lora_r: {config['lora_r']}")
  42. print(f" lora_alpha: {config['lora_alpha']}")
  43. print(f" 目标模块: ['q_proj', 'v_proj'] (默认)")
  44. # 实际训练时取消注释
  45. # result = tool.run(config)
  46. # print(json.dumps(json.loads(result), indent=2, ensure_ascii=False))
  47. return config
  48. # ============================================================================
  49. # 示例2: 不同LoRA秩的对比
  50. # ============================================================================
  51. def compare_lora_ranks():
  52. """
  53. 对比不同LoRA秩的配置
  54. LoRA秩(r)的选择:
  55. - r=8: 较小参数量,适合快速实验
  56. - r=16: 推荐值,平衡性能和效率
  57. - r=32: 较大参数量,追求更好性能
  58. """
  59. configs = {
  60. "r=8 (快速实验)": {
  61. "lora_r": 8,
  62. "lora_alpha": 16,
  63. "params": "~16K"
  64. },
  65. "r=16 (推荐)": {
  66. "lora_r": 16,
  67. "lora_alpha": 32,
  68. "params": "~32K"
  69. },
  70. "r=32 (高性能)": {
  71. "lora_r": 32,
  72. "lora_alpha": 64,
  73. "params": "~65K"
  74. },
  75. }
  76. print("不同LoRA秩的对比:")
  77. for name, config in configs.items():
  78. print(f"\n{name}:")
  79. print(f" lora_r: {config['lora_r']}")
  80. print(f" lora_alpha: {config['lora_alpha']}")
  81. print(f" 预估参数量: {config['params']}")
  82. # 实际训练示例
  83. print("\n训练示例 (r=16):")
  84. print("""
  85. tool = RLTrainingTool()
  86. result = tool.run({
  87. "action": "train",
  88. "algorithm": "sft",
  89. "model_name": "Qwen/Qwen3-0.6B",
  90. "max_samples": 100,
  91. "num_epochs": 1,
  92. "use_lora": True,
  93. "lora_r": 16,
  94. "lora_alpha": 32,
  95. })
  96. """)
  97. return configs
  98. # ============================================================================
  99. # 示例3: LoRA vs 完整微调对比
  100. # ============================================================================
  101. def compare_lora_vs_full_finetuning():
  102. """
  103. 对比LoRA和完整微调的配置
  104. """
  105. print("LoRA vs 完整微调对比:")
  106. print("\nLoRA微调:")
  107. print(" 显存占用: ~4GB (0.5B模型)")
  108. print(" 训练速度: 快(2-3x)")
  109. print(" 模型大小: ~10MB")
  110. print(" batch_size: 8")
  111. print(" use_lora: True")
  112. print("\n完整微调:")
  113. print(" 显存占用: ~14GB (0.5B模型)")
  114. print(" 训练速度: 慢")
  115. print(" 模型大小: ~1GB")
  116. print(" batch_size: 2")
  117. print(" use_lora: False")
  118. print("\n推荐: 使用LoRA进行微调")
  119. # ============================================================================
  120. # 示例4: 实际训练配置示例
  121. # ============================================================================
  122. def practical_training_configs():
  123. """
  124. 实际训练中的推荐配置
  125. """
  126. tool = RLTrainingTool()
  127. # 快速训练配置
  128. quick_config = {
  129. "action": "train",
  130. "algorithm": "sft",
  131. "model_name": "Qwen/Qwen3-0.6B",
  132. "output_dir": "./output/quick_test",
  133. "max_samples": 100,
  134. "num_epochs": 1,
  135. "batch_size": 8,
  136. "use_lora": True,
  137. "lora_r": 8,
  138. "lora_alpha": 16,
  139. }
  140. # 标准训练配置
  141. standard_config = {
  142. "action": "train",
  143. "algorithm": "sft",
  144. "model_name": "Qwen/Qwen3-0.6B",
  145. "output_dir": "./output/standard",
  146. "max_samples": 1000,
  147. "num_epochs": 3,
  148. "batch_size": 4,
  149. "use_lora": True,
  150. "lora_r": 16,
  151. "lora_alpha": 32,
  152. "learning_rate": 5e-5,
  153. }
  154. # 高质量训练配置
  155. high_quality_config = {
  156. "action": "train",
  157. "algorithm": "sft",
  158. "model_name": "Qwen/Qwen3-0.6B",
  159. "output_dir": "./output/high_quality",
  160. "max_samples": None, # 使用全部数据
  161. "num_epochs": 5,
  162. "batch_size": 2,
  163. "use_lora": True,
  164. "lora_r": 32,
  165. "lora_alpha": 64,
  166. "learning_rate": 3e-5,
  167. }
  168. print("实际训练配置示例:")
  169. print("\n1. 快速实验配置:")
  170. print(f" 样本数: {quick_config['max_samples']}")
  171. print(f" epochs: {quick_config['num_epochs']}")
  172. print(f" lora_r: {quick_config['lora_r']}")
  173. print(f" batch_size: {quick_config['batch_size']}")
  174. print("\n2. 标准训练配置:")
  175. print(f" 样本数: {standard_config['max_samples']}")
  176. print(f" epochs: {standard_config['num_epochs']}")
  177. print(f" lora_r: {standard_config['lora_r']}")
  178. print(f" batch_size: {standard_config['batch_size']}")
  179. print("\n3. 高质量训练配置:")
  180. print(f" 样本数: 全部 (max_samples=None)")
  181. print(f" epochs: {high_quality_config['num_epochs']}")
  182. print(f" lora_r: {high_quality_config['lora_r']}")
  183. print(f" batch_size: {high_quality_config['batch_size']}")
  184. # 实际训练时取消注释
  185. # result = tool.run(quick_config)
  186. # print(json.dumps(json.loads(result), indent=2, ensure_ascii=False))
  187. return quick_config, standard_config, high_quality_config
  188. # ============================================================================
  189. # 示例5: LoRA参数调优建议
  190. # ============================================================================
  191. def lora_tuning_guidelines():
  192. """
  193. LoRA参数调优建议
  194. """
  195. guidelines = {
  196. "lora_r (秩)": {
  197. "推荐值": 16,
  198. "范围": "8-32",
  199. "说明": "越大性能越好,但参数量和训练时间也越多",
  200. "选择建议": {
  201. "快速实验": 8,
  202. "平衡性能": 16,
  203. "追求性能": 32,
  204. }
  205. },
  206. "lora_alpha (缩放因子)": {
  207. "推荐值": 32,
  208. "范围": "16-64",
  209. "说明": "通常设置为lora_r的2倍",
  210. "公式": "lora_alpha = 2 * lora_r"
  211. },
  212. "max_samples (样本数)": {
  213. "快速实验": 100,
  214. "标准训练": 1000,
  215. "完整训练": "None (全部数据)",
  216. "说明": "None表示使用全部数据",
  217. },
  218. }
  219. print("LoRA参数调优建议:")
  220. for param, info in guidelines.items():
  221. print(f"\n{param}:")
  222. for key, value in info.items():
  223. if isinstance(value, dict):
  224. print(f" {key}:")
  225. for k, v in value.items():
  226. print(f" - {k}: {v}")
  227. else:
  228. print(f" {key}: {value}")
  229. return guidelines
  230. # ============================================================================
  231. # 主函数
  232. # ============================================================================
  233. if __name__ == "__main__":
  234. print("="*80)
  235. print("示例1: 基础LoRA配置")
  236. print("="*80)
  237. basic_lora_config()
  238. print("\n" + "="*80)
  239. print("示例2: 不同LoRA秩的对比")
  240. print("="*80)
  241. compare_lora_ranks()
  242. print("\n" + "="*80)
  243. print("示例3: LoRA vs 完整微调对比")
  244. print("="*80)
  245. compare_lora_vs_full_finetuning()
  246. print("\n" + "="*80)
  247. print("示例4: 实际训练配置示例")
  248. print("="*80)
  249. practical_training_configs()
  250. print("\n" + "="*80)
  251. print("示例5: LoRA参数调优建议")
  252. print("="*80)
  253. lora_tuning_guidelines()