| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 | 
							- import gradio as gr
 
- import pandas as pd
 
- import lancedb
 
- from lancedb.pydantic import LanceModel, Vector
 
- from lancedb.embeddings import get_registry
 
- from lancedb.rerankers import ColbertReranker
 
- from pathlib import Path
 
- from PIL import Image
 
- import io
 
- import base64
 
- from together import Together
 
- import os
 
- import logging
 
- import argparse
 
- import numpy as np
 
- # Set up argument parsing
 
- parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
 
- parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
 
- parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
 
- parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
 
- parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
 
- parser.add_argument("--api_key", required=True, help="Together API key")
 
- parser.add_argument("--default_model", default="BAAI/bge-large-en-v1.5", help="Default embedding model to use")
 
- args = parser.parse_args()
 
- # Set up logging
 
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
- print("Starting the Fashion Assistant application...")
 
- # Define available models
 
- AVAILABLE_MODELS = {
 
-     "BAAI/bge-large-en-v1.5": "bge-large-en-v1.5",
 
-     "BAAI/bge-small-en-v1.5": "bge-small-en-v1.5",
 
-     "BAAI/bge-reranker-base": "bge-reranker-base",
 
-     "BAAI/bge-reranker-large": "bge-reranker-large"
 
- }
 
- # Define retrieval methods
 
- RETRIEVAL_METHODS = ["Semantic Search", "Full Text Search", "Hybrid Search"]
 
- # Connect to LanceDB
 
- print("Connecting to LanceDB...")
 
- db = lancedb.connect(args.table_path)
 
- def create_table_for_model(model_name):
 
-     print(f"Initializing the sentence transformer model: {model_name}")
 
-     model = get_registry().get("sentence-transformers").create(name=model_name, device="cpu")
 
-     
 
-     class Schema(LanceModel):
 
-         Filename: str
 
-         Title: str
 
-         Size: str
 
-         Gender: str
 
-         Description: str = model.SourceField()
 
-         Category: str
 
-         Type: str
 
-         vector: Vector(model.ndims()) = model.VectorField()
 
-     
 
-     table_name = f"clothes_{AVAILABLE_MODELS[model_name]}"
 
-     if not args.use_existing_table:
 
-         tbl = db.create_table(name=table_name, schema=Schema, mode="overwrite")
 
-         df = pd.read_csv(args.csv_path)
 
-         df = df.dropna().astype(str)
 
-         tbl.add(df.to_dict('records'))
 
-         tbl.create_fts_index(["Description"], replace=True)
 
-         print(f"Created and populated table {table_name}")
 
-     else:
 
-         tbl = db.open_table(table_name)
 
-         tbl.create_fts_index(["Description"], replace=True)
 
-         print(f"Opened existing table {table_name}")
 
-     return tbl
 
- tables = {model: create_table_for_model(model) for model in AVAILABLE_MODELS}
 
- current_table = tables[args.default_model]
 
- current_retrieval_method = "Semantic Search"
 
- # Set up the Together API client
 
- os.environ["TOGETHER_API_KEY"] = args.api_key
 
- client = Together(api_key=args.api_key)
 
- print("Together API client set up successfully.")
 
- def encode_image(image):
 
-     buffered = io.BytesIO()
 
-     image.save(buffered, format="JPEG")
 
-     return base64.b64encode(buffered.getvalue()).decode('utf-8')
 
- def generate_description(image):
 
-     print("Generating description for uploaded image...")
 
-     base64_image = encode_image(image)
 
-     try:
 
-         response = client.chat.completions.create(
 
-             model="meta-llama/Llama-Vision-Free",
 
-             messages=[
 
-                 {
 
-                     "role": "user",
 
-                     "content": [
 
-                         {
 
-                             "type": "image_url",
 
-                             "image_url": {
 
-                                 "url": f"data:image/jpeg;base64,{base64_image}"
 
-                             }
 
-                         },
 
-                         {
 
-                             "type": "text",
 
-                             "text": "Describe this clothing item in detail."
 
-                         }
 
-                     ]
 
-                 }
 
-             ],
 
-             max_tokens=512,
 
-             temperature=0.7,
 
-         )
 
-         description = response.choices[0].message.content
 
-         print(f"Generated description: {description}")
 
-         return description
 
-     except Exception as e:
 
-         print(f"Error generating description: {e}")
 
-         return "Error generating description"
 
- def process_chat_input(chat_history, user_input):
 
-     print(f"Processing chat input: {user_input}")
 
-     messages = [
 
-         {"role": "system", "content": "You are a helpful fashion assistant."}
 
-     ]
 
-     for user_msg, assistant_msg in chat_history:
 
-         messages.append({"role": "user", "content": user_msg})
 
-         messages.append({"role": "assistant", "content": assistant_msg})
 
-     
 
-     user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY. Start your message with '1. ..' "
 
-     messages.append({"role": "user", "content": user_input})
 
-     print(f"Chat history: {messages}")
 
-     try:
 
-         bot_response = client.chat.completions.create(
 
-             model="meta-llama/Llama-Vision-Free",
 
-             messages=messages,
 
-             max_tokens=512,
 
-             temperature=0.7,
 
-         ).choices[0].message.content
 
-         
 
-         print(f"Bot response: {bot_response}")
 
-         return user_input, bot_response
 
-     except Exception as e:
 
-         print(f"Error processing chat input: {e}")
 
-         return user_input, "Error processing chat input"
 
- def retrieve_similar_items(description, n=10):
 
-     print(f"Retrieving similar items for: {description}")
 
-     try:
 
-         if current_retrieval_method == "Semantic Search":
 
-             results = current_table.search(description).limit(n).to_pandas()
 
-         elif current_retrieval_method == "Full Text Search":
 
-             results = current_table.search(description, query_type="fts").limit(n).to_pandas()
 
-         elif current_retrieval_method == "Hybrid Search":
 
-             reranker = ColbertReranker(
 
-                 model_name="answerdotai/answerai-colbert-small-v1",
 
-                 column="Description")
 
-             results = current_table.search(description, query_type="hybrid").rerank(reranker=reranker).limit(n).to_pandas()
 
-         else:
 
-             raise ValueError("Invalid retrieval method")
 
-         print(f"Retrieved {len(results)} similar items using {current_retrieval_method}.")
 
-         return results
 
-     except Exception as e:
 
-         print(f"Error retrieving similar items: {e}")
 
-         return pd.DataFrame()
 
- def rewrite_query(original_query, item_description):
 
-     print(f"Rewriting query: {original_query}")
 
-     messages = [
 
-         {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
 
-         {"role": "user", "content": f"Item description: {item_description}"},
 
-         {"role": "user", "content": f"User query: {original_query}"},
 
-         {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
 
-     ]
 
-     
 
-     try:
 
-         response = client.chat.completions.create(
 
-             model="meta-llama/Llama-Vision-Free",
 
-             messages=messages,
 
-             max_tokens=512,
 
-             temperature=0.7,
 
-         )
 
-         rewritten_query = response.choices[0].message.content
 
-         print(f"Rewritten query: {rewritten_query}")
 
-         return rewritten_query
 
-     except Exception as e:
 
-         print(f"Error rewriting query: {e}")
 
-         return original_query
 
- def fashion_assistant(image, chat_input, chat_history):
 
-     if chat_input != "":
 
-         print("Processing chat input...")
 
-         last_description = chat_history[-1][1] if chat_history else ""
 
-         user_message, bot_response = process_chat_input(chat_history, chat_input)
 
-         similar_items = retrieve_similar_items(bot_response)
 
-         gallery_data = create_gallery_data(similar_items)
 
-         return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
 
-     elif image is not None:
 
-         print("Processing uploaded image...")
 
-         description = generate_description(image)
 
-         user_message = f"I've uploaded an image. The description is: {description}"
 
-         user_message, bot_response = process_chat_input(chat_history, user_message)
 
-         similar_items = retrieve_similar_items(description)
 
-         gallery_data = create_gallery_data(similar_items)
 
-         return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
 
-     else:
 
-         print("No input provided.")
 
-         return chat_history, "", [], ""
 
- def create_gallery_data(results):
 
-     return [
 
-         (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
 
-         for _, row in results.iterrows()
 
-     ]
 
- def on_select(evt: gr.SelectData):
 
-     return f"Selected {evt.value} at index {evt.index}"
 
- def update_chat(image, chat_input, chat_history, last_description):
 
-     new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
 
-     if new_description:
 
-         last_description = new_description
 
-     return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
 
- def update_model(model_name):
 
-     global current_table
 
-     current_table = tables[model_name]
 
-     return f"Switched to model: {model_name}"
 
- def update_retrieval_method(method):
 
-     global current_retrieval_method
 
-     current_retrieval_method = method
 
-     return f"Switched to retrieval method: {method}"
 
- # Define the Gradio interface
 
- print("Setting up Gradio interface...")
 
- with gr.Blocks() as demo:
 
-     gr.Markdown("# Interactive Fashion Assistant")
 
-     with gr.Row():
 
-         with gr.Column(scale=1):
 
-             image_input = gr.Image(type="pil", label="Upload Clothing Image")
 
-             model_dropdown = gr.Dropdown(
 
-                 choices=list(AVAILABLE_MODELS.keys()),
 
-                 value=args.default_model,
 
-                 label="Embedding Model"
 
-             )
 
-             retrieval_dropdown = gr.Dropdown(
 
-                 choices=RETRIEVAL_METHODS,
 
-                 value="Semantic Search",
 
-                 label="Retrieval Method"
 
-             )
 
-         with gr.Column(scale=1):
 
-             chatbot = gr.Chatbot(label="Chat History")
 
-             chat_input = gr.Textbox(label="Chat Input")
 
-             chat_button = gr.Button("Send")
 
-         with gr.Column(scale=2):
 
-             gallery = gr.Gallery(
 
-                 label="Retrieved Clothes",
 
-                 show_label=True,
 
-                 elem_id="gallery",
 
-                 columns=[5],
 
-                 rows=[2],
 
-                 object_fit="contain",
 
-                 height="auto"
 
-             )
 
-             selected_image = gr.Textbox(label="Selected Image")
 
-     
 
-     chat_state = gr.State([])
 
-     last_description = gr.State("")
 
-     
 
-     image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description], 
 
-                        outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
 
-     chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description], 
 
-                       outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
 
-     gallery.select(on_select, None, selected_image)
 
-     model_dropdown.change(update_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")])
 
-     retrieval_dropdown.change(update_retrieval_method, inputs=[retrieval_dropdown], outputs=[gr.Textbox(label="Retrieval Method Status")])
 
-     # Disable embedding model dropdown when Hybrid Search is selected
 
-     retrieval_dropdown.change(lambda x: gr.update(interactive=x != "Hybrid Search"), inputs=[retrieval_dropdown], outputs=[model_dropdown])
 
- print("Gradio interface set up successfully. Launching the app...")
 
- demo.launch()
 
- print("Fashion Assistant application is now running!")
 
 
  |