| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 | import gradio as grimport pandas as pdimport lancedbfrom lancedb.pydantic import LanceModel, Vectorfrom lancedb.embeddings import get_registryfrom lancedb.rerankers import ColbertRerankerfrom pathlib import Pathfrom PIL import Imageimport ioimport base64from together import Togetherimport osimport loggingimport argparseimport numpy as np# Set up argument parsingparser = argparse.ArgumentParser(description="Interactive Fashion Assistant")parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")parser.add_argument("--api_key", required=True, help="Together API key")parser.add_argument("--default_model", default="BAAI/bge-large-en-v1.5", help="Default embedding model to use")args = parser.parse_args()# Set up logginglogging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')print("Starting the Fashion Assistant application...")# Define available modelsAVAILABLE_MODELS = {    "BAAI/bge-large-en-v1.5": "bge-large-en-v1.5",    "BAAI/bge-small-en-v1.5": "bge-small-en-v1.5",    "BAAI/bge-reranker-base": "bge-reranker-base",    "BAAI/bge-reranker-large": "bge-reranker-large"}# Define retrieval methodsRETRIEVAL_METHODS = ["Semantic Search", "Full Text Search", "Hybrid Search"]# Connect to LanceDBprint("Connecting to LanceDB...")db = lancedb.connect(args.table_path)def create_table_for_model(model_name):    print(f"Initializing the sentence transformer model: {model_name}")    model = get_registry().get("sentence-transformers").create(name=model_name, device="cpu")        class Schema(LanceModel):        Filename: str        Title: str        Size: str        Gender: str        Description: str = model.SourceField()        Category: str        Type: str        vector: Vector(model.ndims()) = model.VectorField()        table_name = f"clothes_{AVAILABLE_MODELS[model_name]}"    if not args.use_existing_table:        tbl = db.create_table(name=table_name, schema=Schema, mode="overwrite")        df = pd.read_csv(args.csv_path)        df = df.dropna().astype(str)        tbl.add(df.to_dict('records'))        tbl.create_fts_index(["Description"], replace=True)        print(f"Created and populated table {table_name}")    else:        tbl = db.open_table(table_name)        tbl.create_fts_index(["Description"], replace=True)        print(f"Opened existing table {table_name}")    return tbltables = {model: create_table_for_model(model) for model in AVAILABLE_MODELS}current_table = tables[args.default_model]current_retrieval_method = "Semantic Search"# Set up the Together API clientos.environ["TOGETHER_API_KEY"] = args.api_keyclient = Together(api_key=args.api_key)print("Together API client set up successfully.")def encode_image(image):    buffered = io.BytesIO()    image.save(buffered, format="JPEG")    return base64.b64encode(buffered.getvalue()).decode('utf-8')def generate_description(image):    print("Generating description for uploaded image...")    base64_image = encode_image(image)    try:        response = client.chat.completions.create(            model="meta-llama/Llama-Vision-Free",            messages=[                {                    "role": "user",                    "content": [                        {                            "type": "image_url",                            "image_url": {                                "url": f"data:image/jpeg;base64,{base64_image}"                            }                        },                        {                            "type": "text",                            "text": "Describe this clothing item in detail."                        }                    ]                }            ],            max_tokens=512,            temperature=0.7,        )        description = response.choices[0].message.content        print(f"Generated description: {description}")        return description    except Exception as e:        print(f"Error generating description: {e}")        return "Error generating description"def process_chat_input(chat_history, user_input):    print(f"Processing chat input: {user_input}")    messages = [        {"role": "system", "content": "You are a helpful fashion assistant."}    ]    for user_msg, assistant_msg in chat_history:        messages.append({"role": "user", "content": user_msg})        messages.append({"role": "assistant", "content": assistant_msg})        user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY. Start your message with '1. ..' "    messages.append({"role": "user", "content": user_input})    print(f"Chat history: {messages}")    try:        bot_response = client.chat.completions.create(            model="meta-llama/Llama-Vision-Free",            messages=messages,            max_tokens=512,            temperature=0.7,        ).choices[0].message.content                print(f"Bot response: {bot_response}")        return user_input, bot_response    except Exception as e:        print(f"Error processing chat input: {e}")        return user_input, "Error processing chat input"def retrieve_similar_items(description, n=10):    print(f"Retrieving similar items for: {description}")    try:        if current_retrieval_method == "Semantic Search":            results = current_table.search(description).limit(n).to_pandas()        elif current_retrieval_method == "Full Text Search":            results = current_table.search(description, query_type="fts").limit(n).to_pandas()        elif current_retrieval_method == "Hybrid Search":            reranker = ColbertReranker(                model_name="answerdotai/answerai-colbert-small-v1",                column="Description")            results = current_table.search(description, query_type="hybrid").rerank(reranker=reranker).limit(n).to_pandas()        else:            raise ValueError("Invalid retrieval method")        print(f"Retrieved {len(results)} similar items using {current_retrieval_method}.")        return results    except Exception as e:        print(f"Error retrieving similar items: {e}")        return pd.DataFrame()def rewrite_query(original_query, item_description):    print(f"Rewriting query: {original_query}")    messages = [        {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},        {"role": "user", "content": f"Item description: {item_description}"},        {"role": "user", "content": f"User query: {original_query}"},        {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}    ]        try:        response = client.chat.completions.create(            model="meta-llama/Llama-Vision-Free",            messages=messages,            max_tokens=512,            temperature=0.7,        )        rewritten_query = response.choices[0].message.content        print(f"Rewritten query: {rewritten_query}")        return rewritten_query    except Exception as e:        print(f"Error rewriting query: {e}")        return original_querydef fashion_assistant(image, chat_input, chat_history):    if chat_input != "":        print("Processing chat input...")        last_description = chat_history[-1][1] if chat_history else ""        user_message, bot_response = process_chat_input(chat_history, chat_input)        similar_items = retrieve_similar_items(bot_response)        gallery_data = create_gallery_data(similar_items)        return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description    elif image is not None:        print("Processing uploaded image...")        description = generate_description(image)        user_message = f"I've uploaded an image. The description is: {description}"        user_message, bot_response = process_chat_input(chat_history, user_message)        similar_items = retrieve_similar_items(description)        gallery_data = create_gallery_data(similar_items)        return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description    else:        print("No input provided.")        return chat_history, "", [], ""def create_gallery_data(results):    return [        (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")        for _, row in results.iterrows()    ]def on_select(evt: gr.SelectData):    return f"Selected {evt.value} at index {evt.index}"def update_chat(image, chat_input, chat_history, last_description):    new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)    if new_description:        last_description = new_description    return new_chat_history, new_chat_history, "", last_response, gallery_data, last_descriptiondef update_model(model_name):    global current_table    current_table = tables[model_name]    return f"Switched to model: {model_name}"def update_retrieval_method(method):    global current_retrieval_method    current_retrieval_method = method    return f"Switched to retrieval method: {method}"# Define the Gradio interfaceprint("Setting up Gradio interface...")with gr.Blocks() as demo:    gr.Markdown("# Interactive Fashion Assistant")    with gr.Row():        with gr.Column(scale=1):            image_input = gr.Image(type="pil", label="Upload Clothing Image")            model_dropdown = gr.Dropdown(                choices=list(AVAILABLE_MODELS.keys()),                value=args.default_model,                label="Embedding Model"            )            retrieval_dropdown = gr.Dropdown(                choices=RETRIEVAL_METHODS,                value="Semantic Search",                label="Retrieval Method"            )        with gr.Column(scale=1):            chatbot = gr.Chatbot(label="Chat History")            chat_input = gr.Textbox(label="Chat Input")            chat_button = gr.Button("Send")        with gr.Column(scale=2):            gallery = gr.Gallery(                label="Retrieved Clothes",                show_label=True,                elem_id="gallery",                columns=[5],                rows=[2],                object_fit="contain",                height="auto"            )            selected_image = gr.Textbox(label="Selected Image")        chat_state = gr.State([])    last_description = gr.State("")        image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description],                        outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])    chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description],                       outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])    gallery.select(on_select, None, selected_image)    model_dropdown.change(update_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")])    retrieval_dropdown.change(update_retrieval_method, inputs=[retrieval_dropdown], outputs=[gr.Textbox(label="Retrieval Method Status")])    # Disable embedding model dropdown when Hybrid Search is selected    retrieval_dropdown.change(lambda x: gr.update(interactive=x != "Hybrid Search"), inputs=[retrieval_dropdown], outputs=[model_dropdown])print("Gradio interface set up successfully. Launching the app...")demo.launch()print("Fashion Assistant application is now running!")
 |