{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SwcwXRajHelL", "outputId": "553c2843-5a43-430e-d2fc-cdfabfa8e309" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Wed Sep 6 14:56:22 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n", "| N/A 35C P0 26W / 300W | 0MiB / 16384MiB | 0% Default |\n", "| | | N/A |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| No running processes found |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rpf1Z0k4RJM6", "outputId": "40038a1c-2aad-472f-e966-f1d61349bc01" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install -Uqqq pip --progress-bar off\n", "!pip install -qqq torch==2.0.1 --progress-bar off\n", "!pip install -qqq transformers==4.32.1 --progress-bar off\n", "!pip install -qqq datasets==2.14.4 --progress-bar off\n", "!pip install -qqq peft==0.5.0 --progress-bar off\n", "!pip install -qqq bitsandbytes==0.41.1 --progress-bar off\n", "!pip install -qqq trl==0.7.1 --progress-bar off" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "P1PG0WSvRqVq" }, "outputs": [], "source": [ "import json\n", "import re\n", "from pprint import pprint\n", "\n", "import pandas as pd\n", "import torch\n", "from datasets import Dataset, load_dataset\n", "from huggingface_hub import notebook_login\n", "from peft import LoraConfig, PeftModel\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", " TrainingArguments,\n", ")\n", "from trl import SFTTrainer\n", "\n", "DEVICE = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "MODEL_NAME = \"meta-llama/Llama-2-7b-hf\"" ] }, { "cell_type": "markdown", "metadata": { "id": "4ixsX2Y4doEf" }, "source": [ "## Data" ] }, { "cell_type": "code", "source": [ "dataset = load_dataset(\"Salesforce/dialogstudio\", \"TweetSumm\")\n", "dataset" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Kc0CVTtUkWvl", "outputId": "f69dcc3b-aeeb-4267-bb26-476465cbc80a" }, "execution_count": 2, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],\n", " num_rows: 879\n", " })\n", " validation: Dataset({\n", " features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],\n", " num_rows: 110\n", " })\n", " test: Dataset({\n", " features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],\n", " num_rows: 110\n", " })\n", "})" ] }, "metadata": {}, "execution_count": 2 } ] }, { "cell_type": "code", "source": [ "DEFAULT_SYSTEM_PROMPT = \"\"\"\n", "Below is a conversation between a human and an AI agent. Write a summary of the conversation.\n", "\"\"\".strip()\n", "\n", "\n", "def generate_training_prompt(\n", " conversation: str, summary: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT\n", ") -> str:\n", " return f\"\"\"### Instruction: {system_prompt}\n", "\n", "### Input:\n", "{conversation.strip()}\n", "\n", "### Response:\n", "{summary}\n", "\"\"\".strip()" ], "metadata": { "id": "fbx71jhaMGK5" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "def clean_text(text):\n", " text = re.sub(r\"http\\S+\", \"\", text)\n", " text = re.sub(r\"@[^\\s]+\", \"\", text)\n", " text = re.sub(r\"\\s+\", \" \", text)\n", " return re.sub(r\"\\^[^ ]+\", \"\", text)\n", "\n", "\n", "def create_conversation_text(data_point):\n", " text = \"\"\n", " for item in data_point[\"log\"]:\n", " user = clean_text(item[\"user utterance\"])\n", " text += f\"user: {user.strip()}\\n\"\n", "\n", " agent = clean_text(item[\"system response\"])\n", " text += f\"agent: {agent.strip()}\\n\"\n", "\n", " return text" ], "metadata": { "id": "gMfzUJVOR9Lr" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "def generate_text(data_point):\n", " summaries = json.loads(data_point[\"original dialog info\"])[\"summaries\"][\n", " \"abstractive_summaries\"\n", " ]\n", " summary = summaries[0]\n", " summary = \" \".join(summary)\n", "\n", " conversation_text = create_conversation_text(data_point)\n", " return {\n", " \"conversation\": conversation_text,\n", " \"summary\": summary,\n", " \"text\": generate_training_prompt(conversation_text, summary),\n", " }" ], "metadata": { "id": "eRbskn48QNfW" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "example = generate_text(dataset[\"train\"][0])" ], "metadata": { "id": "J9VuGHGYSR5q" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "print(example[\"summary\"])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ycfL7LvHSu2h", "outputId": "5db80291-761f-4e23-83ab-11dcbf166947" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Customer enquired about his Iphone and Apple watch which is not showing his any steps/activity and health activities. Agent is asking to move to DM and look into it.\n" ] } ] }, { "cell_type": "code", "source": [ "print(example[\"conversation\"])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "eL7EDMq_SxlJ", "outputId": "96eca262-64d4-4b21-eafd-df5f65e1dde3" }, "execution_count": 10, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "user: So neither my iPhone nor my Apple Watch are recording my steps/activity, and Health doesn’t recognise either source anymore for some reason. Any ideas? please read the above.\n", "agent: Let’s investigate this together. To start, can you tell us the software versions your iPhone and Apple Watch are running currently?\n", "user: My iPhone is on 11.1.2, and my watch is on 4.1.\n", "agent: Thank you. Have you tried restarting both devices since this started happening?\n", "user: I’ve restarted both, also un-paired then re-paired the watch.\n", "agent: Got it. When did you first notice that the two devices were not talking to each other. Do the two devices communicate through other apps such as Messages?\n", "user: Yes, everything seems fine, it’s just Health and activity.\n", "agent: Let’s move to DM and look into this a bit more. When reaching out in DM, let us know when this first started happening please. For example, did it start after an update or after installing a certain app?\n", "\n" ] } ] }, { "cell_type": "code", "source": [ "print(example[\"text\"])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BbiL_o5dZZhg", "outputId": "a5606c26-bb8b-4f42-e478-3e7dde005f1e" }, "execution_count": 11, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "### Instruction: Below is a conversation between a human and an AI agent. Write a summary of the conversation.\n", "\n", "### Input:\n", "user: So neither my iPhone nor my Apple Watch are recording my steps/activity, and Health doesn’t recognise either source anymore for some reason. Any ideas? please read the above.\n", "agent: Let’s investigate this together. To start, can you tell us the software versions your iPhone and Apple Watch are running currently?\n", "user: My iPhone is on 11.1.2, and my watch is on 4.1.\n", "agent: Thank you. Have you tried restarting both devices since this started happening?\n", "user: I’ve restarted both, also un-paired then re-paired the watch.\n", "agent: Got it. When did you first notice that the two devices were not talking to each other. Do the two devices communicate through other apps such as Messages?\n", "user: Yes, everything seems fine, it’s just Health and activity.\n", "agent: Let’s move to DM and look into this a bit more. When reaching out in DM, let us know when this first started happening please. For example, did it start after an update or after installing a certain app?\n", "\n", "### Response:\n", "Customer enquired about his Iphone and Apple watch which is not showing his any steps/activity and health activities. Agent is asking to move to DM and look into it.\n" ] } ] }, { "cell_type": "code", "source": [ "def process_dataset(data: Dataset):\n", " return (\n", " data.shuffle(seed=42)\n", " .map(generate_text)\n", " .remove_columns(\n", " [\n", " \"original dialog id\",\n", " \"new dialog id\",\n", " \"dialog index\",\n", " \"original dialog info\",\n", " \"log\",\n", " \"prompt\",\n", " ]\n", " )\n", " )" ], "metadata": { "id": "jKidKeUpZkPb" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "dataset[\"train\"] = process_dataset(dataset[\"train\"])\n", "dataset[\"validation\"] = process_dataset(dataset[\"validation\"])" ], "metadata": { "id": "XHy1pVjlaLtm" }, "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "dataset" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MTqi2hRkaXoW", "outputId": "7d64e882-50a2-4076-879b-ca80425d5480" }, "execution_count": 14, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['conversation', 'summary', 'text'],\n", " num_rows: 879\n", " })\n", " validation: Dataset({\n", " features: ['conversation', 'summary', 'text'],\n", " num_rows: 110\n", " })\n", " test: Dataset({\n", " features: ['original dialog id', 'new dialog id', 'dialog index', 'original dialog info', 'log', 'prompt'],\n", " num_rows: 110\n", " })\n", "})" ] }, "metadata": {}, "execution_count": 14 } ] }, { "cell_type": "markdown", "metadata": { "id": "usDeQuT2Wssl" }, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "07b07c9eec834123a841a004ee7e5833", "aca2cb7283fb4b09979f8fe9288f8e56", "ed814a34e19340c5845a65d40b2954cf", "73abe876bb9a4a2b969b1a3e3e818bc6", "12e0fda8bb4845cebc1153984b7d0b94", "afadebdc124d4a37a87b411d0ce4ae96", "ba2c0981676547cda46f0037ec326f9a", "0bddfa3c05d942e494c8c6f0517fe927", "4e91d215337f4977bd74293cb23be470", "a063e116f64e47e4bc7be4e962e768bf", "9d9208456f784d6f9a198ee9a9248a1a", "edb3dc0290bd4edc81d53716ba7c9056", "c426eac45c8940dfa9afb8fb85b1c2c4", "33da956ab0ab4bc2b3877a390e4daec9", "c5f333c5dc5d43e193545cd166779498", "6a8e8ab26daa44d4aced9f264dc12628", "2c4a02cf74344d898c61bca60dfba0ac", "c37ef9175a5a4b75bddb1437ac916353", "36748bab782548e7aba92347b96376b3", "76fbcc19d6474843b25312bc4d9ab21a", "16f892ef73614c598f84ea265c7c4124", "05be4210d3f44bfca2630399a01536a0", "7c30e55b127041a3a72242a26f4d0bf3", "2b991f85fcd94556866b161e6adfa1ec", "064bee1b3d5845f086fb95ba8bcff29d", "6c86a3c4e27048719acedeef120cada9", "f29066684c8045588d3ce63c9e7ee508", "47beb206fee04439b59fe610ceabfe72", "34e02d664a074fb8af934127e9182aca", "d475352432d84b2886548682e710f20f", "ce383a2e46a84edda042f39de179911e", "d445dfc58b514a3095da5db81a191514" ] }, "id": "3cLg8cYCdv74", "outputId": "88010e5e-d9ae-4a00-f924-87eea119dc2e" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
22 | \n", "1.906400 | \n", "1.921726 | \n", "
44 | \n", "1.823500 | \n", "1.881039 | \n", "
66 | \n", "1.677000 | \n", "1.861916 | \n", "
88 | \n", "1.774600 | \n", "1.853609 | \n", "
110 | \n", "1.646800 | \n", "1.852111 | \n", "
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=110, training_loss=1.8931689695878462, metrics={'train_runtime': 865.9278, 'train_samples_per_second': 2.03, 'train_steps_per_second': 0.127, 'total_flos': 1.2865354251706368e+16, 'train_loss': 1.8931689695878462, 'epoch': 2.0})" ] }, "metadata": {}, "execution_count": 25 } ] }, { "cell_type": "code", "source": [ "trainer.save_model()" ], "metadata": { "id": "jJdkDvxKOq8P" }, "execution_count": 26, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WywDQRmGEVOJ", "outputId": "6cfe63db-4fd4-43a9-9f8c-b39883618c84" }, "outputs": [ { "data": { "text/plain": [ "PeftModelForCausalLM(\n", " (base_model): LoraModel(\n", " (model): LlamaForCausalLM(\n", " (model): LlamaModel(\n", " (embed_tokens): Embedding(32000, 4096)\n", " (layers): ModuleList(\n", " (0-31): 32 x LlamaDecoderLayer(\n", " (self_attn): LlamaAttention(\n", " (q_proj): Linear4bit(\n", " in_features=4096, out_features=4096, bias=False\n", " (lora_dropout): ModuleDict(\n", " (default): Dropout(p=0.1, inplace=False)\n", " )\n", " (lora_A): ModuleDict(\n", " (default): Linear(in_features=4096, out_features=64, bias=False)\n", " )\n", " (lora_B): ModuleDict(\n", " (default): Linear(in_features=64, out_features=4096, bias=False)\n", " )\n", " (lora_embedding_A): ParameterDict()\n", " (lora_embedding_B): ParameterDict()\n", " )\n", " (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n", " (v_proj): Linear4bit(\n", " in_features=4096, out_features=4096, bias=False\n", " (lora_dropout): ModuleDict(\n", " (default): Dropout(p=0.1, inplace=False)\n", " )\n", " (lora_A): ModuleDict(\n", " (default): Linear(in_features=4096, out_features=64, bias=False)\n", " )\n", " (lora_B): ModuleDict(\n", " (default): Linear(in_features=64, out_features=4096, bias=False)\n", " )\n", " (lora_embedding_A): ParameterDict()\n", " (lora_embedding_B): ParameterDict()\n", " )\n", " (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n", " (rotary_emb): LlamaRotaryEmbedding()\n", " )\n", " (mlp): LlamaMLP(\n", " (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)\n", " (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)\n", " (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)\n", " (act_fn): SiLUActivation()\n", " )\n", " (input_layernorm): LlamaRMSNorm()\n", " (post_attention_layernorm): LlamaRMSNorm()\n", " )\n", " )\n", " (norm): LlamaRMSNorm()\n", " )\n", " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n", " )\n", " )\n", ")" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.model" ] }, { "cell_type": "code", "source": [ "from peft import AutoPeftModelForCausalLM\n", "\n", "trained_model = AutoPeftModelForCausalLM.from_pretrained(\n", " OUTPUT_DIR,\n", " low_cpu_mem_usage=True,\n", ")\n", "\n", "merged_model = model.merge_and_unload()\n", "merged_model.save_pretrained(\"merged_model\", safe_serialization=True)\n", "tokenizer.save_pretrained(\"merged_model\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "f708fe2d8875400c8c59c4646609b46b", "63c19c247165446aa0f1fb419704403b", "8209c84bd76b4813b452bb3ca461616a", "561252818c534e30ae7d8b7c78d13115", "59e541dd430747ae95e871d615d18702", "c6b7d2c11152439691351a476b015478", "3d66f708ceb441f6b75ee01d0da9a373", "f495f7dd19304c868b3f72ac5fa609a1", "4d46cf53f6bf48ff98fce4e084c4f84f", "9ffbdd68655345d2a367c83535814e87", "b13aea102ffe44b986b61ae46e6d579d" ] }, "id": "e43OUEfYOzFF", "outputId": "259702f6-eafa-4aed-eb7b-529a5d3b7186" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "f708fe2d8875400c8c59c4646609b46b" } }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "40d_XYPr_vzv" }, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "GIV-N87TQMFB" }, "outputs": [], "source": [ "def generate_prompt(\n", " conversation: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT\n", ") -> str:\n", " return f\"\"\"### Instruction: {system_prompt}\n", "\n", "### Input:\n", "{conversation.strip()}\n", "\n", "### Response:\n", "\"\"\".strip()" ] }, { "cell_type": "code", "source": [ "examples = []\n", "for data_point in dataset[\"test\"].select(range(5)):\n", " summaries = json.loads(data_point[\"original dialog info\"])[\"summaries\"][\n", " \"abstractive_summaries\"\n", " ]\n", " summary = summaries[0]\n", " summary = \" \".join(summary)\n", " conversation = create_conversation_text(data_point)\n", " examples.append(\n", " {\n", " \"summary\": summary,\n", " \"conversation\": conversation,\n", " \"prompt\": generate_prompt(conversation),\n", " }\n", " )\n", "test_df = pd.DataFrame(examples)\n", "test_df" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "XH-AyIzFf2z8", "outputId": "d7911c83-c20e-4b05-adb2-579b06ad13bc" }, "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " summary \\\n", "0 Customer is complaining that the watchlist is ... \n", "1 Customer is asking about the ACC to link to th... \n", "2 Customer is complaining about the new updates ... \n", "3 Customer is complaining about parcel service ... \n", "4 The customer says that he is stuck at Staines ... \n", "\n", " conversation \\\n", "0 user: My watchlist is not updating with new ep... \n", "1 user: hi , my Acc was linked to an old number.... \n", "2 user: the new update ios11 sucks. I can’t even... \n", "3 user: FUCK YOU AND YOUR SHITTY PARCEL SERVICE ... \n", "4 user: Stuck at Staines waiting for a Reading t... \n", "\n", " prompt \n", "0 ### Instruction: Below is a conversation betwe... \n", "1 ### Instruction: Below is a conversation betwe... \n", "2 ### Instruction: Below is a conversation betwe... \n", "3 ### Instruction: Below is a conversation betwe... \n", "4 ### Instruction: Below is a conversation betwe... " ], "text/html": [ "\n", "
\n", " | summary | \n", "conversation | \n", "prompt | \n", "
---|---|---|---|
0 | \n", "Customer is complaining that the watchlist is ... | \n", "user: My watchlist is not updating with new ep... | \n", "### Instruction: Below is a conversation betwe... | \n", "
1 | \n", "Customer is asking about the ACC to link to th... | \n", "user: hi , my Acc was linked to an old number.... | \n", "### Instruction: Below is a conversation betwe... | \n", "
2 | \n", "Customer is complaining about the new updates ... | \n", "user: the new update ios11 sucks. I can’t even... | \n", "### Instruction: Below is a conversation betwe... | \n", "
3 | \n", "Customer is complaining about parcel service ... | \n", "user: FUCK YOU AND YOUR SHITTY PARCEL SERVICE ... | \n", "### Instruction: Below is a conversation betwe... | \n", "
4 | \n", "The customer says that he is stuck at Staines ... | \n", "user: Stuck at Staines waiting for a Reading t... | \n", "### Instruction: Below is a conversation betwe... | \n", "