|
@@ -27,12 +27,13 @@ import os
|
|
|
import argparse
|
|
|
import torch
|
|
|
from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
-from tqdm import tqdm
|
|
|
+from tqdm.auto import tqdm
|
|
|
import csv
|
|
|
from PIL import Image
|
|
|
import torch.multiprocessing as mp
|
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
|
import shutil
|
|
|
+import time
|
|
|
|
|
|
def is_image_corrupt(image_path):
|
|
|
try:
|
|
@@ -47,13 +48,17 @@ def find_and_move_corrupt_images(folder_path, corrupt_folder):
|
|
|
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
|
|
|
|
|
num_cores = mp.cpu_count()
|
|
|
- with ProcessPoolExecutor(max_workers=num_cores) as executor:
|
|
|
- results = executor.map(is_image_corrupt, image_files)
|
|
|
+ with tqdm(total=len(image_files), desc="Checking for corrupt images", unit="file",
|
|
|
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") as pbar:
|
|
|
+ with ProcessPoolExecutor(max_workers=num_cores) as executor:
|
|
|
+ results = list(executor.map(is_image_corrupt, image_files))
|
|
|
+ pbar.update(len(image_files))
|
|
|
|
|
|
corrupt_images = [img for img, is_corrupt in zip(image_files, results) if is_corrupt]
|
|
|
|
|
|
os.makedirs(corrupt_folder, exist_ok=True)
|
|
|
- for img in corrupt_images:
|
|
|
+ for img in tqdm(corrupt_images, desc="Moving corrupt images", unit="file",
|
|
|
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"):
|
|
|
shutil.move(img, os.path.join(corrupt_folder, os.path.basename(img)))
|
|
|
|
|
|
print(f"Moved {len(corrupt_images)} corrupt images to {corrupt_folder}")
|
|
@@ -71,7 +76,14 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
|
|
|
|
|
|
results = []
|
|
|
|
|
|
- for filename in tqdm(input_files[start_idx:end_idx], desc=f"GPU {rank} processing", position=rank):
|
|
|
+ pbar = tqdm(input_files[start_idx:end_idx],
|
|
|
+ desc=f"GPU {rank}",
|
|
|
+ unit="img",
|
|
|
+ position=rank,
|
|
|
+ leave=True,
|
|
|
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
|
|
|
+
|
|
|
+ for filename in pbar:
|
|
|
image_path = os.path.join(args.input_path, filename)
|
|
|
image = get_image(image_path)
|
|
|
|
|
@@ -86,6 +98,8 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
|
|
|
decoded_output = processor.decode(output[0])[len(prompt):]
|
|
|
|
|
|
results.append((filename, decoded_output))
|
|
|
+
|
|
|
+ pbar.set_postfix({"Last File": filename})
|
|
|
|
|
|
with open(output_csv, 'w', newline='', encoding='utf-8') as f:
|
|
|
writer = csv.writer(f)
|
|
@@ -103,6 +117,9 @@ def main():
|
|
|
|
|
|
model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
|
|
|
|
|
|
+ print("Starting image processing pipeline...")
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
# Find and move corrupt images
|
|
|
corrupt_folder = os.path.join(args.input_path, args.corrupt_folder)
|
|
|
find_and_move_corrupt_images(args.input_path, corrupt_folder)
|
|
@@ -110,6 +127,8 @@ def main():
|
|
|
# Get list of remaining (non-corrupt) image files
|
|
|
input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
|
|
|
|
|
+ print(f"\nProcessing {len(input_files)} images using {args.num_gpus} GPUs...")
|
|
|
+
|
|
|
mp.set_start_method('spawn', force=True)
|
|
|
processes = []
|
|
|
|
|
@@ -122,5 +141,10 @@ def main():
|
|
|
for p in processes:
|
|
|
p.join()
|
|
|
|
|
|
+ end_time = time.time()
|
|
|
+ total_time = end_time - start_time
|
|
|
+ print(f"\nTotal processing time: {total_time:.2f} seconds")
|
|
|
+ print("Image captioning completed successfully!")
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
main()
|