Skip to content

Commit

Permalink
Fix bug in safe_learning_rates colab, and make loss function an expli…
Browse files Browse the repository at this point in the history
…cit parameter in functions that depend on it.

This fixes issue #8
  • Loading branch information
mstreeter committed Apr 28, 2023
1 parent 298f97c commit 1f82563
Showing 1 changed file with 15 additions and 55 deletions.
70 changes: 15 additions & 55 deletions autobound/notebooks/safe_learning_rates.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,16 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 3701,
"status": "ok",
"timestamp": 1680628964220,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": 420
},
"id": "0RGtcfhby1Ag"
},
"outputs": [],
"source": [
"import autobound.jax as ab\n",
"\n",
"\n",
"def bound_loss(x, update_dir, max_learning_rate):\n",
"def bound_loss(loss, x, update_dir, max_learning_rate):\n",
" \"\"\"Upper bound the loss as a function of the learning rate.\"\"\"\n",
" def next_loss(learning_rate):\n",
" next_x = jax.tree_util.tree_map(lambda w, v: w + learning_rate*v, x,\n",
Expand All @@ -94,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"colab": {
"height": 279
Expand Down Expand Up @@ -136,7 +126,7 @@
"\n",
"next_loss = jax.jit(lambda eta: loss(x0 + eta*update_dir))\n",
"\n",
"bound = bound_loss(x0, update_dir, max_learning_rate)\n",
"bound = bound_loss(loss, x0, update_dir, max_learning_rate)\n",
"etas = np.linspace(0, max_learning_rate, 101)\n",
"pyplot.plot(etas, [bound.upper(eta) for eta in etas], 'r',\n",
" label='Quadratic upper bound')\n",
Expand All @@ -158,18 +148,8 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 112,
"status": "ok",
"timestamp": 1680628972776,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": 420
},
"id": "-Y5bG4Ued-kg"
},
"outputs": [],
Expand Down Expand Up @@ -209,27 +189,17 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 68,
"status": "ok",
"timestamp": 1680628975924,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": 420
},
"id": "Ha2GeFJxjbJr"
},
"outputs": [],
"source": [
"import jax\n",
"\n",
"\n",
"def safe_learning_rate(x, update_dir, max_learning_rate):\n",
" bounds = bound_loss(x, update_dir, max_learning_rate)\n",
"def safe_learning_rate(loss, x, update_dir, max_learning_rate):\n",
" bounds = bound_loss(loss, x, update_dir, max_learning_rate)\n",
" if len(bounds.coefficients) \u003c 3:\n",
" raise NotImplementedError() # the loss is linear \n",
" c1 = bounds.coefficients[1]\n",
Expand All @@ -243,7 +213,7 @@
"loss = lambda x: (x-1)**2\n",
"x0 = 0.\n",
"update_dir = -jax.grad(loss)(x0)\n",
"safe_eta = safe_learning_rate(x0, update_dir, 1e12)\n",
"safe_eta = safe_learning_rate(loss, x0, update_dir, 1e12)\n",
"assert float(loss(x0 + safe_eta*update_dir)) == 0."
]
},
Expand All @@ -258,7 +228,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"colab": {
"height": 287
Expand Down Expand Up @@ -296,7 +266,7 @@
"update_dir = -jax.grad(loss)(x0)\n",
"\n",
"safe_eta = jax.jit(\n",
" lambda max_rate: safe_learning_rate(x0, update_dir, max_rate))\n",
" lambda max_rate: safe_learning_rate(loss, x0, update_dir, max_rate))\n",
"max_rates = [1.1**i for i in range(-90, 3)]\n",
"\n",
"pyplot.plot(max_rates, [safe_eta(r) for r in max_rates])\n",
Expand All @@ -319,18 +289,8 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 21,
"status": "ok",
"timestamp": 1680628985305,
"user": {
"displayName": "",
"userId": ""
},
"user_tz": 420
},
"id": "1RReHtyna6lr"
},
"outputs": [],
Expand All @@ -342,8 +302,8 @@
"\n",
" def update_state(state):\n",
" x, _, max_eta = state\n",
" jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))\n",
" safe_eta = safe_learning_rate(x, update_dir, max_eta)\n",
" update_dir = jax.tree_util.tree_map(lambda v: -v, jax.grad(loss)(x))\n",
" safe_eta = safe_learning_rate(loss, x, update_dir, max_eta)\n",
" next_x = jax.tree_util.tree_map(\n",
" lambda p, v: p + safe_eta*v, x, update_dir\n",
" )\n",
Expand Down Expand Up @@ -373,7 +333,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"colab": {
"height": 747
Expand Down

0 comments on commit 1f82563

Please sign in to comment.