|
@@ -6,7 +6,7 @@
|
|
|
"provenance": [],
|
|
|
"machine_shape": "hm",
|
|
|
"gpuType": "A100",
|
|
|
- "authorship_tag": "ABX9TyOJJCuqxZQnS1q+Fvz5+URG",
|
|
|
+ "authorship_tag": "ABX9TyNuIN7/ICiXCX5xELzN1Y3R",
|
|
|
"include_colab_link": true
|
|
|
},
|
|
|
"kernelspec": {
|
|
@@ -380,6 +380,8 @@
|
|
|
"source": [
|
|
|
"# Fine-tune a Mistral-7b model with DPO\n",
|
|
|
"\n",
|
|
|
+ "> 🗣️ [Large Language Model Course](https://github.com/mlabonne/llm-course)\n",
|
|
|
+ "\n",
|
|
|
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne)."
|
|
|
],
|
|
|
"metadata": {
|
|
@@ -469,10 +471,10 @@
|
|
|
" prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)\n",
|
|
|
"\n",
|
|
|
" # Format chosen answer\n",
|
|
|
- " chosen = example['chatgpt'] + \"<|im_end|>\\n\"\n",
|
|
|
+ " chosen = example['chosen'] + \"<|im_end|>\\n\"\n",
|
|
|
"\n",
|
|
|
" # Format rejected answer\n",
|
|
|
- " rejected = example['llama2-13b-chat'] + \"<|im_end|>\\n\"\n",
|
|
|
+ " rejected = example['rejected'] + \"<|im_end|>\\n\"\n",
|
|
|
"\n",
|
|
|
" return {\n",
|
|
|
" \"prompt\": system + prompt,\n",
|
|
@@ -561,13 +563,6 @@
|
|
|
")\n",
|
|
|
"model.config.use_cache = False\n",
|
|
|
"\n",
|
|
|
- "# Reference model\n",
|
|
|
- "ref_model = AutoModelForCausalLM.from_pretrained(\n",
|
|
|
- " model_name,\n",
|
|
|
- " torch_dtype=torch.float16,\n",
|
|
|
- " load_in_4bit=True\n",
|
|
|
- ")\n",
|
|
|
- "\n",
|
|
|
"# Training arguments\n",
|
|
|
"training_args = TrainingArguments(\n",
|
|
|
" per_device_train_batch_size=4,\n",
|
|
@@ -588,7 +583,6 @@
|
|
|
"# Create DPO trainer\n",
|
|
|
"dpo_trainer = DPOTrainer(\n",
|
|
|
" model,\n",
|
|
|
- " ref_model,\n",
|
|
|
" args=training_args,\n",
|
|
|
" train_dataset=dataset,\n",
|
|
|
" tokenizer=tokenizer,\n",
|
|
@@ -624,7 +618,7 @@
|
|
|
"tokenizer.save_pretrained(\"final_checkpoint\")\n",
|
|
|
"\n",
|
|
|
"# Flush memory\n",
|
|
|
- "del dpo_trainer, model, ref_model\n",
|
|
|
+ "del dpo_trainer, model\n",
|
|
|
"gc.collect()\n",
|
|
|
"torch.cuda.empty_cache()\n",
|
|
|
"\n",
|