|
@@ -1,12 +1,3 @@
|
|
|
-import os
|
|
|
-import argparse
|
|
|
-import torch
|
|
|
-from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
-from tqdm import tqdm
|
|
|
-import csv
|
|
|
-from PIL import Image
|
|
|
-import torch.multiprocessing as mp
|
|
|
-
|
|
|
USER_TEXT = """
|
|
|
You are an expert fashion captioner, we are writing descriptions of clothes, look at the image closely and write a caption for it.
|
|
|
|
|
@@ -29,25 +20,57 @@ Remember-DO NOT SAY ANYTHING ELSE ABOUT WHAT IS GOING ON, just the opening brace
|
|
|
|
|
|
Example: ALWAYS RETURN ANSWERS IN THE DICTIONARY FORMAT BELOW OK?
|
|
|
|
|
|
-{"Title": "Casual White pant with logo on it", "size": "L", "Category": "Jeans", "Gender": "U", "Type": "Work Casual", "Description": "Write it here, this is where your stuff goes"}
|
|
|
+{"Title": "Casual White pant with logo on it", "size": "L", "Category": "Jeans", "Gender": "U", "Type": "Work Casual", "Description": "Write it here, this is where your stuff goes"} "
|
|
|
"""
|
|
|
|
|
|
+import os
|
|
|
+import argparse
|
|
|
+import torch
|
|
|
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
|
|
|
+from tqdm import tqdm
|
|
|
+import csv
|
|
|
+from PIL import Image
|
|
|
+import torch.multiprocessing as mp
|
|
|
+from concurrent.futures import ProcessPoolExecutor
|
|
|
+import shutil
|
|
|
+
|
|
|
+def is_image_corrupt(image_path):
|
|
|
+ try:
|
|
|
+ with Image.open(image_path) as img:
|
|
|
+ img.verify()
|
|
|
+ return False
|
|
|
+ except (IOError, SyntaxError, Image.UnidentifiedImageError):
|
|
|
+ return True
|
|
|
+
|
|
|
+def find_and_move_corrupt_images(folder_path, corrupt_folder):
|
|
|
+ image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path)
|
|
|
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
|
|
+
|
|
|
+ num_cores = mp.cpu_count()
|
|
|
+ with ProcessPoolExecutor(max_workers=num_cores) as executor:
|
|
|
+ results = executor.map(is_image_corrupt, image_files)
|
|
|
+
|
|
|
+ corrupt_images = [img for img, is_corrupt in zip(image_files, results) if is_corrupt]
|
|
|
+
|
|
|
+ os.makedirs(corrupt_folder, exist_ok=True)
|
|
|
+ for img in corrupt_images:
|
|
|
+ shutil.move(img, os.path.join(corrupt_folder, os.path.basename(img)))
|
|
|
+
|
|
|
+ print(f"Moved {len(corrupt_images)} corrupt images to {corrupt_folder}")
|
|
|
+
|
|
|
def get_image(image_path):
|
|
|
return Image.open(image_path).convert('RGB')
|
|
|
|
|
|
def process_images(rank, world_size, args, model_name, input_files, output_csv):
|
|
|
- # Set up the model and processor for this GPU
|
|
|
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map=f"cuda:{rank}", torch_dtype=torch.bfloat16, token=args.hf_token)
|
|
|
processor = MllamaProcessor.from_pretrained(model_name, token=args.hf_token)
|
|
|
|
|
|
- # Calculate the chunk of files this GPU will process
|
|
|
chunk_size = len(input_files) // world_size
|
|
|
start_idx = rank * chunk_size
|
|
|
end_idx = start_idx + chunk_size if rank < world_size - 1 else len(input_files)
|
|
|
|
|
|
results = []
|
|
|
|
|
|
- # Process files with TQDM
|
|
|
for filename in tqdm(input_files[start_idx:end_idx], desc=f"GPU {rank} processing", position=rank):
|
|
|
image_path = os.path.join(args.input_path, filename)
|
|
|
image = get_image(image_path)
|
|
@@ -64,7 +87,6 @@ def process_images(rank, world_size, args, model_name, input_files, output_csv):
|
|
|
|
|
|
results.append((filename, decoded_output))
|
|
|
|
|
|
- # Write results to CSV
|
|
|
with open(output_csv, 'w', newline='', encoding='utf-8') as f:
|
|
|
writer = csv.writer(f)
|
|
|
writer.writerow(['Filename', 'Caption'])
|
|
@@ -76,14 +98,18 @@ def main():
|
|
|
parser.add_argument("--input_path", required=True, help="Path to input image folder")
|
|
|
parser.add_argument("--output_path", required=True, help="Path to output CSV folder")
|
|
|
parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
|
|
|
+ parser.add_argument("--corrupt_folder", default="corrupt_images", help="Folder to move corrupt images")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
|
|
|
|
|
|
- # Get list of image files
|
|
|
- input_files = [f for f in os.listdir(args.input_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
|
|
|
+ # Find and move corrupt images
|
|
|
+ corrupt_folder = os.path.join(args.input_path, args.corrupt_folder)
|
|
|
+ find_and_move_corrupt_images(args.input_path, corrupt_folder)
|
|
|
+
|
|
|
+ # Get list of remaining (non-corrupt) image files
|
|
|
+ input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
|
|
|
|
|
- # Set up multi-processing
|
|
|
mp.set_start_method('spawn', force=True)
|
|
|
processes = []
|
|
|
|