Explorar o código

Adding simple recipe for tool usage with HF template

Beto hai 9 meses
pai
achega
35aef9cedc

+ 476 - 0
recipes/quickstart/inference/local_inference/simple_tool_test.ipynb

@@ -0,0 +1,476 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "b03f60a1-898a-44c9-a03c-a7bc41890807",
+   "metadata": {},
+   "source": [
+    "# Simple tool formatting with HF\n",
+    "Showcasing how to use the new template format from HF for tool calling"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "512d775f-7625-4d6c-9618-2ea39da9ba17",
+   "metadata": {},
+   "source": [
+    "## Loading the models"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "c42ba115-95c7-4be1-a050-457ee6c28cfd",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "37bc314d4a5c4944a84da770f80f555a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+    "import torch\n",
+    "\n",
+    "# model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
+    "\n",
+    "model_id = \"/home/ubuntu/projects/llama-recipes/models--llhf--Meta-Llama-3.1-8B-Instruct/snapshots/9fd0b760200bab0a7af5e24c14f1283ecdb4765f\"\n",
+    "\n",
+    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
+    "model = AutoModelForCausalLM.from_pretrained(\n",
+    "        model_id,\n",
+    "        torch_dtype=torch.bfloat16,\n",
+    "        device_map=\"auto\",\n",
+    "    )\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8de62346-83d7-495d-91c9-d832c34b6f87",
+   "metadata": {},
+   "source": [
+    "## Setting up the system messages and default tools "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "id": "64db2078-39fe-4775-8764-77aae26fcdce",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dialogs = [\n",
+    "    [\n",
+    "        {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
+    "        {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
+    "    ],\n",
+    "    [\n",
+    "        {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
+    "        {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
+    "    ],\n",
+    "  ]\n",
+    "\n",
+    "messages = [\n",
+    "        {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
+    "        {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
+    "    ]\n",
+    "\n",
+    "builtin_tools = [\"code_interpreter\", \"wolfram_alpha\", \"brave_search\"]\n",
+    "\n",
+    "json_tools = [ \n",
+    "    { \"type\": \"function\",\n",
+    "      \"function\": {\n",
+    "          \"name\": \"spotify_trending_songs\",\n",
+    "          \"description\": \"Get top trending songs on Spotify\",\n",
+    "          \"parameters\": {\n",
+    "            \"n\": {\n",
+    "              \"param_type\": \"int\",\n",
+    "              \"description\": \"Number of trending songs to get\",\n",
+    "              \"required\": \"true\"\n",
+    "            }\n",
+    "          }\n",
+    "        }\n",
+    "   },\n",
+    "    {\n",
+    "      \"type\": \"function\",\n",
+    "      \"function\": {\n",
+    "        \"name\": \"get_current_temperature\",\n",
+    "        \"description\": \"Get the current temperature for a specific location\",\n",
+    "        \"parameters\": {\n",
+    "          \"type\": \"object\",\n",
+    "          \"properties\": {\n",
+    "            \"location\": {\n",
+    "              \"type\": \"string\",\n",
+    "              \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
+    "            },\n",
+    "            \"unit\": {\n",
+    "              \"type\": \"string\",\n",
+    "              \"enum\": [\"Celsius\", \"Fahrenheit\"],\n",
+    "              \"description\": \"The temperature unit to use. Infer this from the user's location.\"\n",
+    "            }\n",
+    "          },\n",
+    "          \"required\": [\"location\", \"unit\"]\n",
+    "        }\n",
+    "      }\n",
+    "    },\n",
+    "    {\n",
+    "      \"type\": \"function\",\n",
+    "      \"function\": {\n",
+    "        \"name\": \"get_rain_probability\",\n",
+    "        \"description\": \"Get the probability of rain for a specific location\",\n",
+    "        \"parameters\": {\n",
+    "          \"type\": \"object\",\n",
+    "          \"properties\": {\n",
+    "            \"location\": {\n",
+    "              \"type\": \"string\",\n",
+    "              \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
+    "            }\n",
+    "          },\n",
+    "          \"required\": [\"location\"]\n",
+    "        }\n",
+    "      }\n",
+    "    }\n",
+    "]\n",
+    "\n",
+    "\n",
+    "json_tools = [ \n",
+    "    { \n",
+    "      \"tool_name\": \"spotify_trending_songs\",\n",
+    "      \"description\": \"Get top trending songs on Spotify\",\n",
+    "      \"parameters\": {\n",
+    "        \"n\": {\n",
+    "          \"param_type\": \"int\",\n",
+    "          \"description\": \"Number of trending songs to get\",\n",
+    "          \"required\": \"true\"\n",
+    "        }\n",
+    "      }\n",
+    "    \n",
+    "   },\n",
+    "    {\n",
+    "      \"type\": \"function\",\n",
+    "      \"function\": {\n",
+    "        \"name\": \"get_rain_probability\",\n",
+    "        \"description\": \"Get the probability of rain for a specific location\",\n",
+    "        \"parameters\": {\n",
+    "          \"type\": \"object\",\n",
+    "          \"properties\": {\n",
+    "            \"location\": {\n",
+    "              \"type\": \"string\",\n",
+    "              \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
+    "            }\n",
+    "          },\n",
+    "          \"required\": [\"location\"]\n",
+    "        }\n",
+    "      }\n",
+    "    }\n",
+    "]\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "484498c7-fab8-448a-9381-3c8860b444b4",
+   "metadata": {},
+   "source": [
+    "## Converting to input ids and checking how the prompt format was applied"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "id": "057ad0c5-ec22-4fd0-ae70-99710368db13",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
+      "\n",
+      "Environment: ipython\n",
+      "Cutting Knowledge Date: December 2023\n",
+      "Today Date: 23 Jul 2024\n",
+      "\n",
+      "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+      "\n",
+      "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+      "\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "# for messages in dialog:\n",
+    "# Shouldn't output the Environment instruction, but it is.\n",
+    "input_ids = tokenizer.apply_chat_template(\n",
+    "        messages,\n",
+    "        add_generation_prompt=True,\n",
+    "        return_tensors=\"pt\"\n",
+    "    ).to(model.device)\n",
+    "\n",
+    "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "id": "cde7edce-eda0-4823-9983-0b358ed14205",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
+      "\n",
+      "Environment: ipython\n",
+      "Tools: wolfram_alpha, brave_search\n",
+      "\n",
+      "Cutting Knowledge Date: December 2023\n",
+      "Today Date: 23 Jul 2024\n",
+      "\n",
+      "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+      "\n",
+      "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+      "\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "input_ids = tokenizer.apply_chat_template(\n",
+    "        messages,\n",
+    "        add_generation_prompt=True,\n",
+    "        return_tensors=\"pt\",\n",
+    "        builtin_tools=builtin_tools\n",
+    "    ).to(model.device)\n",
+    "\n",
+    "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "id": "a161e1a7-ef57-4f02-9a01-8b0487d35454",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
+      "\n",
+      "Environment: ipython\n",
+      "Cutting Knowledge Date: December 2023\n",
+      "Today Date: 23 Jul 2024\n",
+      "\n",
+      "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+      "\n",
+      "Use the function'spotify_trending_songs' to 'Get top trending songs on Spotify':\n",
+      "{\"name\": \"spotify_trending_songs\", \"description\": \"Get top trending songs on Spotify\", \"parameters\": {\n",
+      "    \"n\": {\n",
+      "        \"param_type\": \"int\",\n",
+      "        \"description\": \"Number of trending songs to get\",\n",
+      "        \"required\": \"true\"\n",
+      "    }\n",
+      "}Use the function 'get_rain_probability' to 'Get the probability of rain for a specific location':\n",
+      "{\"name\": \"get_rain_probability\", \"description\": \"Get the probability of rain for a specific location\", \"parameters\": {\n",
+      "    \"type\": \"object\",\n",
+      "    \"properties\": {\n",
+      "        \"location\": {\n",
+      "            \"type\": \"string\",\n",
+      "            \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
+      "        }\n",
+      "    },\n",
+      "    \"required\": [\n",
+      "        \"location\"\n",
+      "    ]\n",
+      "}\n",
+      "\n",
+      "Think very carefully before calling functions.\n",
+      "If a you choose to call a function ONLY reply in the following format with no prefix or suffix:\n",
+      "\n",
+      "<function=example_function_name>{\"example_name\": \"example_value\"}</function>\n",
+      "\n",
+      "Reminder:\n",
+      "- If looking for real time information use relevant functions before falling back to brave_search\n",
+      "- Function calls MUST follow the specified format, start with <function= and end with </function>\n",
+      "- Required parameters MUST be specified\n",
+      "- Only call one function at a time\n",
+      "- Put the entire function call reply on one line\n",
+      "<|start_header_id|>user<|end_header_id|>\n",
+      "\n",
+      "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+      "\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "input_ids = tokenizer.apply_chat_template(\n",
+    "            messages,\n",
+    "            add_generation_prompt=True,\n",
+    "            return_tensors=\"pt\",\n",
+    "            custom_tools=json_tools,\n",
+    "        ).to(model.device)\n",
+    "    \n",
+    "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f009729e-7afa-4b02-a356-e8f032c2f281",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "input_ids = tokenizer.apply_chat_template(\n",
+    "            messages,\n",
+    "            add_generation_prompt=True,\n",
+    "            return_tensors=\"pt\",\n",
+    "            custom_tools=json_tools,\n",
+    "            builtin_tools=builtin_tools\n",
+    "        ).to(model.device)\n",
+    "    \n",
+    "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "d111c384-8250-4adf-ae66-8866bb712827",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([128000, 128006,   9125, 128007,    271,  13013,     25,   6125,  27993,\n",
+      "           198,  38766,   1303,  33025,   2696,     25,   6790,    220,   2366,\n",
+      "            18,    198,  15724,   2696,     25,    220,   1419,  10263,    220,\n",
+      "          2366,     19,    271,   2675,    527,    264,  11190,   6369,   6465,\n",
+      "        128009, 128006,    882, 128007,    271,   3923,    374,    279,   9282,\n",
+      "          3432,    304,   5960,  13175,     30, 128009, 128006,  78191, 128007,\n",
+      "           271])\n",
+      "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
+      "\n",
+      "Environment: ipython\n",
+      "Cutting Knowledge Date: December 2023\n",
+      "Today Date: 23 Jul 2024\n",
+      "\n",
+      "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+      "\n",
+      "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+      "\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(input_ids[0])\n",
+    "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "aa1b303d-d9fc-40df-8c47-bd4a80522d60",
+   "metadata": {},
+   "source": [
+    "## Running inference \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "6f6077e6-92a6-44f5-912f-19fea4e86c60",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n",
+      "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1850: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
+      "  warnings.warn(\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Output:\n",
+      "\n",
+      "I'm just an AI, I don't have access to real-time weather information. However, I can suggest some ways for you to find out the current weather in San Francisco.\n",
+      "\n",
+      "You can check the weather forecast on websites such as:\n",
+      "\n",
+      "1. AccuWeather (accuweather.com)\n",
+      "2. Weather.com (weather.com)\n",
+      "3. National Weather Service (weather.gov)\n",
+      "\n",
+      "You can also check the weather on your smartphone by downloading a weather app such as Dark Sky or Weather Underground.\n",
+      "\n",
+      "If you want to know the current weather in San Francisco, I can suggest some general information about the city's climate. San Francisco has a mild climate year-round, with cool summers and mild winters. The city is known for its foggy weather, especially during the summer months. The average high temperature in San Francisco is around 67°F (19°C), while the average low temperature is around 54°F (12°C).\n"
+     ]
+    }
+   ],
+   "source": [
+    "attention_mask = torch.ones_like(input_ids)\n",
+    "outputs = model.generate(\n",
+    "    input_ids,\n",
+    "    max_new_tokens=400,\n",
+    "    eos_token_id=tokenizer.eos_token_id,\n",
+    "    do_sample=True,\n",
+    "    temperature=0.6,\n",
+    "    top_p=0.9,\n",
+    "    attention_mask=attention_mask,\n",
+    ")\n",
+    "response = outputs[0][input_ids.shape[-1]:]\n",
+    "print(\"\\nOutput:\\n\")\n",
+    "print(tokenizer.decode(response, skip_special_tokens=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ad1fd6a4-222e-4003-8c26-5bda9db05ec8",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}