From fbcb36d87e10f903992a4af07fcc2ff6f8eef4f0 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Tue, 9 Apr 2024 00:20:33 +0200 Subject: [PATCH] Remove invalid shapes from Shape internal table. --- redeal/redeal.py | 53 ++++++++++++++++++++++---------------------- redeal/smartstack.py | 4 +++- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/redeal/redeal.py b/redeal/redeal.py index 545515f..c103759 100644 --- a/redeal/redeal.py +++ b/redeal/redeal.py @@ -1,7 +1,7 @@ from array import array from bisect import bisect from collections import Counter -from itertools import permutations, product +from itertools import combinations_with_replacement, permutations from operator import itemgetter, attrgetter import functools import random @@ -46,9 +46,15 @@ class Shape: ``accept`` function of a simulation, for example.. """ - JOKER = "x" - TABLE = {JOKER: -1, "t": 10, "j": 11, "q": 12, "k": 13, "(": "(", ")": ")"} - TABLE.update({str(n): n for n in range(10)}) + _str_to_val = { + "x": -1, "t": 10, "j": 11, "q": 12, "k": 13, "(": "(", ")": ")", + **{str(n): n for n in range(10)}} + _all_shapes = [ + (s, sh - s, shd - sh, len(Rank) - shd) + for s, sh, shd + in combinations_with_replacement(range(len(Rank) + 1), len(Suit) - 1) + ] + _shape_to_index = {shape: idx for idx, shape in enumerate(_all_shapes)} _cls_cache = {} def __new__(cls, init=None): @@ -57,13 +63,13 @@ def __new__(cls, init=None): return cls._cls_cache[init] except KeyError: self = object.__new__(cls) - self.table = array("b") - self.table.fromlist([0] * (len(Rank) + 1) ** len(Suit)) + self._table = array("b") + self._table.fromlist([0] * len(cls._all_shapes)) self.min_ls = [len(Rank) for _ in Suit] self.max_ls = [0 for _ in Suit] self._op_cache = {} if init: - self.insert([self.TABLE[char.lower()] for char in init]) + self.insert([self._str_to_val[char.lower()] for char in init]) cls._cls_cache[init] = self return self @@ -71,16 +77,16 @@ def __new__(cls, init=None): def from_table(cls, table, min_max_hint=None): """Initialize from a table.""" self = cls() - self.table = array("b") - self.table.fromlist(list(table)) + self._table = array("b") + self._table.fromlist(list(table)) if min_max_hint is not None: self.min_ls, self.max_ls = min_max_hint else: self.min_ls = [len(Rank) for _ in Suit] self.max_ls = [0 for _ in Suit] - for nonflat in product(*[range(len(Rank) + 1) for _ in Suit]): - if self.table[self._flatten(nonflat)]: - for dim, coord in enumerate(nonflat): + for idx, shape in enumerate(cls._all_shapes): + if self._table[idx]: + for dim, coord in enumerate(shape): self.min_ls[dim] = min(self.min_ls[dim], coord) self.max_ls[dim] = max(self.max_ls[dim], coord) return self @@ -89,28 +95,21 @@ def from_table(cls, table, min_max_hint=None): def from_cond(cls, func): """Initialize from a shape-accepting function.""" self = cls() - for nonflat in product(*[range(len(Rank) + 1) for _ in Suit]): - if sum(nonflat) == len(Rank) and func(*nonflat): - self.table[self._flatten(nonflat)] = True - for dim, coord in enumerate(nonflat): + for idx, shape in enumerate(cls._all_shapes): + if func(*shape): + self._table[idx] = True + for dim, coord in enumerate(shape): self.min_ls[dim] = min(self.min_ls[dim], coord) self.max_ls[dim] = max(self.max_ls[dim], coord) return self - @staticmethod - def _flatten(index): - """Transform a 4D index into a 1D index.""" - s, h, d, c = index - mul = len(Rank) + 1 - return ((((s * mul + h) * mul) + d) * mul) + c - def _insert1(self, shape, safe=True): """Insert an element, possibly with "x" but no "()" terms.""" jokers = any(l == -1 for l in shape) pre_set = sum(l for l in shape if l >= 0) if not jokers: if pre_set == len(Rank): - self.table[self._flatten(shape)] = 1 + self._table[self._shape_to_index[shape]] = 1 for suit in Suit: self.min_ls[suit] = min(self.min_ls[suit], shape[suit]) self.max_ls[suit] = max(self.max_ls[suit], shape[suit]) @@ -143,7 +142,7 @@ def insert(self, it, acc=()): def __contains__(self, int_shape): """Check if the given shape is included.""" - return self.table[self._flatten(int_shape)] + return self._table[self._shape_to_index[int_shape]] def __call__(self, hand): """Check if the shape of the given hand is included.""" @@ -155,7 +154,7 @@ def __add__(self, other): return self._op_cache["+", other] except KeyError: table = array("b") - table.fromlist([x or y for x, y in zip(self.table, other.table)]) + table.fromlist([x or y for x, y in zip(self._table, other._table)]) min_ls = [min(self.min_ls[suit], other.min_ls[suit]) for suit in Suit] max_ls = [max(self.max_ls[suit], other.max_ls[suit]) @@ -171,7 +170,7 @@ def __sub__(self, other): except KeyError: table = array("b") table.fromlist( - [x and not y for x, y in zip(self.table, other.table)]) + [x and not y for x, y in zip(self._table, other._table)]) result = Shape.from_table(table, (self.min_ls, self.max_ls)) self._op_cache["-", other] = result return result diff --git a/redeal/smartstack.py b/redeal/smartstack.py index 3545632..c3e8226 100644 --- a/redeal/smartstack.py +++ b/redeal/smartstack.py @@ -36,7 +36,9 @@ def _prepare(self): for lvs_hs in product(*[holdings[suit].items() for suit in Suit]): lvs, hs = zip(*lvs_hs) ls, vs = zip(*lvs) - if ls in self._shape and sum(vs) in self._values: + if (sum(ls) == len(Rank) + and ls in self._shape + and sum(vs) in self._values): counter[ls, vs] += reduce(operator.mul, map(len, hs)) patterns, cumsum = zip(*counter.items()) cumsum = list(cumsum)