|
@@ -1,3 +1,15 @@
|
|
|
+import os
|
|
|
+import argparse
|
|
|
+import torch
|
|
|
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
+from tqdm.auto import tqdm
|
|
|
+import csv
|
|
|
+from PIL import Image
|
|
|
+import torch.multiprocessing as mp
|
|
|
+from concurrent.futures import ProcessPoolExecutor
|
|
|
+import shutil
|
|
|
+import time
|
|
|
+
|
|
|
USER_TEXT = """
|
|
|
You are an expert fashion captioner, we are writing descriptions of clothes, look at the image closely and write a caption for it.
|
|
|
|
|
@@ -23,18 +35,6 @@ Example: ALWAYS RETURN ANSWERS IN THE DICTIONARY FORMAT BELOW OK?
|
|
|
{"Title": "Casual White pant with logo on it", "size": "L", "Category": "Jeans", "Gender": "U", "Type": "Work Casual", "Description": "Write it here, this is where your stuff goes"}
|
|
|
"""
|
|
|
|
|
|
-import os
|
|
|
-import argparse
|
|
|
-import torch
|
|
|
-from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
-from tqdm.auto import tqdm
|
|
|
-import csv
|
|
|
-from PIL import Image
|
|
|
-import torch.multiprocessing as mp
|
|
|
-from concurrent.futures import ProcessPoolExecutor
|
|
|
-import shutil
|
|
|
-import time
|
|
|
-
|
|
|
def is_image_corrupt(image_path):
|
|
|
try:
|
|
|
with Image.open(image_path) as img:
|
|
@@ -66,6 +66,11 @@ def find_and_move_corrupt_images(folder_path, corrupt_folder):
|
|
|
def get_image(image_path):
|
|
|
return Image.open(image_path).convert('RGB')
|
|
|
|
|
|
+def llama_progress_bar(total, desc, position=0):
|
|
|
+ """Custom progress bar with llama emojis."""
|
|
|
+ bar_format = "{desc}: |{bar}| {percentage:3.0f}% [{n_fmt}/{total_fmt}, {rate_fmt}{postfix}]"
|
|
|
+ return tqdm(total=total, desc=desc, position=position, bar_format=bar_format, ascii="🦙·")
|
|
|
+
|
|
|
def process_images(rank, world_size, args, model_name, input_files, output_csv):
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map=f"cuda:{rank}", torch_dtype=torch.bfloat16, token=args.hf_token)
|
|
|
processor = MllamaProcessor.from_pretrained(model_name, token=args.hf_token)
|
|
@@ -76,14 +81,9 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
|
|
|
|
|
|
results = []
|
|
|
|
|
|
- pbar = tqdm(input_files[start_idx:end_idx],
|
|
|
- desc=f"GPU {rank}",
|
|
|
- unit="img",
|
|
|
- position=rank,
|
|
|
- leave=True,
|
|
|
- bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
|
|
|
+ pbar = llama_progress_bar(total=end_idx - start_idx, desc=f"GPU {rank}", position=rank)
|
|
|
|
|
|
- for filename in pbar:
|
|
|
+ for filename in input_files[start_idx:end_idx]:
|
|
|
image_path = os.path.join(args.input_path, filename)
|
|
|
image = get_image(image_path)
|
|
|
|
|
@@ -99,8 +99,11 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
|
|
|
|
|
|
results.append((filename, decoded_output))
|
|
|
|
|
|
+ pbar.update(1)
|
|
|
pbar.set_postfix({"Last File": filename})
|
|
|
|
|
|
+ pbar.close()
|
|
|
+
|
|
|
with open(output_csv, 'w', newline='', encoding='utf-8') as f:
|
|
|
writer = csv.writer(f)
|
|
|
writer.writerow(['Filename', 'Caption'])
|
|
@@ -117,7 +120,7 @@ def main():
|
|
|
|
|
|
model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
|
|
|
|
|
|
- print("Starting image processing pipeline...")
|
|
|
+ print("🦙 Starting image processing pipeline...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
# Find and move corrupt images
|
|
@@ -127,7 +130,7 @@ def main():
|
|
|
# Get list of remaining (non-corrupt) image files
|
|
|
input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
|
|
|
|
|
- print(f"\nProcessing {len(input_files)} images using {args.num_gpus} GPUs...")
|
|
|
+ print(f"\n🦙 Processing {len(input_files)} images using {args.num_gpus} GPUs...")
|
|
|
|
|
|
mp.set_start_method('spawn', force=True)
|
|
|
processes = []
|
|
@@ -143,8 +146,8 @@ def main():
|
|
|
|
|
|
end_time = time.time()
|
|
|
total_time = end_time - start_time
|
|
|
- print(f"\nTotal processing time: {total_time:.2f} seconds")
|
|
|
- print("Image captioning completed successfully!")
|
|
|
+ print(f"\n🦙 Total processing time: {total_time:.2f} seconds")
|
|
|
+ print("🦙 Image captioning completed successfully!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|