| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import osfrom unittest.mock import patchimport pytestimport torchfrom llama_recipes.data.sampler import LengthBasedBatchSamplerfrom llama_recipes.finetuning import mainfrom pytest import approxfrom torch.optim import AdamWfrom torch.utils.data.dataloader import DataLoaderfrom torch.utils.data.sampler import BatchSamplerdef get_fake_dataset():    return [        {            "input_ids": [1],            "attention_mask": [1],            "labels": [1],        }    ]@patch("llama_recipes.finetuning.torch.cuda.is_available")@patch("llama_recipes.finetuning.train")@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")@patch("llama_recipes.finetuning.get_preprocessed_dataset")@patch("llama_recipes.finetuning.optim.AdamW")@patch("llama_recipes.finetuning.StepLR")@pytest.mark.parametrize("cuda_is_available", [True, False])def test_finetuning_no_validation(    step_lr,    optimizer,    get_dataset,    tokenizer,    get_model,    train,    cuda,    cuda_is_available,):    kwargs = {"run_validation": False}    get_dataset.return_value = get_fake_dataset()    cuda.return_value = cuda_is_available    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]    main(**kwargs)    assert train.call_count == 1    args, kwargs = train.call_args    train_dataloader = args[1]    eval_dataloader = args[2]    assert isinstance(train_dataloader, DataLoader)    assert eval_dataloader is None    if cuda_is_available:        assert get_model.return_value.to.call_count == 1        assert get_model.return_value.to.call_args.args[0] == "cuda"    else:        assert get_model.return_value.to.call_count == 0@patch("llama_recipes.finetuning.torch.cuda.is_available")@patch("llama_recipes.finetuning.train")@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")@patch("llama_recipes.finetuning.get_preprocessed_dataset")@patch("llama_recipes.finetuning.optim.AdamW")@patch("llama_recipes.finetuning.StepLR")@pytest.mark.parametrize("cuda_is_available", [True, False])def test_finetuning_with_validation(    step_lr,    optimizer,    get_dataset,    tokenizer,    get_model,    train,    cuda,    cuda_is_available,):    kwargs = {"run_validation": True}    get_dataset.return_value = get_fake_dataset()    cuda.return_value = cuda_is_available    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]    main(**kwargs)    assert train.call_count == 1    args, kwargs = train.call_args    train_dataloader = args[1]    eval_dataloader = args[2]    assert isinstance(train_dataloader, DataLoader)    assert isinstance(eval_dataloader, DataLoader)    if cuda_is_available:        assert get_model.return_value.to.call_count == 1        assert get_model.return_value.to.call_args.args[0] == "cuda"    else:        assert get_model.return_value.to.call_count == 0@patch("llama_recipes.finetuning.torch.cuda.is_available")@patch("llama_recipes.finetuning.train")@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")@patch("llama_recipes.finetuning.get_preprocessed_dataset")@patch("llama_recipes.finetuning.generate_peft_config")@patch("llama_recipes.finetuning.get_peft_model")@patch("llama_recipes.finetuning.optim.AdamW")@patch("llama_recipes.finetuning.StepLR")@pytest.mark.parametrize("cuda_is_available", [True, False])def test_finetuning_peft_lora(    step_lr,    optimizer,    get_peft_model,    gen_peft_config,    get_dataset,    tokenizer,    get_model,    train,    cuda,    cuda_is_available,):    kwargs = {"use_peft": True}    get_dataset.return_value = get_fake_dataset()    cuda.return_value = cuda_is_available    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]    main(**kwargs)    if cuda_is_available:        assert get_peft_model.return_value.to.call_count == 1        assert get_peft_model.return_value.to.call_args.args[0] == "cuda"    else:        assert get_peft_model.return_value.to.call_count == 0    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1@patch("llama_recipes.finetuning.get_peft_model")@patch("llama_recipes.finetuning.setup")@patch("llama_recipes.finetuning.train")@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")@patch("llama_recipes.finetuning.get_preprocessed_dataset")def test_finetuning_peft_llama_adapter(    get_dataset, tokenizer, get_model, train, setup, get_peft_model):    kwargs = {        "use_peft": True,        "peft_method": "llama_adapter",        "enable_fsdp": True,    }    get_dataset.return_value = get_fake_dataset()    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]    os.environ["RANK"] = "0"    os.environ["LOCAL_RANK"] = "0"    os.environ["WORLD_SIZE"] = "1"    os.environ["MASTER_ADDR"] = "localhost"    os.environ["MASTER_PORT"] = "12345"    with pytest.raises(        RuntimeError,        match="Llama_adapter is currently not supported in combination with FSDP",    ):        main(**kwargs)    GET_ME_OUT = "Get me out of here"    get_peft_model.side_effect = RuntimeError(GET_ME_OUT)    kwargs["enable_fsdp"] = False    with pytest.raises(        RuntimeError,        match=GET_ME_OUT,    ):        main(**kwargs)@patch("llama_recipes.finetuning.train")@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")@patch("llama_recipes.finetuning.get_preprocessed_dataset")@patch("llama_recipes.finetuning.get_peft_model")@patch("llama_recipes.finetuning.StepLR")def test_finetuning_weight_decay(    step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):    kwargs = {"weight_decay": 0.01}    get_dataset.return_value = get_fake_dataset()    get_model.return_value.parameters.return_value = [torch.ones(1, 1)]    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]    main(**kwargs)    assert train.call_count == 1    args, kwargs = train.call_args    optimizer = args[4]    print(optimizer.state_dict())    assert isinstance(optimizer, AdamW)    assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)@patch("llama_recipes.finetuning.train")@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")@patch("llama_recipes.finetuning.get_preprocessed_dataset")@patch("llama_recipes.finetuning.optim.AdamW")@patch("llama_recipes.finetuning.StepLR")def test_batching_strategy(    step_lr, optimizer, get_dataset, tokenizer, get_model, train):    kwargs = {"batching_strategy": "packing"}    get_dataset.return_value = get_fake_dataset()    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]    main(**kwargs)    assert train.call_count == 1    args, kwargs = train.call_args    train_dataloader, eval_dataloader = args[1:3]    assert isinstance(train_dataloader.batch_sampler, BatchSampler)    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)    kwargs["batching_strategy"] = "padding"    train.reset_mock()    main(**kwargs)    assert train.call_count == 1    args, kwargs = train.call_args    train_dataloader, eval_dataloader = args[1:3]    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)    kwargs["batching_strategy"] = "none"    with pytest.raises(ValueError):        main(**kwargs)
 |