浏览代码

fix typo in auto wrap policy

Fix a typo. The FSDP wrapper should wrap the `MllamaCrossAttentionDecoderLayer`, which was missing.
Guanghui Qin 5 月之前
父节点
当前提交
a62aff3876
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      src/llama_recipes/finetuning.py

+ 2 - 2
src/llama_recipes/finetuning.py

@@ -237,13 +237,13 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
+        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
         if is_vision:
             my_auto_wrapping_policy = fsdp_auto_wrap_policy(
                 model,
                 [
                     MllamaSelfAttentionDecoderLayer,
-                    MllamaSelfAttentionDecoderLayer,
+                    MllamaCrossAttentionDecoderLayer,
                     MllamaVisionEncoderLayer,
                 ],
             )