Skip to content

Commit

Permalink
Optimize flatten/unflatten (#12)
Browse files Browse the repository at this point in the history
* create custom dataclass decorator

* fix tests

* add typing extension

* ignore overloads from covarage

* remove sort

* simplify flatten

* use dict for flatten

* ordered static fields

* 0.2.0

* no keys

* no keys

* use flatten_func

* conditional register

* revert

* ordered

* add deterministic

* add _pytree__order_leaves

* make immutable

* mapping

* sort fields after init

* change sort/filter order

* fix test
  • Loading branch information
cgarciae authored Apr 13, 2023
1 parent fb81642 commit aa7e016
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 54 deletions.
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "simple-pytree"
version = "0.1.7"
version = "0.2.0"
description = ""
authors = ["Cristian Garcia <[email protected]>"]
license = "MIT"
Expand Down
123 changes: 77 additions & 46 deletions simple_pytree/pytree.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
import dataclasses
import importlib.util
import itertools
import inspect
import typing as tp
from abc import ABCMeta
from copy import copy
from functools import partial
from types import MappingProxyType

import jax

P = tp.TypeVar("P", bound="Pytree")


class PytreeMeta(ABCMeta):
def __call__(self: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
obj: P = self.__new__(self, *args, **kwargs)
def __call__(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
obj: P = cls.__new__(cls, *args, **kwargs)
obj.__dict__["_pytree__initializing"] = True
try:
obj.__init__(*args, **kwargs)
finally:
del obj.__dict__["_pytree__initializing"]

vars_dict = vars(obj)
vars_dict["_pytree__node_fields"] = tuple(
sorted(
field for field in vars_dict if field not in cls._pytree__static_fields
)
)
return obj


class Pytree(metaclass=PytreeMeta):
_pytree__initializing: bool
_pytree__class_is_mutable: bool
_pytree__static_fields: tp.FrozenSet[str]
_pytree__static_fields: tp.Tuple[str, ...]
_pytree__node_fields: tp.Tuple[str, ...]
_pytree__setter_descriptors: tp.FrozenSet[str]

def __init_subclass__(cls, mutable: bool = False):
Expand All @@ -36,6 +45,9 @@ def __init_subclass__(cls, mutable: bool = False):
setter_descriptors = set()
static_fields = _inherited_static_fields(cls)

# add special static fields
static_fields.add("_pytree__node_fields")

for field, value in class_vars.items():
if isinstance(value, dataclasses.Field) and not value.metadata.get(
"pytree_node", True
Expand All @@ -46,28 +58,46 @@ def __init_subclass__(cls, mutable: bool = False):
if hasattr(value, "__set__"):
setter_descriptors.add(field)

static_fields = tuple(sorted(static_fields))

# init class variables
cls._pytree__initializing = False
cls._pytree__class_is_mutable = mutable
cls._pytree__static_fields = frozenset(static_fields)
cls._pytree__static_fields = static_fields
cls._pytree__setter_descriptors = frozenset(setter_descriptors)

# TODO: clean up this in the future once minimal supported version is 0.4.7
if hasattr(jax.tree_util, "register_pytree_with_keys"):
jax.tree_util.register_pytree_with_keys(
cls,
partial(
cls._pytree__flatten,
cls._pytree__static_fields,
with_key_paths=True,
),
cls._pytree__unflatten,
)
if (
"flatten_func"
in inspect.signature(jax.tree_util.register_pytree_with_keys).parameters
):
jax.tree_util.register_pytree_with_keys(
cls,
partial(
cls._pytree__flatten,
with_key_paths=True,
),
cls._pytree__unflatten,
flatten_func=partial(
cls._pytree__flatten,
with_key_paths=False,
),
)
else:
jax.tree_util.register_pytree_with_keys(
cls,
partial(
cls._pytree__flatten,
with_key_paths=True,
),
cls._pytree__unflatten,
)
else:
jax.tree_util.register_pytree_node(
cls,
partial(
cls._pytree__flatten,
cls._pytree__static_fields,
with_key_paths=False,
),
cls._pytree__unflatten,
Expand All @@ -86,45 +116,44 @@ def __init_subclass__(cls, mutable: bool = False):
@classmethod
def _pytree__flatten(
cls,
static_field_names: tp.FrozenSet[str],
pytree: "Pytree",
*,
with_key_paths: bool,
) -> tp.Tuple[
tp.List[tp.Any],
tp.Tuple[tp.List[str], tp.List[tp.Tuple[str, tp.Any]]],
]:
static_fields = []
node_names = []
node_values = []
# sort to ensure deterministic order
for field in sorted(vars(pytree)):
value = getattr(pytree, field)
if field in static_field_names:
static_fields.append((field, value))
else:
if with_key_paths:
value = (jax.tree_util.GetAttrKey(field), value)
node_names.append(field)
node_values.append(value)
) -> tp.Tuple[tp.Tuple[tp.Any, ...], tp.Mapping[str, tp.Any],]:
all_vars = vars(pytree).copy()
static = {k: all_vars.pop(k) for k in pytree._pytree__static_fields}

if with_key_paths:
node_values = tuple(
(jax.tree_util.GetAttrKey(field), all_vars.pop(field))
for field in pytree._pytree__node_fields
)
else:
node_values = tuple(
all_vars.pop(field) for field in pytree._pytree__node_fields
)

return node_values, (node_names, static_fields)
if all_vars:
raise ValueError(
f"Unexpected fields in {cls.__name__}: {', '.join(all_vars.keys())}"
)

return node_values, MappingProxyType(static)

@classmethod
def _pytree__unflatten(
cls: tp.Type[P],
metadata: tp.Tuple[tp.List[str], tp.List[tp.Tuple[str, tp.Any]]],
node_values: tp.List[tp.Any],
static_fields: tp.Mapping[str, tp.Any],
node_values: tp.Tuple[tp.Any, ...],
) -> P:
node_names, static_fields = metadata
node_fields = dict(zip(node_names, node_values))
pytree = object.__new__(cls)
pytree.__dict__.update(node_fields, **dict(static_fields))
pytree.__dict__.update(zip(static_fields["_pytree__node_fields"], node_values))
pytree.__dict__.update(static_fields)
return pytree

@classmethod
def _to_flax_state_dict(
cls, static_field_names: tp.FrozenSet[str], pytree: "Pytree"
cls, static_field_names: tp.Tuple[str, ...], pytree: "Pytree"
) -> tp.Dict[str, tp.Any]:
from flax import serialization

Expand All @@ -138,7 +167,7 @@ def _to_flax_state_dict(
@classmethod
def _from_flax_state_dict(
cls,
static_field_names: tp.FrozenSet[str],
static_field_names: tp.Tuple[str, ...],
pytree: P,
state: tp.Dict[str, tp.Any],
) -> P:
Expand Down Expand Up @@ -192,11 +221,13 @@ def replace(self: P, **kwargs: tp.Any) -> P:
if not tp.TYPE_CHECKING:

def __setattr__(self: P, field: str, value: tp.Any):
if (
not self._pytree__initializing
and not self._pytree__class_is_mutable
and field not in self._pytree__setter_descriptors
):
if self._pytree__initializing or field in self._pytree__setter_descriptors:
pass
elif not hasattr(self, field) and not self._pytree__initializing:
raise AttributeError(
f"Cannot add new fields to {type(self)} after initialization"
)
elif not self._pytree__class_is_mutable:
raise AttributeError(
f"{type(self)} is immutable, trying to update field {field}"
)
Expand Down
32 changes: 29 additions & 3 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ class Foo(Pytree):

assert n == 1

with pytest.raises(AttributeError, match=r"<.*> is immutable"):
foo.y = 2

def test_replace_unknown_fields_error(self):
class Foo(Pytree):
pass
Expand Down Expand Up @@ -195,6 +192,24 @@ def __new__(cls, a):

pytree = jax.tree_map(lambda x: x * 2, pytree)

def test_deterministic_order(self):
class A(Pytree):
def __init__(self, order: bool):
if order:
self.a = 1
self.b = 2
else:
self.b = 2
self.a = 1

p1 = A(order=True)
p2 = A(order=False)

leaves1 = jax.tree_util.tree_leaves(p1)
leaves2 = jax.tree_util.tree_leaves(p2)

assert leaves1 == leaves2


class TestMutablePytree:
def test_pytree(self):
Expand Down Expand Up @@ -222,6 +237,17 @@ def __init__(self, y) -> None:
pytree.x = 4
assert pytree.x == 4

def test_no_new_fields_after_init(self):
class Foo(Pytree, mutable=True):
def __init__(self, x):
self.x = x

foo = Foo(x=1)
foo.x = 2

with pytest.raises(AttributeError, match=r"Cannot add new fields to"):
foo.y = 2

def test_pytree_dataclass(self):
@dataclass
class Foo(Pytree, mutable=True):
Expand Down

0 comments on commit aa7e016

Please sign in to comment.