dataset_utils.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import torch
  4. from llama_recipes.data.concatenator import ConcatDataset
  5. from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
  6. from llama_recipes.utils.config_utils import get_dataloader_kwargs
  7. def get_preprocessed_dataset(
  8. tokenizer, dataset_config, split: str = "train"
  9. ) -> torch.utils.data.Dataset:
  10. if not dataset_config.dataset in DATASET_PREPROC:
  11. raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
  12. def get_split():
  13. return (
  14. dataset_config.train_split
  15. if split == "train"
  16. else dataset_config.test_split
  17. )
  18. return DATASET_PREPROC[dataset_config.dataset](
  19. dataset_config,
  20. tokenizer,
  21. get_split(),
  22. )
  23. def get_custom_data_collator(
  24. dataset_processer, dataset_config
  25. ) -> torch.utils.data.Dataset:
  26. if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
  27. return None
  28. return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
  29. dataset_processer,
  30. dataset_config
  31. )
  32. def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
  33. dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
  34. dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
  35. if split == "train" and train_config.batching_strategy == "packing":
  36. dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
  37. # Create data loader
  38. dataloader = torch.utils.data.DataLoader(
  39. dataset,
  40. num_workers=train_config.num_workers_dataloader,
  41. pin_memory=True,
  42. **dl_kwargs,
  43. )
  44. return dataloader