Browse Source

Add Gradio App and Re-fact Part 3 nb

Sanyam Bhutani 6 tháng trước cách đây
mục cha
commit
abe2c5aed3

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 1809 - 0
recipes/quickstart/Multi-Modal-RAG/Part_3.ipynb


+ 237 - 0
recipes/quickstart/Multi-Modal-RAG/gradio_app.py

@@ -0,0 +1,237 @@
+import gradio as gr
+import pandas as pd
+import lancedb
+from lancedb.pydantic import LanceModel, Vector
+from lancedb.embeddings import get_registry
+from pathlib import Path
+from PIL import Image
+import io
+import base64
+from together import Together
+import os
+import logging
+import argparse
+import numpy as np
+
+
+# Set up argument parsing
+parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
+parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
+parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
+parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
+parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
+parser.add_argument("--api_key", required=True, help="Together API key")
+args = parser.parse_args()
+
+# Set up logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+print("Starting the Fashion Assistant application...")
+
+# Set up the sentence transformer model for CPU use
+print("Initializing the sentence transformer model...")
+model = get_registry().get("sentence-transformers").create(name="BAAI/bge-large-en-v1.5", device="cpu")
+print("Sentence transformer model initialized successfully.")
+
+# Define the schema for our LanceDB table
+class Schema(LanceModel):
+    Filename: str
+    Title: str
+    Size: str
+    Gender: str
+    Description: str = model.SourceField()
+    Category: str
+    Type: str
+    vector: Vector(model.ndims()) = model.VectorField()
+
+# Connect to LanceDB and create/load the table
+print("Connecting to LanceDB...")
+db = lancedb.connect(args.table_path)
+if args.use_existing_table:
+    tbl = db.open_table("clothes")
+else:
+    tbl = db.create_table(name="clothes", schema=Schema, mode="overwrite")
+    # Load and clean data
+    print(f"Loading and cleaning data from CSV: {args.csv_path}")
+    df = pd.read_csv(args.csv_path)
+    df = df.dropna().astype(str)
+    tbl.add(df.to_dict('records'))
+    print(f"Loaded and cleaned {len(df)} records into the LanceDB table.")
+
+print("Connected to LanceDB and created/loaded the 'clothes' table.")
+
+
+
+# Set up the Together API client
+os.environ["TOGETHER_API_KEY"] = args.api_key
+client = Together(api_key=args.api_key)
+print("Together API client set up successfully.")
+
+def encode_image(image):
+    buffered = io.BytesIO()
+    image.save(buffered, format="JPEG")
+    return base64.b64encode(buffered.getvalue()).decode('utf-8')
+
+def generate_description(image):
+    print("Generating description for uploaded image...")
+    base64_image = encode_image(image)
+    try:
+        response = client.chat.completions.create(
+            model="meta-llama/Llama-Vision-Free",
+            messages=[
+                {
+                    "role": "user",
+                    "content": [
+                        {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": f"data:image/jpeg;base64,{base64_image}"
+                            }
+                        },
+                        {
+                            "type": "text",
+                            "text": "Describe this clothing item in detail."
+                        }
+                    ]
+                }
+            ],
+            max_tokens=512,
+            temperature=0.7,
+        )
+        description = response.choices[0].message.content
+        print(f"Generated description: {description}")
+        return description
+    except Exception as e:
+        print(f"Error generating description: {e}")
+        return "Error generating description"
+
+def process_chat_input(chat_history, user_input):
+    print(f"Processing chat input: {user_input}")
+    messages = [
+        {"role": "system", "content": "You are a helpful fashion assistant."}
+    ]
+    for user_msg, assistant_msg in chat_history:
+        messages.append({"role": "user", "content": user_msg})
+        messages.append({"role": "assistant", "content": assistant_msg})
+    
+    user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY."
+    messages.append({"role": "user", "content": user_input})
+    print(f"Chat history: {messages}")
+    try:
+        bot_response = client.chat.completions.create(
+            model="meta-llama/Llama-Vision-Free",
+            messages=messages,
+            max_tokens=512,
+            temperature=0.7,
+        ).choices[0].message.content
+        
+        print(f"Bot response: {bot_response}")
+        return user_input, bot_response
+    except Exception as e:
+        print(f"Error processing chat input: {e}")
+        return user_input, "Error processing chat input"
+
+def retrieve_similar_items(description, n=10):
+    print(f"Retrieving similar items for: {description}")
+    try:
+        results = tbl.search(description).limit(n).to_pandas()
+        print(f"Retrieved {len(results)} similar items.")
+        return results
+    except Exception as e:
+        print(f"Error retrieving similar items: {e}")
+        return pd.DataFrame()
+
+def rewrite_query(original_query, item_description):
+    print(f"Rewriting query: {original_query}")
+    messages = [
+        {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
+        {"role": "user", "content": f"Item description: {item_description}"},
+        {"role": "user", "content": f"User query: {original_query}"},
+        {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
+    ]
+    
+    try:
+        response = client.chat.completions.create(
+            model="meta-llama/Llama-Vision-Free",
+            messages=messages,
+            max_tokens=512,
+            temperature=0.7,
+        )
+        rewritten_query = response.choices[0].message.content
+        print(f"Rewritten query: {rewritten_query}")
+        return rewritten_query
+    except Exception as e:
+        print(f"Error rewriting query: {e}")
+        return original_query
+
+def fashion_assistant(image, chat_input, chat_history):
+    if chat_input is not "":
+        print("Processing chat input...")
+        last_description = chat_history[-1][1] if chat_history else ""
+        #rewritten_query = rewrite_query(chat_input, last_description)
+        user_message, bot_response = process_chat_input(chat_history, chat_input)
+        similar_items = retrieve_similar_items(bot_response)
+        gallery_data = create_gallery_data(similar_items)
+        return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
+    elif image is not None:
+        print("Processing uploaded image...")
+        description = generate_description(image)
+        user_message = f"I've uploaded an image. The description is: {description}"
+        user_message, bot_response = process_chat_input(chat_history, user_message)
+        similar_items = retrieve_similar_items(description)
+        gallery_data = create_gallery_data(similar_items)
+        return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
+    else:
+        print("No input provided.")
+        return chat_history, "", [], ""
+
+def create_gallery_data(results):
+    return [
+        (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
+        for _, row in results.iterrows()
+    ]
+
+def on_select(evt: gr.SelectData):
+    return f"Selected {evt.value} at index {evt.index}"
+
+def update_chat(image, chat_input, chat_history, last_description):
+    new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
+    if new_description:
+        last_description = new_description
+    return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
+
+# Define the Gradio interface
+print("Setting up Gradio interface...")
+with gr.Blocks() as demo:
+    gr.Markdown("# Interactive Fashion Assistant")
+    with gr.Row():
+        with gr.Column(scale=1):
+            image_input = gr.Image(type="pil", label="Upload Clothing Image")
+        with gr.Column(scale=1):
+            chatbot = gr.Chatbot(label="Chat History")
+            chat_input = gr.Textbox(label="Chat Input")
+            chat_button = gr.Button("Send")
+        with gr.Column(scale=2):
+            gallery = gr.Gallery(
+                label="Retrieved Clothes",
+                show_label=True,
+                elem_id="gallery",
+                columns=[5],
+                rows=[2],
+                object_fit="contain",
+                height="auto"
+            )
+            selected_image = gr.Textbox(label="Selected Image")
+    
+    chat_state = gr.State([])
+    last_description = gr.State("")
+    
+    image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description], 
+                       outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
+    chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description], 
+                      outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
+    gallery.select(on_select, None, selected_image)
+
+print("Gradio interface set up successfully. Launching the app...")
+demo.launch()
+print("Fashion Assistant application is now running!")