Browse Source

Update label_script.py

Improve TQDM
Sanyam Bhutani 8 months ago
parent
commit
4a297e5616
1 changed files with 29 additions and 5 deletions
  1. 29 5
      recipes/quickstart/Multi-Modal-RAG/label_script.py

+ 29 - 5
recipes/quickstart/Multi-Modal-RAG/label_script.py

@@ -27,12 +27,13 @@ import os
 import argparse
 import argparse
 import torch
 import torch
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
-from tqdm import tqdm
+from tqdm.auto import tqdm
 import csv
 import csv
 from PIL import Image
 from PIL import Image
 import torch.multiprocessing as mp
 import torch.multiprocessing as mp
 from concurrent.futures import ProcessPoolExecutor
 from concurrent.futures import ProcessPoolExecutor
 import shutil
 import shutil
+import time
 
 
 def is_image_corrupt(image_path):
 def is_image_corrupt(image_path):
     try:
     try:
@@ -47,13 +48,17 @@ def find_and_move_corrupt_images(folder_path, corrupt_folder):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                    if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
     
     
     num_cores = mp.cpu_count()
     num_cores = mp.cpu_count()
-    with ProcessPoolExecutor(max_workers=num_cores) as executor:
-        results = executor.map(is_image_corrupt, image_files)
+    with tqdm(total=len(image_files), desc="Checking for corrupt images", unit="file", 
+              bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") as pbar:
+        with ProcessPoolExecutor(max_workers=num_cores) as executor:
+            results = list(executor.map(is_image_corrupt, image_files))
+            pbar.update(len(image_files))
     
     
     corrupt_images = [img for img, is_corrupt in zip(image_files, results) if is_corrupt]
     corrupt_images = [img for img, is_corrupt in zip(image_files, results) if is_corrupt]
     
     
     os.makedirs(corrupt_folder, exist_ok=True)
     os.makedirs(corrupt_folder, exist_ok=True)
-    for img in corrupt_images:
+    for img in tqdm(corrupt_images, desc="Moving corrupt images", unit="file",
+                    bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"):
         shutil.move(img, os.path.join(corrupt_folder, os.path.basename(img)))
         shutil.move(img, os.path.join(corrupt_folder, os.path.basename(img)))
     
     
     print(f"Moved {len(corrupt_images)} corrupt images to {corrupt_folder}")
     print(f"Moved {len(corrupt_images)} corrupt images to {corrupt_folder}")
@@ -71,7 +76,14 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
     
     
     results = []
     results = []
     
     
-    for filename in tqdm(input_files[start_idx:end_idx], desc=f"GPU {rank} processing", position=rank):
+    pbar = tqdm(input_files[start_idx:end_idx], 
+                desc=f"GPU {rank}", 
+                unit="img",
+                position=rank,
+                leave=True,
+                bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
+    
+    for filename in pbar:
         image_path = os.path.join(args.input_path, filename)
         image_path = os.path.join(args.input_path, filename)
         image = get_image(image_path)
         image = get_image(image_path)
 
 
@@ -86,6 +98,8 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
         decoded_output = processor.decode(output[0])[len(prompt):]
         decoded_output = processor.decode(output[0])[len(prompt):]
 
 
         results.append((filename, decoded_output))
         results.append((filename, decoded_output))
+        
+        pbar.set_postfix({"Last File": filename})
 
 
     with open(output_csv, 'w', newline='', encoding='utf-8') as f:
     with open(output_csv, 'w', newline='', encoding='utf-8') as f:
         writer = csv.writer(f)
         writer = csv.writer(f)
@@ -103,6 +117,9 @@ def main():
 
 
     model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
     model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
 
 
+    print("Starting image processing pipeline...")
+    start_time = time.time()
+
     # Find and move corrupt images
     # Find and move corrupt images
     corrupt_folder = os.path.join(args.input_path, args.corrupt_folder)
     corrupt_folder = os.path.join(args.input_path, args.corrupt_folder)
     find_and_move_corrupt_images(args.input_path, corrupt_folder)
     find_and_move_corrupt_images(args.input_path, corrupt_folder)
@@ -110,6 +127,8 @@ def main():
     # Get list of remaining (non-corrupt) image files
     # Get list of remaining (non-corrupt) image files
     input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
     input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
 
 
+    print(f"\nProcessing {len(input_files)} images using {args.num_gpus} GPUs...")
+
     mp.set_start_method('spawn', force=True)
     mp.set_start_method('spawn', force=True)
     processes = []
     processes = []
 
 
@@ -122,5 +141,10 @@ def main():
     for p in processes:
     for p in processes:
         p.join()
         p.join()
 
 
+    end_time = time.time()
+    total_time = end_time - start_time
+    print(f"\nTotal processing time: {total_time:.2f} seconds")
+    print("Image captioning completed successfully!")
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     main()
     main()