Skip to content

Commit

Permalink
More refactoring and removals
Browse files Browse the repository at this point in the history
- Remove operators
- Remove `Laplace2D`
- Some refactoring
- Fix dropout
- update API to PyTreeClass 0.4
  • Loading branch information
ASEM000 committed Jul 9, 2023
1 parent 2722229 commit ec6262b
Show file tree
Hide file tree
Showing 19 changed files with 226 additions and 418 deletions.
44 changes: 33 additions & 11 deletions serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,70 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pytreeclass import (
TreeClass,
field,
from pytreeclass._src.code_build import field, fields
from pytreeclass._src.tree_base import TreeClass
from pytreeclass._src.tree_index import AtIndexer, BaseKey
from pytreeclass._src.tree_mask import (
freeze,
is_frozen,
is_nondiff,
is_tree_equal,
tree_mask,
tree_unmask,
unfreeze,
)
from pytreeclass._src.tree_pprint import (
tree_diagram,
tree_indent,
tree_mermaid,
tree_repr,
tree_repr_with_trace,
tree_str,
tree_summary,
unfreeze,
)
from pytreeclass._src.tree_util import (
Partial,
bcmap,
is_tree_equal,
tree_flatten_with_trace,
tree_leaves_with_trace,
tree_map_with_trace,
)

from . import nn
from .operators import diff, value_and_diff

__all__ = (
# general utils
"TreeClass",
"is_tree_equal",
"field",
"fields",
# pprint utils
"tree_diagram",
"tree_mermaid",
"tree_repr",
"tree_str",
"tree_indent",
"tree_summary",
"tree_trace_summary",
# freeze/unfreeze utils
# masking utils
"is_nondiff",
"is_frozen",
"freeze",
"unfreeze",
"tree_unmask",
"tree_mask",
# indexing utils
"AtIndexer",
"BaseKey",
# tree utils
"bcmap",
"tree_map_with_trace",
"tree_leaves_with_trace",
"tree_flatten_with_trace",
"tree_repr_with_trace",
"Partial",
# serket
"nn",
"diff",
"value_and_diff",
)


__version__ = "0.2.0b6"
__version__ = "0.2.0b7"
5 changes: 1 addition & 4 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from .blocks import UNetBlock, VGG16Block, VGG19Block
from .blur import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from .containers import Lambda, Sequential
from .containers import Sequential
from .contrast import AdjustContrast2D, RandomContrast2D
from .convolution import (
Conv1D,
Expand Down Expand Up @@ -86,7 +86,6 @@
from .flatten import Flatten, Unflatten
from .flip import FlipLeftRight2D, FlipUpDown2D
from .fully_connected import FNN, MLP
from .laplace import Laplace2D
from .linear import (
Bilinear,
Embedding,
Expand Down Expand Up @@ -158,7 +157,6 @@
"Dropout3D",
# containers
"Sequential",
"Lambda",
# Pooling
"MaxPool1D",
"MaxPool2D",
Expand Down Expand Up @@ -223,7 +221,6 @@
"Filter2D",
"FFTFilter2D",
# Resize
"Laplace2D",
"FlipLeftRight2D",
"FlipUpDown2D",
"Resize1D",
Expand Down
108 changes: 33 additions & 75 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,61 +20,8 @@
import jax.numpy as jnp
import pytreeclass as pytc
from jax import lax
from pytreeclass import TreeClass

from serket.nn.utils import Range, ScalarLike


def adaptive_leaky_relu(x: jax.Array, a: float = 1.0, v: float = 1.0) -> jax.Array:
return jnp.maximum(0, a * x) - v * jnp.maximum(0, -a * x)


def adaptive_relu(x: jax.Array, a: float = 1.0) -> jax.Array:
return jnp.maximum(0, a * x)


def adaptive_sigmoid(x: jax.Array, a: float = 1.0) -> jax.Array:
return 1 / (1 + jnp.exp(-a * x))


def adaptive_tanh(x: jax.Array, a: float = 1.0) -> jax.Array:
return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x))


def hard_shrink(x: jax.Array, alpha: float = 0.5) -> jax.Array:
return jnp.where(x > alpha, x, jnp.where(x < -alpha, x, 0.0))


def parametric_relu(x: jax.Array, a: float = 0.25) -> jax.Array:
return jnp.where(x >= 0, x, x * a)


def soft_shrink(x: jax.Array, alpha: float = 0.5) -> jax.Array:
return jnp.where(
x < -alpha,
x + alpha,
jnp.where(x > alpha, x - alpha, 0.0),
)


def square_plus(x: jax.Array) -> jax.Array:
return 0.5 * (x + jnp.sqrt(x * x + 4))


def soft_sign(x: jax.Array) -> jax.Array:
return x / (1 + jnp.abs(x))


def thresholded_relu(x: jax.Array, theta: float = 1.0) -> jax.Array:
return jnp.where(x > theta, x, 0)


def mish(x: jax.Array) -> jax.Array:
return x * jax.nn.tanh(jax.nn.softplus(x))


def snake(x: jax.Array, frequency: float = 1.0) -> jax.Array:
return x + (1 - jnp.cos(2 * frequency * x)) / (2 * frequency)
from serket.nn.utils import IsInstance, Range, ScalarLike


class AdaptiveLeakyReLU(pytc.TreeClass):
Expand All @@ -87,7 +34,8 @@ class AdaptiveLeakyReLU(pytc.TreeClass):
v: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return adaptive_leaky_relu(x, self.a, lax.stop_gradient(self.v))
v = jax.lax.stop_gradient(self.v)
return jnp.maximum(0, self.a * x) - v * jnp.maximum(0, -self.a * x)


class AdaptiveReLU(pytc.TreeClass):
Expand All @@ -99,7 +47,7 @@ class AdaptiveReLU(pytc.TreeClass):
a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return adaptive_relu(x, self.a)
return jnp.maximum(0, self.a * x)


class AdaptiveSigmoid(pytc.TreeClass):
Expand All @@ -111,7 +59,7 @@ class AdaptiveSigmoid(pytc.TreeClass):
a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return adaptive_sigmoid(x, self.a)
return 1 / (1 + jnp.exp(-self.a * x))


class AdaptiveTanh(pytc.TreeClass):
Expand All @@ -123,13 +71,14 @@ class AdaptiveTanh(pytc.TreeClass):
a: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return adaptive_tanh(x, self.a)
a = self.a
return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x))


class CeLU(pytc.TreeClass):
"""Celu activation function"""

alpha: float = 1.0
alpha: float = pytc.field(default=1.0, callbacks=[ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.celu(x, alpha=lax.stop_gradient(self.alpha))
Expand All @@ -138,7 +87,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
class ELU(pytc.TreeClass):
"""Exponential linear unit"""

alpha: float = 1.0
alpha: float = pytc.field(default=1.0, callbacks=[ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.elu(x, alpha=lax.stop_gradient(self.alpha))
Expand All @@ -147,7 +96,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
class GELU(pytc.TreeClass):
"""Gaussian error linear unit"""

approximate: bool = True
approximate: bool = pytc.field(default=1.0, callbacks=[IsInstance(bool)])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.gelu(x, approximate=self.approximate)
Expand All @@ -163,10 +112,11 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
class HardShrink(pytc.TreeClass):
"""Hard shrink activation function"""

alpha: float = 0.5
alpha: float = pytc.field(default=0.5, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return hard_shrink(x, lax.stop_gradient(self.alpha))
alpha = lax.stop_gradient(self.alpha)
return jnp.where(x > alpha, x, jnp.where(x < -alpha, x, 0.0))


class HardSigmoid(pytc.TreeClass):
Expand Down Expand Up @@ -207,7 +157,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
class LeakyReLU(pytc.TreeClass):
"""Leaky ReLU activation function"""

negative_slope: float = 0.01
negative_slope: float = pytc.field(default=0.01, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.leaky_relu(x, lax.stop_gradient(self.negative_slope))
Expand Down Expand Up @@ -252,23 +202,28 @@ class SoftSign(pytc.TreeClass):
"""SoftSign activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return soft_sign(x)
return x / (1 + jnp.abs(x))


class SoftShrink(pytc.TreeClass):
"""SoftShrink activation function"""

alpha: float = 0.5
alpha: float = pytc.field(default=0.5, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return soft_shrink(x, lax.stop_gradient(self.alpha))
alpha = lax.stop_gradient(self.alpha)
return jnp.where(
x < -alpha,
x + alpha,
jnp.where(x > alpha, x - alpha, 0.0),
)


class SquarePlus(pytc.TreeClass):
"""SquarePlus activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return square_plus(x)
return 0.5 * (x + jnp.sqrt(x * x + 4))


class Swish(pytc.TreeClass):
Expand All @@ -295,26 +250,27 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
class ThresholdedReLU(pytc.TreeClass):
"""Thresholded ReLU activation function."""

theta: float = pytc.field(callbacks=[Range(0), ScalarLike()])
theta: float = pytc.field(default=1.0, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return thresholded_relu(x, lax.stop_gradient(self.theta))
theta = lax.stop_gradient(self.theta)
return jnp.where(x > theta, x, 0)


class Mish(pytc.TreeClass):
"""Mish activation function https://arxiv.org/pdf/1908.08681.pdf."""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return mish(x)
return x * jax.nn.tanh(jax.nn.softplus(x))


class PReLU(pytc.TreeClass):
"""Parametric ReLU activation function"""

a: float = 0.25
a: float = pytc.field(default=0.25, callbacks=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array, **k) -> jax.Array:
return parametric_relu(x, self.a)
return jnp.where(x >= 0, x, x * self.a)


class Snake(pytc.TreeClass):
Expand All @@ -330,7 +286,8 @@ class Snake(pytc.TreeClass):
a: float = pytc.field(callbacks=[Range(0), ScalarLike()], default=1.0)

def __call__(self, x: jax.Array, **k) -> jax.Array:
return snake(x, lax.stop_gradient(self.a))
a = lax.stop_gradient(self.a)
return x + (1 - jnp.cos(2 * a * x)) / (2 * a)


# useful for building layers from configuration text
Expand Down Expand Up @@ -402,7 +359,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
]


act_map: dict[str, TreeClass] = dict(zip(get_args(ActivationLiteral), acts))
act_map: dict[str, pytc.TreeClass] = dict(zip(get_args(ActivationLiteral), acts))

ActivationFunctionType = Callable[[jax.typing.ArrayLike], jax.Array]
ActivationType = Union[ActivationLiteral, ActivationFunctionType]
Expand All @@ -414,6 +371,7 @@ def resolve_activation(act_func: ActivationType) -> ActivationFunctionType:
if isinstance(act_func, str):
if act_func in act_map:
return act_map[act_func]()

raise ValueError(
f"Unknown activation function {act_func=}, "
f"available activations are {list(act_map.keys())}"
Expand Down
4 changes: 0 additions & 4 deletions serket/nn/blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""

self.in_features = positive_int_cb(in_features)
self.kernel_size = positive_int_cb(kernel_size)

Expand Down Expand Up @@ -162,7 +160,6 @@ def __init__(self, in_features: int, kernel: jax.Array):
[6. 9. 9. 9. 6.]
[6. 9. 9. 9. 6.]
[4. 6. 6. 6. 4.]]]
"""
if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")
Expand Down Expand Up @@ -207,7 +204,6 @@ def __init__(self, in_features: int, kernel: jax.Array):
[6.0000005 9. 9. 9. 6.0000005]
[4. 6.0000005 6.0000005 6.0000005 4. ]]]
"""

if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand Down
Loading

0 comments on commit ec6262b

Please sign in to comment.