Skip to content

Commit

Permalink
add inline_automask
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 4, 2023
1 parent 348a412 commit 22e723a
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 14 deletions.
201 changes: 188 additions & 13 deletions docs/notebooks/common_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,8 @@
"Linear(\n",
" in_features=(1), \n",
" out_features=1, \n",
" in_axis=(-1), \n",
" out_axis=-1, \n",
" weight_init=glorot_uniform, \n",
" bias_init=zeros, \n",
" weight=f32[1,1](μ=0.20, σ=0.00, ∈[0.20,0.20]), \n",
Expand Down Expand Up @@ -1239,14 +1241,18 @@
" encoder=Linear(\n",
" in_features=(#1), \n",
" out_features=#10, \n",
" in_axis=(#-1), \n",
" out_axis=#-1, \n",
" weight_init=#glorot_uniform, \n",
" bias_init=#zeros, \n",
" weight=f32[1,10](μ=-0.78, σ=1.11, ∈[-2.58,0.00]), \n",
" weight=f32[10,1](μ=-0.78, σ=1.11, ∈[-2.58,0.00]), \n",
" bias=f32[10](μ=-0.39, σ=0.55, ∈[-1.29,0.00])\n",
" ), \n",
" decoder=Linear(\n",
" in_features=(#10), \n",
" out_features=#1, \n",
" in_axis=(#-1), \n",
" out_axis=#-1, \n",
" weight_init=#glorot_uniform, \n",
" bias_init=#zeros, \n",
" weight=None, \n",
Expand Down Expand Up @@ -1334,38 +1340,58 @@
"import jax.random as jr\n",
"import jax.numpy as jnp\n",
"from typing import Any\n",
"from typing import Any, Callable, TypeVar\n",
"from typing_extensions import ParamSpec\n",
"\n",
"T = TypeVar(\"T\")\n",
"P = ParamSpec(\"P\")\n",
"\n",
"\n",
"def automask(jax_transform):\n",
" \"\"\"Enable jax transformations to accept non-jax types.\"\"\"\n",
"def automask(jax_transform: Callable[P, T]) -> Callable[P, T]:\n",
" \"\"\"Enable jax transformations to accept non-jax types. e.g. ``jax.grad``.\"\"\"\n",
" # works with functions that takes a function as input\n",
" # and returns a function as output e.g. `jax.grad`\n",
"\n",
" def out_transform(func, **transformation_kwargs):\n",
" @ft.partial(jax_transform, **transformation_kwargs)\n",
" def jax_boundary(*args, **kwargs):\n",
" # unmask the inputs before pasing to the actual function\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" # outputs should return jax types\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
"\n",
" @ft.wraps(func)\n",
" def outer_wrapper(*args, **kwargs):\n",
" # mask the inputs before the `jax` boundary\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
" # apply the jax transformation\n",
" output = jax_boundary(*args, **kwargs)\n",
" # unmask the outputs before returning\n",
" return sk.tree_unmask(output)\n",
"\n",
" return outer_wrapper\n",
"\n",
" return out_transform"
" return out_transform\n",
"\n",
"\n",
"def inline_automask(jax_transform: Callable[P, T]) -> Callable[P, T]:\n",
" \"\"\"Enable jax transformations to accept non-jax types e.g. ``jax.lax.scan``.\"\"\"\n",
" # works with functions that takes a function and arguments as input\n",
" # and returns jax types as output e.g. `jax.lax.scan`\n",
"\n",
" def outer_wrapper(func, *args, **kwargs):\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
"\n",
" def func_masked(*args, **kwargs):\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
"\n",
" output = jax_transform(func_masked, *args, **kwargs)\n",
" return sk.tree_unmask(output)\n",
"\n",
" return outer_wrapper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`automask` with `jit`"
"### `automask` with `jit`"
]
},
{
Expand All @@ -1378,6 +1404,7 @@
"output_type": "stream",
"text": [
"`jit error`: Argument 'layer' of type <class 'str'> is not a valid JAX type\n",
"\n",
"Using automask:\n",
"forward_jit(params, x)=Array([[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
" [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
Expand All @@ -1404,15 +1431,15 @@
" print(\"`jit error`:\", e)\n",
" # now with `automask` the function can accept non-jax types (e.g. string)\n",
" forward_jit = automask(jax.jit)(forward)\n",
" print(\"Using automask:\")\n",
" print(\"\\nUsing automask:\")\n",
" print(f\"{forward_jit(params, x)=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`automask` with `vmap`"
"### `automask` with `vmap`"
]
},
{
Expand All @@ -1425,6 +1452,7 @@
"output_type": "stream",
"text": [
"`vmap error`: Output from batched function 'layer' with type <class 'str'> is not a valid JAX type\n",
"\n",
"Using automask:\n",
"{\n",
" name:layer, \n",
Expand All @@ -1449,9 +1477,156 @@
" print(\"`vmap error`:\", e)\n",
" # now with `automask` the function can accept non-jax types (e.g. string)\n",
" params = automask(jax.vmap)(make_params)(keys)\n",
" print(\"Using automask:\")\n",
" print(\"\\nUsing automask:\")\n",
" print(sk.tree_repr(params))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `automask` with `make_jaxpr`"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`jax.make_jaxpr` failed: error=TypeError(\"Argument 'glorot_uniform' of type <class 'str'> is not a valid JAX type\")\n",
"\n",
"Using `automask:\n",
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[1,1]\u001b[39m b\u001b[35m:f32[1]\u001b[39m c\u001b[35m:f32[10,10]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22md\u001b[35m:f32[1]\u001b[39m = squeeze[dimensions=(1,)] a\n",
" e\u001b[35m:f32[10]\u001b[39m = reduce_sum[axes=(1,)] c\n",
" f\u001b[35m:f32[10,1]\u001b[39m = dot_general[\n",
" dimension_numbers=(([], []), ([], []))\n",
" preferred_element_type=float32\n",
" ] e d\n",
" g\u001b[35m:f32[1,1]\u001b[39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] b\n",
" h\u001b[35m:f32[10,1]\u001b[39m = add f g\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(h,) }\n"
]
}
],
"source": [
"# make jaxpr\n",
"def func(layer, input):\n",
" # this function accepts non-jax types (e.g. `layer`)\n",
" output = layer(input)\n",
" return output\n",
"\n",
"\n",
"linear = sk.nn.Linear(1, 1, key=jr.PRNGKey(0))\n",
"try:\n",
" jax.make_jaxpr(func)(linear, jnp.ones((10, 10)))\n",
"except TypeError as error:\n",
" print(f\"`jax.make_jaxpr` failed: {error=}\")\n",
" print(\"\\nUsing `automask:\")\n",
" print(automask(jax.make_jaxpr)(func)(linear, jnp.ones((10, 10))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `inline_automask` with `jax.lax.scan`"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`jax.lax.scan` Failed: error=TypeError(\"Value 'glorot_uniform' with type <class 'str'> is not a valid JAX type\")\n",
"\n",
"Using `inline_automask`:\n",
"(Linear(\n",
" in_features=(1), \n",
" out_features=1, \n",
" in_axis=(-1), \n",
" out_axis=-1, \n",
" weight_init=glorot_uniform, \n",
" bias_init=zeros, \n",
" weight=f32[1,1](μ=0.20, σ=0.00, ∈[0.20,0.20]), \n",
" bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n",
"), Array([[1.9816917],\n",
" [1.9816917],\n",
" [1.9816917],\n",
" [1.9816917],\n",
" [1.9816917],\n",
" [1.9816917],\n",
" [1.9816917],\n",
" [1.9816917],\n",
" [1.9816917],\n",
" [1.9816917]], dtype=float32))\n"
]
}
],
"source": [
"def scan_func(layer, input):\n",
" # layer contains non-jax types\n",
" output = layer(input)\n",
" return layer, output\n",
"\n",
"layer = sk.nn.Linear(1, 1, key=jr.PRNGKey(0))\n",
"\n",
"try:\n",
" jax.lax.scan(func, linear, jnp.ones((10, 10)))\n",
"except TypeError as error:\n",
" print(f\"`jax.lax.scan` Failed: {error=}\")\n",
" print(\"\\nUsing `inline_automask`:\")\n",
" print(inline_automask(jax.lax.scan)(scan_func, linear, jnp.ones((10, 10))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `inline_automask` with `jax.eval_shape`"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`jax.eval_shape` Failed: error=TypeError(\"Argument 'glorot_uniform' of type <class 'str'> is not a valid JAX type\")\n",
"\n",
"Using `inline_automask`:\n",
"ShapeDtypeStruct(shape=(10, 1), dtype=float32)\n"
]
}
],
"source": [
"# eval shape\n",
"def func(layer, input):\n",
" # this function accepts non-jax types (e.g. `layer`)\n",
" output = layer(input)\n",
" return output\n",
"\n",
"\n",
"linear = sk.nn.Linear(1, 1, key=jr.PRNGKey(0))\n",
"\n",
"try:\n",
" jax.eval_shape(func, linear, jnp.ones((10, 10)))\n",
"except TypeError as error:\n",
" print(f\"`jax.eval_shape` Failed: {error=}\")\n",
" print(\"\\nUsing `inline_automask`:\")\n",
" print(inline_automask(jax.eval_shape)(func, linear, jnp.ones((10, 10))))"
]
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion serket/_src/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,4 @@ def is_leaf(x: Any) -> bool:


tree_eval.eval_dispatcher = ft.singledispatch(lambda x: x)
tree_eval.def_eval = tree_eval.eval_dispatcher.register
tree_eval.def_eval = tree_eval.eval_dispatcher.register

0 comments on commit 22e723a

Please sign in to comment.