| 
					
				 | 
			
			
				@@ -39,7 +39,7 @@ from utils.train_utils import ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     clear_gpu_cache, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     get_parameter_dtypes, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     print_model_size, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    get_policies   
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    get_policies 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from utils.dataset_utils import get_preprocessed_dataset 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -88,10 +88,10 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if torch.distributed.is_initialized(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         torch.cuda.set_device(rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         setup_environ_flags(rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Calculate gradient accumulation steps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Load the pre-trained model and setup its configuration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp and train_config.low_cpu_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -113,19 +113,20 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             llama_config = LlamaConfig.from_pretrained(train_config.model_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             with torch.device("meta"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model = LlamaForCausalLM(llama_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = LlamaForCausalLM.from_pretrained( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             train_config.model_name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             load_in_8bit=True if train_config.quantization else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             device_map="auto" if train_config.quantization else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Prepare the model for int8 training if quantization is enabled 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.quantization: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = prepare_model_for_int8_training(model) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp and fsdp_config.pure_bf16: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model.to(torch.bfloat16) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -134,7 +135,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     tokenizer.add_special_tokens( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 "pad_token": "<PAD>", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -142,11 +143,11 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         peft_config = generate_peft_config(train_config, kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = get_peft_model(model, peft_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model.print_trainable_parameters() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     #setting up FSDP if enable_fsdp is enabled 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not train_config.use_peft and train_config.freeze_layers: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             freeze_transformer_layers(train_config.num_freeze_layers) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -159,8 +160,9 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             sharding_strategy=fsdp_config.sharding_strategy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             device_id=torch.cuda.current_device(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             limit_all_gathers=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            sync_module_states=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sync_module_states=True if train_config.low_cpu_fsdp else False, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if train_config.low_cpu_fsdp and rank != 0 else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if fsdp_config.fsdp_activation_checkpointing: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             policies.apply_fsdp_checkpointing(model) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -168,14 +170,14 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model.to("cuda") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset_config = generate_dataset_config(train_config, kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      # Load and preprocess the dataset for training and validation 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset_train = get_preprocessed_dataset( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         tokenizer, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dataset_config, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         split="train", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if not train_config.enable_fsdp or rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(f"--> Training Set Length = {len(dataset_train)}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -202,7 +204,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 rank=dist.get_rank(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 num_replicas=dist.get_world_size(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Create DataLoaders for the training and validation dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     train_dataloader = torch.utils.data.DataLoader( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dataset_train, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -224,7 +226,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             drop_last=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             collate_fn=default_data_collator, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Initialize the optimizer and learning rate scheduler 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         optimizer = AnyPrecisionAdamW( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -246,7 +248,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     results = train( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         train_dataloader, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        eval_dataloader,  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        eval_dataloader, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         tokenizer, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         optimizer, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         scheduler, 
			 |