|
@@ -0,0 +1,476 @@
|
|
|
|
+{
|
|
|
|
+ "cells": [
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "id": "b03f60a1-898a-44c9-a03c-a7bc41890807",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# Simple tool formatting with HF\n",
|
|
|
|
+ "Showcasing how to use the new template format from HF for tool calling"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "id": "512d775f-7625-4d6c-9618-2ea39da9ba17",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Loading the models"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 5,
|
|
|
|
+ "id": "c42ba115-95c7-4be1-a050-457ee6c28cfd",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "data": {
|
|
|
|
+ "application/vnd.jupyter.widget-view+json": {
|
|
|
|
+ "model_id": "37bc314d4a5c4944a84da770f80f555a",
|
|
|
|
+ "version_major": 2,
|
|
|
|
+ "version_minor": 0
|
|
|
|
+ },
|
|
|
|
+ "text/plain": [
|
|
|
|
+ "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "output_type": "display_data"
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
|
|
|
+ "import torch\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "model_id = \"/home/ubuntu/projects/llama-recipes/models--llhf--Meta-Llama-3.1-8B-Instruct/snapshots/9fd0b760200bab0a7af5e24c14f1283ecdb4765f\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
|
|
|
+ "model = AutoModelForCausalLM.from_pretrained(\n",
|
|
|
|
+ " model_id,\n",
|
|
|
|
+ " torch_dtype=torch.bfloat16,\n",
|
|
|
|
+ " device_map=\"auto\",\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "id": "8de62346-83d7-495d-91c9-d832c34b6f87",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Setting up the system messages and default tools "
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 29,
|
|
|
|
+ "id": "64db2078-39fe-4775-8764-77aae26fcdce",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "dialogs = [\n",
|
|
|
|
+ " [\n",
|
|
|
|
+ " {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
|
|
|
|
+ " {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
|
|
|
|
+ " ],\n",
|
|
|
|
+ " [\n",
|
|
|
|
+ " {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
|
|
|
|
+ " {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
|
|
|
|
+ " ],\n",
|
|
|
|
+ " ]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "messages = [\n",
|
|
|
|
+ " {\"role\": \"system\", \"content\": \"You are a helpful chatbot\"},\n",
|
|
|
|
+ " {\"role\": \"user\", \"content\": \"What is the weather today in San Francisco?\"},\n",
|
|
|
|
+ " ]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "builtin_tools = [\"code_interpreter\", \"wolfram_alpha\", \"brave_search\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "json_tools = [ \n",
|
|
|
|
+ " { \"type\": \"function\",\n",
|
|
|
|
+ " \"function\": {\n",
|
|
|
|
+ " \"name\": \"spotify_trending_songs\",\n",
|
|
|
|
+ " \"description\": \"Get top trending songs on Spotify\",\n",
|
|
|
|
+ " \"parameters\": {\n",
|
|
|
|
+ " \"n\": {\n",
|
|
|
|
+ " \"param_type\": \"int\",\n",
|
|
|
|
+ " \"description\": \"Number of trending songs to get\",\n",
|
|
|
|
+ " \"required\": \"true\"\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"type\": \"function\",\n",
|
|
|
|
+ " \"function\": {\n",
|
|
|
|
+ " \"name\": \"get_current_temperature\",\n",
|
|
|
|
+ " \"description\": \"Get the current temperature for a specific location\",\n",
|
|
|
|
+ " \"parameters\": {\n",
|
|
|
|
+ " \"type\": \"object\",\n",
|
|
|
|
+ " \"properties\": {\n",
|
|
|
|
+ " \"location\": {\n",
|
|
|
|
+ " \"type\": \"string\",\n",
|
|
|
|
+ " \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " \"unit\": {\n",
|
|
|
|
+ " \"type\": \"string\",\n",
|
|
|
|
+ " \"enum\": [\"Celsius\", \"Fahrenheit\"],\n",
|
|
|
|
+ " \"description\": \"The temperature unit to use. Infer this from the user's location.\"\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " \"required\": [\"location\", \"unit\"]\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"type\": \"function\",\n",
|
|
|
|
+ " \"function\": {\n",
|
|
|
|
+ " \"name\": \"get_rain_probability\",\n",
|
|
|
|
+ " \"description\": \"Get the probability of rain for a specific location\",\n",
|
|
|
|
+ " \"parameters\": {\n",
|
|
|
|
+ " \"type\": \"object\",\n",
|
|
|
|
+ " \"properties\": {\n",
|
|
|
|
+ " \"location\": {\n",
|
|
|
|
+ " \"type\": \"string\",\n",
|
|
|
|
+ " \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " \"required\": [\"location\"]\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ "]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "json_tools = [ \n",
|
|
|
|
+ " { \n",
|
|
|
|
+ " \"tool_name\": \"spotify_trending_songs\",\n",
|
|
|
|
+ " \"description\": \"Get top trending songs on Spotify\",\n",
|
|
|
|
+ " \"parameters\": {\n",
|
|
|
|
+ " \"n\": {\n",
|
|
|
|
+ " \"param_type\": \"int\",\n",
|
|
|
|
+ " \"description\": \"Number of trending songs to get\",\n",
|
|
|
|
+ " \"required\": \"true\"\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"type\": \"function\",\n",
|
|
|
|
+ " \"function\": {\n",
|
|
|
|
+ " \"name\": \"get_rain_probability\",\n",
|
|
|
|
+ " \"description\": \"Get the probability of rain for a specific location\",\n",
|
|
|
|
+ " \"parameters\": {\n",
|
|
|
|
+ " \"type\": \"object\",\n",
|
|
|
|
+ " \"properties\": {\n",
|
|
|
|
+ " \"location\": {\n",
|
|
|
|
+ " \"type\": \"string\",\n",
|
|
|
|
+ " \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " \"required\": [\"location\"]\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ "]\n",
|
|
|
|
+ "\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "id": "484498c7-fab8-448a-9381-3c8860b444b4",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Converting to input ids and checking how the prompt format was applied"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 21,
|
|
|
|
+ "id": "057ad0c5-ec22-4fd0-ae70-99710368db13",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Environment: ipython\n",
|
|
|
|
+ "Cutting Knowledge Date: December 2023\n",
|
|
|
|
+ "Today Date: 23 Jul 2024\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "# for messages in dialog:\n",
|
|
|
|
+ "# Shouldn't output the Environment instruction, but it is.\n",
|
|
|
|
+ "input_ids = tokenizer.apply_chat_template(\n",
|
|
|
|
+ " messages,\n",
|
|
|
|
+ " add_generation_prompt=True,\n",
|
|
|
|
+ " return_tensors=\"pt\"\n",
|
|
|
|
+ " ).to(model.device)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 22,
|
|
|
|
+ "id": "cde7edce-eda0-4823-9983-0b358ed14205",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Environment: ipython\n",
|
|
|
|
+ "Tools: wolfram_alpha, brave_search\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Cutting Knowledge Date: December 2023\n",
|
|
|
|
+ "Today Date: 23 Jul 2024\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "input_ids = tokenizer.apply_chat_template(\n",
|
|
|
|
+ " messages,\n",
|
|
|
|
+ " add_generation_prompt=True,\n",
|
|
|
|
+ " return_tensors=\"pt\",\n",
|
|
|
|
+ " builtin_tools=builtin_tools\n",
|
|
|
|
+ " ).to(model.device)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 30,
|
|
|
|
+ "id": "a161e1a7-ef57-4f02-9a01-8b0487d35454",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Environment: ipython\n",
|
|
|
|
+ "Cutting Knowledge Date: December 2023\n",
|
|
|
|
+ "Today Date: 23 Jul 2024\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Use the function'spotify_trending_songs' to 'Get top trending songs on Spotify':\n",
|
|
|
|
+ "{\"name\": \"spotify_trending_songs\", \"description\": \"Get top trending songs on Spotify\", \"parameters\": {\n",
|
|
|
|
+ " \"n\": {\n",
|
|
|
|
+ " \"param_type\": \"int\",\n",
|
|
|
|
+ " \"description\": \"Number of trending songs to get\",\n",
|
|
|
|
+ " \"required\": \"true\"\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ "}Use the function 'get_rain_probability' to 'Get the probability of rain for a specific location':\n",
|
|
|
|
+ "{\"name\": \"get_rain_probability\", \"description\": \"Get the probability of rain for a specific location\", \"parameters\": {\n",
|
|
|
|
+ " \"type\": \"object\",\n",
|
|
|
|
+ " \"properties\": {\n",
|
|
|
|
+ " \"location\": {\n",
|
|
|
|
+ " \"type\": \"string\",\n",
|
|
|
|
+ " \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
|
|
|
|
+ " }\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ " \"required\": [\n",
|
|
|
|
+ " \"location\"\n",
|
|
|
|
+ " ]\n",
|
|
|
|
+ "}\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Think very carefully before calling functions.\n",
|
|
|
|
+ "If a you choose to call a function ONLY reply in the following format with no prefix or suffix:\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "<function=example_function_name>{\"example_name\": \"example_value\"}</function>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Reminder:\n",
|
|
|
|
+ "- If looking for real time information use relevant functions before falling back to brave_search\n",
|
|
|
|
+ "- Function calls MUST follow the specified format, start with <function= and end with </function>\n",
|
|
|
|
+ "- Required parameters MUST be specified\n",
|
|
|
|
+ "- Only call one function at a time\n",
|
|
|
|
+ "- Put the entire function call reply on one line\n",
|
|
|
|
+ "<|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "input_ids = tokenizer.apply_chat_template(\n",
|
|
|
|
+ " messages,\n",
|
|
|
|
+ " add_generation_prompt=True,\n",
|
|
|
|
+ " return_tensors=\"pt\",\n",
|
|
|
|
+ " custom_tools=json_tools,\n",
|
|
|
|
+ " ).to(model.device)\n",
|
|
|
|
+ " \n",
|
|
|
|
+ "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "id": "f009729e-7afa-4b02-a356-e8f032c2f281",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "input_ids = tokenizer.apply_chat_template(\n",
|
|
|
|
+ " messages,\n",
|
|
|
|
+ " add_generation_prompt=True,\n",
|
|
|
|
+ " return_tensors=\"pt\",\n",
|
|
|
|
+ " custom_tools=json_tools,\n",
|
|
|
|
+ " builtin_tools=builtin_tools\n",
|
|
|
|
+ " ).to(model.device)\n",
|
|
|
|
+ " \n",
|
|
|
|
+ "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 14,
|
|
|
|
+ "id": "d111c384-8250-4adf-ae66-8866bb712827",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "tensor([128000, 128006, 9125, 128007, 271, 13013, 25, 6125, 27993,\n",
|
|
|
|
+ " 198, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366,\n",
|
|
|
|
+ " 18, 198, 15724, 2696, 25, 220, 1419, 10263, 220,\n",
|
|
|
|
+ " 2366, 19, 271, 2675, 527, 264, 11190, 6369, 6465,\n",
|
|
|
|
+ " 128009, 128006, 882, 128007, 271, 3923, 374, 279, 9282,\n",
|
|
|
|
+ " 3432, 304, 5960, 13175, 30, 128009, 128006, 78191, 128007,\n",
|
|
|
|
+ " 271])\n",
|
|
|
|
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "Environment: ipython\n",
|
|
|
|
+ "Cutting Knowledge Date: December 2023\n",
|
|
|
|
+ "Today Date: 23 Jul 2024\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "print(input_ids[0])\n",
|
|
|
|
+ "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "id": "aa1b303d-d9fc-40df-8c47-bd4a80522d60",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Running inference \n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 15,
|
|
|
|
+ "id": "6f6077e6-92a6-44f5-912f-19fea4e86c60",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stderr",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n",
|
|
|
|
+ "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1850: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
|
|
|
|
+ " warnings.warn(\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "\n",
|
|
|
|
+ "Output:\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "I'm just an AI, I don't have access to real-time weather information. However, I can suggest some ways for you to find out the current weather in San Francisco.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "You can check the weather forecast on websites such as:\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "1. AccuWeather (accuweather.com)\n",
|
|
|
|
+ "2. Weather.com (weather.com)\n",
|
|
|
|
+ "3. National Weather Service (weather.gov)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "You can also check the weather on your smartphone by downloading a weather app such as Dark Sky or Weather Underground.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "If you want to know the current weather in San Francisco, I can suggest some general information about the city's climate. San Francisco has a mild climate year-round, with cool summers and mild winters. The city is known for its foggy weather, especially during the summer months. The average high temperature in San Francisco is around 67°F (19°C), while the average low temperature is around 54°F (12°C).\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "attention_mask = torch.ones_like(input_ids)\n",
|
|
|
|
+ "outputs = model.generate(\n",
|
|
|
|
+ " input_ids,\n",
|
|
|
|
+ " max_new_tokens=400,\n",
|
|
|
|
+ " eos_token_id=tokenizer.eos_token_id,\n",
|
|
|
|
+ " do_sample=True,\n",
|
|
|
|
+ " temperature=0.6,\n",
|
|
|
|
+ " top_p=0.9,\n",
|
|
|
|
+ " attention_mask=attention_mask,\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "response = outputs[0][input_ids.shape[-1]:]\n",
|
|
|
|
+ "print(\"\\nOutput:\\n\")\n",
|
|
|
|
+ "print(tokenizer.decode(response, skip_special_tokens=True))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "id": "ad1fd6a4-222e-4003-8c26-5bda9db05ec8",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": []
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "metadata": {
|
|
|
|
+ "kernelspec": {
|
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
|
+ "language": "python",
|
|
|
|
+ "name": "python3"
|
|
|
|
+ },
|
|
|
|
+ "language_info": {
|
|
|
|
+ "codemirror_mode": {
|
|
|
|
+ "name": "ipython",
|
|
|
|
+ "version": 3
|
|
|
|
+ },
|
|
|
|
+ "file_extension": ".py",
|
|
|
|
+ "mimetype": "text/x-python",
|
|
|
|
+ "name": "python",
|
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
|
+ "version": "3.10.9"
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ "nbformat": 4,
|
|
|
|
+ "nbformat_minor": 5
|
|
|
|
+}
|