浏览代码

fix readme and fsdp logic

Kai Wu 7 月之前
父节点
当前提交
2730bcaab7
共有 2 个文件被更改,包括 7 次插入2 次删除
  1. 1 1
      recipes/quickstart/finetuning/finetune_vision_model.md
  2. 6 1
      src/llama_recipes/finetuning.py

+ 1 - 1
recipes/quickstart/finetuning/finetune_vision_model.md

@@ -1,7 +1,7 @@
 ## Fine-Tuning Meta Llama Multi Modal Models recipe
 This recipe steps you through how to finetune a Llama 3.2 vision model on the OCR VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset.
 
-**Disclaimer** As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset to demonstrate the steps needed for fine-tuning our vision models.
+**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset to demonstrate the steps needed for fine-tuning our vision models.
 
 ### Fine-tuning steps
 

+ 6 - 1
src/llama_recipes/finetuning.py

@@ -187,7 +187,12 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
+        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
+        if is_vision:
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
+        else:
+        # Create the FSDP wrapper for LlamaDecoderLayer in text models
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()