test_grammar_datasets.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from pathlib import Path
  4. import pytest
  5. from unittest.mock import patch
  6. DATA_DIR = Path(__file__).parents[2] / "llama_recipes/datasets/grammar_dataset/"
  7. @pytest.mark.skip_missing_tokenizer
  8. @pytest.mark.skipif(not Path(DATA_DIR / "grammar_validation.csv").exists(), reason="grammar_validation.csv not found")
  9. @pytest.mark.skipif(not Path(DATA_DIR / "gtrain_10k.csv").exists(), reason="gtrain_10k.csv not found")
  10. @patch('llama_recipes.finetuning.train')
  11. @patch('llama_recipes.finetuning.AutoTokenizer')
  12. @patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
  13. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  14. @patch('llama_recipes.finetuning.optim.AdamW')
  15. @patch('llama_recipes.finetuning.StepLR')
  16. def test_grammar_dataset(step_lr, optimizer, get_model, get_config, tokenizer, train, setup_tokenizer, llama_version):
  17. from llama_recipes.finetuning import main
  18. setup_tokenizer(tokenizer)
  19. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  20. get_config.return_value.model_type = "llama"
  21. BATCH_SIZE = 8
  22. kwargs = {
  23. "model_name": llama_version,
  24. "batch_size_training": BATCH_SIZE,
  25. "val_batch_size": 1,
  26. "use_peft": False,
  27. "dataset": "grammar_dataset",
  28. "batching_strategy": "padding",
  29. }
  30. main(**kwargs)
  31. assert train.call_count == 1
  32. args, kwargs = train.call_args
  33. train_dataloader = args[1]
  34. eval_dataloader = args[2]
  35. VAL_SAMPLES = 2988
  36. TRAIN_SAMPLES = 13016
  37. assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
  38. assert len(eval_dataloader) == VAL_SAMPLES
  39. batch = next(iter(train_dataloader))
  40. assert "labels" in batch.keys()
  41. assert "input_ids" in batch.keys()
  42. assert "attention_mask" in batch.keys()
  43. token = args[3]
  44. assert batch["input_ids"][0][0] == token.bos_token_id
  45. assert batch["labels"][0][-1] == token.eos_token_id
  46. assert batch["input_ids"][0][-1] == token.eos_token_id