| 
					
				 | 
			
			
				@@ -1,7 +1,12 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Copyright (c) Meta Platforms, Inc. and affiliates. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from torch.distributed._tensor.device_mesh import init_device_mesh 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import os  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import torch.cuda.nccl as nccl 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import torch.distributed as dist 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from torch.distributed._tensor.device_mesh import init_device_mesh 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def fsdp_auto_wrap_policy(model, transformer_layer_names): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     import functools 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -79,3 +84,38 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         raise RuntimeError("Failed to create a valid device mesh.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return device_mesh 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def get_policies(cfg, rank): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    """Get the policies for mixed precision and fsdp wrapping""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    verify_bfloat_support = (( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    torch.version.cuda 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    and torch.cuda.is_bf16_supported() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    and torch.version.cuda >= "11.0" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    and dist.is_nccl_available() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    and nccl.version() >= (2, 10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) or 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    (is_xpu_available())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    mixed_precision_policy = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    wrapping_policy = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Mixed precision 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if cfg.mixed_precision: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        bf16_ready = verify_bfloat_support 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if bf16_ready and not cfg.use_fp16: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mixed_precision_policy = bfSixteen 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif cfg.use_fp16: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mixed_precision_policy = fpSixteen 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                print(f"FP16 enabled") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(f"bFloat16 support not present. Using FP32, and not mixed precision") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    wrapping_policy = get_llama_wrapper() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return mixed_precision_policy, wrapping_policy 
			 |