Skip to content

Commit

Permalink
tf conversion and edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 18, 2023
1 parent 0e73bca commit 232358b
Show file tree
Hide file tree
Showing 7 changed files with 388 additions and 40 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ net = sk.tree_unmask(net)
|Attention| - `MultiHeadAttention`|
| Containers| - `Sequential`, `RandomApply`, `RandomChoice` |
| Convolution | - `{FFT,_}Conv{1D,2D,3D}` <br> - `{FFT,_}Conv{1D,2D,3D}Transpose` <br> - `Depthwise{FFT,_}Conv{1D,2D,3D}` <br> - `Separable{FFT,_}Conv{1D,2D,3D}` <br> - `Conv{1D,2D,3D}Local` |
|Dropout|- `Dropout`<br> - `Dropout{1D,2D,3D}` <br> - `GeneralDropout` <br> - `RandomCutout{1D,2D}` |
|Dropout|- `Dropout`<br> - `Dropout{1D,2D,3D}` <br> - `RandomCutout{1D,2D}` |
| Linear | - `Linear`, `Multilinear`, `GeneralLinear`, `Identity` |
|Densely connected| - `FNN` , <br> - `MLP` _compile time_ optimized |
|Normalization|- `{Layer,Instance,Group,Batch}Norm`|
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Install from github::
notebooks/evaluation
notebooks/mixed_precision
notebooks/checkpointing
notebooks/convert_tensorflow
notebooks/regularization
notebooks/subset_training

Expand Down
361 changes: 361 additions & 0 deletions docs/notebooks/convert_tensorflow.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/notebooks/train_pinn_burgers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.11.0"
},
"orig_nbformat": 4
},
Expand Down
8 changes: 4 additions & 4 deletions serket/_src/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def calculate_attention(
v_heads: jax.Array,
mask: jax.Array,
num_heads: int,
drop_layer: sk.nn.GeneralDropout,
drop_layer: sk.nn.Dropout,
key: jr.KeyArray,
) -> jax.Array:
"""Applies multi-head attention to the given inputs.
Expand Down Expand Up @@ -176,7 +176,7 @@ class MultiHeadAttention(sk.TreeClass):
>>> import serket as sk
>>> layer = sk.nn.MultiHeadAttention(1, 1, key=jr.PRNGKey(0))
>>> print(repr(layer.dropout))
GeneralDropout(drop_rate=0.0, drop_axes=Ellipsis)
Dropout(drop_rate=0.0, drop_axes=None)
>>> print(repr(sk.tree_eval(layer).dropout))
Identity()
Expand Down Expand Up @@ -247,8 +247,8 @@ def __init__(
qkey, kkey, vkey, okey = jr.split(key, 4)

self.num_heads = num_heads
drop_axes = (-1, -2) if drop_broadcast else ...
self.dropout = sk.nn.GeneralDropout(drop_rate, drop_axes)
drop_axes = (-1, -2) if drop_broadcast else None
self.dropout = sk.nn.Dropout(drop_rate, drop_axes)

self.q_projection = sk.nn.Linear(
in_features=q_features,
Expand Down
52 changes: 20 additions & 32 deletions serket/_src/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def dropout_nd(
key: jr.KeyArray,
x: jax.Array,
drop_rate,
drop_axes: Sequence[int] | Literal["..."] = ...,
drop_axes: Sequence[int] | None = None
) -> jax.Array:
"""Drop some elements of the input array."""
# drop_axes = None means dropout is applied to all axes
shape = (
x.shape
if drop_axes is ...
if drop_axes is None
else (x.shape[i] if i in drop_axes else 1 for i in range(x.ndim))
)

Expand Down Expand Up @@ -132,40 +132,15 @@ def scan_step(x, key):


@sk.autoinit
class GeneralDropout(sk.TreeClass):
"""Drop some elements of the input array.
Args:
drop_rate: probability of an element to be zeroed. Default: 0.5
drop_axes: axes along which dropout is applied. default: ``...`` which means
dropout is applied to all axes.
"""

drop_rate: float = sk.field(
default=0.5,
on_setattr=[IsInstance(float), Range(0, 1)],
on_getattr=[jax.lax.stop_gradient_p.bind],
)
drop_axes: tuple[int, ...] | Literal["..."] = ...

def __call__(self, x, *, key: jr.KeyArray):
"""Drop some elements of the input array.
Args:
x: input array
key: random number generator key
"""
return dropout_nd(key, x, self.drop_rate, self.drop_axes)


class Dropout(GeneralDropout):
class Dropout(sk.TreeClass):
"""Drop some elements of the input array.
Randomly zeroes some of the elements of the input array with
probability ``drop_rate`` using samples from a Bernoulli distribution.
Args:
drop_rate: probability of an element to be zeroed. Default: 0.5
drop_axes: axes to apply dropout. Default: None to apply to all axes.
Example:
>>> import serket as sk
Expand Down Expand Up @@ -199,8 +174,21 @@ class Dropout(GeneralDropout):
)
"""

def __init__(self, drop_rate: float = 0.5):
super().__init__(drop_rate=drop_rate, drop_axes=...)
drop_rate: float = sk.field(
default=0.5,
on_setattr=[IsInstance(float), Range(0, 1)],
on_getattr=[jax.lax.stop_gradient_p.bind],
)
drop_axes: tuple[int, ...] | None = None

def __call__(self, x, *, key: jr.KeyArray):
"""Drop some elements of the input array.
Args:
x: input array
key: random number generator key
"""
return dropout_nd(key, x, self.drop_rate, self.drop_axes)


@sk.autoinit
Expand Down Expand Up @@ -467,7 +455,7 @@ def spatial_ndim(self) -> int:

@tree_eval.def_eval(RandomCutout1D)
@tree_eval.def_eval(RandomCutout2D)
@tree_eval.def_eval(GeneralDropout)
@tree_eval.def_eval(DropoutND)
@tree_eval.def_eval(Dropout)
def _(_) -> sk.nn.Identity:
return sk.nn.Identity()
2 changes: 0 additions & 2 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
Dropout1D,
Dropout2D,
Dropout3D,
GeneralDropout,
RandomCutout1D,
RandomCutout2D,
)
Expand Down Expand Up @@ -233,7 +232,6 @@
"Dropout1D",
"Dropout2D",
"Dropout3D",
"GeneralDropout",
"RandomCutout1D",
"RandomCutout2D",
# linear
Expand Down

0 comments on commit 232358b

Please sign in to comment.