浏览代码

fix typo in auto wrap policy (#793)

Kai Wu 5 月之前
父节点
当前提交
de3e32c8a7
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      src/llama_recipes/finetuning.py

+ 2 - 2
src/llama_recipes/finetuning.py

@@ -246,13 +246,13 @@ def main(**kwargs):
             print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
         
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
+        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
         if is_vision:
             my_auto_wrapping_policy = fsdp_auto_wrap_policy(
                 model,
                 [
                     MllamaSelfAttentionDecoderLayer,
-                    MllamaSelfAttentionDecoderLayer,
+                    MllamaCrossAttentionDecoderLayer,
                     MllamaVisionEncoderLayer,
                 ],
             )