|
@@ -237,13 +237,13 @@ def main(**kwargs):
|
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
- # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
|
+ # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
|
if is_vision:
|
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(
|
|
|
model,
|
|
|
[
|
|
|
MllamaSelfAttentionDecoderLayer,
|
|
|
- MllamaSelfAttentionDecoderLayer,
|
|
|
+ MllamaCrossAttentionDecoderLayer,
|
|
|
MllamaVisionEncoderLayer,
|
|
|
],
|
|
|
)
|