gradio_app.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import gradio as gr
  2. import pandas as pd
  3. import lancedb
  4. from lancedb.pydantic import LanceModel, Vector
  5. from lancedb.embeddings import get_registry
  6. from pathlib import Path
  7. from PIL import Image
  8. import io
  9. import base64
  10. from together import Together
  11. import os
  12. import logging
  13. import argparse
  14. import numpy as np
  15. # Set up argument parsing
  16. parser = argparse.ArgumentParser(description="Interactive Fashion Assistant")
  17. parser.add_argument("--images_folder", required=True, help="Path to the folder containing compressed images")
  18. parser.add_argument("--csv_path", required=True, help="Path to the CSV file with clothing data")
  19. parser.add_argument("--table_path", default="~/.lancedb", help="Table path for LanceDB")
  20. parser.add_argument("--use_existing_table", action="store_true", help="Use existing table if it exists")
  21. parser.add_argument("--api_key", required=True, help="Together API key")
  22. args = parser.parse_args()
  23. # Set up logging
  24. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  25. print("Starting the Fashion Assistant application...")
  26. # Set up the sentence transformer model for CPU use
  27. print("Initializing the sentence transformer model...")
  28. model = get_registry().get("sentence-transformers").create(name="BAAI/bge-large-en-v1.5", device="cpu")
  29. print("Sentence transformer model initialized successfully.")
  30. # Define the schema for our LanceDB table
  31. class Schema(LanceModel):
  32. Filename: str
  33. Title: str
  34. Size: str
  35. Gender: str
  36. Description: str = model.SourceField()
  37. Category: str
  38. Type: str
  39. vector: Vector(model.ndims()) = model.VectorField()
  40. # Connect to LanceDB and create/load the table
  41. print("Connecting to LanceDB...")
  42. db = lancedb.connect(args.table_path)
  43. if args.use_existing_table:
  44. tbl = db.open_table("clothes")
  45. else:
  46. tbl = db.create_table(name="clothes", schema=Schema, mode="overwrite")
  47. # Load and clean data
  48. print(f"Loading and cleaning data from CSV: {args.csv_path}")
  49. df = pd.read_csv(args.csv_path)
  50. df = df.dropna().astype(str)
  51. tbl.add(df.to_dict('records'))
  52. print(f"Loaded and cleaned {len(df)} records into the LanceDB table.")
  53. print("Connected to LanceDB and created/loaded the 'clothes' table.")
  54. # Set up the Together API client
  55. os.environ["TOGETHER_API_KEY"] = args.api_key
  56. client = Together(api_key=args.api_key)
  57. print("Together API client set up successfully.")
  58. def encode_image(image):
  59. buffered = io.BytesIO()
  60. image.save(buffered, format="JPEG")
  61. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  62. def generate_description(image):
  63. print("Generating description for uploaded image...")
  64. base64_image = encode_image(image)
  65. try:
  66. response = client.chat.completions.create(
  67. model="meta-llama/Llama-Vision-Free",
  68. messages=[
  69. {
  70. "role": "user",
  71. "content": [
  72. {
  73. "type": "image_url",
  74. "image_url": {
  75. "url": f"data:image/jpeg;base64,{base64_image}"
  76. }
  77. },
  78. {
  79. "type": "text",
  80. "text": "Describe this clothing item in detail."
  81. }
  82. ]
  83. }
  84. ],
  85. max_tokens=512,
  86. temperature=0.7,
  87. )
  88. description = response.choices[0].message.content
  89. print(f"Generated description: {description}")
  90. return description
  91. except Exception as e:
  92. print(f"Error generating description: {e}")
  93. return "Error generating description"
  94. def process_chat_input(chat_history, user_input):
  95. print(f"Processing chat input: {user_input}")
  96. messages = [
  97. {"role": "system", "content": "You are a helpful fashion assistant."}
  98. ]
  99. for user_msg, assistant_msg in chat_history:
  100. messages.append({"role": "user", "content": user_msg})
  101. messages.append({"role": "assistant", "content": assistant_msg})
  102. user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY."
  103. messages.append({"role": "user", "content": user_input})
  104. print(f"Chat history: {messages}")
  105. try:
  106. bot_response = client.chat.completions.create(
  107. model="meta-llama/Llama-Vision-Free",
  108. messages=messages,
  109. max_tokens=512,
  110. temperature=0.7,
  111. ).choices[0].message.content
  112. print(f"Bot response: {bot_response}")
  113. return user_input, bot_response
  114. except Exception as e:
  115. print(f"Error processing chat input: {e}")
  116. return user_input, "Error processing chat input"
  117. def retrieve_similar_items(description, n=10):
  118. print(f"Retrieving similar items for: {description}")
  119. try:
  120. results = tbl.search(description).limit(n).to_pandas()
  121. print(f"Retrieved {len(results)} similar items.")
  122. return results
  123. except Exception as e:
  124. print(f"Error retrieving similar items: {e}")
  125. return pd.DataFrame()
  126. def rewrite_query(original_query, item_description):
  127. print(f"Rewriting query: {original_query}")
  128. messages = [
  129. {"role": "system", "content": "You are a helpful fashion assistant. Rewrite the user's query to include details from the item description."},
  130. {"role": "user", "content": f"Item description: {item_description}"},
  131. {"role": "user", "content": f"User query: {original_query}"},
  132. {"role": "user", "content": "Please rewrite the query to include relevant details from the item description."}
  133. ]
  134. try:
  135. response = client.chat.completions.create(
  136. model="meta-llama/Llama-Vision-Free",
  137. messages=messages,
  138. max_tokens=512,
  139. temperature=0.7,
  140. )
  141. rewritten_query = response.choices[0].message.content
  142. print(f"Rewritten query: {rewritten_query}")
  143. return rewritten_query
  144. except Exception as e:
  145. print(f"Error rewriting query: {e}")
  146. return original_query
  147. def fashion_assistant(image, chat_input, chat_history):
  148. if chat_input is not "":
  149. print("Processing chat input...")
  150. last_description = chat_history[-1][1] if chat_history else ""
  151. #rewritten_query = rewrite_query(chat_input, last_description)
  152. user_message, bot_response = process_chat_input(chat_history, chat_input)
  153. similar_items = retrieve_similar_items(bot_response)
  154. gallery_data = create_gallery_data(similar_items)
  155. return chat_history + [[user_message, bot_response]], bot_response, gallery_data, last_description
  156. elif image is not None:
  157. print("Processing uploaded image...")
  158. description = generate_description(image)
  159. user_message = f"I've uploaded an image. The description is: {description}"
  160. user_message, bot_response = process_chat_input(chat_history, user_message)
  161. similar_items = retrieve_similar_items(description)
  162. gallery_data = create_gallery_data(similar_items)
  163. return chat_history + [[user_message, bot_response]], bot_response, gallery_data, description
  164. else:
  165. print("No input provided.")
  166. return chat_history, "", [], ""
  167. def create_gallery_data(results):
  168. return [
  169. (str(Path(args.images_folder) / row['Filename']), f"{row['Title']}\n{row['Description']}")
  170. for _, row in results.iterrows()
  171. ]
  172. def on_select(evt: gr.SelectData):
  173. return f"Selected {evt.value} at index {evt.index}"
  174. def update_chat(image, chat_input, chat_history, last_description):
  175. new_chat_history, last_response, gallery_data, new_description = fashion_assistant(image, chat_input, chat_history)
  176. if new_description:
  177. last_description = new_description
  178. return new_chat_history, new_chat_history, "", last_response, gallery_data, last_description
  179. # Define the Gradio interface
  180. print("Setting up Gradio interface...")
  181. with gr.Blocks() as demo:
  182. gr.Markdown("# Interactive Fashion Assistant")
  183. with gr.Row():
  184. with gr.Column(scale=1):
  185. image_input = gr.Image(type="pil", label="Upload Clothing Image")
  186. with gr.Column(scale=1):
  187. chatbot = gr.Chatbot(label="Chat History")
  188. chat_input = gr.Textbox(label="Chat Input")
  189. chat_button = gr.Button("Send")
  190. with gr.Column(scale=2):
  191. gallery = gr.Gallery(
  192. label="Retrieved Clothes",
  193. show_label=True,
  194. elem_id="gallery",
  195. columns=[5],
  196. rows=[2],
  197. object_fit="contain",
  198. height="auto"
  199. )
  200. selected_image = gr.Textbox(label="Selected Image")
  201. chat_state = gr.State([])
  202. last_description = gr.State("")
  203. image_input.change(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
  204. outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
  205. chat_button.click(update_chat, inputs=[image_input, chat_input, chat_state, last_description],
  206. outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
  207. gallery.select(on_select, None, selected_image)
  208. print("Gradio interface set up successfully. Launching the app...")
  209. demo.launch()
  210. print("Fashion Assistant application is now running!")