Skip to content

Commit

Permalink
Update dropout.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 8, 2023
1 parent dcf9cc0 commit ff17aa8
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions serket/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def dropout_nd(
)


@sk.autoinit
class GeneralDropout(sk.TreeClass):
"""Drop some elements of the input tensor.
Expand All @@ -58,11 +59,8 @@ class GeneralDropout(sk.TreeClass):
dropout is applied to all axes.
"""

def __init__(
self, drop_rate: float, drop_axes: tuple[int, ...] | Literal["..."] = ...
):
self.drop_rate = drop_rate
self.drop_axes = drop_axes
drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)])
drop_axes: tuple[int, ...] | Literal["..."] = ...

def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)):
return dropout_nd(x, jax.lax.stop_gradient(self.drop_rate), key, self.drop_axes)
Expand Down

0 comments on commit ff17aa8

Please sign in to comment.