Skip to content

Commit

Permalink
adapt to pytc 0.5 (#23)
Browse files Browse the repository at this point in the history
* adapt to pytc 0.5

* docs edit
  • Loading branch information
ASEM000 authored Jul 24, 2023
1 parent e895601 commit dd5c522
Show file tree
Hide file tree
Showing 20 changed files with 204 additions and 804 deletions.
4 changes: 2 additions & 2 deletions docs/notebooks/bilstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.760744e-03\tTime: 0.020\r"
"Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.760744e-03\tTime: 0.025\r"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x13cab12a0>"
"<matplotlib.legend.Legend at 0x1490affa0>"
]
},
"execution_count": 5,
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"source": [
"k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)\n",
"\n",
"\n",
"@sk.autoinit\n",
"class ConvNet(sk.TreeClass):\n",
" conv1: sk.nn.Conv2D = sk.nn.Conv2D(1, 32, 3, key=k1, padding=\"valid\")\n",
" pool1: sk.nn.MaxPool2D = sk.nn.MaxPool2D(2, 2)\n",
Expand Down Expand Up @@ -201,7 +201,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 001/001\tBatch: 468/468\tBatch loss: 2.040178e-01\tBatch accuracy: 0.984375\tTime: 18.339\r"
"Epoch: 001/001\tBatch: 468/468\tBatch loss: 2.040178e-01\tBatch accuracy: 0.984375\tTime: 18.784\r"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ keywords = [
"functional-programming",
"machine-learning",
]
dependencies = ["pytreeclass>=0.4.0"]
dependencies = ["pytreeclass>=0.5.0"]

classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down
6 changes: 5 additions & 1 deletion serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
BaseKey,
Partial,
TreeClass,
autoinit,
bcmap,
field,
fields,
freeze,
is_frozen,
is_nondiff,
is_tree_equal,
leafwise,
tree_diagram,
tree_flatten_with_trace,
tree_graph,
Expand All @@ -50,6 +52,7 @@
"is_tree_equal",
"field",
"fields",
"autoinit",
# pprint utils
"tree_diagram",
"tree_graph",
Expand All @@ -74,11 +77,12 @@
"tree_flatten_with_trace",
"tree_repr_with_trace",
"Partial",
"leafwise",
# serket
"nn",
"tree_evaluation",
"tree_state",
)


__version__ = "0.4.0b1"
__version__ = "0.5.0b1"
19 changes: 19 additions & 0 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from serket.nn.utils import IsInstance, Range, ScalarLike


@sk.autoinit
class AdaptiveLeakyReLU(sk.TreeClass):
"""Leaky ReLU activation function
Expand All @@ -39,6 +40,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return jnp.maximum(0, self.a * x) - v * jnp.maximum(0, -self.a * x)


@sk.autoinit
class AdaptiveReLU(sk.TreeClass):
"""ReLU activation function with learnable parameters
Note:
Expand All @@ -51,6 +53,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return jnp.maximum(0, self.a * x)


@sk.autoinit
class AdaptiveSigmoid(sk.TreeClass):
"""Sigmoid activation function with learnable `a` parameter
Note:
Expand All @@ -63,6 +66,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return 1 / (1 + jnp.exp(-self.a * x))


@sk.autoinit
class AdaptiveTanh(sk.TreeClass):
"""Tanh activation function with learnable parameters
Note:
Expand All @@ -76,6 +80,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return (jnp.exp(a * x) - jnp.exp(-a * x)) / (jnp.exp(a * x) + jnp.exp(-a * x))


@sk.autoinit
class CeLU(sk.TreeClass):
"""Celu activation function"""

Expand All @@ -85,6 +90,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.celu(x, alpha=lax.stop_gradient(self.alpha))


@sk.autoinit
class ELU(sk.TreeClass):
"""Exponential linear unit"""

Expand All @@ -94,6 +100,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.elu(x, alpha=lax.stop_gradient(self.alpha))


@sk.autoinit
class GELU(sk.TreeClass):
"""Gaussian error linear unit"""

Expand All @@ -103,13 +110,15 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.gelu(x, approximate=self.approximate)


@sk.autoinit
class GLU(sk.TreeClass):
"""Gated linear unit"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.glu(x)


@sk.autoinit
class HardShrink(sk.TreeClass):
"""Hard shrink activation function"""

Expand All @@ -120,41 +129,47 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return jnp.where(x > alpha, x, jnp.where(x < -alpha, x, 0.0))


@sk.autoinit
class HardSigmoid(sk.TreeClass):
"""Hard sigmoid activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.hard_sigmoid(x)


@sk.autoinit
class HardSwish(sk.TreeClass):
"""Hard swish activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.hard_swish(x)


@sk.autoinit
class HardTanh(sk.TreeClass):
"""Hard tanh activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.hard_tanh(x)


@sk.autoinit
class LogSigmoid(sk.TreeClass):
"""Log sigmoid activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.log_sigmoid(x)


@sk.autoinit
class LogSoftmax(sk.TreeClass):
"""Log softmax activation function"""

def __call__(self, x: jax.Array, **k) -> jax.Array:
return jax.nn.log_softmax(x)


@sk.autoinit
class LeakyReLU(sk.TreeClass):
"""Leaky ReLU activation function"""

Expand Down Expand Up @@ -206,6 +221,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return x / (1 + jnp.abs(x))


@sk.autoinit
class SoftShrink(sk.TreeClass):
"""SoftShrink activation function"""

Expand Down Expand Up @@ -248,6 +264,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return x - jax.nn.tanh(x)


@sk.autoinit
class ThresholdedReLU(sk.TreeClass):
"""Thresholded ReLU activation function."""

Expand All @@ -265,6 +282,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return x * jax.nn.tanh(jax.nn.softplus(x))


@sk.autoinit
class PReLU(sk.TreeClass):
"""Parametric ReLU activation function"""

Expand All @@ -274,6 +292,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return jnp.where(x >= 0, x, x * self.a)


@sk.autoinit
class Snake(sk.TreeClass):
"""Snake activation function
Expand Down
1 change: 1 addition & 0 deletions serket/nn/blocks/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
return self.conv(x)


@sk.autoinit
class UNetBlock(sk.TreeClass):
"""Vanilla UNet
Expand Down
1 change: 1 addition & 0 deletions serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import serket as sk


@sk.autoinit
class Sequential(sk.TreeClass):
"""A sequential container for layers.
Expand Down
8 changes: 2 additions & 6 deletions serket/nn/contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def random_contrast_nd(
return adjust_contrast_nd(x, contrast_factor)


@sk.autoinit
class AdjustContrastND(sk.TreeClass):
"""Adjusts the contrast of an NDimage by scaling the pixel values by a factor.
Expand Down Expand Up @@ -71,14 +72,12 @@ class AdjustContrast2D(AdjustContrastND):
contrast_factor: contrast factor to adjust the image by.
"""

def __init__(self, contrast_factor=1.0):
super().__init__(contrast_factor=contrast_factor)

@property
def spatial_ndim(self) -> int:
return 2


@sk.autoinit
class RandomContrastND(sk.TreeClass):
"""Randomly adjusts the contrast of an image by scaling the pixel
values by a factor.
Expand Down Expand Up @@ -131,9 +130,6 @@ class RandomContrast2D(RandomContrastND):
contrast_range: range of contrast factors to randomly sample from.
"""

def __init__(self, contrast_range=(0.5, 1)):
super().__init__(contrast_range=contrast_range)

@property
def spatial_ndim(self) -> int:
return 2
Loading

0 comments on commit dd5c522

Please sign in to comment.