Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
dycw committed Oct 9, 2024
1 parent 6a09a08 commit 5043d20
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
57 changes: 53 additions & 4 deletions src/tests/test_iterables.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,18 @@ def test_and_singleton(self) -> None:
expected = Collection(_Item(1))
assert result == expected

def test_and_collection(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection & Collection(_Item(1))
assert isinstance(result, Collection)
expected = Collection(_Item(1))
assert result == expected

def test_and_iterable(self) -> None:
collection = Collection(map(_Item, range(4)))
result = collection & Collection(_Item(1), _Item(2))
collection = Collection(map(_Item, range(3)))
result = collection & [_Item(1)]
assert isinstance(result, Collection)
expected = Collection(_Item(1), _Item(2))
expected = Collection(_Item(1))
assert result == expected

def test_filter(self) -> None:
Expand Down Expand Up @@ -574,13 +581,55 @@ def test_or_singleton(self) -> None:
expected = Collection(map(_Item, range(4)))
assert result == expected

def test_or_iterable(self) -> None:
def test_or_collection(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection | Collection(map(_Item, range(1, 4)))
assert isinstance(result, Collection)
expected = Collection(map(_Item, range(4)))
assert result == expected

def test_or_iterable(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection | map(_Item, range(1, 4))
assert isinstance(result, Collection)
expected = Collection(map(_Item, range(4)))
assert result == expected

def test_sub_single_int(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - 1
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_single_item(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - _Item(1)
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_collection(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - Collection(_Item(1))
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_iterable_items(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - [_Item(1)]
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected

def test_sub_iterable_ints(self) -> None:
collection = Collection(map(_Item, range(3)))
result = collection - [1]
assert isinstance(result, Collection)
expected = Collection(_Item(0), _Item(2))
assert result == expected


class TestEnsureHashables:
def test_main(self) -> None:
Expand Down
14 changes: 11 additions & 3 deletions src/utilities/iterables.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ def chunked(iterable: Iterable[_T], n: int, /) -> Iterator[Sequence[_T]]:

class _SupportsHashAndLT(Protocol):
@override
def __hash__(self) -> int: ...
def __lt__(self, other: Any, /) -> bool: ...
def __hash__(self) -> int: ... # pragma: no cover
def __lt__(self, other: Any, /) -> bool: ... # pragma: no cover


_TSupportsHashAndSort = TypeVar("_TSupportsHashAndSort", bound=_SupportsHashAndLT)
Expand Down Expand Up @@ -559,9 +559,17 @@ def __or__(self, other: MaybeIterable[_TSupportsHashAndSort], /) -> Self: # pyr
return self.__or__(type(self)(other))

@override
def __sub__(self, other: MaybeIterable[_TSupportsHashAndSort], /) -> Self:
def __sub__(
self, other: int | Sequence[int] | MaybeIterable[_TSupportsHashAndSort], /
) -> Self:
if isinstance(other, int):
return self - self[other]
if isinstance(other, Sequence) and all(isinstance(i, int) for i in other):
other = cast(Sequence[int], other)
return self - self[other]
if isinstance(other, type(self)):
return type(self)(super().__sub__(other))
other = cast(Iterable[_TSupportsHashAndSort], other)
return self.__sub__(type(self)(other))

def filter(self, func: Callable[[_TSupportsHashAndSort], bool], /) -> Self:
Expand Down

0 comments on commit 5043d20

Please sign in to comment.