|
@@ -246,13 +246,13 @@ def main(**kwargs):
|
|
|
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
- # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
|
+ # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
|
if is_vision:
|
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(
|
|
|
model,
|
|
|
[
|
|
|
MllamaSelfAttentionDecoderLayer,
|
|
|
- MllamaSelfAttentionDecoderLayer,
|
|
|
+ MllamaCrossAttentionDecoderLayer,
|
|
|
MllamaVisionEncoderLayer,
|
|
|
],
|
|
|
)
|