Skip to content

Commit

Permalink
define arraylib.array_equal
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Mar 27, 2024
1 parent 7f437f9 commit 173aaf7
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 65 deletions.
20 changes: 3 additions & 17 deletions sepes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,9 @@
from sepes._src.code_build import autoinit, field, fields
from sepes._src.tree_base import TreeClass
from sepes._src.tree_index import AtIndexer, at
from sepes._src.tree_mask import (
is_masked,
is_nondiff,
tree_mask,
tree_unmask,
)
from sepes._src.tree_pprint import (
tree_diagram,
tree_repr,
tree_str,
tree_summary,
)
from sepes._src.tree_util import (
bcmap,
leafwise,
value_and_tree,
)
from sepes._src.tree_mask import is_masked, is_nondiff, tree_mask, tree_unmask
from sepes._src.tree_pprint import tree_diagram, tree_repr, tree_str, tree_summary
from sepes._src.tree_util import bcmap, leafwise, value_and_tree

__all__ = (
# general utils
Expand Down
41 changes: 26 additions & 15 deletions sepes/_src/backend/arraylib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,31 @@
"""Backend tools for sepes."""

from __future__ import annotations

import functools as ft
from typing import Callable, NamedTuple


class NoImplError(NamedTuple):
op: Callable

def __call__(self, *_, **__):
raise NotImplementedError(f"No implementation for {self.op}")


tobytes = ft.singledispatch(lambda array: ...)
where = ft.singledispatch(lambda condition, x, y: ...)
nbytes = ft.singledispatch(lambda array: ...)
shape = ft.singledispatch(lambda array: ...)
dtype = ft.singledispatch(lambda array: ...)
min = ft.singledispatch(lambda array: ...)
max = ft.singledispatch(lambda array: ...)
mean = ft.singledispatch(lambda array: ...)
std = ft.singledispatch(lambda array: ...)
all = ft.singledispatch(lambda array: ...)
is_floating = ft.singledispatch(lambda array: ...)
is_integer = ft.singledispatch(lambda array: ...)
is_inexact = ft.singledispatch(lambda array: ...)
is_bool = ft.singledispatch(lambda array: ...)
ndarrays: tuple[type, ...] = ()
tobytes = ft.singledispatch(NoImplError("tobytes"))
where = ft.singledispatch(NoImplError("where"))
nbytes = ft.singledispatch(NoImplError("nbytes"))
shape = ft.singledispatch(NoImplError("shape"))
dtype = ft.singledispatch(NoImplError("dtype"))
min = ft.singledispatch(NoImplError("min"))
max = ft.singledispatch(NoImplError("max"))
mean = ft.singledispatch(NoImplError("mean"))
std = ft.singledispatch(NoImplError("std"))
all = ft.singledispatch(NoImplError("all"))
array_equal = ft.singledispatch(NoImplError("array_equal"))
is_floating = ft.singledispatch(NoImplError("is_floating"))
is_integer = ft.singledispatch(NoImplError("is_integer"))
is_inexact = ft.singledispatch(NoImplError("is_inexact"))
is_bool = ft.singledispatch(NoImplError("is_bool"))
ndarrays: list[type] = []
10 changes: 6 additions & 4 deletions sepes/_src/backend/arraylib/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from __future__ import annotations


from jax import Array
import jax.numpy as jnp
import numpy as np
from jax import Array

import sepes._src.backend.arraylib as arraylib

arraylib.tobytes.register(Array, lambda x: jnp.array(x).tobytes())
arraylib.tobytes.register(Array, lambda x: np.array(x).tobytes())
arraylib.where.register(Array, jnp.where)
arraylib.nbytes.register(Array, lambda x: x.nbytes)
arraylib.shape.register(Array, jnp.shape)
Expand All @@ -29,8 +30,9 @@
arraylib.mean.register(Array, jnp.mean)
arraylib.std.register(Array, jnp.std)
arraylib.all.register(Array, jnp.all)
arraylib.array_equal.register(Array, np.array_equal) # NOTE: not traceable
arraylib.is_floating.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.floating))
arraylib.is_integer.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.integer))
arraylib.is_inexact.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.inexact))
arraylib.is_bool.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.bool_))
arraylib.ndarrays += (Array,)
arraylib.ndarrays.append(Array)
4 changes: 3 additions & 1 deletion sepes/_src/backend/arraylib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
from numpy import ndarray

import sepes._src.backend.arraylib as arraylib

arraylib.tobytes.register(ndarray, lambda x: np.array(x).tobytes())
Expand All @@ -28,8 +29,9 @@
arraylib.mean.register(ndarray, np.mean)
arraylib.std.register(ndarray, np.std)
arraylib.all.register(ndarray, np.all)
arraylib.array_equal.register(ndarray, np.array_equal)
arraylib.is_floating.register(ndarray, lambda x: np.issubdtype(x.dtype, np.floating))
arraylib.is_integer.register(ndarray, lambda x: np.issubdtype(x.dtype, np.integer))
arraylib.is_inexact.register(ndarray, lambda x: np.issubdtype(x.dtype, np.inexact))
arraylib.is_bool.register(ndarray, lambda x: np.issubdtype(x.dtype, np.bool_))
arraylib.ndarrays += (ndarray,)
arraylib.ndarrays.append(ndarray)
4 changes: 3 additions & 1 deletion sepes/_src/backend/arraylib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import torch
from torch import Tensor

import sepes._src.backend.arraylib as arraylib

floatings = [torch.float16, torch.float32, torch.float64]
Expand All @@ -33,8 +34,9 @@
arraylib.mean.register(Tensor, torch.mean)
arraylib.std.register(Tensor, torch.std)
arraylib.all.register(Tensor, torch.all)
arraylib.array_equal.register(Tensor, torch.equal)
arraylib.is_floating.register(Tensor, lambda x: x.dtype in floatings)
arraylib.is_integer.register(Tensor, lambda x: x.dtype in integers)
arraylib.is_inexact.register(Tensor, lambda x: x.dtype in floatings + complexes)
arraylib.is_bool.register(Tensor, lambda x: x.dtype == torch.bool)
arraylib.ndarrays += (Tensor,)
arraylib.ndarrays.append(Tensor)
5 changes: 2 additions & 3 deletions sepes/_src/tree_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
import abc
from typing import Any, Hashable, TypeVar

from typing_extensions import Unpack
from typing_extensions import Self, Unpack

import sepes
from sepes._src.code_build import fields
from sepes._src.tree_index import AtIndexer
from sepes._src.tree_pprint import PPSpec, tree_repr, tree_str
from sepes._src.tree_util import is_tree_equal, tree_copy, tree_hash, value_and_tree
from typing_extensions import Self
from sepes._src.tree_index import AtIndexer

T = TypeVar("T", bound=Hashable)
S = TypeVar("S")
Expand Down
18 changes: 12 additions & 6 deletions sepes/_src/tree_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"""

from __future__ import annotations
from sepes._src.backend import is_package_avaiable

import abc
import functools as ft
import re
Expand All @@ -45,6 +45,7 @@

import sepes
import sepes._src.backend.arraylib as arraylib
from sepes._src.backend import is_package_avaiable
from sepes._src.backend.treelib import ParallelConfig
from sepes._src.tree_pprint import tree_repr

Expand Down Expand Up @@ -177,6 +178,7 @@ def resolve_where(
is_leaf: Callable[[Any], bool] | None = None,
):
treelib = sepes._src.backend.treelib
ndarrays = tuple(arraylib.ndarrays)

def combine_bool_leaves(*leaves):
# given a list of boolean leaves, combine them using `and`
Expand All @@ -188,7 +190,7 @@ def combine_bool_leaves(*leaves):
return verdict

def is_bool_leaf(leaf: Any) -> bool:
if isinstance(leaf, arraylib.ndarrays):
if isinstance(leaf, ndarrays):
return arraylib.is_bool(leaf)
return isinstance(leaf, bool)

Expand Down Expand Up @@ -342,14 +344,15 @@ def get(
{'a': None, 'b': [1, None, None]}
"""
treelib = sepes._src.backend.treelib
ndarrays = tuple(arraylib.ndarrays)

def leaf_get(where: Any, leaf: Any):
# support both array and non-array leaves
# for array boolean mask we select **parts** of the array that
# matches the mask, for example if the mask is Array([True, False, False])
# and the leaf is Array([1, 2, 3]) then the result is Array([1])
# because of the variable resultant size of the output
if isinstance(where, arraylib.ndarrays) and len(arraylib.shape(where)):
if isinstance(where, ndarrays) and len(arraylib.shape(where)):
if fill_value is not _no_fill_value:
return arraylib.where(where, leaf, fill_value)
return leaf[where]
Expand Down Expand Up @@ -398,14 +401,15 @@ def set(
{'a': 1, 'b': [100, 2, 3]}
"""
treelib = sepes._src.backend.treelib
ndarrays = tuple(arraylib.ndarrays)

def leaf_set(where: Any, leaf: Any, set_value: Any):
# support both array and non-array leaves
# for array boolean mask we select **parts** of the array that
# matches the mask, for example if the mask is Array([True, False, False])
# and the leaf is Array([1, 2, 3]) then the result is Array([1, 100, 100])
# with set_value = 100
if isinstance(where, arraylib.ndarrays):
if isinstance(where, ndarrays):
return arraylib.where(where, set_value, leaf)
return set_value if where else leaf

Expand Down Expand Up @@ -475,13 +479,14 @@ def apply(
>>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel) # doctest: +SKIP
"""
treelib = sepes._src.backend.treelib
ndarrays = tuple(arraylib.ndarrays)

def leaf_apply(where: Any, leaf: Any):
# same as `leaf_set` but with `func` applied to the leaf
# one thing to note is that, the where mask select an array
# then the function needs work properly when applied to the selected
# array elements
if isinstance(where, arraylib.ndarrays):
if isinstance(where, ndarrays):
return arraylib.where(where, func(leaf), leaf)
return func(leaf) if where else leaf

Expand Down Expand Up @@ -534,6 +539,7 @@ def scan(
leaf values while carrying a state and returning a single value.
"""
treelib = sepes._src.backend.treelib
ndarrays = tuple(arraylib.ndarrays)
running_state = state

def stateless_func(leaf):
Expand All @@ -542,7 +548,7 @@ def stateless_func(leaf):
return leaf

def leaf_apply(where: Any, leaf: Any):
if isinstance(where, arraylib.ndarrays):
if isinstance(where, ndarrays):
return arraylib.where(where, stateless_func(leaf), leaf)
return stateless_func(leaf) if where else leaf

Expand Down
7 changes: 6 additions & 1 deletion sepes/_src/tree_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def _(node) -> str:
return f"#{tree_summary.type_dispatcher(node.__wrapped__)}"


def hash_array(value: T) -> int:
bytes = arraylib.tobytes(value)
return int(hashlib.sha256(bytes).hexdigest(), 16)


class _MaskedHashable(_MaskBase):
def __hash__(self) -> int:
return tree_hash(self.__wrapped__)
Expand Down Expand Up @@ -120,7 +125,7 @@ def __eq__(self, other) -> bool:
return False
if arraylib.dtype(lhs) != arraylib.dtype(rhs):
return False
return arraylib.all(lhs == rhs)
return arraylib.array_equal(lhs, rhs)


def mask(value: T) -> _MaskBase[T]:
Expand Down
10 changes: 6 additions & 4 deletions sepes/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@

from __future__ import annotations

import copy
import functools as ft
import operator as op
import copy
from math import ceil, floor, trunc
from typing import Any, Callable, Hashable, Iterator, Sequence, Tuple, TypeVar, Generic
import sepes._src.backend.arraylib as arraylib
from sepes._src.backend import is_package_avaiable
from typing import Any, Callable, Generic, Hashable, Iterator, Sequence, Tuple, TypeVar

from typing_extensions import ParamSpec

import sepes
import sepes._src.backend.arraylib as arraylib
from sepes._src.backend import is_package_avaiable

T = TypeVar("T")
T1 = TypeVar("T1")
Expand Down
10 changes: 3 additions & 7 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@
# limitations under the License.

import copy
import functools as ft
import os
from typing import Any

import pytest

from sepes._src.backend import backend, treelib
from sepes._src.code_build import autoinit
from sepes._src.tree_base import TreeClass
from sepes._src.tree_mask import (
is_masked,
tree_mask,
tree_unmask,
)
import functools as ft
import os
from sepes._src.tree_mask import is_masked, tree_mask, tree_unmask
from sepes._src.tree_util import is_tree_equal, leafwise, tree_hash

test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
Expand Down
5 changes: 2 additions & 3 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
from __future__ import annotations

import math
import os
from typing import Any

import pytest

from sepes._src.backend import backend
from sepes._src.code_build import autoinit, field
from sepes._src.tree_base import TreeClass
from sepes._src.tree_util import bcmap, is_tree_equal, leafwise
import os

from sepes._src.tree_mask import tree_mask
from sepes._src.tree_util import bcmap, is_tree_equal, leafwise

freeze = lambda x: tree_mask(x, cond=lambda _: True)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from __future__ import annotations

import dataclasses as dc
import os
from collections import namedtuple
from typing import Any

import pytest
import os

test_arraylib = os.environ.get("SEPES_TEST_ARRAYLIB", "numpy")
backend = os.environ.get("SEPES_BACKEND", "jax")
Expand Down
5 changes: 3 additions & 2 deletions tests/test_treeclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import copy
import dataclasses as dc
import inspect
import os
from typing import Any

import numpy.testing as npt
import pytest
import os

from sepes._src.backend import backend, treelib
from sepes._src.code_build import (
autoinit,
Expand All @@ -29,8 +30,8 @@
fields,
)
from sepes._src.tree_base import TreeClass
from sepes._src.tree_util import partial, is_tree_equal, value_and_tree
from sepes._src.tree_mask import tree_mask
from sepes._src.tree_util import is_tree_equal, partial, value_and_tree

freeze = lambda x: tree_mask(x, cond=lambda _: True)

Expand Down

0 comments on commit 173aaf7

Please sign in to comment.