grpo-llama323b-text2sql.yaml 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Model arguments
  2. model_name_or_path: meta-llama/Llama-3.2-3B-Instruct
  3. model_revision: main
  4. torch_dtype: bfloat16
  5. attn_implementation: flash_attention_2
  6. bf16: true
  7. tf32: true
  8. output_dir: runs/llama-3.2-3b-grpo-text2sql-4rewards-6gpu
  9. # Lora Arguments
  10. # No LoRA is used here
  11. # Training arguments
  12. max_steps: 750 # 1000 #500
  13. per_device_train_batch_size: 1
  14. gradient_accumulation_steps: 8
  15. gradient_checkpointing: true
  16. gradient_checkpointing_kwargs:
  17. use_reentrant: false
  18. learning_rate: 5.0e-7 # 1.0e-6 # 5.0e-7 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
  19. lr_scheduler_type: cosine
  20. warmup_ratio: 0.03
  21. # GRPO specific parameters
  22. beta: 0.001 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
  23. max_prompt_length: 512 # 256
  24. max_completion_length: 1024
  25. num_generations: 8 # 6 # 8
  26. use_vllm: true
  27. # Reward function weights
  28. # Order: [format_reward_func, execution_reward_func, ensemble_n_gram_reward_func]
  29. reward_weights: [1.0, 3.0, 1.0, 1.0]
  30. # **Recommended Weight Strategy**
  31. # Current Setting: `[1.0, 3.0, 1.0]`**
  32. # * **Format reward (1.0)**: Standard weight since format correctness is binary but essential
  33. # * **Execution reward (3.0)**: **Highest weight** - SQL execution correctness is most important for text2sql
  34. # * **N-gram similarity (1.0)**: Standard weight for syntactic similarity
  35. # **Alternative Weight Strategies**
  36. # **Conservative approach: `[2.0, 4.0, 1.0]`**
  37. # * Emphasizes both format and execution correctness
  38. # * Lower weight on similarity metrics
  39. # **Balanced approach: `[1.5, 2.0, 1.5]`**
  40. # * More balanced across all three metrics
  41. # * Good for early training stages
  42. # **Similarity-focused: `[1.0, 2.0, 2.0]`**
  43. # * Higher weight on N-gram similarity
  44. # * Useful if execution often fails initially
  45. # final_reward = format_reward*1.0 + execution_reward*3.0 + ngram_reward*1.0
  46. vllm_device: "cuda:0" # use vLLM for generation and DeepSpeed for distributed training.
  47. # Set the num_processes to the number of GPUs you have -
  48. # the last one will be used with vLLM for Generation.
  49. # if you have 6 GPUs, set vllm_device to "cuda:5" (or 5?) and
  50. # num_processes to 5 (or 6? in which case, 6th GPU will be used
  51. # for both generation and training
  52. vllm_gpu_memory_utilization: 0.5
  53. # Logging arguments
  54. logging_strategy: steps
  55. logging_steps: 2
  56. report_to:
  57. - tensorboard
  58. save_strategy: "steps"
  59. save_steps: 50
  60. seed: 42
  61. # Hugging Face Hub
  62. push_to_hub: false
  63. # hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir
  64. hub_strategy: every_save