From 42afcf0c83726ab9dc0496fcaebe8d8b7363fe65 Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Wed, 9 Oct 2024 20:28:21 +0900 Subject: [PATCH] Add `Collection` (#777) --- src/tests/test_iterables.py | 211 +++++++++++++++++++++++++++++++++++- src/utilities/__init__.py | 2 +- src/utilities/iterables.py | 106 ++++++++++++++++++ 3 files changed, 317 insertions(+), 2 deletions(-) diff --git a/src/tests/test_iterables.py b/src/tests/test_iterables.py index bb7bf5055..1245cf308 100644 --- a/src/tests/test_iterables.py +++ b/src/tests/test_iterables.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum, auto from itertools import repeat from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -32,6 +32,7 @@ CheckSubSetError, CheckSuperMappingError, CheckSuperSetError, + Collection, EnsureIterableError, EnsureIterableNotStrError, OneEmptyError, @@ -426,6 +427,214 @@ def test_odd(self) -> None: assert result == expected +@dataclass(order=True, unsafe_hash=True, slots=True) +class _Item: + n: int + + +class TestCollection: + def test_and_singleton(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection & _Item(1) + assert isinstance(result, Collection) + expected = Collection(_Item(1)) + assert result == expected + + def test_and_collection(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection & Collection(_Item(1)) + assert isinstance(result, Collection) + expected = Collection(_Item(1)) + assert result == expected + + def test_and_iterable(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection & [_Item(1)] + assert isinstance(result, Collection) + expected = Collection(_Item(1)) + assert result == expected + + def test_filter(self) -> None: + collection = Collection(map(_Item, range(4))) + result = collection.filter(lambda item: item.n % 2 == 0) + assert isinstance(result, Collection) + expected = Collection(_Item(0), _Item(2)) + assert result == expected + + def test_get_single_int_ok(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection.get(1) + expected = _Item(1) + assert result == expected + + def test_get_single_int_fail(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection.get(3) + assert result is None + + def test_get_single_item_ok(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection.get(_Item(1)) + expected = _Item(1) + assert result == expected + + def test_get_single_item_fail(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection.get(_Item(3)) + assert result is None + + def test_get_item_single_int_ok(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection[1] + expected = _Item(1) + assert result == expected + + def test_get_item_single_int_fail(self) -> None: + collection = Collection(map(_Item, range(3))) + with raises(IndexError): + _ = collection[3] + + def test_get_item_single_item_ok(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection[_Item(1)] + expected = _Item(1) + assert result == expected + + def test_get_item_single_item_fail(self) -> None: + collection = Collection(map(_Item, range(3))) + with raises(KeyError): + _ = collection[_Item(3)] + + def test_get_item_slice(self) -> None: + collection = Collection(map(_Item, range(4))) + result = collection[1:3] + assert isinstance(result, Collection) + expected = Collection(_Item(1), _Item(2)) + assert result == expected + + def test_get_item_multiple_ints(self) -> None: + collection = Collection(map(_Item, range(4))) + result = collection[1, 2] + assert isinstance(result, Collection) + expected = Collection(_Item(1), _Item(2)) + assert result == expected + + def test_get_item_multiple_items(self) -> None: + collection = Collection(map(_Item, range(4))) + result = collection[_Item(1), _Item(2)] + assert isinstance(result, Collection) + expected = Collection(_Item(1), _Item(2)) + assert result == expected + + def test_get_item_sequence_ints(self) -> None: + collection = Collection(map(_Item, range(4))) + result = collection[[1, 2]] + assert isinstance(result, Collection) + expected = Collection(_Item(1), _Item(2)) + assert result == expected + + def test_hash(self) -> None: + collection = Collection(map(_Item, range(3))) + _ = hash(collection) + + def test_init_one_singleton(self) -> None: + collection = Collection(_Item(1)) + assert isinstance(collection, Collection) + assert len(collection) == 1 + assert one(collection) == _Item(1) + + def test_init_one_iterable(self) -> None: + collection = Collection(map(_Item, range(3))) + assert isinstance(collection, Collection) + assert len(collection) == 3 + + def test_init_many_singletons(self) -> None: + collection = Collection(_Item(1), _Item(2), _Item(3)) + assert isinstance(collection, Collection) + assert len(collection) == 3 + + def test_init_many_iterables(self) -> None: + collection = Collection(map(_Item, range(3)), map(_Item, range(3))) + assert isinstance(collection, Collection) + assert len(collection) == 3 + + def test_iter(self) -> None: + collection = Collection(map(_Item, range(3))) + result = list(collection) + expected = list(map(_Item, range(3))) + assert result == expected + + def test_map_return_same_type(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection.map(lambda item: replace(item, n=item.n + 1)) + assert isinstance(result, Collection) + expected = Collection(map(_Item, range(1, 4))) + assert result == expected + + def test_map_return_different_type(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection.map(lambda item: item.n) + assert isinstance(result, Collection) + expected = Collection(range(3)) + assert result == expected + + def test_or_singleton(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection | _Item(3) + assert isinstance(result, Collection) + expected = Collection(map(_Item, range(4))) + assert result == expected + + def test_or_collection(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection | Collection(map(_Item, range(1, 4))) + assert isinstance(result, Collection) + expected = Collection(map(_Item, range(4))) + assert result == expected + + def test_or_iterable(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection | map(_Item, range(1, 4)) + assert isinstance(result, Collection) + expected = Collection(map(_Item, range(4))) + assert result == expected + + def test_sub_single_int(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection - 1 + assert isinstance(result, Collection) + expected = Collection(_Item(0), _Item(2)) + assert result == expected + + def test_sub_single_item(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection - _Item(1) + assert isinstance(result, Collection) + expected = Collection(_Item(0), _Item(2)) + assert result == expected + + def test_sub_collection(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection - Collection(_Item(1)) + assert isinstance(result, Collection) + expected = Collection(_Item(0), _Item(2)) + assert result == expected + + def test_sub_iterable_items(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection - [_Item(1)] + assert isinstance(result, Collection) + expected = Collection(_Item(0), _Item(2)) + assert result == expected + + def test_sub_iterable_ints(self) -> None: + collection = Collection(map(_Item, range(3))) + result = collection - [1] + assert isinstance(result, Collection) + expected = Collection(_Item(0), _Item(2)) + assert result == expected + + class TestEnsureHashables: def test_main(self) -> None: assert ensure_hashables(1, 2, a=3, b=4) == ([1, 2], {"a": 3, "b": 4}) diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py index 5f1fe437a..4eff862e2 100644 --- a/src/utilities/__init__.py +++ b/src/utilities/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "0.58.12" +__version__ = "0.58.13" diff --git a/src/utilities/iterables.py b/src/utilities/iterables.py index 8693ae81c..82b5a1a7e 100644 --- a/src/utilities/iterables.py +++ b/src/utilities/iterables.py @@ -22,6 +22,8 @@ Generic, Literal, Never, + Protocol, + Self, TypeGuard, TypeVar, assert_never, @@ -482,6 +484,110 @@ def chunked(iterable: Iterable[_T], n: int, /) -> Iterator[Sequence[_T]]: return iter(partial(take, n, iter(iterable)), []) +class _SupportsHashAndLT(Protocol): + @override + def __hash__(self) -> int: ... # pragma: no cover + def __lt__(self, other: Any, /) -> bool: ... # pragma: no cover + + +_TSupportsHashAndSort = TypeVar("_TSupportsHashAndSort", bound=_SupportsHashAndLT) +_USupportsHashAndSort = TypeVar("_USupportsHashAndSort", bound=_SupportsHashAndLT) + + +class Collection(frozenset[_TSupportsHashAndSort]): + """A collection of hashable, sortable items.""" + + @override + def __new__(cls, *item_or_items: MaybeIterable[_TSupportsHashAndSort]) -> Self: + items = list(chain(*map(always_iterable, item_or_items))) + return super().__new__(cls, items) + + @override + def __and__(self, other: MaybeIterable[_TSupportsHashAndSort], /) -> Self: + if isinstance(other, type(self)): + return type(self)(super().__and__(other)) + return self.__and__(type(self)(always_iterable(other))) + + @overload + def __getitem__( + self, item: int | _TSupportsHashAndSort, / + ) -> _TSupportsHashAndSort: ... + @overload + def __getitem__( + self, + item: Sequence[int] + | slice + | tuple[int, ...] + | tuple[_TSupportsHashAndSort, ...], + /, + ) -> Self: ... + def __getitem__( + self, + item: int + | slice + | Sequence[int] + | tuple[int, ...] + | tuple[_TSupportsHashAndSort, ...] + | _TSupportsHashAndSort, + /, + ) -> _TSupportsHashAndSort | Self: + if isinstance(item, int): + return sorted(self)[item] + if isinstance(item, slice): + return type(self)(sorted(self)[item]) + if isinstance(item, tuple): + if all(isinstance(i, int) for i in item): + item = cast(tuple[int, ...], item) + return type(self)(v for i, v in enumerate(sorted(self)) if i in item) + item = cast(tuple[_TSupportsHashAndSort, ...], item) + return self & item + if isinstance(item, Sequence): + return type(self)(v for i, v in enumerate(sorted(self)) if i in item) + try: + return one(i for i in self if i == item) + except OneEmptyError: + raise KeyError(item) from None + + @override + def __iter__(self) -> Iterator[_TSupportsHashAndSort]: + yield from sorted(super().__iter__()) + + @override + def __or__(self, other: MaybeIterable[_TSupportsHashAndSort], /) -> Self: # pyright: ignore[reportIncompatibleMethodOverride] + if isinstance(other, type(self)): + return type(self)(super().__or__(other)) + return self.__or__(type(self)(other)) + + @override + def __sub__( + self, other: int | Sequence[int] | MaybeIterable[_TSupportsHashAndSort], / + ) -> Self: + if isinstance(other, int): + return self - self[other] + if isinstance(other, Sequence) and all(isinstance(i, int) for i in other): + other = cast(Sequence[int], other) + return self - self[other] + if isinstance(other, type(self)): + return type(self)(super().__sub__(other)) + other = cast(Iterable[_TSupportsHashAndSort], other) + return self.__sub__(type(self)(other)) + + def filter(self, func: Callable[[_TSupportsHashAndSort], bool], /) -> Self: + return type(self)(filter(func, self)) + + def get(self, item: int | _TSupportsHashAndSort, /) -> _TSupportsHashAndSort | None: + try: + return self[item] + except (IndexError, KeyError): + return None + + def map( + self, func: Callable[[_TSupportsHashAndSort], _USupportsHashAndSort], / + ) -> Collection[_USupportsHashAndSort]: + values = cast(Any, map(func, self)) + return cast(Any, type(self)(values)) + + def ensure_hashables( *args: Any, **kwargs: Any ) -> tuple[list[Hashable], dict[str, Hashable]]: