dataset_utils.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import importlib
  4. from functools import partial
  5. from pathlib import Path
  6. import torch
  7. from llama_recipes.datasets import (
  8. get_grammar_dataset,
  9. get_alpaca_dataset,
  10. get_samsum_dataset,
  11. get_llamaguard_toxicchat_dataset,
  12. )
  13. def load_module_from_py_file(py_file: str) -> object:
  14. """
  15. This method loads a module from a py file which is not in the Python path
  16. """
  17. module_name = Path(py_file).name
  18. loader = importlib.machinery.SourceFileLoader(module_name, py_file)
  19. spec = importlib.util.spec_from_loader(module_name, loader)
  20. module = importlib.util.module_from_spec(spec)
  21. loader.exec_module(module)
  22. return module
  23. def get_custom_dataset(dataset_config, tokenizer, split: str):
  24. if ":" in dataset_config.file:
  25. module_path, func_name = dataset_config.file.split(":")
  26. else:
  27. module_path, func_name = dataset_config.file, "get_custom_dataset"
  28. if not module_path.endswith(".py"):
  29. raise ValueError(f"Dataset file {module_path} is not a .py file.")
  30. module_path = Path(module_path)
  31. if not module_path.is_file():
  32. raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
  33. module = load_module_from_py_file(module_path.as_posix())
  34. try:
  35. return getattr(module, func_name)(dataset_config, tokenizer, split)
  36. except AttributeError as e:
  37. print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
  38. raise e
  39. DATASET_PREPROC = {
  40. "alpaca_dataset": partial(get_alpaca_dataset),
  41. "grammar_dataset": get_grammar_dataset,
  42. "samsum_dataset": get_samsum_dataset,
  43. "custom_dataset": get_custom_dataset,
  44. "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
  45. }
  46. def get_preprocessed_dataset(
  47. tokenizer, dataset_config, split: str = "train"
  48. ) -> torch.utils.data.Dataset:
  49. if not dataset_config.dataset in DATASET_PREPROC:
  50. raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
  51. def get_split():
  52. return (
  53. dataset_config.train_split
  54. if split == "train"
  55. else dataset_config.test_split
  56. )
  57. return DATASET_PREPROC[dataset_config.dataset](
  58. dataset_config,
  59. tokenizer,
  60. get_split(),
  61. )