From 1f82563346895efb87bceefd1bcd48ce7713426c Mon Sep 17 00:00:00 2001 From: Matt Streeter Date: Fri, 28 Apr 2023 10:18:22 -0700 Subject: [PATCH] Fix bug in safe_learning_rates colab, and make loss function an explicit parameter in functions that depend on it. This fixes issue #8 --- autobound/notebooks/safe_learning_rates.ipynb | 70 ++++--------------- 1 file changed, 15 insertions(+), 55 deletions(-) diff --git a/autobound/notebooks/safe_learning_rates.ipynb b/autobound/notebooks/safe_learning_rates.ipynb index 1f02103..aaecf81 100644 --- a/autobound/notebooks/safe_learning_rates.ipynb +++ b/autobound/notebooks/safe_learning_rates.ipynb @@ -53,18 +53,8 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { - "executionInfo": { - "elapsed": 3701, - "status": "ok", - "timestamp": 1680628964220, - "user": { - "displayName": "", - "userId": "" - }, - "user_tz": 420 - }, "id": "0RGtcfhby1Ag" }, "outputs": [], @@ -72,7 +62,7 @@ "import autobound.jax as ab\n", "\n", "\n", - "def bound_loss(x, update_dir, max_learning_rate):\n", + "def bound_loss(loss, x, update_dir, max_learning_rate):\n", " \"\"\"Upper bound the loss as a function of the learning rate.\"\"\"\n", " def next_loss(learning_rate):\n", " next_x = jax.tree_util.tree_map(lambda w, v: w + learning_rate*v, x,\n", @@ -94,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "colab": { "height": 279 @@ -136,7 +126,7 @@ "\n", "next_loss = jax.jit(lambda eta: loss(x0 + eta*update_dir))\n", "\n", - "bound = bound_loss(x0, update_dir, max_learning_rate)\n", + "bound = bound_loss(loss, x0, update_dir, max_learning_rate)\n", "etas = np.linspace(0, max_learning_rate, 101)\n", "pyplot.plot(etas, [bound.upper(eta) for eta in etas], 'r',\n", " label='Quadratic upper bound')\n", @@ -158,18 +148,8 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { - "executionInfo": { - "elapsed": 112, - "status": "ok", - "timestamp": 1680628972776, - "user": { - "displayName": "", - "userId": "" - }, - "user_tz": 420 - }, "id": "-Y5bG4Ued-kg" }, "outputs": [], @@ -209,18 +189,8 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { - "executionInfo": { - "elapsed": 68, - "status": "ok", - "timestamp": 1680628975924, - "user": { - "displayName": "", - "userId": "" - }, - "user_tz": 420 - }, "id": "Ha2GeFJxjbJr" }, "outputs": [], @@ -228,8 +198,8 @@ "import jax\n", "\n", "\n", - "def safe_learning_rate(x, update_dir, max_learning_rate):\n", - " bounds = bound_loss(x, update_dir, max_learning_rate)\n", + "def safe_learning_rate(loss, x, update_dir, max_learning_rate):\n", + " bounds = bound_loss(loss, x, update_dir, max_learning_rate)\n", " if len(bounds.coefficients) \u003c 3:\n", " raise NotImplementedError() # the loss is linear \n", " c1 = bounds.coefficients[1]\n", @@ -243,7 +213,7 @@ "loss = lambda x: (x-1)**2\n", "x0 = 0.\n", "update_dir = -jax.grad(loss)(x0)\n", - "safe_eta = safe_learning_rate(x0, update_dir, 1e12)\n", + "safe_eta = safe_learning_rate(loss, x0, update_dir, 1e12)\n", "assert float(loss(x0 + safe_eta*update_dir)) == 0." ] }, @@ -258,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "colab": { "height": 287 @@ -296,7 +266,7 @@ "update_dir = -jax.grad(loss)(x0)\n", "\n", "safe_eta = jax.jit(\n", - " lambda max_rate: safe_learning_rate(x0, update_dir, max_rate))\n", + " lambda max_rate: safe_learning_rate(loss, x0, update_dir, max_rate))\n", "max_rates = [1.1**i for i in range(-90, 3)]\n", "\n", "pyplot.plot(max_rates, [safe_eta(r) for r in max_rates])\n", @@ -319,18 +289,8 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { - "executionInfo": { - "elapsed": 21, - "status": "ok", - "timestamp": 1680628985305, - "user": { - "displayName": "", - "userId": "" - }, - "user_tz": 420 - }, "id": "1RReHtyna6lr" }, "outputs": [], @@ -342,8 +302,8 @@ "\n", " def update_state(state):\n", " x, _, max_eta = state\n", - " jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))\n", - " safe_eta = safe_learning_rate(x, update_dir, max_eta)\n", + " update_dir = jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))\n", + " safe_eta = safe_learning_rate(loss, x, update_dir, max_eta)\n", " next_x = jax.tree_util.tree_map(\n", " lambda p, v: p + safe_eta*v, x, update_dir\n", " )\n", @@ -373,7 +333,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "colab": { "height": 747