From ff17aa8016ee481109b76c8b03ae2f0849a26dbb Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Tue, 8 Aug 2023 17:02:09 +0300 Subject: [PATCH] Update dropout.py --- serket/nn/dropout.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/serket/nn/dropout.py b/serket/nn/dropout.py index 2752b21..536eb31 100644 --- a/serket/nn/dropout.py +++ b/serket/nn/dropout.py @@ -49,6 +49,7 @@ def dropout_nd( ) +@sk.autoinit class GeneralDropout(sk.TreeClass): """Drop some elements of the input tensor. @@ -58,11 +59,8 @@ class GeneralDropout(sk.TreeClass): dropout is applied to all axes. """ - def __init__( - self, drop_rate: float, drop_axes: tuple[int, ...] | Literal["..."] = ... - ): - self.drop_rate = drop_rate - self.drop_axes = drop_axes + drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) + drop_axes: tuple[int, ...] | Literal["..."] = ... def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)): return dropout_nd(x, jax.lax.stop_gradient(self.drop_rate), key, self.drop_axes)