Browse Source

Created using Colaboratory

Maxime Labonne 1 year ago
parent
commit
4dc551d702
1 changed files with 6 additions and 12 deletions
  1. 6 12
      Fine_tune_a_Mistral_7b_model_with_DPO.ipynb

+ 6 - 12
Fine_tune_a_Mistral_7b_model_with_DPO.ipynb

@@ -6,7 +6,7 @@
       "provenance": [],
       "machine_shape": "hm",
       "gpuType": "A100",
-      "authorship_tag": "ABX9TyOJJCuqxZQnS1q+Fvz5+URG",
+      "authorship_tag": "ABX9TyNuIN7/ICiXCX5xELzN1Y3R",
       "include_colab_link": true
     },
     "kernelspec": {
@@ -380,6 +380,8 @@
       "source": [
         "# Fine-tune a Mistral-7b model with DPO\n",
         "\n",
+        "> 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)\n",
+        "\n",
         "❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne)."
       ],
       "metadata": {
@@ -469,10 +471,10 @@
         "    prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n",
         "\n",
         "    # Format chosen answer\n",
-        "    chosen = example['chatgpt'] + \"<|im_end|>\\n\"\n",
+        "    chosen = example['chosen'] + \"<|im_end|>\\n\"\n",
         "\n",
         "    # Format rejected answer\n",
-        "    rejected = example['llama2-13b-chat'] + \"<|im_end|>\\n\"\n",
+        "    rejected = example['rejected'] + \"<|im_end|>\\n\"\n",
         "\n",
         "    return {\n",
         "        \"prompt\": system + prompt,\n",
@@ -561,13 +563,6 @@
         ")\n",
         "model.config.use_cache = False\n",
         "\n",
-        "# Reference model\n",
-        "ref_model = AutoModelForCausalLM.from_pretrained(\n",
-        "    model_name,\n",
-        "    torch_dtype=torch.float16,\n",
-        "    load_in_4bit=True\n",
-        ")\n",
-        "\n",
         "# Training arguments\n",
         "training_args = TrainingArguments(\n",
         "    per_device_train_batch_size=4,\n",
@@ -588,7 +583,6 @@
         "# Create DPO trainer\n",
         "dpo_trainer = DPOTrainer(\n",
         "    model,\n",
-        "    ref_model,\n",
         "    args=training_args,\n",
         "    train_dataset=dataset,\n",
         "    tokenizer=tokenizer,\n",
@@ -624,7 +618,7 @@
         "tokenizer.save_pretrained(\"final_checkpoint\")\n",
         "\n",
         "# Flush memory\n",
-        "del dpo_trainer, model, ref_model\n",
+        "del dpo_trainer, model\n",
         "gc.collect()\n",
         "torch.cuda.empty_cache()\n",
         "\n",