Skip to content

Commit

Permalink
remove Bilinear and streamline naming
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 28, 2023
1 parent 0a7448c commit 8cef188
Show file tree
Hide file tree
Showing 23 changed files with 409 additions and 462 deletions.
21 changes: 21 additions & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Changelog

## V0.x

### Changes

- `ScanRNN` changes:

- `backward_cell=...` is deprecated , instead use `ScanRNN(forward_cell, backward_cell, reverse=(False,True))`
- `ScanRNN` now accepts arbitrary number of cells as input and an argument `reverse` to decide whether to reverse the corresponding cell or not.
- `return_state` is added to control whether the final carry is returned or not.
- `cell.init_state` is deprecated use `sk.tree_state(cell, ...)` instead.

- Naming changes:
- `***_init_func` -> `***_init` shorter and more concise
- `gamma_init_func` -> `weight_init`
- `beta_init_func` -> `bias_init`

### Deprecations

- `Bilinear` is deprecated, use `Multilinear((in1_features, in2_features), out_features)`
8 changes: 0 additions & 8 deletions changelog.md

This file was deleted.

13 changes: 6 additions & 7 deletions docs/API/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
:maxdepth: 2
:caption: API Documentation

linear
dropout
activations
containers
pooling
convolution
dropout
image
linear
normalization
image_filtering
reshaping
random_transforms
activations
pooling
recurrent
reshaping
misc


File renamed without changes.
1 change: 0 additions & 1 deletion docs/API/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ Linear
.. currentmodule:: serket.nn

.. autoclass:: Linear
.. autoclass:: Bilinear
.. autoclass:: Identity
.. autoclass:: Multilinear
.. autoclass:: GeneralLinear
Expand Down
12 changes: 1 addition & 11 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,7 @@
PixelShuffle2D,
RandomContrast2D,
)
from .linear import (
FNN,
MLP,
Bilinear,
Embedding,
GeneralLinear,
Identity,
Linear,
Multilinear,
)
from .linear import FNN, MLP, Embedding, GeneralLinear, Identity, Linear, Multilinear
from .normalization import BatchNorm, GroupNorm, InstanceNorm, LayerNorm
from .pooling import (
AdaptiveAvgPool1D,
Expand Down Expand Up @@ -177,7 +168,6 @@
"MLP",
# Linear
"Linear",
"Bilinear",
"Identity",
"Multilinear",
"GeneralLinear",
Expand Down
6 changes: 1 addition & 5 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,5 @@ def resolve_activation(act_func: ActivationType) -> ActivationFunctionType:
if isinstance(act_func, str):
if act_func in act_map:
return act_map[act_func]()

raise ValueError(
f"Unknown activation function {act_func=}, "
f"available activations are {list(act_map)}"
)
raise ValueError(f"Unknown {act_func=}, available activations: {list(act_map)}")
return act_func
4 changes: 2 additions & 2 deletions serket/nn/blocks/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def __init__(self, in_features: int, out_features: int):
out_features=out_features,
kernel_size=3,
padding=1,
bias_init_func=None,
bias_init=None,
)
self.conv2 = sk.nn.Conv2D(
in_features=out_features,
out_features=out_features,
kernel_size=3,
padding=1,
bias_init_func=None,
bias_init=None,
)

def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down
5 changes: 2 additions & 3 deletions serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Any

import jax
import jax.numpy as jnp
import jax.random as jr

import serket as sk
Expand Down Expand Up @@ -54,7 +53,6 @@ def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Arr
try:
x = layer(x, key=key)
except TypeError:
# a `layer` or a `function` without a key argument
x = layer(x)
return x

Expand Down Expand Up @@ -109,7 +107,8 @@ class RandomApply(sk.TreeClass):
p: float = sk.field(default=0.5, callbacks=[Range(0, 1)])

def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)):
return jnp.where(jr.bernoulli(key, self.p), self.layer(x), x)
p = jax.lax.stop_gradient(self.p)
return self.layer(x) if jr.bernoulli(key, p) else x


@tree_evaluation.def_evaluation(RandomApply)
Expand Down
Loading

0 comments on commit 8cef188

Please sign in to comment.