Browse Source

Update label_script.py

Improve TQDM
Sanyam Bhutani 8 months ago
parent
commit
4a297e5616
1 changed files with 29 additions and 5 deletions
  1. 29 5
      recipes/quickstart/Multi-Modal-RAG/label_script.py

+ 29 - 5
recipes/quickstart/Multi-Modal-RAG/label_script.py

@@ -27,12 +27,13 @@ import os
 import argparse
 import torch
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
-from tqdm import tqdm
+from tqdm.auto import tqdm
 import csv
 from PIL import Image
 import torch.multiprocessing as mp
 from concurrent.futures import ProcessPoolExecutor
 import shutil
+import time
 
 def is_image_corrupt(image_path):
     try:
@@ -47,13 +48,17 @@ def find_and_move_corrupt_images(folder_path, corrupt_folder):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
     
     num_cores = mp.cpu_count()
-    with ProcessPoolExecutor(max_workers=num_cores) as executor:
-        results = executor.map(is_image_corrupt, image_files)
+    with tqdm(total=len(image_files), desc="Checking for corrupt images", unit="file", 
+              bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") as pbar:
+        with ProcessPoolExecutor(max_workers=num_cores) as executor:
+            results = list(executor.map(is_image_corrupt, image_files))
+            pbar.update(len(image_files))
     
     corrupt_images = [img for img, is_corrupt in zip(image_files, results) if is_corrupt]
     
     os.makedirs(corrupt_folder, exist_ok=True)
-    for img in corrupt_images:
+    for img in tqdm(corrupt_images, desc="Moving corrupt images", unit="file",
+                    bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"):
         shutil.move(img, os.path.join(corrupt_folder, os.path.basename(img)))
     
     print(f"Moved {len(corrupt_images)} corrupt images to {corrupt_folder}")
@@ -71,7 +76,14 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
     
     results = []
     
-    for filename in tqdm(input_files[start_idx:end_idx], desc=f"GPU {rank} processing", position=rank):
+    pbar = tqdm(input_files[start_idx:end_idx], 
+                desc=f"GPU {rank}", 
+                unit="img",
+                position=rank,
+                leave=True,
+                bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
+    
+    for filename in pbar:
         image_path = os.path.join(args.input_path, filename)
         image = get_image(image_path)
 
@@ -86,6 +98,8 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
         decoded_output = processor.decode(output[0])[len(prompt):]
 
         results.append((filename, decoded_output))
+        
+        pbar.set_postfix({"Last File": filename})
 
     with open(output_csv, 'w', newline='', encoding='utf-8') as f:
         writer = csv.writer(f)
@@ -103,6 +117,9 @@ def main():
 
     model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
 
+    print("Starting image processing pipeline...")
+    start_time = time.time()
+
     # Find and move corrupt images
     corrupt_folder = os.path.join(args.input_path, args.corrupt_folder)
     find_and_move_corrupt_images(args.input_path, corrupt_folder)
@@ -110,6 +127,8 @@ def main():
     # Get list of remaining (non-corrupt) image files
     input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
 
+    print(f"\nProcessing {len(input_files)} images using {args.num_gpus} GPUs...")
+
     mp.set_start_method('spawn', force=True)
     processes = []
 
@@ -122,5 +141,10 @@ def main():
     for p in processes:
         p.join()
 
+    end_time = time.time()
+    total_time = end_time - start_time
+    print(f"\nTotal processing time: {total_time:.2f} seconds")
+    print("Image captioning completed successfully!")
+
 if __name__ == "__main__":
     main()