Skip to content

Commit

Permalink
improve some multiset typing
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Feb 26, 2025
1 parent 1ede6cc commit 9a9381f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
29 changes: 15 additions & 14 deletions src/icepool/evaluator/multiset_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@

if TYPE_CHECKING:
from icepool.generator.alignment import Alignment
from icepool import MultisetExpression
from icepool.expression.multiset_expression import MultisetExpressionBase

PREFERRED_ORDER_COST_FACTOR = 10
"""The preferred order will be favored this times as much."""

EvaluationCache: TypeAlias = 'MutableMapping[tuple[Alignment, tuple[MultisetExpression, ...], Hashable], Mapping[Any, int]]'
EvaluationCache: TypeAlias = 'MutableMapping[tuple[Alignment, tuple[MultisetExpressionBase, ...], Hashable], Mapping[Any, int]]'
"""Type representing the cache used within an evaluation."""


class MultisetEvaluator(ABC, Generic[T, U_co]):
"""An abstract, immutable, callable class for evaulating one or more input `MultisetExpression`s.
"""An abstract, immutable, callable class for evaulating one or more input `MultisetExpressionBase`s.
There is one abstract method to implement: `next_state()`.
This should incrementally calculate the result given one outcome at a time
Expand Down Expand Up @@ -57,8 +57,7 @@ class MultisetEvaluator(ABC, Generic[T, U_co]):
Otherwise, values in the cache may be incorrect.
"""

def next_state(self, state: Hashable, outcome: T, /, *counts:
int) -> Hashable:
def next_state(self, state: Hashable, outcome: T, /, *counts) -> Hashable:
"""State transition function.
This should produce a state given the previous state, an outcome,
Expand Down Expand Up @@ -117,8 +116,8 @@ def next_state(self, state: Hashable, outcome: T, /, *counts:
"""
raise NotImplementedError()

def next_state_ascending(self, state: Hashable, outcome: T, /, *counts:
int) -> Hashable:
def next_state_ascending(self, state: Hashable, outcome: T, /,
*counts) -> Hashable:
"""As next_state() but handles outcomes in ascending order only.
You can implement both `next_state_ascending()` and
Expand All @@ -127,8 +126,8 @@ def next_state_ascending(self, state: Hashable, outcome: T, /, *counts:
"""
raise NotImplementedError()

def next_state_descending(self, state: Hashable, outcome: T, /, *counts:
int) -> Hashable:
def next_state_descending(self, state: Hashable, outcome: T, /,
*counts) -> Hashable:
"""As next_state() but handles outcomes in descending order only.
You can implement both `next_state_ascending()` and
Expand Down Expand Up @@ -265,18 +264,20 @@ def evaluate(

@overload
def evaluate(
self,
*args: 'MultisetExpression[T]') -> 'MultisetEvaluator[T, U_co]':
self, *args:
'MultisetExpressionBase[T, Q]') -> 'MultisetEvaluator[T, U_co]':
...

@overload
def evaluate(
self, *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
self, *args:
'MultisetExpressionBase[T, Q] | Mapping[T, int] | Sequence[T]'
) -> 'icepool.Die[U_co] | MultisetEvaluator[T, U_co]':
...

def evaluate(
self, *args: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
self, *args:
'MultisetExpressionBase[T, Q] | Mapping[T, int] | Sequence[T]'
) -> 'icepool.Die[U_co] | MultisetEvaluator[T, U_co]':
"""Evaluates input expression(s).
Expand Down Expand Up @@ -378,7 +379,7 @@ def next_state_method(self, order: Order, /) -> Callable[..., Hashable]:
f'Could not find next_state* implementation for order {order}.')

def _select_algorithm(
self, *inputs: 'icepool.MultisetExpression[T]'
self, *inputs: 'icepool.MultisetExpressionBase[T, Any]'
) -> tuple[
'Callable[[Order, EvaluationCache, Callable[..., Hashable], Alignment[T], tuple[icepool.MultisetExpression[T], ...]], Mapping[Any, int]]',
Order]:
Expand Down
4 changes: 2 additions & 2 deletions src/icepool/evaluator/multiset_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import inspect
from functools import cached_property, update_wrapper

from typing import Callable, Collection, Hashable, TypeAlias, overload
from typing import Any, Callable, Collection, Hashable, TypeAlias, overload

from icepool.order import Order, OrderReason, merge_order_preferences
from icepool.typing import T, U_co
Expand Down Expand Up @@ -142,7 +142,7 @@ def bad(a, b)
class MultisetFunctionEvaluator(MultisetEvaluator[T, U_co]):
__name__ = '(unnamed)'

def __init__(self, *inputs: 'icepool.MultisetExpression[T]',
def __init__(self, *inputs: 'icepool.MultisetExpressionBase[T, Any]',
evaluator: MultisetEvaluator[T, U_co]) -> None:
self._evaluator = evaluator
bound_inputs: 'list[icepool.MultisetExpressionBase]' = []
Expand Down
38 changes: 30 additions & 8 deletions src/icepool/expression/multiset_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,42 @@
import operator
import random

from icepool.typing import Q, T, U, Expandable, ImplicitConversionError, T
from icepool.typing import Q, T, U, Expandable, ImplicitConversionError
from types import EllipsisType
from typing import (Callable, Collection, Iterator, Literal, Mapping, Sequence,
Type, cast, overload)
from typing import (TYPE_CHECKING, Any, Callable, Collection, Iterator,
Literal, Mapping, Sequence, Type, cast, overload)

if TYPE_CHECKING:
from icepool.expression.multiset_tuple_expression import MultisetTupleExpression


class MultisetArityError(ValueError):
"""Indicates that an arity was not the same as required."""


@overload
def implicit_convert_to_expression(
arg: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
arg: 'MultisetExpression[T]| Mapping[T, int] | Sequence[T]'
) -> 'MultisetExpression[T]':
...


@overload
def implicit_convert_to_expression(
arg: 'MultisetTupleExpression[T]') -> 'MultisetTupleExpression[T]':
...


@overload
def implicit_convert_to_expression(
arg: 'MultisetExpressionBase | Mapping[T, int] | Sequence[T]'
) -> 'MultisetExpression[T] | MultisetTupleExpression[T]':
...


def implicit_convert_to_expression(
arg: 'MultisetExpressionBase[T, Q] | Mapping[T, int] | Sequence[T]'
) -> 'MultisetExpressionBase[T, Q]':
"""Implcitly converts the argument to a `MultisetExpression` with `int` counts.
Args:
Expand Down Expand Up @@ -370,8 +393,8 @@ def symmetric_difference(
[1, 2, 2, 3] ^ [1, 2, 4] -> [2, 3, 4]
```
"""
other = implicit_convert_to_expression(other)
return icepool.operator.MultisetSymmetricDifference(self, other)
return icepool.operator.MultisetSymmetricDifference(
self, implicit_convert_to_expression(other))

def keep_outcomes(
self, target:
Expand Down Expand Up @@ -936,8 +959,7 @@ def largest_count_and_outcome(
def __rfloordiv__(
self, other: 'MultisetExpression[T] | Mapping[T, int] | Sequence[T]'
) -> 'icepool.Die[int] | icepool.MultisetEvaluator[T, int]':
other = implicit_convert_to_expression(other)
return other.count_subset(self)
return implicit_convert_to_expression(other).count_subset(self)

def count_subset(
self,
Expand Down

0 comments on commit 9a9381f

Please sign in to comment.