瀏覽代碼

Fix Methods and fix prompts

Sanyam Bhutani 6 月之前
父節點
當前提交
089e0b16c2
共有 1 個文件被更改,包括 9 次插入3 次删除
  1. 9 3
      recipes/quickstart/Multi-Modal-RAG/final_demo.py

+ 9 - 3
recipes/quickstart/Multi-Modal-RAG/final_demo.py

@@ -64,9 +64,11 @@ def create_table_for_model(model_name):
         df = pd.read_csv(args.csv_path)
         df = df.dropna().astype(str)
         tbl.add(df.to_dict('records'))
+        tbl.create_fts_index(["Description"], replace=True)
         print(f"Created and populated table {table_name}")
     else:
         tbl = db.open_table(table_name)
+        tbl.create_fts_index(["Description"], replace=True)
         print(f"Opened existing table {table_name}")
     return tbl
 
@@ -126,7 +128,7 @@ def process_chat_input(chat_history, user_input):
         messages.append({"role": "user", "content": user_msg})
         messages.append({"role": "assistant", "content": assistant_msg})
     
-    user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY."
+    user_input += ". START YOUR MESSAGE DIRECTLY WITH A RESPONSE LIST. DO NOT REPEAT THE NAME OF THE ITEM MENTIONED IN THE QUERY. Start your message with '1. ..' "
     messages.append({"role": "user", "content": user_input})
     print(f"Chat history: {messages}")
     try:
@@ -151,8 +153,12 @@ def retrieve_similar_items(description, n=10):
         elif current_retrieval_method == "Full Text Search":
             results = current_table.search(description, query_type="fts").limit(n).to_pandas()
         elif current_retrieval_method == "Hybrid Search":
-            reranker = ColbertReranker(column="Description")
+            reranker = ColbertReranker(
+                model_name="answerdotai/answerai-colbert-small-v1",
+                column="Description")
             results = current_table.search(description, query_type="hybrid").rerank(reranker=reranker).limit(n).to_pandas()
+        else:
+            raise ValueError("Invalid retrieval method")
         print(f"Retrieved {len(results)} similar items using {current_retrieval_method}.")
         return results
     except Exception as e:
@@ -276,4 +282,4 @@ with gr.Blocks() as demo:
 
 print("Gradio interface set up successfully. Launching the app...")
 demo.launch()
-print("Fashion Assistant application is now running!")
+print("Fashion Assistant application is now running!")