Skip to content

Commit

Permalink
convert some of the defaults to torch
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 16, 2024
1 parent 0bceb61 commit 19a1d43
Showing 1 changed file with 23 additions and 150 deletions.
173 changes: 23 additions & 150 deletions dominoes/datasets/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,21 @@ def pad_best_lines(best_seq, max_output, null_index, ignore_index=-1):
returns:
padded_best_seq: the best sequence padded to a fixed length
"""

def as_tensor(seq):
return torch.tensor(seq, dtype=torch.long)

padded_best_seq = []
for seq in best_seq:
c_length = len(seq)
append_null = [null_index] if max_output > c_length else []
append_ignore = [ignore_index] * (max_output - (c_length + 1))
seq += append_null + append_ignore
seq = torch.cat((seq, as_tensor(append_null), as_tensor(append_ignore)))
padded_best_seq.append(seq)
return padded_best_seq


def padBestLine(bestSequence, max_output, null_index, ignore_index=-1):
for bs in bestSequence:
c_length = len(bs)
append_null = [null_index] if max_output > c_length else []
append_ignore = [ignore_index] * (max_output - (c_length + 1))
bs += append_null + append_ignore
# bs += [ignore_index]*(max_output-len(bs))
return bestSequence


def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=[], prev_dir=[], max_length=None):
def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=None, prev_dir=None, max_length=None):
"""
recursively construct all possible lines given a set of dominoes, an available value to play on,
and the previous played/direction dominoe index sequences.
Expand All @@ -126,28 +120,33 @@ def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=[],
to the dominoes in the set or hand_index)
direction: the list of the direction each dominoe must be played within each sequence
"""
# if prev_seq and prev_dir are not provided, that means this is the first call
if prev_seq is None:
prev_seq = torch.tensor([], dtype=torch.long)
prev_dir = torch.tensor([], dtype=torch.long)

# if the maximum length of the sequence is reached, return sequence up to this point
if max_length is not None and len(prev_seq) == max_length:
return [prev_seq], [prev_dir]

# check if previous sequence end position matches the available value
if len(prev_seq) > 0:
msg = "the end of the last sequence doesn't match what is defined as available!"
assert dominoes[prev_seq[-1]][0 if prev_dir[-1] == 1 else 1] == available, msg

# convert dominoes to torch tensor if it is a numpy array
if isinstance(dominoes, np.ndarray):
dominoes = torch.tensor(dominoes)

# if hand_index is not provided, use all dominoes in the set
if hand_index is None:
hand_index = torch.arange(len(dominoes))
hand_index = torch.arange(len(dominoes), dtype=torch.long)

# set hand ("playable dominoes")
hand = dominoes[hand_index]

# check if previous sequence end position matches the available value
if len(prev_seq) > 0:
msg = "the end of the last sequence doesn't match what is defined as available!"
assert hand[prev_seq[-1]][0 if prev_dir[-1] == 1 else 1] == available, msg

# find all dominoes in hand that can be played on the available token
possible_plays = torch.where(torch.any(hand == available, axis=1) & ~torch.isin(hand_index, prev_seq))[0]
possible_plays = torch.where(torch.any(hand == available, dim=1) & ~torch.isin(hand_index, prev_seq))[0]

# if no more plays are possible, return the finished sequence and direction
if len(possible_plays) == 0:
Expand All @@ -160,11 +159,9 @@ def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=[],
# if the first value of the possible play matches the available value
if hand[idx_play][0] == available:
# add to sequence
cseq = copy(prev_seq)
cseq.append(hand_index[idx_play])
cseq = torch.cat((prev_seq.clone().view(-1), hand_index[idx_play].view(1)))
# play in forward direction
cdir = copy(prev_dir)
cdir.append(0)
cdir = torch.cat((prev_dir.clone().view(-1), torch.tensor(0, dtype=torch.long).view(1)))
# construct sequence recursively from this new sequence
cseq, cdir = construct_line_recursive(
dominoes, hand[idx_play][1], hand_index=hand_index, prev_seq=cseq, prev_dir=cdir, max_length=max_length
Expand All @@ -178,11 +175,9 @@ def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=[],
# then play it in the reverse direction
else:
# add to sequence
cseq = copy(prev_seq)
cseq.append(hand_index[idx_play])
# play in reverse direction
cdir = copy(prev_dir)
cdir.append(1)
cseq = torch.cat((prev_seq.clone().view(-1), hand_index[idx_play].view(1)))
# play in forward direction
cdir = torch.cat((prev_dir.clone().view(-1), torch.tensor(1, dtype=torch.long).view(1)))
# construct sequence recursively from this new sequence
cseq, cdir = construct_line_recursive(
dominoes, hand[idx_play][0], hand_index=hand_index, prev_seq=cseq, prev_dir=cdir, max_length=max_length
Expand All @@ -196,128 +191,6 @@ def construct_line_recursive(dominoes, available, hand_index=None, prev_seq=[],
return sequence, direction


def dominoeUnevenBatch(batchSize, minSeq, maxSeq, listDominoes, dominoeValue, highestDominoe, ignoreIndex=-1, return_full=False):
"""
retrieve a batch of dominoes and their target order given the value of each dominoe
dominoes are paired values (combinations with replacement) of integers
from 0 to <highestDominoe>. The total value of each dominoe is the sum of
the two integers associated with that dominoe. For example, the dominoe
(7|3) has value 10.
Each element in the batch contains an input and target. The input is
composed of a sequence of dominoes in a random order, transformed into a
simple representation (explained below). The target is a list of the order
of dominoes by the one with the highest value to the one with the lowest
value. Note that many dominoes share the same value, but since the dominoe
list is always the same, equal value dominoes will always be sorted in the
same way.
Each element can have a different sequence length, they will be padded
with zeros to whatever the longest sequence is. The ignoreIndex is used to
determine what to label targets for any padded elements (i.e. any place
where no prediction is needed). The nll_loss function then accepts this as
an input to ignore. This is part of the reason why pointer networks are
awesome... the input and output can vary in size!!!
The simple representation is a two-hot vector where the first
<highestDominoe+1> elements represent the first value of the dominoe, and
the second <highestDominoe+1> elements represent the second value of the
dominoe. Here are some examples for highest dominoe = 3:
(0 | 0): [1, 0, 0, 0, 1, 0, 0, 0]
(0 | 1): [1, 0, 0, 0, 0, 1, 0, 0]
(0 | 2): [1, 0, 0, 0, 0, 0, 1, 0]
(0 | 3): [1, 0, 0, 0, 0, 0, 0, 1]
(1 | 0): [0, 1, 0, 0, 1, 0, 0, 0]
(2 | 1): [0, 0, 1, 0, 0, 1, 0, 0]
"""
numDominoes = len(listDominoes)
input_dim = 2 * (highestDominoe + 1)

# choose how long each sequence in the batch will be
seqLength = np.random.randint(minSeq, maxSeq + 1, batchSize)
maxSeqLength = max(seqLength) # max sequence length for padding

# choose dominoes from the batch, and get their value (in points)
selection = [np.random.choice(numDominoes, sl, replace=False).tolist() for sl in seqLength]
value = [dominoeValue[sel] for sel in selection]

# index of first and second value in two-hot representation
pad = [[0] * (maxSeqLength - sl) for sl in seqLength]
firstValue = np.stack([listDominoes[sel, 0].tolist() + p for p, sel in zip(pad, selection)])
secondValue = np.stack([(listDominoes[sel, 1] + highestDominoe + 1).tolist() + p for p, sel in zip(pad, selection)])
firstValue = torch.tensor(firstValue, dtype=torch.int64).unsqueeze(2)
secondValue = torch.tensor(secondValue, dtype=torch.int64).unsqueeze(2)

# create mask (used for scattering and also as an output)
mask = 1.0 * (torch.arange(maxSeqLength).view(1, -1).expand(batchSize, -1) < torch.tensor(seqLength).view(-1, 1))

# scatter data into two-hot vectors, except where sequence length is exceed where the mask is 0
input = torch.zeros((batchSize, maxSeqLength, input_dim), dtype=torch.float)
input.scatter_(2, firstValue, mask.float().unsqueeze(2))
input.scatter_(2, secondValue, mask.float().unsqueeze(2))

# sort and pad each list of dominoes by value
def sortPad(val, padTo, ignoreIndex=-1):
s = sorted(range(len(val)), key=lambda i: -val[i])
p = [ignoreIndex] * (padTo - len(val))
return s + p

# create a padded sort index, then turn into a torch tensor as the target vector
sortIdx = [sortPad(val, maxSeqLength, ignoreIndex) for val in value] # pad with ignore index so nll_loss ignores them
target = torch.stack([torch.LongTensor(idx) for idx in sortIdx])

if return_full:
return input, target, mask, selection
else:
return input, target, mask


def generateBatch(
highestDominoe,
dominoes,
batch_size,
numInHand,
return_target=True,
value_method="dominoe",
available_token=False,
null_token=False,
ignore_index=-1,
return_full=False,
):

input, selection, available = random_dominoe_hand(
numInHand, dominoes, highestDominoe, batch_size=batch_size, null_token=null_token, available_token=available_token
)

mask_tokens = numInHand + (1 if null_token else 0) + (1 if available_token else 0)
mask = torch.ones((batch_size, mask_tokens), dtype=torch.float)

if return_target:
# then measure best line and convert it to a "target" array
if available_token:
bestSequence, bestDirection = getBestLineFromAvailable(dominoes, selection, available, value_method=value_method)
else:
bestSequence, bestDirection = getBestLine(dominoes, selection, highestDominoe, value_method=value_method)

# convert sequence to hand index
iseq = convertToHandIndex(selection, bestSequence)

# create target and append null_index once, then ignore_index afterwards
# the idea is that the agent should play the best line, then indicate that the line is over, then anything else doesn't matter
null_index = numInHand if null_token else ignore_index
target = torch.tensor(np.stack(padBestLine(iseq, numInHand + 1, null_index, ignore_index=ignore_index)), dtype=torch.long)
else:
# otherwise set these to None so we can use the same return structure
target, bestSequence, bestDirection = None, None, None

if return_full:
return input, target, mask, bestSequence, bestDirection, selection, available
return input, target, mask


def held_karp(dists):
"""
Implementation of Held-Karp, an algorithm that solves the Traveling
Expand Down

0 comments on commit 19a1d43

Please sign in to comment.