Skip to content

Commit

Permalink
change Wallenius into a class that deals MultisetExpressions #215
Browse files Browse the repository at this point in the history
  • Loading branch information
HighDiceRoller committed Feb 22, 2025
1 parent 77e5dba commit d7f01fd
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 91 deletions.
23 changes: 14 additions & 9 deletions apps/honkai_star_rail_relic.html
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ <h2>Notes</h2>
setLoadingText('Loading icepool')
await pyodide.runPythonAsync(`
import micropip
await micropip.install('icepool==1.6.0')
await micropip.install('icepool==1.7.2a0')
import js
import pyodide
from functools import cache
from icepool import Die
from icepool import wallenius
stats = {
'spd' : ('SPD', 4),
Expand All @@ -164,7 +164,13 @@ <h2>Notes</h2>
'def' : ('DEF', 10),
}
base = Die({stat: weight for stat, (_, weight) in stats.items()})
def wanted_substat_count(haves, wants, more):
possible_weights = []
for stat, (full_name, weight) in stats.items():
if stat not in haves:
possible_weights.append((stat in wants, weight))
return wallenius(possible_weights, more).map(sum, star=False)
@cache
def roll_substats(current, more):
Expand Down Expand Up @@ -210,8 +216,7 @@ <h2>Notes</h2>
for fixed_count in [0, 1, 2]:
fixed = choose_fixed(wants, fixed_count)
substats = roll_substats(haves + fixed, 4 - fixed_count)
wanted_count = substats.map(lambda s: len(set(s) & wants))
wanted_count = wanted_substat_count(haves + fixed, wants, 4 - fixed_count) + fixed_count
for hits in [1, 2, 3, 4]:
value = (wanted_count >= hits).mean()
results[(hits, fixed_count)] = value
Expand All @@ -231,24 +236,24 @@ <h2>Notes</h2>
if fixed_count == best_fixed_count:
chance_cell.style.background = '#cfc'
else:
chance_cell.style.background = '#ccc'
chance_cell.style.background = '#ddd'
score = results[(hits, fixed_count)] / results[(hits, 0)] / cost
score_cell = js.document.getElementById(f'score{hits}{fixed_count}')
score_cell.innerHTML = f'{score:0.2%}'
if fixed_count == best_fixed_count:
score_cell.style.background = '#cfc'
else:
score_cell.style.background = '#ccc'
score_cell.style.background = '#ddd'
else:
for fixed_count in [0, 1, 2]:
chance_cell = js.document.getElementById(f'chance{hits}{fixed_count}')
chance_cell.innerHTML = f''
chance_cell.style.background = '#ccc'
chance_cell.style.background = '#ddd'
score_cell = js.document.getElementById(f'score{hits}{fixed_count}')
score_cell.innerHTML = f''
score_cell.style.background = '#ccc'
score_cell.style.background = '#ddd'
`)
}
Expand Down
4 changes: 2 additions & 2 deletions src/icepool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@

from icepool.population.format import format_probability_inverse

from icepool.noncentral_hypergeometric import wallenius
from icepool.wallenius import Wallenius

import icepool.generator as generator
import icepool.evaluator as evaluator
Expand All @@ -165,5 +165,5 @@
'standard_pool', 'MultisetGenerator', 'MultisetExpression',
'MultisetEvaluator', 'Order', 'Deck', 'Deal', 'MultiDeal',
'multiset_function', 'function', 'typing', 'evaluator',
'format_probability_inverse', 'wallenius'
'format_probability_inverse', 'Wallenius'
]
75 changes: 0 additions & 75 deletions src/icepool/noncentral_hypergeometric.py

This file was deleted.

87 changes: 87 additions & 0 deletions src/icepool/wallenius.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
__docformat__ = 'google'

import icepool

from collections import Counter, defaultdict
from functools import cache

from typing import Generic, Iterable, Mapping, MutableMapping
from icepool.typing import T


@cache
def _wallenius_weights(weight_die: icepool.Die[int], hand_size: int,
/) -> 'icepool.Die[tuple[int, ...]]':
"""A die whose outcomes are sorted tuples of weights to pull."""
if hand_size == 0:
return icepool.Die([()])

def inner(weight):
return (_wallenius_weights(weight_die.remove(weight, weight),
hand_size - 1) +
(weight, )).map(lambda x: tuple(sorted(x)))

return weight_die.map(inner)


class Wallenius(Generic[T]):
"""EXPERIMENTAL: Wallenius' noncentral hypergeometric distribution.
This is sampling without replacement with weights, where the entire weight
of a card goes away when it is pulled.
"""
_weight_decks: 'MutableMapping[int, icepool.Deck[T]]'
_weight_die: 'icepool.Die[int]'

def __init__(self, data: Iterable[tuple[T, int]]
| Mapping[T, int | tuple[int, int]]):
"""Constructor.
Args:
data: Either an iterable of (outcome, weight), or a mapping from
outcomes to either weights or (weight, quantity).
hand_size: The number of outcomes to pull.
"""
self._weight_decks = {}

if isinstance(data, Mapping):
for outcome, v in data.items():
if isinstance(v, int):
weight = v
quantity = 1
else:
weight, quantity = v
self._weight_decks[weight] = self._weight_decks.get(
weight, icepool.Deck()).append(outcome, quantity)
else:
for outcome, weight in data:
self._weight_decks[weight] = self._weight_decks.get(
weight, icepool.Deck()).append(outcome)

self._weight_die = icepool.Die({
weight: weight * deck.denominator()
for weight, deck in self._weight_decks.items()
})

def deal(self, hand_size: int, /) -> 'icepool.MultisetExpression[T]':
"""Deals the specified number of outcomes from the Wallenius.
The result is a `MultisetExpression` representing the multiset of
outcomes dealt.
"""
if hand_size == 0:
return icepool.Pool([])

def inner(weights):
weight_counts = Counter(weights)
result = None
for weight, count in weight_counts.items():
deal = self._weight_decks[weight].deal(count)
if result is None:
result = deal
else:
result = result + deal
return result

hand_weights = _wallenius_weights(self._weight_die, hand_size)
return hand_weights.map_to_pool(inner, star=False)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import icepool
import pytest

from icepool import wallenius, Die
from icepool import Wallenius, Die


def test_wallenius_singleton_dict():
Expand All @@ -11,7 +11,7 @@ def test_wallenius_singleton_dict():
expected = base.map(lambda x: base.remove(x).map(lambda y: (x, y))).map(
lambda x: tuple(sorted(x)))

assert wallenius(data, 2) == expected
assert Wallenius(data).deal(2).expand() == expected


def test_wallenius_singleton_list():
Expand All @@ -21,7 +21,7 @@ def test_wallenius_singleton_list():
expected = base.map(lambda x: base.remove(x).map(lambda y: (x, y))).map(
lambda x: tuple(sorted(x)))

assert wallenius(data, 2) == expected
assert Wallenius(data).deal(2).expand() == expected


def test_wallenius_weighted_dict():
Expand All @@ -31,7 +31,7 @@ def test_wallenius_weighted_dict():
expected = base.map(lambda x: base.remove(x, x).map(lambda y: (x, y))).map(
lambda x: tuple(sorted(x)))

assert wallenius(data, 2).simplify() == expected.simplify()
assert Wallenius(data).deal(2).expand().simplify() == expected.simplify()


def test_wallenius_weighted_list():
Expand All @@ -41,4 +41,4 @@ def test_wallenius_weighted_list():
expected = base.map(lambda x: base.remove(x, x).map(lambda y: (x, y))).map(
lambda x: tuple(sorted(x)))

assert wallenius(data, 2).simplify() == expected.simplify()
assert Wallenius(data).deal(2).expand().simplify() == expected.simplify()

0 comments on commit d7f01fd

Please sign in to comment.