Explorar o código

updated notebook

Maxime Labonne hai 1 ano
pai
achega
01128d3d1f
Modificáronse 1 ficheiros con 29 adicións e 13 borrados
  1. 29 13
      Fine_tune_Llama_2_in_Google_Colab.ipynb

+ 29 - 13
Fine_tune_Llama_2_in_Google_Colab.ipynb

@@ -6,7 +6,7 @@
       "provenance": [],
       "machine_shape": "hm",
       "gpuType": "V100",
-      "authorship_tag": "ABX9TyPNl/WKBYXOzuJCP/puYm6d",
+      "authorship_tag": "ABX9TyMElK+4/0JPkM9Cs0WQVGXA",
       "include_colab_link": true
     },
     "kernelspec": {
@@ -37,7 +37,7 @@
         "\n",
         "❤️ Created by [@maximelabonne](), based on Younes Belkada's [GitHub Gist](https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da).\n",
         "\n",
-        "This notebook runs on a T4 GPU with high RAM. (Last update: 23 Jul 2023)\n"
+        "This notebook runs on a T4 GPU with high RAM. (Last update: 26 Jul 2023)\n"
       ],
       "metadata": {
         "id": "OSHlAbqzDFDq"
@@ -88,19 +88,19 @@
         "dataset_name = \"mlabonne/guanaco-llama2-1k\"\n",
         "\n",
         "# Fine-tuned model name\n",
-        "new_model = \"llama-2-7b-guanaco\"\n",
+        "new_model = \"llama-2-7b-miniguanaco\"\n",
         "\n",
         "################################################################################\n",
         "# QLoRA parameters\n",
         "################################################################################\n",
         "\n",
-        "# Lora attention dimension\n",
+        "# LoRA attention dimension\n",
         "lora_r = 64\n",
         "\n",
-        "# Alpha parameter for Lora scaling\n",
+        "# Alpha parameter for LoRA scaling\n",
         "lora_alpha = 16\n",
         "\n",
-        "# Dropout probability for Lora layers\n",
+        "# Dropout probability for LoRA layers\n",
         "lora_dropout = 0.1\n",
         "\n",
         "################################################################################\n",
@@ -140,7 +140,7 @@
         "per_device_eval_batch_size = 4\n",
         "\n",
         "# Number of update steps to accumulate the gradients for\n",
-        "gradient_accumulation_steps = 1\n",
+        "gradient_accumulation_steps = 2\n",
         "\n",
         "# Enable gradient checkpointing\n",
         "gradient_checkpointing = True\n",
@@ -228,6 +228,11 @@
         "model.config.use_cache = False\n",
         "model.config.pretraining_tp = 1\n",
         "\n",
+        "# Load LLaMA tokenizer\n",
+        "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
+        "tokenizer.pad_token = tokenizer.eos_token\n",
+        "tokenizer.padding_side = \"right\" # Fix weird overflow issue with fp16 training\n",
+        "\n",
         "# Load LoRA configuration\n",
         "peft_config = LoraConfig(\n",
         "    lora_alpha=lora_alpha,\n",
@@ -237,11 +242,6 @@
         "    task_type=\"CAUSAL_LM\",\n",
         ")\n",
         "\n",
-        "# Load LLaMA tokenizer\n",
-        "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
-        "tokenizer.pad_token = tokenizer.eos_token\n",
-        "tokenizer.padding_side = \"right\" # Fix weird overflow issue with fp16 training\n",
-        "\n",
         "# Set training parameters\n",
         "training_arguments = TrainingArguments(\n",
         "    output_dir=output_dir,\n",
@@ -290,9 +290,25 @@
     {
       "cell_type": "code",
       "source": [
+        "%load_ext tensorboard\n",
+        "%tensorboard --logdir results/runs"
+      ],
+      "metadata": {
+        "id": "crj9svNe4hU5"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# Ignore warnings\n",
         "logging.set_verbosity(logging.CRITICAL)\n",
+        "\n",
+        "# Run text generation pipeline with our next model\n",
+        "prompt = \"What is a large language model?\"\n",
         "pipe = pipeline(task=\"text-generation\", model=model, tokenizer=tokenizer, max_length=200)\n",
-        "result = pipe(\"Tell me a joke\")\n",
+        "result = pipe(f\"<s>[INST] {prompt} [/INST]\")\n",
         "print(result[0]['generated_text'])"
       ],
       "metadata": {