Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 31, 2024
1 parent 533d390 commit afe7753
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 108 deletions.
10 changes: 10 additions & 0 deletions docs/API/constructor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
🏗️ Constructor utils API
=============================


.. currentmodule:: serket

.. autofunction:: field
.. autofunction:: fields
.. autofunction:: autoinit
.. autofunction:: leafwise
31 changes: 0 additions & 31 deletions docs/API/core.rst

This file was deleted.

5 changes: 1 addition & 4 deletions docs/API/masking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

.. currentmodule:: serket

.. autofunction:: is_nondiff
.. autofunction:: freeze
.. autofunction:: unfreeze
.. autofunction:: is_frozen
.. autofunction:: is_masked
.. autofunction:: tree_mask
.. autofunction:: tree_unmask
10 changes: 10 additions & 0 deletions docs/API/module.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
📍 Module API
=============================


.. currentmodule:: serket

.. autoclass:: TreeClass
:members:
at

2 changes: 0 additions & 2 deletions docs/API/pretty_print.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
.. currentmodule:: serket

.. autofunction:: tree_diagram
.. autofunction:: tree_graph
.. autofunction:: tree_mermaid
.. autofunction:: tree_repr
.. autofunction:: tree_str
.. autofunction:: tree_summary
8 changes: 5 additions & 3 deletions docs/API/sepes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
=============================

.. note::
`sepes <https://sepes.readthedocs.io/en/latest/?badge=latest>`_ API is fully re-exported under the ``serket`` namespace.
`Check the docs <https://sepes.readthedocs.io/en/latest/?badge=latest>`_ for full details.
`sepes <https://sepes.readthedocs.io/>`_ API is fully re-exported under the ``serket`` namespace.
`Check the docs <https://sepes.readthedocs.io/>`_ for full details.

.. toctree::
:maxdepth: 2
:caption: API Documentation

core
module
masking
tree
constructor
pretty_print
17 changes: 17 additions & 0 deletions docs/API/tree.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
🌲 Tree utils API
=============================


.. currentmodule:: serket

.. autoclass:: at
:members:
get,
set,
apply,
scan,
reduce,
pluck,

.. autofunction:: value_and_tree
.. autofunction:: bcmap
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[build-systems]
[build-system]
requires = ["setuptools >= 61"]
build-backend = "setuptools.build_meta"

Expand All @@ -15,7 +15,7 @@ keywords = [
"functional-programming",
"machine-learning",
]
dependencies = ["sepes>=0.11.3"]
dependencies = ["sepes>=0.12.0"]

classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down
39 changes: 10 additions & 29 deletions serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,20 @@


from sepes import (
AtIndexer,
BaseKey,
Partial,
TreeClass,
at,
autoinit,
bcmap,
field,
fields,
freeze,
is_frozen,
is_nondiff,
is_tree_equal,
is_masked,
leafwise,
partial,
tree_diagram,
tree_graph,
tree_mask,
tree_mermaid,
tree_repr,
tree_str,
tree_summary,
tree_unmask,
unfreeze,
value_and_tree,
)

Expand All @@ -49,35 +39,26 @@
from . import cluster, image, nn

__all__ = [
# general utils
# sepes
# module utils
"TreeClass",
"is_tree_equal",
"field",
"fields",
"autoinit",
# pprint utils
"tree_diagram",
"tree_graph",
"tree_mermaid",
"tree_repr",
"tree_str",
"tree_summary",
# masking utils
"is_nondiff",
"is_frozen",
"freeze",
"unfreeze",
"is_masked",
"tree_unmask",
"tree_mask",
"value_and_tree",
# indexing utils
"AtIndexer",
"at",
"BaseKey",
# tree utils
"at",
"bcmap",
"Partial",
"partial",
"value_and_tree",
# construction utils
"field",
"fields",
"autoinit",
"leafwise",
# serket
"cluster",
Expand Down
19 changes: 13 additions & 6 deletions serket/_src/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@

import serket as sk
from serket._src.custom_transform import tree_eval
from serket._src.utils import single_dispatch


@ft.singledispatch
@single_dispatch(argnum=0)
def sequential(key: jax.Array, _1, _2):
raise TypeError(f"Invalid {type(key)=}")


@sequential.register(type(None))
@sequential.def_type(type(None))
def _(key: None, layers: Sequence[Callable[..., Any]], array: Any):
del key # no key is supplied then no random number generation is needed
return ft.reduce(lambda x, layer: layer(x), layers, array)


@sequential.register(jax.Array)
@sequential.def_type(jax.Array)
def _(key: jax.Array, layers: Sequence[Callable[..., Any]], array: Any):
"""Applies a sequence of layers to an array.
Expand Down Expand Up @@ -79,16 +80,16 @@ def __init__(self, *layers):
def __call__(self, input: jax.Array, *, key: jax.Array | None = None) -> jax.Array:
return sequential(key, self.layers, input)

@ft.singledispatchmethod
@single_dispatch(argnum=1)
def __getitem__(self, key):
raise TypeError(f"Invalid index type: {type(key)}")

@__getitem__.register(slice)
@__getitem__.def_type(slice)
def _(self, key: slice):
# return a new Sequential object with the sliced layers
return type(self)(*self.layers[key])

@__getitem__.register(int)
@__getitem__.def_type(int)
def _(self, key: int):
return self.layers[key]

Expand All @@ -102,6 +103,12 @@ def __reversed__(self):
return reversed(self.layers)


@sk.tree_summary.def_type(Sequential)
def _(node):
types = [type(x).__name__ for x in node]
return f"{type(node).__name__}[{','.join(types)}]"

Check warning on line 110 in serket/_src/containers.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/containers.py#L108-L110

Added lines #L108 - L110 were not covered by tests

def random_choice(key: jax.Array, layers: tuple[Callable[..., Any], ...], array: Any):
"""Randomly selects one of the given layers/functions.
Expand Down
20 changes: 10 additions & 10 deletions serket/_src/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

from __future__ import annotations

import functools as ft
from inspect import getfullargspec
from typing import Any, TypeVar

import jax

import serket as sk
from serket._src.utils import single_dispatch

T = TypeVar("T")

Expand Down Expand Up @@ -94,20 +94,20 @@ def tree_state(tree: T, **kwargs) -> T:
# input. This poses a challenge for the user to pass the correct input
# to the state initialization rule.

types = tuple(set(tree_state.state_dispatcher.registry) - {object})
types = tuple(set(tree_state.dispatcher.registry) - {object})

def is_leaf(node: Any) -> bool:
return isinstance(node, types)

def dispatch_func(leaf):
try:
return tree_state.state_dispatcher(leaf, **kwargs)
return tree_state.dispatcher(leaf, **kwargs)

except TypeError as e:
# check if the leaf has a state rule

for mro in type(leaf).__mro__[:-1]:
if mro in (registry := tree_state.state_dispatcher.registry):
if mro in (registry := tree_state.dispatcher.registry):
func = registry[mro]
break
else:
Expand All @@ -134,8 +134,8 @@ def dispatch_func(leaf):
return jax.tree_map(dispatch_func, tree, is_leaf=is_leaf)


tree_state.state_dispatcher = ft.singledispatch(NoState)
tree_state.def_state = tree_state.state_dispatcher.register
tree_state.dispatcher = single_dispatch(argnum=0)(NoState)
tree_state.def_state = tree_state.dispatcher.def_type


def tree_eval(tree):
Expand Down Expand Up @@ -197,13 +197,13 @@ def tree_eval(tree):
[1. 1. 1.]]
"""

types = tuple(set(tree_eval.eval_dispatcher.registry) - {object})
types = tuple(set(tree_eval.dispatcher.registry) - {object})

def is_leaf(node: Any) -> bool:
return isinstance(node, types)

return jax.tree_map(tree_eval.eval_dispatcher, tree, is_leaf=is_leaf)
return jax.tree_map(tree_eval.dispatcher, tree, is_leaf=is_leaf)


tree_eval.eval_dispatcher = ft.singledispatch(lambda x: x)
tree_eval.def_eval = tree_eval.eval_dispatcher.register
tree_eval.dispatcher = single_dispatch(argnum=0)(lambda x: x)
tree_eval.def_eval = tree_eval.dispatcher.def_type
7 changes: 3 additions & 4 deletions serket/_src/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@

from __future__ import annotations

import functools as ft
from typing import Callable, Literal, TypeVar, Union, get_args

import jax
import jax.numpy as jnp
from jax import lax

import serket as sk
from serket._src.utils import IsInstance, Range, ScalarLike
from serket._src.utils import IsInstance, Range, ScalarLike, single_dispatch

T = TypeVar("T")

Expand Down Expand Up @@ -521,12 +520,12 @@ def __call__(self, input: jax.Array) -> jax.Array:
act_map = dict(zip(get_args(ActivationLiteral), acts))


@ft.singledispatch
@single_dispatch(argnum=0)
def resolve_activation(act: T) -> T:
return act


@resolve_activation.register(str)
@resolve_activation.def_type(str)
def _(act: str):
try:
return jax.tree_map(lambda x: x, act_map[act])
Expand Down
Loading

0 comments on commit afe7753

Please sign in to comment.