|
@@ -0,0 +1,784 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "de56ee05-3b71-43c9-8cbf-6ad9b3233f38",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/dlai/AI_Agents_in_LangGraph_L1_Build_an_Agent_from_Scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "16ba1896-6867-4c68-9951-b0aadb819782",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "This notebook ports the DeepLearning.AI short course [AI Agents in LangGraph Lesson 1 Build an Agent from Scratch](https://learn.deeplearning.ai/courses/ai-agents-in-langgraph/lesson/2/build-an-agent-from-scratch) to using Llama 3, with a bonus section that ports the agent from scratch code to using LangGraph, introduced in [Lession 2 LangGraph Components](https://learn.deeplearning.ai/courses/ai-agents-in-langgraph/lesson/3/langgraph-components) of the course. \n",
|
|
|
+ "\n",
|
|
|
+ "You should take the course, especially the first two lessons, before or after going through this notebook, to have a deeper understanding."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "9b168b57-6ff8-41d1-8f8f-a0c4a5ff108e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "!pip install groq"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "2c067d5f-c58c-47c0-8ccd-9a8710711bf7",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import os \n",
|
|
|
+ "from groq import Groq\n",
|
|
|
+ "\n",
|
|
|
+ "os.environ['GROQ_API_KEY'] = 'your_groq_api_key' # get a free key at https://console.groq.com/keys"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "5f7d8d95-36fb-4b14-bd28-99d305c0fded",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# a quick sanity test of calling Llama 3 70b on Groq \n",
|
|
|
+ "# see https://console.groq.com/docs/text-chat for more info\n",
|
|
|
+ "client = Groq()\n",
|
|
|
+ "chat_completion = client.chat.completions.create(\n",
|
|
|
+ " model=\"llama3-70b-8192\",\n",
|
|
|
+ " messages=[{\"role\": \"user\", \"content\": \"what are the words Charlotte wrote for the pig?\"}]\n",
|
|
|
+ ")\n",
|
|
|
+ "print(chat_completion.choices[0].message.content)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "f758c771-5afe-4008-9d7f-92a6f526778b",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### ReAct Agent from Sractch"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "c00c0479-0913-4a92-8991-fe5a9a936bdb",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "client = Groq()\n",
|
|
|
+ "model = \"llama3-8b-8192\" # this model works with the prompt below only for the first simpler example; you'll see how to modify the prompt to make it work for a more complicated question\n",
|
|
|
+ "#model = \"llama3-70b-8192\" # this model works with the prompt below for both example questions \n",
|
|
|
+ "\n",
|
|
|
+ "class Agent:\n",
|
|
|
+ " def __init__(self, system=\"\"):\n",
|
|
|
+ " self.system = system\n",
|
|
|
+ " self.messages = []\n",
|
|
|
+ " if self.system:\n",
|
|
|
+ " self.messages.append({\"role\": \"system\", \"content\": system})\n",
|
|
|
+ "\n",
|
|
|
+ " def __call__(self, message):\n",
|
|
|
+ " self.messages.append({\"role\": \"user\", \"content\": message})\n",
|
|
|
+ " result = self.execute()\n",
|
|
|
+ " self.messages.append({\"role\": \"assistant\", \"content\": result})\n",
|
|
|
+ " return result\n",
|
|
|
+ "\n",
|
|
|
+ " def execute(self):\n",
|
|
|
+ " completion = client.chat.completions.create(\n",
|
|
|
+ " model=model,\n",
|
|
|
+ " temperature=0,\n",
|
|
|
+ " messages=self.messages)\n",
|
|
|
+ " return completion.choices[0].message.content\n",
|
|
|
+ " "
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "f766fb44-e8c2-43db-af83-8b9053a334ef",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "prompt = \"\"\"\n",
|
|
|
+ "You run in a loop of Thought, Action, PAUSE, Observation.\n",
|
|
|
+ "At the end of the loop you output an Answer\n",
|
|
|
+ "Use Thought to describe your thoughts about the question you have been asked.\n",
|
|
|
+ "Use Action to run one of the actions available to you - then return PAUSE.\n",
|
|
|
+ "Observation will be the result of running those actions.\n",
|
|
|
+ "\n",
|
|
|
+ "Your available actions are:\n",
|
|
|
+ "\n",
|
|
|
+ "calculate:\n",
|
|
|
+ "e.g. calculate: 4 * 7 / 3\n",
|
|
|
+ "Runs a calculation and returns the number - uses Python so be sure to use floating point syntax if necessary\n",
|
|
|
+ "\n",
|
|
|
+ "average_dog_weight:\n",
|
|
|
+ "e.g. average_dog_weight: Collie\n",
|
|
|
+ "returns average weight of a dog when given the breed\n",
|
|
|
+ "\n",
|
|
|
+ "Example session:\n",
|
|
|
+ "\n",
|
|
|
+ "Question: How much does a Bulldog weigh?\n",
|
|
|
+ "Thought: I should look the dogs weight using average_dog_weight\n",
|
|
|
+ "Action: average_dog_weight: Bulldog\n",
|
|
|
+ "PAUSE\n",
|
|
|
+ "\n",
|
|
|
+ "You will be called again with this:\n",
|
|
|
+ "\n",
|
|
|
+ "Observation: A Bulldog weights 51 lbs\n",
|
|
|
+ "\n",
|
|
|
+ "You then output:\n",
|
|
|
+ "\n",
|
|
|
+ "Answer: A bulldog weights 51 lbs\n",
|
|
|
+ "\"\"\".strip()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "93ab1290-625b-4b69-be4d-210fca43a513",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def calculate(what):\n",
|
|
|
+ " return eval(what)\n",
|
|
|
+ "\n",
|
|
|
+ "def average_dog_weight(name):\n",
|
|
|
+ " if name in \"Scottish Terrier\": \n",
|
|
|
+ " return(\"Scottish Terriers average 20 lbs\")\n",
|
|
|
+ " elif name in \"Border Collie\":\n",
|
|
|
+ " return(\"a Border Collies average weight is 37 lbs\")\n",
|
|
|
+ " elif name in \"Toy Poodle\":\n",
|
|
|
+ " return(\"a toy poodles average weight is 7 lbs\")\n",
|
|
|
+ " else:\n",
|
|
|
+ " return(\"An average dog weights 50 lbs\")\n",
|
|
|
+ "\n",
|
|
|
+ "known_actions = {\n",
|
|
|
+ " \"calculate\": calculate,\n",
|
|
|
+ " \"average_dog_weight\": average_dog_weight\n",
|
|
|
+ "}"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "52f900d9-15f0-4f48-9bf3-6165c70e4b42",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot = Agent(prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "f1b612c9-2a7d-4325-b36f-182899252538",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "result = abot(\"How much does a toy poodle weigh?\")\n",
|
|
|
+ "print(result)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "e27dda33-c76d-4a19-8aef-02ba5389e7a6",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot.messages"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "b6e85ca1-85af-43e3-a5ea-c5faf0935361",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# manually call the exeternal func (tool) for now\n",
|
|
|
+ "result = average_dog_weight(\"Toy Poodle\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "a9263ac7-fa81-4c95-91c8-a6c0741ab7f8",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "result"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "cb309710-0693-422f-a739-38ca9455e497",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "next_prompt = \"Observation: {}\".format(result)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "bd567e42-b5a9-4e4e-8807-38bb1d6c80a4",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot(next_prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "255bf148-bf85-40c5-b33e-d849a42c127b",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot.messages"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "c286a6d5-b5b3-473b-bad6-aa6f1468e603",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot = Agent(prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "f5e13b6e-e68e-45c2-b688-a257b531e482",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "question = \"\"\"I have 2 dogs, a border collie and a scottish terrier. \\\n",
|
|
|
+ "What is their combined weight\"\"\"\n",
|
|
|
+ "abot(question)\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "049202f1-585f-42c3-8511-08eca7e5ed0e",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot.messages"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "f086f19a-30fe-40ca-aafb-f1ce7c28982d",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "next_prompt = \"Observation: {}\".format(average_dog_weight(\"Border Collie\"))\n",
|
|
|
+ "print(next_prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "1747c78d-642d-4f57-81a0-27218eab3958",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot(next_prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "85809d8f-cd95-4e0a-acb7-9705817bea70",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot.messages"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "e77591fa-4e04-4eb6-8a40-ca26a71765f9",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "next_prompt = \"Observation: {}\".format(average_dog_weight(\"Scottish Terrier\"))\n",
|
|
|
+ "print(next_prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "1f72b638-de07-4972-bbdb-8c8602f3d143",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot(next_prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "eb5bf29d-22f9-4c0d-aea6-7e9c99e71835",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot.messages"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "a67add73-b3c3-42be-9c54-f8a6ac828869",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "next_prompt = \"Observation: {}\".format(eval(\"37 + 20\"))\n",
|
|
|
+ "print(next_prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "801fda04-9756-4ae4-9990-559216d38be8",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot(next_prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "56f7b9f4-289f-498d-8bc8-da9bb7365d52",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot.messages"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "155ee9b3-a4f9-43dd-b23e-0f268f72f198",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Automate the ReAct action execution"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "5b2196f8-88e6-4eb4-82b0-cf251a07e313",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import re\n",
|
|
|
+ "\n",
|
|
|
+ "# automate the action execution above to make the whole ReAct (Thought - Action- Observsation) process fully automated\n",
|
|
|
+ "action_re = re.compile('^Action: (\\w+): (.*)$') # python regular expression to selection action"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "ea5710d6-5d9a-46ff-a275-46311257d9fd",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def query(question, max_turns=5):\n",
|
|
|
+ " i = 0\n",
|
|
|
+ " bot = Agent(prompt) # set system prompt\n",
|
|
|
+ " next_prompt = question\n",
|
|
|
+ " while i < max_turns:\n",
|
|
|
+ " i += 1\n",
|
|
|
+ " result = bot(next_prompt)\n",
|
|
|
+ " print(result)\n",
|
|
|
+ " actions = [\n",
|
|
|
+ " action_re.match(a)\n",
|
|
|
+ " for a in result.split('\\n')\n",
|
|
|
+ " if action_re.match(a)\n",
|
|
|
+ " ]\n",
|
|
|
+ " if actions:\n",
|
|
|
+ " # There is an action to run\n",
|
|
|
+ " action, action_input = actions[0].groups()\n",
|
|
|
+ " if action not in known_actions:\n",
|
|
|
+ " raise Exception(\"Unknown action: {}: {}\".format(action, action_input))\n",
|
|
|
+ " print(\" -- running {} {}\".format(action, action_input))\n",
|
|
|
+ "\n",
|
|
|
+ " # key to make the agent process fully automated:\n",
|
|
|
+ " # programtically call the external func with arguments, with the info returned by LLM\n",
|
|
|
+ " observation = known_actions[action](action_input) \n",
|
|
|
+ "\n",
|
|
|
+ " print(\"Observation:\", observation)\n",
|
|
|
+ " next_prompt = \"Observation: {}\".format(observation)\n",
|
|
|
+ " else:\n",
|
|
|
+ " return"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "232d0818-7580-424b-9538-1e2b1c15360b",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "#### Using model \"llama3-8b-8192\", the code below will cause an invalid syntax error because the Action returned is calculate: (average_dog_weight: Border Collie) + (average_dog_weight: Scottish Terrier), instead of the expected \"Action: average_dog_weight: Border Collie\"."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "bb0095f3-b3f1-48cf-b3fb-36049b6b91f0",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "question = \"\"\"I have 2 dogs, a border collie and a scottish terrier. \\\n",
|
|
|
+ "What is their combined weight\"\"\"\n",
|
|
|
+ "query(question)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "952ffac8-5ec2-48f3-8049-d03c130dad0d",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "#### Prompt engineering in action:\n",
|
|
|
+ "REPLACE \"Use Thought to describe your thoughts about the question you have been asked. Use Action to run one of the actions available to you - then return PAUSE.\" with \n",
|
|
|
+ "\"First, use Thought to describe your thoughts about the question you have been asked, and generate Action to run one of the actions available to you, then return PAUSE.\""
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "ec791ad6-b39a-4f46-b149-704c23d6c506",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "prompt = \"\"\"\n",
|
|
|
+ "You run in a loop of Thought, Action, PAUSE, Observation.\n",
|
|
|
+ "At the end of the loop you output an Answer.\n",
|
|
|
+ "First, use Thought to describe your thoughts about the question you have been asked, and generate Action to run one of the actions available to you, then return PAUSE.\n",
|
|
|
+ "Observation will be the result of running those actions.\n",
|
|
|
+ "\n",
|
|
|
+ "Your available actions are:\n",
|
|
|
+ "\n",
|
|
|
+ "calculate:\n",
|
|
|
+ "e.g. calculate: 4 * 7 / 3\n",
|
|
|
+ "Runs a calculation and returns the number - uses Python so be sure to use floating point syntax if necessary\n",
|
|
|
+ "\n",
|
|
|
+ "average_dog_weight:\n",
|
|
|
+ "e.g. average_dog_weight: Collie\n",
|
|
|
+ "returns average weight of a dog when given the breed\n",
|
|
|
+ "\n",
|
|
|
+ "Example session:\n",
|
|
|
+ "\n",
|
|
|
+ "Question: How much does a Bulldog weigh?\n",
|
|
|
+ "Thought: I should look the dogs weight using average_dog_weight\n",
|
|
|
+ "Action: average_dog_weight: Bulldog\n",
|
|
|
+ "PAUSE\n",
|
|
|
+ "\n",
|
|
|
+ "You will be called again with this:\n",
|
|
|
+ "\n",
|
|
|
+ "Observation: A Bulldog weights 51 lbs\n",
|
|
|
+ "\n",
|
|
|
+ "You then output:\n",
|
|
|
+ "\n",
|
|
|
+ "Answer: A bulldog weights 51 lbs\n",
|
|
|
+ "\"\"\".strip()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "90bcf731-4d89-473b-98e1-53826da149f9",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "question = \"\"\"I have 2 dogs, a border collie and a scottish terrier. \\\n",
|
|
|
+ "What is their combined weight\"\"\"\n",
|
|
|
+ "query(question)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "09d30a8f-3783-4ee5-a48e-7d89e22a508a",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Bonus: Port the Agent Implementation to LangGraph"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "6b5ed82e-2d70-45ac-b026-904da211f81a",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "!pip install langchain\n",
|
|
|
+ "!pip install langgraph\n",
|
|
|
+ "!pip install langchain_openai\n",
|
|
|
+ "!pip install langchain_community\n",
|
|
|
+ "!pip install httpx\n",
|
|
|
+ "!pip install langchain-groq"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "a8ed3b90-688e-4aa2-8e43-e951af29a57f",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "from langgraph.graph import StateGraph, END\n",
|
|
|
+ "from typing import TypedDict, Annotated\n",
|
|
|
+ "import operator\n",
|
|
|
+ "from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage\n",
|
|
|
+ "from langchain_openai import ChatOpenAI\n",
|
|
|
+ "from langchain_community.tools.tavily_search import TavilySearchResults"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "c555f945-7db0-4dc9-9ea5-5632bf941fe4",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "from langchain_groq import ChatGroq\n",
|
|
|
+ "\n",
|
|
|
+ "model = ChatGroq(temperature=0, model_name=\"llama3-8b-8192\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "7755a055-fa1f-474f-b558-230cc5a67a33",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "from langchain_core.tools import tool\n",
|
|
|
+ "from langgraph.prebuilt import ToolNode\n",
|
|
|
+ "\n",
|
|
|
+ "@tool\n",
|
|
|
+ "def calculate(what):\n",
|
|
|
+ " \"\"\"Runs a calculation and returns the number.\"\"\"\n",
|
|
|
+ " return eval(what)\n",
|
|
|
+ "\n",
|
|
|
+ "@tool\n",
|
|
|
+ "def average_dog_weight(name):\n",
|
|
|
+ " \"\"\"Returns the average weight of a dog.\"\"\"\n",
|
|
|
+ " if name in \"Scottish Terrier\":\n",
|
|
|
+ " return(\"Scottish Terriers average 20 lbs\")\n",
|
|
|
+ " elif name in \"Border Collie\":\n",
|
|
|
+ " return(\"a Border Collies average weight is 37 lbs\")\n",
|
|
|
+ " elif name in \"Toy Poodle\":\n",
|
|
|
+ " return(\"a toy poodles average weight is 7 lbs\")\n",
|
|
|
+ " else:\n",
|
|
|
+ " return(\"An average dog weights 50 lbs\")\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "4a003862-8fd2-45b1-8fe4-78d7cd5888d9",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "prompt = \"\"\"\n",
|
|
|
+ "You run in a loop of Thought, Action, Observation.\n",
|
|
|
+ "At the end of the loop you output an Answer\n",
|
|
|
+ "Use Thought to describe your thoughts about the question you have been asked.\n",
|
|
|
+ "Use Action to run one of the actions available to you.\n",
|
|
|
+ "Observation will be the result of running those actions.\n",
|
|
|
+ "\n",
|
|
|
+ "Your available actions are:\n",
|
|
|
+ "\n",
|
|
|
+ "calculate:\n",
|
|
|
+ "e.g. calculate: 4 * 7 / 3\n",
|
|
|
+ "Runs a calculation and returns the number - uses Python so be sure to use floating point syntax if necessary\n",
|
|
|
+ "\n",
|
|
|
+ "average_dog_weight:\n",
|
|
|
+ "e.g. average_dog_weight: Collie\n",
|
|
|
+ "returns average weight of a dog when given the breed\n",
|
|
|
+ "\n",
|
|
|
+ "Example session:\n",
|
|
|
+ "\n",
|
|
|
+ "Question: How much does a Bulldog weigh?\n",
|
|
|
+ "Thought: I should look the dogs weight using average_dog_weight\n",
|
|
|
+ "Action: average_dog_weight: Bulldog\n",
|
|
|
+ "\n",
|
|
|
+ "You will be called again with this:\n",
|
|
|
+ "\n",
|
|
|
+ "Observation: A Bulldog weights 51 lbs\n",
|
|
|
+ "\n",
|
|
|
+ "You then output:\n",
|
|
|
+ "\n",
|
|
|
+ "Answer: A bulldog weights 51 lbs\n",
|
|
|
+ "\"\"\".strip()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "471c0aa7-547f-4d5f-9e99-73ef47101d41",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class AgentState(TypedDict):\n",
|
|
|
+ " messages: Annotated[list[AnyMessage], operator.add]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "530e8a60-085a-4485-af03-bafc6b2c1d88",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class Agent:\n",
|
|
|
+ "\n",
|
|
|
+ " def __init__(self, model, tools, system=\"\"):\n",
|
|
|
+ " self.system = system\n",
|
|
|
+ " graph = StateGraph(AgentState)\n",
|
|
|
+ " graph.add_node(\"llm\", self.call_llm)\n",
|
|
|
+ " graph.add_node(\"action\", self.take_action)\n",
|
|
|
+ " graph.add_conditional_edges(\n",
|
|
|
+ " \"llm\",\n",
|
|
|
+ " self.exists_action,\n",
|
|
|
+ " {True: \"action\", False: END}\n",
|
|
|
+ " )\n",
|
|
|
+ " graph.add_edge(\"action\", \"llm\")\n",
|
|
|
+ " graph.set_entry_point(\"llm\")\n",
|
|
|
+ " self.graph = graph.compile()\n",
|
|
|
+ " self.tools = {t.name: t for t in tools}\n",
|
|
|
+ " self.model = model.bind_tools(tools)\n",
|
|
|
+ "\n",
|
|
|
+ " def exists_action(self, state: AgentState):\n",
|
|
|
+ " result = state['messages'][-1]\n",
|
|
|
+ " return len(result.tool_calls) > 0\n",
|
|
|
+ "\n",
|
|
|
+ " def call_llm(self, state: AgentState):\n",
|
|
|
+ " messages = state['messages']\n",
|
|
|
+ " if self.system:\n",
|
|
|
+ " messages = [SystemMessage(content=self.system)] + messages\n",
|
|
|
+ " message = self.model.invoke(messages)\n",
|
|
|
+ " return {'messages': [message]}\n",
|
|
|
+ "\n",
|
|
|
+ " def take_action(self, state: AgentState):\n",
|
|
|
+ " tool_calls = state['messages'][-1].tool_calls\n",
|
|
|
+ " results = []\n",
|
|
|
+ " for t in tool_calls:\n",
|
|
|
+ " print(f\"Calling: {t}\")\n",
|
|
|
+ " result = self.tools[t['name']].invoke(t['args'])\n",
|
|
|
+ " results.append(ToolMessage(tool_call_id=t['id'], name=t['name'], content=str(result)))\n",
|
|
|
+ " print(\"Back to the model!\")\n",
|
|
|
+ " return {'messages': results}"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "3db8dcea-d4eb-46df-bd90-55acd4c5520a",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "abot = Agent(model, [calculate, average_dog_weight], system=prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "72c62e36-9321-40d2-86d8-b3c9caf3020f",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "messages = [HumanMessage(content=\"How much does a Toy Poodle weigh?\")]\n",
|
|
|
+ "result = abot.graph.invoke({\"messages\": messages})\n",
|
|
|
+ "result['messages'], result['messages'][-1].content\n",
|
|
|
+ "\n",
|
|
|
+ "# the code above will cause an error because Llama 3 8B incorrectly returns an extra \"calculate\" tool call"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "56b4c622-b306-4aa3-84e6-4ccd6d6f272f",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# using the Llama 3 70B will fix the error\n",
|
|
|
+ "model = ChatGroq(temperature=0, model_name=\"llama3-70b-8192\")\n",
|
|
|
+ "abot = Agent(model, [calculate, average_dog_weight], system=prompt)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "629ca375-979a-45d7-bad8-7240ae9ad844",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Toy Poodle case sensitive here - can be fixed easily by modifying def average_dog_weight\n",
|
|
|
+ "messages = [HumanMessage(content=\"How much does a Toy Poodle weigh?\")]\n",
|
|
|
+ "result = abot.graph.invoke({\"messages\": messages})\n",
|
|
|
+ "result['messages'], result['messages'][-1].content"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "30e253ae-e742-4df8-92e6-fadfc3826003",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "messages = [HumanMessage(content=\"I have 2 dogs, a border collie and a scottish terrier. What are their average weights? Total weight?\")]\n",
|
|
|
+ "result = abot.graph.invoke({\"messages\": messages})"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "238ec75c-4ff6-4561-bb0a-895530a61e47",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "result['messages'], result['messages'][-1].content"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.10.14"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|