final_demo.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import gradio as gr
  2. import pandas as pd
  3. import lancedb
  4. from lancedb.pydantic import LanceModel, Vector
  5. from lancedb.embeddings import get_registry
  6. from pathlib import Path
  7. from PIL import Image
  8. import io
  9. import base64
  10. from together import Together
  11. import os
  12. import logging
  13. import argparse
  14. import numpy as np
  15. # Set up argument parsing
  16. parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
  17. parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
  18. parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
  19. parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
  20. parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
  21. parser.add_argument("--api_key", required=True, help="Together API key")
  22. parser.add_argument("--default_model", default="BAAI/bge-large-en-v1.5", help="Default embedding model to use")
  23. args = parser.parse_args()
  24. # Set up logging
  25. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  26. print("Starting the Fashion Assistant application...")
  27. # Define available models
  28. AVAILABLE_MODELS = {
  29. "BAAI/bge-large-en-v1.5": "bge-large-en-v1.5",
  30. "BAAI/bge-small-en-v1.5": "bge-small-en-v1.5",
  31. "BAAI/bge-reranker-base": "bge-reranker-base",
  32. "BAAI/bge-reranker-large": "bge-reranker-large"
  33. }
  34. # Connect to LanceDB
  35. print("Connecting to LanceDB...")
  36. db = lancedb.connect(args.table_path)
  37. def create_table_for_model(model_name):
  38. print(f"Initializing the sentence transformer model: {model_name}")
  39. model = get_registry().get("sentence-transformers").create(name=model_name, device="cpu")
  40. class Schema(LanceModel):
  41. Filename: str
  42. Title: str
  43. Size: str
  44. Gender: str
  45. Description: str = model.SourceField()
  46. Category: str
  47. Type: str
  48. vector: Vector(model.ndims()) = model.VectorField()
  49. table_name = f"clothes_{AVAILABLE_MODELS[model_name]}"
  50. if not args.use_existing_table:
  51. tbl = db.create_table(name=table_name, schema=Schema, mode="overwrite")
  52. df = pd.read_csv(args.csv_path)
  53. df = df.dropna().astype(str)
  54. tbl.add(df.to_dict('records'))
  55. print(f"Created and populated table {table_name}")
  56. else:
  57. tbl = db.open_table(table_name)
  58. print(f"Opened existing table {table_name}")
  59. return tbl
  60. tables = {model: create_table_for_model(model) for model in AVAILABLE_MODELS}
  61. current_table = tables[args.default_model]
  62. # Set up the Together API client
  63. os.environ["TOGETHER_API_KEY"] = args.api_key
  64. client = Together(api_key=args.api_key)
  65. print("Together API client set up successfully.")
  66. def encode_image(image):
  67. buffered = io.BytesIO()
  68. image.save(buffered, format="JPEG")
  69. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  70. def generate_description(image):
  71. print("Generating description for uploaded image...")
  72. base64_image = encode_image(image)
  73. try:
  74. response = client.chat.completions.create(
  75. model="meta-llama/Llama-Vision-Free",
  76. messages=[
  77. {
  78. "role": "user",
  79. "content": [
  80. {
  81. "type": "image_url",
  82. "image_url": {
  83. "url": f"data:image/jpeg;base64,{base64_image}"
  84. }
  85. },
  86. {
  87. "type": "text",
  88. "text": "Describe this clothing item in detail."
  89. }
  90. ]
  91. }
  92. ],
  93. max_tokens=512,
  94. temperature=0.7,
  95. )
  96. description = response.choices[0].message.content
  97. print(f"Generated description: {description}")
  98. return description
  99. except Exception as e:
  100. print(f"Error generating description: {e}")
  101. return "Error generating description"
  102. def process_chat_input(chat_history, user_input):
  103. print(f"Processing chat input: {user_input}")
  104. messages = [
  105. {"role": "system", "content": "You are a helpful fashion assistant."}
  106. ]
  107. for user_msg, assistant_msg in chat_history:
  108. messages.append({"role": "user", "content": user_msg})
  109. messages.append({"role": "assistant", "content": assistant_msg})
  110. user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY."
  111. messages.append({"role": "user", "content": user_input})
  112. print(f"Chat history: {messages}")
  113. try:
  114. bot_response = client.chat.completions.create(
  115. model="meta-llama/Llama-Vision-Free",
  116. messages=messages,
  117. max_tokens=512,
  118. temperature=0.7,
  119. ).choices[0].message.content
  120. print(f"Bot response: {bot_response}")
  121. return user_input, bot_response
  122. except Exception as e:
  123. print(f"Error processing chat input: {e}")
  124. return user_input, "Error processing chat input"
  125. def retrieve_similar_items(description, n=10):
  126. print(f"Retrieving similar items for: {description}")
  127. try:
  128. results = current_table.search(description).limit(n).to_pandas()
  129. print(f"Retrieved {len(results)} similar items.")
  130. return results
  131. except Exception as e:
  132. print(f"Error retrieving similar items: {e}")
  133. return pd.DataFrame()
  134. def rewrite_query(original_query, item_description):
  135. print(f"Rewriting query: {original_query}")
  136. messages = [
  137. {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
  138. {"role": "user", "content": f"Item description: {item_description}"},
  139. {"role": "user", "content": f"User query: {original_query}"},
  140. {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
  141. ]
  142. try:
  143. response = client.chat.completions.create(
  144. model="meta-llama/Llama-Vision-Free",
  145. messages=messages,
  146. max_tokens=512,
  147. temperature=0.7,
  148. )
  149. rewritten_query = response.choices[0].message.content
  150. print(f"Rewritten query: {rewritten_query}")
  151. return rewritten_query
  152. except Exception as e:
  153. print(f"Error rewriting query: {e}")
  154. return original_query
  155. def fashion_assistant(image, chat_input, chat_history):
  156. if chat_input != "":
  157. print("Processing chat input...")
  158. last_description = chat_history[-1][1] if chat_history else ""
  159. user_message, bot_response = process_chat_input(chat_history, chat_input)
  160. similar_items = retrieve_similar_items(bot_response)
  161. gallery_data = create_gallery_data(similar_items)
  162. return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
  163. elif image is not None:
  164. print("Processing uploaded image...")
  165. description = generate_description(image)
  166. user_message = f"I've uploaded an image. The description is: {description}"
  167. user_message, bot_response = process_chat_input(chat_history, user_message)
  168. similar_items = retrieve_similar_items(description)
  169. gallery_data = create_gallery_data(similar_items)
  170. return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
  171. else:
  172. print("No input provided.")
  173. return chat_history, "", [], ""
  174. def create_gallery_data(results):
  175. return [
  176. (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
  177. for _, row in results.iterrows()
  178. ]
  179. def on_select(evt: gr.SelectData):
  180. return f"Selected {evt.value} at index {evt.index}"
  181. def update_chat(image, chat_input, chat_history, last_description):
  182. new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
  183. if new_description:
  184. last_description = new_description
  185. return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
  186. def update_model(model_name):
  187. global current_table
  188. current_table = tables[model_name]
  189. return f"Switched to model: {model_name}"
  190. # Define the Gradio interface
  191. print("Setting up Gradio interface...")
  192. with gr.Blocks() as demo:
  193. gr.Markdown("# Interactive Fashion Assistant")
  194. with gr.Row():
  195. with gr.Column(scale=1):
  196. image_input = gr.Image(type="pil", label="Upload Clothing Image")
  197. model_dropdown = gr.Dropdown(
  198. choices=list(AVAILABLE_MODELS.keys()),
  199. value=args.default_model,
  200. label="Embedding Model"
  201. )
  202. with gr.Column(scale=1):
  203. chatbot = gr.Chatbot(label="Chat History")
  204. chat_input = gr.Textbox(label="Chat Input")
  205. chat_button = gr.Button("Send")
  206. with gr.Column(scale=2):
  207. gallery = gr.Gallery(
  208. label="Retrieved Clothes",
  209. show_label=True,
  210. elem_id="gallery",
  211. columns=[5],
  212. rows=[2],
  213. object_fit="contain",
  214. height="auto"
  215. )
  216. selected_image = gr.Textbox(label="Selected Image")
  217. chat_state = gr.State([])
  218. last_description = gr.State("")
  219. image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
  220. outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
  221. chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
  222. outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
  223. gallery.select(on_select, None, selected_image)
  224. model_dropdown.change(update_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")])
  225. print("Gradio interface set up successfully. Launching the app...")
  226. demo.launch()
  227. print("Fashion Assistant application is now running!")