Skip to content

Commit

Permalink
naming refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 9, 2024
1 parent c787143 commit 1b65b01
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions dominoes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ def listDominoes(highestDominoe):
return np.array([np.array(quake) for quake in itertools.combinations_with_replacement(np.arange(highestDominoe + 1), 2)], dtype=int)


def twohotDominoe(dominoeIndex, dominoes, highestDominoe, available=None, available_token=False, null_token=False, with_batch=True):
def twohot_dominoe(selected, dominoes, highest_dominoe, available=None, available_token=False, null_token=False, with_batch=True):
"""
converts an index of dominoes to a stacked two-hot representation
converts an index of selected dominoes to a stacked two-hot representation
dominoes are paired values (combinations with replacement) of integers
from 0 to <highestDominoe>.
from 0 to <highest_dominoe>.
This simple representation is a two-hot vector where the first
<highestDominoe+1> elements represent the first value of the dominoe, and
the second <highestDominoe+1> elements represent the second value of the
dominoe. Here are some examples for highest dominoe = 3:
<highest_dominoe>+1 elements represent the first value of the dominoe, and
the second <highest_dominoe>+1 elements represent the second value of the
dominoe. Here are some examples for <highest_dominoe> = 3:
(0 | 0): [1, 0, 0, 0, 1, 0, 0, 0]
(0 | 1): [1, 0, 0, 0, 0, 1, 0, 0]
Expand All @@ -110,17 +110,17 @@ def twohotDominoe(dominoeIndex, dominoes, highestDominoe, available=None, availa
(2 | 1): [0, 0, 1, 0, 0, 1, 0, 0]
"""
assert dominoeIndex.ndim == 1, "dominoeIndex must have shape (numDominoesSelected, 1)"
assert selected.ndim == 1, "selected must have shape (num_selected, )"
if available_token:
assert available is not None, "if with_available=True, then available needs to be provided"
(num_dominoes,) = dominoeIndex.shape
num_dominoes = selected.shape[0]

# input dimension determined by highest dominoe (twice the number of possible values on a dominoe)
input_dim = (2 if not (available_token) else 3) * (highestDominoe + 1) + (1 if null_token else 0)
input_dim = (2 if not (available_token) else 3) * (highest_dominoe + 1) + (1 if null_token else 0)

# first & second value are index and shifted index
firstValue = torch.tensor(dominoes[dominoeIndex, 0], dtype=torch.int64).unsqueeze(1)
secondValue = torch.tensor(dominoes[dominoeIndex, 1] + highestDominoe + 1, dtype=torch.int64).unsqueeze(1)
firstValue = torch.tensor(dominoes[selected, 0], dtype=torch.int64).unsqueeze(1)
secondValue = torch.tensor(dominoes[selected, 1] + highest_dominoe + 1, dtype=torch.int64).unsqueeze(1)

# scatter data into two-hot vectors
src = torch.ones((num_dominoes, 1), dtype=torch.float)
Expand All @@ -134,7 +134,7 @@ def twohotDominoe(dominoeIndex, dominoes, highestDominoe, available=None, availa

if available_token:
rep_available = torch.zeros((1, input_dim), dtype=torch.float)
availableidx = int((highestDominoe + 1) * 2 + available)
availableidx = int((highest_dominoe + 1) * 2 + available)
rep_available[0, availableidx] = 1.0
twohot = torch.cat((twohot, rep_available), dim=0)

Expand Down Expand Up @@ -212,7 +212,7 @@ def gameSequenceToString(dominoes, sequence, direction, player=None, playNumber=
print(name[idx], sequenceString)


def constructLineRecursive(dominoes, myHand, available, previousSequence=[], previousDirection=[], maxLineLength=None):
def construct_line_recursive(dominoes, myHand, available, previousSequence=[], previousDirection=[], maxLineLength=None):
# this version of the function uses absolute dominoe numbers, rather than indexing based on which order they are in the hand
# if there are too many dominoes in hand, constructing all possible lines takes way too long...
if (maxLineLength is not None) and (len(previousSequence) == maxLineLength):
Expand Down Expand Up @@ -245,7 +245,7 @@ def constructLineRecursive(dominoes, myHand, available, previousSequence=[], pre
cdir = copy(previousDirection)
cdir.append(0)
# then recursively construct line from this standpoint
cSequence, cDirection = constructLineRecursive(
cSequence, cDirection = construct_line_recursive(
dominoes, myHand, hand[idxPlay][1], previousSequence=cseq, previousDirection=cdir, maxLineLength=maxLineLength
)
# once lines are constructed, add them all to "sequence" and "direction", which will be a list of lists of all possible sequences
Expand All @@ -259,7 +259,7 @@ def constructLineRecursive(dominoes, myHand, available, previousSequence=[], pre
cseq.append(myHand[idxPlay])
cdir = copy(previousDirection)
cdir.append(1)
cSequence, cDirection = constructLineRecursive(
cSequence, cDirection = construct_line_recursive(
dominoes, myHand, hand[idxPlay][0], previousSequence=cseq, previousDirection=cdir, maxLineLength=maxLineLength
)
for cns, cnd in zip(cSequence, cDirection):
Expand Down

0 comments on commit 1b65b01

Please sign in to comment.