final_demo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import gradio as gr
  2. import pandas as pd
  3. import lancedb
  4. from lancedb.pydantic import LanceModel, Vector
  5. from lancedb.embeddings import get_registry
  6. from lancedb.rerankers import ColbertReranker
  7. from pathlib import Path
  8. from PIL import Image
  9. import io
  10. import base64
  11. from together import Together
  12. import os
  13. import logging
  14. import argparse
  15. import numpy as np
  16. # Set up argument parsing
  17. parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
  18. parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
  19. parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
  20. parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
  21. parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
  22. parser.add_argument("--api_key", required=True, help="Together API key")
  23. parser.add_argument("--default_model", default="BAAI/bge-large-en-v1.5", help="Default embedding model to use")
  24. args = parser.parse_args()
  25. # Set up logging
  26. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  27. print("Starting the Fashion Assistant application...")
  28. # Define available models
  29. AVAILABLE_MODELS = {
  30. "BAAI/bge-large-en-v1.5": "bge-large-en-v1.5",
  31. "BAAI/bge-small-en-v1.5": "bge-small-en-v1.5",
  32. "BAAI/bge-reranker-base": "bge-reranker-base",
  33. "BAAI/bge-reranker-large": "bge-reranker-large"
  34. }
  35. # Define retrieval methods
  36. RETRIEVAL_METHODS = ["Semantic Search", "Full Text Search", "Hybrid Search"]
  37. # Connect to LanceDB
  38. print("Connecting to LanceDB...")
  39. db = lancedb.connect(args.table_path)
  40. def create_table_for_model(model_name):
  41. print(f"Initializing the sentence transformer model: {model_name}")
  42. model = get_registry().get("sentence-transformers").create(name=model_name, device="cpu")
  43. class Schema(LanceModel):
  44. Filename: str
  45. Title: str
  46. Size: str
  47. Gender: str
  48. Description: str = model.SourceField()
  49. Category: str
  50. Type: str
  51. vector: Vector(model.ndims()) = model.VectorField()
  52. table_name = f"clothes_{AVAILABLE_MODELS[model_name]}"
  53. if not args.use_existing_table:
  54. tbl = db.create_table(name=table_name, schema=Schema, mode="overwrite")
  55. df = pd.read_csv(args.csv_path)
  56. df = df.dropna().astype(str)
  57. tbl.add(df.to_dict('records'))
  58. tbl.create_fts_index(["Description"], replace=True)
  59. print(f"Created and populated table {table_name}")
  60. else:
  61. tbl = db.open_table(table_name)
  62. tbl.create_fts_index(["Description"], replace=True)
  63. print(f"Opened existing table {table_name}")
  64. return tbl
  65. tables = {model: create_table_for_model(model) for model in AVAILABLE_MODELS}
  66. current_table = tables[args.default_model]
  67. current_retrieval_method = "Semantic Search"
  68. # Set up the Together API client
  69. os.environ["TOGETHER_API_KEY"] = args.api_key
  70. client = Together(api_key=args.api_key)
  71. print("Together API client set up successfully.")
  72. def encode_image(image):
  73. buffered = io.BytesIO()
  74. image.save(buffered, format="JPEG")
  75. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  76. def generate_description(image):
  77. print("Generating description for uploaded image...")
  78. base64_image = encode_image(image)
  79. try:
  80. response = client.chat.completions.create(
  81. model="meta-llama/Llama-Vision-Free",
  82. messages=[
  83. {
  84. "role": "user",
  85. "content": [
  86. {
  87. "type": "image_url",
  88. "image_url": {
  89. "url": f"data:image/jpeg;base64,{base64_image}"
  90. }
  91. },
  92. {
  93. "type": "text",
  94. "text": "Describe this clothing item in detail."
  95. }
  96. ]
  97. }
  98. ],
  99. max_tokens=512,
  100. temperature=0.7,
  101. )
  102. description = response.choices[0].message.content
  103. print(f"Generated description: {description}")
  104. return description
  105. except Exception as e:
  106. print(f"Error generating description: {e}")
  107. return "Error generating description"
  108. def process_chat_input(chat_history, user_input):
  109. print(f"Processing chat input: {user_input}")
  110. messages = [
  111. {"role": "system", "content": "You are a helpful fashion assistant."}
  112. ]
  113. for user_msg, assistant_msg in chat_history:
  114. messages.append({"role": "user", "content": user_msg})
  115. messages.append({"role": "assistant", "content": assistant_msg})
  116. user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY. Start your message with '1. ..' "
  117. messages.append({"role": "user", "content": user_input})
  118. print(f"Chat history: {messages}")
  119. try:
  120. bot_response = client.chat.completions.create(
  121. model="meta-llama/Llama-Vision-Free",
  122. messages=messages,
  123. max_tokens=512,
  124. temperature=0.7,
  125. ).choices[0].message.content
  126. print(f"Bot response: {bot_response}")
  127. return user_input, bot_response
  128. except Exception as e:
  129. print(f"Error processing chat input: {e}")
  130. return user_input, "Error processing chat input"
  131. def retrieve_similar_items(description, n=10):
  132. print(f"Retrieving similar items for: {description}")
  133. try:
  134. if current_retrieval_method == "Semantic Search":
  135. results = current_table.search(description).limit(n).to_pandas()
  136. elif current_retrieval_method == "Full Text Search":
  137. results = current_table.search(description, query_type="fts").limit(n).to_pandas()
  138. elif current_retrieval_method == "Hybrid Search":
  139. reranker = ColbertReranker(
  140. model_name="answerdotai/answerai-colbert-small-v1",
  141. column="Description")
  142. results = current_table.search(description, query_type="hybrid").rerank(reranker=reranker).limit(n).to_pandas()
  143. else:
  144. raise ValueError("Invalid retrieval method")
  145. print(f"Retrieved {len(results)} similar items using {current_retrieval_method}.")
  146. return results
  147. except Exception as e:
  148. print(f"Error retrieving similar items: {e}")
  149. return pd.DataFrame()
  150. def rewrite_query(original_query, item_description):
  151. print(f"Rewriting query: {original_query}")
  152. messages = [
  153. {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
  154. {"role": "user", "content": f"Item description: {item_description}"},
  155. {"role": "user", "content": f"User query: {original_query}"},
  156. {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
  157. ]
  158. try:
  159. response = client.chat.completions.create(
  160. model="meta-llama/Llama-Vision-Free",
  161. messages=messages,
  162. max_tokens=512,
  163. temperature=0.7,
  164. )
  165. rewritten_query = response.choices[0].message.content
  166. print(f"Rewritten query: {rewritten_query}")
  167. return rewritten_query
  168. except Exception as e:
  169. print(f"Error rewriting query: {e}")
  170. return original_query
  171. def fashion_assistant(image, chat_input, chat_history):
  172. if chat_input != "":
  173. print("Processing chat input...")
  174. last_description = chat_history[-1][1] if chat_history else ""
  175. user_message, bot_response = process_chat_input(chat_history, chat_input)
  176. similar_items = retrieve_similar_items(bot_response)
  177. gallery_data = create_gallery_data(similar_items)
  178. return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
  179. elif image is not None:
  180. print("Processing uploaded image...")
  181. description = generate_description(image)
  182. user_message = f"I've uploaded an image. The description is: {description}"
  183. user_message, bot_response = process_chat_input(chat_history, user_message)
  184. similar_items = retrieve_similar_items(description)
  185. gallery_data = create_gallery_data(similar_items)
  186. return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
  187. else:
  188. print("No input provided.")
  189. return chat_history, "", [], ""
  190. def create_gallery_data(results):
  191. return [
  192. (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
  193. for _, row in results.iterrows()
  194. ]
  195. def on_select(evt: gr.SelectData):
  196. return f"Selected {evt.value} at index {evt.index}"
  197. def update_chat(image, chat_input, chat_history, last_description):
  198. new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
  199. if new_description:
  200. last_description = new_description
  201. return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
  202. def update_model(model_name):
  203. global current_table
  204. current_table = tables[model_name]
  205. return f"Switched to model: {model_name}"
  206. def update_retrieval_method(method):
  207. global current_retrieval_method
  208. current_retrieval_method = method
  209. return f"Switched to retrieval method: {method}"
  210. # Define the Gradio interface
  211. print("Setting up Gradio interface...")
  212. with gr.Blocks() as demo:
  213. gr.Markdown("# Interactive Fashion Assistant")
  214. with gr.Row():
  215. with gr.Column(scale=1):
  216. image_input = gr.Image(type="pil", label="Upload Clothing Image")
  217. model_dropdown = gr.Dropdown(
  218. choices=list(AVAILABLE_MODELS.keys()),
  219. value=args.default_model,
  220. label="Embedding Model"
  221. )
  222. retrieval_dropdown = gr.Dropdown(
  223. choices=RETRIEVAL_METHODS,
  224. value="Semantic Search",
  225. label="Retrieval Method"
  226. )
  227. with gr.Column(scale=1):
  228. chatbot = gr.Chatbot(label="Chat History")
  229. chat_input = gr.Textbox(label="Chat Input")
  230. chat_button = gr.Button("Send")
  231. with gr.Column(scale=2):
  232. gallery = gr.Gallery(
  233. label="Retrieved Clothes",
  234. show_label=True,
  235. elem_id="gallery",
  236. columns=[5],
  237. rows=[2],
  238. object_fit="contain",
  239. height="auto"
  240. )
  241. selected_image = gr.Textbox(label="Selected Image")
  242. chat_state = gr.State([])
  243. last_description = gr.State("")
  244. image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
  245. outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
  246. chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
  247. outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
  248. gallery.select(on_select, None, selected_image)
  249. model_dropdown.change(update_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")])
  250. retrieval_dropdown.change(update_retrieval_method, inputs=[retrieval_dropdown], outputs=[gr.Textbox(label="Retrieval Method Status")])
  251. # Disable embedding model dropdown when Hybrid Search is selected
  252. retrieval_dropdown.change(lambda x: gr.update(interactive=x != "Hybrid Search"), inputs=[retrieval_dropdown], outputs=[model_dropdown])
  253. print("Gradio interface set up successfully. Launching the app...")
  254. demo.launch()
  255. print("Fashion Assistant application is now running!")