|
@@ -51,57 +51,12 @@ def tokenize_dialogs(dialogs, images, processor):
|
|
|
tokenizer_length = len(processor.tokenizer)
|
|
|
return batch
|
|
|
|
|
|
-def tokenize_dialog(dialog, images, processor):
|
|
|
- # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
|
|
|
- text_prompt = processor.apply_chat_template(dialog)
|
|
|
- #print("text_prompt",text_prompt)
|
|
|
- batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
|
|
|
- labels = copy.copy(batch["input_ids"].tolist()[0])
|
|
|
- eot_indices = [i for i,n in enumerate(labels) if n == 128009]
|
|
|
- last_idx = 0
|
|
|
- # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
|
|
|
- # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
|
|
|
- prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
|
|
|
- for n, idx in enumerate(eot_indices):
|
|
|
- current_seq = labels[last_idx:idx+1]
|
|
|
- if check_header(prompt_header_seqs,current_seq):
|
|
|
- # found prompt header, indicating that this seq should be masked
|
|
|
- labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
|
|
|
- else:
|
|
|
- last_idx = idx+1
|
|
|
- # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
|
|
|
- assistant_header_seq = [128006, 78191, 128007]
|
|
|
- labels = replace_target(assistant_header_seq,labels)
|
|
|
- #print("labels",labels)
|
|
|
- # print("pixel_values .shape",batch["pixel_values"].shape)
|
|
|
- # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
|
|
|
|
|
|
- batch["labels"] = torch.tensor(labels)
|
|
|
- # exit()
|
|
|
- # combined_tokens = {
|
|
|
- # # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
|
|
|
- # # "labels": list(itertools.chain(*(t for t in labels_tokens))),
|
|
|
- # "input_ids": dialog_tokens,
|
|
|
- # "labels": labels,
|
|
|
- # "attention_mask": [1]*len(dialog_tokens),
|
|
|
- # "pixel_values": batch["pixel_values"],
|
|
|
- # "aspect_ratio_ids": batch["aspect_ratio_ids"],
|
|
|
- # "aspect_ratio_mask": batch["aspect_ratio_mask"],
|
|
|
- # "cross_attention_mask": batch["cross_attention_mask"]
|
|
|
- # }
|
|
|
- # input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
|
|
|
- # labels = list(itertools.chain(*(t for t in labels_tokens))),
|
|
|
- # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
|
|
|
- # pixel_values = batch["pixel_values"],
|
|
|
- # image_sizes = batch["image_sizes"]
|
|
|
-# print("combined_tokens",combined_tokens[image_sizes])
|
|
|
-
|
|
|
- return batch
|
|
|
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
|
|
|
# load_dataset will return DatasetDict that contains all the data in the train set
|
|
|
- dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
|
|
|
- dataset = dataset_dict[split]
|
|
|
- dataset = dataset.select(range(500))
|
|
|
+ dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ai2d")
|
|
|
+ dataset = dataset_dict['train']
|
|
|
+ dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
|
|
|
return dataset
|
|
|
|
|
|
class VQADataCollator:
|
|
@@ -111,35 +66,26 @@ class VQADataCollator:
|
|
|
def __call__(self, samples):
|
|
|
dialogs,images = [],[]
|
|
|
for sample in samples:
|
|
|
- image,sample_text = sample["images"],sample["messages"]
|
|
|
+ image_list,sample_list = sample["images"],sample["texts"]
|
|
|
+ if len(image_list) > 1:
|
|
|
+ raise ValueError("Only support one image per sample")
|
|
|
+ image = image_list[0].convert("RGB") # only use the first image
|
|
|
dialog = []
|
|
|
- for line in sample_text:
|
|
|
- content = []
|
|
|
- messages = line["content"]
|
|
|
- role = line["role"]
|
|
|
- for message in messages:
|
|
|
- if message["type"] == "image":
|
|
|
- content.append({"type": "image"})
|
|
|
- elif message["type"] == "text":
|
|
|
- content.append({"type": "text", "text": message["text"].strip()})
|
|
|
- dialog.append({"role": role,"content":content})
|
|
|
+ for sample_dict in sample_list:
|
|
|
+ if not dialog:
|
|
|
+ # only append image to the first sentence
|
|
|
+ dialog += [
|
|
|
+ {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
|
|
|
+ {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
|
|
|
+ ]
|
|
|
+
|
|
|
+ else:
|
|
|
+ dialog += [
|
|
|
+ {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
|
|
|
+ {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
|
|
|
+ ]
|
|
|
dialogs.append(dialog)
|
|
|
- images.append(image)
|
|
|
+ images.append([image])
|
|
|
return tokenize_dialogs(dialogs,images, self.processor)
|
|
|
- def __callworking__(self, samples):
|
|
|
- for sample in samples:
|
|
|
- image,sample_text = sample["images"],sample["messages"]
|
|
|
- dialog = []
|
|
|
- for line in sample_text:
|
|
|
- content = []
|
|
|
- messages = line["content"]
|
|
|
- role = line["role"]
|
|
|
- for message in messages:
|
|
|
- if message["type"] == "image":
|
|
|
- content.append({"type": "image"})
|
|
|
- elif message["type"] == "text":
|
|
|
- content.append({"type": "text", "text": message["text"].strip()})
|
|
|
- dialog.append({"role": role,"content":content})
|
|
|
- return tokenize_dialog(dialog,image, self.processor)
|
|
|
def get_data_collator(processor):
|
|
|
return VQADataCollator(processor)
|