Skip to content

Commit

Permalink
simplify arraylib, add 3.12 to ci
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Oct 12, 2023
1 parent 38c10c1 commit cb67b3a
Show file tree
Hide file tree
Showing 21 changed files with 221 additions and 261 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_numpy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Nested tree tools in python
![Tests](https://github.com/ASEM000/sepes/actions/workflows/test_jax.yml/badge.svg)
![Tests](https://github.com/ASEM000/sepes/actions/workflows/test_numpy.yml/badge.svg)
![Tests](https://github.com/ASEM000/sepes/actions/workflows/test_torch.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-blue)
![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11%203.12_-blue)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Downloads](https://static.pepy.tech/badge/sepes)](https://pepy.tech/project/sepes)
[![codecov](https://codecov.io/gh/ASEM000/sepes/branch/main/graph/badge.svg?token=TZBRMO0UQH)](https://codecov.io/gh/ASEM000/sepes)
Expand Down
3 changes: 0 additions & 3 deletions sepes/_src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from typing import Literal, Callable
import logging
from contextlib import contextmanager
from sepes._src.backend.arraylib.base import ArrayLib

arraylib = ArrayLib()


@ft.lru_cache(maxsize=None)
Expand Down
20 changes: 20 additions & 0 deletions sepes/_src/backend/arraylib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backend tools for sepes."""

from __future__ import annotations
import functools as ft

tobytes = ft.singledispatch(lambda array: ...)
where = ft.singledispatch(lambda condition, x, y: ...)
nbytes = ft.singledispatch(lambda array: ...)
shape = ft.singledispatch(lambda array: ...)
dtype = ft.singledispatch(lambda array: ...)
min = ft.singledispatch(lambda array: ...)
max = ft.singledispatch(lambda array: ...)
mean = ft.singledispatch(lambda array: ...)
std = ft.singledispatch(lambda array: ...)
all = ft.singledispatch(lambda array: ...)
is_floating = ft.singledispatch(lambda array: ...)
is_integer = ft.singledispatch(lambda array: ...)
is_inexact = ft.singledispatch(lambda array: ...)
is_bool = ft.singledispatch(lambda array: ...)
ndarrays: tuple[type, ...] = ()
51 changes: 0 additions & 51 deletions sepes/_src/backend/arraylib/base.py

This file was deleted.

32 changes: 16 additions & 16 deletions sepes/_src/backend/arraylib/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@

from jax import Array
import jax.numpy as jnp
from sepes._src.backend.arraylib.base import ArrayLib
import sepes._src.backend.arraylib as arraylib

ArrayLib.tobytes.register(Array, lambda x: jnp.array(x).tobytes())
ArrayLib.where.register(Array, jnp.where)
ArrayLib.nbytes.register(Array, lambda x: x.nbytes)
ArrayLib.shape.register(Array, jnp.shape)
ArrayLib.dtype.register(Array, lambda x: x.dtype)
ArrayLib.min.register(Array, jnp.min)
ArrayLib.max.register(Array, jnp.max)
ArrayLib.mean.register(Array, jnp.mean)
ArrayLib.std.register(Array, jnp.std)
ArrayLib.all.register(Array, jnp.all)
ArrayLib.is_floating.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.floating))
ArrayLib.is_integer.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.integer))
ArrayLib.is_inexact.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.inexact))
ArrayLib.is_bool.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.bool_))
ArrayLib.ndarrays += (Array,)
arraylib.tobytes.register(Array, lambda x: jnp.array(x).tobytes())
arraylib.where.register(Array, jnp.where)
arraylib.nbytes.register(Array, lambda x: x.nbytes)
arraylib.shape.register(Array, jnp.shape)
arraylib.dtype.register(Array, lambda x: x.dtype)
arraylib.min.register(Array, jnp.min)
arraylib.max.register(Array, jnp.max)
arraylib.mean.register(Array, jnp.mean)
arraylib.std.register(Array, jnp.std)
arraylib.all.register(Array, jnp.all)
arraylib.is_floating.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.floating))
arraylib.is_integer.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.integer))
arraylib.is_inexact.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.inexact))
arraylib.is_bool.register(Array, lambda x: jnp.issubdtype(x.dtype, jnp.bool_))
arraylib.ndarrays += (Array,)
32 changes: 16 additions & 16 deletions sepes/_src/backend/arraylib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@

import numpy as np
from numpy import ndarray
from sepes._src.backend.arraylib.base import ArrayLib
import sepes._src.backend.arraylib as arraylib

ArrayLib.tobytes.register(ndarray, lambda x: np.array(x).tobytes())
ArrayLib.where.register(ndarray, np.where)
ArrayLib.nbytes.register(ndarray, lambda x: x.nbytes)
ArrayLib.shape.register(ndarray, np.shape)
ArrayLib.dtype.register(ndarray, lambda x: x.dtype)
ArrayLib.min.register(ndarray, np.min)
ArrayLib.max.register(ndarray, np.max)
ArrayLib.mean.register(ndarray, np.mean)
ArrayLib.std.register(ndarray, np.std)
ArrayLib.all.register(ndarray, np.all)
ArrayLib.is_floating.register(ndarray, lambda x: np.issubdtype(x.dtype, np.floating))
ArrayLib.is_integer.register(ndarray, lambda x: np.issubdtype(x.dtype, np.integer))
ArrayLib.is_inexact.register(ndarray, lambda x: np.issubdtype(x.dtype, np.inexact))
ArrayLib.is_bool.register(ndarray, lambda x: np.issubdtype(x.dtype, np.bool_))
ArrayLib.ndarrays += (ndarray,)
arraylib.tobytes.register(ndarray, lambda x: np.array(x).tobytes())
arraylib.where.register(ndarray, np.where)
arraylib.nbytes.register(ndarray, lambda x: x.nbytes)
arraylib.shape.register(ndarray, np.shape)
arraylib.dtype.register(ndarray, lambda x: x.dtype)
arraylib.min.register(ndarray, np.min)
arraylib.max.register(ndarray, np.max)
arraylib.mean.register(ndarray, np.mean)
arraylib.std.register(ndarray, np.std)
arraylib.all.register(ndarray, np.all)
arraylib.is_floating.register(ndarray, lambda x: np.issubdtype(x.dtype, np.floating))
arraylib.is_integer.register(ndarray, lambda x: np.issubdtype(x.dtype, np.integer))
arraylib.is_inexact.register(ndarray, lambda x: np.issubdtype(x.dtype, np.inexact))
arraylib.is_bool.register(ndarray, lambda x: np.issubdtype(x.dtype, np.bool_))
arraylib.ndarrays += (ndarray,)
32 changes: 16 additions & 16 deletions sepes/_src/backend/arraylib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
import numpy as np
import torch
from torch import Tensor
from sepes._src.backend.arraylib.base import ArrayLib
import sepes._src.backend.arraylib as arraylib

floatings = [torch.float16, torch.float32, torch.float64]
complexes = [torch.complex32, torch.complex64, torch.complex128]
integers = [torch.int8, torch.int16, torch.int32, torch.int64]

ArrayLib.tobytes.register(Tensor, lambda x: np.from_dlpack(x).tobytes())
ArrayLib.where.register(Tensor, torch.where)
ArrayLib.nbytes.register(Tensor, lambda x: x.nbytes)
ArrayLib.shape.register(Tensor, lambda x: tuple(x.shape))
ArrayLib.dtype.register(Tensor, lambda x: x.dtype)
ArrayLib.min.register(Tensor, torch.min)
ArrayLib.max.register(Tensor, torch.max)
ArrayLib.mean.register(Tensor, torch.mean)
ArrayLib.std.register(Tensor, torch.std)
ArrayLib.all.register(Tensor, torch.all)
ArrayLib.is_floating.register(Tensor, lambda x: x.dtype in floatings)
ArrayLib.is_integer.register(Tensor, lambda x: x.dtype in integers)
ArrayLib.is_inexact.register(Tensor, lambda x: x.dtype in floatings + complexes)
ArrayLib.is_bool.register(Tensor, lambda x: x.dtype == torch.bool)
ArrayLib.ndarrays += (Tensor,)
arraylib.tobytes.register(Tensor, lambda x: np.from_dlpack(x).tobytes())
arraylib.where.register(Tensor, torch.where)
arraylib.nbytes.register(Tensor, lambda x: x.nbytes)
arraylib.shape.register(Tensor, lambda x: tuple(x.shape))
arraylib.dtype.register(Tensor, lambda x: x.dtype)
arraylib.min.register(Tensor, torch.min)
arraylib.max.register(Tensor, torch.max)
arraylib.mean.register(Tensor, torch.mean)
arraylib.std.register(Tensor, torch.std)
arraylib.all.register(Tensor, torch.all)
arraylib.is_floating.register(Tensor, lambda x: x.dtype in floatings)
arraylib.is_integer.register(Tensor, lambda x: x.dtype in integers)
arraylib.is_inexact.register(Tensor, lambda x: x.dtype in floatings + complexes)
arraylib.is_bool.register(Tensor, lambda x: x.dtype == torch.bool)
arraylib.ndarrays += (Tensor,)
122 changes: 122 additions & 0 deletions sepes/_src/backend/treelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,125 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import abc
import os
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Callable, Hashable, Iterable, Literal, Tuple, TypedDict, TypeVar

# optree namespace identifier
namespace: str = os.environ.get("SEPES_NAMESPACE", "sepes")

Tree = TypeVar("Tree", bound=Any)
Leaf = TypeVar("Leaf", bound=Any)
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
KeyPath = Tuple[KeyEntry, ...]
KeyPathLeaf = Tuple[KeyPath, Leaf]
pool_map = dict(thread=ThreadPoolExecutor, process=ProcessPoolExecutor)


class ParallelConfig(TypedDict):
max_workers: int | None
kind: Literal["thread", "process"]


def raise_future_execption(future):
raise future.exception()


def concurrent_map(
func: Callable[..., Any],
flat: Iterable[Any],
max_workers: int | None = None,
kind: Literal["thread", "process"] = "thread",
) -> Iterable[Any]:
with (executor := pool_map[kind](max_workers)) as executor:
futures = [executor.submit(func, *args) for args in zip(*flat)]

return [
future.result()
if future.exception() is None
else raise_future_execption(future)
for future in futures
]


class AbstractTreeLib(abc.ABC):
"""The minimal interface for tree operations used by sepes."""

@staticmethod
@abc.abstractmethod
def tree_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[[Any], None] | None = None,
is_parallel: bool | ParallelConfig = False,
) -> Any:
...

@staticmethod
@abc.abstractmethod
def tree_path_map(
func: Callable[..., Any],
tree: Any,
*rest: Any,
is_leaf: Callable[[Any], bool] | None = None,
is_parallel: bool | ParallelConfig = False,
) -> Any:
...

@staticmethod
@abc.abstractmethod
def tree_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> tuple[Iterable[Leaf], Any]:
...

@staticmethod
@abc.abstractmethod
def tree_path_flatten(
tree: Any,
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> tuple[Iterable[KeyPathLeaf], Any]:
...

@staticmethod
@abc.abstractmethod
def tree_unflatten(treedef: Any, leaves: Iterable[Any]) -> Any:
...

@staticmethod
@abc.abstractmethod
def register_treeclass(klass: type[Tree]) -> None:
...

@staticmethod
@abc.abstractmethod
def register_static(klass: type[Tree]) -> None:
...

@staticmethod
@abc.abstractmethod
def attribute_key(name: str) -> Any:
...

@staticmethod
@abc.abstractmethod
def sequence_key(index: int) -> Any:
...

@staticmethod
@abc.abstractmethod
def dict_key(key: Hashable) -> Any:
...

@staticmethod
@abc.abstractmethod
def keystr(keys: Any) -> str:
...
Loading

0 comments on commit cb67b3a

Please sign in to comment.