label_script.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import os
  2. import argparse
  3. import torch
  4. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  5. from tqdm.auto import tqdm
  6. import csv
  7. from PIL import Image
  8. import torch.multiprocessing as mp
  9. from concurrent.futures import ProcessPoolExecutor
  10. import shutil
  11. import time
  12. USER_TEXT = """
  13. You are an expert fashion captioner, we are writing descriptions of clothes, look at the image closely and write a caption for it.
  14. Write the following Title, Size, Category, Gender, Type, Description in JSON FORMAT, PLEASE DO NOT FORGET JSON,
  15. ALSO START WITH THE JSON AND NOT ANY THING ELSE, FIRST CHAR IN YOUR RESPONSE IS ITS OPENING BRACE
  16. FOLLOW THESE STEPS CLOSELY WHEN WRITING THE CAPTION:
  17. 1. Only start your response with a dictionary like the example below, nothing else, I NEED TO PARSE IT LATER, SO DONT ADD ANYTHING ELSE-IT WILL BREAK MY CODE
  18. Remember-DO NOT SAY ANYTHING ELSE ABOUT WHAT IS GOING ON, just the opening brace is the first thing in your response nothing else ok?
  19. 2. REMEMBER TO CLOSE THE DICTIONARY WITH '}'BRACE, IT GOES AFTER THE END OF DESCRIPTION-YOU ALWAYS FORGET IT, THIS WILL CAUSE A LOT OF ISSUES
  20. 3. If you cant tell the size from image, guess it! its okay but dont literally write that you guessed it
  21. 4. Do not make the caption very literal, all of these are product photos, DO NOT CAPTION HOW OR WHERE THEY ARE PLACED, FOCUS ON WRITING ABOUT THE PIECE OF CLOTHING
  22. 5. BE CREATIVE WITH THE DESCRIPTION BUT FOLLOW EVERYTHING CLOSELY FOR STRUCTURE
  23. 6. Return your answer in dictionary format, see the example below
  24. {"Title": "Title of item of clothing", "Size": {'S', 'M', 'L', 'XL'}, #select one randomly if you cant tell from the image. DO NOT TELL ME YOU ESTIMATE OR GUESSED IT ONLY THE LETTER IS ENOUGH", Category": {T-Shirt, Shoes, Tops, Pants, Jeans, Shorts, Skirts, Shoes, Footwear}, "Gender": {M, F, U}, "Type": {Casual, Formal, Work Casual, Lounge}, "Description": "Write it here"}
  25. Example: ALWAYS RETURN ANSWERS IN THE DICTIONARY FORMAT BELOW OK?
  26. {"Title": "Casual White pant with logo on it", "size": "L", "Category": "Jeans", "Gender": "U", "Type": "Work Casual", "Description": "Write it here, this is where your stuff goes"}
  27. """
  28. def is_image_corrupt(image_path):
  29. try:
  30. with Image.open(image_path) as img:
  31. img.verify()
  32. return False
  33. except (IOError, SyntaxError, Image.UnidentifiedImageError):
  34. return True
  35. def find_and_move_corrupt_images(folder_path, corrupt_folder):
  36. image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path)
  37. if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
  38. num_cores = mp.cpu_count()
  39. with tqdm(total=len(image_files), desc="Checking for corrupt images", unit="file",
  40. bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") as pbar:
  41. with ProcessPoolExecutor(max_workers=num_cores) as executor:
  42. results = list(executor.map(is_image_corrupt, image_files))
  43. pbar.update(len(image_files))
  44. corrupt_images = [img for img, is_corrupt in zip(image_files, results) if is_corrupt]
  45. os.makedirs(corrupt_folder, exist_ok=True)
  46. for img in tqdm(corrupt_images, desc="Moving corrupt images", unit="file",
  47. bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"):
  48. shutil.move(img, os.path.join(corrupt_folder, os.path.basename(img)))
  49. print(f"Moved {len(corrupt_images)} corrupt images to {corrupt_folder}")
  50. def get_image(image_path):
  51. return Image.open(image_path).convert('RGB')
  52. def llama_progress_bar(total, desc, position=0):
  53. """Custom progress bar with llama emojis."""
  54. bar_format = "{desc}: |{bar}| {percentage:3.0f}% [{n_fmt}/{total_fmt}, {rate_fmt}{postfix}]"
  55. return tqdm(total=total, desc=desc, position=position, bar_format=bar_format, ascii="🦙·")
  56. def process_images(rank, world_size, args, model_name, input_files, output_csv):
  57. model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map=f"cuda:{rank}", torch_dtype=torch.bfloat16, token=args.hf_token)
  58. processor = MllamaProcessor.from_pretrained(model_name, token=args.hf_token)
  59. chunk_size = len(input_files) // world_size
  60. start_idx = rank * chunk_size
  61. end_idx = start_idx + chunk_size if rank < world_size - 1 else len(input_files)
  62. results = []
  63. pbar = llama_progress_bar(total=end_idx - start_idx, desc=f"GPU {rank}", position=rank)
  64. for filename in input_files[start_idx:end_idx]:
  65. image_path = os.path.join(args.input_path, filename)
  66. image = get_image(image_path)
  67. conversation = [
  68. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_TEXT}]}
  69. ]
  70. prompt = processor.apply_chat_template(conversation, add_special_tokens=False, add_generation_prompt=True, tokenize=False)
  71. inputs = processor(image, prompt, return_tensors="pt").to(model.device)
  72. output = model.generate(**inputs, temperature=1, top_p=0.9, max_new_tokens=512)
  73. decoded_output = processor.decode(output[0])[len(prompt):]
  74. results.append((filename, decoded_output))
  75. pbar.update(1)
  76. pbar.set_postfix({"Last File": filename})
  77. pbar.close()
  78. with open(output_csv, 'w', newline='', encoding='utf-8') as f:
  79. writer = csv.writer(f)
  80. writer.writerow(['Filename', 'Caption'])
  81. writer.writerows(results)
  82. def main():
  83. parser = argparse.ArgumentParser(description="Multi-GPU Image Captioning")
  84. parser.add_argument("--hf_token", required=True, help="Hugging Face API token")
  85. parser.add_argument("--input_path", required=True, help="Path to input image folder")
  86. parser.add_argument("--output_path", required=True, help="Path to output CSV folder")
  87. parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
  88. parser.add_argument("--corrupt_folder", default="corrupt_images", help="Folder to move corrupt images")
  89. args = parser.parse_args()
  90. model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
  91. print("🦙 Starting image processing pipeline...")
  92. start_time = time.time()
  93. # Find and move corrupt images
  94. corrupt_folder = os.path.join(args.input_path, args.corrupt_folder)
  95. find_and_move_corrupt_images(args.input_path, corrupt_folder)
  96. # Get list of remaining (non-corrupt) image files
  97. input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
  98. print(f"\n🦙 Processing {len(input_files)} images using {args.num_gpus} GPUs...")
  99. mp.set_start_method('spawn', force=True)
  100. processes = []
  101. for rank in range(args.num_gpus):
  102. output_csv = os.path.join(args.output_path, f"captions_gpu_{rank}.csv")
  103. p = mp.Process(target=process_images, args=(rank, args.num_gpus, args, model_name, input_files, output_csv))
  104. p.start()
  105. processes.append(p)
  106. for p in processes:
  107. p.join()
  108. end_time = time.time()
  109. total_time = end_time - start_time
  110. print(f"\n🦙 Total processing time: {total_time:.2f} seconds")
  111. print("🦙 Image captioning completed successfully!")
  112. if __name__ == "__main__":
  113. main()