diff --git a/serket/nn/dropout.py b/serket/nn/dropout.py index 2752b21..536eb31 100644 --- a/serket/nn/dropout.py +++ b/serket/nn/dropout.py @@ -49,6 +49,7 @@ def dropout_nd( ) +@sk.autoinit class GeneralDropout(sk.TreeClass): """Drop some elements of the input tensor. @@ -58,11 +59,8 @@ class GeneralDropout(sk.TreeClass): dropout is applied to all axes. """ - def __init__( - self, drop_rate: float, drop_axes: tuple[int, ...] | Literal["..."] = ... - ): - self.drop_rate = drop_rate - self.drop_axes = drop_axes + drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) + drop_axes: tuple[int, ...] | Literal["..."] = ... def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)): return dropout_nd(x, jax.lax.stop_gradient(self.drop_rate), key, self.drop_axes)