grammar_dataset.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/jfleg
  4. # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
  5. from datasets import load_dataset
  6. from pathlib import Path
  7. from torch.utils.data import Dataset
  8. from llama_recipes.datasets.utils import ConcatDataset
  9. class grammar(Dataset):
  10. def __init__(
  11. self,
  12. tokenizer,
  13. csv_name=None,
  14. ):
  15. try:
  16. self.dataset = load_dataset(
  17. "csv",
  18. data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
  19. delimiter=",",
  20. )
  21. except Exception as e:
  22. print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.")
  23. raise e
  24. # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path)
  25. # if num_samples:
  26. # self.dataset = self.dataset.select(list(range(0, num_samples)))
  27. self.tokenizer = tokenizer
  28. self.print_text = False # print_text
  29. def __len__(self):
  30. return self.dataset["train"].shape[0]
  31. def convert_to_features(self, example_batch):
  32. # Create prompt and tokenize contexts and questions
  33. if self.print_text:
  34. print("Input Text: ", self.clean_text(example_batch["text"]))
  35. input_ = example_batch["input"]
  36. target_ = example_batch["target"]
  37. prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
  38. sample = self.tokenizer(prompt)
  39. return sample
  40. def __getitem__(self, index):
  41. sample = self.convert_to_features(self.dataset["train"][int(index)])
  42. source_ids = sample["input_ids"]
  43. src_mask = sample["attention_mask"]
  44. return {
  45. "input_ids": source_ids,
  46. "attention_mask": src_mask,
  47. "labels": source_ids.copy(),
  48. }
  49. def get_dataset(
  50. dataset_config, tokenizer, csv_name=None
  51. ):
  52. """cover function for handling loading the working dataset"""
  53. """dataset loading"""
  54. if csv_name is None:
  55. currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
  56. print(f"Loading dataset {currPath}")
  57. csv_name = str(currPath)
  58. dataset = grammar(
  59. tokenizer=tokenizer,
  60. csv_name=csv_name,
  61. )
  62. return dataset