|
@@ -187,7 +187,12 @@ def main(**kwargs):
|
|
|
freeze_transformer_layers(model, train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
|
|
|
+ # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
|
|
|
+ if is_vision:
|
|
|
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
|
|
|
+ else:
|
|
|
+ # Create the FSDP wrapper for LlamaDecoderLayer in text models
|
|
|
+ my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
|
|
|
device_id = 0
|
|
|
if is_xpu_available():
|
|
|
device_id = torch.xpu.current_device()
|