浏览代码

moving to repo model

Beto 9 月之前
父节点
当前提交
3bb2cce669
共有 1 个文件被更改,包括 75 次插入47 次删除
  1. 75 47
      recipes/quickstart/inference/local_inference/simple_tool_test.ipynb

+ 75 - 47
recipes/quickstart/inference/local_inference/simple_tool_test.ipynb

@@ -19,14 +19,25 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": null,
+   "id": "307ce66e-3d62-442e-9d85-b29565e6f583",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from huggingface_hub import login\n",
+    "# login()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
    "id": "c42ba115-95c7-4be1-a050-457ee6c28cfd",
    "id": "c42ba115-95c7-4be1-a050-457ee6c28cfd",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
     {
     {
      "data": {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "37bc314d4a5c4944a84da770f80f555a",
+       "model_id": "e16d6350bfa44e0dbb6f5c63f949f0fe",
        "version_major": 2,
        "version_major": 2,
        "version_minor": 0
        "version_minor": 0
       },
       },
@@ -42,9 +53,7 @@
     "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
     "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
     "import torch\n",
     "import torch\n",
     "\n",
     "\n",
-    "# model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
-    "\n",
-    "model_id = \"/home/ubuntu/projects/llama-recipes/models--llhf--Meta-Llama-3.1-8B-Instruct/snapshots/9fd0b760200bab0a7af5e24c14f1283ecdb4765f\"\n",
+    "model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
     "\n",
     "\n",
     "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
     "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
     "model = AutoModelForCausalLM.from_pretrained(\n",
     "model = AutoModelForCausalLM.from_pretrained(\n",
@@ -57,6 +66,19 @@
    ]
    ]
   },
   },
   {
   {
+   "cell_type": "code",
+   "execution_count": 41,
+   "id": "f14afed8-b6ad-46b3-86fa-2d68e2e2ea19",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# del model\n",
+    "# import gc\n",
+    "# gc.collect()\n",
+    "# torch.cuda.empty_cache()"
+   ]
+  },
+  {
    "cell_type": "markdown",
    "cell_type": "markdown",
    "id": "8de62346-83d7-495d-91c9-d832c34b6f87",
    "id": "8de62346-83d7-495d-91c9-d832c34b6f87",
    "metadata": {},
    "metadata": {},
@@ -331,47 +353,58 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 43,
    "id": "f009729e-7afa-4b02-a356-e8f032c2f281",
    "id": "f009729e-7afa-4b02-a356-e8f032c2f281",
    "metadata": {},
    "metadata": {},
-   "outputs": [],
-   "source": [
-    "input_ids = tokenizer.apply_chat_template(\n",
-    "            messages,\n",
-    "            add_generation_prompt=True,\n",
-    "            return_tensors=\"pt\",\n",
-    "            custom_tools=json_tools,\n",
-    "            builtin_tools=builtin_tools\n",
-    "        ).to(model.device)\n",
-    "    \n",
-    "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 14,
-   "id": "d111c384-8250-4adf-ae66-8866bb712827",
-   "metadata": {},
    "outputs": [
    "outputs": [
     {
     {
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "tensor([128000, 128006,   9125, 128007,    271,  13013,     25,   6125,  27993,\n",
-      "           198,  38766,   1303,  33025,   2696,     25,   6790,    220,   2366,\n",
-      "            18,    198,  15724,   2696,     25,    220,   1419,  10263,    220,\n",
-      "          2366,     19,    271,   2675,    527,    264,  11190,   6369,   6465,\n",
-      "        128009, 128006,    882, 128007,    271,   3923,    374,    279,   9282,\n",
-      "          3432,    304,   5960,  13175,     30, 128009, 128006,  78191, 128007,\n",
-      "           271])\n",
       "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
       "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
       "\n",
       "\n",
       "Environment: ipython\n",
       "Environment: ipython\n",
+      "Tools: wolfram_alpha, brave_search\n",
+      "\n",
       "Cutting Knowledge Date: December 2023\n",
       "Cutting Knowledge Date: December 2023\n",
       "Today Date: 23 Jul 2024\n",
       "Today Date: 23 Jul 2024\n",
       "\n",
       "\n",
       "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
       "You are a helpful chatbot<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
       "\n",
       "\n",
+      "Use the function'spotify_trending_songs' to 'Get top trending songs on Spotify':\n",
+      "{\"name\": \"spotify_trending_songs\", \"description\": \"Get top trending songs on Spotify\", \"parameters\": {\n",
+      "    \"n\": {\n",
+      "        \"param_type\": \"int\",\n",
+      "        \"description\": \"Number of trending songs to get\",\n",
+      "        \"required\": \"true\"\n",
+      "    }\n",
+      "}Use the function 'get_rain_probability' to 'Get the probability of rain for a specific location':\n",
+      "{\"name\": \"get_rain_probability\", \"description\": \"Get the probability of rain for a specific location\", \"parameters\": {\n",
+      "    \"type\": \"object\",\n",
+      "    \"properties\": {\n",
+      "        \"location\": {\n",
+      "            \"type\": \"string\",\n",
+      "            \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
+      "        }\n",
+      "    },\n",
+      "    \"required\": [\n",
+      "        \"location\"\n",
+      "    ]\n",
+      "}\n",
+      "\n",
+      "Think very carefully before calling functions.\n",
+      "If a you choose to call a function ONLY reply in the following format with no prefix or suffix:\n",
+      "\n",
+      "<function=example_function_name>{\"example_name\": \"example_value\"}</function>\n",
+      "\n",
+      "Reminder:\n",
+      "- If looking for real time information use relevant functions before falling back to brave_search\n",
+      "- Function calls MUST follow the specified format, start with <function= and end with </function>\n",
+      "- Required parameters MUST be specified\n",
+      "- Only call one function at a time\n",
+      "- Put the entire function call reply on one line\n",
+      "<|start_header_id|>user<|end_header_id|>\n",
+      "\n",
       "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
       "What is the weather today in San Francisco?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
       "\n",
       "\n",
       "\n"
       "\n"
@@ -379,7 +412,14 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "print(input_ids[0])\n",
+    "input_ids = tokenizer.apply_chat_template(\n",
+    "            messages,\n",
+    "            add_generation_prompt=True,\n",
+    "            return_tensors=\"pt\",\n",
+    "            custom_tools=json_tools,\n",
+    "            builtin_tools=builtin_tools\n",
+    "        ).to(model.device)\n",
+    "    \n",
     "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
     "print(tokenizer.decode(input_ids[0], skip_special_tokens=False))"
    ]
    ]
   },
   },
@@ -393,7 +433,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 32,
    "id": "6f6077e6-92a6-44f5-912f-19fea4e86c60",
    "id": "6f6077e6-92a6-44f5-912f-19fea4e86c60",
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
@@ -401,9 +441,7 @@
      "name": "stderr",
      "name": "stderr",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n",
-      "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1850: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.\n",
-      "  warnings.warn(\n"
+      "Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.\n"
      ]
      ]
     },
     },
     {
     {
@@ -413,17 +451,7 @@
       "\n",
       "\n",
       "Output:\n",
       "Output:\n",
       "\n",
       "\n",
-      "I'm just an AI, I don't have access to real-time weather information. However, I can suggest some ways for you to find out the current weather in San Francisco.\n",
-      "\n",
-      "You can check the weather forecast on websites such as:\n",
-      "\n",
-      "1. AccuWeather (accuweather.com)\n",
-      "2. Weather.com (weather.com)\n",
-      "3. National Weather Service (weather.gov)\n",
-      "\n",
-      "You can also check the weather on your smartphone by downloading a weather app such as Dark Sky or Weather Underground.\n",
-      "\n",
-      "If you want to know the current weather in San Francisco, I can suggest some general information about the city's climate. San Francisco has a mild climate year-round, with cool summers and mild winters. The city is known for its foggy weather, especially during the summer months. The average high temperature in San Francisco is around 67°F (19°C), while the average low temperature is around 54°F (12°C).\n"
+      "<function=get_rain_probability>{\"location\": \"San Francisco, CA\"}</function>\n"
      ]
      ]
     }
     }
    ],
    ],