main.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. from groq import Groq
  3. import duckdb
  4. import yaml
  5. import glob
  6. import json
  7. def get_verified_queries(directory_path):
  8. """
  9. Reads YAML files from the specified directory, loads the verified SQL queries and their descriptions,
  10. and stores them in a dictionary.
  11. Parameters:
  12. directory_path (str): The path to the directory containing the YAML files with verified queries.
  13. Returns:
  14. dict: A dictionary where the keys are the names of the YAML files (without the directory path and file extension)
  15. and the values are the parsed content of the YAML files.
  16. """
  17. verified_queries_yaml_files = glob.glob(os.path.join(directory_path, '*.yaml'))
  18. verified_queries_dict = {}
  19. for file in verified_queries_yaml_files:
  20. with open(file, 'r') as stream:
  21. try:
  22. file_name = file[len(directory_path):-5]
  23. verified_queries_dict[file_name] = yaml.safe_load(stream)
  24. except yaml.YAMLError as exc:
  25. continue
  26. return verified_queries_dict
  27. def execute_duckdb_query_function_calling(query_name,verified_queries_dict):
  28. """
  29. Executes a SQL query from the verified queries dictionary using DuckDB and returns the result as a DataFrame.
  30. Parameters:
  31. query_name (str): The name of the query to be executed, corresponding to a key in the verified queries dictionary.
  32. verified_queries_dict (dict): A dictionary containing verified queries, where the keys are query names and the values
  33. are dictionaries with query details including the SQL statement.
  34. Returns:
  35. pandas.DataFrame: The result of the executed query as a DataFrame.
  36. """
  37. original_cwd = os.getcwd()
  38. os.chdir('data')
  39. query = verified_queries_dict[query_name]['sql']
  40. try:
  41. conn = duckdb.connect(database=':memory:', read_only=False)
  42. query_result = conn.execute(query).fetchdf().reset_index(drop=True)
  43. finally:
  44. os.chdir(original_cwd)
  45. return query_result
  46. model = "llama3-8b-8192"
  47. # Initialize the Groq client
  48. groq_api_key = os.getenv('GROQ_API_KEY')
  49. client = Groq(
  50. api_key=groq_api_key
  51. )
  52. directory_path = 'verified-queries/'
  53. verified_queries_dict = get_verified_queries(directory_path)
  54. # Display the title and introduction of the application
  55. multiline_text = """
  56. Welcome! Ask questions about employee data or purchase details, like "Show the 5 most recent purchases" or "What was the most expensive purchase?". The app matches your question to pre-verified SQL queries for accurate results.
  57. """
  58. print(multiline_text)
  59. while True:
  60. # Get user input from the console
  61. user_input = input("You: ")
  62. #Simplify verified_queries_dict to just show query name and description
  63. query_description_mapping = {key: subdict['description'] for key, subdict in verified_queries_dict.items()}
  64. # Step 1: send the conversation and available functions to the model
  65. # Define the messages to be sent to the Groq API
  66. messages = [
  67. {
  68. "role": "system",
  69. "content": '''You are a function calling LLM that uses the data extracted from the execute_duckdb_query_function_calling function to answer questions around a DuckDB dataset.
  70. Extract the query_name parameter from this mapping by finding the one whose description best matches the user's question:
  71. {query_description_mapping}
  72. '''.format(query_description_mapping=query_description_mapping)
  73. },
  74. {
  75. "role": "user",
  76. "content": user_input,
  77. }
  78. ]
  79. # Define the tool (function) to be used by the Groq API
  80. tools = [
  81. {
  82. "type": "function",
  83. "function": {
  84. "name": "execute_duckdb_query_function_calling",
  85. "description": "Executes a verified DuckDB SQL Query",
  86. "parameters": {
  87. "type": "object",
  88. "properties": {
  89. "query_name": {
  90. "type": "string",
  91. "description": "The name of the verified query (i.e. 'most-recent-purchases')",
  92. }
  93. },
  94. "required": ["query_name"],
  95. },
  96. },
  97. }
  98. ]
  99. # Send the conversation and available functions to the Groq API
  100. response = client.chat.completions.create(
  101. model=model,
  102. messages=messages,
  103. tools=tools,
  104. tool_choice="auto",
  105. max_tokens=4096
  106. )
  107. # Extract the response message and any tool calls from the response
  108. response_message = response.choices[0].message
  109. tool_calls = response_message.tool_calls
  110. # Define a dictionary of available functions
  111. available_functions = {
  112. "execute_duckdb_query_function_calling": execute_duckdb_query_function_calling,
  113. }
  114. # Iterate over the tool calls in the response
  115. for tool_call in tool_calls:
  116. function_name = tool_call.function.name # Get the function name
  117. function_to_call = available_functions[function_name] # Get the function to call
  118. function_args = json.loads(tool_call.function.arguments) # Parse the function arguments
  119. print('Query found: ', function_args.get("query_name"))
  120. # Call the function with the provided arguments
  121. function_response = function_to_call(
  122. query_name=function_args.get("query_name"),
  123. verified_queries_dict=verified_queries_dict
  124. )
  125. # Print the function response (query result)
  126. print(function_response)