Skip to content

Commit

Permalink
refactor many old functions, move to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 11, 2024
1 parent 8140754 commit cd23a98
Showing 1 changed file with 142 additions and 127 deletions.
269 changes: 142 additions & 127 deletions dominoes/datasets/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,105 +33,162 @@ def get_dominoe_set(highest_dominoe, as_torch=False):
return stack_function(dominoe_set)


def get_best_line(dominoes, selection, highest_dominoe, value_method="dominoe"):
# check value method
if not (value_method == "dominoe" or value_method == "length"):
raise ValueError("did not recognize value_method, it has to be either 'dominoe' or 'length'")
def get_best_line(dominoes, available, value_method="dominoe"):
"""
get the best line of dominoes given a set of dominoes and an available token
args:
dominoes: torch.tensor of shape (num_dominoes, 2)
available: (int) the value that is available to play on
value_method: (str) either "dominoe" or "length" to measure the value of the line
if "dominoe" the value is the sum of the dominoes in the line
if "length" the value is the length of the line
returns:
best_sequence: the best sequence of dominoes
best_direction: the direction of each dominoe in the sequence
bestSequence = []
bestDirection = []
for sel in selection:
cBestSeq = []
cBestDir = []
cBestVal = []
for available in range(highest_dominoe + 1):
cseq, cdir = construct_line_recursive(dominoes, sel, available)
if value_method == "dominoe":
cval = [np.sum(dominoes[cs]) for cs in cseq]
else:
cval = [len(cs) for cs in cseq]
cidx = max(enumerate(cval), key=lambda x: x[1])[0]
cBestSeq.append(cseq[cidx])
cBestDir.append(cdir[cidx])
cBestVal.append(cval[cidx])

cBestIdx = max(enumerate(cBestVal), key=lambda x: x[1])[0]
bestSequence.append(cBestSeq[cBestIdx])
bestDirection.append(cBestDir[cBestIdx])

return bestSequence, bestDirection


def get_best_line_from_available(dominoes, selection, available, value_method="dominoe"):
"""
# check value method
if not (value_method == "dominoe" or value_method == "length"):
raise ValueError("did not recognize value_method, it has to be either 'dominoe' or 'length'")

bestSequence = []
bestDirection = []
for sel, ava in zip(selection, available):
cseq, cdir = construct_line_recursive(dominoes, sel, ava)
if value_method == "dominoe":
cval = [np.sum(dominoes[cs]) for cs in cseq]
else:
cval = [len(cs) for cs in cseq]
cidx = max(enumerate(cval), key=lambda x: x[1])[0]
bestSequence.append(cseq[cidx])
bestDirection.append(cdir[cidx])
return bestSequence, bestDirection


def construct_line_recursive(dominoes, myHand, available, previousSequence=[], previousDirection=[], maxLineLength=None):
# this version of the function uses absolute dominoe numbers, rather than indexing based on which order they are in the hand
# if there are too many dominoes in hand, constructing all possible lines takes way too long...
if (maxLineLength is not None) and (len(previousSequence) == maxLineLength):
return [previousSequence], [previousDirection]

assert type(previousSequence) == list and type(previousDirection) == list, "previous sequence and direction must be lists"
if len(previousSequence) > 0:
# if a previous sequence was provided, make sure the end of it matches what is defined as available
assert (
dominoes[previousSequence[-1]][0 if previousDirection[-1] == 1 else 1] == available
), "the end of the last sequence doesn't match what is defined as available!"

# recursively constructs all possible lines given a hand (value pairs in list), an available value to play on, and the previous played/direction dominoe index sequences
hand = dominoes[myHand]
possiblePlays = np.where(np.any(hand == available, axis=1) & ~np.isin(myHand, previousSequence))[0]

# if there are no possible plays, the return the finished sequence
if len(possiblePlays) == 0:
return [previousSequence], [previousDirection]

# otherwise, make new lines for each possible play
# get all possible lines with this set of dominoes and the available token
allseqs, alldirs = construct_line_recursive(dominoes, available)

# measure value with either dominoe method or length method
if value_method == "dominoe":
allval = [torch.sum(dominoes[seq]) for seq in allseqs]
else:
allval = [len(seq) for seq in allseqs]

# get index to the best sequence
best_idx = max(enumerate(allval), key=lambda x: x[1])[0]

# return the best sequence and direction
return allseqs[best_idx], alldirs[best_idx]


def pad_best_lines(best_seq, max_output, null_index, ignore_index=-1):
"""
pad the best sequence of dominoes to a fixed length
args:
best_seq: the best sequence of dominoes
max_output: the maximum length of the sequence
null_index: the index of the null index (set to ignore_index if no null token)
ignore_index: the index of the ignore index
returns:
padded_best_seq: the best sequence padded to a fixed length
"""
padded_best_seq = []
for seq in best_seq:
c_length = len(seq)
append_null = [null_index] if max_output > c_length else []
append_ignore = [ignore_index] * (max_output - (c_length + 1))
seq += append_null + append_ignore
padded_best_seq.append(seq)
return padded_best_seq


def padBestLine(bestSequence, max_output, null_index, ignore_index=-1):
for bs in bestSequence:
c_length = len(bs)
append_null = [null_index] if max_output > c_length else []
append_ignore = [ignore_index] * (max_output - (c_length + 1))
bs += append_null + append_ignore
# bs += [ignore_index]*(max_output-len(bs))
return bestSequence


def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=[], prev_dir=[], max_length=None):
"""
recursively construct all possible lines given a set of dominoes, an available value to play on,
and the previous played/direction dominoe index sequences.
This method can be used in two ways:
1. if hand_index is not provided, it will use all dominoes in the set and the resulting
sequences will use the indices of the dominoes in the set provided in the first argument.
2. if hand_index is provided, it will only use those dominoes in the set and the resulting
sequences will use the indices of the dominoes in the hand_index list.
args:
dominoes: torch.tensor or numpy nd.array of shape (num_dominoes, 2)
available: (int) the value that is available to play on
hand_index: (optional, list[int]) the index of the dominoes in the hand
prev_seq: the previous sequence of dominoes -- is used for recursion
prev_dir: the previous direction of the dominoes -- is used for recursion
max_length: the maximum length of the line
returns:
sequence: the list of all possible sequences of dominoes (with indices corresponding
to the dominoes in the set or hand_index)
direction: the list of the direction each dominoe must be played within each sequence
"""
# if the maximum length of the sequence is reached, return sequence up to this point
if max_length is not None and len(prev_seq) == max_length:
return [prev_seq], [prev_dir]

# check if previous sequence end position matches the available value
if len(prev_seq) > 0:
msg = "the end of the last sequence doesn't match what is defined as available!"
assert dominoes[prev_seq[-1]][0 if prev_dir[-1] == 1 else 1] == available, msg

# convert dominoes to torch tensor if it is a numpy array
if isinstance(dominoes, np.ndarray):
dominoes = torch.tensor(dominoes)

# if hand_index is not provided, use all dominoes in the set
if hand_index is None:
hand_index = torch.arange(len(dominoes))

# set hand ("playable dominoes")
hand = dominoes[hand_index]

# find all dominoes in hand that can be played on the available token
possible_plays = torch.where(torch.any(hand == available, axis=1) & ~torch.isin(hand_index, prev_seq))[0]

# if no more plays are possible, return the finished sequence and direction
if len(possible_plays) == 0:
return [prev_seq], [prev_dir]

# otherwise create new lines for each possible play
sequence = []
direction = []
for idxPlay in possiblePlays:
# if the first value of the possible play matches the available, then play it in the forward direction
if hand[idxPlay][0] == available:
# copy previousSequence and previousDirection, append new play in forward direction to it
cseq = copy(previousSequence)
cseq.append(myHand[idxPlay])
cdir = copy(previousDirection)
for idx_play in possible_plays:
# if the first value of the possible play matches the available value
if hand[idx_play][0] == available:
# add to sequence
cseq = copy(prev_seq)
cseq.append(hand_index[idx_play])
# play in forward direction
cdir = copy(prev_dir)
cdir.append(0)
# then recursively construct line from this standpoint
cSequence, cDirection = construct_line_recursive(
dominoes, myHand, hand[idxPlay][1], previousSequence=cseq, previousDirection=cdir, maxLineLength=maxLineLength
# construct sequence recursively from this new sequence
cseq, cdir = construct_line_recursive(
dominoes, hand[idx_play][1], hand_index=hand_index, prev_seq=cseq, prev_dir=cdir, max_length=max_length
)
# once lines are constructed, add them all to "sequence" and "direction", which will be a list of lists of all possible sequences
for cns, cnd in zip(cSequence, cDirection):
# add all sequence/direction lists to possible sequences
for cns, cnd in zip(cseq, cdir):
sequence.append(cns)
direction.append(cnd)

# if the second value of the possible play matches the available and it isn't a double, then play it in the reverse direction (all same except direction and next available)
if (hand[idxPlay][0] != hand[idxPlay][1]) and (hand[idxPlay][1] == available):
cseq = copy(previousSequence)
cseq.append(myHand[idxPlay])
cdir = copy(previousDirection)
# if the second value of the possible play matches the available and it isn't a double,
# then play it in the reverse direction
else:
# add to sequence
cseq = copy(prev_seq)
cseq.append(hand_index[idx_play])
# play in reverse direction
cdir = copy(prev_dir)
cdir.append(1)
cSequence, cDirection = construct_line_recursive(
dominoes, myHand, hand[idxPlay][0], previousSequence=cseq, previousDirection=cdir, maxLineLength=maxLineLength
# construct sequence recursively from this new sequence
cseq, cdir = construct_line_recursive(
dominoes, hand[idx_play][0], hand_index=hand_index, prev_seq=cseq, prev_dir=cdir, max_length=max_length
)
for cns, cnd in zip(cSequence, cDirection):
# add all sequence/direction lists to possible sequences
for cns, cnd in zip(cseq, cdir):
sequence.append(cns)
direction.append(cnd)

Expand Down Expand Up @@ -218,48 +275,6 @@ def sortPad(val, padTo, ignoreIndex=-1):
return input, target, mask


def makeLines(input, dominoes, value_method="dominoe"):
selection, available = input # unpack
cseq, cdir = construct_line_recursive(dominoes, selection, available)
if value_method == "dominoe":
cval = [np.sum(dominoes[cs]) for cs in cseq]
else:
cval = [len(cs) for cs in cseq]
cidx = max(enumerate(cval), key=lambda x: x[1])[0]
return cseq[cidx], cdir[cidx]


def getBestLineFromAvailablePool(dominoes, selection, available, value_method="dominoe", threads=18):
# check value method
if not (value_method == "dominoe" or value_method == "length"):
raise ValueError("did not recognize value_method, it has to be either 'dominoe' or 'length'")
p_makeLines = partial(makeLines, dominoes=dominoes, value_method=value_method)

with Pool(threads) as p:
lines = p.map(p_makeLines, zip(selection, available))
bestSequence, bestDirection = map(list, zip(*lines))
return bestSequence, bestDirection


def convertToHandIndex(selection, bestSequence):
indices = []
for sel, seq in zip(selection, bestSequence):
# look up table for current selection
elementIdx = {element: idx for idx, element in enumerate(sel)}
indices.append([elementIdx[element] for element in seq])
return indices


def padBestLine(bestSequence, max_output, null_index, ignore_index=-1):
for bs in bestSequence:
c_length = len(bs)
append_null = [null_index] if max_output > c_length else []
append_ignore = [ignore_index] * (max_output - (c_length + 1))
bs += append_null + append_ignore
# bs += [ignore_index]*(max_output-len(bs))
return bestSequence


def generateBatch(
highestDominoe,
dominoes,
Expand Down

0 comments on commit cd23a98

Please sign in to comment.