ocrvqa_dataset.py 4.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. from datasets import load_dataset
  5. import itertools
  6. import torch
  7. # check system prompt token seq or user prompt token seq is in the current token list
  8. def check_header(targets,seq):
  9. for i in range(len(seq)-3):
  10. if seq[i:i+3] in targets:
  11. return True
  12. return False
  13. def replace_target(target,seq):
  14. for i in range(len(seq)-3):
  15. if seq[i:i+3] == target:
  16. seq[i],seq[i+1],seq[i+2] = -100,-100,-100
  17. return seq
  18. def tokenize_dialogs(dialogs, images, processor):
  19. text_prompt = processor.apply_chat_template(dialogs)
  20. batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
  21. label_list = []
  22. for i in range(len(batch["input_ids"])):
  23. dialog_tokens = batch["input_ids"][i].tolist()
  24. labels = copy.copy(dialog_tokens)
  25. eot_indices = [i for i,n in enumerate(labels) if n == 128009]
  26. last_idx = 0
  27. # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
  28. # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
  29. prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
  30. for n, idx in enumerate(eot_indices):
  31. current_seq = labels[last_idx:idx+1]
  32. if check_header(prompt_header_seqs,current_seq):
  33. # found prompt header, indicating that this seq should be masked
  34. labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
  35. else:
  36. last_idx = idx+1
  37. # Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
  38. assistant_header_seq = [128006, 78191, 128007]
  39. labels = replace_target(assistant_header_seq,labels)
  40. # Mask the padding token and image token 128256
  41. for i in range(len(labels)):
  42. if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: # 128256 is image token index
  43. labels[i] = -100
  44. label_list.append(labels)
  45. batch["labels"] = torch.tensor(label_list)
  46. return batch
  47. def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
  48. # load_dataset will return DatasetDict that contains all the data in the train set
  49. dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
  50. dataset = dataset_dict['train']
  51. # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
  52. dataset = dataset.select(range(2000))
  53. dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
  54. return dataset
  55. class OCRVQADataCollator:
  56. def __init__(self, processor):
  57. self.processor = processor
  58. self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
  59. def __call__(self, samples):
  60. dialogs,images = [],[]
  61. for sample in samples:
  62. image_list,sample_list = sample["images"],sample["texts"]
  63. if len(image_list) > 1:
  64. raise ValueError("Only support one image per sample")
  65. image = image_list[0].convert("RGB") # only use the first image
  66. dialog = []
  67. for sample_dict in sample_list:
  68. if not dialog:
  69. # only append image to the first sentence
  70. dialog += [
  71. {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
  72. {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
  73. ]
  74. else:
  75. dialog += [
  76. {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
  77. {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
  78. ]
  79. dialogs.append(dialog)
  80. images.append([image])
  81. return tokenize_dialogs(dialogs,images, self.processor)
  82. def get_data_collator(processor):
  83. return OCRVQADataCollator(processor)