main.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import os
  2. from groq import Groq
  3. import json
  4. import duckdb
  5. import sqlparse
  6. def chat_with_groq(client, prompt, model, response_format):
  7. """
  8. This function sends a prompt to the Groq API and retrieves the AI's response.
  9. Parameters:
  10. client (Groq): The Groq API client.
  11. prompt (str): The prompt to send to the AI.
  12. model (str): The AI model to use for the response.
  13. response_format (dict): The format of the response.
  14. If response_format is a dictionary with {"type": "json_object"}, it configures JSON mode.
  15. Returns:
  16. str: The content of the AI's response.
  17. """
  18. completion = client.chat.completions.create(
  19. model=model,
  20. messages=[
  21. {
  22. "role": "user",
  23. "content": prompt
  24. }
  25. ],
  26. response_format=response_format
  27. )
  28. return completion.choices[0].message.content
  29. def execute_duckdb_query(query):
  30. """
  31. This function executes a SQL query on a DuckDB database and returns the result.
  32. Parameters:
  33. query (str): The SQL query to execute.
  34. Returns:
  35. DataFrame: The result of the query as a pandas DataFrame.
  36. """
  37. original_cwd = os.getcwd()
  38. os.chdir('data')
  39. try:
  40. conn = duckdb.connect(database=':memory:', read_only=False)
  41. query_result = conn.execute(query).fetchdf().reset_index(drop=True)
  42. finally:
  43. os.chdir(original_cwd)
  44. return query_result
  45. def get_summarization(client, user_question, df, model):
  46. """
  47. This function generates a summarization prompt based on the user's question and the resulting data.
  48. It then sends this summarization prompt to the Groq API and retrieves the AI's response.
  49. Parameters:
  50. client (Groqcloud): The Groq API client.
  51. user_question (str): The user's question.
  52. df (DataFrame): The DataFrame resulting from the SQL query.
  53. model (str): The AI model to use for the response.
  54. Returns:
  55. str: The content of the AI's response to the summarization prompt.
  56. """
  57. prompt = '''
  58. A user asked the following question pertaining to local database tables:
  59. {user_question}
  60. To answer the question, a dataframe was returned:
  61. Dataframe:
  62. {df}
  63. In a few sentences, summarize the data in the table as it pertains to the original user question. Avoid qualifiers like "based on the data" and do not comment on the structure or metadata of the table itself
  64. '''.format(user_question = user_question, df = df)
  65. # Response format is set to 'None'
  66. return chat_with_groq(client,prompt,model,None)
  67. def main():
  68. """
  69. The main function of the application. It handles user input, controls the flow of the application,
  70. and initiates a conversation in the command line.
  71. """
  72. model = "llama3-70b-8192"
  73. # Get the Groq API key and create a Groq client
  74. groq_api_key = os.getenv('GROQ_API_KEY')
  75. client = Groq(
  76. api_key=groq_api_key
  77. )
  78. print("Welcome to the DuckDB Query Generator!")
  79. print("You can ask questions about the data in the 'employees.csv' and 'purchases.csv' files.")
  80. # Load the base prompt
  81. with open('prompts/base_prompt.txt', 'r') as file:
  82. base_prompt = file.read()
  83. while True:
  84. # Get the user's question
  85. user_question = input("Ask a question: ")
  86. if user_question:
  87. # Generate the full prompt for the AI
  88. full_prompt = base_prompt.format(user_question=user_question)
  89. # Get the AI's response. Call with '{"type": "json_object"}' to use JSON mode
  90. llm_response = chat_with_groq(client, full_prompt, model, {"type": "json_object"})
  91. result_json = json.loads(llm_response)
  92. if 'sql' in result_json:
  93. sql_query = result_json['sql']
  94. results_df = execute_duckdb_query(sql_query)
  95. formatted_sql_query = sqlparse.format(sql_query, reindent=True, keyword_case='upper')
  96. print("```sql\n" + formatted_sql_query + "\n```")
  97. print(results_df.to_markdown(index=False))
  98. summarization = get_summarization(client,user_question,results_df,model)
  99. print(summarization.replace('$','\\$'))
  100. elif 'error' in result_json:
  101. print("ERROR:", 'Could not generate valid SQL for this question')
  102. print(result_json['error'])
  103. if __name__ == "__main__":
  104. main()