Skip to content

Commit

Permalink
v0.11.3 (#9)
Browse files Browse the repository at this point in the history
* next
  • Loading branch information
ASEM000 authored Dec 16, 2023
1 parent ac2d76e commit ea1ad24
Show file tree
Hide file tree
Showing 12 changed files with 505 additions and 485 deletions.
24 changes: 24 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
# Changelog

## V0.11.3

- Raise error if `autoinit` is used with `__init__` method defined.
- Avoid applying `copy.copy` `jax.Array` during flatten/unflatten or `AtIndexer` operations.
- Add `at` as an alias for `AtIndexer` for shorter syntax.
- Deprecate `AtIndexer.__call__` in favor of `value_and_tree` to apply function in a functional manner by copying the input argument.

```python
import sepes as sp
class Counter(sp.TreeClass):
def __init__(self, count: int):
self.count = count
def increment(self, value):
self.count += value
return self.count
counter = Counter(0)
# the function follow jax.value_and_grad semantics where the tree is the
# copied mutated input argument, if the function mutates the input arguments
sp.value_and_tree(lambda C: C.increment(1))(counter)
# (1, Counter(count=1))
```

- Updated docstrings. e.g. How to construct flops counter in `tree_summary` using `jax.jit`

## V0.11.2

- No freezing rule for `jax.Tracer` in `sp.freeze`
Expand Down
10 changes: 5 additions & 5 deletions docs/API/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
.. autoclass:: TreeClass
:members:
at

.. autoclass:: Partial
.. autoclass:: partial
.. autoclass:: AtIndexer
:members:
get,
Expand All @@ -17,15 +17,15 @@
scan,
reduce,
pluck,
__call__

.. autoclass:: at
.. autoclass:: BaseKey
:members:
__eq__

.. autofunction:: autoinit
.. autofunction:: leafwise
.. autofunction:: field
.. autofunction:: fields
.. autofunction:: bcmap
.. autofunction:: is_tree_equal
.. autofunction:: is_tree_equal
.. autofunction:: value_and_tree
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ ignore = [
"N813",
"D105",
"C901",
"B102",
]

[tool.ruff.pydocstyle]
Expand Down
16 changes: 13 additions & 3 deletions sepes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sepes._src.backend import backend_context
from sepes._src.code_build import autoinit, field, fields
from sepes._src.tree_base import TreeClass
from sepes._src.tree_index import AtIndexer, BaseKey
from sepes._src.tree_index import AtIndexer, BaseKey, at
from sepes._src.tree_mask import (
freeze,
is_frozen,
Expand All @@ -32,7 +32,14 @@
tree_str,
tree_summary,
)
from sepes._src.tree_util import Partial, bcmap, is_tree_equal, leafwise
from sepes._src.tree_util import (
Partial,
bcmap,
is_tree_equal,
leafwise,
partial,
value_and_tree,
)

__all__ = (
# general utils
Expand All @@ -57,16 +64,19 @@
"tree_mask",
# indexing utils
"AtIndexer",
"at",
"BaseKey",
# tree utils
"bcmap",
"Partial",
"partial",
"leafwise",
"value_and_tree",
# backend utils
"backend_context",
)

__version__ = "0.11.2"
__version__ = "0.11.3"

AtIndexer.__module__ = "sepes"
TreeClass.__module__ = "sepes"
Expand Down
88 changes: 56 additions & 32 deletions sepes/_src/code_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@
from collections.abc import Callable, MutableMapping, MutableSequence, MutableSet
from typing import Any, Literal, Sequence, TypeVar, get_args
from warnings import warn
from weakref import WeakSet
from typing_extensions import dataclass_transform


T = TypeVar("T")
PyTree = Any
EllipsisType = type(Ellipsis)
KindType = Literal["POS_ONLY", "POS_OR_KW", "VAR_POS", "KW_ONLY", "VAR_KW", "CLASS_VAR"]
arg_kinds: tuple[str, ...] = get_args(KindType)
EXCLUDED_FIELD_NAMES: set[str] = {"self", "__post_init__", "__annotations__"}
_autoinit_registry: WeakSet[type] = WeakSet()


class Null:
Expand Down Expand Up @@ -83,6 +86,22 @@ def slots(klass) -> tuple[str, ...]:
return getattr(klass, "__slots__", ())


def pipe(funcs: Sequence[Callable[[Any], Any]], name: str | None, value: Any):
"""Apply a sequence of functions on the field value."""
for func in funcs:
# for a given sequence of unary functions, apply them on the field value
# and return the result. if an error is raised, emit a descriptive error
try:
value = func(value)
except Exception as e:
# emit a *descriptive* error message with the name of the attribute
# associated with the field and the name of the function that raised
# the error.
cname = getattr(func, "__name__", func)
raise type(e)(f"On applying {cname} for field=`{name}`:\n{e}")
return value


class Field:
"""Field descriptor placeholder"""

Expand Down Expand Up @@ -133,20 +152,13 @@ def replace(self, **kwargs) -> Field:
# to allow the user to replace the field attributes.
return type(self)(**{k: kwargs.get(k, getattr(self, k)) for k in slots(Field)})

def pipe(self, funcs: Sequence[Callable[[Any], Any]], value: Any):
"""Apply a sequence of functions on the field value."""
for func in funcs:
# for a given sequence of unary functions, apply them on the field value
# and return the result. if an error is raised, emit a descriptive error
try:
value = func(value)
except Exception as e:
# emit a *descriptive* error message with the name of the attribute
# associated with the field and the name of the function that raised
# the error.
cname = getattr(func, "__name__", func)
raise type(e)(f"On applying {cname} for field=`{self.name}`:\n{e}")
return value
def pipe_on_setattr(self, value: Any) -> Any:
"""Apply a sequence of functions on the field value during setting."""
return pipe(self.on_setattr, self.name, value)

def pipe_on_getattr(self, value: Any) -> Any:
"""Apply a sequence of functions on the field value during getting."""
return pipe(self.on_getattr, self.name, value)

def __set_name__(self, owner, name: str) -> None:
"""Set the field name."""
Expand All @@ -167,11 +179,11 @@ def __get__(self: T, instance, _) -> T | Any:
"""Return the field value."""
if instance is None:
return self
return self.pipe(self.on_getattr, vars(instance)[self.name])
return self.pipe_on_getattr(vars(instance)[self.name])

def __set__(self: T, instance, value) -> None:
"""Set the field value."""
vars(instance)[self.name] = self.pipe(self.on_setattr, value)
vars(instance)[self.name] = self.pipe_on_setattr(value)

def __delete__(self: T, instance) -> None:
"""Delete the field value."""
Expand Down Expand Up @@ -460,7 +472,6 @@ def check_order_of_args(field_map: dict[KindType, Field]) -> dict[KindType, Fiel

def build_init_method(klass: type[T]) -> type[T]:
field_map: dict[KindType, Field] = build_field_map(klass)

field_map = check_excluded_types(field_map)
field_map = check_duplicate_var_kind(field_map)
field_map = check_order_of_args(field_map)
Expand All @@ -469,10 +480,12 @@ def build_init_method(klass: type[T]) -> type[T]:

body: list[str] = []
head: list[str] = ["self"]
heads: dict[str, list[str]] = defaultdict(list)
heads: dict[KindType, list[str]] = defaultdict(list)

for field in field_map.values():
if field.kind == "CLASS_VAR":
# skip class variables from init synthesis
# e.g. class A: x = field(default=1, kind="CLASS_VAR")
continue

if field.init:
Expand All @@ -489,6 +502,7 @@ def build_init_method(klass: type[T]) -> type[T]:
# e.g def __init__(.., x=value) but
# pass reference to the default value
heads[field.kind] += [f"{alias}=refmap['{field.name}'].default"]

else:
if field.default is not NULL:
# case for fields with `init=False` and no default value
Expand Down Expand Up @@ -516,14 +530,15 @@ def build_init_method(klass: type[T]) -> type[T]:

code += f"\n\t\t{';'.join(body)}"
code += f"\n\t__init__.__qualname__ = '{klass.__qualname__}.__init__'"
code += f"\n\t__init__.__annotations__ = refmap['__annotations__']"
code += "\n\t__init__.__annotations__ = refmap['__annotations__']"
code += "\n\treturn __init__"

# execute the code in the class namespace to generate the method
exec(code, vars(sys.modules[klass.__module__]), namespace := dict())
method = namespace["closure"](field_map)
# add the method to the class
setattr(klass, "__init__", method)
# mark the class as transformed
return klass


Expand All @@ -549,33 +564,41 @@ def autoinit(klass: type[T]) -> type[T]:
Example:
>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Tree:
... x: int
... y: int
>>> inspect.signature(Tree.__init__)
<Signature (self, x: int, y: int) -> None>
>>> tree = Tree(1, 2)
>>> tree.x, tree.y
(1, 2)
Example:
>>> # define fields with different argument kinds
Define fields with different argument kinds
>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Tree:
... kw_only_field: int = sp.field(default=1, kind="KW_ONLY")
... pos_only_field: int = sp.field(default=2, kind="POS_ONLY")
>>> inspect.signature(Tree.__init__)
<Signature (self, pos_only_field: int = 2, /, *, kw_only_field: int = 1) -> None>
Example:
>>> # define a converter to apply ``abs`` on the field value
Define a converter to apply ``abs`` on the field value
>>> @sp.autoinit
... class Tree:
... a:int = sp.field(on_setattr=[abs])
>>> Tree(a=-1).a
1
.. warning::
- The ``autoinit`` decorator will is no-op if the class already has a
user-defined ``__init__`` method.
The ``autoinit`` decorator will raise ``TypeError`` if the user defines
``__init__`` method in the decorated class.
Note:
- In case of inheritance, the ``__init__`` method is generated from the
Expand Down Expand Up @@ -630,22 +653,23 @@ def autoinit(klass: type[T]) -> type[T]:
Traceback (most recent call last):
...
"""
if klass in _autoinit_registry:
# autoinit(autoinit(klass)) == autoinit(klass)
# idempotent decorator to avoid redefining the class
return klass

if "__init__" in vars(klass):
# if the class already has a user-defined __init__ method
# then return the class as is without any modification
warn(f"autoinit({klass.__name__}) with `__init__` is a no-op")
return klass

for base in klass.__mro__[1:-1]:
# skip the current and object class
if "__init__" in vars(base):
warn(f"autoinit({klass.__name__}) skips base class {base.__name__} hints")
# then raise an error to avoid confusing the user
raise TypeError(f"autoinit({klass.__name__}) with defined `__init__`.")

# first convert the current class hints to fields
# then build the __init__ method from the fields of the current class
# and any base classes that are decorated with `autoinit`
return build_init_method(convert_hints_to_fields(klass))
klass = build_init_method(convert_hints_to_fields(klass))
# add the class to the registry to avoid redefining the class
_autoinit_registry.add(klass)
return klass


excluded_type_dispatcher = ft.singledispatch(lambda _: None)
Expand Down
Loading

0 comments on commit ea1ad24

Please sign in to comment.