{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e8cba0b6",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/use_cases/text2sql/StructuredLlama.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>  \n",
    "\n",
    "## Use Llama 3 to chat about structured data\n",
    "This demo shows how to use LangChain with Llama 3 to query structured data, the 2023-24 NBA roster info, stored in a SQLite DB, and how to ask Llama 3 follow up question about the DB."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f839d07d",
   "metadata": {},
   "source": [
    "We start by installing the necessary packages:\n",
    "- [Replicate](https://replicate.com/) to host the Llama 3 model\n",
    "- [langchain](https://python.langchain.com/docs/get_started/introduction) provides necessary RAG tools for this demo\n",
    "\n",
    "**Note** We will be using [Replicate](https://replicate.com/meta/meta-llama-3-8b-instruct) to run the examples here. You will need to first sign in with Replicate with your github account, then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. You can also use other Llama 3 cloud providers such as [Groq](https://console.groq.com/), [Together](https://api.together.xyz/playground/language/meta-llama/Llama-3-8b-hf), or [Anyscale](https://app.endpoints.anyscale.com/playground) - see Section 2 of the Getting to Know Llama [notebook](https://github.com/meta-llama/llama-recipes/blob/main/recipes/quickstart/Getting_to_know_Llama.ipynb) for more information.\n",
    "\n",
    "If you'd like to run Llama 3 locally for the benefits of privacy, no cost or no rate limit (some Llama 3 hosting providers set limits for free plan of queries or tokens per second or minute), see [Running Llama Locally](https://github.com/meta-llama/llama-recipes/blob/main/recipes/quickstart/Running_Llama2_Anywhere/Running_Llama_on_Mac_Windows_Linux.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33fb3190-59fb-4edd-82dd-f20f6eab3e47",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install langchain replicate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa4562d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from getpass import getpass\n",
    "import os\n",
    "\n",
    "REPLICATE_API_TOKEN = getpass()\n",
    "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e586b75",
   "metadata": {},
   "source": [
    "Next we call the Llama 3 8b chat model from Replicate. You can also use Llama 3 70b model by replacing the `model` name with \"meta/meta-llama-3-70b-instruct\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dcd744c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.llms import Replicate\n",
    "llm = Replicate(\n",
    "    model=\"meta/meta-llama-3-8b-instruct\",\n",
    "    model_kwargs={\"temperature\": 0.0, \"top_p\": 1, \"max_new_tokens\":500}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d421ae7",
   "metadata": {},
   "source": [
    "To recreate the `nba_roster.db` file, run the two commands below:\n",
    "- `python txt2csv.py` to convert the `nba.txt` file to `nba_roster.csv`. The `nba.txt` file was created by scraping the NBA roster info from the web.\n",
    "- `python csv2db.py` to convert `nba_roster.csv` to `nba_roster.db`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb99f39-cd7a-4db6-91dd-02f3bf80347c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "# Note: to run in Colab, you need to upload the nba_roster.db file in the repo to the Colab folder first.\n",
    "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info=0)\n",
    "\n",
    "def get_schema():\n",
    "    return db.get_table_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d793ce7-324b-4861-926c-54973d7c9b43",
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"What team is Klay Thompson on?\"\n",
    "prompt = f\"\"\"Based on the table schema below, write a SQL query that would answer the user's question; just return the SQL query and nothing else.\n",
    "\n",
    "Scheme:\n",
    "{get_schema()}\n",
    "\n",
    "Question: {question}\n",
    "\n",
    "SQL Query:\"\"\"\n",
    "\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70776558",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer = llm.invoke(prompt)\n",
    "print(answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afcf423a",
   "metadata": {},
   "source": [
    "If you don't have the \"just return the SQL query and nothing else\" in the prompt above, or even with it but asking Llama 2 which doesn't follow instructions as well as Llama 3, you'll likely get more text other than the SQL query back in the answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62472ce6-794b-4a61-b88c-a1e031e28e4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# note this is a dangerous operation and for demo purpose only; in production app you'll need to safe-guard any DB operation\n",
    "result = db.run(answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ed4bc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# how about a follow up question\n",
    "follow_up = \"What's his salary?\"\n",
    "print(llm.invoke(follow_up))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98b2c523",
   "metadata": {},
   "source": [
    "Since we did not pass any context along with the follow-up to the model it did not know who \"his\" is and just picked LeBron James.\n",
    "\n",
    "Let's try to fix it by adding context to the follow-up prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c305278-29d2-4e88-9b3d-ad67c94ce0f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = f\"\"\"Based on the table schema, question, SQL query, and SQL response below, write a new SQL response; be concise, just output the SQL response.\n",
    "\n",
    "Scheme:\n",
    "{get_schema()}\n",
    "\n",
    "Question: {follow_up}\n",
    "SQL Query: {question}\n",
    "SQL Result: {result}\n",
    "\n",
    "New SQL Response:\n",
    "\"\"\"\n",
    "print(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03739b96-e607-4fa9-bc5c-df118198dc7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_answer = llm.invoke(prompt)\n",
    "print(new_answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c782abb6-3b44-45be-8694-70fc29b82523",
   "metadata": {},
   "source": [
    "Because we have \"be concise, just output the SQL response\", Llama 3 is able to just generate the SQL statement; otherwise output parsing will be needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ecfca53-be7e-4668-bad1-5ca7571817d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "db.run(new_answer)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}