| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
 
- import copy
 
- import itertools
 
- import torch
 
- from datasets import load_dataset
 
- # check system prompt token seq or user prompt token seq is in the current token list
 
- def check_header(targets, seq):
 
-     for i in range(len(seq) - 3):
 
-         if seq[i : i + 3] in targets:
 
-             return True
 
-     return False
 
- def replace_target(target, seq):
 
-     for i in range(len(seq) - 3):
 
-         if seq[i : i + 3] == target:
 
-             seq[i], seq[i + 1], seq[i + 2] = -100, -100, -100
 
-     return seq
 
- def tokenize_dialogs(dialogs, images, processor):
 
-     text_prompt = processor.apply_chat_template(dialogs)
 
-     text_prompt = [prompt.replace('<|begin_of_text|>','') for prompt in text_prompt]
 
-     batch = processor(
 
-         images=images,
 
-         text=text_prompt,
 
-         padding=True,
 
-         return_tensors="pt",
 
-     )
 
-     label_list = []
 
-     for i in range(len(batch["input_ids"])):
 
-         dialog_tokens = batch["input_ids"][i].tolist()
 
-         labels = copy.copy(dialog_tokens)
 
-         eot_indices = [i for i, n in enumerate(labels) if n == 128009]
 
-         last_idx = 0
 
-         # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
 
-         # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
 
-         prompt_header_seqs = [[128006, 9125, 128007], [128006, 882, 128007]]
 
-         for n, idx in enumerate(eot_indices):
 
-             current_seq = labels[last_idx : idx + 1]
 
-             if check_header(prompt_header_seqs, current_seq):
 
-                 # found prompt header, indicating that this seq should be masked
 
-                 labels[last_idx : idx + 1] = [-100] * (idx - last_idx + 1)
 
-             else:
 
-                 last_idx = idx + 1
 
-             #  Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
 
-         assistant_header_seq = [128006, 78191, 128007]
 
-         labels = replace_target(assistant_header_seq, labels)
 
-         # Mask the padding token and image token 128256
 
-         for i in range(len(labels)):
 
-             if (
 
-                 labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256
 
-             ):  #  128256 is image token index
 
-                 labels[i] = -100
 
-         label_list.append(labels)
 
-     batch["labels"] = torch.tensor(label_list)
 
-     return batch
 
- def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
 
-     # load_dataset will return DatasetDict that contains all the data in the train set
 
-     dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
 
-     dataset = dataset_dict["train"]
 
-     # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
 
-     dataset = dataset.select(range(2000))
 
-     dataset = dataset.train_test_split(
 
-         test_size=1 - split_ratio, shuffle=True, seed=42
 
-     )[split]
 
-     return dataset
 
- class OCRVQADataCollator:
 
-     def __init__(self, processor):
 
-         self.processor = processor
 
-         self.processor.tokenizer.padding_side = (
 
-             "right"  # during training, one always uses padding on the right
 
-         )
 
-     def __call__(self, samples):
 
-         dialogs, images = [], []
 
-         for sample in samples:
 
-             image_list, sample_list = sample["images"], sample["texts"]
 
-             if len(image_list) > 1:
 
-                 raise ValueError("Only support one image per sample")
 
-             image = image_list[0].convert("RGB")  # only use the first image
 
-             dialog = []
 
-             for sample_dict in sample_list:
 
-                 if not dialog:
 
-                     # only append image to the first sentence
 
-                     dialog += [
 
-                         {
 
-                             "role": "user",
 
-                             "content": [
 
-                                 {"type": "image"},
 
-                                 {"type": "text", "text": sample_dict["user"].strip()},
 
-                             ],
 
-                         },
 
-                         {
 
-                             "role": "assistant",
 
-                             "content": [
 
-                                 {
 
-                                     "type": "text",
 
-                                     "text": sample_dict["assistant"].strip(),
 
-                                 }
 
-                             ],
 
-                         },
 
-                     ]
 
-                 else:
 
-                     dialog += [
 
-                         {
 
-                             "role": "user",
 
-                             "content": [
 
-                                 {"type": "text", "text": sample_dict["user"].strip()}
 
-                             ],
 
-                         },
 
-                         {
 
-                             "role": "assistant",
 
-                             "content": [
 
-                                 {
 
-                                     "type": "text",
 
-                                     "text": sample_dict["assistant"].strip(),
 
-                                 }
 
-                             ],
 
-                         },
 
-                     ]
 
-             dialogs.append(dialog)
 
-             images.append([image])
 
-         return tokenize_dialogs(dialogs, images, self.processor)
 
- def get_data_collator(processor):
 
-     return OCRVQADataCollator(processor)
 
 
  |