| 
					
				 | 
			
			
				@@ -1,4 +1,5 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from unittest.mock import patch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import importlib 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torch.utils.data.dataloader import DataLoader 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -11,7 +12,7 @@ from llama_recipes.finetuning import main 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 @patch('llama_recipes.finetuning.optim.AdamW') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 @patch('llama_recipes.finetuning.StepLR') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    kwargs = {"run_validation": True} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    kwargs = {"run_validation": False} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     get_dataset.return_value = [1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -36,8 +37,7 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 @patch('llama_recipes.finetuning.optim.AdamW') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 @patch('llama_recipes.finetuning.StepLR') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    kwargs = {"run_validation": False} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    kwargs = {"run_validation": True} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     get_dataset.return_value = [1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     main(**kwargs) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -47,7 +47,6 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     args, kwargs = train.call_args 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     train_dataloader = args[1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     eval_dataloader = args[2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     assert isinstance(train_dataloader, DataLoader) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     assert isinstance(eval_dataloader, DataLoader) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 |