|
@@ -0,0 +1,237 @@
|
|
|
+import gradio as gr
|
|
|
+import pandas as pd
|
|
|
+import lancedb
|
|
|
+from lancedb.pydantic import LanceModel, Vector
|
|
|
+from lancedb.embeddings import get_registry
|
|
|
+from pathlib import Path
|
|
|
+from PIL import Image
|
|
|
+import io
|
|
|
+import base64
|
|
|
+from together import Together
|
|
|
+import os
|
|
|
+import logging
|
|
|
+import argparse
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+
|
|
|
+# Set up argument parsing
|
|
|
+parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
|
|
|
+parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
|
|
|
+parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
|
|
|
+parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
|
|
|
+parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
|
|
|
+parser.add_argument("--api_key", required=True, help="Together API key")
|
|
|
+args = parser.parse_args()
|
|
|
+
|
|
|
+# Set up logging
|
|
|
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
+
|
|
|
+print("Starting the Fashion Assistant application...")
|
|
|
+
|
|
|
+# Set up the sentence transformer model for CPU use
|
|
|
+print("Initializing the sentence transformer model...")
|
|
|
+model = get_registry().get("sentence-transformers").create(name="BAAI/bge-large-en-v1.5", device="cpu")
|
|
|
+print("Sentence transformer model initialized successfully.")
|
|
|
+
|
|
|
+# Define the schema for our LanceDB table
|
|
|
+class Schema(LanceModel):
|
|
|
+ Filename: str
|
|
|
+ Title: str
|
|
|
+ Size: str
|
|
|
+ Gender: str
|
|
|
+ Description: str = model.SourceField()
|
|
|
+ Category: str
|
|
|
+ Type: str
|
|
|
+ vector: Vector(model.ndims()) = model.VectorField()
|
|
|
+
|
|
|
+# Connect to LanceDB and create/load the table
|
|
|
+print("Connecting to LanceDB...")
|
|
|
+db = lancedb.connect(args.table_path)
|
|
|
+if args.use_existing_table:
|
|
|
+ tbl = db.open_table("clothes")
|
|
|
+else:
|
|
|
+ tbl = db.create_table(name="clothes", schema=Schema, mode="overwrite")
|
|
|
+ # Load and clean data
|
|
|
+ print(f"Loading and cleaning data from CSV: {args.csv_path}")
|
|
|
+ df = pd.read_csv(args.csv_path)
|
|
|
+ df = df.dropna().astype(str)
|
|
|
+ tbl.add(df.to_dict('records'))
|
|
|
+ print(f"Loaded and cleaned {len(df)} records into the LanceDB table.")
|
|
|
+
|
|
|
+print("Connected to LanceDB and created/loaded the 'clothes' table.")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# Set up the Together API client
|
|
|
+os.environ["TOGETHER_API_KEY"] = args.api_key
|
|
|
+client = Together(api_key=args.api_key)
|
|
|
+print("Together API client set up successfully.")
|
|
|
+
|
|
|
+def encode_image(image):
|
|
|
+ buffered = io.BytesIO()
|
|
|
+ image.save(buffered, format="JPEG")
|
|
|
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
|
+
|
|
|
+def generate_description(image):
|
|
|
+ print("Generating description for uploaded image...")
|
|
|
+ base64_image = encode_image(image)
|
|
|
+ try:
|
|
|
+ response = client.chat.completions.create(
|
|
|
+ model="meta-llama/Llama-Vision-Free",
|
|
|
+ messages=[
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": f"data:image/jpeg;base64,{base64_image}"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": "Describe this clothing item in detail."
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ max_tokens=512,
|
|
|
+ temperature=0.7,
|
|
|
+ )
|
|
|
+ description = response.choices[0].message.content
|
|
|
+ print(f"Generated description: {description}")
|
|
|
+ return description
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error generating description: {e}")
|
|
|
+ return "Error generating description"
|
|
|
+
|
|
|
+def process_chat_input(chat_history, user_input):
|
|
|
+ print(f"Processing chat input: {user_input}")
|
|
|
+ messages = [
|
|
|
+ {"role": "system", "content": "You are a helpful fashion assistant."}
|
|
|
+ ]
|
|
|
+ for user_msg, assistant_msg in chat_history:
|
|
|
+ messages.append({"role": "user", "content": user_msg})
|
|
|
+ messages.append({"role": "assistant", "content": assistant_msg})
|
|
|
+
|
|
|
+ user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY."
|
|
|
+ messages.append({"role": "user", "content": user_input})
|
|
|
+ print(f"Chat history: {messages}")
|
|
|
+ try:
|
|
|
+ bot_response = client.chat.completions.create(
|
|
|
+ model="meta-llama/Llama-Vision-Free",
|
|
|
+ messages=messages,
|
|
|
+ max_tokens=512,
|
|
|
+ temperature=0.7,
|
|
|
+ ).choices[0].message.content
|
|
|
+
|
|
|
+ print(f"Bot response: {bot_response}")
|
|
|
+ return user_input, bot_response
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error processing chat input: {e}")
|
|
|
+ return user_input, "Error processing chat input"
|
|
|
+
|
|
|
+def retrieve_similar_items(description, n=10):
|
|
|
+ print(f"Retrieving similar items for: {description}")
|
|
|
+ try:
|
|
|
+ results = tbl.search(description).limit(n).to_pandas()
|
|
|
+ print(f"Retrieved {len(results)} similar items.")
|
|
|
+ return results
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error retrieving similar items: {e}")
|
|
|
+ return pd.DataFrame()
|
|
|
+
|
|
|
+def rewrite_query(original_query, item_description):
|
|
|
+ print(f"Rewriting query: {original_query}")
|
|
|
+ messages = [
|
|
|
+ {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
|
|
|
+ {"role": "user", "content": f"Item description: {item_description}"},
|
|
|
+ {"role": "user", "content": f"User query: {original_query}"},
|
|
|
+ {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
|
|
|
+ ]
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = client.chat.completions.create(
|
|
|
+ model="meta-llama/Llama-Vision-Free",
|
|
|
+ messages=messages,
|
|
|
+ max_tokens=512,
|
|
|
+ temperature=0.7,
|
|
|
+ )
|
|
|
+ rewritten_query = response.choices[0].message.content
|
|
|
+ print(f"Rewritten query: {rewritten_query}")
|
|
|
+ return rewritten_query
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error rewriting query: {e}")
|
|
|
+ return original_query
|
|
|
+
|
|
|
+def fashion_assistant(image, chat_input, chat_history):
|
|
|
+ if chat_input is not "":
|
|
|
+ print("Processing chat input...")
|
|
|
+ last_description = chat_history[-1][1] if chat_history else ""
|
|
|
+ #rewritten_query = rewrite_query(chat_input, last_description)
|
|
|
+ user_message, bot_response = process_chat_input(chat_history, chat_input)
|
|
|
+ similar_items = retrieve_similar_items(bot_response)
|
|
|
+ gallery_data = create_gallery_data(similar_items)
|
|
|
+ return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
|
|
|
+ elif image is not None:
|
|
|
+ print("Processing uploaded image...")
|
|
|
+ description = generate_description(image)
|
|
|
+ user_message = f"I've uploaded an image. The description is: {description}"
|
|
|
+ user_message, bot_response = process_chat_input(chat_history, user_message)
|
|
|
+ similar_items = retrieve_similar_items(description)
|
|
|
+ gallery_data = create_gallery_data(similar_items)
|
|
|
+ return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
|
|
|
+ else:
|
|
|
+ print("No input provided.")
|
|
|
+ return chat_history, "", [], ""
|
|
|
+
|
|
|
+def create_gallery_data(results):
|
|
|
+ return [
|
|
|
+ (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
|
|
|
+ for _, row in results.iterrows()
|
|
|
+ ]
|
|
|
+
|
|
|
+def on_select(evt: gr.SelectData):
|
|
|
+ return f"Selected {evt.value} at index {evt.index}"
|
|
|
+
|
|
|
+def update_chat(image, chat_input, chat_history, last_description):
|
|
|
+ new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
|
|
|
+ if new_description:
|
|
|
+ last_description = new_description
|
|
|
+ return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
|
|
|
+
|
|
|
+# Define the Gradio interface
|
|
|
+print("Setting up Gradio interface...")
|
|
|
+with gr.Blocks() as demo:
|
|
|
+ gr.Markdown("# Interactive Fashion Assistant")
|
|
|
+ with gr.Row():
|
|
|
+ with gr.Column(scale=1):
|
|
|
+ image_input = gr.Image(type="pil", label="Upload Clothing Image")
|
|
|
+ with gr.Column(scale=1):
|
|
|
+ chatbot = gr.Chatbot(label="Chat History")
|
|
|
+ chat_input = gr.Textbox(label="Chat Input")
|
|
|
+ chat_button = gr.Button("Send")
|
|
|
+ with gr.Column(scale=2):
|
|
|
+ gallery = gr.Gallery(
|
|
|
+ label="Retrieved Clothes",
|
|
|
+ show_label=True,
|
|
|
+ elem_id="gallery",
|
|
|
+ columns=[5],
|
|
|
+ rows=[2],
|
|
|
+ object_fit="contain",
|
|
|
+ height="auto"
|
|
|
+ )
|
|
|
+ selected_image = gr.Textbox(label="Selected Image")
|
|
|
+
|
|
|
+ chat_state = gr.State([])
|
|
|
+ last_description = gr.State("")
|
|
|
+
|
|
|
+ image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
|
|
|
+ outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
|
|
|
+ chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
|
|
|
+ outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
|
|
|
+ gallery.select(on_select, None, selected_image)
|
|
|
+
|
|
|
+print("Gradio interface set up successfully. Launching the app...")
|
|
|
+demo.launch()
|
|
|
+print("Fashion Assistant application is now running!")
|