ソースを参照

grpo llama 3.2 3b with 3 reward functions

Jeff Tang 4 週間 前
コミット
c88e10fab6

+ 22 - 0
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/deepspeed_zero3.yaml

@@ -0,0 +1,22 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+  deepspeed_multinode_launcher: standard
+  offload_optimizer_device: none
+  offload_param_device: none
+  zero3_init_flag: true
+  zero3_save_16bit_model: true
+  zero_stage: 3
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false

+ 72 - 0
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo-llama323b-text2sql.yaml

@@ -0,0 +1,72 @@
+# Model arguments
+model_name_or_path: meta-llama/Llama-3.2-3B-Instruct
+model_revision: main
+torch_dtype: bfloat16
+attn_implementation: flash_attention_2
+bf16: true
+tf32: true
+output_dir: runs/llama-3.2-3b-grpo-text2sql-alltrain-lr5-ng8
+
+# Lora Arguments
+# No LoRA is used here
+
+# Training arguments
+max_steps: 500 # 1000 #500
+per_device_train_batch_size: 1
+gradient_accumulation_steps: 8
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+  use_reentrant: false
+learning_rate: 5.0e-7 # 1.0e-6 # 5.0e-7 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
+lr_scheduler_type: cosine
+warmup_ratio: 0.03
+# GRPO specific parameters
+beta: 0.001 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05
+max_prompt_length: 512 # 256
+max_completion_length: 1024
+num_generations: 8 # 6 # 8
+use_vllm: true
+
+# Reward function weights
+# Order: [format_reward_func, execution_reward_func, ensemble_n_gram_reward_func]
+reward_weights: [1.0, 3.0, 1.0]
+# **Recommended Weight Strategy**
+# Current Setting: `[1.0, 3.0, 1.0]`**
+# *   **Format reward (1.0)**: Standard weight since format correctness is binary but essential
+# *   **Execution reward (3.0)**: **Highest weight** - SQL execution correctness is most important for text2sql
+# *   **N-gram similarity (1.0)**: Standard weight for syntactic similarity
+
+# **Alternative Weight Strategies**
+# **Conservative approach: `[2.0, 4.0, 1.0]`**
+# *   Emphasizes both format and execution correctness
+# *   Lower weight on similarity metrics
+# **Balanced approach: `[1.5, 2.0, 1.5]`**
+# *   More balanced across all three metrics
+# *   Good for early training stages
+# **Similarity-focused: `[1.0, 2.0, 2.0]`**
+# *   Higher weight on N-gram similarity
+# *   Useful if execution often fails initially
+
+
+vllm_device: "cuda:0" # use vLLM for generation and DeepSpeed for distributed training.
+# Set the num_processes to the number of GPUs you have -
+# the last one will be used with vLLM for Generation.
+# if you have 6 GPUs, set vllm_device to "cuda:5" (or 5?) and
+# num_processes to 5 (or 6? in which case, 6th GPU will be used
+# for both generation and training
+
+vllm_gpu_memory_utilization: 0.5
+
+# Logging arguments
+logging_strategy: steps
+logging_steps: 2
+report_to:
+- tensorboard
+save_strategy: "steps"
+save_steps: 50
+seed: 42
+
+# Hugging Face Hub
+push_to_hub: false
+  # hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir
+hub_strategy: every_save

ファイルの差分が大きいため隠しています
+ 537 - 0
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo_text2sql.py


+ 13 - 0
end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/requirements.txt

@@ -0,0 +1,13 @@
+torch==2.5.1
+tensorboard==2.19.0
+setuptools==70.3.0
+flash-attn==2.7.4.post1
+transformers==4.48.1
+datasets==3.1.0
+accelerate==1.3.0
+hf-transfer==0.1.9
+deepspeed==0.15.4
+trl==0.14.0
+peft==0.15.2
+vllm==0.7.0
+func_timeout==4.3.5