From cd23a9895bda1e7f8a6ab547f81c40ed2ce57a53 Mon Sep 17 00:00:00 2001 From: landoskape Date: Thu, 11 Apr 2024 21:39:18 +0100 Subject: [PATCH] refactor many old functions, move to torch --- dominoes/datasets/support.py | 269 ++++++++++++++++++----------------- 1 file changed, 142 insertions(+), 127 deletions(-) diff --git a/dominoes/datasets/support.py b/dominoes/datasets/support.py index 1a1cd75..28e8e0c 100644 --- a/dominoes/datasets/support.py +++ b/dominoes/datasets/support.py @@ -33,105 +33,162 @@ def get_dominoe_set(highest_dominoe, as_torch=False): return stack_function(dominoe_set) -def get_best_line(dominoes, selection, highest_dominoe, value_method="dominoe"): - # check value method - if not (value_method == "dominoe" or value_method == "length"): - raise ValueError("did not recognize value_method, it has to be either 'dominoe' or 'length'") +def get_best_line(dominoes, available, value_method="dominoe"): + """ + get the best line of dominoes given a set of dominoes and an available token + + args: + dominoes: torch.tensor of shape (num_dominoes, 2) + available: (int) the value that is available to play on + value_method: (str) either "dominoe" or "length" to measure the value of the line + if "dominoe" the value is the sum of the dominoes in the line + if "length" the value is the length of the line + + returns: + best_sequence: the best sequence of dominoes + best_direction: the direction of each dominoe in the sequence - bestSequence = [] - bestDirection = [] - for sel in selection: - cBestSeq = [] - cBestDir = [] - cBestVal = [] - for available in range(highest_dominoe + 1): - cseq, cdir = construct_line_recursive(dominoes, sel, available) - if value_method == "dominoe": - cval = [np.sum(dominoes[cs]) for cs in cseq] - else: - cval = [len(cs) for cs in cseq] - cidx = max(enumerate(cval), key=lambda x: x[1])[0] - cBestSeq.append(cseq[cidx]) - cBestDir.append(cdir[cidx]) - cBestVal.append(cval[cidx]) - - cBestIdx = max(enumerate(cBestVal), key=lambda x: x[1])[0] - bestSequence.append(cBestSeq[cBestIdx]) - bestDirection.append(cBestDir[cBestIdx]) - - return bestSequence, bestDirection - - -def get_best_line_from_available(dominoes, selection, available, value_method="dominoe"): + """ # check value method if not (value_method == "dominoe" or value_method == "length"): raise ValueError("did not recognize value_method, it has to be either 'dominoe' or 'length'") - bestSequence = [] - bestDirection = [] - for sel, ava in zip(selection, available): - cseq, cdir = construct_line_recursive(dominoes, sel, ava) - if value_method == "dominoe": - cval = [np.sum(dominoes[cs]) for cs in cseq] - else: - cval = [len(cs) for cs in cseq] - cidx = max(enumerate(cval), key=lambda x: x[1])[0] - bestSequence.append(cseq[cidx]) - bestDirection.append(cdir[cidx]) - return bestSequence, bestDirection - - -def construct_line_recursive(dominoes, myHand, available, previousSequence=[], previousDirection=[], maxLineLength=None): - # this version of the function uses absolute dominoe numbers, rather than indexing based on which order they are in the hand - # if there are too many dominoes in hand, constructing all possible lines takes way too long... - if (maxLineLength is not None) and (len(previousSequence) == maxLineLength): - return [previousSequence], [previousDirection] - - assert type(previousSequence) == list and type(previousDirection) == list, "previous sequence and direction must be lists" - if len(previousSequence) > 0: - # if a previous sequence was provided, make sure the end of it matches what is defined as available - assert ( - dominoes[previousSequence[-1]][0 if previousDirection[-1] == 1 else 1] == available - ), "the end of the last sequence doesn't match what is defined as available!" - - # recursively constructs all possible lines given a hand (value pairs in list), an available value to play on, and the previous played/direction dominoe index sequences - hand = dominoes[myHand] - possiblePlays = np.where(np.any(hand == available, axis=1) & ~np.isin(myHand, previousSequence))[0] - - # if there are no possible plays, the return the finished sequence - if len(possiblePlays) == 0: - return [previousSequence], [previousDirection] - - # otherwise, make new lines for each possible play + # get all possible lines with this set of dominoes and the available token + allseqs, alldirs = construct_line_recursive(dominoes, available) + + # measure value with either dominoe method or length method + if value_method == "dominoe": + allval = [torch.sum(dominoes[seq]) for seq in allseqs] + else: + allval = [len(seq) for seq in allseqs] + + # get index to the best sequence + best_idx = max(enumerate(allval), key=lambda x: x[1])[0] + + # return the best sequence and direction + return allseqs[best_idx], alldirs[best_idx] + + +def pad_best_lines(best_seq, max_output, null_index, ignore_index=-1): + """ + pad the best sequence of dominoes to a fixed length + + args: + best_seq: the best sequence of dominoes + max_output: the maximum length of the sequence + null_index: the index of the null index (set to ignore_index if no null token) + ignore_index: the index of the ignore index + + returns: + padded_best_seq: the best sequence padded to a fixed length + """ + padded_best_seq = [] + for seq in best_seq: + c_length = len(seq) + append_null = [null_index] if max_output > c_length else [] + append_ignore = [ignore_index] * (max_output - (c_length + 1)) + seq += append_null + append_ignore + padded_best_seq.append(seq) + return padded_best_seq + + +def padBestLine(bestSequence, max_output, null_index, ignore_index=-1): + for bs in bestSequence: + c_length = len(bs) + append_null = [null_index] if max_output > c_length else [] + append_ignore = [ignore_index] * (max_output - (c_length + 1)) + bs += append_null + append_ignore + # bs += [ignore_index]*(max_output-len(bs)) + return bestSequence + + +def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=[], prev_dir=[], max_length=None): + """ + recursively construct all possible lines given a set of dominoes, an available value to play on, + and the previous played/direction dominoe index sequences. + + This method can be used in two ways: + 1. if hand_index is not provided, it will use all dominoes in the set and the resulting + sequences will use the indices of the dominoes in the set provided in the first argument. + 2. if hand_index is provided, it will only use those dominoes in the set and the resulting + sequences will use the indices of the dominoes in the hand_index list. + + args: + dominoes: torch.tensor or numpy nd.array of shape (num_dominoes, 2) + available: (int) the value that is available to play on + hand_index: (optional, list[int]) the index of the dominoes in the hand + prev_seq: the previous sequence of dominoes -- is used for recursion + prev_dir: the previous direction of the dominoes -- is used for recursion + max_length: the maximum length of the line + + returns: + sequence: the list of all possible sequences of dominoes (with indices corresponding + to the dominoes in the set or hand_index) + direction: the list of the direction each dominoe must be played within each sequence + """ + # if the maximum length of the sequence is reached, return sequence up to this point + if max_length is not None and len(prev_seq) == max_length: + return [prev_seq], [prev_dir] + + # check if previous sequence end position matches the available value + if len(prev_seq) > 0: + msg = "the end of the last sequence doesn't match what is defined as available!" + assert dominoes[prev_seq[-1]][0 if prev_dir[-1] == 1 else 1] == available, msg + + # convert dominoes to torch tensor if it is a numpy array + if isinstance(dominoes, np.ndarray): + dominoes = torch.tensor(dominoes) + + # if hand_index is not provided, use all dominoes in the set + if hand_index is None: + hand_index = torch.arange(len(dominoes)) + + # set hand ("playable dominoes") + hand = dominoes[hand_index] + + # find all dominoes in hand that can be played on the available token + possible_plays = torch.where(torch.any(hand == available, axis=1) & ~torch.isin(hand_index, prev_seq))[0] + + # if no more plays are possible, return the finished sequence and direction + if len(possible_plays) == 0: + return [prev_seq], [prev_dir] + + # otherwise create new lines for each possible play sequence = [] direction = [] - for idxPlay in possiblePlays: - # if the first value of the possible play matches the available, then play it in the forward direction - if hand[idxPlay][0] == available: - # copy previousSequence and previousDirection, append new play in forward direction to it - cseq = copy(previousSequence) - cseq.append(myHand[idxPlay]) - cdir = copy(previousDirection) + for idx_play in possible_plays: + # if the first value of the possible play matches the available value + if hand[idx_play][0] == available: + # add to sequence + cseq = copy(prev_seq) + cseq.append(hand_index[idx_play]) + # play in forward direction + cdir = copy(prev_dir) cdir.append(0) - # then recursively construct line from this standpoint - cSequence, cDirection = construct_line_recursive( - dominoes, myHand, hand[idxPlay][1], previousSequence=cseq, previousDirection=cdir, maxLineLength=maxLineLength + # construct sequence recursively from this new sequence + cseq, cdir = construct_line_recursive( + dominoes, hand[idx_play][1], hand_index=hand_index, prev_seq=cseq, prev_dir=cdir, max_length=max_length ) - # once lines are constructed, add them all to "sequence" and "direction", which will be a list of lists of all possible sequences - for cns, cnd in zip(cSequence, cDirection): + # add all sequence/direction lists to possible sequences + for cns, cnd in zip(cseq, cdir): sequence.append(cns) direction.append(cnd) - # if the second value of the possible play matches the available and it isn't a double, then play it in the reverse direction (all same except direction and next available) - if (hand[idxPlay][0] != hand[idxPlay][1]) and (hand[idxPlay][1] == available): - cseq = copy(previousSequence) - cseq.append(myHand[idxPlay]) - cdir = copy(previousDirection) + # if the second value of the possible play matches the available and it isn't a double, + # then play it in the reverse direction + else: + # add to sequence + cseq = copy(prev_seq) + cseq.append(hand_index[idx_play]) + # play in reverse direction + cdir = copy(prev_dir) cdir.append(1) - cSequence, cDirection = construct_line_recursive( - dominoes, myHand, hand[idxPlay][0], previousSequence=cseq, previousDirection=cdir, maxLineLength=maxLineLength + # construct sequence recursively from this new sequence + cseq, cdir = construct_line_recursive( + dominoes, hand[idx_play][0], hand_index=hand_index, prev_seq=cseq, prev_dir=cdir, max_length=max_length ) - for cns, cnd in zip(cSequence, cDirection): + # add all sequence/direction lists to possible sequences + for cns, cnd in zip(cseq, cdir): sequence.append(cns) direction.append(cnd) @@ -218,48 +275,6 @@ def sortPad(val, padTo, ignoreIndex=-1): return input, target, mask -def makeLines(input, dominoes, value_method="dominoe"): - selection, available = input # unpack - cseq, cdir = construct_line_recursive(dominoes, selection, available) - if value_method == "dominoe": - cval = [np.sum(dominoes[cs]) for cs in cseq] - else: - cval = [len(cs) for cs in cseq] - cidx = max(enumerate(cval), key=lambda x: x[1])[0] - return cseq[cidx], cdir[cidx] - - -def getBestLineFromAvailablePool(dominoes, selection, available, value_method="dominoe", threads=18): - # check value method - if not (value_method == "dominoe" or value_method == "length"): - raise ValueError("did not recognize value_method, it has to be either 'dominoe' or 'length'") - p_makeLines = partial(makeLines, dominoes=dominoes, value_method=value_method) - - with Pool(threads) as p: - lines = p.map(p_makeLines, zip(selection, available)) - bestSequence, bestDirection = map(list, zip(*lines)) - return bestSequence, bestDirection - - -def convertToHandIndex(selection, bestSequence): - indices = [] - for sel, seq in zip(selection, bestSequence): - # look up table for current selection - elementIdx = {element: idx for idx, element in enumerate(sel)} - indices.append([elementIdx[element] for element in seq]) - return indices - - -def padBestLine(bestSequence, max_output, null_index, ignore_index=-1): - for bs in bestSequence: - c_length = len(bs) - append_null = [null_index] if max_output > c_length else [] - append_ignore = [ignore_index] * (max_output - (c_length + 1)) - bs += append_null + append_ignore - # bs += [ignore_index]*(max_output-len(bs)) - return bestSequence - - def generateBatch( highestDominoe, dominoes,