Skip to content

Commit

Permalink
more functionalize
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 28, 2023
1 parent 584cf53 commit 4663bb5
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 51 deletions.
88 changes: 53 additions & 35 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.glu(x)


def hard_shrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array:
"""Hard shrink activation function
Reference:
https://arxiv.org/pdf/1702.00783.pdf.
"""
return jnp.where(x > alpha, x, jnp.where(x < -alpha, x, 0.0))


@sk.autoinit
class HardShrink(sk.TreeClass):
"""Hard shrink activation function"""
Expand All @@ -172,7 +181,7 @@ class HardShrink(sk.TreeClass):

def __call__(self, x: jax.Array) -> jax.Array:
alpha = lax.stop_gradient(self.alpha)
return jnp.where(x > alpha, x, jnp.where(x < -alpha, x, 0.0))
return hard_shrink(x, alpha)


class HardSigmoid(sk.TreeClass):
Expand Down Expand Up @@ -255,11 +264,16 @@ def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.softplus(x)


def softsign(x: jax.typing.ArrayLike) -> jax.Array:
"""SoftSign activation function"""
return x / (1 + jnp.abs(x))


class SoftSign(sk.TreeClass):
"""SoftSign activation function"""

def __call__(self, x: jax.Array) -> jax.Array:
return x / (1 + jnp.abs(x))
return softsign(x)


def softshrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array:
Expand Down Expand Up @@ -316,11 +330,16 @@ def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.tanh(x)


def tanh_shrink(x: jax.typing.ArrayLike) -> jax.Array:
"""TanhShrink activation function"""
return x - jnp.tanh(x)


class TanhShrink(sk.TreeClass):
"""TanhShrink activation function"""

def __call__(self, x: jax.Array) -> jax.Array:
return x - jax.nn.tanh(x)
return tanh_shrink(x)


def thresholded_relu(x: jax.typing.ArrayLike, theta: float = 1.0) -> jax.Array:
Expand Down Expand Up @@ -436,43 +455,42 @@ def __call__(self, x: jax.Array) -> jax.Array:


acts = [
AdaptiveLeakyReLU(),
AdaptiveReLU(),
AdaptiveSigmoid(),
AdaptiveTanh(),
CeLU(),
ELU(),
GELU(),
GLU(),
HardShrink(),
HardSigmoid(),
HardSwish(),
HardTanh(),
LeakyReLU(),
LogSigmoid(),
LogSoftmax(),
Mish(),
PReLU(),
ReLU(),
ReLU6(),
SeLU(),
Sigmoid(),
Snake(),
SoftPlus(),
SoftShrink(),
SoftSign(),
SquarePlus(),
Swish(),
Tanh(),
TanhShrink(),
ThresholdedReLU(),
adaptive_leaky_relu,
adaptive_relu,
adaptive_sigmoid,
adaptive_tanh,
jax.nn.celu,
jax.nn.elu,
jax.nn.gelu,
jax.nn.glu,
hard_shrink,
jax.nn.hard_sigmoid,
jax.nn.hard_swish,
jax.nn.hard_tanh,
jax.nn.leaky_relu,
jax.nn.log_sigmoid,
jax.nn.log_softmax,
mish,
prelu,
jax.nn.relu,
jax.nn.relu6,
jax.nn.selu,
jax.nn.sigmoid,
snake,
jax.nn.softplus,
softshrink,
softsign,
squareplus,
jax.nn.swish,
jax.nn.tanh,
tanh_shrink,
thresholded_relu,
]


act_map: dict[str, sk.TreeClass] = dict(zip(get_args(ActivationLiteral), acts))

ActivationFunctionType = Callable[[jax.typing.ArrayLike], jax.Array]
ActivationType = Union[ActivationLiteral, ActivationFunctionType]
act_map = dict(zip(get_args(ActivationLiteral), acts))


@ft.singledispatch
Expand Down
71 changes: 62 additions & 9 deletions serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@
from serket.nn.utils import Range


def sequential(
layers: tuple[Any, ...],
array: jax.Array,
*,
key: jr.KeyArray,
):
"""Applies a sequence of layers to an array.
Args:
layers: a tuple of layers a callable accepting an array and returning an
array with an optional ``key`` argument for random number generation.
array: an array to apply the layers to.
key: a random number generator key.
"""
for key, layer in zip(jr.split(key, len(layers)), layers):
try:
array = layer(array, key=key)
except TypeError:
# key argument is not supported
array = layer(array)
return array


class Sequential(sk.TreeClass):
"""A sequential container for layers.
Expand All @@ -49,12 +73,7 @@ def __init__(self, *layers):
self.layers = layers

def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
for key, layer in zip(jr.split(key, len(self.layers)), self.layers):
try:
x = layer(x, key=key)
except TypeError:
x = layer(x)
return x
return sequential(layers=self.layers, array=x, key=key)

@ft.singledispatchmethod
def __getitem__(self, key):
Expand All @@ -79,6 +98,24 @@ def __reversed__(self):
return reversed(self.layers)


def random_apply(
layer,
array: jax.Array,
*,
rate: float,
key: jr.KeyArray,
) -> jax.Array:
"""Randomly applies a layer with probability ``rate``.
Args:
layer: layer to apply.
array: an array to apply the layer to.
rate: probability of applying the layer
key: a random number generator key.
"""
return layer(array) if jr.bernoulli(key, rate) else array


@sk.autoinit
class RandomApply(sk.TreeClass):
"""Randomly applies a layer with probability ``rate``.
Expand Down Expand Up @@ -110,7 +147,24 @@ class RandomApply(sk.TreeClass):

def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
rate = jax.lax.stop_gradient(self.rate)
return self.layer(x) if jr.bernoulli(key, rate) else x
return random_apply(layer=self.layer, array=x, rate=rate, key=key)


def random_choice(
layers: tuple[Any, ...],
array: jax.Array,
*,
key: jr.KeyArray,
) -> jax.Array:
"""Randomly selects one of the given layers/functions.
Args:
layers: variable number of layers/functions to select from.
array: an array to apply the layer to.
key: a random number generator key.
"""
index = jr.randint(key, (), 0, len(layers))
return jax.lax.switch(index, layers, array)


class RandomChoice(sk.TreeClass):
Expand Down Expand Up @@ -148,8 +202,7 @@ def __init__(self, *layers):
self.layers = layers

def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
index = jr.randint(key, (), 0, len(self.layers))
return jax.lax.switch(index, self.layers, x)
return random_choice(layers=self.layers, array=x, key=key)


@tree_eval.def_eval(RandomChoice)
Expand Down
14 changes: 7 additions & 7 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def convolution_ndim(
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
bias: bias. shape is (out_features, (1,)*spatial).
bias: bias. shape is (out_features, (1,)*spatial). set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
Expand Down Expand Up @@ -573,7 +573,7 @@ def fft_convolution_ndim(
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
bias: bias. shape is (out_features, (1,)*spatial).
bias: bias. shape is (out_features, (1,)*spatial). set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
Expand Down Expand Up @@ -964,7 +964,7 @@ def transposed_convolution_ndim(
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
bias: bias. shape is (out_features, (1,)*spatial).
bias: bias. shape is (out_features, (1,)*spatial). set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
Expand Down Expand Up @@ -1327,7 +1327,7 @@ def transposed_fft_convolution_ndim(
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
bias: bias. shape is (out_features, (1,)*spatial).
bias: bias. shape is (out_features, (1,)*spatial). set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
Expand Down Expand Up @@ -1724,7 +1724,7 @@ def depthwise_convolution_ndim(
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
bias: bias. shape is (out_features, (1,)*spatial).
bias: bias. shape is (out_features, (1,)*spatial). set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
Expand Down Expand Up @@ -2026,7 +2026,7 @@ def depthwise_fft_convolution_ndim(
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
bias: bias. shape is (out_features, (1,)*spatial).
bias: bias. shape is (out_features, (1,)*spatial). set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
Expand Down Expand Up @@ -2976,7 +2976,7 @@ def local_convolution_ndim(
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
bias: bias. shape is (out_features, (1,)*spatial).
bias: bias. shape is (out_features, (1,)*spatial). set to ``None`` to not use a bias.
strides: stride of the convolution accepts tuple of integers for different
strides in each dimension.
padding: padding of the input before convolution accepts tuple of integers
Expand Down

0 comments on commit 4663bb5

Please sign in to comment.