瀏覽代碼

Added 3.2 Example

Updated imports and added 3.2 models
Sanyam Bhutani 6 月之前
父節點
當前提交
386092cf84

+ 339 - 6
recipes/quickstart/agents/Agents_101/Tool_Calling_101.ipynb

@@ -6,6 +6,8 @@
    "source": [
     "# Tool Calling 101:\n",
     "\n",
+    "Note: If you are looking for `3.2` Featherlight Model (1B and 3B) instructions, please scroll to the bottom\n",
+    "\n",
     "This is part (1/2) in the tool calling series, this notebook will cover the basics of what tool calling is and how to perform it with `Llama 3.1 models`\n",
     "\n",
     "Here's what you will learn in this notebook:\n",
@@ -14,6 +16,7 @@
     "- Avoid common mistakes when performing tool-calling with Llama\n",
     "- Understand Prompt templates for Tool Calling\n",
     "- Understand how the tool calls are handled under the hood\n",
+    "- 3.2 Model Tool Calling Format and Behaviour\n",
     "\n",
     "In Part 2, we will learn how to build system that can get us comparision between 2 papers"
    ]
@@ -42,7 +45,7 @@
     "#### Install and setup groq dependencies\n",
     "\n",
     "- Install `groq` api to access Llama model(s)\n",
-    "- Configure our client and authenticate with API Key(s)"
+    "- Configure our client and authenticate with API Key(s), Note: PLEASE UPDATE YOUR KEY BELOW"
    ]
   },
   {
@@ -56,14 +59,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 82,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
     "import os\n",
     "from groq import Groq\n",
     "# Create the Groq client\n",
-    "client = Groq(api_key=os.environ.get(\"GROQ_API_KEY\"), )"
+    "client = Groq(api_key='gsk_PDfGP611i_HAHAHAHA_THIS_IS_NOT_MY_REAL_KEY_PLEASE_REPLACE')"
    ]
   },
   {
@@ -91,7 +94,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 83,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -105,7 +108,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 84,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -667,6 +670,336 @@
    ]
   },
   {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 3.2 Models Prompt Format"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Life is great because Llama Team writes great docs for us, so we can conviently copy-pasta examples from there :)\n",
+    "\n",
+    "[Here](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-tool-calling-(1b/3b)-) are the docs for your reference that we will be using. \n",
+    "\n",
+    "Let's verify the details from `llama-toolchain` again and then start the prompt engineering for the small Llamas."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Traceback (most recent call last):\n",
+      "  File \"/opt/miniconda3/bin/llama\", line 8, in <module>\n",
+      "    sys.exit(main())\n",
+      "             ^^^^^^\n",
+      "  File \"/opt/miniconda3/lib/python3.12/site-packages/llama_toolchain/cli/llama.py\", line 44, in main\n",
+      "    parser.run(args)\n",
+      "  File \"/opt/miniconda3/lib/python3.12/site-packages/llama_toolchain/cli/llama.py\", line 38, in run\n",
+      "    args.func(args)\n",
+      "  File \"/opt/miniconda3/lib/python3.12/site-packages/llama_toolchain/cli/model/prompt_format.py\", line 59, in _run_model_template_cmd\n",
+      "    raise argparse.ArgumentTypeError(\n",
+      "argparse.ArgumentTypeError: llama3_1 is not a valid Model. Choose one from --\n",
+      "Llama3.1-8B\n",
+      "Llama3.1-70B\n",
+      "Llama3.1-405B\n",
+      "Llama3.1-8B-Instruct\n",
+      "Llama3.1-70B-Instruct\n",
+      "Llama3.1-405B-Instruct\n",
+      "Llama3.2-1B\n",
+      "Llama3.2-3B\n",
+      "Llama3.2-1B-Instruct\n",
+      "Llama3.2-3B-Instruct\n",
+      "Llama3.2-11B-Vision\n",
+      "Llama3.2-90B-Vision\n",
+      "Llama3.2-11B-Vision-Instruct\n",
+      "Llama3.2-90B-Vision-Instruct\n"
+     ]
+    }
+   ],
+   "source": [
+    "!llama model prompt-format"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "function_definitions = \"\"\"[\n",
+    "    {\n",
+    "        \"name\": \"get_user_info\",\n",
+    "        \"description\": \"Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.\",\n",
+    "        \"parameters\": {\n",
+    "            \"type\": \"dict\",\n",
+    "            \"required\": [\n",
+    "                \"user_id\"\n",
+    "            ],\n",
+    "            \"properties\": {\n",
+    "                \"user_id\": {\n",
+    "                \"type\": \"integer\",\n",
+    "                \"description\": \"The unique identifier of the user. It is used to fetch the specific user details from the database.\"\n",
+    "            },\n",
+    "            \"special\": {\n",
+    "                \"type\": \"string\",\n",
+    "                \"description\": \"Any special information or parameters that need to be considered while fetching user details.\",\n",
+    "                \"default\": \"none\"\n",
+    "                }\n",
+    "            }\n",
+    "        }\n",
+    "    }\n",
+    "]\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "system_prompt = \"\"\"You are an expert in composing functions. You are given a question and a set of possible functions. \n",
+    "Based on the question, you will need to make one or more function/tool calls to achieve the purpose. \n",
+    "If none of the function can be used, point it out. If the given question lacks the parameters required by the function,\n",
+    "also point it out. You should only return the function call in tools call sections.\n",
+    "\n",
+    "If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]\\n\n",
+    "You SHOULD NOT include any other text in the response.\n",
+    "\n",
+    "Here is a list of functions in JSON format that you can invoke.\\n\\n{functions}\\n\"\"\".format(functions=function_definitions)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "chat_history = []\n",
+    "\n",
+    "def model_chat(user_input: str, sys_prompt = system_prompt, temperature: int = 0.7, max_tokens=2048):\n",
+    "    \n",
+    "    chat_history = [\n",
+    "        {\n",
+    "            \"role\": \"system\",\n",
+    "            \"content\": system_prompt\n",
+    "        }\n",
+    "    ]\n",
+    "    \n",
+    "    chat_history.append({\"role\": \"user\", \"content\": user_input})\n",
+    "    \n",
+    "    #print(chat_history)\n",
+    "    \n",
+    "    #print(\"User: \", user_input)\n",
+    "    \n",
+    "    response = client.chat.completions.create(model=\"llama-3.2-3b-preview\",\n",
+    "                                          messages=chat_history,\n",
+    "                                          max_tokens=max_tokens,\n",
+    "                                          temperature=temperature)\n",
+    "    \n",
+    "    chat_history.append({\n",
+    "    \"role\": \"assistant\",\n",
+    "    \"content\": response.choices[0].message.content\n",
+    "    })\n",
+    "    \n",
+    "    \n",
+    "    #print(\"Assistant:\", response.choices[0].message.content)\n",
+    "    \n",
+    "    return response.choices[0].message.content"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Assistant: [get_user_info(user_id=7890, special='black')]\n"
+     ]
+    }
+   ],
+   "source": [
+    "user_input = \"Can you retrieve the details for the user with the ID 7890, who has black as their special request?\"\n",
+    "\n",
+    "print(\"Assistant:\", model_chat(user_input, sys_prompt=system_prompt))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Dummy dataset to make sure our model stays happy :) "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_user_info(user_id: int, special: str = \"none\") -> dict:\n",
+    "    # This is a mock database of users\n",
+    "    user_database = {\n",
+    "        7890: {\"name\": \"Emma Davis\", \"email\": \"emma@example.com\", \"age\": 31},\n",
+    "        1234: {\"name\": \"Liam Wilson\", \"email\": \"liam@example.com\", \"age\": 28},\n",
+    "        2345: {\"name\": \"Olivia Chen\", \"email\": \"olivia@example.com\", \"age\": 35},\n",
+    "        3456: {\"name\": \"Noah Taylor\", \"email\": \"noah@example.com\", \"age\": 42},\n",
+    "        4567: {\"name\": \"Ava Martinez\", \"email\": \"ava@example.com\", \"age\": 39},\n",
+    "        5678: {\"name\": \"Ethan Brown\", \"email\": \"ethan@example.com\", \"age\": 45},\n",
+    "        6789: {\"name\": \"Sophia Kim\", \"email\": \"sophia@example.com\", \"age\": 33},\n",
+    "        8901: {\"name\": \"Mason Lee\", \"email\": \"mason@example.com\", \"age\": 29},\n",
+    "        9012: {\"name\": \"Isabella Garcia\", \"email\": \"isabella@example.com\", \"age\": 37},\n",
+    "        1357: {\"name\": \"James Johnson\", \"email\": \"james@example.com\", \"age\": 41}\n",
+    "    }\n",
+    "    \n",
+    "    # Check if the user exists in our mock database\n",
+    "    if user_id in user_database:\n",
+    "        user_data = user_database[user_id]\n",
+    "        \n",
+    "        # Handle the 'special' parameter\n",
+    "        if special != \"none\":\n",
+    "            user_data[\"special_info\"] = f\"Special request: {special}\"\n",
+    "        \n",
+    "        return user_data\n",
+    "    else:\n",
+    "        return {\"error\": \"User not found\"}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[{'name': 'Emma Davis',\n",
+       "  'email': 'emma@example.com',\n",
+       "  'age': 31,\n",
+       "  'special_info': 'Special request: black'}]"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "[get_user_info(user_id=7890, special='black')]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Hello Regex, my good old friend :) "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import re\n",
+    "import json\n",
+    "\n",
+    "# Assuming you have defined get_user_info function and SYSTEM_PROMPT\n",
+    "\n",
+    "chat_history = []\n",
+    "\n",
+    "def process_response(response):\n",
+    "    function_call_pattern = r'\\[(.*?)\\((.*?)\\)\\]'\n",
+    "    function_calls = re.findall(function_call_pattern, response)\n",
+    "    \n",
+    "    if function_calls:\n",
+    "        processed_response = []\n",
+    "        for func_name, args_str in function_calls:\n",
+    "            args_dict = {}\n",
+    "            for arg in args_str.split(','):\n",
+    "                key, value = arg.split('=')\n",
+    "                key = key.strip()\n",
+    "                value = value.strip().strip(\"'\")\n",
+    "                if value.isdigit():\n",
+    "                    value = int(value)\n",
+    "                args_dict[key] = value\n",
+    "            \n",
+    "            if func_name == 'get_user_info':\n",
+    "                result = get_user_info(**args_dict)\n",
+    "                processed_response.append(f\"Function call result: {json.dumps(result, indent=2)}\")\n",
+    "            else:\n",
+    "                processed_response.append(f\"Unknown function: {func_name}\")\n",
+    "        return \"\\n\".join(processed_response)\n",
+    "    else:\n",
+    "        return response\n",
+    "\n",
+    "def model_chat(user_input: str, sys_prompt=system_prompt, temperature: float = 0.7, max_tokens: int = 2048):\n",
+    "    global chat_history\n",
+    "    \n",
+    "    if not chat_history:\n",
+    "        chat_history = [\n",
+    "            {\n",
+    "                \"role\": \"system\",\n",
+    "                \"content\": sys_prompt\n",
+    "            }\n",
+    "        ]\n",
+    "    \n",
+    "    chat_history.append({\"role\": \"user\", \"content\": user_input})\n",
+    "    \n",
+    "    response = client.chat.completions.create(\n",
+    "        model=\"llama-3.2-3b-preview\",\n",
+    "        messages=chat_history,\n",
+    "        max_tokens=max_tokens,\n",
+    "        temperature=temperature\n",
+    "    )\n",
+    "    \n",
+    "    assistant_response = response.choices[0].message.content\n",
+    "    processed_response = process_response(assistant_response)\n",
+    "    \n",
+    "    chat_history.append({\n",
+    "        \"role\": \"assistant\",\n",
+    "        \"content\": assistant_response\n",
+    "    })\n",
+    "    \n",
+    "    return processed_response"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Assistant: Function call result: {\n",
+      "  \"name\": \"Emma Davis\",\n",
+      "  \"email\": \"emma@example.com\",\n",
+      "  \"age\": 31,\n",
+      "  \"special_info\": \"Special request: black\"\n",
+      "}\n"
+     ]
+    }
+   ],
+   "source": [
+    "user_input = \"Can you retrieve the details for the user with the ID 7890, who has black as their special request?\"\n",
+    "\n",
+    "print(\"Assistant:\", model_chat(user_input, sys_prompt=system_prompt))"
+   ]
+  },
+  {
    "cell_type": "code",
    "execution_count": 56,
    "metadata": {},
@@ -692,7 +1025,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.12.4"
+   "version": "3.12.5"
   }
  },
  "nbformat": 4,

+ 9 - 3
recipes/quickstart/agents/Agents_101/Tool_Calling_201.ipynb

@@ -68,6 +68,13 @@
    ]
   },
   {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "##### Note: PLEASE REPLACE API KEYS BELOW WITH YOUR REAL ONES"
+   ]
+  },
+  {
    "cell_type": "code",
    "execution_count": 38,
    "metadata": {},
@@ -78,10 +85,9 @@
     "from groq import Groq\n",
     "\n",
     "# Create the Groq client\n",
-    "client = Groq(api_key=os.environ.get(\"GROQ_API_KEY\"), )\n",
+    "client = Groq(api_key='gsk_PDfGP611i_HAHAHAHA_THIS_IS_NOT_MY_REAL_KEY_PLEASE_REPLACE')\n",
     "\n",
-    "TAVILY_API_KEY = os.environ.get('TAVILY_API_KEY')\n",
-    "tavily_client = TavilyClient(api_key=TAVILY_API_KEY)\n"
+    "tavily_client = TavilyClient(api_key='fake_key_HAHAHAHA_THIS_IS_NOT_MY_REAL_KEY_PLEASE_REPLACE')\n"
    ]
   },
   {