|  | @@ -0,0 +1,73 @@
 | 
	
		
			
				|  |  | +from unittest.mock import patch
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from torch.utils.data.dataloader import DataLoader
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +from llama_recipes.finetuning import main
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.train')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.get_preprocessed_dataset')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.optim.AdamW')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.StepLR')
 | 
	
		
			
				|  |  | +def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 | 
	
		
			
				|  |  | +    kwargs = {"run_validation": True}
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    get_dataset.return_value = [1]
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    main(**kwargs)
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    assert train.call_count == 1
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    args, kwargs = train.call_args
 | 
	
		
			
				|  |  | +    train_dataloader = args[1]
 | 
	
		
			
				|  |  | +    eval_dataloader = args[2]
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    assert isinstance(train_dataloader, DataLoader)
 | 
	
		
			
				|  |  | +    assert eval_dataloader is None
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    assert get_model.return_value.to.call_args.args[0] == "cuda"
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.train')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.get_preprocessed_dataset')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.optim.AdamW')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.StepLR')
 | 
	
		
			
				|  |  | +def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 | 
	
		
			
				|  |  | +    kwargs = {"run_validation": False}
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    get_dataset.return_value = [1]
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    main(**kwargs)
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    assert train.call_count == 1
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    args, kwargs = train.call_args
 | 
	
		
			
				|  |  | +    train_dataloader = args[1]
 | 
	
		
			
				|  |  | +    eval_dataloader = args[2]
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    assert isinstance(train_dataloader, DataLoader)
 | 
	
		
			
				|  |  | +    assert isinstance(eval_dataloader, DataLoader)
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    assert get_model.return_value.to.call_args.args[0] == "cuda"
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.train')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.get_preprocessed_dataset')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.generate_peft_config')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.get_peft_model')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.optim.AdamW')
 | 
	
		
			
				|  |  | +@patch('llama_recipes.finetuning.StepLR')
 | 
	
		
			
				|  |  | +def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
 | 
	
		
			
				|  |  | +    kwargs = {"use_peft": True}
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    get_dataset.return_value = [1]
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    main(**kwargs)
 | 
	
		
			
				|  |  | +    
 | 
	
		
			
				|  |  | +    assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
 | 
	
		
			
				|  |  | +    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 |