Skip to content

Commit

Permalink
zom and typo
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 27, 2023
1 parent d2547e2 commit 168c0e6
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 72 deletions.
11 changes: 8 additions & 3 deletions serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from typing import Any

import jax
import jax.numpy as jnp
import jax.random as jr

import serket as sk
from serket.nn.custom_transform import tree_evaluation
from serket.nn.utils import Range


Expand Down Expand Up @@ -107,6 +109,9 @@ class RandomApply(sk.TreeClass):
p: float = sk.field(default=0.5, callbacks=[Range(0, 1)])

def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)):
if not jr.bernoulli(key, jax.lax.stop_gradient(self.p)):
return x
return self.layer(x)
return jnp.where(jr.bernoulli(key, self.p), self.layer(x), x)


@tree_evaluation.def_evaluation(RandomApply)
def tree_evaluation_random_apply(layer: RandomApply):
return layer.layer
2 changes: 1 addition & 1 deletion serket/nn/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ def is_leaf(x: Callable[[Any], bool]) -> bool:


tree_evaluation.evaluation_dispatcher = ft.singledispatch(lambda x: x)
tree_evaluation.def_evalutation = tree_evaluation.evaluation_dispatcher.register
tree_evaluation.def_evaluation = tree_evaluation.evaluation_dispatcher.register
8 changes: 4 additions & 4 deletions serket/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,10 @@ def spatial_ndim(self) -> int:
return 2


@tree_evaluation.def_evalutation(RandomCutout1D)
@tree_evaluation.def_evalutation(RandomCutout2D)
@tree_evaluation.def_evalutation(Dropout)
@tree_evaluation.def_evalutation(DropoutND)
@tree_evaluation.def_evaluation(RandomCutout1D)
@tree_evaluation.def_evaluation(RandomCutout2D)
@tree_evaluation.def_evaluation(Dropout)
@tree_evaluation.def_evaluation(DropoutND)
def dropout_evaluation(_) -> Identity:
# dropout is a no-op during evaluation
return Identity()
26 changes: 2 additions & 24 deletions serket/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,6 @@ def layer_norm(
eps: float,
normalized_shape: int | tuple[int],
) -> jax.Array:
"""Layer Normalization
See: https://nn.labml.ai/normalization/layer_norm/index.html
transform the input by scaling and shifting to have zero mean and unit variance.
Args:
x: input array
gamma: scale
beta: shift
eps: a value added to the denominator for numerical stability.
normalized_shape: the shape of the input to be normalized.
"""
dims = tuple(range(len(x.shape) - len(normalized_shape), len(x.shape)))

μ = jnp.mean(x, axis=dims, keepdims=True)
Expand All @@ -67,17 +56,6 @@ def group_norm(
eps: float,
groups: int,
) -> jax.Array:
"""Group Normalization
See: https://nn.labml.ai/normalization/group_norm/index.html
transform the input by scaling and shifting to have zero mean and unit variance.
Args:
x: input array
gamma: scale Array
beta: shift Array
eps: a value added to the denominator for numerical stability.
groups: number of groups to separate the channels into
"""
# split channels into groups
xx = x.reshape(groups, -1)
μ = jnp.mean(xx, axis=-1, keepdims=True)
Expand Down Expand Up @@ -520,11 +498,11 @@ def __call__(
return x, state


@tree_evaluation.def_evalutation(BatchNorm)
@tree_evaluation.def_evaluation(BatchNorm)
def _(batchnorm: BatchNorm) -> EvalNorm:
return EvalNorm(
in_features=batchnorm.in_features,
momentum=batchnorm.momentum,
momentum=batchnorm.momentum, # ignored
eps=batchnorm.eps,
gamma_init_func=lambda *_: batchnorm.gamma,
beta_init_func=lambda *_: batchnorm.beta,
Expand Down
108 changes: 68 additions & 40 deletions serket/nn/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import jax.random as jr

import serket as sk
from serket.nn.custom_transform import tree_evaluation
from serket.nn.linear import Identity
from serket.nn.utils import (
IsInstance,
canonicalize,
Expand Down Expand Up @@ -437,22 +439,27 @@ def spatial_ndim(self) -> int:
return 3


def random_crop_nd(
x: jax.Array,
*,
crop_size: tuple[int, ...],
key: jr.KeyArray,
) -> jax.Array:
start: tuple[int, ...] = tuple(
jr.randint(key, shape=(), minval=0, maxval=x.shape[i] - s)
for i, s in enumerate(crop_size)
)
return jax.lax.dynamic_slice(x, start, crop_size)


class RandomCropND(sk.TreeClass):
def __init__(self, size: int | tuple[int, ...]):
self.size = canonicalize(size, self.spatial_ndim, name="size")

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
start = tuple(
jr.randint(
key,
shape=(),
minval=0,
maxval=x.shape[i] - s,
)
for i, s in enumerate(self.size)
)
return jax.lax.dynamic_slice(x, (0, *start), (x.shape[0], *self.size))
crop_size = [x.shape[0], *self.size]
return random_crop_nd(x, crop_size=crop_size, key=key)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -523,8 +530,7 @@ class FlipLeftRight2D(sk.TreeClass):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
flip = lambda x: jnp.flip(x, axis=1)
return jax.vmap(flip)(x)
return jax.vmap(lambda x: jnp.flip(x, axis=1))(x)

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -554,23 +560,58 @@ class FlipUpDown2D(sk.TreeClass):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
flip = lambda x: jnp.flip(x, axis=0)
return jax.vmap(flip)(x)
return jax.vmap(lambda x: jnp.flip(x, axis=0))(x)

@property
def spatial_ndim(self) -> int:
return 2


def _zoom_axis(
x: jax.Array,
factor: float,
key: jr.KeyArray,
axis: int,
) -> jax.Array:
if factor == 0:
return x

axis_size = x.shape[axis]
dtype = x.dtype
resized_axis_size = int(axis_size * (1 + factor))

def zoom_in(x):
shape = list(x.shape)
resized_shape = list(shape)
resized_shape[axis] = resized_axis_size
x = jax.image.resize(x, shape=resized_shape, method="linear")
x = random_crop_nd(x, crop_size=shape, key=key)
return x.astype(dtype)

def zoom_out(x):
shape = list(x.shape)
resized_shape = list(shape)
resized_shape[axis] = resized_axis_size
x = jax.image.resize(x, shape=resized_shape, method="linear")
pad_width = [(0, 0)] * len(x.shape)
left = (axis_size - resized_axis_size) // 2
right = axis_size - resized_axis_size - left
pad_width[axis] = (left, right)
x = jnp.pad(x, pad_width=pad_width)
return x.astype(dtype)

return zoom_out(x) if factor < 0 else zoom_in(x)


class RandomZoom2D(sk.TreeClass):
def __init__(
self,
height_factor: tuple[float, float] = (0.0, 1.0),
width_factor: tuple[float, float] = (0.0, 1.0),
):
"""Randomly zooms an image.
"""Randomly zooms a channle-first image tensor.
Positive values are zoom in, negative values are zoom out.
Positive values are zoom in, negative values are zoom out, and 0 is no zoom.
Args:
height_factor: (min, max)
Expand All @@ -588,6 +629,7 @@ def __init__(
self.height_factor = height_factor
self.width_factor = width_factor

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
k1, k2, k3, k4 = jr.split(key, 4)

Expand All @@ -608,32 +650,18 @@ def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
(height_factor, width_factor)
)

r, c = x.shape[1:3] # R = rows, C = cols
rr = int(r * (1 + height_factor)) # RR = resized rows,
cc = int(c * (1 + width_factor)) # CC = resized cols

if height_factor > 0:
# zoom in rows
x = Resize2D((rr, c))(x)
x = RandomCrop2D((r, c))(x, key=k3)

if width_factor > 0:
# zoom in cols
x = Resize2D((r, cc))(x)
x = RandomCrop2D((r, c))(x, key=k4)

if height_factor < 0:
# zoom out rows
x = Resize2D((rr, c))(x)
x = Pad2D((((r - rr) // 2, (r - rr) - ((r - rr) // 2)), (0, 0)))(x)

if width_factor < 0:
# zoom out cols
x = Resize2D((r, cc))(x)
x = Pad2D(((0, 0), ((c - cc) // 2, (c - cc) - (c - cc) // 2)))(x)

x = _zoom_axis(x, height_factor, k3, axis=1)
x = _zoom_axis(x, width_factor, k4, axis=2)
return x

@property
def spatial_ndim(self) -> int:
return 2


@tree_evaluation.def_evaluation(RandomCrop1D)
@tree_evaluation.def_evaluation(RandomCrop2D)
@tree_evaluation.def_evaluation(RandomCrop3D)
@tree_evaluation.def_evaluation(RandomZoom2D)
def random_transform_eval(_) -> Identity:
return Identity()

0 comments on commit 168c0e6

Please sign in to comment.