From 1e884d26f5ee999c37ba56ec11e4e4901a3375d2 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Wed, 6 Sep 2023 20:03:42 +0900 Subject: [PATCH] pytc 0.8.0 bump --- pyproject.toml | 2 +- serket/image/augment.py | 4 ++-- serket/image/geometric.py | 4 ++-- serket/nn/activation.py | 28 ++++++++++++++-------------- serket/nn/clustering.py | 4 ++-- serket/nn/containers.py | 2 +- serket/nn/dropout.py | 4 ++-- serket/nn/normalization.py | 5 +++-- serket/nn/reshape.py | 8 ++++---- 9 files changed, 31 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 712fbe3..19a9d6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ keywords = [ "functional-programming", "machine-learning", ] -dependencies = ["pytreeclass>=0.6.0", "kernex>=0.2.0"] +dependencies = ["pytreeclass>=0.8.0", "kernex>=0.2.0"] classifiers = [ "Development Status :: 5 - Production/Stable", diff --git a/serket/image/augment.py b/serket/image/augment.py index 1eff3ad..d02c897 100644 --- a/serket/image/augment.py +++ b/serket/image/augment.py @@ -311,7 +311,7 @@ class Posterize2D(sk.TreeClass): - https://github.com/python-pillow/Pillow/blob/main/src/PIL/ImageOps.py#L547 """ - bits: int = sk.field(callbacks=[IsInstance(int), Range(1, 8)]) + bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)]) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: @@ -401,7 +401,7 @@ class JigSaw2D(sk.TreeClass): - https://imgaug.readthedocs.io/en/latest/source/overview/geometric.html#jigsaw """ - tiles: int = sk.field(callbacks=[IsInstance(int), Range(1)]) + tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)]) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array: diff --git a/serket/image/geometric.py b/serket/image/geometric.py index a77e044..c7bf5ce 100644 --- a/serket/image/geometric.py +++ b/serket/image/geometric.py @@ -593,7 +593,7 @@ class HorizontalTranslate2D(sk.TreeClass): [ 0 0 21 22 23]]] """ - shift: int = sk.field(callbacks=[IsInstance(int)]) + shift: int = sk.field(on_setattr=[IsInstance(int)]) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: @@ -625,7 +625,7 @@ class VerticalTranslate2D(sk.TreeClass): [11 12 13 14 15]]] """ - shift: int = sk.field(callbacks=[IsInstance(int)]) + shift: int = sk.field(on_setattr=[IsInstance(int)]) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: diff --git a/serket/nn/activation.py b/serket/nn/activation.py index 5feb9fa..0d5421d 100644 --- a/serket/nn/activation.py +++ b/serket/nn/activation.py @@ -48,8 +48,8 @@ class AdaptiveLeakyReLU(sk.TreeClass): https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) - v: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) + v: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: v = jax.lax.stop_gradient(self.v) @@ -73,7 +73,7 @@ class AdaptiveReLU(sk.TreeClass): https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: return adaptive_relu(x, self.a) @@ -96,7 +96,7 @@ class AdaptiveSigmoid(sk.TreeClass): https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: return adaptive_sigmoid(x, self.a) @@ -119,7 +119,7 @@ class AdaptiveTanh(sk.TreeClass): https://arxiv.org/pdf/1906.01170.pdf. """ - a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: a = self.a @@ -130,7 +130,7 @@ def __call__(self, x: jax.Array) -> jax.Array: class CeLU(sk.TreeClass): """Celu activation function""" - alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()]) + alpha: float = sk.field(default=1.0, on_setattr=[ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: return jax.nn.celu(x, alpha=lax.stop_gradient(self.alpha)) @@ -140,7 +140,7 @@ def __call__(self, x: jax.Array) -> jax.Array: class ELU(sk.TreeClass): """Exponential linear unit""" - alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()]) + alpha: float = sk.field(default=1.0, on_setattr=[ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: return jax.nn.elu(x, alpha=lax.stop_gradient(self.alpha)) @@ -150,7 +150,7 @@ def __call__(self, x: jax.Array) -> jax.Array: class GELU(sk.TreeClass): """Gaussian error linear unit""" - approximate: bool = sk.field(default=False, callbacks=[IsInstance(bool)]) + approximate: bool = sk.field(default=False, on_setattr=[IsInstance(bool)]) def __call__(self, x: jax.Array) -> jax.Array: return jax.nn.gelu(x, approximate=self.approximate) @@ -177,7 +177,7 @@ def hard_shrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array: class HardShrink(sk.TreeClass): """Hard shrink activation function""" - alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()]) + alpha: float = sk.field(default=0.5, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: alpha = lax.stop_gradient(self.alpha) @@ -223,7 +223,7 @@ def __call__(self, x: jax.Array) -> jax.Array: class LeakyReLU(sk.TreeClass): """Leaky ReLU activation function""" - negative_slope: float = sk.field(default=0.01, callbacks=[Range(0), ScalarLike()]) + negative_slope: float = sk.field(default=0.01, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: return jax.nn.leaky_relu(x, lax.stop_gradient(self.negative_slope)) @@ -293,7 +293,7 @@ def softshrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array: class SoftShrink(sk.TreeClass): """SoftShrink activation function""" - alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()]) + alpha: float = sk.field(default=0.5, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: alpha = lax.stop_gradient(self.alpha) @@ -355,7 +355,7 @@ def thresholded_relu(x: jax.typing.ArrayLike, theta: float = 1.0) -> jax.Array: class ThresholdedReLU(sk.TreeClass): """Thresholded ReLU activation function.""" - theta: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()]) + theta: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: theta = lax.stop_gradient(self.theta) @@ -383,7 +383,7 @@ def prelu(x: jax.typing.ArrayLike, a: float = 0.25) -> jax.Array: class PReLU(sk.TreeClass): """Parametric ReLU activation function""" - a: float = sk.field(default=0.25, callbacks=[Range(0), ScalarLike()]) + a: float = sk.field(default=0.25, on_setattr=[Range(0), ScalarLike()]) def __call__(self, x: jax.Array) -> jax.Array: return prelu(x, self.a) @@ -412,7 +412,7 @@ class Snake(sk.TreeClass): https://arxiv.org/pdf/2006.08195.pdf. """ - a: float = sk.field(callbacks=[Range(0), ScalarLike()], default=1.0) + a: float = sk.field(on_setattr=[Range(0), ScalarLike()], default=1.0) def __call__(self, x: jax.Array) -> jax.Array: a = lax.stop_gradient(self.a) diff --git a/serket/nn/clustering.py b/serket/nn/clustering.py index 81c7204..3b3f3f5 100644 --- a/serket/nn/clustering.py +++ b/serket/nn/clustering.py @@ -180,8 +180,8 @@ class KMeans(sk.TreeClass): >>> assert jnp.all(eval_state.centers == state.centers) """ - clusters: int = sk.field(callbacks=[IsInstance(int), Range(1)]) - tol: float = sk.field(callbacks=[IsInstance(float), Range(0, min_inclusive=False)]) + clusters: int = sk.field(on_setattr=[IsInstance(int), Range(1)]) + tol: float = sk.field(on_setattr=[IsInstance(float), Range(0, min_inclusive=False)]) def __call__( self, diff --git a/serket/nn/containers.py b/serket/nn/containers.py index 5a304bc..3946fd7 100644 --- a/serket/nn/containers.py +++ b/serket/nn/containers.py @@ -143,7 +143,7 @@ class RandomApply(sk.TreeClass): """ layer: Any - rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) + rate: float = sk.field(default=0.5, on_setattr=[Range(0, 1)]) def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)): rate = jax.lax.stop_gradient(self.rate) diff --git a/serket/nn/dropout.py b/serket/nn/dropout.py index fb0e1a6..43afe80 100644 --- a/serket/nn/dropout.py +++ b/serket/nn/dropout.py @@ -59,7 +59,7 @@ class GeneralDropout(sk.TreeClass): dropout is applied to all axes. """ - drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) + drop_rate: float = sk.field(default=0.5, on_setattr=[Range(0, 1)]) drop_axes: tuple[int, ...] | Literal["..."] = ... def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)): @@ -109,7 +109,7 @@ def __init__(self, drop_rate: float = 0.5): @sk.autoinit class DropoutND(sk.TreeClass): - drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)]) + drop_rate: float = sk.field(default=0.5, on_setattr=[Range(0, 1)]) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x, *, key=jr.PRNGKey(0)): diff --git a/serket/nn/normalization.py b/serket/nn/normalization.py index 8212686..542ffbe 100644 --- a/serket/nn/normalization.py +++ b/serket/nn/normalization.py @@ -131,7 +131,7 @@ class LayerNorm(sk.TreeClass): https://nn.labml.ai/normalization/layer_norm/index.html """ - eps: float = sk.field(callbacks=[Range(0, min_inclusive=False), ScalarLike()]) + eps: float = sk.field(on_setattr=[Range(0, min_inclusive=False), ScalarLike()]) @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( @@ -225,6 +225,7 @@ class GroupNorm(sk.TreeClass): Reference: https://nn.labml.ai/normalization/group_norm/index.html """ + eps: float = sk.field(on_setattr=[Range(0), ScalarLike()]) @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( @@ -240,7 +241,7 @@ def __init__( ): self.in_features = positive_int_cb(in_features) self.groups = positive_int_cb(groups) - self.eps = sk.field(callbacks=[Range(0), ScalarLike()])(eps) + self.eps = eps # needs more info for checking if in_features % groups != 0: diff --git a/serket/nn/reshape.py b/serket/nn/reshape.py index 6df7f06..5dacfd1 100644 --- a/serket/nn/reshape.py +++ b/serket/nn/reshape.py @@ -282,8 +282,8 @@ class Flatten(sk.TreeClass): https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html?highlight=flatten#torch.nn.Flatten """ - start_dim: int = sk.field(default=0, callbacks=[IsInstance(int)]) - end_dim: int = sk.field(default=-1, callbacks=[IsInstance(int)]) + start_dim: int = sk.field(default=0, on_setattr=[IsInstance(int)]) + end_dim: int = sk.field(default=-1, on_setattr=[IsInstance(int)]) def __call__(self, x: jax.Array) -> jax.Array: start_dim = self.start_dim + (0 if self.start_dim >= 0 else x.ndim) @@ -311,8 +311,8 @@ class Unflatten(sk.TreeClass): - https://pytorch.org/docs/stable/generated/torch.nn.Unflatten.html?highlight=unflatten """ - dim: int = sk.field(default=0, callbacks=[IsInstance(int)]) - shape: tuple = sk.field(default=None, callbacks=[IsInstance(tuple)]) + dim: int = sk.field(default=0, on_setattr=[IsInstance(int)]) + shape: tuple = sk.field(default=None, on_setattr=[IsInstance(tuple)]) def __call__(self, x: jax.Array, **k) -> jax.Array: shape = list(x.shape)