Browse Source

Update label_script.py

Added llama to progress bar
Sanyam Bhutani 7 months ago
parent
commit
7a9e595d10
1 changed files with 26 additions and 23 deletions
  1. 26 23
      recipes/quickstart/Multi-Modal-RAG/label_script.py

+ 26 - 23
recipes/quickstart/Multi-Modal-RAG/label_script.py

@@ -1,3 +1,15 @@
+import os
+import argparse
+import torch
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
+from tqdm.auto import tqdm
+import csv
+from PIL import Image
+import torch.multiprocessing as mp
+from concurrent.futures import ProcessPoolExecutor
+import shutil
+import time
+
 USER_TEXT = """
 You are an expert fashion captioner, we are writing descriptions of clothes, look at the image closely and write a caption for it.
 
@@ -23,18 +35,6 @@ Example: ALWAYS RETURN ANSWERS IN THE DICTIONARY FORMAT BELOW OK?
 {"Title": "Casual White pant with logo on it", "size": "L", "Category": "Jeans", "Gender": "U", "Type": "Work Casual", "Description": "Write it here, this is where your stuff goes"} 
 """
 
-import os
-import argparse
-import torch
-from transformers import MllamaForConditionalGeneration, MllamaProcessor
-from tqdm.auto import tqdm
-import csv
-from PIL import Image
-import torch.multiprocessing as mp
-from concurrent.futures import ProcessPoolExecutor
-import shutil
-import time
-
 def is_image_corrupt(image_path):
     try:
         with Image.open(image_path) as img:
@@ -66,6 +66,11 @@ def find_and_move_corrupt_images(folder_path, corrupt_folder):
 def get_image(image_path):
     return Image.open(image_path).convert('RGB')
 
+def llama_progress_bar(total, desc, position=0):
+    """Custom progress bar with llama emojis."""
+    bar_format = "{desc}: |{bar}| {percentage:3.0f}% [{n_fmt}/{total_fmt}, {rate_fmt}{postfix}]"
+    return tqdm(total=total, desc=desc, position=position, bar_format=bar_format, ascii="🦙·")
+
 def process_images(rank, world_size, args, model_name, input_files, output_csv):
     model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map=f"cuda:{rank}", torch_dtype=torch.bfloat16, token=args.hf_token)
     processor = MllamaProcessor.from_pretrained(model_name, token=args.hf_token)
@@ -76,14 +81,9 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
     
     results = []
     
-    pbar = tqdm(input_files[start_idx:end_idx], 
-                desc=f"GPU {rank}", 
-                unit="img",
-                position=rank,
-                leave=True,
-                bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
+    pbar = llama_progress_bar(total=end_idx - start_idx, desc=f"GPU {rank}", position=rank)
     
-    for filename in pbar:
+    for filename in input_files[start_idx:end_idx]:
         image_path = os.path.join(args.input_path, filename)
         image = get_image(image_path)
 
@@ -99,8 +99,11 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
 
         results.append((filename, decoded_output))
         
+        pbar.update(1)
         pbar.set_postfix({"Last File": filename})
 
+    pbar.close()
+
     with open(output_csv, 'w', newline='', encoding='utf-8') as f:
         writer = csv.writer(f)
         writer.writerow(['Filename', 'Caption'])
@@ -117,7 +120,7 @@ def main():
 
     model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
 
-    print("Starting image processing pipeline...")
+    print("🦙 Starting image processing pipeline...")
     start_time = time.time()
 
     # Find and move corrupt images
@@ -127,7 +130,7 @@ def main():
     # Get list of remaining (non-corrupt) image files
     input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
 
-    print(f"\nProcessing {len(input_files)} images using {args.num_gpus} GPUs...")
+    print(f"\n🦙 Processing {len(input_files)} images using {args.num_gpus} GPUs...")
 
     mp.set_start_method('spawn', force=True)
     processes = []
@@ -143,8 +146,8 @@ def main():
 
     end_time = time.time()
     total_time = end_time - start_time
-    print(f"\nTotal processing time: {total_time:.2f} seconds")
-    print("Image captioning completed successfully!")
+    print(f"\n🦙 Total processing time: {total_time:.2f} seconds")
+    print("🦙 Image captioning completed successfully!")
 
 if __name__ == "__main__":
     main()