1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import torch
- from llama_recipes.data.concatenator import ConcatDataset
- from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
- from llama_recipes.utils.config_utils import get_dataloader_kwargs
- def get_preprocessed_dataset(
- tokenizer, dataset_config, split: str = "train"
- ) -> torch.utils.data.Dataset:
- if not dataset_config.dataset in DATASET_PREPROC:
- raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
- def get_split():
- return (
- dataset_config.train_split
- if split == "train"
- else dataset_config.test_split
- )
- return DATASET_PREPROC[dataset_config.dataset](
- dataset_config,
- tokenizer,
- get_split(),
- )
- def get_custom_data_collator(
- dataset_processer, dataset_config
- ) -> torch.utils.data.Dataset:
- if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
- return None
- return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
- dataset_processer,
- dataset_config
- )
- def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
- dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
- dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
-
- if split == "train" and train_config.batching_strategy == "packing":
- dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
- # Create data loader
- dataloader = torch.utils.data.DataLoader(
- dataset,
- num_workers=train_config.num_workers_dataloader,
- pin_memory=True,
- **dl_kwargs,
- )
- return dataloader
-
|