|
@@ -3,6 +3,7 @@ import pandas as pd
|
|
|
import lancedb
|
|
|
from lancedb.pydantic import LanceModel, Vector
|
|
|
from lancedb.embeddings import get_registry
|
|
|
+from lancedb.rerankers import ColbertReranker
|
|
|
from pathlib import Path
|
|
|
from PIL import Image
|
|
|
import io
|
|
@@ -36,6 +37,9 @@ AVAILABLE_MODELS = {
|
|
|
"BAAI/bge-reranker-large": "bge-reranker-large"
|
|
|
}
|
|
|
|
|
|
+# Define retrieval methods
|
|
|
+RETRIEVAL_METHODS = ["Semantic Search", "Full Text Search", "Hybrid Search"]
|
|
|
+
|
|
|
# Connect to LanceDB
|
|
|
print("Connecting to LanceDB...")
|
|
|
db = lancedb.connect(args.table_path)
|
|
@@ -68,6 +72,7 @@ def create_table_for_model(model_name):
|
|
|
|
|
|
tables = {model: create_table_for_model(model) for model in AVAILABLE_MODELS}
|
|
|
current_table = tables[args.default_model]
|
|
|
+current_retrieval_method = "Semantic Search"
|
|
|
|
|
|
# Set up the Together API client
|
|
|
os.environ["TOGETHER_API_KEY"] = args.api_key
|
|
@@ -141,8 +146,14 @@ def process_chat_input(chat_history, user_input):
|
|
|
def retrieve_similar_items(description, n=10):
|
|
|
print(f"Retrieving similar items for: {description}")
|
|
|
try:
|
|
|
- results = current_table.search(description).limit(n).to_pandas()
|
|
|
- print(f"Retrieved {len(results)} similar items.")
|
|
|
+ if current_retrieval_method == "Semantic Search":
|
|
|
+ results = current_table.search(description).limit(n).to_pandas()
|
|
|
+ elif current_retrieval_method == "Full Text Search":
|
|
|
+ results = current_table.search(description, query_type="fts").limit(n).to_pandas()
|
|
|
+ elif current_retrieval_method == "Hybrid Search":
|
|
|
+ reranker = ColbertReranker(column="Description")
|
|
|
+ results = current_table.search(description, query_type="hybrid").rerank(reranker=reranker).limit(n).to_pandas()
|
|
|
+ print(f"Retrieved {len(results)} similar items using {current_retrieval_method}.")
|
|
|
return results
|
|
|
except Exception as e:
|
|
|
print(f"Error retrieving similar items: {e}")
|
|
@@ -211,6 +222,11 @@ def update_model(model_name):
|
|
|
current_table = tables[model_name]
|
|
|
return f"Switched to model: {model_name}"
|
|
|
|
|
|
+def update_retrieval_method(method):
|
|
|
+ global current_retrieval_method
|
|
|
+ current_retrieval_method = method
|
|
|
+ return f"Switched to retrieval method: {method}"
|
|
|
+
|
|
|
# Define the Gradio interface
|
|
|
print("Setting up Gradio interface...")
|
|
|
with gr.Blocks() as demo:
|
|
@@ -223,6 +239,11 @@ with gr.Blocks() as demo:
|
|
|
value=args.default_model,
|
|
|
label="Embedding Model"
|
|
|
)
|
|
|
+ retrieval_dropdown = gr.Dropdown(
|
|
|
+ choices=RETRIEVAL_METHODS,
|
|
|
+ value="Semantic Search",
|
|
|
+ label="Retrieval Method"
|
|
|
+ )
|
|
|
with gr.Column(scale=1):
|
|
|
chatbot = gr.Chatbot(label="Chat History")
|
|
|
chat_input = gr.Textbox(label="Chat Input")
|
|
@@ -248,6 +269,10 @@ with gr.Blocks() as demo:
|
|
|
outputs=[chat_state, chatbot, chat_input, chat_input, gallery, last_description])
|
|
|
gallery.select(on_select, None, selected_image)
|
|
|
model_dropdown.change(update_model, inputs=[model_dropdown], outputs=[gr.Textbox(label="Model Status")])
|
|
|
+ retrieval_dropdown.change(update_retrieval_method, inputs=[retrieval_dropdown], outputs=[gr.Textbox(label="Retrieval Method Status")])
|
|
|
+
|
|
|
+ # Disable embedding model dropdown when Hybrid Search is selected
|
|
|
+ retrieval_dropdown.change(lambda x: gr.update(interactive=x != "Hybrid Search"), inputs=[retrieval_dropdown], outputs=[model_dropdown])
|
|
|
|
|
|
print("Gradio interface set up successfully. Launching the app...")
|
|
|
demo.launch()
|