|
@@ -0,0 +1,340 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "104f2b97-f9bb-4dcc-a4c8-099710768851",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# Parallel Tool use"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "f8dc57b6-2c48-4ee3-bb2c-25441274ed2f",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Setup"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "e70814b4",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Make sure you have `ipykernel` and `pip` pre-installed"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "962ae5e2",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "%pip install -r requirements.txt"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "id": "e21816b3",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "'Groq API key configured: gsk_7FdrzM...'"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 1,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "import os\n",
|
|
|
+ "import json\n",
|
|
|
+ "\n",
|
|
|
+ "from groq import Groq\n",
|
|
|
+ "from dotenv import load_dotenv\n",
|
|
|
+ "\n",
|
|
|
+ "load_dotenv()\n",
|
|
|
+ "\"Groq API key configured: \" + os.environ[\"GROQ_API_KEY\"][:10] + \"...\""
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "7f7c9c55-e925-4cc1-89f2-58237acf14a4",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "We will use the ```llama3-70b-8192``` model in this demo. Note that you will need a Groq API Key to proceed and can create an account [here](https://console.groq.com/) to generate one for free. Only Llama 3 models support parallel tool use at this time (05/07/2024).\n",
|
|
|
+ "\n",
|
|
|
+ "We recommend using the 70B Llama 3 model, 8B has subpar consistency."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "id": "0cca781b-1950-4167-b36a-c1099d6b3b00",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "client = Groq(api_key=os.getenv(\"GROQ_API_KEY\"))\n",
|
|
|
+ "model = \"llama3-70b-8192\""
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "2c23ec2b",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Let's define a dummy function we can invoke in our tool use loop"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "id": "f2ce18dc",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def get_weather(city: str):\n",
|
|
|
+ " if city == \"Madrid\":\n",
|
|
|
+ " return 35\n",
|
|
|
+ " elif city == \"San Francisco\":\n",
|
|
|
+ " return 18\n",
|
|
|
+ " elif city == \"Paris\":\n",
|
|
|
+ " return 20\n",
|
|
|
+ " else:\n",
|
|
|
+ " return 15"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "a37e3c92",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Now we define our messages and tools and run the completion request."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "id": "6b454910-4352-40cc-b9b2-cc79edabd7c1",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "messages = [\n",
|
|
|
+ " {\"role\": \"system\", \"content\": \"\"\"You are a helpful assistant.\"\"\"},\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"user\",\n",
|
|
|
+ " \"content\": \"What is the weather in Paris, Tokyo and Madrid?\",\n",
|
|
|
+ " },\n",
|
|
|
+ "]\n",
|
|
|
+ "tools = [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"type\": \"function\",\n",
|
|
|
+ " \"function\": {\n",
|
|
|
+ " \"name\": \"get_weather\",\n",
|
|
|
+ " \"description\": \"Returns the weather in the given city in degrees Celsius\",\n",
|
|
|
+ " \"parameters\": {\n",
|
|
|
+ " \"type\": \"object\",\n",
|
|
|
+ " \"properties\": {\n",
|
|
|
+ " \"city\": {\n",
|
|
|
+ " \"type\": \"string\",\n",
|
|
|
+ " \"description\": \"The name of the city\",\n",
|
|
|
+ " }\n",
|
|
|
+ " },\n",
|
|
|
+ " \"required\": [\"city\"],\n",
|
|
|
+ " },\n",
|
|
|
+ " },\n",
|
|
|
+ " }\n",
|
|
|
+ "]\n",
|
|
|
+ "response = client.chat.completions.create(\n",
|
|
|
+ " model=model, messages=messages, tools=tools, tool_choice=\"auto\", max_tokens=4096\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "response_message = response.choices[0].message"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "25c2838f",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# Processing the tool calls\n",
|
|
|
+ "\n",
|
|
|
+ "Now we process the assistant message and construct the required messages to continue the conversation. \n",
|
|
|
+ "\n",
|
|
|
+ "*Including* invoking each tool_call against our actual function."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "id": "fe623ab9",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"system\",\n",
|
|
|
+ " \"content\": \"You are a helpful assistant.\"\n",
|
|
|
+ " },\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"user\",\n",
|
|
|
+ " \"content\": \"What is the weather in Paris, Tokyo and Madrid?\"\n",
|
|
|
+ " },\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"assistant\",\n",
|
|
|
+ " \"tool_calls\": [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"id\": \"call_5ak8\",\n",
|
|
|
+ " \"function\": {\n",
|
|
|
+ " \"name\": \"get_weather\",\n",
|
|
|
+ " \"arguments\": \"{\\\"city\\\":\\\"Paris\\\"}\"\n",
|
|
|
+ " },\n",
|
|
|
+ " \"type\": \"function\"\n",
|
|
|
+ " },\n",
|
|
|
+ " {\n",
|
|
|
+ " \"id\": \"call_zq26\",\n",
|
|
|
+ " \"function\": {\n",
|
|
|
+ " \"name\": \"get_weather\",\n",
|
|
|
+ " \"arguments\": \"{\\\"city\\\":\\\"Tokyo\\\"}\"\n",
|
|
|
+ " },\n",
|
|
|
+ " \"type\": \"function\"\n",
|
|
|
+ " },\n",
|
|
|
+ " {\n",
|
|
|
+ " \"id\": \"call_znf3\",\n",
|
|
|
+ " \"function\": {\n",
|
|
|
+ " \"name\": \"get_weather\",\n",
|
|
|
+ " \"arguments\": \"{\\\"city\\\":\\\"Madrid\\\"}\"\n",
|
|
|
+ " },\n",
|
|
|
+ " \"type\": \"function\"\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " },\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"tool\",\n",
|
|
|
+ " \"content\": \"20\",\n",
|
|
|
+ " \"tool_call_id\": \"call_5ak8\"\n",
|
|
|
+ " },\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"tool\",\n",
|
|
|
+ " \"content\": \"15\",\n",
|
|
|
+ " \"tool_call_id\": \"call_zq26\"\n",
|
|
|
+ " },\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"tool\",\n",
|
|
|
+ " \"content\": \"35\",\n",
|
|
|
+ " \"tool_call_id\": \"call_znf3\"\n",
|
|
|
+ " }\n",
|
|
|
+ "]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "tool_calls = response_message.tool_calls\n",
|
|
|
+ "\n",
|
|
|
+ "messages.append(\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"assistant\",\n",
|
|
|
+ " \"tool_calls\": [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"id\": tool_call.id,\n",
|
|
|
+ " \"function\": {\n",
|
|
|
+ " \"name\": tool_call.function.name,\n",
|
|
|
+ " \"arguments\": tool_call.function.arguments,\n",
|
|
|
+ " },\n",
|
|
|
+ " \"type\": tool_call.type,\n",
|
|
|
+ " }\n",
|
|
|
+ " for tool_call in tool_calls\n",
|
|
|
+ " ],\n",
|
|
|
+ " }\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "available_functions = {\n",
|
|
|
+ " \"get_weather\": get_weather,\n",
|
|
|
+ "}\n",
|
|
|
+ "for tool_call in tool_calls:\n",
|
|
|
+ " function_name = tool_call.function.name\n",
|
|
|
+ " function_to_call = available_functions[function_name]\n",
|
|
|
+ " function_args = json.loads(tool_call.function.arguments)\n",
|
|
|
+ " function_response = function_to_call(**function_args)\n",
|
|
|
+ "\n",
|
|
|
+ " # Note how we create a separate tool call message for each tool call\n",
|
|
|
+ " # the model is able to discern the tool call result through the tool_call_id\n",
|
|
|
+ " messages.append(\n",
|
|
|
+ " {\n",
|
|
|
+ " \"role\": \"tool\",\n",
|
|
|
+ " \"content\": json.dumps(function_response),\n",
|
|
|
+ " \"tool_call_id\": tool_call.id,\n",
|
|
|
+ " }\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ "print(json.dumps(messages, indent=2))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "1abe981a",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Now we run our final completion with multiple tool call results included in the messages array.\n",
|
|
|
+ "\n",
|
|
|
+ "**Note**\n",
|
|
|
+ "\n",
|
|
|
+ "We pass the tool definitions again to help the model understand:\n",
|
|
|
+ "\n",
|
|
|
+ "1. The assistant message with the tool call\n",
|
|
|
+ "2. Interpret the tool results."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "id": "5f077df3",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "The weather in Paris is 20°C, in Tokyo is 15°C, and in Madrid is 35°C.\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "response = client.chat.completions.create(\n",
|
|
|
+ " model=model, messages=messages, tools=tools, tool_choice=\"auto\", max_tokens=4096\n",
|
|
|
+ ")\n",
|
|
|
+ "\n",
|
|
|
+ "print(response.choices[0].message.content)"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.10.13"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|