Skip to content

Commit

Permalink
Merge pull request #22 from ASEM000/batchnorm,-tree_evaluation
Browse files Browse the repository at this point in the history
edits and state/eval addition
  • Loading branch information
ASEM000 committed Jul 20, 2023
2 parents 3fcf63e + 4c58e37 commit 02f11f8
Show file tree
Hide file tree
Showing 19 changed files with 643 additions and 322 deletions.
13 changes: 5 additions & 8 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
name: pypi

on:
release:
types: [created]

on: workflow_dispatch
jobs:
deploy:
runs-on: ubuntu-latest
Expand All @@ -16,11 +12,12 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
pip install build
pip install twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
python -m build
python -m twine upload dist/*
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install git+https://github.com/ASEM000/PyTreeClass
python -m pip install tensorflow
python -m pip install pytreeclass>=0.4.0
python -m pip install keras_core>=0.1.1
python -m pip install pytest wheel optax jaxlib coverage kernex
- name: Pytest Check
run: |
Expand Down
8 changes: 7 additions & 1 deletion serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_tree_equal,
tree_diagram,
tree_flatten_with_trace,
tree_graph,
tree_leaves_with_trace,
tree_map_with_trace,
tree_mask,
Expand All @@ -40,6 +41,8 @@
)

from . import nn
from .nn.evaluation import tree_evaluation
from .nn.state import tree_state

__all__ = (
# general utils
Expand All @@ -49,6 +52,7 @@
"fields",
# pprint utils
"tree_diagram",
"tree_graph",
"tree_mermaid",
"tree_repr",
"tree_str",
Expand All @@ -72,7 +76,9 @@
"Partial",
# serket
"nn",
"tree_evaluation",
"tree_state",
)


__version__ = "0.2.0b7"
__version__ = "0.4.0b1"
14 changes: 3 additions & 11 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,8 @@
from .flatten import Flatten, Unflatten
from .flip import FlipLeftRight2D, FlipUpDown2D
from .fully_connected import FNN, MLP
from .linear import (
Bilinear,
Embedding,
GeneralLinear,
Identity,
Linear,
MergeLinear,
Multilinear,
)
from .normalization import GroupNorm, InstanceNorm, LayerNorm
from .linear import Bilinear, Embedding, GeneralLinear, Identity, Linear, Multilinear
from .normalization import BatchNorm, GroupNorm, InstanceNorm, LayerNorm
from .padding import Pad1D, Pad2D, Pad3D
from .pooling import (
AdaptiveAvgPool1D,
Expand Down Expand Up @@ -149,7 +141,6 @@
"Multilinear",
"GeneralLinear",
"Embedding",
"MergeLinear",
# Dropout
"Dropout",
"Dropout1D",
Expand Down Expand Up @@ -215,6 +206,7 @@
"LayerNorm",
"InstanceNorm",
"GroupNorm",
"BatchNorm",
# Blur
"AvgBlur2D",
"GaussianBlur2D",
Expand Down
7 changes: 4 additions & 3 deletions serket/nn/contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ def __init__(self, contrast_range=(0.5, 1)):
and len(contrast_range) == 2
and contrast_range[0] <= contrast_range[1]
):
msg = "contrast_range must be a tuple of two floats, "
msg += "with the first one smaller than the second one."
raise ValueError(msg)
raise ValueError(
"`contrast_range` must be a tuple of two floats, "
"with the first one smaller than the second one."
)

self.contrast_range = contrast_range

Expand Down
67 changes: 26 additions & 41 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,20 @@ def __init__(
self.spatial_ndim,
name="kernel_dilation",
)
self.weight_init_func = resolve_init_func(weight_init_func)
self.bias_init_func = resolve_init_func(bias_init_func)

weight_init_func = resolve_init_func(weight_init_func)
bias_init_func = resolve_init_func(bias_init_func)

self.groups = positive_int_cb(groups)

if self.out_features % self.groups != 0:
raise ValueError(
f"Expected out_features % groups == 0, \n"
f"got {self.out_features % self.groups}"
)
raise ValueError(f"{(out_features % groups == 0)=}")

weight_shape = (out_features, in_features // groups, *self.kernel_size)
self.weight = self.weight_init_func(key, weight_shape)
self.weight = weight_init_func(key, weight_shape)

if bias_init_func is None:
self.bias = None
else:
bias_shape = (out_features, *(1,) * self.spatial_ndim)
self.bias = self.bias_init_func(key, bias_shape)
bias_shape = (out_features, *(1,) * self.spatial_ndim)
self.bias = bias_init_func(key, bias_shape)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
Expand Down Expand Up @@ -432,24 +428,18 @@ def __init__(
self.spatial_ndim,
name="kernel_dilation",
)
self.weight_init_func = resolve_init_func(weight_init_func)
self.bias_init_func = resolve_init_func(bias_init_func)
weight_init_func = resolve_init_func(weight_init_func)
bias_init_func = resolve_init_func(bias_init_func)
self.groups = positive_int_cb(groups)

if self.out_features % self.groups != 0:
raise ValueError(
"Expected out_features % groups == 0,"
f"got {self.out_features % self.groups}"
)
raise ValueError(f"{(self.out_features % self.groups ==0)=}")

weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW
self.weight = self.weight_init_func(key, weight_shape)
self.weight = weight_init_func(key, weight_shape)

if bias_init_func is None:
self.bias = None
else:
bias_shape = (out_features, *(1,) * self.spatial_ndim)
self.bias = self.bias_init_func(key, bias_shape)
bias_shape = (out_features, *(1,) * self.spatial_ndim)
self.bias = bias_init_func(key, bias_shape)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
Expand Down Expand Up @@ -774,19 +764,18 @@ def __init__(
self.padding = padding # delayed canonicalization
self.input_dilation = canonicalize(1, self.spatial_ndim, name="input_dilation")
self.kernel_dilation = canonicalize(
1, self.spatial_ndim, name="kernel_dilation"
1,
self.spatial_ndim,
name="kernel_dilation",
)
self.weight_init_func = resolve_init_func(weight_init_func)
self.bias_init_func = resolve_init_func(bias_init_func)
weight_init_func = resolve_init_func(weight_init_func)
bias_init_func = resolve_init_func(bias_init_func)

weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW
self.weight = self.weight_init_func(key, weight_shape)
self.weight = weight_init_func(key, weight_shape)

if bias_init_func is None:
self.bias = None
else:
bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim)
self.bias = self.bias_init_func(key, bias_shape)
bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim)
self.bias = bias_init_func(key, bias_shape)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
Expand Down Expand Up @@ -1359,8 +1348,8 @@ def __init__(
self.spatial_ndim,
name="kernel_dilation",
)
self.weight_init_func = resolve_init_func(weight_init_func)
self.bias_init_func = resolve_init_func(bias_init_func)
weight_init_func = resolve_init_func(weight_init_func)
bias_init_func = resolve_init_func(bias_init_func)

out_size = calculate_convolution_output_shape(
shape=self.in_size,
Expand All @@ -1376,14 +1365,10 @@ def __init__(
*out_size,
)

self.weight = self.weight_init_func(key, weight_shape)
self.weight = weight_init_func(key, weight_shape)

bias_shape = (self.out_features, *out_size)

if bias_init_func is None:
self.bias = None
else:
self.bias = self.bias_init_func(key, bias_shape)
self.bias = bias_init_func(key, bias_shape)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
Expand Down
11 changes: 10 additions & 1 deletion serket/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from jax import lax

import serket as sk
from serket.nn.evaluation import tree_evaluation
from serket.nn.linear import Identity
from serket.nn.utils import Range, validate_spatial_ndim


Expand All @@ -38,7 +40,7 @@ class Dropout(sk.TreeClass):
>>> import jax.numpy as jnp
>>> layer = sk.nn.Dropout(0.5)
>>> # change `p` to 0.0 to turn off dropout
>>> layer = layer.at["p"].set(0.0, is_leaf=pytc.is_frozen)
>>> layer = layer.at["p"].set(0.0, is_leaf=sk.is_frozen)
Note:
Use `p`= 0.0 to turn off dropout.
Expand Down Expand Up @@ -157,3 +159,10 @@ def __init__(self, p: float = 0.5):
@property
def spatial_ndim(self) -> int:
return 3


@tree_evaluation.def_evalutation(Dropout)
@tree_evaluation.def_evalutation(DropoutND)
def dropout_evaluation(_):
# dropout is a no-op during evaluation
return Identity()
58 changes: 58 additions & 0 deletions serket/nn/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 Serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Define dispatchers for custom tree evaluation."""

from __future__ import annotations

import functools as ft
from typing import Any, Callable, TypeVar

import jax

T = TypeVar("T")


def tree_evaluation(tree: T) -> T:
"""Modify tree layers to disable any trainning related behavior.
For example, `Dropout` layers drop probability is set to 0.0. and `BatchNorm`
layer `track_running_stats` is set to False when evaluating the tree.
Args:
tree: A tree of layers.
Returns:
A tree of layers with evaluation behavior.
Example:
>>> # dropout is replaced by an identity layer in evaluation mode
>>> # by registering `tree_evaluation.def_evaluation(sk.nn.Dropout, sk.nn.Identity)`
>>> import jax.numpy as jnp
>>> import serket as sk
>>> layer = sk.nn.Dropout(0.5)
>>> sk.tree_evaluation(layer)
Identity()
"""

def is_leaf(x: Callable[[Any], bool]) -> bool:
types = set(tree_evaluation.evaluation_dispatcher.registry.keys())
types.discard(object)
return isinstance(x, tuple(types))

return jax.tree_map(tree_evaluation.evaluation_dispatcher, tree, is_leaf=is_leaf)


tree_evaluation.evaluation_dispatcher = ft.singledispatch(lambda x: x)
tree_evaluation.def_evalutation = tree_evaluation.evaluation_dispatcher.register
Loading

0 comments on commit 02f11f8

Please sign in to comment.