| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.import copyimport itertoolsimport torchfrom datasets import load_dataset# check system prompt token seq or user prompt token seq is in the current token listdef check_header(targets, seq):    for i in range(len(seq) - 3):        if seq[i : i + 3] in targets:            return True    return Falsedef replace_target(target, seq):    for i in range(len(seq) - 3):        if seq[i : i + 3] == target:            seq[i], seq[i + 1], seq[i + 2] = -100, -100, -100    return seqdef tokenize_dialogs(dialogs, images, processor):    text_prompt = processor.apply_chat_template(dialogs)    text_prompt = [prompt.replace('<|begin_of_text|>','') for prompt in text_prompt]    batch = processor(        images=images,        text=text_prompt,        padding=True,        return_tensors="pt",    )    label_list = []    for i in range(len(batch["input_ids"])):        dialog_tokens = batch["input_ids"][i].tolist()        labels = copy.copy(dialog_tokens)        eot_indices = [i for i, n in enumerate(labels) if n == 128009]        last_idx = 0        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]        prompt_header_seqs = [[128006, 9125, 128007], [128006, 882, 128007]]        for n, idx in enumerate(eot_indices):            current_seq = labels[last_idx : idx + 1]            if check_header(prompt_header_seqs, current_seq):                # found prompt header, indicating that this seq should be masked                labels[last_idx : idx + 1] = [-100] * (idx - last_idx + 1)            else:                last_idx = idx + 1            #  Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]        assistant_header_seq = [128006, 78191, 128007]        labels = replace_target(assistant_header_seq, labels)        # Mask the padding token and image token 128256        for i in range(len(labels)):            if (                labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256            ):  #  128256 is image token index                labels[i] = -100        label_list.append(labels)    batch["labels"] = torch.tensor(label_list)    return batchdef get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):    # load_dataset will return DatasetDict that contains all the data in the train set    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")    dataset = dataset_dict["train"]    # Comment out the following line to use the full dataset, for quick testing only use 2000 samples    dataset = dataset.select(range(2000))    dataset = dataset.train_test_split(        test_size=1 - split_ratio, shuffle=True, seed=42    )[split]    return datasetclass OCRVQADataCollator:    def __init__(self, processor):        self.processor = processor        self.processor.tokenizer.padding_side = (            "right"  # during training, one always uses padding on the right        )    def __call__(self, samples):        dialogs, images = [], []        for sample in samples:            image_list, sample_list = sample["images"], sample["texts"]            if len(image_list) > 1:                raise ValueError("Only support one image per sample")            image = image_list[0].convert("RGB")  # only use the first image            dialog = []            for sample_dict in sample_list:                if not dialog:                    # only append image to the first sentence                    dialog += [                        {                            "role": "user",                            "content": [                                {"type": "image"},                                {"type": "text", "text": sample_dict["user"].strip()},                            ],                        },                        {                            "role": "assistant",                            "content": [                                {                                    "type": "text",                                    "text": sample_dict["assistant"].strip(),                                }                            ],                        },                    ]                else:                    dialog += [                        {                            "role": "user",                            "content": [                                {"type": "text", "text": sample_dict["user"].strip()}                            ],                        },                        {                            "role": "assistant",                            "content": [                                {                                    "type": "text",                                    "text": sample_dict["assistant"].strip(),                                }                            ],                        },                    ]            dialogs.append(dialog)            images.append([image])        return tokenize_dialogs(dialogs, images, self.processor)def get_data_collator(processor):    return OCRVQADataCollator(processor)
 |