Skip to content

Commit

Permalink
p -> drop_rate
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 4, 2023
1 parent fff1261 commit f6440f3
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- `MLP` produces smaller `jaxprs` and are faster to compile. for my use case -higher order differentiation through `PINN`- the new `MLP` is faster to compile.
- `kernel_dilation` -> `dilation`
- `input_dilation` -> Removed.
- `p` -> `drop_rate` in all dropout layers

### Additions

Expand Down
58 changes: 46 additions & 12 deletions docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,19 @@
}
],
"source": [
"import serket as sk \n",
"import jax \n",
"import math \n",
"import serket as sk\n",
"import jax\n",
"import math\n",
"\n",
"# 1) linear layer with no bias\n",
"linear = sk.nn.Linear(1, 10, weight_init=\"he_normal\", bias_init=None)\n",
"\n",
"\n",
"# linear layer with custom initialization function\n",
"def init_func(key, shape, dtype=jax.numpy.float32):\n",
" return jax.numpy.arange(math.prod(shape), dtype=dtype).reshape(shape)\n",
"\n",
"\n",
"linear = sk.nn.Linear(1, 10, weight_init=init_func, bias_init=None)\n",
"print(linear.weight)\n",
"# [[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]\n",
Expand Down Expand Up @@ -322,28 +324,29 @@
"outputs": [],
"source": [
"import serket as sk\n",
"import jax \n",
"import jax\n",
"\n",
"# 1) activation function with a string\n",
"linear = sk.nn.FNN([1,1],act_func=\"relu\")\n",
"linear = sk.nn.FNN([1, 1], act_func=\"relu\")\n",
"\n",
"# 2) activation function with a function\n",
"linear = sk.nn.FNN([1,1],act_func=jax.nn.relu)\n",
"linear = sk.nn.FNN([1, 1], act_func=jax.nn.relu)\n",
"\n",
"\n",
"@sk.autoinit\n",
"class MyTrainableActivation(sk.TreeClass):\n",
" my_param: float = 10.0\n",
" def __call__(self, x):\n",
" return x * self.my_param\n",
" my_param: float = 10.0\n",
"\n",
" def __call__(self, x):\n",
" return x * self.my_param\n",
"\n",
"\n",
"# 3) activation function with a class\n",
"linear = sk.nn.FNN([1,1],act_func=MyTrainableActivation())\n",
"linear = sk.nn.FNN([1, 1], act_func=MyTrainableActivation())\n",
"\n",
"# 4) activation function with a registered class\n",
"sk.def_act_entry(\"my_act\", MyTrainableActivation)\n",
"linear = sk.nn.FNN([1,1],act_func=\"my_act\")"
"linear = sk.nn.FNN([1, 1], act_func=\"my_act\")"
]
},
{
Expand Down Expand Up @@ -379,12 +382,43 @@
}
],
"source": [
"import serket as sk \n",
"import serket as sk\n",
"import jax\n",
"\n",
"linear = sk.nn.Linear(10, 5, dtype=jax.numpy.float16)\n",
"linear\n",
"# note the dtype is f16(float16) in the repr output"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import serket as sk \n",
"import jax.numpy as jnp\n",
"\n",
"class MultiHeadDotProductAttention(sk.TreeClass):\n",
" def __init__(\n",
" self,\n",
" qkv_features:tuple[int,int,int],\n",
" out_features:int,\n",
" *,\n",
" num_heads:int,\n",
" q_weight_init,\n",
" q_bias_init,\n",
" k_weight_init,\n",
" k_bias_init,\n",
" v_weight_init,\n",
" v_bias_init,\n",
"\n",
" ):\n",
" \n",
"\n",
"\n",
" "
]
}
],
"metadata": {
Expand Down
4 changes: 3 additions & 1 deletion docs/notebooks/regularization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@
" x = jax.nn.tanh(self.linear3(x))\n",
" x = self.linear4(x)\n",
" return x\n",
" \n",
"\n",
"\n",
"def linear_12_weight_l1_loss(net: Net):\n",
" return (\n",
" # select desired branches (linear1, linear2 in this example)\n",
Expand All @@ -113,6 +114,7 @@
" .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)\n",
" )\n",
"\n",
"\n",
"net = Net()\n",
"print(linear_12_weight_l1_loss(net))"
]
Expand Down
2 changes: 2 additions & 0 deletions docs/notebooks/train_eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
"\n",
"add_one = AddOne()\n",
"\n",
"\n",
"class AddOneEval(sk.TreeClass):\n",
" def __call__(self, x: jax.Array) -> jax.Array:\n",
" return x # no-op\n",
Expand All @@ -222,6 +223,7 @@
"def _(_: AddOne) -> AddOneEval:\n",
" return AddOneEval()\n",
"\n",
"\n",
"print(add_one(x))\n",
"print(sk.tree_eval(add_one)(x))"
]
Expand Down
22 changes: 11 additions & 11 deletions serket/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ class Dropout(sk.TreeClass):
"""Drop some elements of the input tensor.
Randomly zeroes some of the elements of the input tensor with
probability ``p`` using samples from a Bernoulli distribution.
probability ``drop_rate`` using samples from a Bernoulli distribution.
Args:
p: probability of an element to be zeroed. Default: 0.5
drop_rate: probability of an element to be zeroed. Default: 0.5
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.nn.Dropout(0.5)
>>> # change `p` to 0.0 to turn off dropout
>>> layer = layer.at["p"].set(0.0, is_leaf=sk.is_frozen)
>>> print(layer(jnp.ones([10])))
[2. 0. 2. 2. 2. 2. 2. 2. 0. 0.]
Note:
Use :func:`.tree_eval` to turn off dropout during evaluation.
Expand All @@ -65,27 +65,27 @@ class Dropout(sk.TreeClass):
)
"""

p: float = sk.field(default=0.5, callbacks=[Range(0, 1)])
drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)])

def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)):
return jnp.where(
(keep_prop := jax.lax.stop_gradient(1 - self.p)) == 0.0,
(keep_prop := jax.lax.stop_gradient(1 - self.drop_rate)) == 0.0,
jnp.zeros_like(x),
jnp.where(jr.bernoulli(key, keep_prop, x.shape), x / keep_prop, 0),
)


@sk.autoinit
class DropoutND(sk.TreeClass):
p: float = sk.field(default=0.5, callbacks=[Range(0, 1)])
drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x, *, key=jr.PRNGKey(0)):
# drops full feature maps along first axis.
shape = (x.shape[0], *([1] * (x.ndim - 1)))

return jnp.where(
(keep_prop := jax.lax.stop_gradient(1 - self.p)) == 0.0,
(keep_prop := jax.lax.stop_gradient(1 - self.drop_rate)) == 0.0,
jnp.zeros_like(x),
jnp.where(jr.bernoulli(key, keep_prop, shape=shape), x / keep_prop, 0),
)
Expand All @@ -100,7 +100,7 @@ class Dropout1D(DropoutND):
"""Drops full feature maps along the channel axis.
Args:
p: fraction of an elements to be zeroed out.
drop_rate: fraction of an elements to be zeroed out.
Example:
>>> import serket as sk
Expand Down Expand Up @@ -143,7 +143,7 @@ class Dropout2D(DropoutND):
"""Drops full feature maps along the channel axis.
Args:
p: fraction of an elements to be zeroed out.
drop_rate: fraction of an elements to be zeroed out.
Example:
>>> import serket as sk
Expand Down Expand Up @@ -190,7 +190,7 @@ class Dropout3D(DropoutND):
"""Drops full feature maps along the channel axis.
Args:
p: fraction of an elements to be zeroed out.
drop_rate: fraction of an elements to be zeroed out.
Example:
>>> import serket as sk
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_dropout():
layer = Dropout(1.0)
npt.assert_allclose(layer(x), jnp.array([0.0, 0.0, 0.0, 0.0, 0.0]))

layer = layer.at["p"].set(0.0, is_leaf=sk.is_frozen)
layer = layer.at["drop_rate"].set(0.0, is_leaf=sk.is_frozen)
npt.assert_allclose(layer(x), x)

with pytest.raises(ValueError):
Expand Down

0 comments on commit f6440f3

Please sign in to comment.