{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "

\n", " \n", "\n", "

\n", "\n", "## Subsurface Data Analytics \n", "\n", "## Interactive Gradient Descent\n", "\n", "#### Michael Pyrcz, Associate Professor, University of Texas at Austin \n", "\n", "##### [Twitter](https://twitter.com/geostatsguy) | [GitHub](https://github.com/GeostatsGuy) | [Website](http://michaelpyrcz.com) | [GoogleScholar](https://scholar.google.com/citations?user=QVZ20eQAAAAJ&hl=en&oi=ao) | [Book](https://www.amazon.com/Geostatistical-Reservoir-Modeling-Michael-Pyrcz/dp/0199731446) | [YouTube](https://www.youtube.com/channel/UCLqEr-xV-ceHdXXXrTId5ig) | [LinkedIn](https://www.linkedin.com/in/michael-pyrcz-61a648a1) | [GeostatsPy](https://github.com/GeostatsGuy/GeostatsPy)\n", "\n", "\n", "### PGE 383 Exercise: Interactive Gradient Descent\n", "\n", "Here's a simple demonstration of gradient descent. I have a lecture that describes the basics of optimization for machine learning model training, [Optimization Lecture](https://www.youtube.com/watch?v=4nYz5j0sAQs&list=PLG19vXLQHvSC2ZKFIkgVpI9fCjkN38kwf&index=23&t=578s)\n", "\n", "In this demonstration you will be able to take a customimizable loss function.\n", "\n", "* you can vary the magnitude and spatial correlation of the noise\n", "\n", "* **magnitude of noise** is controlled by the standard deviation of a Gaussian random residual \n", "\n", "* **spatialcorrelation of noise** is controlled by the size of a Gaussian kernal applied to convolve the random noise to impose spatial structure \n", "\n", "Then you can set the:\n", "\n", "* **initial parameter combination** ($\\phi_1, \\phi_2$)\n", "\n", "* **initial learning rate** \\[0.1,5\\]\n", "\n", "* **momentum** \\[0.0,1.0\\]\n", "\n", "* **\n", "\n", "* simple polynomial model\n", "\n", "* 1 preditor feature and 1 response feature\n", "\n", "for an high interpretability model/ simple illustration.\n", "\n", "#### Train / Test Split\n", "\n", "The available data is split into training and testing subsets.\n", "\n", "* in general 15-30% of the data is withheld from training to apply as testing data\n", "\n", "* testing data selection should be fair, the same difficulty of predictions (offset/different from the training dat\n", "\n", "#### Machine Learning Model Traing\n", "\n", "The training data is applied to train the model parameters such that the model minimizes mismatch with the training data\n", "\n", "* it is common to use **mean square error** (known as a **L2 norm**) as a loss function summarizing the model mismatch\n", "\n", "* **miminizing the loss function** for simple models an anlytical solution may be available, but for most machine this requires an iterative optimization method to find the best model parameters\n", "\n", "This process is repeated over a range of model complexities specified by hyperparameters. \n", "\n", "#### Machine Learning Model Tuning\n", "\n", "The withheld testing data is retrieved and loss function (usually the **L2 norm** again) is calculated to summarize the error over the testing data\n", "\n", "* this is repeated over over the range of specified hypparameters\n", "\n", "* the model complexity / hyperparameters that minimize the loss function / error summary in testing is selected\n", "\n", "This is known are model hypparameter tuning.\n", "\n", "#### Machine Learning Model Overfit\n", "\n", "More model complexity/flexibility than can be justified with the available data, data accuracy, frequency and coverage\n", "\n", "* Model explains “idiosyncrasies” of the data, capturing data noise/error in the model\n", "\n", "* High accuracy in training, but low accuracy in testing / real-world use away from training data cases – poor ability of the model to generalize\n", "\n", "\n", "#### Workflow Goals\n", "\n", "Learn the basics of machine learning training, tuning for model generalization while avoiding model overfit.\n", "\n", "This includes:\n", "\n", "* Demonstrate model training and tuning by hand with an interactive exercies\n", "\n", "* Demonstrate the role of data error in leading to model overfit with complicated models\n", "\n", "#### Getting Started\n", "\n", "You will need to copy the following data files to your working directory. They are available [here](https://github.com/GeostatsGuy/GeoDataSets):\n", "\n", "* Tabular data - [Stochastic_1D_por_perm_demo.csv](https://github.com/GeostatsGuy/GeoDataSets/blob/master/Stochastic_1D_por_perm_demo.csv)\n", "* Tabular data - [Random_Parabola.csv](https://github.com/GeostatsGuy/GeoDataSets/blob/master/Random_Parabola.csv)\n", "\n", "These datasets are available in the folder: https://github.com/GeostatsGuy/GeoDataSets.\n", "\n", "\n", "#### Import Required Packages\n", "\n", "We will also need some standard packages. These should have been installed with Anaconda 3." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%matplotlib notebook\n", "import sys # supress output to screen for interactive variogram modeling \n", "import math # square root operator\n", "import numpy as np # arrays and matrix math\n", "import matplotlib.pyplot as plt # plotting\n", "import scipy # kernel for convolution\n", "from ipywidgets import interactive # widgets and interactivity\n", "from ipywidgets import widgets \n", "from ipywidgets import Layout\n", "from ipywidgets import Label\n", "from ipywidgets import VBox, HBox\n", "from scipy.interpolate import griddata\n", "from matplotlib.animation import FuncAnimation\n", "import scipy.signal as signal\n", "from scipy.interpolate import RegularGridInterpolator\n", "from matplotlib import animation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Animated Optimization\n", "\n", "\n", "\n", "#### Make the Exhaustive Loss Function\n", "\n", "Let's ensure a global minimum in the middle of the solution space and add some structured noise to increase the difficulty\n", "\n", "* local minimums for gradient descent to get stuck" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "nx = 100; ny = 100 # size of the loss function grid\n", "cmap = plt.cm.inferno # set the colormap\n", "xmin = 0.0; xmax = 100.0; ymin = 0.0; ymax = 100.0 # set the extents\n", "X, Y = np.mgrid[ymin:ymax:complex(ny), 0:100:complex(nx)]\n", "XY = np.vstack((Y.flatten(), X.flatten())).T # make location list for interpolation method\n", "XY_list = list(map(tuple, XY))\n", "x = np.zeros(100); y = np.zeros(100) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '
');\n", " var titletext = $(\n", " '
');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('
');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('
');\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", " nav_element.append(button);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = $('');\n", " nav_element.append(status_bar);\n", " this.message = status_bar[0];\n", "\n", " // Add the close button to the window.\n", " var buttongrp = $('
');\n", " var button = $('');\n", " button.click(function (evt) { fig.handle_close(fig, {}); } );\n", " button.mouseover('Stop Interaction', toolbar_mouse_event);\n", " buttongrp.append(button);\n", " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n", " titlebar.prepend(buttongrp);\n", "}\n", "\n", "mpl.figure.prototype._root_extra_style = function(el){\n", " var fig = this\n", " el.on(\"remove\", function(){\n", "\tfig.close_ws(fig, {});\n", " });\n", "}\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(el){\n", " // this is important to make the div 'focusable\n", " el.attr('tabindex', 0)\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " }\n", " else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "\n", "}\n", "\n", "mpl.figure.prototype._key_event_extra = function(event, name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager)\n", " manager = IPython.keyboard_manager;\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which == 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "}\n", "\n", "mpl.figure.prototype.handle_save = function(fig, msg) {\n", " fig.ondownload(fig, null);\n", "}\n", "\n", "\n", "mpl.find_output_cell = function(html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i=0; i= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] == html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel != null) {\n", " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "r0 = 10.0\n", "momentum = 0.9\n", "make_video = False\n", "\n", "delta = 0.001\n", "k = 0.04 # exponential decay parameter\n", "\n", "x = np.zeros(100); y = np.zeros(100)\n", "x[0] = 5.0; y[0] = 50.0\n", "\n", "plt.close()\n", "# set up the plot\n", "fig = plt.figure(figsize = (12,10))\n", "ax = plt.gca()\n", "ax.set_xlim([xmin,xmax]); ax.set_ylim([ymin,ymax])\n", "ax.set_xlabel(r'$\\phi_1$'); ax.set_ylabel(r'$\\phi_2$'); ax.set_title('Gradient Descent')\n", "plt.rc('xtick', labelsize=10); plt.rc('ytick', labelsize=10) \n", "plt.rc('axes', titlesize=20); plt.rc('axes', labelsize=15) \n", "point, = ax.plot([],[], marker=\"o\",c='white', \n", " markeredgecolor = \"black\", ms=30, alpha = 0.5)\n", "cbar = plt.colorbar(c1);\n", "cbar.set_label('Loss Function',rotation = 270,labelpad = 20)\n", "global line\n", "\n", "x = np.clip(x,xmin+delta,xmax-delta); y = np.clip(y,ymin+delta,ymax-delta) # ensure start is not too close to edge\n", "for i in range(1,100): \n", " gradient_x = (interp((y[i-1], x[i-1]+delta)).item()-interp((y[i-1], x[i-1]-delta)).item())/(2*delta)\n", " gradient_y = (interp((y[i-1]+delta, x[i-1])).item()-interp((y[i-1]-delta, x[i-1])).item())/(2*delta)\n", " \n", " r = r0 * math.exp(-k*i) \n", " \n", " if i > 1: # include mommentum\n", " x_step = momentum*prev_x + (1.0-momentum)*(-1*r*gradient_x) # integrat momentum\n", " y_step = momentum*prev_y + (1.0-momentum)*(-1*r*gradient_y) # integrat momentum\n", " else: \n", " x_step = -1*r*gradient_x \n", " y_step = -1*r*gradient_y \n", " \n", " x[i] = x[i-1] + x_step # step forward\n", " y[i] = y[i-1] + y_step \n", " \n", " x[i] = np.clip(x[i],xmin+delta,xmax-delta); y[i] = np.clip(y[i],ymin+delta,ymax-delta) # ensure point is not too close to edge\n", "\n", " prev_x = (x[i] - x[i-1]) # store current step for momentum on next step\n", " prev_y = (y[i] - y[i-1])\n", " \n", "def init(): # animation initialization function\n", " ax = plt.gca()\n", " c1 = ax.contourf(X, Y, truth, cmap = plt.cm.inferno,vmin=0,vmax=100,levels=np.linspace(0, 100, 50))\n", " ax.contour(X, Y, truth, levels=10, linewidths=0.5, colors='black',linestyles='dashed') \n", " \n", "def update(frame):\n", " N = 10\n", " ax = plt.gca()\n", " N = min(N,frame)\n", " top = min(frame,90)\n", " point.set_data(x[frame-N:top],y[frame-N:top])\n", " #point.set_sizes(np.linspace(1,N,N))\n", " #point.set_color(plt.cm.inferno((interp((y[frame],x[frame])).item())/100))\n", "\n", "ani = FuncAnimation(fig, update, frames = np.arange(1, len(x)-20), init_func = init, # make the animation\n", " interval=100, blit=True, repeat = False)\n", "\n", "if make_video: # set up formatting for the movie files\n", " writer = animation.FFMpegWriter(fps=15, metadata=dict(artist='Michael Pyrcz'), bitrate=100000)\n", " ani.save('gradient_descent.mp4', writer=writer)\n", " " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([5.00000000e+00, 1.09732151e+01, 1.66005737e+01, 2.26402826e+01,\n", " 2.74773594e+01, 3.24681367e+01, 3.65755385e+01, 4.16798377e+01,\n", " 4.70823671e+01, 5.08227895e+01, 5.29042218e+01, 5.35516769e+01,\n", " 5.29814415e+01, 5.18579690e+01, 5.00922320e+01, 4.83208676e+01,\n", " 4.63086883e+01, 4.42477544e+01, 4.25447098e+01, 4.13299349e+01,\n", " 4.15426236e+01, 4.23916660e+01, 4.37224425e+01, 4.47674252e+01,\n", " 4.53579521e+01, 4.55815336e+01, 4.53184349e+01, 4.48883008e+01,\n", " 4.43364513e+01, 4.39017758e+01, 4.35581920e+01, 4.32781110e+01,\n", " 4.32553479e+01, 4.34076260e+01, 4.36934037e+01, 4.38853347e+01,\n", " 4.40235862e+01, 4.41277874e+01, 4.42048521e+01, 4.42534424e+01,\n", " 4.42670355e+01, 4.42565021e+01, 4.42230012e+01, 4.41817463e+01,\n", " 4.41360576e+01, 4.40815382e+01, 4.40084635e+01, 4.39253691e+01,\n", " 4.38506894e+01, 4.37888653e+01, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03,\n", " 1.00000000e-03, 1.00000000e-03, 1.00000000e-03, 1.00000000e-03])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x[90-20:80]\n", "x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib import animation, rc\n", "from IPython.display import HTML\n", "\n", "rc('animation', html='jshtml')\n", "\n", "l = widgets.Text(value=' Gradient Descent Demo, Prof. Michael Pyrcz, The University of Texas at Austin',\n", " layout=Layout(width='950px', height='30px'))\n", "\n", "momentum = widgets.FloatSlider(min=0.0, max = 1.0, value=0.5, step = 0.1, description = 'Momentum',orientation='horizontal',style = {'description_width': 'initial'}, continuous_update=False)\n", "r = widgets.FloatSlider(min=1.0, max = 5.0, value=1.5, step = 0.25, description = 'Noise StDev',orientation='horizontal',style = {'description_width': 'initial'}, continuous_update=False)\n", "\n", "ui = widgets.HBox([momentum,r],)\n", "ui2 = widgets.VBox([l,ui],)\n", "\n", "def run_plot(momentum,r):\n", " make_video = False\n", " nx = 100; ny = 100 # size of the loss function grid\n", " cmap = plt.cm.inferno # set the colormap\n", " xmin = 0.0; xmax = 100.0; ymin = 0.0; ymax = 100.0 # set the extents\n", " X, Y = np.mgrid[ymin:ymax:complex(ny), 0:100:complex(nx)]\n", " XY = np.vstack((Y.flatten(), X.flatten())).T # make location list for interpolation method\n", " XY_list = list(map(tuple, XY))\n", " x = np.zeros(100); y = np.zeros(100) \n", " \n", " make_video = False\n", " x[0] = 25.0; y[0] = 0.0\n", " delta = 0.001\n", " \n", " x = np.zeros(50); y = np.zeros(50)\n", " \n", " plt.close()\n", " # set up the plot\n", " fig = plt.figure(figsize = (12,10))\n", " ax = plt.gca()\n", " ax.set_xlim([xmin,xmax]); ax.set_ylim([ymin,ymax])\n", " ax.set_xlabel(r'$\\phi_1$'); ax.set_ylabel(r'$\\phi_2$'); ax.set_title('Gradient Descent')\n", " plt.rc('xtick', labelsize=10); plt.rc('ytick', labelsize=10) \n", " plt.rc('axes', titlesize=20); plt.rc('axes', labelsize=15) \n", " point, = ax.plot([],[], marker=\"o\",c='white', \n", " markeredgecolor = \"black\", ms=30, alpha = 0.5)\n", " cbar = plt.colorbar(c1);\n", " cbar.set_label('Loss Function',rotation = 270,labelpad = 20)\n", " global line\n", " \n", " x = np.clip(x,xmin+delta,xmax-delta); y = np.clip(y,ymin+delta,ymax-delta) # ensure start is not too close to edge\n", " for i in range(1,50): \n", " gradient_x = (interp((y[i-1], x[i-1]+delta)).item()-interp((y[i-1], x[i-1]-delta)).item())/(2*delta)\n", " gradient_y = (interp((y[i-1]+delta, x[i-1])).item()-interp((y[i-1]-delta, x[i-1])).item())/(2*delta)\n", " \n", " if i > 1: # include mommentum\n", " x_step = momentum*prev_x + (1.0-momentum)*(-1*r*gradient_x) # integrat momentum\n", " y_step = momentum*prev_y + (1.0-momentum)*(-1*r*gradient_y) # integrat momentum\n", " else: \n", " x_step = -1*r*gradient_x \n", " y_step = -1*r*gradient_y \n", " \n", " x[i] = x[i-1] + x_step # step forward\n", " y[i] = y[i-1] + y_step \n", " \n", " x[i] = np.clip(x[i],xmin+delta,xmax-delta); y[i] = np.clip(y[i],ymin+delta,ymax-delta) # ensure point is not too close to edge\n", " \n", " prev_x = (x[i] - x[i-1]) # store current step for momentum on next step\n", " prev_y = (y[i] - y[i-1])\n", " \n", " def init(): # animation initialization function\n", " ax = plt.gca()\n", " c1 = ax.contourf(X, Y, truth, cmap = plt.cm.inferno,vmin=0,vmax=100,levels=np.linspace(0, 100, 50))\n", " ax.contour(X, Y, truth, levels=10, linewidths=0.5, colors='black',linestyles='dashed') \n", " \n", " def update(frame):\n", " N = 10\n", " ax = plt.gca()\n", " N = min(N,frame-1)\n", " point.set_data(x[frame-N:frame],y[frame-N:frame])\n", " #point.set_sizes(np.linspace(1,N,N))\n", " #point.set_color(plt.cm.inferno((interp((y[frame],x[frame])).item())/100))\n", " \n", " ani = FuncAnimation(fig, update, frames = np.arange(1, len(x)), init_func = init, # make the animation\n", " interval=100, blit=True, repeat = True)\n", " \n", " if make_video: # set up formatting for the movie files\n", " writer = animation.FFMpegWriter(fps=15, metadata=dict(artist='Me'), bitrate=1800)\n", " ani.save('gradient_descent.mp4', writer=writer)\n", " \n", "# connect the function to make the samples and plot to the widgets \n", "interactive_plot = widgets.interactive_output(run_plot, {'momentum':momentum,'r':r})\n", "#interactive_plot.clear_output(wait = False) # reduce flickering by delaying plot updating\n", "\n", "display(ui2, interactive_plot) # display the interactive plot" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.linspace(1,10,10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }