Skip to content

Commit

Permalink
add missing functional form for conv
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 28, 2023
1 parent f31aeec commit 9b30708
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 101 deletions.
2 changes: 1 addition & 1 deletion docs/API/convolution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Convolution
.. note::
The ``fft`` convolution variant is useful in myriad of cases, specifically the ``fft`` variant could be faster for larger kernel sizes. the following figure compares the speed of both implementation for different kernel size on mac ``m1`` cpu setup.

.. image:: fft_bench.svg
.. image:: ../_static/fft_bench.svg
:width: 600
:align: center

Expand Down
File renamed without changes
117 changes: 103 additions & 14 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@
T = TypeVar("T")


def adaptive_leaky_relu(
x: jax.typing.ArrayLike,
a: float = 1.0,
v: float = 1.0,
) -> jax.Array:
"""Adaptive Leaky ReLU activation function
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""
return jnp.maximum(0, a * x) - v * jnp.maximum(0, -a * x)


@sk.autoinit
class AdaptiveLeakyReLU(sk.TreeClass):
"""Leaky ReLU activation function
Expand All @@ -40,7 +53,16 @@ class AdaptiveLeakyReLU(sk.TreeClass):

def __call__(self, x: jax.Array) -> jax.Array:
v = jax.lax.stop_gradient(self.v)
return jnp.maximum(0, self.a * x) - v * jnp.maximum(0, -self.a * x)
return adaptive_leaky_relu(x, self.a, v)


def adaptive_relu(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array:
"""Adaptive ReLU activation function
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""
return jnp.maximum(0, a * x)


@sk.autoinit
Expand All @@ -54,7 +76,16 @@ class AdaptiveReLU(sk.TreeClass):
a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return jnp.maximum(0, self.a * x)
return adaptive_relu(x, self.a)


def adaptive_sigmoid(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array:
"""Adaptive sigmoid activation function
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""
return 1 / (1 + jnp.exp(-a * x))


@sk.autoinit
Expand All @@ -68,7 +99,16 @@ class AdaptiveSigmoid(sk.TreeClass):
a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return 1 / (1 + jnp.exp(-self.a * x))
return adaptive_sigmoid(x, self.a)


def adaptive_tanh(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array:
"""Adaptive tanh activation function
Reference:
https://arxiv.org/pdf/1906.01170.pdf.
"""
return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x))


@sk.autoinit
Expand All @@ -83,7 +123,7 @@ class AdaptiveTanh(sk.TreeClass):

def __call__(self, x: jax.Array) -> jax.Array:
a = self.a
return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x))
return adaptive_tanh(x, a)


@sk.autoinit
Expand Down Expand Up @@ -222,6 +262,19 @@ def __call__(self, x: jax.Array) -> jax.Array:
return x / (1 + jnp.abs(x))


def softshrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array:
"""Soft shrink activation function
Reference:
https://arxiv.org/pdf/1702.00783.pdf.
"""
return jnp.where(
x < -alpha,
x + alpha,
jnp.where(x > alpha, x - alpha, 0.0),
)


@sk.autoinit
class SoftShrink(sk.TreeClass):
"""SoftShrink activation function"""
Expand All @@ -230,18 +283,23 @@ class SoftShrink(sk.TreeClass):

def __call__(self, x: jax.Array) -> jax.Array:
alpha = lax.stop_gradient(self.alpha)
return jnp.where(
x < -alpha,
x + alpha,
jnp.where(x > alpha, x - alpha, 0.0),
)
return softshrink(x, alpha)


def squareplus(x: jax.typing.ArrayLike) -> jax.Array:
"""SquarePlus activation function
Reference:
https://arxiv.org/pdf/1908.08681.pdf.
"""
return 0.5 * (x + jnp.sqrt(x * x + 4))


class SquarePlus(sk.TreeClass):
"""SquarePlus activation function"""

def __call__(self, x: jax.Array) -> jax.Array:
return 0.5 * (x + jnp.sqrt(x * x + 4))
return squareplus(x)


class Swish(sk.TreeClass):
Expand All @@ -265,6 +323,15 @@ def __call__(self, x: jax.Array) -> jax.Array:
return x - jax.nn.tanh(x)


def thresholded_relu(x: jax.typing.ArrayLike, theta: float = 1.0) -> jax.Array:
"""Thresholded ReLU activation function
Reference:
https://arxiv.org/pdf/1911.09737.pdf.
"""
return jnp.where(x > theta, x, 0)


@sk.autoinit
class ThresholdedReLU(sk.TreeClass):
"""Thresholded ReLU activation function."""
Expand All @@ -273,14 +340,24 @@ class ThresholdedReLU(sk.TreeClass):

def __call__(self, x: jax.Array) -> jax.Array:
theta = lax.stop_gradient(self.theta)
return jnp.where(x > theta, x, 0)
return thresholded_relu(x, theta)


def mish(x: jax.typing.ArrayLike) -> jax.Array:
"""Mish activation function https://arxiv.org/pdf/1908.08681.pdf."""
return x * jax.nn.tanh(jax.nn.softplus(x))


class Mish(sk.TreeClass):
"""Mish activation function https://arxiv.org/pdf/1908.08681.pdf."""

def __call__(self, x: jax.Array) -> jax.Array:
return x * jax.nn.tanh(jax.nn.softplus(x))
return mish(x)


def prelu(x: jax.typing.ArrayLike, a: float = 0.25) -> jax.Array:
"""Parametric ReLU activation function"""
return jnp.where(x >= 0, x, x * a)


@sk.autoinit
Expand All @@ -290,7 +367,19 @@ class PReLU(sk.TreeClass):
a: float = sk.field(default=0.25, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return jnp.where(x >= 0, x, x * self.a)
return prelu(x, self.a)


def snake(x: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array:
"""Snake activation function
Args:
a: scalar (frequency) parameter of the activation function, default is 1.0.
Reference:
https://arxiv.org/pdf/2006.08195.pdf.
"""
return x + (1 - jnp.cos(2 * a * x)) / (2 * a)


@sk.autoinit
Expand All @@ -308,7 +397,7 @@ class Snake(sk.TreeClass):

def __call__(self, x: jax.Array) -> jax.Array:
a = lax.stop_gradient(self.a)
return x + (1 - jnp.cos(2 * a * x)) / (2 * a)
return snake(x, a)


# useful for building layers from configuration text
Expand Down
3 changes: 1 addition & 2 deletions serket/nn/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class KMeans(sk.TreeClass):
>>> labels, state = layer(x)
>>> plt.scatter(x[:, 0], x[:, 1], c=labels[:, 0], cmap="jet_r") # doctest: +SKIP
>>> plt.scatter(state.centers[:, 0], state.centers[:, 1], c="r", marker="o", linewidths=4) # doctest: +SKIP
.. image:: ../_static/kmeans.svg
:width: 600
:align: center
Expand Down Expand Up @@ -227,7 +227,6 @@ def __call__(
) -> tuple[jax.Array, KMeansState]:
distances = distances_from_centers(x, state.centers)
labels = labels_from_distances(distances)
state = state._replace(iters=None, error=None)
return labels, state


Expand Down
Loading

0 comments on commit 9b30708

Please sign in to comment.