Przeglądaj źródła

Add OpenLLaMa notebook

Venelin Valkov 1 rok temu
rodzic
commit
ee7e1e9d09
1 zmienionych plików z 544 dodań i 0 usunięć
  1. 544 0
      open-llama.ipynb

+ 544 - 0
open-llama.ipynb

@@ -0,0 +1,544 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+  "colab": {
+   "provenance": [],
+   "gpuType": "T4",
+   "machine_shape": "hm"
+  },
+  "kernelspec": {
+   "name": "python3",
+   "display_name": "Python 3"
+  },
+  "language_info": {
+   "name": "python"
+  },
+  "accelerator": "GPU",
+  "gpuClass": "standard"
+ },
+ "cells": [
+  {
+   "cell_type": "code",
+   "source": [
+    "!nvidia-smi"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "uHJROIGn3i_Z",
+    "outputId": "764afe17-db79-4832-e93b-5e59728de378"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "Sat May  6 10:14:14 2023       \n",
+      "+-----------------------------------------------------------------------------+\n",
+      "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
+      "|-------------------------------+----------------------+----------------------+\n",
+      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
+      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
+      "|                               |                      |               MIG M. |\n",
+      "|===============================+======================+======================|\n",
+      "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
+      "| N/A   58C    P0    29W /  70W |  13785MiB / 15360MiB |      0%      Default |\n",
+      "|                               |                      |                  N/A |\n",
+      "+-------------------------------+----------------------+----------------------+\n",
+      "                                                                               \n",
+      "+-----------------------------------------------------------------------------+\n",
+      "| Processes:                                                                  |\n",
+      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
+      "|        ID   ID                                                   Usage      |\n",
+      "|=============================================================================|\n",
+      "+-----------------------------------------------------------------------------+\n"
+     ]
+    }
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "!git clone https://huggingface.co/openlm-research/open_llama_7b_preview_300bt"
+   ],
+   "metadata": {
+    "id": "PlYcNqrMqcgy"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "KQUxwEIPmF36"
+   },
+   "outputs": [],
+   "source": [
+    "!pip install -qqq transformers==4.28.1 --progress-bar off\n",
+    "!pip install -qqq bitsandbytes==0.38.1 --progress-bar off\n",
+    "!pip install -qqq accelerate==0.18.0 --progress-bar off\n",
+    "!pip install -qqq sentencepiece==0.1.99 --progress-bar off"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "import textwrap\n",
+    "\n",
+    "import torch\n",
+    "from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer"
+   ],
+   "metadata": {
+    "id": "dIaeE5sKplgx"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "def print_response(response: str):\n",
+    "    print(textwrap.fill(response, width=110))"
+   ],
+   "metadata": {
+    "id": "R7I0fMNXxc7k"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+    "BOS_TOKEN_ID = 1\n",
+    "EOS_TOKEN_ID = 2\n",
+    "MAX_TOKENS = 1024"
+   ],
+   "metadata": {
+    "id": "HtRfNE-RrDmT"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "MODEL_NAME = \"/content/open_llama_7b_preview_300bt/open_llama_7b_preview_300bt_transformers_weights\"\n",
+    "\n",
+    "tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME, add_eos_token=True)\n",
+    "\n",
+    "model = LlamaForCausalLM.from_pretrained(\n",
+    "    MODEL_NAME, local_files_only=True, torch_dtype=torch.float16, device_map=\"auto\"\n",
+    ")"
+   ],
+   "metadata": {
+    "id": "a_8RthnX9SPG"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "tokenizer.bos_token_id = BOS_TOKEN_ID"
+   ],
+   "metadata": {
+    "id": "UcxAGhaTVO25"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "## Single prompt"
+   ],
+   "metadata": {
+    "id": "oBE_HnVDKJqJ"
+   }
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "prompt = \"The world's highest building is\""
+   ],
+   "metadata": {
+    "id": "mAzCHuY9_1sn"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "generation_config = GenerationConfig(max_new_tokens=256, temperature=0.5)"
+   ],
+   "metadata": {
+    "id": "CLHPIWub__MW"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)"
+   ],
+   "metadata": {
+    "id": "jmtslKI1ACcX"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "inputs"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "YzAmXhqok9U-",
+    "outputId": "ab8f39fd-18b4-4ca4-ec7d-b38c49ce31ed"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": [
+       "{'input_ids': tensor([[    1,   347,   925, 31889, 31842,  4454,  2203,   322,     0]],\n",
+       "       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
+      ]
+     },
+     "metadata": {},
+     "execution_count": 11
+    }
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "%%time\n",
+    "\n",
+    "with torch.inference_mode():\n",
+    "    tokens = model.generate(**inputs, generation_config=generation_config)"
+   ],
+   "metadata": {
+    "id": "T4p2rFsCAHJx",
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "outputId": "a1487161-a49f-442a-fb64-97fafd886b15"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "CPU times: user 18.2 s, sys: 306 ms, total: 18.5 s\n",
+      "Wall time: 21.5 s\n"
+     ]
+    }
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "tokens"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "ceEPtizqlOJ3",
+    "outputId": "a87dcc6d-cb7e-42d6-b9ce-3e2076fd65f5"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "execute_result",
+     "data": {
+      "text/plain": [
+       "tensor([[    1,   347,   925, 31889, 31842,  4454,  2203,   322,     0, 31856,\n",
+       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
+       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
+       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
+       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
+       "         21880, 31856,   347,   925, 31889, 31842,  8522,   384,  2203,   322,\n",
+       "           266,  4646, 31892, 21527, 21880, 31856,   347,   925, 31889, 31842,\n",
+       "          8522,   384,  2203,   322,   266,  4646, 31892, 21527, 21880, 31856,\n",
+       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
+       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
+       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
+       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
+       "         21880, 31856,   347,   925, 31889, 31842,  8522,   384,  2203,   322,\n",
+       "           266,  4646, 31892, 21527, 21880, 31856,   347,   925, 31889, 31842,\n",
+       "          8522,   384,  2203,   322,   266,  4646, 31892, 21527, 21880, 31856,\n",
+       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
+       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
+       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
+       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
+       "         21880, 31856,   347,   925, 31889, 31842,  8522,   384,  2203,   322,\n",
+       "           266,  4646, 31892, 21527, 21880, 31856,   347,   925, 31889, 31842,\n",
+       "          8522,   384,  2203,   322,   266,  4646, 31892, 21527, 21880, 31856,\n",
+       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
+       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
+       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
+       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
+       "         21880, 31856,   347,   925, 31889]], device='cuda:0')"
+      ]
+     },
+     "metadata": {},
+     "execution_count": 13
+    }
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "completion = tokenizer.decode(tokens[0], skip_special_tokens=True)\n",
+    "print_response(completion)"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "LOYLu8ybAN0g",
+    "outputId": "bcc2a9bb-29fa-4e6e-87ec-6469002da31b"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "The world's highest building is. The world's tallest building is the Burj Khalifa. The world's tallest\n",
+      "building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building\n",
+      "is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building is the\n",
+      "Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building is the Burj\n",
+      "Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa.\n",
+      "The world's tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The\n",
+      "world's tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's\n",
+      "tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest\n",
+      "building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world'\n"
+     ]
+    }
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "## Prompting with sampling"
+   ],
+   "metadata": {
+    "id": "NtvOWBP2KMzU"
+   }
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "def top_k_sampling(logits, k=10):\n",
+    "    top_k = torch.topk(logits, k)\n",
+    "    top_k_indices = top_k.indices\n",
+    "    top_k_values = top_k.values\n",
+    "    probabilities = torch.softmax(top_k_values, dim=-1)\n",
+    "    choice = torch.multinomial(probabilities, num_samples=1)\n",
+    "    token_id = int(top_k_indices[choice])\n",
+    "    return token_id\n",
+    "\n",
+    "\n",
+    "def process_chat(\n",
+    "    model: LlamaForCausalLM,\n",
+    "    tokenizer: LlamaTokenizer,\n",
+    "    prompt: str,\n",
+    "    max_new_tokens: int = 256,\n",
+    "):\n",
+    "    input_ids = tokenizer(prompt).input_ids\n",
+    "\n",
+    "    output_token_ids = list(input_ids)\n",
+    "\n",
+    "    max_src_len = MAX_TOKENS - max_new_tokens - 8\n",
+    "    input_ids = input_ids[-max_src_len:]\n",
+    "    with torch.inference_mode():\n",
+    "        for i in range(max_new_tokens):\n",
+    "            if i == 0:\n",
+    "                out = model(\n",
+    "                    input_ids=torch.as_tensor([input_ids], device=DEVICE),\n",
+    "                    use_cache=True,\n",
+    "                )\n",
+    "                logits = out.logits\n",
+    "                past_key_values = out.past_key_values\n",
+    "            else:\n",
+    "                out = model(\n",
+    "                    input_ids=torch.as_tensor([[token_id]], device=DEVICE),\n",
+    "                    use_cache=True,\n",
+    "                    past_key_values=past_key_values,\n",
+    "                )\n",
+    "                logits = out.logits\n",
+    "                past_key_values = out.past_key_values\n",
+    "\n",
+    "            last_token_logits = logits[0][-1]\n",
+    "\n",
+    "            token_id = top_k_sampling(last_token_logits)\n",
+    "\n",
+    "            output_token_ids.append(token_id)\n",
+    "\n",
+    "            if token_id == EOS_TOKEN_ID:\n",
+    "                break\n",
+    "\n",
+    "    return tokenizer.decode(output_token_ids, skip_special_tokens=True)"
+   ],
+   "metadata": {
+    "id": "sx_zG4GvJ7Ze"
+   },
+   "execution_count": null,
+   "outputs": []
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "%%time\n",
+    "prompt = \"You're Michael G Scott from the office. What is your favorite phrase?\"\n",
+    "response = process_chat(model, tokenizer, prompt)\n",
+    "print_response(response)"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "UACwKuSkIEI_",
+    "outputId": "a6ffa056-fdaf-4b07-f027-7886257a66db"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "You're Michael G Scott from the office. What is your favorite phrase?. I'm a bit too old to be doing this!\n",
+      "What do you think of your character in the Office? Who is your favorite person you know who is also an actor?\n",
+      "How do you get your inspiration? How do you feel about working with John Krasinski? What is the hardest thing\n",
+      "about being in the show? What kind of a fan are you? What's your favorite thing about being an actor?\n",
+      "CPU times: user 5.44 s, sys: 0 ns, total: 5.44 s\n",
+      "Wall time: 5.46 s\n"
+     ]
+    }
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "prompt = \"The world's highest building is\"\n",
+    "response = process_chat(model, tokenizer, prompt)\n",
+    "print_response(response)"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "k6tcUkdT-H-q",
+    "outputId": "feea72cd-bb93-490c-d0b0-3718f28e0807"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "The world's highest building is tower, which is 1,336.8 meters from the ground in 2007. A building 1,336.8\n",
+      "meters high is called tallest building in history. The tallest building in the world is currently Burj Khalifa\n",
+      "(Dubai). Burj Khalifa is 829.8 meters tall. Burj Khalifa is a 77-story skyscraper. This tall building is the\n",
+      "28th tallest building in the world and the 10th highest structure in the world. Burj Khalifa is a mixed use\n",
+      "building. In addition to office space, there are apartments, hotels, shopping malls in the building. In\n",
+      "addition to being the highest building in the world, the tower Burj Khalifa was also the first building on the\n",
+      "planet to reach the height of 800 meters. This is a record that was achieved in 2010. Burj Khalifa is owned by\n",
+      "the real estate developer Emaar Properties. Burj Khalifa was developed in a record time, the building was\n",
+      "opened at 2010, but was first opened at the end of December 2010 and officially opened on January 4\n"
+     ]
+    }
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "prompt = \"The best way to invest $10,000 is\"\n",
+    "response = process_chat(model, tokenizer, prompt)\n",
+    "print_response(response)"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "ByY-PuTiGtRy",
+    "outputId": "835e55da-831a-41d3-e064-b8c2734adf53"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "The best way to invest $10,000 is The best way to invest $10,000 is to start by saving up to $5000 and then\n",
+      "buy stock that’s already cheap. This is a more conservative way of investing and you can buy a good stock for\n",
+      "a cheap price. I think you’re right, but the way you’re describing it is more like buying a stock for 50 cents\n",
+      "a share. If you look at the way that I describe it, it’s not a stock that’s a penny a share. It’s a stock\n",
+      "that’s a penny a share, so it’s a penny on a dollar. If you look at the way that I describe it, it’s a stock\n",
+      "that’s a penny on a dollar. It’s a penny on a dollar, so it’s a penny on a dollar. But if you look at it, it’s\n",
+      "a penny on a penny. I think it’s a penny on a penny. But the way you describe it is a penny on a penny. But if\n",
+      "you look at it, it’s a penny on a penny. I think it’s a penny on the dollar. But the way you describe it is a\n",
+      "penny on the penny. But\n"
+     ]
+    }
+   ]
+  },
+  {
+   "cell_type": "code",
+   "source": [
+    "prompt = \"The best make and model v8 manual gearbox car is\"\n",
+    "response = process_chat(model, tokenizer, prompt)\n",
+    "print_response(response)"
+   ],
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "e8lclPl1G60Y",
+    "outputId": "2fa65dc0-27c7-4fdf-c66e-fa7e817c7fe0"
+   },
+   "execution_count": null,
+   "outputs": [
+    {
+     "output_type": "stream",
+     "name": "stdout",
+     "text": [
+      "The best make and model v8 manual gearbox car is 1968-1969 gmc sierra with 4spd manual or gm overdrive auto.\n",
+      "If you want to know how to install a manual gearbox car, please do not hesitate to contact us. If you are\n",
+      "wondering how to install a manual gearbox car, please contact us at:. If you want to know how to install a\n",
+      "manual transaxle gearbox car, please read the following.\n"
+     ]
+    }
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "## References\n",
+    "\n",
+    "- https://github.com/riversun/open_llama_7b_hands_on\n",
+    "- https://news.ycombinator.com/item?id=35798888\n",
+    "- https://github.com/openlm-research/open_llama\n",
+    "- https://huggingface.co/openlm-research/open_llama_7b_preview_300bt"
+   ],
+   "metadata": {
+    "id": "56SaOYzppR1m"
+   }
+  }
+ ]
+}