label_script.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. import argparse
  3. import torch
  4. from transformers import MllamaForConditionalGeneration, MllamaProcessor
  5. from tqdm import tqdm
  6. import csv
  7. from PIL import Image
  8. import torch.multiprocessing as mp
  9. USER_TEXT = """
  10. You are an expert fashion captioner, we are writing descriptions of clothes, look at the image closely and write a caption for it.
  11. Write the following Title, Size, Category, Gender, Type, Description in JSON FORMAT, PLEASE DO NOT FORGET JSON, I WILL BE VERY SAD AND CRY
  12. ALSO START WITH THE JSON AND NOT ANY THING ELSE, FIRST CHAR IN YOUR RESPONSE IS ITS OPENING BRACE, I WILL DRINK CHAI IF YOU FOLLOW THIS
  13. FOLLOW THESE STEPS CLOSELY WHEN WRITING THE CAPTION:
  14. 1. Only start your response with a dictionary like the example below, nothing else, I NEED TO PARSE IT LATER, SO DONT ADD ANYTHING ELSE-IT WILL BREAK MY CODE AND I WILL BE VERY SAD
  15. Remember-DO NOT SAY ANYTHING ELSE ABOUT WHAT IS GOING ON, just the opening brace is the first thing in your response nothing else ok?
  16. 2. REMEMBER TO CLOSE THE DICTIONARY WITH '}'BRACE, IT GOES AFTER THE END OF DESCRIPTION-YOU ALWAYS FORGET IT, THIS WILL CAUSE A FIRE ON A PRODUCTION SERVER BEING USE BY MILLIONS
  17. 3. If you cant tell the size from image, guess it! its okay but dont literally write that you guessed it
  18. 4. Do not make the caption very literal, all of these are product photos, DO NOT CAPTION HOW OR WHERE THEY ARE PLACED, FOCUS ON WRITING ABOUT THE PIECE OF CLOTHING
  19. 5. BE CREATIVE WITH THE DESCRIPTION BUT FOLLOW EVERYTHING CLOSELY FOR STRUCTURE
  20. 6. Return your answer in dictionary format, see the example below
  21. 7. Please do NOT add new lines or tabs in the JSON
  22. 8. I REPEAT DO NOT GIVE ME YOUR EXPLAINATION START WITH THE JSON
  23. {"Title": "Title of item of clothing", "Size": {'S', 'M', 'L', 'XL'}, #select one randomly if you cant tell from the image. DO NOT TELL ME YOU ESTIMATE OR GUESSED IT ONLY THE LETTER IS ENOUGH", Category": {T-Shirt, Shoes, Tops, Pants, Jeans, Shorts, Skirts, Shoes, Footwear}, "Gender": {M, F, U}, "Type": {Casual, Formal, Work Casual, Lounge}, "Description": "Write it here"}
  24. Example: ALWAYS RETURN ANSWERS IN THE DICTIONARY FORMAT BELOW OK?
  25. {"Title": "Casual White pant with logo on it", "size": "L", "Category": "Jeans", "Gender": "U", "Type": "Work Casual", "Description": "Write it here, this is where your stuff goes"}
  26. """
  27. def get_image(image_path):
  28. return Image.open(image_path).convert('RGB')
  29. def process_images(rank, world_size, args, model_name, input_files, output_csv):
  30. # Set up the model and processor for this GPU
  31. model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map=f"cuda:{rank}", torch_dtype=torch.bfloat16, token=args.hf_token)
  32. processor = MllamaProcessor.from_pretrained(model_name, token=args.hf_token)
  33. # Calculate the chunk of files this GPU will process
  34. chunk_size = len(input_files) // world_size
  35. start_idx = rank * chunk_size
  36. end_idx = start_idx + chunk_size if rank < world_size - 1 else len(input_files)
  37. results = []
  38. # Process files with TQDM
  39. for filename in tqdm(input_files[start_idx:end_idx], desc=f"GPU {rank} processing", position=rank):
  40. image_path = os.path.join(args.input_path, filename)
  41. image = get_image(image_path)
  42. conversation = [
  43. {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_TEXT}]}
  44. ]
  45. prompt = processor.apply_chat_template(conversation, add_special_tokens=False, add_generation_prompt=True, tokenize=False)
  46. inputs = processor(image, prompt, return_tensors="pt").to(model.device)
  47. output = model.generate(**inputs, temperature=1, top_p=0.9, max_new_tokens=512)
  48. decoded_output = processor.decode(output[0])[len(prompt):]
  49. results.append((filename, decoded_output))
  50. # Write results to CSV
  51. with open(output_csv, 'w', newline='', encoding='utf-8') as f:
  52. writer = csv.writer(f)
  53. writer.writerow(['Filename', 'Caption'])
  54. writer.writerows(results)
  55. def main():
  56. parser = argparse.ArgumentParser(description="Multi-GPU Image Captioning")
  57. parser.add_argument("--hf_token", required=True, help="Hugging Face API token")
  58. parser.add_argument("--input_path", required=True, help="Path to input image folder")
  59. parser.add_argument("--output_path", required=True, help="Path to output CSV folder")
  60. parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
  61. args = parser.parse_args()
  62. model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
  63. # Get list of image files
  64. input_files = [f for f in os.listdir(args.input_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
  65. # Set up multi-processing
  66. mp.set_start_method('spawn', force=True)
  67. processes = []
  68. for rank in range(args.num_gpus):
  69. output_csv = os.path.join(args.output_path, f"captions_gpu_{rank}.csv")
  70. p = mp.Process(target=process_images, args=(rank, args.num_gpus, args, model_name, input_files, output_csv))
  71. p.start()
  72. processes.append(p)
  73. for p in processes:
  74. p.join()
  75. if __name__ == "__main__":
  76. main()