vqa_dataset.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. from datasets import load_dataset
  5. import itertools
  6. import torch
  7. # check system prompt token seq or user prompt token seq is in the current token list
  8. def check_header(targets,seq):
  9. for i in range(len(seq)-3):
  10. if seq[i:i+3] in targets:
  11. return True
  12. return False
  13. def replace_target(target,seq):
  14. for i in range(len(seq)-3):
  15. if seq[i:i+3] == target:
  16. seq[i],seq[i+1],seq[i+2] = -100,-100,-100
  17. return seq
  18. def tokenize_dialogs(dialogs, images, processor):
  19. # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
  20. text_prompt = processor.apply_chat_template(dialogs)
  21. #print("text_prompt",text_prompt)
  22. batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
  23. label_list = []
  24. for i in range(len(batch["input_ids"])):
  25. dialog_tokens = batch["input_ids"][i].tolist()
  26. labels = copy.copy(dialog_tokens)
  27. eot_indices = [i for i,n in enumerate(labels) if n == 128009]
  28. last_idx = 0
  29. # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
  30. # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
  31. prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
  32. for n, idx in enumerate(eot_indices):
  33. current_seq = labels[last_idx:idx+1]
  34. if check_header(prompt_header_seqs,current_seq):
  35. # found prompt header, indicating that this seq should be masked
  36. labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
  37. else:
  38. last_idx = idx+1
  39. # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
  40. assistant_header_seq = [128006, 78191, 128007]
  41. labels = replace_target(assistant_header_seq,labels)
  42. label_list.append(labels)
  43. batch["labels"] = torch.tensor(label_list)
  44. tokenizer_length = len(processor.tokenizer)
  45. return batch
  46. def tokenize_dialog(dialog, images, processor):
  47. # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
  48. text_prompt = processor.apply_chat_template(dialog)
  49. #print("text_prompt",text_prompt)
  50. batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
  51. labels = copy.copy(batch["input_ids"].tolist()[0])
  52. eot_indices = [i for i,n in enumerate(labels) if n == 128009]
  53. last_idx = 0
  54. # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
  55. # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
  56. prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
  57. for n, idx in enumerate(eot_indices):
  58. current_seq = labels[last_idx:idx+1]
  59. if check_header(prompt_header_seqs,current_seq):
  60. # found prompt header, indicating that this seq should be masked
  61. labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
  62. else:
  63. last_idx = idx+1
  64. # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
  65. assistant_header_seq = [128006, 78191, 128007]
  66. labels = replace_target(assistant_header_seq,labels)
  67. #print("labels",labels)
  68. # print("pixel_values .shape",batch["pixel_values"].shape)
  69. # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
  70. batch["labels"] = torch.tensor(labels)
  71. # exit()
  72. # combined_tokens = {
  73. # # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
  74. # # "labels": list(itertools.chain(*(t for t in labels_tokens))),
  75. # "input_ids": dialog_tokens,
  76. # "labels": labels,
  77. # "attention_mask": [1]*len(dialog_tokens),
  78. # "pixel_values": batch["pixel_values"],
  79. # "aspect_ratio_ids": batch["aspect_ratio_ids"],
  80. # "aspect_ratio_mask": batch["aspect_ratio_mask"],
  81. # "cross_attention_mask": batch["cross_attention_mask"]
  82. # }
  83. # input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
  84. # labels = list(itertools.chain(*(t for t in labels_tokens))),
  85. # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
  86. # pixel_values = batch["pixel_values"],
  87. # image_sizes = batch["image_sizes"]
  88. # print("combined_tokens",combined_tokens[image_sizes])
  89. return batch
  90. def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
  91. # load_dataset will return DatasetDict that contains all the data in the train set
  92. dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
  93. dataset = dataset_dict[split]
  94. dataset = dataset.select(range(500))
  95. return dataset
  96. class VQADataCollator:
  97. def __init__(self, processor):
  98. self.processor = processor
  99. self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
  100. def __call__(self, samples):
  101. dialogs,images = [],[]
  102. for sample in samples:
  103. image,sample_text = sample["images"],sample["messages"]
  104. dialog = []
  105. for line in sample_text:
  106. content = []
  107. messages = line["content"]
  108. role = line["role"]
  109. for message in messages:
  110. if message["type"] == "image":
  111. content.append({"type": "image"})
  112. elif message["type"] == "text":
  113. content.append({"type": "text", "text": message["text"].strip()})
  114. dialog.append({"role": role,"content":content})
  115. dialogs.append(dialog)
  116. images.append(image)
  117. return tokenize_dialogs(dialogs,images, self.processor)
  118. def __callworking__(self, samples):
  119. for sample in samples:
  120. image,sample_text = sample["images"],sample["messages"]
  121. dialog = []
  122. for line in sample_text:
  123. content = []
  124. messages = line["content"]
  125. role = line["role"]
  126. for message in messages:
  127. if message["type"] == "image":
  128. content.append({"type": "image"})
  129. elif message["type"] == "text":
  130. content.append({"type": "text", "text": message["text"].strip()})
  131. dialog.append({"role": role,"content":content})
  132. return tokenize_dialog(dialog,image, self.processor)
  133. def get_data_collator(processor):
  134. return VQADataCollator(processor)