Skip to content

Commit

Permalink
more edits switch to toml
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 9, 2023
1 parent a2d60a5 commit d269138
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 165 deletions.
14 changes: 0 additions & 14 deletions .flake8

This file was deleted.

2 changes: 0 additions & 2 deletions .isort.cfg

This file was deleted.

43 changes: 43 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
[build-systems]
requires = ["setuptools >= 61"]
build-backend = "setuptools.build_meta"

[project]
name = "serket"
dynamic = ["version"]
requires-python = ">=3.8"
license = {text = "Apache-2.0"}
description = "Functional neural network library in JAX"
authors = [{name = "Mahmoud Asem", email = "[email protected]"}]
keywords = ["jax", "neural-networks", "functional-programming", "machine-learning"]
dependencies = ["jax>=0.4.7", "typing-extensions"]

classifiers=[
"Development Status :: 5 - Production/Stable",
"Environment :: Console",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
]

[tool.setuptools.dynamic]
version = {attr = "serket.__version__" }

[tool.setuptools.packages.find]
include = ["serket", "serket.*"]

[project.urls]
Source = "https://github.com/ASEM000/Serket"


[tool.ruff]
select = ["F", "E", "I"]
line-length = 120
ignore = [
"E731", # do not assign a lambda expression, use a def
]
3 changes: 2 additions & 1 deletion serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@


class AdaptiveLeakyReLU(pytc.TreeClass):
"""Leaky ReLU activation function with learnable `a` parameter
"""Leaky ReLU activation function
Note:
https://arxiv.org/pdf/1906.01170.pdf.
"""
Expand Down
3 changes: 2 additions & 1 deletion serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class Sequential(pytc.TreeClass):
"""A sequential container for layers.
Args:
layers: a tuple of layers.
layers: a tuple or a list of layers. if a list is passed, it will
be casted to a tuple to maintain immutable behavior.
Example:
>>> import jax.numpy as jnp
Expand Down
7 changes: 6 additions & 1 deletion serket/nn/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ def __init__(self, size: int | tuple[int, ...]):
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)) -> jax.Array:
start = tuple(
jr.randint(key, shape=(), minval=0, maxval=x.shape[i] - s)
jr.randint(
key,
shape=(),
minval=0,
maxval=x.shape[i] - s,
)
for i, s in enumerate(self.size)
)
return jax.lax.dynamic_slice(x, (0, *start), (x.shape[0], *self.size))
Expand Down
14 changes: 5 additions & 9 deletions serket/nn/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@
from serket.nn.utils import positive_int_cb, validate_spatial_ndim


@ft.partial(jax.jit, static_argnums=(1,))
def histeq(x, bins_count: int = 256):
hist, bins = jnp.histogram(x.flatten(), bins_count, density=True)
cdf = hist.cumsum()
cdf = (bins_count - 1) * cdf / cdf[-1]
return jnp.interp(x.flatten(), bins[:-1], cdf).reshape(x.shape)


class HistogramEqualization2D(pytc.TreeClass):
def __init__(self, bins: int = 256):
"""Apply histogram equalization to 2D spatial array channel wise
Expand All @@ -47,7 +39,11 @@ def __init__(self, bins: int = 256):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
return histeq(x, self.bins)
bins_count = self.bins
hist, bins = jnp.histogram(x.flatten(), bins_count, density=True)
cdf = hist.cumsum()
cdf = (bins_count - 1) * cdf / cdf[-1]
return jnp.interp(x.flatten(), bins[:-1], cdf).reshape(x.shape)

@property
def spatial_ndim(self) -> int:
Expand Down
4 changes: 1 addition & 3 deletions serket/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,9 +980,7 @@ def __call__(
**k,
) -> jax.Array:
if not isinstance(state, (RNNState, type(None))):
msg = "Expected state to be an instance of RNNState, "
msg += f"got {type(state).__name__}"
raise TypeError(msg)
raise TypeError(f"Expected state to be an instance of RNNState, {state=}")

# non-spatial RNN : (time steps, in_features)
# spatial RNN : (time steps, in_features, *spatial_dims)
Expand Down
125 changes: 45 additions & 80 deletions serket/nn/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,19 @@ class ResizeND(pytc.TreeClass):
Resize an image to a given size using a given interpolation method.
Args:
size: the size of the output.
method: the method of interpolation. Defaults to "nearest".
size: the size of the output. if size is None, the output size is
calculated as input size * scale
method: the method of interpolation. Defaults to "nearest". choices are:
- "nearest": Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- "linear", "bilinear", "trilinear", "triangle": Linear interpolation.
If antialias is True, uses a triangular filter when downsampling.
- "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys
cubic kernel.
- "lanczos3": Lanczos resampling, using a kernel of radius 3.
- "lanczos5": Lanczos resampling, using a kernel of radius 5.
antialias: whether to use antialiasing. Defaults to True.
antialias: whether to use antialiasing. Defaults to True.
Note:
- if size is None, the output size is calculated as input size * scale
- interpolation methods
"nearest" :
Nearest neighbor interpolation. The values of antialias and precision are ignored.
"linear", "bilinear", "trilinear", "triangle" :
Linear interpolation. If antialias is True, uses a triangular filter when downsampling.
"cubic", "bicubic", "tricubic" :
Cubic interpolation, using the Keys cubic kernel.
"lanczos3" :
Lanczos resampling, using a kernel of radius 3.
"lanczos5"
Lanczos resampling, using a kernel of radius 5.
"""

def __init__(
Expand Down Expand Up @@ -118,27 +110,18 @@ def __init__(
"""Resize a 1D input to a given size using a given interpolation method.
Args:
size: the size of the output.
method: the method of interpolation. Defaults to "nearest".
size: the size of the output. if size is None, the output size is
calculated as input size * scale
method: the method of interpolation. Defaults to "nearest". choices are:
- "nearest": Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- "linear", "bilinear", "trilinear", "triangle": Linear interpolation.
If antialias is True, uses a triangular filter when downsampling.
- "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys
cubic kernel.
- "lanczos3": Lanczos resampling, using a kernel of radius 3.
- "lanczos5": Lanczos resampling, using a kernel of radius 5.
antialias: whether to use antialiasing. Defaults to True.
Note:
- if size is None, the output size is calculated as input size * scale
- interpolation methods
"nearest" :
Nearest neighbor interpolation. The values of antialias and precision are ignored.
"linear", "bilinear", "trilinear", "triangle" :
Linear interpolation. If antialias is True, uses a triangular filter when downsampling.
"cubic", "bicubic", "tricubic" :
Cubic interpolation, using the Keys cubic kernel.
"lanczos3" :
Lanczos resampling, using a kernel of radius 3.
"lanczos5"
Lanczos resampling, using a kernel of radius 5.
"""
super().__init__(size=size, method=method, antialias=antialias)

Expand All @@ -157,27 +140,18 @@ def __init__(
"""Resize a 2D input to a given size using a given interpolation method.
Args:
size: the size of the output.
method: the method of interpolation. Defaults to "nearest".
size: the size of the output. if size is None, the output size is
calculated as input size * scale
method: the method of interpolation. Defaults to "nearest". choices are:
- "nearest": Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- "linear", "bilinear", "trilinear", "triangle": Linear interpolation.
If antialias is True, uses a triangular filter when downsampling.
- "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys
cubic kernel.
- "lanczos3": Lanczos resampling, using a kernel of radius 3.
- "lanczos5": Lanczos resampling, using a kernel of radius 5.
antialias: whether to use antialiasing. Defaults to True.
Note:
- if size is None, the output size is calculated as input size * scale
- interpolation methods
"nearest" :
Nearest neighbor interpolation. The values of antialias and precision are ignored.
"linear", "bilinear", "trilinear", "triangle" :
Linear interpolation. If antialias is True, uses a triangular filter when downsampling.
"cubic", "bicubic", "tricubic" :
Cubic interpolation, using the Keys cubic kernel.
"lanczos3" :
Lanczos resampling, using a kernel of radius 3.
"lanczos5"
Lanczos resampling, using a kernel of radius 5.
"""
super().__init__(size=size, method=method, antialias=antialias)

Expand All @@ -196,27 +170,18 @@ def __init__(
"""Resize a 3D input to a given size using a given interpolation method.
Args:
size: the size of the output.
method: the method of interpolation. Defaults to "nearest".
size: the size of the output. if size is None, the output size is
calculated as input size * scale
method: the method of interpolation. Defaults to "nearest". choices are:
- "nearest": Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- "linear", "bilinear", "trilinear", "triangle": Linear interpolation.
If antialias is True, uses a triangular filter when downsampling.
- "cubic", "bicubic", "tricubic": Cubic interpolation, using the Keys
cubic kernel.
- "lanczos3": Lanczos resampling, using a kernel of radius 3.
- "lanczos5": Lanczos resampling, using a kernel of radius 5.
antialias: whether to use antialiasing. Defaults to True.
Note:
- if size is None, the output size is calculated as input size * scale
- interpolation methods
"nearest" :
Nearest neighbor interpolation. The values of antialias and precision are ignored.
"linear", "bilinear", "trilinear", "triangle" :
Linear interpolation. If antialias is True, uses a triangular filter when downsampling.
"cubic", "bicubic", "tricubic" :
Cubic interpolation, using the Keys cubic kernel.
"lanczos3" :
Lanczos resampling, using a kernel of radius 3.
"lanczos5"
Lanczos resampling, using a kernel of radius 5.
"""
super().__init__(size=size, method=method, antialias=antialias)

Expand Down
54 changes: 0 additions & 54 deletions setup.py

This file was deleted.

0 comments on commit d269138

Please sign in to comment.