| 
					
				 | 
			
			
				@@ -48,18 +48,19 @@ def tokenize_dialogs(dialogs, images, processor): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 labels[i] = -100 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         label_list.append(labels) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     batch["labels"] = torch.tensor(label_list) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    tokenizer_length = len(processor.tokenizer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return batch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # load_dataset will return DatasetDict that contains all the data in the train set 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ai2d") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset = dataset_dict['train'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Comment out the following line to use the full dataset, for quick testing only use 2000 samples 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    dataset = dataset.select(range(2000)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-class VQADataCollator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class OCRVQADataCollator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self, processor): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.processor = processor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -88,4 +89,4 @@ class VQADataCollator: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             images.append([image]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return tokenize_dialogs(dialogs,images, self.processor) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def get_data_collator(processor): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    return VQADataCollator(processor) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return OCRVQADataCollator(processor) 
			 |