123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- import gradio as gr
- import pandas as pd
- import lancedb
- from lancedb.pydantic import LanceModel, Vector
- from lancedb.embeddings import get_registry
- from pathlib import Path
- from PIL import Image
- import io
- import base64
- from together import Together
- import os
- import logging
- import argparse
- import numpy as np
- # Set up argument parsing
- parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
- parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
- parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
- parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
- parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
- parser.add_argument("--api_key", required=True, help="Together API key")
- args = parser.parse_args()
- # Set up logging
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- print("Starting the Fashion Assistant application...")
- # Set up the sentence transformer model for CPU use
- print("Initializing the sentence transformer model...")
- model = get_registry().get("sentence-transformers").create(name="BAAI/bge-large-en-v1.5", device="cpu")
- print("Sentence transformer model initialized successfully.")
- # Define the schema for our LanceDB table
- class Schema(LanceModel):
- Filename: str
- Title: str
- Size: str
- Gender: str
- Description: str = model.SourceField()
- Category: str
- Type: str
- vector: Vector(model.ndims()) = model.VectorField()
- # Connect to LanceDB and create/load the table
- print("Connecting to LanceDB...")
- db = lancedb.connect(args.table_path)
- if args.use_existing_table:
- tbl = db.open_table("clothes")
- else:
- tbl = db.create_table(name="clothes", schema=Schema, mode="overwrite")
- # Load and clean data
- print(f"Loading and cleaning data from CSV: {args.csv_path}")
- df = pd.read_csv(args.csv_path)
- df = df.dropna().astype(str)
- tbl.add(df.to_dict('records'))
- print(f"Loaded and cleaned {len(df)} records into the LanceDB table.")
- print("Connected to LanceDB and created/loaded the 'clothes' table.")
- # Set up the Together API client
- os.environ["TOGETHER_API_KEY"] = args.api_key
- client = Together(api_key=args.api_key)
- print("Together API client set up successfully.")
- def encode_image(image):
- buffered = io.BytesIO()
- image.save(buffered, format="JPEG")
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
- def generate_description(image):
- print("Generating description for uploaded image...")
- base64_image = encode_image(image)
- try:
- response = client.chat.completions.create(
- model="meta-llama/Llama-Vision-Free",
- messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{base64_image}"
- }
- },
- {
- "type": "text",
- "text": "Describe this clothing item in detail."
- }
- ]
- }
- ],
- max_tokens=512,
- temperature=0.7,
- )
- description = response.choices[0].message.content
- print(f"Generated description: {description}")
- return description
- except Exception as e:
- print(f"Error generating description: {e}")
- return "Error generating description"
- def process_chat_input(chat_history, user_input):
- print(f"Processing chat input: {user_input}")
- messages = [
- {"role": "system", "content": "You are a helpful fashion assistant."}
- ]
- for user_msg, assistant_msg in chat_history:
- messages.append({"role": "user", "content": user_msg})
- messages.append({"role": "assistant", "content": assistant_msg})
-
- user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY."
- messages.append({"role": "user", "content": user_input})
- print(f"Chat history: {messages}")
- try:
- bot_response = client.chat.completions.create(
- model="meta-llama/Llama-Vision-Free",
- messages=messages,
- max_tokens=512,
- temperature=0.7,
- ).choices[0].message.content
-
- print(f"Bot response: {bot_response}")
- return user_input, bot_response
- except Exception as e:
- print(f"Error processing chat input: {e}")
- return user_input, "Error processing chat input"
- def retrieve_similar_items(description, n=10):
- print(f"Retrieving similar items for: {description}")
- try:
- results = tbl.search(description).limit(n).to_pandas()
- print(f"Retrieved {len(results)} similar items.")
- return results
- except Exception as e:
- print(f"Error retrieving similar items: {e}")
- return pd.DataFrame()
- def rewrite_query(original_query, item_description):
- print(f"Rewriting query: {original_query}")
- messages = [
- {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
- {"role": "user", "content": f"Item description: {item_description}"},
- {"role": "user", "content": f"User query: {original_query}"},
- {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
- ]
-
- try:
- response = client.chat.completions.create(
- model="meta-llama/Llama-Vision-Free",
- messages=messages,
- max_tokens=512,
- temperature=0.7,
- )
- rewritten_query = response.choices[0].message.content
- print(f"Rewritten query: {rewritten_query}")
- return rewritten_query
- except Exception as e:
- print(f"Error rewriting query: {e}")
- return original_query
- def fashion_assistant(image, chat_input, chat_history):
- if chat_input is not "":
- print("Processing chat input...")
- last_description = chat_history[-1][1] if chat_history else ""
- #rewritten_query = rewrite_query(chat_input, last_description)
- user_message, bot_response = process_chat_input(chat_history, chat_input)
- similar_items = retrieve_similar_items(bot_response)
- gallery_data = create_gallery_data(similar_items)
- return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
- elif image is not None:
- print("Processing uploaded image...")
- description = generate_description(image)
- user_message = f"I've uploaded an image. The description is: {description}"
- user_message, bot_response = process_chat_input(chat_history, user_message)
- similar_items = retrieve_similar_items(description)
- gallery_data = create_gallery_data(similar_items)
- return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
- else:
- print("No input provided.")
- return chat_history, "", [], ""
- def create_gallery_data(results):
- return [
- (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
- for _, row in results.iterrows()
- ]
- def on_select(evt: gr.SelectData):
- return f"Selected {evt.value} at index {evt.index}"
- def update_chat(image, chat_input, chat_history, last_description):
- new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
- if new_description:
- last_description = new_description
- return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
- # Define the Gradio interface
- print("Setting up Gradio interface...")
- with gr.Blocks() as demo:
- gr.Markdown("# Interactive Fashion Assistant")
- with gr.Row():
- with gr.Column(scale=1):
- image_input = gr.Image(type="pil", label="Upload Clothing Image")
- with gr.Column(scale=1):
- chatbot = gr.Chatbot(label="Chat History")
- chat_input = gr.Textbox(label="Chat Input")
- chat_button = gr.Button("Send")
- with gr.Column(scale=2):
- gallery = gr.Gallery(
- label="Retrieved Clothes",
- show_label=True,
- elem_id="gallery",
- columns=[5],
- rows=[2],
- object_fit="contain",
- height="auto"
- )
- selected_image = gr.Textbox(label="Selected Image")
-
- chat_state = gr.State([])
- last_description = gr.State("")
-
- image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
- outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
- chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
- outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
- gallery.select(on_select, None, selected_image)
- print("Gradio interface set up successfully. Launching the app...")
- demo.launch()
- print("Fashion Assistant application is now running!")
|