{
"cells": [
{
"cell_type": "markdown",
"id": "b03f60a1-898a-44c9-a03c-a7bc41890807",
"metadata": {},
"source": [
"# Simple tool formatting with HF\n",
"Showcasing how to use the new template format from HF for tool calling"
]
},
{
"cell_type": "markdown",
"id": "512d775f-7625-4d6c-9618-2ea39da9ba17",
"metadata": {},
"source": [
"## Loading the models"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "307ce66e-3d62-442e-9d85-b29565e6f583",
"metadata": {},
"outputs": [],
"source": [
"# from huggingface_hub import login\n",
"# login()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "c42ba115-95c7-4be1-a050-457ee6c28cfd",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e16d6350bfa44e0dbb6f5c63f949f0fe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"import torch\n",
"\n",
"model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_id,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\",\n",
" )\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "f14afed8-b6ad-46b3-86fa-2d68e2e2ea19",
"metadata": {},
"outputs": [],
"source": [
"# del model\n",
"# import gc\n",
"# gc.collect()\n",
"# torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"id": "8de62346-83d7-495d-91c9-d832c34b6f87",
"metadata": {},
"source": [
"## Setting up the system messages and default tools "
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "64db2078-39fe-4775-8764-77aae26fcdce",
"metadata": {},
"outputs": [],
"source": [
"dialogs = [\n",
" [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
" {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
" ],\n",
" [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
" {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
" ],\n",
" ]\n",
"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
" {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
" ]\n",
"\n",
"builtin_tools = [\"code_interpreter\", \"wolfram_alpha\", \"brave_search\"]\n",
"\n",
"json_tools = [ \n",
" { \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"spotify_trending_songs\",\n",
" \"description\": \"Get top trending songs on Spotify\",\n",
" \"parameters\": {\n",
" \"n\": {\n",
" \"param_type\": \"int\",\n",
" \"description\": \"Number of trending songs to get\",\n",
" \"required\": \"true\"\n",
" }\n",
" }\n",
" }\n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_temperature\",\n",
" \"description\": \"Get the current temperature for a specific location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"enum\": [\"Celsius\", \"Fahrenheit\"],\n",
" \"description\": \"The temperature unit to use. Infer this from the user's location.\"\n",
" }\n",
" },\n",
" \"required\": [\"location\", \"unit\"]\n",
" }\n",
" }\n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_rain_probability\",\n",
" \"description\": \"Get the probability of rain for a specific location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
" }\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" }\n",
" }\n",
"]\n",
"\n",
"\n",
"json_tools = [ \n",
" { \n",
" \"tool_name\": \"spotify_trending_songs\",\n",
" \"description\": \"Get top trending songs on Spotify\",\n",
" \"parameters\": {\n",
" \"n\": {\n",
" \"param_type\": \"int\",\n",
" \"description\": \"Number of trending songs to get\",\n",
" \"required\": \"true\"\n",
" }\n",
" }\n",
" \n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_rain_probability\",\n",
" \"description\": \"Get the probability of rain for a specific location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
" }\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" }\n",
" }\n",
"]\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "484498c7-fab8-448a-9381-3c8860b444b4",
"metadata": {},
"source": [
"## Converting to input ids and checking how the prompt format was applied"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "057ad0c5-ec22-4fd0-ae70-99710368db13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 23 Jul 2024\n",
"\n",
"You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"\n"
]
}
],
"source": [
"# for messages in dialog:\n",
"# Shouldn't output the Environment instruction, but it is.\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\"\n",
" ).to(model.device)\n",
"\n",
"print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "cde7edce-eda0-4823-9983-0b358ed14205",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Tools: wolfram_alpha, brave_search\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 23 Jul 2024\n",
"\n",
"You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"\n"
]
}
],
"source": [
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\",\n",
" builtin_tools=builtin_tools\n",
" ).to(model.device)\n",
"\n",
"print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "a161e1a7-ef57-4f02-9a01-8b0487d35454",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 23 Jul 2024\n",
"\n",
"You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Use the function'spotify_trending_songs' to 'Get top trending songs on Spotify':\n",
"{\"name\": \"spotify_trending_songs\", \"description\": \"Get top trending songs on Spotify\", \"parameters\": {\n",
" \"n\": {\n",
" \"param_type\": \"int\",\n",
" \"description\": \"Number of trending songs to get\",\n",
" \"required\": \"true\"\n",
" }\n",
"}Use the function 'get_rain_probability' to 'Get the probability of rain for a specific location':\n",
"{\"name\": \"get_rain_probability\", \"description\": \"Get the probability of rain for a specific location\", \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
"}\n",
"\n",
"Think very carefully before calling functions.\n",
"If a you choose to call a function ONLY reply in the following format with no prefix or suffix:\n",
"\n",
"{\"example_name\": \"example_value\"}\n",
"\n",
"Reminder:\n",
"- If looking for real time information use relevant functions before falling back to brave_search\n",
"- Function calls MUST follow the specified format, start with \n",
"- Required parameters MUST be specified\n",
"- Only call one function at a time\n",
"- Put the entire function call reply on one line\n",
"<|start_header_id|>user<|end_header_id|>\n",
"\n",
"What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"\n"
]
}
],
"source": [
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\",\n",
" custom_tools=json_tools,\n",
" ).to(model.device)\n",
" \n",
"print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "f009729e-7afa-4b02-a356-e8f032c2f281",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
"\n",
"Environment: ipython\n",
"Tools: wolfram_alpha, brave_search\n",
"\n",
"Cutting Knowledge Date: December 2023\n",
"Today Date: 23 Jul 2024\n",
"\n",
"You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
"\n",
"Use the function'spotify_trending_songs' to 'Get top trending songs on Spotify':\n",
"{\"name\": \"spotify_trending_songs\", \"description\": \"Get top trending songs on Spotify\", \"parameters\": {\n",
" \"n\": {\n",
" \"param_type\": \"int\",\n",
" \"description\": \"Number of trending songs to get\",\n",
" \"required\": \"true\"\n",
" }\n",
"}Use the function 'get_rain_probability' to 'Get the probability of rain for a specific location':\n",
"{\"name\": \"get_rain_probability\", \"description\": \"Get the probability of rain for a specific location\", \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
" }\n",
" },\n",
" \"required\": [\n",
" \"location\"\n",
" ]\n",
"}\n",
"\n",
"Think very carefully before calling functions.\n",
"If a you choose to call a function ONLY reply in the following format with no prefix or suffix:\n",
"\n",
"{\"example_name\": \"example_value\"}\n",
"\n",
"Reminder:\n",
"- If looking for real time information use relevant functions before falling back to brave_search\n",
"- Function calls MUST follow the specified format, start with \n",
"- Required parameters MUST be specified\n",
"- Only call one function at a time\n",
"- Put the entire function call reply on one line\n",
"<|start_header_id|>user<|end_header_id|>\n",
"\n",
"What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"\n"
]
}
],
"source": [
"input_ids = tokenizer.apply_chat_template(\n",
" messages,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\",\n",
" custom_tools=json_tools,\n",
" builtin_tools=builtin_tools\n",
" ).to(model.device)\n",
" \n",
"print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
]
},
{
"cell_type": "markdown",
"id": "aa1b303d-d9fc-40df-8c47-bd4a80522d60",
"metadata": {},
"source": [
"## Running inference \n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "6f6077e6-92a6-44f5-912f-19fea4e86c60",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Output:\n",
"\n",
"{\"location\": \"San Francisco, CA\"}\n"
]
}
],
"source": [
"attention_mask = torch.ones_like(input_ids)\n",
"outputs = model.generate(\n",
" input_ids,\n",
" max_new_tokens=400,\n",
" eos_token_id=tokenizer.eos_token_id,\n",
" do_sample=True,\n",
" temperature=0.6,\n",
" top_p=0.9,\n",
" attention_mask=attention_mask,\n",
")\n",
"response = outputs[0][input_ids.shape[-1]:]\n",
"print(\"\\nOutput:\\n\")\n",
"print(tokenizer.decode(response, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad1fd6a4-222e-4003-8c26-5bda9db05ec8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}