Skip to content

Commit

Permalink
pytc 0.8.0 bump
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 6, 2023
1 parent 800109e commit 1e884d2
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 30 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ keywords = [
"functional-programming",
"machine-learning",
]
dependencies = ["pytreeclass>=0.6.0", "kernex>=0.2.0"]
dependencies = ["pytreeclass>=0.8.0", "kernex>=0.2.0"]

classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down
4 changes: 2 additions & 2 deletions serket/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class Posterize2D(sk.TreeClass):
- https://github.com/python-pillow/Pillow/blob/main/src/PIL/ImageOps.py#L547
"""

bits: int = sk.field(callbacks=[IsInstance(int), Range(1, 8)])
bits: int = sk.field(on_setattr=[IsInstance(int), Range(1, 8)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -401,7 +401,7 @@ class JigSaw2D(sk.TreeClass):
- https://imgaug.readthedocs.io/en/latest/source/overview/geometric.html#jigsaw
"""

tiles: int = sk.field(callbacks=[IsInstance(int), Range(1)])
tiles: int = sk.field(on_setattr=[IsInstance(int), Range(1)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
Expand Down
4 changes: 2 additions & 2 deletions serket/image/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ class HorizontalTranslate2D(sk.TreeClass):
[ 0 0 21 22 23]]]
"""

shift: int = sk.field(callbacks=[IsInstance(int)])
shift: int = sk.field(on_setattr=[IsInstance(int)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -625,7 +625,7 @@ class VerticalTranslate2D(sk.TreeClass):
[11 12 13 14 15]]]
"""

shift: int = sk.field(callbacks=[IsInstance(int)])
shift: int = sk.field(on_setattr=[IsInstance(int)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
Expand Down
28 changes: 14 additions & 14 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class AdaptiveLeakyReLU(sk.TreeClass):
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])
v: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()])
v: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
v = jax.lax.stop_gradient(self.v)
Expand All @@ -73,7 +73,7 @@ class AdaptiveReLU(sk.TreeClass):
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return adaptive_relu(x, self.a)
Expand All @@ -96,7 +96,7 @@ class AdaptiveSigmoid(sk.TreeClass):
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return adaptive_sigmoid(x, self.a)
Expand All @@ -119,7 +119,7 @@ class AdaptiveTanh(sk.TreeClass):
https://arxiv.org/pdf/1906.01170.pdf.
"""

a: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
a = self.a
Expand All @@ -130,7 +130,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
class CeLU(sk.TreeClass):
"""Celu activation function"""

alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()])
alpha: float = sk.field(default=1.0, on_setattr=[ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.celu(x, alpha=lax.stop_gradient(self.alpha))
Expand All @@ -140,7 +140,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
class ELU(sk.TreeClass):
"""Exponential linear unit"""

alpha: float = sk.field(default=1.0, callbacks=[ScalarLike()])
alpha: float = sk.field(default=1.0, on_setattr=[ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.elu(x, alpha=lax.stop_gradient(self.alpha))
Expand All @@ -150,7 +150,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
class GELU(sk.TreeClass):
"""Gaussian error linear unit"""

approximate: bool = sk.field(default=False, callbacks=[IsInstance(bool)])
approximate: bool = sk.field(default=False, on_setattr=[IsInstance(bool)])

def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.gelu(x, approximate=self.approximate)
Expand All @@ -177,7 +177,7 @@ def hard_shrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array:
class HardShrink(sk.TreeClass):
"""Hard shrink activation function"""

alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()])
alpha: float = sk.field(default=0.5, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
alpha = lax.stop_gradient(self.alpha)
Expand Down Expand Up @@ -223,7 +223,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
class LeakyReLU(sk.TreeClass):
"""Leaky ReLU activation function"""

negative_slope: float = sk.field(default=0.01, callbacks=[Range(0), ScalarLike()])
negative_slope: float = sk.field(default=0.01, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.leaky_relu(x, lax.stop_gradient(self.negative_slope))
Expand Down Expand Up @@ -293,7 +293,7 @@ def softshrink(x: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array:
class SoftShrink(sk.TreeClass):
"""SoftShrink activation function"""

alpha: float = sk.field(default=0.5, callbacks=[Range(0), ScalarLike()])
alpha: float = sk.field(default=0.5, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
alpha = lax.stop_gradient(self.alpha)
Expand Down Expand Up @@ -355,7 +355,7 @@ def thresholded_relu(x: jax.typing.ArrayLike, theta: float = 1.0) -> jax.Array:
class ThresholdedReLU(sk.TreeClass):
"""Thresholded ReLU activation function."""

theta: float = sk.field(default=1.0, callbacks=[Range(0), ScalarLike()])
theta: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
theta = lax.stop_gradient(self.theta)
Expand Down Expand Up @@ -383,7 +383,7 @@ def prelu(x: jax.typing.ArrayLike, a: float = 0.25) -> jax.Array:
class PReLU(sk.TreeClass):
"""Parametric ReLU activation function"""

a: float = sk.field(default=0.25, callbacks=[Range(0), ScalarLike()])
a: float = sk.field(default=0.25, on_setattr=[Range(0), ScalarLike()])

def __call__(self, x: jax.Array) -> jax.Array:
return prelu(x, self.a)
Expand Down Expand Up @@ -412,7 +412,7 @@ class Snake(sk.TreeClass):
https://arxiv.org/pdf/2006.08195.pdf.
"""

a: float = sk.field(callbacks=[Range(0), ScalarLike()], default=1.0)
a: float = sk.field(on_setattr=[Range(0), ScalarLike()], default=1.0)

def __call__(self, x: jax.Array) -> jax.Array:
a = lax.stop_gradient(self.a)
Expand Down
4 changes: 2 additions & 2 deletions serket/nn/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ class KMeans(sk.TreeClass):
>>> assert jnp.all(eval_state.centers == state.centers)
"""

clusters: int = sk.field(callbacks=[IsInstance(int), Range(1)])
tol: float = sk.field(callbacks=[IsInstance(float), Range(0, min_inclusive=False)])
clusters: int = sk.field(on_setattr=[IsInstance(int), Range(1)])
tol: float = sk.field(on_setattr=[IsInstance(float), Range(0, min_inclusive=False)])

def __call__(
self,
Expand Down
2 changes: 1 addition & 1 deletion serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class RandomApply(sk.TreeClass):
"""

layer: Any
rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)])
rate: float = sk.field(default=0.5, on_setattr=[Range(0, 1)])

def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
rate = jax.lax.stop_gradient(self.rate)
Expand Down
4 changes: 2 additions & 2 deletions serket/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class GeneralDropout(sk.TreeClass):
dropout is applied to all axes.
"""

drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)])
drop_rate: float = sk.field(default=0.5, on_setattr=[Range(0, 1)])
drop_axes: tuple[int, ...] | Literal["..."] = ...

def __call__(self, x, *, key: jr.KeyArray = jr.PRNGKey(0)):
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(self, drop_rate: float = 0.5):

@sk.autoinit
class DropoutND(sk.TreeClass):
drop_rate: float = sk.field(default=0.5, callbacks=[Range(0, 1)])
drop_rate: float = sk.field(default=0.5, on_setattr=[Range(0, 1)])

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x, *, key=jr.PRNGKey(0)):
Expand Down
5 changes: 3 additions & 2 deletions serket/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class LayerNorm(sk.TreeClass):
https://nn.labml.ai/normalization/layer_norm/index.html
"""

eps: float = sk.field(callbacks=[Range(0, min_inclusive=False), ScalarLike()])
eps: float = sk.field(on_setattr=[Range(0, min_inclusive=False), ScalarLike()])

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
Expand Down Expand Up @@ -225,6 +225,7 @@ class GroupNorm(sk.TreeClass):
Reference:
https://nn.labml.ai/normalization/group_norm/index.html
"""
eps: float = sk.field(on_setattr=[Range(0), ScalarLike()])

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
Expand All @@ -240,7 +241,7 @@ def __init__(
):
self.in_features = positive_int_cb(in_features)
self.groups = positive_int_cb(groups)
self.eps = sk.field(callbacks=[Range(0), ScalarLike()])(eps)
self.eps = eps

# needs more info for checking
if in_features % groups != 0:
Expand Down
8 changes: 4 additions & 4 deletions serket/nn/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ class Flatten(sk.TreeClass):
https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html?highlight=flatten#torch.nn.Flatten
"""

start_dim: int = sk.field(default=0, callbacks=[IsInstance(int)])
end_dim: int = sk.field(default=-1, callbacks=[IsInstance(int)])
start_dim: int = sk.field(default=0, on_setattr=[IsInstance(int)])
end_dim: int = sk.field(default=-1, on_setattr=[IsInstance(int)])

def __call__(self, x: jax.Array) -> jax.Array:
start_dim = self.start_dim + (0 if self.start_dim >= 0 else x.ndim)
Expand Down Expand Up @@ -311,8 +311,8 @@ class Unflatten(sk.TreeClass):
- https://pytorch.org/docs/stable/generated/torch.nn.Unflatten.html?highlight=unflatten
"""

dim: int = sk.field(default=0, callbacks=[IsInstance(int)])
shape: tuple = sk.field(default=None, callbacks=[IsInstance(tuple)])
dim: int = sk.field(default=0, on_setattr=[IsInstance(int)])
shape: tuple = sk.field(default=None, on_setattr=[IsInstance(tuple)])

def __call__(self, x: jax.Array, **k) -> jax.Array:
shape = list(x.shape)
Expand Down

0 comments on commit 1e884d2

Please sign in to comment.