| 1234567891011121314151617181920212223242526272829303132333435363738394041 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- import torch
 
- from functools import partial
 
- from ft_datasets import (
 
-     get_grammar_dataset,
 
-     get_alpaca_dataset,
 
-     get_samsum_dataset,
 
- )
 
- from typing import Optional
 
- DATASET_PREPROC = {
 
-     "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
 
-     "grammar_dataset": get_grammar_dataset,
 
-     "samsum_dataset": get_samsum_dataset,
 
- }
 
- def get_preprocessed_dataset(
 
-     tokenizer, dataset_config, split: str = "train"
 
- ) -> torch.utils.data.Dataset:
 
-     if not dataset_config.dataset in DATASET_PREPROC:
 
-         raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
 
-     def get_split():
 
-         return (
 
-             dataset_config.train_split
 
-             if split == "train"
 
-             else dataset_config.test_split
 
-         )
 
-     
 
-     return DATASET_PREPROC[dataset_config.dataset](
 
-         dataset_config,
 
-         tokenizer,
 
-         get_split(),
 
-     )
 
 
  |