{
"cells": [
{
"cell_type": "markdown",
"id": "8ac4ba3b-c438-4f2e-8f52-39846beb5642",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "974c2eb0-4844-4e0d-ae91-91571d070a3f",
"metadata": {},
"outputs": [],
"source": [
"! pip install -U langchain_groq langchain tavily-python replicate langgraph matplotlib"
]
},
{
"attachments": {
"e5e59030-655b-401d-962c-2ef75410b177.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "3a7cc5f2-0fd0-4f73-b5a2-c4d45335c356",
"metadata": {},
"source": [
"# LangGraph Tool Calling Agent with Llama3\n",
"\n",
"LLM-powered agents combine planning, memory, and tool-use (see [here](https://lilianweng.github.io/posts/2023-06-23-agent/), [here](https://www.deeplearning.ai/the-batch/how-agents-can-improve-llm-performance/)). \n",
"\n",
"LangGraph is a library that can be used to build agents:\n",
" \n",
"1) It allows us to define `nodes` for our assistant (which decides whether to call a tool) and our actions (tool calls).\n",
"2) It allows us to define specific `edges` that connect these nodes (e.g., based upon whether a tool call is decided).\n",
"3) It enables `cycles`, where we can call our assistant in a loop until a stopping condition.\n",
"\n",
"\n",
"\n",
"We'll augment a tool-calling version of Llama 3 with various multi-model capabilities using an agent. \n",
"\n",
"### Enviorment\n",
"\n",
"We'll use [Tavily](https://tavily.com/#api) for web search.\n",
"\n",
"We'll use [Replicate](https://replicate.com/), which offers free to try API key and for various multi-modal capabilities.\n",
"\n",
"We can review LangChain LLM integrations that support tool calling [here](https://python.langchain.com/docs/integrations/chat/).\n",
"\n",
"Groq is included. [Here](https://github.com/groq/groq-api-cookbook/blob/main/llama3-stock-market-function-calling/llama3-stock-market-function-calling.ipynb) is a notebook by Groq on function calling with Llama 3 and LangChain."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d39c2a04-d7e7-42f4-9265-780a14f591c0",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from getpass import getpass\n",
"TAVILY_API_KEY = getpass()\n",
"os.environ[\"TAVILY_API_KEY\"] = TAVILY_API_KEY"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8bba6bb-d3f2-4502-83ff-a5cf61d859c8",
"metadata": {},
"outputs": [],
"source": [
"REPLICATE_API_TOKEN = getpass()\n",
"os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd3c1fe4-d6ce-484a-9b3c-e80491f03066",
"metadata": {},
"outputs": [],
"source": [
"GROQ_API_KEY = getpass()\n",
"os.environ[\"GROQ_API_KEY\"] = GROQ_API_KEY"
]
},
{
"cell_type": "markdown",
"id": "f8387853-9875-478c-8511-c75855c58e52",
"metadata": {},
"source": [
"Optionally, add tracing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6a812f4-0305-4eb5-bf29-2c41ef57cb03",
"metadata": {},
"outputs": [],
"source": [
"os.environ['LANGCHAIN_TRACING_V2'] = 'true'\n",
"os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'\n",
"os.environ['LANGCHAIN_API_KEY'] = "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb7387b2-0094-480b-9b59-3788a22ed06e",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"LANGCHAIN_PROJECT\"] = \"llama3-tool-use-agent\""
]
},
{
"cell_type": "markdown",
"id": "9f0193d8-123e-4e1c-94e4-b7fb04d0f2f2",
"metadata": {},
"source": [
"### Define tools\n",
"\n",
"These are the same tools that we used in the [tool-calling-agent notebook](tool-calling-agent.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "586d8bba-2451-4136-b322-f47bcd4b9841",
"metadata": {},
"outputs": [],
"source": [
"import replicate\n",
"\n",
"from langchain_core.tools import tool\n",
"from langgraph.prebuilt import ToolNode\n",
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"\n",
"@tool\n",
"def magic_function(input: int) -> int:\n",
" \"\"\"Applies a magic function to an input.\"\"\"\n",
" return input + 2\n",
"\n",
"@tool\n",
"def web_search(input: str) -> str:\n",
" \"\"\"Runs web search.\"\"\"\n",
" web_search_tool = TavilySearchResults()\n",
" docs = web_search_tool.invoke({\"query\": input})\n",
" return docs\n",
"\n",
"@tool\n",
"def text2image(text: str) -> str:\n",
" \"\"\"generate an image based on a text.\"\"\"\n",
" output = replicate.run(\n",
" \"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc\",\n",
" input={\n",
" \"width\": 1024,\n",
" \"height\": 1024,\n",
" \"prompt\": text, # a yellow lab puppy running free with wild flowers in the mountain behind\n",
" \"scheduler\": \"KarrasDPM\",\n",
" \"num_outputs\": 1,\n",
" \"guidance_scale\": 7.5,\n",
" \"apply_watermark\": True,\n",
" \"negative_prompt\": \"worst quality, low quality\",\n",
" \"prompt_strength\": 0.8,\n",
" \"num_inference_steps\": 60\n",
" }\n",
" )\n",
" print(output)\n",
" return output[0]\n",
"\n",
"@tool\n",
"def image2text(image_url: str, prompt: str) -> str:\n",
" \"\"\"generate text for image_url based on prompt.\"\"\"\n",
" input = {\n",
" \"image\": image_url,\n",
" \"prompt\": prompt\n",
" }\n",
"\n",
" output = replicate.run(\n",
" \"yorickvp/llava-13b:b5f6212d032508382d61ff00469ddda3e32fd8a0e75dc39d8a4191bb742157fb\",\n",
" input=input\n",
" )\n",
"\n",
" return \"\".join(output)\n",
"\n",
"@tool\n",
"def text2speech(text: str) -> int:\n",
" \"\"\"convert text to a speech.\"\"\"\n",
" output = replicate.run(\n",
" \"cjwbw/seamless_communication:668a4fec05a887143e5fe8d45df25ec4c794dd43169b9a11562309b2d45873b0\",\n",
" input={\n",
" \"task_name\": \"T2ST (Text to Speech translation)\",\n",
" \"input_text\": text,\n",
" \"input_text_language\": \"English\",\n",
" \"max_input_audio_length\": 60,\n",
" \"target_language_text_only\": \"English\",\n",
" \"target_language_with_speech\": \"English\"\n",
" }\n",
" )\n",
" return output['audio_output']\n",
"\n",
"def create_tool_node_with_fallback(tools: list) -> dict:\n",
" return ToolNode(tools).with_fallbacks(\n",
" [RunnableLambda(handle_tool_error)], exception_key=\"error\"\n",
" )\n",
"\n",
"def _print_event(event: dict, _printed: set, max_length=1500):\n",
" current_state = event.get(\"dialog_state\")\n",
" if current_state:\n",
" print(f\"Currently in: \", current_state[-1])\n",
" message = event.get(\"messages\")\n",
" if message:\n",
" if isinstance(message, list):\n",
" message = message[-1]\n",
" if message.id not in _printed:\n",
" msg_repr = message.pretty_repr(html=True)\n",
" if len(msg_repr) > max_length:\n",
" msg_repr = msg_repr[:max_length] + \" ... (truncated)\"\n",
" print(msg_repr)\n",
" _printed.add(message.id)\n",
"\n",
"def handle_tool_error(state) -> dict:\n",
" error = state.get(\"error\")\n",
" tool_calls = state[\"messages\"][-1].tool_calls\n",
" return {\n",
" \"messages\": [\n",
" ToolMessage(\n",
" content=f\"Error: {repr(error)}\\n please fix your mistakes.\",\n",
" tool_call_id=tc[\"id\"],\n",
" )\n",
" for tc in tool_calls\n",
" ]\n",
" }\n",
"\n",
"# List of tools\n",
"tools = [\n",
" magic_function,\n",
" web_search,\n",
" text2image,\n",
" image2text,\n",
" text2speech\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "2579e847-33f2-4d88-8579-78d9f4affbb3",
"metadata": {},
"source": [
"### State\n",
"\n",
"This list of messages is passed to each node of our agent.\n",
"\n",
"This will serve as short-term memory that persists during the lifetime of our agent. \n",
"\n",
"See [this overview](https://github.com/langchain-ai/langgraph) of LangGraph for more detail."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f41abc45-bf1a-45a8-a6d3-d90c1cbbd3dd",
"metadata": {},
"outputs": [],
"source": [
"from typing import Annotated\n",
"from typing_extensions import TypedDict\n",
"from langgraph.graph.message import AnyMessage, add_messages\n",
"\n",
"class State(TypedDict):\n",
" messages: Annotated[list[AnyMessage], add_messages]"
]
},
{
"cell_type": "markdown",
"id": "1c06614c-e0d3-40a1-9b65-0f5f825cb4ca",
"metadata": {},
"source": [
"### Assistant \n",
"\n",
"This is Llama 3, with tool-calling, using [Groq](https://python.langchain.com/v0.1/docs/integrations/chat/groq/).\n",
"\n",
"We bind the available tools to Llama 3. \n",
"\n",
"And we further specify the available tools in our assistant prompt."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29fe9a47-857f-4554-8539-8777734e3faa",
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"from langchain_groq import ChatGroq\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import Runnable, RunnableConfig\n",
"\n",
"# Assistant\n",
"class Assistant:\n",
" \n",
" def __init__(self, runnable: Runnable):\n",
" self.runnable = runnable\n",
"\n",
" def __call__(self, state: State, config: RunnableConfig):\n",
" while True:\n",
" # Get any user-provided configs \n",
" image_url = config['configurable'].get(\"image_url\", None)\n",
" # Append to state\n",
" state = {**state, \"image_url\": image_url}\n",
" # Invoke the tool-calling LLM\n",
" result = self.runnable.invoke(state)\n",
" # If it is a tool call -> response is valid\n",
" # If it has meaninful text -> response is valid\n",
" # Otherwise, we re-prompt it b/c response is not meaninful\n",
" if not result.tool_calls and (\n",
" not result.content\n",
" or isinstance(result.content, list)\n",
" and not result.content[0].get(\"text\")\n",
" ):\n",
" messages = state[\"messages\"] + [(\"user\", \"Respond with a real output.\")]\n",
" state = {**state, \"messages\": messages}\n",
" else:\n",
" break\n",
" return {\"messages\": result}\n",
"\n",
"# Prompt \n",
"primary_assistant_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant for with five tools: (1) web search, \"\n",
" \"(2) a custom, magic_function, (3) text to image, (4) image to text \"\n",
" \"(5) text to speech. Use these provided tools in response to the user question. \"\n",
" \"Your image url is: {image_url} \"\n",
" \"Current time: {time}.\",\n",
" ),\n",
" (\"placeholder\", \"{messages}\"),\n",
" ]\n",
").partial(time=datetime.now())\n",
"\n",
"# LLM chain\n",
"llm = ChatGroq(temperature=0, model=\"llama3-70b-8192\")\n",
"assistant_runnable = primary_assistant_prompt | llm.bind_tools(tools)"
]
},
{
"cell_type": "markdown",
"id": "96514473-c092-4195-bf76-9c61089b5072",
"metadata": {},
"source": [
"### Graph\n",
"\n",
"Here, we lay out the graph."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dacf444f-be0f-41bd-b9bd-5ac776fbd5f8",
"metadata": {},
"outputs": [],
"source": [
"from langgraph.checkpoint.sqlite import SqliteSaver\n",
"from langgraph.graph import END, StateGraph\n",
"from langgraph.prebuilt import ToolNode, tools_condition\n",
"from langchain_core.runnables import RunnableLambda\n",
"\n",
"# Graph\n",
"builder = StateGraph(State)\n",
"\n",
"# Define nodes: these do the work\n",
"builder.add_node(\"assistant\", Assistant(assistant_runnable))\n",
"builder.add_node(\"tools\", create_tool_node_with_fallback(tools))\n",
"\n",
"# Define edges: these determine how the control flow moves\n",
"builder.set_entry_point(\"assistant\")\n",
"builder.add_conditional_edges(\n",
" \"assistant\",\n",
" # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools\n",
" # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END\n",
" tools_condition, \n",
" # \"tools\" calls one of our tools. END causes the graph to terminate (and respond to the user)\n",
" {\"tools\": \"tools\", END: END},\n",
")\n",
"builder.add_edge(\"tools\", \"assistant\")\n",
"\n",
"# The checkpointer lets the graph persist its state\n",
"memory = SqliteSaver.from_conn_string(\":memory:\")\n",
"graph = builder.compile(checkpointer=memory)"
]
},
{
"cell_type": "markdown",
"id": "d404bfe1-fc7a-49f1-9e7a-4bd77f830c6d",
"metadata": {},
"source": [
"We can visualize it."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "983ab01b-bd31-47b1-bfd8-15cb0e555156",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Image, display\n",
"\n",
"try:\n",
" display(Image(graph.get_graph(xray=True).draw_mermaid_png()))\n",
"except:\n",
" pass"
]
},
{
"cell_type": "markdown",
"id": "d6388175-6b8f-483c-ab31-948258a8fb7e",
"metadata": {},
"source": [
"### Test\n",
"\n",
"Now, we can test each tool!\n",
"\n",
"See the traces to audit specifically what is happening."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbc9a513-56cf-45fc-a276-d68c54620203",
"metadata": {},
"outputs": [],
"source": [
"questions = [\"What is magic_function(3)\",\n",
" \"What is the weather in SF?\",\n",
" \"Generate an image based upon this text: 'a yellow lab puppy running free with wild flowers in the mountain behind'\",\n",
" \"Tell me a story about this image\",\n",
" \"Convert this text to speech: The image features a small white dog running down a dirt path, surrounded by a beautiful landscape. The dog is happily smiling as it runs, and the path is lined with colorful flowers, creating a vibrant and lively atmosphere. The scene appears to be set in a mountainous area, adding to the picturesque nature of the image.\"\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3baef29-84c5-491d-8ff7-107d64f23435",
"metadata": {},
"outputs": [],
"source": [
"import uuid \n",
"_printed = set()\n",
"image_url = None\n",
"thread_id = str(uuid.uuid4())\n",
"\n",
"config = {\n",
" \"configurable\": {\n",
" \"image_url\": image_url,\n",
" # Checkpoints are accessed by thread_id\n",
" \"thread_id\": thread_id,\n",
" }\n",
"}\n",
"\n",
"events = graph.stream(\n",
" {\"messages\": (\"user\", questions[0])}, config, stream_mode=\"values\"\n",
")\n",
"for event in events:\n",
" _print_event(event, _printed)"
]
},
{
"cell_type": "markdown",
"id": "d25fdf4f-feac-41c6-828c-24494d4bc7c9",
"metadata": {},
"source": [
"Trace: \n",
"\n",
"https://smith.langchain.com/public/e4f4055f-eb68-482a-8843-cecc67ea76d3/r"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5542c838-44f8-45d4-a866-72a056ece638",
"metadata": {},
"outputs": [],
"source": [
"_printed = set()\n",
"image_url = None\n",
"thread_id = str(uuid.uuid4())\n",
"\n",
"config = {\n",
" \"configurable\": {\n",
" \"image_url\": image_url,\n",
" # Checkpoints are accessed by thread_id\n",
" \"thread_id\": thread_id,\n",
" }\n",
"}\n",
"\n",
"events = graph.stream(\n",
" {\"messages\": (\"user\", questions[1])}, config, stream_mode=\"values\"\n",
")\n",
"for event in events:\n",
" _print_event(event, _printed)"
]
},
{
"cell_type": "markdown",
"id": "8391f5fa-aa60-4784-a732-a5e098d11624",
"metadata": {},
"source": [
"Trace: \n",
"\n",
"https://smith.langchain.com/public/1a46bdba-448b-4b23-a78b-650d28d5ee7f/r"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed58a028-fad7-4ba4-8117-b69a5d0b6489",
"metadata": {},
"outputs": [],
"source": [
"_printed = set()\n",
"image_url = None\n",
"thread_id = str(uuid.uuid4())\n",
"\n",
"config = {\n",
" \"configurable\": {\n",
" \"image_url\": image_url,\n",
" # Checkpoints are accessed by thread_id\n",
" \"thread_id\": thread_id,\n",
" }\n",
"}\n",
"\n",
"events = graph.stream(\n",
" {\"messages\": (\"user\", questions[2])}, config, stream_mode=\"values\"\n",
")\n",
"for event in events:\n",
" _print_event(event, _printed)"
]
},
{
"cell_type": "markdown",
"id": "03af241a-83a7-4f65-a628-fac0971468b6",
"metadata": {},
"source": [
"Trace: \n",
"\n",
"https://smith.langchain.com/public/cc9ca4f1-05c8-4dea-a85b-c852f22c14ae/r"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8283c47-145e-4243-9f5c-4203bfde62d3",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def display_image(image_url):\n",
" \"\"\"Display generated image\"\"\"\n",
" response = requests.get(image_url)\n",
" img = Image.open(BytesIO(response.content)) \n",
" plt.imshow(img)\n",
" plt.axis('off')\n",
" plt.show()\n",
"\n",
"image_url = event['messages'][-2].content\n",
"display_image(image_url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19fd5aaa-c882-4de3-b94f-ee290947f944",
"metadata": {},
"outputs": [],
"source": [
"import uuid \n",
"\n",
"_printed = set()\n",
"thread_id = str(uuid.uuid4())\n",
"\n",
"config = {\n",
" \"configurable\": {\n",
" \"image_url\": image_url,\n",
" # Checkpoints are accessed by thread_id\n",
" \"thread_id\": thread_id,\n",
" }\n",
"}\n",
"\n",
"events = graph.stream(\n",
" {\"messages\": (\"user\", questions[3])}, config, stream_mode=\"values\"\n",
")\n",
"for event in events:\n",
" _print_event(event, _printed)"
]
},
{
"cell_type": "markdown",
"id": "dd67d65a-717c-4da1-bbd6-1a796bfc077e",
"metadata": {},
"source": [
"Trace: \n",
"\n",
"https://smith.langchain.com/public/89f45ee4-effc-4cca-b3e6-f12cf5c29168/r"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73e3e12e-d725-4e66-aba1-ddc2d270f468",
"metadata": {},
"outputs": [],
"source": [
"_printed = set()\n",
"image_url = None\n",
"thread_id = str(uuid.uuid4())\n",
"\n",
"config = {\n",
" \"configurable\": {\n",
" \"image_url\": image_url,\n",
" # Checkpoints are accessed by thread_id\n",
" \"thread_id\": thread_id,\n",
" }\n",
"}\n",
"\n",
"events = graph.stream(\n",
" {\"messages\": (\"user\", questions[4])}, config, stream_mode=\"values\"\n",
")\n",
"for event in events:\n",
" _print_event(event, _printed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb7536d7-8883-4e8b-8584-251e1c3b92f2",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Audio\n",
"\n",
"def play_audio(output_url):\n",
" return Audio(url=output_url, autoplay=False)\n",
"\n",
"audio_url = event['messages'][-2].content\n",
"play_audio(audio_url)"
]
},
{
"cell_type": "markdown",
"id": "25ec5e4c-37ff-495f-992f-c3b3935b8565",
"metadata": {},
"source": [
"Trace: \n",
"\n",
"https://smith.langchain.com/public/b504a513-3123-4bfd-8796-3364968559b2/r"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}