|
@@ -48,18 +48,19 @@ def tokenize_dialogs(dialogs, images, processor):
|
|
|
labels[i] = -100
|
|
|
label_list.append(labels)
|
|
|
batch["labels"] = torch.tensor(label_list)
|
|
|
- tokenizer_length = len(processor.tokenizer)
|
|
|
return batch
|
|
|
|
|
|
|
|
|
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
|
|
|
# load_dataset will return DatasetDict that contains all the data in the train set
|
|
|
- dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ai2d")
|
|
|
+ dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
|
|
|
dataset = dataset_dict['train']
|
|
|
+ # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
|
|
|
+ dataset = dataset.select(range(2000))
|
|
|
dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
|
|
|
return dataset
|
|
|
|
|
|
-class VQADataCollator:
|
|
|
+class OCRVQADataCollator:
|
|
|
def __init__(self, processor):
|
|
|
self.processor = processor
|
|
|
self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
|
|
@@ -88,4 +89,4 @@ class VQADataCollator:
|
|
|
images.append([image])
|
|
|
return tokenize_dialogs(dialogs,images, self.processor)
|
|
|
def get_data_collator(processor):
|
|
|
- return VQADataCollator(processor)
|
|
|
+ return OCRVQADataCollator(processor)
|