{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"source": [
"# Visualizing GPT-2's Loss Landscape\n",
"\n",
"❤️ Created by [@maximelabonne](https://twitter.com/maximelabonne).\n",
"\n",
"Simple perturbation-based calculation of the negative log-likelihood loss in two directions, given \"I have a dream\" as input.\n",
"\n",
"Reference: [Visualizing the Loss Landscape of Neural Nets](https://arxiv.org/abs/1712.09913), by Li et al. (2018)"
],
"metadata": {
"id": "Dgptqrg0zEY5"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "lIYdn1woOS1n",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ffd96eb4-2861-4658-e89d-383eaeb36c3b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.2/7.2 MB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m75.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m67.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
],
"source": [
"!pip install -q transformers"
]
},
{
"cell_type": "code",
"source": [
"%%time\n",
"\n",
"import torch\n",
"from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
"import numpy as np\n",
"import plotly.graph_objects as go\n",
"from tqdm import tqdm\n",
"import imageio\n",
"import os\n",
"\n",
"# Load pre-trained model\n",
"model_name = 'gpt2'\n",
"model = GPT2LMHeadModel.from_pretrained(model_name)\n",
"tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n",
"\n",
"# Set model to evaluation mode\n",
"model.eval()\n",
"\n",
"# Define our input\n",
"input_text = \"I have a dream\"\n",
"inputs = tokenizer.encode_plus(input_text, return_tensors=\"pt\")\n",
"\n",
"# Compute the original loss\n",
"outputs = model(**inputs, labels=inputs[\"input_ids\"])\n",
"original_loss = outputs.loss.item()\n",
"\n",
"# Define two random directions\n",
"direction1 = [torch.randn_like(p) for p in model.parameters()]\n",
"direction2 = [torch.randn_like(p) for p in model.parameters()]\n",
"\n",
"# Normalize vectors\n",
"for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
" norm_p = torch.linalg.norm(p.flatten())\n",
" d1.div_(torch.linalg.norm(d1.flatten())).mul_(norm_p)\n",
" d2.div_(torch.linalg.norm(d2.flatten())).mul_(norm_p)\n",
"\n",
"# Define the range to explore\n",
"x = np.linspace(-1, 1, 20)\n",
"y = np.linspace(-1, 1, 20)\n",
"X, Y = np.meshgrid(x, y)\n",
"\n",
"# Prepare to collect the losses\n",
"Z = np.zeros_like(X)\n",
"\n",
"# Compute loss for each direction\n",
"for i in tqdm(range(x.size), desc=\"x progress\"):\n",
" for j in tqdm(range(y.size), desc=\"y progress\", leave=False):\n",
" # Perturb the model parameters\n",
" for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
" p.data.add_(x[i]*d1 + y[j]*d2)\n",
" \n",
" # Compute the loss\n",
" outputs = model(**inputs, labels=inputs['input_ids'])\n",
" Z[i, j] = outputs.loss.item()\n",
" \n",
" # Revert the model parameters\n",
" for p, d1, d2 in zip(model.parameters(), direction1, direction2):\n",
" p.data.sub_(x[i]*d1 + y[j]*d2)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4Wud37sZa1Y7",
"outputId": "c4e9c839-2938-4df8-c54f-96f5792cc0ed"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 20/20 [12:42<00:00, 38.11s/it]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Create 3D plot\n",
"fig = go.Figure(data=[go.Surface(z=Z, x=X, y=Y, \n",
" showscale=False,)])\n",
"fig.update_layout(\n",
" title=\"GPT-2's Loss Landscape\",\n",
" autosize=True,\n",
" width=1000,\n",
" height=600,\n",
" # scene=dict(\n",
" # xaxis=dict(visible=False),\n",
" # yaxis=dict(visible=False),\n",
" # zaxis=dict(visible=False),\n",
" # )\n",
")\n",
"fig.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 617
},
"id": "5zcGwU4ji67L",
"outputId": "f149f7eb-abfb-4b06-aead-ace3772c5379"
},
"execution_count": 32,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"