Skip to content

Commit

Permalink
Add Collection (#777)
Browse files Browse the repository at this point in the history
  • Loading branch information
dycw authored Oct 9, 2024
1 parent 6e46b9c commit 42afcf0
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 2 deletions.
211 changes: 210 additions & 1 deletion src/tests/test_iterables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum, auto
from itertools import repeat
from typing import TYPE_CHECKING, Any, ClassVar, Literal
Expand Down Expand Up @@ -32,6 +32,7 @@
CheckSubSetError,
CheckSuperMappingError,
CheckSuperSetError,
Collection,
EnsureIterableError,
EnsureIterableNotStrError,
OneEmptyError,
Expand Down Expand Up @@ -426,6 +427,214 @@ def test_odd(self) -> None:
assert result == expected


@dataclass(order=True, unsafe_hash=True, slots=True)
class _Item:
n: int


class TestCollection:
def test_and_singleton(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection & _Item(1)
assert isinstance(result, Collection)
expected = Collection(_Item(1))
assert result == expected

def test_and_collection(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection & Collection(_Item(1))
assert isinstance(result, Collection)
expected = Collection(_Item(1))
assert result == expected

def test_and_iterable(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection & [_Item(1)]
assert isinstance(result, Collection)
expected = Collection(_Item(1))
assert result == expected

def test_filter(self) -> None:
collection = Collection(map(_Item, range(4)))
result = collection.filter(lambda item: item.n % 2 == 0)
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_get_single_int_ok(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection.get(1)
expected = _Item(1)
assert result == expected

def test_get_single_int_fail(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection.get(3)
assert result is None

def test_get_single_item_ok(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection.get(_Item(1))
expected = _Item(1)
assert result == expected

def test_get_single_item_fail(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection.get(_Item(3))
assert result is None

def test_get_item_single_int_ok(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection[1]
expected = _Item(1)
assert result == expected

def test_get_item_single_int_fail(self) -> None:
collection = Collection(map(_Item, range(3)))
with raises(IndexError):
_ = collection[3]

def test_get_item_single_item_ok(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection[_Item(1)]
expected = _Item(1)
assert result == expected

def test_get_item_single_item_fail(self) -> None:
collection = Collection(map(_Item, range(3)))
with raises(KeyError):
_ = collection[_Item(3)]

def test_get_item_slice(self) -> None:
collection = Collection(map(_Item, range(4)))
result = collection[1:3]
assert isinstance(result, Collection)
expected = Collection(_Item(1), _Item(2))
assert result == expected

def test_get_item_multiple_ints(self) -> None:
collection = Collection(map(_Item, range(4)))
result = collection[1, 2]
assert isinstance(result, Collection)
expected = Collection(_Item(1), _Item(2))
assert result == expected

def test_get_item_multiple_items(self) -> None:
collection = Collection(map(_Item, range(4)))
result = collection[_Item(1), _Item(2)]
assert isinstance(result, Collection)
expected = Collection(_Item(1), _Item(2))
assert result == expected

def test_get_item_sequence_ints(self) -> None:
collection = Collection(map(_Item, range(4)))
result = collection[[1, 2]]
assert isinstance(result, Collection)
expected = Collection(_Item(1), _Item(2))
assert result == expected

def test_hash(self) -> None:
collection = Collection(map(_Item, range(3)))
_ = hash(collection)

def test_init_one_singleton(self) -> None:
collection = Collection(_Item(1))
assert isinstance(collection, Collection)
assert len(collection) == 1
assert one(collection) == _Item(1)

def test_init_one_iterable(self) -> None:
collection = Collection(map(_Item, range(3)))
assert isinstance(collection, Collection)
assert len(collection) == 3

def test_init_many_singletons(self) -> None:
collection = Collection(_Item(1), _Item(2), _Item(3))
assert isinstance(collection, Collection)
assert len(collection) == 3

def test_init_many_iterables(self) -> None:
collection = Collection(map(_Item, range(3)), map(_Item, range(3)))
assert isinstance(collection, Collection)
assert len(collection) == 3

def test_iter(self) -> None:
collection = Collection(map(_Item, range(3)))
result = list(collection)
expected = list(map(_Item, range(3)))
assert result == expected

def test_map_return_same_type(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection.map(lambda item: replace(item, n=item.n + 1))
assert isinstance(result, Collection)
expected = Collection(map(_Item, range(1, 4)))
assert result == expected

def test_map_return_different_type(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection.map(lambda item: item.n)
assert isinstance(result, Collection)
expected = Collection(range(3))
assert result == expected

def test_or_singleton(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection | _Item(3)
assert isinstance(result, Collection)
expected = Collection(map(_Item, range(4)))
assert result == expected

def test_or_collection(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection | Collection(map(_Item, range(1, 4)))
assert isinstance(result, Collection)
expected = Collection(map(_Item, range(4)))
assert result == expected

def test_or_iterable(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection | map(_Item, range(1, 4))
assert isinstance(result, Collection)
expected = Collection(map(_Item, range(4)))
assert result == expected

def test_sub_single_int(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - 1
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_single_item(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - _Item(1)
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_collection(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - Collection(_Item(1))
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_iterable_items(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - [_Item(1)]
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_iterable_ints(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - [1]
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected


class TestEnsureHashables:
def test_main(self) -> None:
assert ensure_hashables(1, 2, a=3, b=4) == ([1, 2], {"a": 3, "b": 4})
Expand Down
2 changes: 1 addition & 1 deletion src/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

__version__ = "0.58.12"
__version__ = "0.58.13"
106 changes: 106 additions & 0 deletions src/utilities/iterables.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
Generic,
Literal,
Never,
Protocol,
Self,
TypeGuard,
TypeVar,
assert_never,
Expand Down Expand Up @@ -482,6 +484,110 @@ def chunked(iterable: Iterable[_T], n: int, /) -> Iterator[Sequence[_T]]:
return iter(partial(take, n, iter(iterable)), [])


class _SupportsHashAndLT(Protocol):
@override
def __hash__(self) -> int: ... # pragma: no cover
def __lt__(self, other: Any, /) -> bool: ... # pragma: no cover


_TSupportsHashAndSort = TypeVar("_TSupportsHashAndSort", bound=_SupportsHashAndLT)
_USupportsHashAndSort = TypeVar("_USupportsHashAndSort", bound=_SupportsHashAndLT)


class Collection(frozenset[_TSupportsHashAndSort]):
"""A collection of hashable, sortable items."""

@override
def __new__(cls, *item_or_items: MaybeIterable[_TSupportsHashAndSort]) -> Self:
items = list(chain(*map(always_iterable, item_or_items)))
return super().__new__(cls, items)

@override
def __and__(self, other: MaybeIterable[_TSupportsHashAndSort], /) -> Self:
if isinstance(other, type(self)):
return type(self)(super().__and__(other))
return self.__and__(type(self)(always_iterable(other)))

@overload
def __getitem__(
self, item: int | _TSupportsHashAndSort, /
) -> _TSupportsHashAndSort: ...
@overload
def __getitem__(
self,
item: Sequence[int]
| slice
| tuple[int, ...]
| tuple[_TSupportsHashAndSort, ...],
/,
) -> Self: ...
def __getitem__(
self,
item: int
| slice
| Sequence[int]
| tuple[int, ...]
| tuple[_TSupportsHashAndSort, ...]
| _TSupportsHashAndSort,
/,
) -> _TSupportsHashAndSort | Self:
if isinstance(item, int):
return sorted(self)[item]
if isinstance(item, slice):
return type(self)(sorted(self)[item])
if isinstance(item, tuple):
if all(isinstance(i, int) for i in item):
item = cast(tuple[int, ...], item)
return type(self)(v for i, v in enumerate(sorted(self)) if i in item)
item = cast(tuple[_TSupportsHashAndSort, ...], item)
return self & item
if isinstance(item, Sequence):
return type(self)(v for i, v in enumerate(sorted(self)) if i in item)
try:
return one(i for i in self if i == item)
except OneEmptyError:
raise KeyError(item) from None

@override
def __iter__(self) -> Iterator[_TSupportsHashAndSort]:
yield from sorted(super().__iter__())

@override
def __or__(self, other: MaybeIterable[_TSupportsHashAndSort], /) -> Self: # pyright: ignore[reportIncompatibleMethodOverride]
if isinstance(other, type(self)):
return type(self)(super().__or__(other))
return self.__or__(type(self)(other))

@override
def __sub__(
self, other: int | Sequence[int] | MaybeIterable[_TSupportsHashAndSort], /
) -> Self:
if isinstance(other, int):
return self - self[other]
if isinstance(other, Sequence) and all(isinstance(i, int) for i in other):
other = cast(Sequence[int], other)
return self - self[other]
if isinstance(other, type(self)):
return type(self)(super().__sub__(other))
other = cast(Iterable[_TSupportsHashAndSort], other)
return self.__sub__(type(self)(other))

def filter(self, func: Callable[[_TSupportsHashAndSort], bool], /) -> Self:
return type(self)(filter(func, self))

def get(self, item: int | _TSupportsHashAndSort, /) -> _TSupportsHashAndSort | None:
try:
return self[item]
except (IndexError, KeyError):
return None

def map(
self, func: Callable[[_TSupportsHashAndSort], _USupportsHashAndSort], /
) -> Collection[_USupportsHashAndSort]:
values = cast(Any, map(func, self))
return cast(Any, type(self)(values))


def ensure_hashables(
*args: Any, **kwargs: Any
) -> tuple[list[Hashable], dict[str, Hashable]]:
Expand Down

0 comments on commit 42afcf0

Please sign in to comment.