|
@@ -1,237 +0,0 @@
|
|
|
-import gradio as gr
|
|
|
-import pandas as pd
|
|
|
-import lancedb
|
|
|
-from lancedb.pydantic import LanceModel, Vector
|
|
|
-from lancedb.embeddings import get_registry
|
|
|
-from pathlib import Path
|
|
|
-from PIL import Image
|
|
|
-import io
|
|
|
-import base64
|
|
|
-from together import Together
|
|
|
-import os
|
|
|
-import logging
|
|
|
-import argparse
|
|
|
-import numpy as np
|
|
|
-
|
|
|
-
|
|
|
-# Set up argument parsing
|
|
|
-parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
|
|
|
-parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
|
|
|
-parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
|
|
|
-parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
|
|
|
-parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
|
|
|
-parser.add_argument("--api_key", required=True, help="Together API key")
|
|
|
-args = parser.parse_args()
|
|
|
-
|
|
|
-# Set up logging
|
|
|
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
-
|
|
|
-print("Starting the Fashion Assistant application...")
|
|
|
-
|
|
|
-# Set up the sentence transformer model for CPU use
|
|
|
-print("Initializing the sentence transformer model...")
|
|
|
-model = get_registry().get("sentence-transformers").create(name="BAAI/bge-large-en-v1.5", device="cpu")
|
|
|
-print("Sentence transformer model initialized successfully.")
|
|
|
-
|
|
|
-# Define the schema for our LanceDB table
|
|
|
-class Schema(LanceModel):
|
|
|
- Filename: str
|
|
|
- Title: str
|
|
|
- Size: str
|
|
|
- Gender: str
|
|
|
- Description: str = model.SourceField()
|
|
|
- Category: str
|
|
|
- Type: str
|
|
|
- vector: Vector(model.ndims()) = model.VectorField()
|
|
|
-
|
|
|
-# Connect to LanceDB and create/load the table
|
|
|
-print("Connecting to LanceDB...")
|
|
|
-db = lancedb.connect(args.table_path)
|
|
|
-if args.use_existing_table:
|
|
|
- tbl = db.open_table("clothes")
|
|
|
-else:
|
|
|
- tbl = db.create_table(name="clothes", schema=Schema, mode="overwrite")
|
|
|
- # Load and clean data
|
|
|
- print(f"Loading and cleaning data from CSV: {args.csv_path}")
|
|
|
- df = pd.read_csv(args.csv_path)
|
|
|
- df = df.dropna().astype(str)
|
|
|
- tbl.add(df.to_dict('records'))
|
|
|
- print(f"Loaded and cleaned {len(df)} records into the LanceDB table.")
|
|
|
-
|
|
|
-print("Connected to LanceDB and created/loaded the 'clothes' table.")
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-# Set up the Together API client
|
|
|
-os.environ["TOGETHER_API_KEY"] = args.api_key
|
|
|
-client = Together(api_key=args.api_key)
|
|
|
-print("Together API client set up successfully.")
|
|
|
-
|
|
|
-def encode_image(image):
|
|
|
- buffered = io.BytesIO()
|
|
|
- image.save(buffered, format="JPEG")
|
|
|
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
|
-
|
|
|
-def generate_description(image):
|
|
|
- print("Generating description for uploaded image...")
|
|
|
- base64_image = encode_image(image)
|
|
|
- try:
|
|
|
- response = client.chat.completions.create(
|
|
|
- model="meta-llama/Llama-Vision-Free",
|
|
|
- messages=[
|
|
|
- {
|
|
|
- "role": "user",
|
|
|
- "content": [
|
|
|
- {
|
|
|
- "type": "image_url",
|
|
|
- "image_url": {
|
|
|
- "url": f"data:image/jpeg;base64,{base64_image}"
|
|
|
- }
|
|
|
- },
|
|
|
- {
|
|
|
- "type": "text",
|
|
|
- "text": "Describe this clothing item in detail."
|
|
|
- }
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- max_tokens=512,
|
|
|
- temperature=0.7,
|
|
|
- )
|
|
|
- description = response.choices[0].message.content
|
|
|
- print(f"Generated description: {description}")
|
|
|
- return description
|
|
|
- except Exception as e:
|
|
|
- print(f"Error generating description: {e}")
|
|
|
- return "Error generating description"
|
|
|
-
|
|
|
-def process_chat_input(chat_history, user_input):
|
|
|
- print(f"Processing chat input: {user_input}")
|
|
|
- messages = [
|
|
|
- {"role": "system", "content": "You are a helpful fashion assistant."}
|
|
|
- ]
|
|
|
- for user_msg, assistant_msg in chat_history:
|
|
|
- messages.append({"role": "user", "content": user_msg})
|
|
|
- messages.append({"role": "assistant", "content": assistant_msg})
|
|
|
-
|
|
|
- user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY."
|
|
|
- messages.append({"role": "user", "content": user_input})
|
|
|
- print(f"Chat history: {messages}")
|
|
|
- try:
|
|
|
- bot_response = client.chat.completions.create(
|
|
|
- model="meta-llama/Llama-Vision-Free",
|
|
|
- messages=messages,
|
|
|
- max_tokens=512,
|
|
|
- temperature=0.7,
|
|
|
- ).choices[0].message.content
|
|
|
-
|
|
|
- print(f"Bot response: {bot_response}")
|
|
|
- return user_input, bot_response
|
|
|
- except Exception as e:
|
|
|
- print(f"Error processing chat input: {e}")
|
|
|
- return user_input, "Error processing chat input"
|
|
|
-
|
|
|
-def retrieve_similar_items(description, n=10):
|
|
|
- print(f"Retrieving similar items for: {description}")
|
|
|
- try:
|
|
|
- results = tbl.search(description).limit(n).to_pandas()
|
|
|
- print(f"Retrieved {len(results)} similar items.")
|
|
|
- return results
|
|
|
- except Exception as e:
|
|
|
- print(f"Error retrieving similar items: {e}")
|
|
|
- return pd.DataFrame()
|
|
|
-
|
|
|
-def rewrite_query(original_query, item_description):
|
|
|
- print(f"Rewriting query: {original_query}")
|
|
|
- messages = [
|
|
|
- {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
|
|
|
- {"role": "user", "content": f"Item description: {item_description}"},
|
|
|
- {"role": "user", "content": f"User query: {original_query}"},
|
|
|
- {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
|
|
|
- ]
|
|
|
-
|
|
|
- try:
|
|
|
- response = client.chat.completions.create(
|
|
|
- model="meta-llama/Llama-Vision-Free",
|
|
|
- messages=messages,
|
|
|
- max_tokens=512,
|
|
|
- temperature=0.7,
|
|
|
- )
|
|
|
- rewritten_query = response.choices[0].message.content
|
|
|
- print(f"Rewritten query: {rewritten_query}")
|
|
|
- return rewritten_query
|
|
|
- except Exception as e:
|
|
|
- print(f"Error rewriting query: {e}")
|
|
|
- return original_query
|
|
|
-
|
|
|
-def fashion_assistant(image, chat_input, chat_history):
|
|
|
- if chat_input is not "":
|
|
|
- print("Processing chat input...")
|
|
|
- last_description = chat_history[-1][1] if chat_history else ""
|
|
|
- #rewritten_query = rewrite_query(chat_input, last_description)
|
|
|
- user_message, bot_response = process_chat_input(chat_history, chat_input)
|
|
|
- similar_items = retrieve_similar_items(bot_response)
|
|
|
- gallery_data = create_gallery_data(similar_items)
|
|
|
- return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
|
|
|
- elif image is not None:
|
|
|
- print("Processing uploaded image...")
|
|
|
- description = generate_description(image)
|
|
|
- user_message = f"I've uploaded an image. The description is: {description}"
|
|
|
- user_message, bot_response = process_chat_input(chat_history, user_message)
|
|
|
- similar_items = retrieve_similar_items(description)
|
|
|
- gallery_data = create_gallery_data(similar_items)
|
|
|
- return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
|
|
|
- else:
|
|
|
- print("No input provided.")
|
|
|
- return chat_history, "", [], ""
|
|
|
-
|
|
|
-def create_gallery_data(results):
|
|
|
- return [
|
|
|
- (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
|
|
|
- for _, row in results.iterrows()
|
|
|
- ]
|
|
|
-
|
|
|
-def on_select(evt: gr.SelectData):
|
|
|
- return f"Selected {evt.value} at index {evt.index}"
|
|
|
-
|
|
|
-def update_chat(image, chat_input, chat_history, last_description):
|
|
|
- new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
|
|
|
- if new_description:
|
|
|
- last_description = new_description
|
|
|
- return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
|
|
|
-
|
|
|
-# Define the Gradio interface
|
|
|
-print("Setting up Gradio interface...")
|
|
|
-with gr.Blocks() as demo:
|
|
|
- gr.Markdown("# Interactive Fashion Assistant")
|
|
|
- with gr.Row():
|
|
|
- with gr.Column(scale=1):
|
|
|
- image_input = gr.Image(type="pil", label="Upload Clothing Image")
|
|
|
- with gr.Column(scale=1):
|
|
|
- chatbot = gr.Chatbot(label="Chat History")
|
|
|
- chat_input = gr.Textbox(label="Chat Input")
|
|
|
- chat_button = gr.Button("Send")
|
|
|
- with gr.Column(scale=2):
|
|
|
- gallery = gr.Gallery(
|
|
|
- label="Retrieved Clothes",
|
|
|
- show_label=True,
|
|
|
- elem_id="gallery",
|
|
|
- columns=[5],
|
|
|
- rows=[2],
|
|
|
- object_fit="contain",
|
|
|
- height="auto"
|
|
|
- )
|
|
|
- selected_image = gr.Textbox(label="Selected Image")
|
|
|
-
|
|
|
- chat_state = gr.State([])
|
|
|
- last_description = gr.State("")
|
|
|
-
|
|
|
- image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
|
|
|
- outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
|
|
|
- chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
|
|
|
- outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
|
|
|
- gallery.select(on_select, None, selected_image)
|
|
|
-
|
|
|
-print("Gradio interface set up successfully. Launching the app...")
|
|
|
-demo.launch()
|
|
|
-print("Fashion Assistant application is now running!")
|