ocrvqa_dataset.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. import itertools
  5. import torch
  6. from datasets import load_dataset
  7. # check system prompt token seq or user prompt token seq is in the current token list
  8. def check_header(targets, seq):
  9. for i in range(len(seq) - 3):
  10. if seq[i : i + 3] in targets:
  11. return True
  12. return False
  13. def replace_target(target, seq):
  14. for i in range(len(seq) - 3):
  15. if seq[i : i + 3] == target:
  16. seq[i], seq[i + 1], seq[i + 2] = -100, -100, -100
  17. return seq
  18. def tokenize_dialogs(dialogs, images, processor):
  19. text_prompt = processor.apply_chat_template(dialogs)
  20. text_prompt = [prompt.replace('<|begin_of_text|>','') for prompt in text_prompt]
  21. batch = processor(
  22. images=images,
  23. text=text_prompt,
  24. padding=True,
  25. return_tensors="pt",
  26. )
  27. label_list = []
  28. for i in range(len(batch["input_ids"])):
  29. dialog_tokens = batch["input_ids"][i].tolist()
  30. labels = copy.copy(dialog_tokens)
  31. eot_indices = [i for i, n in enumerate(labels) if n == 128009]
  32. last_idx = 0
  33. # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
  34. # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
  35. prompt_header_seqs = [[128006, 9125, 128007], [128006, 882, 128007]]
  36. for n, idx in enumerate(eot_indices):
  37. current_seq = labels[last_idx : idx + 1]
  38. if check_header(prompt_header_seqs, current_seq):
  39. # found prompt header, indicating that this seq should be masked
  40. labels[last_idx : idx + 1] = [-100] * (idx - last_idx + 1)
  41. else:
  42. last_idx = idx + 1
  43. # Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
  44. assistant_header_seq = [128006, 78191, 128007]
  45. labels = replace_target(assistant_header_seq, labels)
  46. # Mask the padding token and image token 128256
  47. for i in range(len(labels)):
  48. if (
  49. labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256
  50. ): # 128256 is image token index
  51. labels[i] = -100
  52. label_list.append(labels)
  53. batch["labels"] = torch.tensor(label_list)
  54. return batch
  55. def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
  56. # load_dataset will return DatasetDict that contains all the data in the train set
  57. dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
  58. dataset = dataset_dict["train"]
  59. # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
  60. dataset = dataset.select(range(2000))
  61. dataset = dataset.train_test_split(
  62. test_size=1 - split_ratio, shuffle=True, seed=42
  63. )[split]
  64. return dataset
  65. class OCRVQADataCollator:
  66. def __init__(self, processor):
  67. self.processor = processor
  68. self.processor.tokenizer.padding_side = (
  69. "right" # during training, one always uses padding on the right
  70. )
  71. def __call__(self, samples):
  72. dialogs, images = [], []
  73. for sample in samples:
  74. image_list, sample_list = sample["images"], sample["texts"]
  75. if len(image_list) > 1:
  76. raise ValueError("Only support one image per sample")
  77. image = image_list[0].convert("RGB") # only use the first image
  78. dialog = []
  79. for sample_dict in sample_list:
  80. if not dialog:
  81. # only append image to the first sentence
  82. dialog += [
  83. {
  84. "role": "user",
  85. "content": [
  86. {"type": "image"},
  87. {"type": "text", "text": sample_dict["user"].strip()},
  88. ],
  89. },
  90. {
  91. "role": "assistant",
  92. "content": [
  93. {
  94. "type": "text",
  95. "text": sample_dict["assistant"].strip(),
  96. }
  97. ],
  98. },
  99. ]
  100. else:
  101. dialog += [
  102. {
  103. "role": "user",
  104. "content": [
  105. {"type": "text", "text": sample_dict["user"].strip()}
  106. ],
  107. },
  108. {
  109. "role": "assistant",
  110. "content": [
  111. {
  112. "type": "text",
  113. "text": sample_dict["assistant"].strip(),
  114. }
  115. ],
  116. },
  117. ]
  118. dialogs.append(dialog)
  119. images.append([image])
  120. return tokenize_dialogs(dialogs, images, self.processor)
  121. def get_data_collator(processor):
  122. return OCRVQADataCollator(processor)