Sanyam Bhutani 6 hónapja
szülő
commit
f1c1f6d824
1 módosított fájl, 27 hozzáadás és 2 törlés
  1. 27 2
      recipes/quickstart/Multi-Modal-RAG/final_demo.py

+ 27 - 2
recipes/quickstart/Multi-Modal-RAG/final_demo.py

@@ -3,6 +3,7 @@ import pandas as pd
 import lancedb
 from lancedb.pydantic import LanceModel, Vector
 from lancedb.embeddings import get_registry
+from lancedb.rerankers import ColbertReranker
 from pathlib import Path
 from PIL import Image
 import io
@@ -36,6 +37,9 @@ AVAILABLE_MODELS = {
     "BAAI/bge-reranker-large": "bge-reranker-large"
 }
 
+# Define retrieval methods
+RETRIEVAL_METHODS = ["Semantic Search", "Full Text Search", "Hybrid Search"]
+
 # Connect to LanceDB
 print("Connecting to LanceDB...")
 db = lancedb.connect(args.table_path)
@@ -68,6 +72,7 @@ def create_table_for_model(model_name):
 
 tables = {model: create_table_for_model(model) for model in AVAILABLE_MODELS}
 current_table = tables[args.default_model]
+current_retrieval_method = "Semantic Search"
 
 # Set up the Together API client
 os.environ["TOGETHER_API_KEY"] = args.api_key
@@ -141,8 +146,14 @@ def process_chat_input(chat_history, user_input):
 def retrieve_similar_items(description, n=10):
     print(f"Retrieving similar items for: {description}")
     try:
-        results = current_table.search(description).limit(n).to_pandas()
-        print(f"Retrieved {len(results)} similar items.")
+        if current_retrieval_method == "Semantic Search":
+            results = current_table.search(description).limit(n).to_pandas()
+        elif current_retrieval_method == "Full Text Search":
+            results = current_table.search(description, query_type="fts").limit(n).to_pandas()
+        elif current_retrieval_method == "Hybrid Search":
+            reranker = ColbertReranker(column="Description")
+            results = current_table.search(description, query_type="hybrid").rerank(reranker=reranker).limit(n).to_pandas()
+        print(f"Retrieved {len(results)} similar items using {current_retrieval_method}.")
         return results
     except Exception as e:
         print(f"Error retrieving similar items: {e}")
@@ -211,6 +222,11 @@ def update_model(model_name):
     current_table = tables[model_name]
     return f"Switched to model: {model_name}"
 
+def update_retrieval_method(method):
+    global current_retrieval_method
+    current_retrieval_method = method
+    return f"Switched to retrieval method: {method}"
+
 # Define the Gradio interface
 print("Setting up Gradio interface...")
 with gr.Blocks() as demo:
@@ -223,6 +239,11 @@ with gr.Blocks() as demo:
                 value=args.default_model,
                 label="Embedding Model"
             )
+            retrieval_dropdown = gr.Dropdown(
+                choices=RETRIEVAL_METHODS,
+                value="Semantic Search",
+                label="Retrieval Method"
+            )
         with gr.Column(scale=1):
             chatbot = gr.Chatbot(label="Chat History")
             chat_input = gr.Textbox(label="Chat Input")
@@ -248,6 +269,10 @@ with gr.Blocks() as demo:
                       outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
     gallery.select(on_select, None, selected_image)
     model_dropdown.change(update_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")])
+    retrieval_dropdown.change(update_retrieval_method, inputs=[retrieval_dropdown], outputs=[gr.Textbox(label="Retrieval Method Status")])
+
+    # Disable embedding model dropdown when Hybrid Search is selected
+    retrieval_dropdown.change(lambda x: gr.update(interactive=x != "Hybrid Search"), inputs=[retrieval_dropdown], outputs=[model_dropdown])
 
 print("Gradio interface set up successfully. Launching the app...")
 demo.launch()