Browse Source

MM-RAG-New-PR (#847)

Sanyam Bhutani 3 months ago
parent
commit
4621f29368

File diff suppressed because it is too large
+ 1 - 1
3p-integrations/llama_on_prem.md


File diff suppressed because it is too large
+ 7 - 7
3p-integrations/togetherai/README.md


+ 116 - 0
end-to-end-use-cases/Multi-Modal-RAG/README.md

@@ -0,0 +1,116 @@
+# End to End Tutorial on using Llama models for Multi-Modal RAG 
+
+## Recipe Overview: Multi-Modal RAG using `Llama-3.2-11B` model: 
+
+This is a complete workshop on how to label images using the new Llama 3.2-Vision Models and performing RAG using the image caption capabilities of the model.
+
+- **Data Labeling and Preparation:** We start by downloading 5000 images of clothing items and labeling them using `Llama-3.2-11B-Vision-Instruct` model
+- **Cleaning Labels:** With the labels based on the notebook above, we will then clean the dataset and prepare it for RAG
+- **Building Vector DB and RAG Pipeline:** With the final clean dataset, we can use descriptions and 11B model to generate recommendations
+
+## Requirements:
+
+Before we start:
+
+1. Please grab your HF CLI Token from [here](https://huggingface.co/settings/tokens)
+2. Git clone [this dataset](https://huggingface.co/datasets/Sanyam/MM-Demo) inside the Multi-Modal-RAG folder: `git clone https://huggingface.co/datasets/Sanyam/MM-Demo` (Remember to thank the original author by up voting [Kaggle Dataset](https://www.kaggle.com/datasets/agrigorev/clothing-dataset-full))
+3. Make sure you grab a together.ai token [here](https://www.together.ai)
+
+## Detailed Outline for running:
+
+Order of running files, the notebook establish the method of approaching the problem. Once we establish it, we use the scripts to run the method end to end.
+
+- Notebook 1: `Part_1_Data_Preparation.ipynb`
+- Script: `label_script.py`
+- Notebook 2: `Part_2_Cleaning_Data_and_DB.ipynb`
+- Notebook 3: `Part_3_RAG_Setup_and_Validation.ipynb`
+- Script: `final_demo.py`
+
+Here's the detailed outline:
+
+### Step 1: Data Prep and Synthetic Labeling:
+
+In this step we start with an unlabeled dataset and use the image captioning capability of the model to write a description of the image and categorize it.
+
+[Notebook for Step 1](./notebooks/Part_1_Data_Preparation.ipynb) and [Script for Step 1](./scripts/label_script.py)
+
+To run the script (remember to set n):
+```
+python scripts/label_script.py --hf_token "your_huggingface_token_here" \
+    --input_path "../MM-Demo/images_compressed" \
+    --output_path "../MM-Demo/output/" \
+    --num_gpus N
+```
+
+The dataset consists of 5000 images with some meta-data.
+
+The first half is preparing the dataset for labeling:
+- Clean/Remove corrupt images
+- Some exploratory analysis to understand existing distribution
+- Merging up categories of clothes to reduce complexity 
+- Balancing dataset by randomly sampling images to have an equal distribution for retrieval
+
+Second Half consists of Labeling the dataset. Llama 3.2, 11B model can only process one image at a time:
+- We load a few images and test captioning
+- We run this pipeline on random images and iterate on the prompt till we feel the model is giving good outputs
+- Finally, we can create a script to label all 5000 images on multi-GPU
+
+After running the script on the entire dataset, we have more data cleaning to perform.
+
+### Step 2: Cleaning up Synthetic Labels and preparing the dataset:
+
+[Notebook for Step 2](./notebooks/Part_2_Cleaning_Data_and_DB.ipynb)
+
+We notice that even after some fun prompt engineering, the model faces some hallucinations-there are some issues with the JSON formatting and we notice that it hallucinates the label categories. Here is how we address this:
+
+- Re-balance the dataset by mapping correct categories. This is useful to make sure we have an equal distribution in our dataset for retrieval
+- Fix Descriptions so that we can create a CSV
+
+Now, we are ready to try our vector db pipeline:
+
+### Step 3: Notebook 3: MM-RAG using lance-db to validate idea
+
+[Notebook for Step 3](./notebooks/Part_3_RAG_Setup_and_Validation.ipynb) and [Final Demo Script](./scripts/label_script.py)
+
+
+With the cleaned descriptions and dataset, we can now store these in a vector-db, here's the steps:
+
+
+- We create embeddings using the text description of our clothes
+- Use 11-B model to describe the uploaded image
+- Ask the model to suggest complementary items to the upload
+- Try to find similar or complementary images based on the upload
+
+We try the approach with different retrieval methods.
+
+Finally, we can bring this all together in a Gradio App. 
+
+For running the script:
+```
+python scripts/final_demo.py \
+    --images_folder "../MM-Demo/compressed_images" \
+    --csv_path "../MM-Demo/final_balanced_sample_dataset.csv" \
+    --table_path "~/.lancedb" \
+    --api_key "your_together_api_key" \
+    --default_model "BAAI/bge-large-en-v1.5" \
+    --use_existing_table 
+```
+
+Note: We can further improve the description prompt. You will notice sometimes the description starts with the title of the cloth which causes in retrieval of "similar" clothes instead of "complementary" items
+
+- Upload an image
+- 11B model describes the image
+- We retrieve complementary clothes to wear based on the description
+- You can keep the loop going by chatting with the model
+
+## Resources used: 
+
+Credit and Thanks to List of models and resources used in the showcase:
+
+Firstly, thanks to the author here for providing this dataset on which we base our exercise [here](https://www.kaggle.com/datasets/agrigorev/clothing-dataset-full)
+
+- [Llama-3.2-11B-Vision-Instruct Model](https://www.llama.com/docs/how-to-guides/vision-capabilities/)
+- [Lance-db for vector database](https://lancedb.com)
+- [This Kaggle dataset](https://www.kaggle.com/datasets/agrigorev/clothing-dataset-full)
+- [HF Dataset](https://huggingface.co/datasets/Sanyam/MM-Demo) Since output of the model can be non-deterministic every time we run, we will use the uploaded dataset to give a universal experience
+- [Together API for demo](https://www.together.ai)

File diff suppressed because it is too large
+ 1756 - 0
end-to-end-use-cases/Multi-Modal-RAG/notebooks/Part_1_Data_Preparation.ipynb


File diff suppressed because it is too large
+ 1822 - 0
end-to-end-use-cases/Multi-Modal-RAG/notebooks/Part_2_Cleaning_Data_and_DB.ipynb


File diff suppressed because it is too large
+ 1304 - 0
end-to-end-use-cases/Multi-Modal-RAG/notebooks/Part_3_RAG_Setup_and_Validation.ipynb


+ 285 - 0
end-to-end-use-cases/Multi-Modal-RAG/scripts/final_demo.py

@@ -0,0 +1,285 @@
+import gradio as gr
+import pandas as pd
+import lancedb
+from lancedb.pydantic import LanceModel, Vector
+from lancedb.embeddings import get_registry
+from lancedb.rerankers import ColbertReranker
+from pathlib import Path
+from PIL import Image
+import io
+import base64
+from together import Together
+import os
+import logging
+import argparse
+import numpy as np
+
+# Set up argument parsing
+parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
+parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
+parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
+parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
+parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
+parser.add_argument("--api_key", required=True, help="Together API key")
+parser.add_argument("--default_model", default="BAAI/bge-large-en-v1.5", help="Default embedding model to use")
+args = parser.parse_args()
+
+# Set up logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+print("Starting the Fashion Assistant application...")
+
+# Define available models
+AVAILABLE_MODELS = {
+    "BAAI/bge-large-en-v1.5": "bge-large-en-v1.5",
+    "BAAI/bge-small-en-v1.5": "bge-small-en-v1.5",
+    "BAAI/bge-reranker-base": "bge-reranker-base",
+    "BAAI/bge-reranker-large": "bge-reranker-large"
+}
+
+# Define retrieval methods
+RETRIEVAL_METHODS = ["Semantic Search", "Full Text Search", "Hybrid Search"]
+
+# Connect to LanceDB
+print("Connecting to LanceDB...")
+db = lancedb.connect(args.table_path)
+
+def create_table_for_model(model_name):
+    print(f"Initializing the sentence transformer model: {model_name}")
+    model = get_registry().get("sentence-transformers").create(name=model_name, device="cpu")
+    
+    class Schema(LanceModel):
+        Filename: str
+        Title: str
+        Size: str
+        Gender: str
+        Description: str = model.SourceField()
+        Category: str
+        Type: str
+        vector: Vector(model.ndims()) = model.VectorField()
+    
+    table_name = f"clothes_{AVAILABLE_MODELS[model_name]}"
+    if not args.use_existing_table:
+        tbl = db.create_table(name=table_name, schema=Schema, mode="overwrite")
+        df = pd.read_csv(args.csv_path)
+        df = df.dropna().astype(str)
+        tbl.add(df.to_dict('records'))
+        tbl.create_fts_index(["Description"], replace=True)
+        print(f"Created and populated table {table_name}")
+    else:
+        tbl = db.open_table(table_name)
+        tbl.create_fts_index(["Description"], replace=True)
+        print(f"Opened existing table {table_name}")
+    return tbl
+
+tables = {model: create_table_for_model(model) for model in AVAILABLE_MODELS}
+current_table = tables[args.default_model]
+current_retrieval_method = "Semantic Search"
+
+# Set up the Together API client
+os.environ["TOGETHER_API_KEY"] = args.api_key
+client = Together(api_key=args.api_key)
+print("Together API client set up successfully.")
+
+def encode_image(image):
+    buffered = io.BytesIO()
+    image.save(buffered, format="JPEG")
+    return base64.b64encode(buffered.getvalue()).decode('utf-8')
+
+def generate_description(image):
+    print("Generating description for uploaded image...")
+    base64_image = encode_image(image)
+    try:
+        response = client.chat.completions.create(
+            model="meta-llama/Llama-Vision-Free",
+            messages=[
+                {
+                    "role": "user",
+                    "content": [
+                        {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": f"data:image/jpeg;base64,{base64_image}"
+                            }
+                        },
+                        {
+                            "type": "text",
+                            "text": "Describe this clothing item in detail."
+                        }
+                    ]
+                }
+            ],
+            max_tokens=512,
+            temperature=0.7,
+        )
+        description = response.choices[0].message.content
+        print(f"Generated description: {description}")
+        return description
+    except Exception as e:
+        print(f"Error generating description: {e}")
+        return "Error generating description"
+
+def process_chat_input(chat_history, user_input):
+    print(f"Processing chat input: {user_input}")
+    messages = [
+        {"role": "system", "content": "You are a helpful fashion assistant."}
+    ]
+    for user_msg, assistant_msg in chat_history:
+        messages.append({"role": "user", "content": user_msg})
+        messages.append({"role": "assistant", "content": assistant_msg})
+    
+    user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY. Start your message with '1. ..' "
+    messages.append({"role": "user", "content": user_input})
+    print(f"Chat history: {messages}")
+    try:
+        bot_response = client.chat.completions.create(
+            model="meta-llama/Llama-Vision-Free",
+            messages=messages,
+            max_tokens=512,
+            temperature=0.7,
+        ).choices[0].message.content
+        
+        print(f"Bot response: {bot_response}")
+        return user_input, bot_response
+    except Exception as e:
+        print(f"Error processing chat input: {e}")
+        return user_input, "Error processing chat input"
+
+def retrieve_similar_items(description, n=10):
+    print(f"Retrieving similar items for: {description}")
+    try:
+        if current_retrieval_method == "Semantic Search":
+            results = current_table.search(description).limit(n).to_pandas()
+        elif current_retrieval_method == "Full Text Search":
+            results = current_table.search(description, query_type="fts").limit(n).to_pandas()
+        elif current_retrieval_method == "Hybrid Search":
+            reranker = ColbertReranker(
+                model_name="answerdotai/answerai-colbert-small-v1",
+                column="Description")
+            results = current_table.search(description, query_type="hybrid").rerank(reranker=reranker).limit(n).to_pandas()
+        else:
+            raise ValueError("Invalid retrieval method")
+        print(f"Retrieved {len(results)} similar items using {current_retrieval_method}.")
+        return results
+    except Exception as e:
+        print(f"Error retrieving similar items: {e}")
+        return pd.DataFrame()
+
+def rewrite_query(original_query, item_description):
+    print(f"Rewriting query: {original_query}")
+    messages = [
+        {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
+        {"role": "user", "content": f"Item description: {item_description}"},
+        {"role": "user", "content": f"User query: {original_query}"},
+        {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
+    ]
+    
+    try:
+        response = client.chat.completions.create(
+            model="meta-llama/Llama-Vision-Free",
+            messages=messages,
+            max_tokens=512,
+            temperature=0.7,
+        )
+        rewritten_query = response.choices[0].message.content
+        print(f"Rewritten query: {rewritten_query}")
+        return rewritten_query
+    except Exception as e:
+        print(f"Error rewriting query: {e}")
+        return original_query
+
+def fashion_assistant(image, chat_input, chat_history):
+    if chat_input != "":
+        print("Processing chat input...")
+        last_description = chat_history[-1][1] if chat_history else ""
+        user_message, bot_response = process_chat_input(chat_history, chat_input)
+        similar_items = retrieve_similar_items(bot_response)
+        gallery_data = create_gallery_data(similar_items)
+        return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
+    elif image is not None:
+        print("Processing uploaded image...")
+        description = generate_description(image)
+        user_message = f"I've uploaded an image. The description is: {description}"
+        user_message, bot_response = process_chat_input(chat_history, user_message)
+        similar_items = retrieve_similar_items(description)
+        gallery_data = create_gallery_data(similar_items)
+        return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
+    else:
+        print("No input provided.")
+        return chat_history, "", [], ""
+
+def create_gallery_data(results):
+    return [
+        (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
+        for _, row in results.iterrows()
+    ]
+
+def on_select(evt: gr.SelectData):
+    return f"Selected {evt.value} at index {evt.index}"
+
+def update_chat(image, chat_input, chat_history, last_description):
+    new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
+    if new_description:
+        last_description = new_description
+    return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
+
+def update_model(model_name):
+    global current_table
+    current_table = tables[model_name]
+    return f"Switched to model: {model_name}"
+
+def update_retrieval_method(method):
+    global current_retrieval_method
+    current_retrieval_method = method
+    return f"Switched to retrieval method: {method}"
+
+# Define the Gradio interface
+print("Setting up Gradio interface...")
+with gr.Blocks() as demo:
+    gr.Markdown("# Interactive Fashion Assistant")
+    with gr.Row():
+        with gr.Column(scale=1):
+            image_input = gr.Image(type="pil", label="Upload Clothing Image")
+            model_dropdown = gr.Dropdown(
+                choices=list(AVAILABLE_MODELS.keys()),
+                value=args.default_model,
+                label="Embedding Model"
+            )
+            retrieval_dropdown = gr.Dropdown(
+                choices=RETRIEVAL_METHODS,
+                value="Semantic Search",
+                label="Retrieval Method"
+            )
+        with gr.Column(scale=1):
+            chatbot = gr.Chatbot(label="Chat History")
+            chat_input = gr.Textbox(label="Chat Input")
+            chat_button = gr.Button("Send")
+        with gr.Column(scale=2):
+            gallery = gr.Gallery(
+                label="Retrieved Clothes",
+                show_label=True,
+                elem_id="gallery",
+                columns=[5],
+                rows=[2],
+                object_fit="contain",
+                height="auto"
+            )
+            selected_image = gr.Textbox(label="Selected Image")
+    
+    chat_state = gr.State([])
+    last_description = gr.State("")
+    
+    image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description], 
+                       outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
+    chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description], 
+                      outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
+    gallery.select(on_select, None, selected_image)
+    model_dropdown.change(update_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")])
+    retrieval_dropdown.change(update_retrieval_method, inputs=[retrieval_dropdown], outputs=[gr.Textbox(label="Retrieval Method Status")])
+
+    # Disable embedding model dropdown when Hybrid Search is selected
+    retrieval_dropdown.change(lambda x: gr.update(interactive=x != "Hybrid Search"), inputs=[retrieval_dropdown], outputs=[model_dropdown])
+
+print("Gradio interface set up successfully. Launching the app...")
+demo.launch()
+print("Fashion Assistant application is now running!")

+ 151 - 0
end-to-end-use-cases/Multi-Modal-RAG/scripts/label_script.py

@@ -0,0 +1,151 @@
+import os
+import argparse
+import torch
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
+from tqdm.auto import tqdm
+import csv
+from PIL import Image
+import torch.multiprocessing as mp
+from concurrent.futures import ProcessPoolExecutor
+import shutil
+import time
+
+USER_TEXT = """
+You are an expert fashion captioner, we are writing descriptions of clothes, look at the image closely and write a caption for it.
+
+Write the following Title, Size, Category, Gender, Type, Description in JSON FORMAT, PLEASE DO NOT FORGET JSON, 
+
+ALSO START WITH THE JSON AND NOT ANY THING ELSE, FIRST CHAR IN YOUR RESPONSE IS ITS OPENING BRACE
+
+FOLLOW THESE STEPS CLOSELY WHEN WRITING THE CAPTION: 
+1. Only start your response with a dictionary like the example below, nothing else, I NEED TO PARSE IT LATER, SO DONT ADD ANYTHING ELSE-IT WILL BREAK MY CODE
+Remember-DO NOT SAY ANYTHING ELSE ABOUT WHAT IS GOING ON, just the opening brace is the first thing in your response nothing else ok?
+2. REMEMBER TO CLOSE THE DICTIONARY WITH '}'BRACE, IT GOES AFTER THE END OF DESCRIPTION-YOU ALWAYS FORGET IT, THIS WILL CAUSE A LOT OF ISSUES
+3. If you cant tell the size from image, guess it! its okay but dont literally write that you guessed it
+4. Do not make the caption very literal, all of these are product photos, DO NOT CAPTION HOW OR WHERE THEY ARE PLACED, FOCUS ON WRITING ABOUT THE PIECE OF CLOTHING
+5. BE CREATIVE WITH THE DESCRIPTION BUT FOLLOW EVERYTHING CLOSELY FOR STRUCTURE
+6. Return your answer in dictionary format, see the example below
+
+{"Title": "Title of item of clothing", "Size": {'S', 'M', 'L', 'XL'}, #select one randomly if you cant tell from the image. DO NOT TELL ME YOU ESTIMATE OR GUESSED IT ONLY THE LETTER IS ENOUGH", Category":  {T-Shirt, Shoes, Tops, Pants, Jeans, Shorts, Skirts, Shoes, Footwear}, "Gender": {M, F, U}, "Type": {Casual, Formal, Work Casual, Lounge}, "Description": "Write it here"}
+
+Example: ALWAYS RETURN ANSWERS IN THE DICTIONARY FORMAT BELOW OK?
+
+{"Title": "Casual White pant with logo on it", "size": "L", "Category": "Jeans", "Gender": "U", "Type": "Work Casual", "Description": "Write it here, this is where your stuff goes"} 
+"""
+
+def is_image_corrupt(image_path):
+    try:
+        with Image.open(image_path) as img:
+            img.verify()
+        return False
+    except (IOError, SyntaxError, Image.UnidentifiedImageError):
+        return True
+
+def find_and_move_corrupt_images(folder_path, corrupt_folder):
+    image_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) 
+                   if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
+    
+    num_cores = mp.cpu_count()
+    with tqdm(total=len(image_files), desc="Checking for corrupt images", unit="file", 
+              bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") as pbar:
+        with ProcessPoolExecutor(max_workers=num_cores) as executor:
+            results = list(executor.map(is_image_corrupt, image_files))
+            pbar.update(len(image_files))
+    
+    corrupt_images = [img for img, is_corrupt in zip(image_files, results) if is_corrupt]
+    
+    os.makedirs(corrupt_folder, exist_ok=True)
+    for img in tqdm(corrupt_images, desc="Moving corrupt images", unit="file",
+                    bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"):
+        shutil.move(img, os.path.join(corrupt_folder, os.path.basename(img)))
+    
+    print(f"Moved {len(corrupt_images)} corrupt images to {corrupt_folder}")
+
+def get_image(image_path):
+    return Image.open(image_path).convert('RGB')
+
+def llama_progress_bar(total, desc, position=0):
+    """Custom progress bar with llama emojis."""
+    bar_format = "{desc}: |{bar}| {percentage:3.0f}% [{n_fmt}/{total_fmt}, {rate_fmt}{postfix}]"
+    return tqdm(total=total, desc=desc, position=position, bar_format=bar_format, ascii="🦙·")
+
+def process_images(rank, world_size, args, model_name, input_files, output_csv):
+    model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map=f"cuda:{rank}", torch_dtype=torch.bfloat16, token=args.hf_token)
+    processor = MllamaProcessor.from_pretrained(model_name, token=args.hf_token)
+
+    chunk_size = len(input_files) // world_size
+    start_idx = rank * chunk_size
+    end_idx = start_idx + chunk_size if rank < world_size - 1 else len(input_files)
+    
+    results = []
+    
+    pbar = llama_progress_bar(total=end_idx - start_idx, desc=f"GPU {rank}", position=rank)
+    
+    for filename in input_files[start_idx:end_idx]:
+        image_path = os.path.join(args.input_path, filename)
+        image = get_image(image_path)
+
+        conversation = [
+            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": USER_TEXT}]}
+        ]
+
+        prompt = processor.apply_chat_template(conversation, add_special_tokens=False, add_generation_prompt=True, tokenize=False)
+        inputs = processor(image, prompt, return_tensors="pt").to(model.device)
+
+        output = model.generate(**inputs, temperature=1, top_p=0.9, max_new_tokens=512)
+        decoded_output = processor.decode(output[0])[len(prompt):]
+
+        results.append((filename, decoded_output))
+        
+        pbar.update(1)
+        pbar.set_postfix({"Last File": filename})
+
+    pbar.close()
+
+    with open(output_csv, 'w', newline='', encoding='utf-8') as f:
+        writer = csv.writer(f)
+        writer.writerow(['Filename', 'Caption'])
+        writer.writerows(results)
+
+def main():
+    parser = argparse.ArgumentParser(description="Multi-GPU Image Captioning")
+    parser.add_argument("--hf_token", required=True, help="Hugging Face API token")
+    parser.add_argument("--input_path", required=True, help="Path to input image folder")
+    parser.add_argument("--output_path", required=True, help="Path to output CSV folder")
+    parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
+    parser.add_argument("--corrupt_folder", default="corrupt_images", help="Folder to move corrupt images")
+    args = parser.parse_args()
+
+    model_name = "meta-llama/Llama-3.2-11b-Vision-Instruct"
+
+    print("🦙 Starting image processing pipeline...")
+    start_time = time.time()
+
+    # Find and move corrupt images
+    corrupt_folder = os.path.join(args.input_path, args.corrupt_folder)
+    find_and_move_corrupt_images(args.input_path, corrupt_folder)
+
+    # Get list of remaining (non-corrupt) image files
+    input_files = [f for f in os.listdir(args.input_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
+
+    print(f"\n🦙 Processing {len(input_files)} images using {args.num_gpus} GPUs...")
+
+    mp.set_start_method('spawn', force=True)
+    processes = []
+
+    for rank in range(args.num_gpus):
+        output_csv = os.path.join(args.output_path, f"captions_gpu_{rank}.csv")
+        p = mp.Process(target=process_images, args=(rank, args.num_gpus, args, model_name, input_files, output_csv))
+        p.start()
+        processes.append(p)
+
+    for p in processes:
+        p.join()
+
+    end_time = time.time()
+    total_time = end_time - start_time
+    print(f"\n🦙 Total processing time: {total_time:.2f} seconds")
+    print("🦙 Image captioning completed successfully!")
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
getting-started/README.md

@@ -3,7 +3,7 @@
 If you are new to developing with Meta Llama models, this is where you should start. This folder contains introductory-level notebooks across different techniques relating to Meta Llama.
 
 * The [Build_with_Llama 3.2](./build_with_Llama_3_2.ipynb) notebook showcases a comprehensive walkthrough of the new capabilities of Llama 3.2 models, including multimodal use cases, function/tool calling, Llama Stack, and Llama on edge.
-* The [Prompt_Engineering_with_Llama](./Prompt_Engineering_with_Llama_3.ipynb) notebook showcases the various ways to elicit appropriate outputs from Llama. Take this notebook for a spin to get a feel for how Llama responds to different inputs and generation parameters.
+* The [Prompt_Engineering_with_Llama](./Prompt_Engineering_with_Llama.ipynb) notebook showcases the various ways to elicit appropriate outputs from Llama. Take this notebook for a spin to get a feel for how Llama responds to different inputs and generation parameters.
 * The [inference](./inference/) folder contains scripts to deploy Llama for inference on server and mobile. See also [3p_integrations/vllm](../3p-integrations/vllm/) and [3p_integrations/tgi](../3p-integrations/tgi/) for hosting Llama on open-source model servers.
 * The [RAG](./RAG/) folder contains a simple Retrieval-Augmented Generation application using Llama.
 * The [finetuning](./finetuning/) folder contains resources to help you finetune Llama on your custom datasets, for both single- and multi-GPU setups. The scripts use the native llama-recipes finetuning code found in [finetuning.py](../src/llama_recipes/finetuning.py) which supports these features:

+ 1 - 1
getting-started/finetuning/finetune_vision_model.md

@@ -28,7 +28,7 @@ For **finetuning with LLM freeze using FSDP**, we can run the following code:
 
 For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
 
-For more details about local inference with the fine-tuned checkpoint, please read [Inference with FSDP checkpoints section](https://github.com/meta-llama/llama-recipes/tree/main/recipes/quickstart/inference/local_inference#inference-with-fsdp-checkpoints) to learn how to convert the FSDP weights into a consolidated Hugging Face formatted model for local inference.
+For more details about local inference with the fine-tuned checkpoint, please read [Inference with FSDP checkpoints section](../../getting-started/inference/local_inference/#inference-with-fsdp-checkpoints) to learn how to convert the FSDP weights into a consolidated Hugging Face formatted model for local inference.
 
 ### How to use a custom dataset to fine-tune vision model
 

+ 2 - 2
src/docs/FAQ.md

@@ -36,13 +36,13 @@ Here we discuss frequently asked questions that may occur and we found useful al
     os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
 
     ```
-    We also added this environment variable in `setup_environ_flags` of the [train_utils.py](../src/llama_recipes/utils/train_utils.py), feel free to uncomment it if required.
+    We also added this environment variable in `setup_environ_flags` of the [train_utils.py](../llama_recipes/utils/train_utils.py), feel free to uncomment it if required.
 
 8. Additional debugging flags?
 
     The environment variable `TORCH_DISTRIBUTED_DEBUG` can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks are synchronized appropriately. `TORCH_DISTRIBUTED_DEBUG` can be set to either OFF (default), INFO, or DETAIL depending on the debugging level required. Please note that the most verbose option, DETAIL may impact the application performance and thus should only be used when debugging issues.
 
-    We also added this environment variable in `setup_environ_flags` of the [train_utils.py](../src/llama_recipes/utils/train_utils.py), feel free to uncomment it if required.
+    We also added this environment variable in `setup_environ_flags` of the [train_utils.py](../llama_recipes/utils/train_utils.py), feel free to uncomment it if required.
 
 9. I am getting import errors when running inference.