alpaca_dataset.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
  4. import copy
  5. import json
  6. import torch
  7. from torch.utils.data import Dataset
  8. PROMPT_DICT = {
  9. "prompt_input": (
  10. "Below is an instruction that describes a task, paired with an input that provides further context. "
  11. "Write a response that appropriately completes the request.\n\n"
  12. "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
  13. ),
  14. "prompt_no_input": (
  15. "Below is an instruction that describes a task. "
  16. "Write a response that appropriately completes the request.\n\n"
  17. "### Instruction:\n{instruction}\n\n### Response:"
  18. ),
  19. }
  20. class InstructionDataset(Dataset):
  21. def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
  22. self.ann = json.load(open(dataset_config.data_path))
  23. if partition == "train":
  24. self.ann = self.ann
  25. else:
  26. self.ann = self.ann[:200]
  27. self.max_words = max_words
  28. # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
  29. self.tokenizer = tokenizer
  30. # self.tokenizer1 = tokenizer
  31. def __len__(self):
  32. return len(self.ann)
  33. def __getitem__(self, index):
  34. IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
  35. ann = self.ann[index]
  36. if ann.get("input", "") == "":
  37. prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
  38. else:
  39. prompt = PROMPT_DICT["prompt_input"].format_map(ann)
  40. example = prompt + ann["output"]
  41. prompt = torch.tensor(
  42. self.tokenizer.encode(prompt), dtype=torch.int64
  43. )
  44. example = self.tokenizer.encode(example)
  45. example.append(self.tokenizer.eos_token_id)
  46. example = torch.tensor(
  47. example, dtype=torch.int64
  48. )
  49. if example.shape[0] > self.max_words:
  50. example = example[: self.max_words]
  51. labels = copy.deepcopy(example)
  52. labels[: len(prompt)] = -1
  53. example_mask = example.ge(0)
  54. label_mask = labels.ge(0)
  55. example[~example_mask] = 0
  56. labels[~label_mask] = IGNORE_INDEX
  57. example_mask = example_mask.float()
  58. label_mask = label_mask.float()
  59. return {
  60. "input_ids": example,
  61. "labels": labels,
  62. "attention_mask":example_mask,
  63. }