From a17c19ebb8d2a838aeed23c7863716a121c7b415 Mon Sep 17 00:00:00 2001 From: Mahmoud Asem <48389287+ASEM000@users.noreply.github.com> Date: Thu, 14 Dec 2023 02:04:37 +0900 Subject: [PATCH] line search (#89) --- docs/notebooks/model_surgery.ipynb | 320 ++++++++++----------------- docs/notebooks/optimlib.ipynb | 341 +++++++++++++++++++++++------ 2 files changed, 395 insertions(+), 266 deletions(-) diff --git a/docs/notebooks/model_surgery.ipynb b/docs/notebooks/model_surgery.ipynb index 8d67aa8..ff4a6ac 100644 --- a/docs/notebooks/model_surgery.ipynb +++ b/docs/notebooks/model_surgery.ipynb @@ -15,7 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install git+https://github.com/ASEM000/serket --quiet" + "# !pip install git+https://github.com/ASEM000/serket --quiet" ] }, { @@ -61,7 +61,7 @@ } ], "source": [ - "pytree1 = [1, [2, 3], 4] \n", + "pytree1 = [1, [2, 3], 4]\n", "indexer: sk.AtIndexer = sk.AtIndexer(pytree1)\n", "pytree2 = indexer[...].get() # get the whole pytree using ...\n", "print(f\"{pytree1=}, {pytree2=}\")\n", @@ -126,133 +126,92 @@ "indexer: sk.AtIndexer = sk.AtIndexer(pytree1)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - "Original pytree1\n", - "dict\n", - "├── ['a']=-1\n", - "├── ['b']:dict\n", - "│ ├── ['c']=2\n", - "│ └── ['d']=3\n", - "├── ['e']=-4\n", - "└── ['f']:dict\n", - " ├── ['g']=7\n", - " └── ['h']=8\n", - "```" - ] - }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': -1, 'b': {'c': 200, 'd': 3}, 'e': -4, 'f': {'g': 7, 'h': 8}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# exmaple 1: set the value of pytree1[\"b\"][\"c\"] to 200\n", - "pytree2 = indexer[\"b\"][\"c\"].set(200)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - "Set pytree1['b']['c'] to 200:\n", - " dict\n", - " ├── ['a']=-1\n", - " ├── ['b']:dict\n", - "+│ ├── ['c']=200\n", - " │ └── ['d']=3\n", - " ├── ['e']=-4\n", - " └── ['f']:dict\n", - " ├── ['g']=7\n", - " └── ['h']=8\n", - "\n", - "```" + "pytree2 = indexer[\"b\"][\"c\"].set(200)\n", + "pytree2" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# example 2: set the value of pytree1[\"b\"] to 100\n", - "pytree3 = indexer[\"b\"].set(100)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - "Set pytree1['b'] to 100:\n", - " dict\n", - " ├── ['a']=-1\n", - "+├── ['b']=100\n", - " ├── ['e']=-4\n", - " └── ['f']:dict\n", - " ├── ['g']=7\n", - " └── ['h']=8\n", - "```" + "pytree3 = indexer[\"b\"].set(100)\n", + "pytree3" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': -1, 'b': {'c': 100, 'd': 100}, 'e': -4, 'f': {'g': 7, 'h': 8}}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# example 3: set _all leaves_ of \"b\" subtree to 100\n", - "pytree4 = indexer[\"b\"][...].set(100)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - "Set _all leaves_ of pytree1['b'] to 100:\n", - " dict\n", - " ├── ['a']=-1\n", - " ├── ['b']:dict\n", - "+│ ├── ['c']=100\n", - "+│ └── ['d']=100\n", - " ├── ['e']=-4\n", - " └── ['f']:dict\n", - " ├── ['g']=7\n", - " └── ['h']=8\n", - "```" + "pytree4 = indexer[\"b\"][...].set(100)\n", + "pytree4" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': -1, 'b': {'c': 100, 'd': 100}, 'e': -4, 'f': {'g': 100, 'h': 100}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# example 4: set _all leaves_ of pytree1[\"b\"] _and_ pytree1[\"f\"] to 100\n", - "pytree5 = indexer[\"b\", \"f\"][...].set(100)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - "Set _all leaves_ of pytree1['b'] and pytree1['f'] to 100:\n", - " dict\n", - " ├── ['a']=-1\n", - " ├── ['b']:dict\n", - "+│ ├── ['c']=100\n", - "+│ └── ['d']=100\n", - " ├── ['e']=-4\n", - " └── ['f']:dict\n", - "+ ├── ['g']=100\n", - "+ └── ['h']=100\n", - "```" + "pytree5 = indexer[\"b\", \"f\"][...].set(100)\n", + "pytree5" ] }, { @@ -267,49 +226,24 @@ "cell_type": "code", "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n", "mask = jax.tree_map(lambda x: x < 0, pytree1)\n", "indexer: sk.AtIndexer = sk.AtIndexer(pytree1)\n", - "pytree2 = indexer[mask].set(0)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Original pytree\n", - "```diff\n", - "dict\n", - " ├── ['a']=-1\n", - " ├── ['b']:dict\n", - " │ ├── ['c']=2\n", - " │ └── ['d']=3\n", - " └── ['e']=-4\n", - "```\n", - "\n", - "Mask\n", - "\n", - "```diff\n", - "dict\n", - "+├── ['a']=True\n", - " ├── ['b']:dict\n", - " │ ├── ['c']=False\n", - " │ └── ['d']=False\n", - "+└── ['e']=True\n", - "```\n", - "\n", - "modified pytree\n", - "\n", - "```diff\n", - "dict\n", - "+├── ['a']=0\n", - " ├── ['b']:dict\n", - " │ ├── ['c']=2\n", - " │ └── ['d']=3\n", - "+└── ['e']=0\n", - "```" + "pytree2 = indexer[mask].set(0)\n", + "pytree2" ] }, { @@ -371,21 +305,6 @@ "print(f\"{net1=}\")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - "Net\n", - "├── .encoder:dict\n", - "│ ├── ['bias']=f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - "│ └── ['weight']=f32[5,3](μ=0.30, σ=0.90, ∈[-1.44,1.84])\n", - "└── .decoder:dict\n", - " ├── ['bias']=f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - " └── ['weight']=f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20])\n", - "```" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -397,64 +316,63 @@ "cell_type": "code", "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Net(\n", + " encoder={\n", + " bias:f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " weight:f32[5,3](μ=0.04, σ=0.18, ∈[-0.20,0.20])\n", + " }, \n", + " decoder={\n", + " bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " weight:f32[5,5](μ=-0.02, σ=0.18, ∈[-0.20,0.20])\n", + " }\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# example 1: clip the `weights` of `encoder` and `decoder` to [-0.2, 0.2]\n", - "net2 = net1.at[\"encoder\", \"decoder\"][\"weight\"].apply(lambda x: jnp.clip(x, -0.2, 0.2))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - " Net\n", - " ├── .encoder:dict\n", - " │ ├── ['bias']=f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - "+│ └── ['weight']=f32[5,3](μ=0.04, σ=0.18, ∈[-0.20,0.20])\n", - " └── .decoder:dict\n", - " ├── ['bias']=f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - "+ └── ['weight']=f32[5,5](μ=-0.02, σ=0.18, ∈[-0.20,0.20])\n", - "````" + "net2 = net1.at[\"encoder\", \"decoder\"][\"weight\"].apply(lambda x: jnp.clip(x, -0.2, 0.2))\n", + "net2" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Net(\n", + " encoder={\n", + " weight:f32[5,3](μ=100.00, σ=0.00, ∈[100.00,100.00]), \n", + " bias:f32[5](μ=100.00, σ=0.00, ∈[100.00,100.00])\n", + " }, \n", + " decoder={\n", + " weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), \n", + " bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " }\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# example 2: load pretrained weights for `encoder`\n", - "pretrained = {\"weight\": jnp.ones((5, 3))*100., \"bias\": jnp.ones((5,))*100.}\n", - "net3 = net1.at[\"encoder\"].set(pretrained)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```diff\n", - " net1=Net(\n", - " encoder={\n", - "- weight:f32[5,3](μ=0.30, σ=0.90, ∈[-1.44,1.84]), \n", - "- bias:f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - " }, \n", - " decoder={\n", - " weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), \n", - " bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - " }\n", - " )\n", - " net3=Net(\n", - " encoder={\n", - "+ weight:f32[5,3](μ=100.00, σ=0.00, ∈[100.00,100.00]), \n", - "+ bias:f32[5](μ=100.00, σ=0.00, ∈[100.00,100.00])\n", - " }, \n", - " decoder={\n", - " weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), \n", - " bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - " }\n", - " )\n", - "```" + "pretrained = {\"weight\": jnp.ones((5, 3)) * 100.0, \"bias\": jnp.ones((5,)) * 100.0}\n", + "net3 = net1.at[\"encoder\"].set(pretrained)\n", + "net3" ] } ], diff --git a/docs/notebooks/optimlib.ipynb b/docs/notebooks/optimlib.ipynb index bb837af..563c5f5 100644 --- a/docs/notebooks/optimlib.ipynb +++ b/docs/notebooks/optimlib.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# 🧮 Build mini optimizer library" + "# 🧮 Mini optimizer library" ] }, { @@ -13,28 +13,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In the following example a mini optimizer library is built using `serket`. The strategy will be to write the optimizer method with the inplace modification as done in similar libraries like `PyTorch`, then use `.at[method_name]` to execute the inplace modification on a new instance and comply with `jax` functional updates. Unlike `optax`, this optimizer library combines the state and update function as a method." + "In the following example a mini optimizer library is built using `serket`. The strategy will be to write the optimizer methods with the inplace modification as done in similar libraries like `PyTorch`, then use `.at[method_name]` to execute the inplace modification on a new instance and comply with `jax` functional updates." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install git+https://github.com/ASEM000/serket --quiet" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Imports" - ] - }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -53,12 +46,48 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Optimizer functions" + "## Utils" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class FNN(sk.TreeClass):\n", + " def __init__(self, *, key: jax.Array):\n", + " k1, k2 = jr.split(key)\n", + " self.w1 = jax.random.normal(k1, [10, 1])\n", + " self.b1 = jnp.zeros([10], dtype=jnp.float32)\n", + " self.w2 = jax.random.normal(k2, [1, 10])\n", + " self.b2 = jnp.zeros([1], dtype=jnp.float32)\n", + "\n", + " def __call__(self, input: jax.Array) -> jax.Array:\n", + " output = input @ self.w1.T + self.b1\n", + " output = jax.nn.relu(output)\n", + " output = output @ self.w2.T + self.b2\n", + " return output\n", + "\n", + "\n", + "def loss_func(net: FNN, input: jax.Array, target: jax.Array) -> jax.Array:\n", + " return jnp.mean((net(input) - target) ** 2)\n", + "\n", + "\n", + "input = jnp.linspace(-1, 1, 100).reshape(-1, 1)\n", + "target = input**2 + 0.1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## First order-optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -77,6 +106,19 @@ "\n", "\n", "class Adam(sk.TreeClass):\n", + " \"\"\"Apply the Adam update rule to the incoming updates\n", + "\n", + " Args:\n", + " tree: PyTree of parameters to be optimized\n", + " beta1: exponential decay rate for the first moment\n", + " beta2: exponential decay rate for the second moment\n", + " eps: small value to avoid division by zero\n", + "\n", + " Note:\n", + " This implementation does not scale the updates by the learning rate.\n", + " Use ``jax.tree_map(lambda x: x * lr, updates)`` to scale the updates.\n", + " \"\"\"\n", + "\n", " def __init__(\n", " self,\n", " tree,\n", @@ -92,7 +134,11 @@ " self.count = 0\n", "\n", " def __call__(self, updates: T) -> T:\n", - " # calling this method will raise `AttributeError` because\n", + " \"\"\"Apply the Adam update rule to the incoming updates\"\"\"\n", + " # this method will transform the incoming updates(gradients) into\n", + " # the updates that will be applied to the parameters\n", + "\n", + " # NOTE: calling this method will raise `AttributeError` because\n", " # its mutating the state (e.g. self.something=something)\n", " # it will only work if used with `.at` that executes it functionally\n", " self.count += 1\n", @@ -108,6 +154,15 @@ "\n", "\n", "class ExponentialDecay(sk.TreeClass):\n", + " \"\"\"Scale the incoming updates by an exponentially decaying learning rate\n", + "\n", + " Args:\n", + " init_rate: initial learning rate\n", + " decay_rate: rate of decay\n", + " transition_steps: number of steps to transition from init_rate to 0\n", + " transition_begins: number of steps to wait before starting the transition\n", + " \"\"\"\n", + "\n", " def __init__(\n", " self,\n", " init_rate: float,\n", @@ -124,6 +179,10 @@ " self.transition_steps = transition_steps\n", "\n", " def __call__(self, updates: T) -> T:\n", + " \"\"\"Scale the updates by the current learning rate\"\"\"\n", + " # NOTE: calling this method will raise `AttributeError` because\n", + " # its mutating the state (e.g. self.something=something)\n", + " # it will only work if used with `.at` that executes it functionally\n", " self.count += 1\n", " count = self.count - self.transition_begins\n", " self.rate = jnp.where(\n", @@ -138,40 +197,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Construct a fully connected neural network" + "### Test" ] }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "class FNN(sk.TreeClass):\n", - " def __init__(self, *, key: jax.Array):\n", - " k1, k2 = jr.split(key)\n", - " self.w1 = jax.random.normal(k1, [1, 10])\n", - " self.b1 = jnp.zeros([10], dtype=jnp.float32)\n", - " self.w2 = jax.random.normal(k2, [10, 1])\n", - " self.b2 = jnp.zeros([1], dtype=jnp.float32)\n", - "\n", - " def __call__(self, x: jax.Array) -> jax.Array:\n", - " x = x @ self.w1 + self.b1\n", - " x = jax.nn.relu(x)\n", - " x = x @ self.w2 + self.b2\n", - " return x" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train function" - ] - }, - { - "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -193,10 +224,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, @@ -212,21 +243,6 @@ } ], "source": [ - "def loss_func(net: FNN, x: jax.Array, y: jax.Array) -> jax.Array:\n", - " return jnp.mean((net(x) - y) ** 2)\n", - "\n", - "\n", - "@jax.jit\n", - "def train_step(net, optim, x, y):\n", - " grads = jax.grad(loss_func)(net, x, y)\n", - " grads, optim = optim.at[\"__call__\"](grads)\n", - " net = jax.tree_map(lambda p, g: p + g, net, grads)\n", - " return net, optim\n", - "\n", - "\n", - "x = jnp.linspace(-1, 1, 100).reshape(-1, 1)\n", - "y = x**2 + 0.1\n", - "\n", "net = FNN(key=jr.PRNGKey(0))\n", "optim = sk.Sequential(\n", " Adam(net),\n", @@ -234,16 +250,211 @@ ")\n", "\n", "\n", + "@jax.jit\n", + "def train_step(net, optim, input, target):\n", + " grads = jax.grad(loss_func)(net, input, target)\n", + " grads, optim = sk.AtIndexer(optim)[\"__call__\"](grads)\n", + " net = jax.tree_map(lambda p, g: p + g, net, grads)\n", + " return net, optim\n", + "\n", + "\n", "for i in range(1, 10_000 + 1):\n", - " net, optim = train_step(net, optim, x, y)\n", + " net, optim = train_step(net, optim, input, target)\n", " if i % 1_000 == 0:\n", - " loss = loss_func(net, x, y)\n", + " loss = loss_func(net, input, target)\n", " print(f\"Epoch={i:003d}\\tLoss: {loss:.3e}\")\n", "\n", - "plt.plot(x, y, label=\"true\")\n", - "plt.plot(x, net(x), label=\"pred\")\n", + "plt.plot(input, target, label=\"true\")\n", + "plt.plot(input, net(input), label=\"pred\")\n", "plt.legend()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Line search\n", + "\n", + "In this section [backtracking line search](https://en.wikipedia.org/wiki/Backtracking_line_search) is implemented. The line search is used to find the step size that satisfies the strong [Wolfe conditions](https://en.wikipedia.org/wiki/Wolfe_conditions). for more check [N&W Ch3](https://www.math.uci.edu/~qnie/Publications/NumericalOptimization.pdf). The method is written in a stateful manner, i.e. it modifies the state of the optimizer inplace, However it is executed in a functional manner using `at` method to comply with `jax` transformations." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import serket as sk\n", + "import jax.numpy as jnp\n", + "from typing import Callable\n", + "import jax.tree_util as jtu\n", + "import functools as ft\n", + "\n", + "# transform numpy function that work on array to\n", + "# work on pytree of arrays. additionally if the rhs is a scalar it will\n", + "# be broadcasted to the pytree\n", + "tree_mul = sk.bcmap(jnp.multiply)\n", + "tree_add = sk.bcmap(jnp.add)\n", + "tree_neg = sk.bcmap(jnp.negative)\n", + "tree_vdot = sk.bcmap(ft.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST))\n", + "\n", + "\n", + "class BackTrackLS(sk.TreeClass):\n", + " \"\"\"Backtracking line search with strong Wolfe conditions.\n", + "\n", + " Args:\n", + " func: the function to be optimized with respect to the loss function.\n", + " accepts single pytree argument. for multiple arguments use\n", + " `functools.partial` to fix the other arguments.\n", + " maxiter: the maximum number of iterations.\n", + " tol: the tolerance for the stopping criterion.\n", + " c1: the sufficient decrease parameter for the Armijo condition. Must\n", + " statisfy 0fields for more details\n", + " func: Callable = sk.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze])\n", + "\n", + " def __init__(\n", + " self,\n", + " *,\n", + " func: Callable[..., jax.Array],\n", + " maxiter: int = 100,\n", + " tol: float = 0.0,\n", + " c1: float = 1e-4,\n", + " c2: float = 0.9,\n", + " step_size: float = 1.0,\n", + " decay: float = 0.5,\n", + " ):\n", + " self.func = func\n", + " self.maxiter = maxiter\n", + " self.tol = tol\n", + " # wolfe conditions\n", + " self.c1 = c1 # armijo condition constant\n", + " self.c2 = c2 # curvature condition constant\n", + " self.step_size = step_size\n", + " self.decay = decay\n", + "\n", + " # conditions numerics\n", + " self.wolfe1 = jnp.inf\n", + " self.wolfe2 = jnp.inf\n", + " self.error = jnp.inf\n", + "\n", + " # status\n", + " self.tol_reached = False\n", + " self.max_iter_reached = False\n", + " self.fail = False\n", + " self.iter_count = 0\n", + "\n", + " def step(self, xk0: T, fk0: jax.Array, dfk0: T) -> tuple[T, jax.Array, T]:\n", + " \"\"\"Compute the next iterate of the line search.\n", + "\n", + " Args:\n", + " xk0: the initial parameters. accepts pytree.\n", + " fk0: the initial function value.\n", + " dfk0: the initial gradient. accepts pytree same structure as xk0.\n", + "\n", + " Returns:\n", + " xk1: the next iterate. xk1 = xk0 + αpk\n", + " fk1: the next function value. f(xk0 + αpk)\n", + " dfk1: the next gradient. ∇f(xk0 + αpk)\n", + " \"\"\"\n", + " # NOTE: calling this method will raise `AttributeError` because\n", + " # its mutating the state (e.g. self.something=something)\n", + " # it will only work if used with `.at` that executes it functionally\n", + "\n", + " self.step_size = jnp.minimum(1.0, self.step_size)\n", + " # for simplicity we will use the negative gradient as the descent direction\n", + " pk = tree_neg(dfk0)\n", + " # <∇f(xk), pk> but with pytrees\n", + " dfkTpk0 = sum(jtu.tree_leaves(tree_vdot(dfk0, pk)))\n", + " # xk+1 = xk + αpk\n", + " xk1 = tree_add(xk0, tree_mul(pk, self.step_size))\n", + " # f(xk+1), ∇f(xk+1)\n", + " fk1, dfk1 = jax.value_and_grad(self.func)(xk1)\n", + " # <∇f(xk+1), pk> but with pytrees\n", + " dfkTp1 = sum(jtu.tree_leaves(tree_vdot(dfk1, pk)))\n", + "\n", + " # armijo condition\n", + " # f(xk+1) ≤ f(xk) + c1α∇f(xk)⊤pk\n", + " self.wolfe1 = fk1 - (fk0 + self.c1 * self.step_size * dfkTpk0)\n", + "\n", + " # curvature condition\n", + " # |∇f(xk+1)⊤pk| ≤ c2|∇f(xk)⊤pk|\n", + " self.wolfe2 = abs(dfkTp1) - self.c2 * abs(dfkTpk0)\n", + " self.error = jnp.maximum(self.wolfe1, self.wolfe2)\n", + " self.error = jnp.maximum(self.error, 0.0)\n", + " self.iter_count += 1\n", + "\n", + " # check status\n", + " self.tol_reached = self.error <= self.tol\n", + " self.max_iter_reached = self.iter_count >= self.maxiter\n", + " self.fail = self.fail | (self.max_iter_reached & ~self.tol_reached)\n", + " self.step_size = jnp.where(\n", + " self.fail | self.tol_reached,\n", + " self.step_size,\n", + " self.step_size * self.decay,\n", + " )\n", + " return xk1, fk1, dfk1\n", + "\n", + " @staticmethod\n", + " def cond_func(state: T) -> bool:\n", + " *_, ls = state\n", + " return ~(ls.fail | ls.tol_reached | ls.max_iter_reached)\n", + "\n", + " @staticmethod\n", + " def body_func(state: T) -> T:\n", + " (xk0, fk0, dfk0), _, ls = state\n", + " # note that the step method mutates the state and \n", + " # will raise `AttributeError` if called directly\n", + " # instead of using `.at` to return a tuple of the method \n", + " # output and the mutated state with the new values.\n", + " (xk1, fk1, dfk1), ls = ls.at[\"step\"](xk0, fk0, dfk0)\n", + " return (xk0, fk0, dfk0), (xk1, fk1, dfk1), ls" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "ls = BackTrackLS(\n", + " func=ft.partial(loss_func, input=input, target=target),\n", + " maxiter=100,\n", + " tol=1e-4,\n", + " c1=1e-4,\n", + " c2=0.9,\n", + " step_size=1.0,\n", + " decay=0.9,\n", + ")\n", + "\n", + "# example usage\n", + "# pass the initial parameters, function value and gradient\n", + "fk0, dfk0 = jax.value_and_grad(loss_func)(net, input, target)\n", + "init = (net, fk0, dfk0)\n", + "state = init, init, ls\n", + "state = jax.lax.while_loop(ls.cond_func, ls.body_func, state)\n", + "_, (xk1, fk1, dfk1), ls = state" + ] } ], "metadata": {