| 
					
				 | 
			
			
				@@ -50,7 +50,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         import wandb 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     except ImportError: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         raise ImportError( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "You are trying to use wandb which is not currently installed" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "You are trying to use wandb which is not currently installed. " 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             "Please install it using pip install wandb" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     from llama_recipes.configs import wandb_config as WANDB_CONFIG 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -59,7 +59,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     update_config(wandb_config, **kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     run.config.update(train_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    run.config.update(fsdp_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    run.config.update(fsdp_config, allow_val_change=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return run 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -84,6 +84,8 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         clear_gpu_cache(local_rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         setup_environ_flags(rank) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    wandb_run = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_wandb: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not train_config.enable_fsdp or rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)     
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -152,9 +154,8 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         peft_config = generate_peft_config(train_config, kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model = get_peft_model(model, peft_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model.print_trainable_parameters() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if train_config.enable_wandb: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if not train_config.enable_fsdp or rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                wandb_run.config.update(peft_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if wandb_run: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            wandb_run.config.update(peft_config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     #setting up FSDP if enable_fsdp is enabled 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if train_config.enable_fsdp: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -260,7 +261,7 @@ def main(**kwargs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         fsdp_config if train_config.enable_fsdp else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         local_rank if train_config.enable_fsdp else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         rank if train_config.enable_fsdp else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        wandb_run if train_config.enable_wandb else None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        wandb_run, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if not train_config.enable_fsdp or rank==0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         [print(f'Key: {k}, Value: {v}') for k, v in results.items()] 
			 |