Skip to content

Commit

Permalink
more edits
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 14, 2023
1 parent d269138 commit dd20431
Show file tree
Hide file tree
Showing 27 changed files with 2,459 additions and 1,661 deletions.
39 changes: 22 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,32 @@ build-backend = "setuptools.build_meta"
name = "serket"
dynamic = ["version"]
requires-python = ">=3.8"
license = {text = "Apache-2.0"}
license = { text = "Apache-2.0" }
description = "Functional neural network library in JAX"
authors = [{name = "Mahmoud Asem", email = "[email protected]"}]
keywords = ["jax", "neural-networks", "functional-programming", "machine-learning"]
dependencies = ["jax>=0.4.7", "typing-extensions"]
authors = [{ name = "Mahmoud Asem", email = "[email protected]" }]
keywords = [
"jax",
"neural-networks",
"functional-programming",
"machine-learning",
]
dependencies = ["pytreeclass>=0.4.0"]

classifiers=[
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
classifiers = [
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
]

[tool.setuptools.dynamic]
version = {attr = "serket.__version__" }
version = { attr = "serket.__version__" }

[tool.setuptools.packages.find]
include = ["serket", "serket.*"]
Expand All @@ -39,5 +44,5 @@ Source = "https://github.com/ASEM000/Serket"
select = ["F", "E", "I"]
line-length = 120
ignore = [
"E731", # do not assign a lambda expression, use a def
"E731", # do not assign a lambda expression, use a def
]
4 changes: 2 additions & 2 deletions serket/experimental/test_lazy_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import jax
import jax.numpy as jnp
import pytest
import pytreeclass as pytc

import serket as sk
from serket.experimental import lazy_class


Expand All @@ -30,7 +30,7 @@ def test_lazy_class():
infer_method_name="__call__", # -> `infer_func` is applied to `__call__` method
lazy_marker=None, # -> `None` is used to indicate a lazy argument
)
class LazyLinear(pytc.TreeClass):
class LazyLinear(sk.TreeClass):
weight: jax.Array
bias: jax.Array

Expand Down
92 changes: 46 additions & 46 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,198 +18,198 @@

import jax
import jax.numpy as jnp
import pytreeclass as pytc
from jax import lax

import serket as sk
from serket.nn.utils import IsInstance, Range, ScalarLike


class AdaptiveLeakyReLU(pytc.TreeClass):
class AdaptiveLeakyReLU(sk.TreeClass):
"""Leaky ReLU activation function
Note:
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])
v: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])
v: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
v = jax.lax.stop_gradient(self.v)
return jnp.maximum(0, self.a * x) - v * jnp.maximum(0, -self.a * x)


class AdaptiveReLU(pytc.TreeClass):
class AdaptiveReLU(sk.TreeClass):
"""ReLU activation function with learnable parameters
Note:
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jnp.maximum(0, self.a * x)


class AdaptiveSigmoid(pytc.TreeClass):
class AdaptiveSigmoid(sk.TreeClass):
"""Sigmoid activation function with learnable `a` parameter
Note:
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return 1 / (1 + jnp.exp(-self.a * x))


class AdaptiveTanh(pytc.TreeClass):
class AdaptiveTanh(sk.TreeClass):
"""Tanh activation function with learnable parameters
Note:
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
a = self.a
return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x))


class CeLU(pytc.TreeClass):
class CeLU(sk.TreeClass):
"""Celu activation function"""

alpha: float = pytc.field(default=1.0, callbacks=[ScalarLike()])
alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.celu(x, alpha=lax.stop_gradient(self.alpha))


class ELU(pytc.TreeClass):
class ELU(sk.TreeClass):
"""Exponential linear unit"""

alpha: float = pytc.field(default=1.0, callbacks=[ScalarLike()])
alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.elu(x, alpha=lax.stop_gradient(self.alpha))


class GELU(pytc.TreeClass):
class GELU(sk.TreeClass):
"""Gaussian error linear unit"""

approximate: bool = pytc.field(default=1.0, callbacks=[IsInstance(bool)])
approximate: bool = sk.field(default=1.0, callbacks=[IsInstance(bool)])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.gelu(x, approximate=self.approximate)


class GLU(pytc.TreeClass):
class GLU(sk.TreeClass):
"""Gated linear unit"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.glu(x)


class HardShrink(pytc.TreeClass):
class HardShrink(sk.TreeClass):
"""Hard shrink activation function"""

alpha: float = pytc.field(default=0.5, callbacks=[Range(0), ScalarLike()])
alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
alpha = lax.stop_gradient(self.alpha)
return jnp.where(x > alpha, x, jnp.where(x < -alpha, x, 0.0))


class HardSigmoid(pytc.TreeClass):
class HardSigmoid(sk.TreeClass):
"""Hard sigmoid activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.hard_sigmoid(x)


class HardSwish(pytc.TreeClass):
class HardSwish(sk.TreeClass):
"""Hard swish activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.hard_swish(x)


class HardTanh(pytc.TreeClass):
class HardTanh(sk.TreeClass):
"""Hard tanh activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.hard_tanh(x)


class LogSigmoid(pytc.TreeClass):
class LogSigmoid(sk.TreeClass):
"""Log sigmoid activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.log_sigmoid(x)


class LogSoftmax(pytc.TreeClass):
class LogSoftmax(sk.TreeClass):
"""Log softmax activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.log_softmax(x)


class LeakyReLU(pytc.TreeClass):
class LeakyReLU(sk.TreeClass):
"""Leaky ReLU activation function"""

negative_slope: float = pytc.field(default=0.01, callbacks=[Range(0), ScalarLike()])
negative_slope: float = sk.field(default=0.01, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.leaky_relu(x, lax.stop_gradient(self.negative_slope))


class ReLU(pytc.TreeClass):
class ReLU(sk.TreeClass):
"""ReLU activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.relu(x)


class ReLU6(pytc.TreeClass):
class ReLU6(sk.TreeClass):
"""ReLU6 activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.relu6(x)


class SeLU(pytc.TreeClass):
class SeLU(sk.TreeClass):
"""Scaled Exponential Linear Unit"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.selu(x)


class Sigmoid(pytc.TreeClass):
class Sigmoid(sk.TreeClass):
"""Sigmoid activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.sigmoid(x)


class SoftPlus(pytc.TreeClass):
class SoftPlus(sk.TreeClass):
"""SoftPlus activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.softplus(x)


class SoftSign(pytc.TreeClass):
class SoftSign(sk.TreeClass):
"""SoftSign activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return x / (1 + jnp.abs(x))


class SoftShrink(pytc.TreeClass):
class SoftShrink(sk.TreeClass):
"""SoftShrink activation function"""

alpha: float = pytc.field(default=0.5, callbacks=[Range(0), ScalarLike()])
alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
alpha = lax.stop_gradient(self.alpha)
Expand All @@ -220,61 +220,61 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
)


class SquarePlus(pytc.TreeClass):
class SquarePlus(sk.TreeClass):
"""SquarePlus activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return 0.5 * (x + jnp.sqrt(x * x + 4))


class Swish(pytc.TreeClass):
class Swish(sk.TreeClass):
"""Swish activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.swish(x)


class Tanh(pytc.TreeClass):
class Tanh(sk.TreeClass):
"""Tanh activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.tanh(x)


class TanhShrink(pytc.TreeClass):
class TanhShrink(sk.TreeClass):
"""TanhShrink activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return x - jax.nn.tanh(x)


class ThresholdedReLU(pytc.TreeClass):
class ThresholdedReLU(sk.TreeClass):
"""Thresholded ReLU activation function."""

theta: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])
theta: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
theta = lax.stop_gradient(self.theta)
return jnp.where(x > theta, x, 0)


class Mish(pytc.TreeClass):
class Mish(sk.TreeClass):
"""Mish activation function https://arxiv.org/pdf/1908.08681.pdf."""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return x * jax.nn.tanh(jax.nn.softplus(x))


class PReLU(pytc.TreeClass):
class PReLU(sk.TreeClass):
"""Parametric ReLU activation function"""

a: float = pytc.field(default=0.25, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=0.25, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jnp.where(x >= 0, x, x * self.a)


class Snake(pytc.TreeClass):
class Snake(sk.TreeClass):
"""Snake activation function
Args:
Expand All @@ -284,7 +284,7 @@ class Snake(pytc.TreeClass):
https://arxiv.org/pdf/2006.08195.pdf.
"""

a: float = pytc.field(callbacks=[Range(0), ScalarLike()], default=1.0)
a: float = sk.field(callbacks=[Range(0), ScalarLike()], default=1.0)

def __call__(self, x: jax.Array, **k) -> jax.Array:
a = lax.stop_gradient(self.a)
Expand Down Expand Up @@ -360,7 +360,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
]


act_map: dict[str, pytc.TreeClass] = dict(zip(get_args(ActivationLiteral), acts))
act_map: dict[str, sk.TreeClass] = dict(zip(get_args(ActivationLiteral), acts))

ActivationFunctionType = Callable[[jax.typing.ArrayLike], jax.Array]
ActivationType = Union[ActivationLiteral, ActivationFunctionType]
Expand Down
Loading

0 comments on commit dd20431

Please sign in to comment.