vqa_dataset.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import copy
  4. from datasets import load_dataset
  5. import itertools
  6. # check system prompt token seq or user prompt token seq is in the current token list
  7. def check_header(targets,seq):
  8. for i in range(len(seq)-3):
  9. if seq[i:i+3] in targets:
  10. return True
  11. return False
  12. def replace_target(target,seq):
  13. for i in range(len(seq)-3):
  14. if seq[i:i+3] == target:
  15. seq[i],seq[i+1],seq[i+2] = -100,-100,-100
  16. return seq
  17. def tokenize_dialog(dialog, images, processor):
  18. # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
  19. text_prompt = processor.apply_chat_template(dialog)
  20. #print("text_prompt",text_prompt)
  21. batch = processor(images=images, text=text_prompt)
  22. dialog_tokens = batch["input_ids"].tolist()[0]
  23. #print("dialog_tokens",dialog_tokens)
  24. #print("dialog_tokens",dialog_tokens)
  25. attention_mask = batch["attention_mask"].tolist()[0]
  26. #print("attention_mask",attention_mask)
  27. labels = copy.copy(dialog_tokens)
  28. eot_indices = [i for i,n in enumerate(labels) if n == 128009]
  29. last_idx = 0
  30. # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
  31. # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
  32. prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
  33. for n, idx in enumerate(eot_indices):
  34. current_seq = labels[last_idx:idx+1]
  35. if check_header(prompt_header_seqs,current_seq):
  36. # found prompt header, indicating that this seq should be masked
  37. labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
  38. else:
  39. last_idx = idx+1
  40. # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
  41. assistant_header_seq = [128006, 78191, 128007]
  42. labels = replace_target(assistant_header_seq,labels)
  43. #print("labels",labels)
  44. combined_tokens = {
  45. # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
  46. # "labels": list(itertools.chain(*(t for t in labels_tokens))),
  47. "input_ids": dialog_tokens,
  48. "labels": labels,
  49. "attention_mask": [1]*len(dialog_tokens),
  50. "pixel_values": batch["pixel_values"].tolist()[0],
  51. "image_sizes": batch["image_sizes"].tolist()[0]
  52. }
  53. # input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
  54. # labels = list(itertools.chain(*(t for t in labels_tokens))),
  55. # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
  56. # pixel_values = batch["pixel_values"],
  57. # image_sizes = batch["image_sizes"]
  58. # print("combined_tokens",combined_tokens[image_sizes])
  59. return combined_tokens
  60. def image_tokenize(sample, processor):
  61. processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
  62. images,sample_text = sample["images"],sample["messages"]
  63. dialog = []
  64. for line in sample_text:
  65. content = []
  66. messages = line["content"]
  67. role = line["role"]
  68. for message in messages:
  69. if message["type"] == "image":
  70. content.append({"type": "image"})
  71. elif message["type"] == "text":
  72. content.append({"type": "text", "text": message["text"].strip()})
  73. dialog.append({"role": role,"content":content})
  74. return tokenize_dialog(dialog,images, processor)
  75. def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
  76. # load_dataset will return DatasetDict that contains all the data in the train set
  77. dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
  78. dataset = dataset_dict[split]
  79. dataset = dataset.select(range(100))
  80. tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
  81. tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
  82. return tokenized_datasets