Skip to content

Commit

Permalink
finish refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Apr 10, 2024
1 parent 1f3300e commit 4cda78e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 81 deletions.
4 changes: 2 additions & 2 deletions dominoes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 30,
"id": "40c715a5",
"metadata": {},
"outputs": [],
"source": [
"# TODO For refactoring\n",
"# continue working on generate_batch method in DominoesDataset\n",
"# "
"# generate_batch method needs help with \"set_target\""
]
},
{
Expand Down
79 changes: 0 additions & 79 deletions dominoes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,85 +94,6 @@ def listDominoes(highestDominoe):
return get_dominoe_set(highestDominoe, as_torch=False)


def get_dominoe_set(highest_dominoe, as_torch=False):
"""
Create a list of dominoes in a set with highest value of <highest_dominoe>
The dominoes are paired values (combinations with replacement) of integers
from 0 to <highest_dominoe>. This method returns either a numpy array or a
torch tensor of the dominoes as integers.
The shape will be (num_dominoes, 2) where the first column is the first value
of the dominoe, and the second column is the second value of the dominoe.
args:
highest_dominoe: the highest value of a dominoe
as_torch: return dominoes as torch tensor if True, otherwise return numpy array
returns:
dominoes: an array or tensor of dominoes in the set
"""
# given a standard rule for how to organize the list of dominoes as one-hot arrays, list the dominoes present in a one hot array
array_function = torch.tensor if as_torch else np.array
stack_function = torch.stack if as_torch else np.stack
dominoe_set = [array_function(quake, dtype=int) for quake in itertools.combinations_with_replacement(np.arange(highest_dominoe + 1), 2)]
return stack_function(dominoe_set)


def twohot_dominoe(selected, dominoes, highest_dominoe, available=None, available_token=False, null_token=False, with_batch=True):
"""
converts an index of selected dominoes to a stacked two-hot representation
dominoes are paired values (combinations with replacement) of integers
from 0 to <highest_dominoe>.
This simple representation is a two-hot vector where the first
<highest_dominoe>+1 elements represent the first value of the dominoe, and
the second <highest_dominoe>+1 elements represent the second value of the
dominoe. Here are some examples for <highest_dominoe> = 3:
(0 | 0): [1, 0, 0, 0, 1, 0, 0, 0]
(0 | 1): [1, 0, 0, 0, 0, 1, 0, 0]
(0 | 2): [1, 0, 0, 0, 0, 0, 1, 0]
(0 | 3): [1, 0, 0, 0, 0, 0, 0, 1]
(1 | 0): [0, 1, 0, 0, 1, 0, 0, 0]
(2 | 1): [0, 0, 1, 0, 0, 1, 0, 0]
"""
assert selected.ndim == 1, "selected must have shape (num_selected, )"
if available_token:
assert available is not None, "if with_available=True, then available needs to be provided"
num_dominoes = selected.shape[0]

# input dimension determined by highest dominoe (twice the number of possible values on a dominoe)
input_dim = (2 if not (available_token) else 3) * (highest_dominoe + 1) + (1 if null_token else 0)

# first & second value are index and shifted index
firstValue = torch.tensor(dominoes[selected, 0], dtype=torch.int64).unsqueeze(1)
secondValue = torch.tensor(dominoes[selected, 1] + highest_dominoe + 1, dtype=torch.int64).unsqueeze(1)

# scatter data into two-hot vectors
src = torch.ones((num_dominoes, 1), dtype=torch.float)
twohot = torch.zeros((num_dominoes, input_dim), dtype=torch.float)
twohot.scatter_(1, firstValue, src)
twohot.scatter_(1, secondValue, src)

if null_token:
null = torch.zeros((1, input_dim), dtype=torch.float).scatter_(1, torch.tensor(input_dim - 1).view(1, 1), torch.tensor(1.0).view(1, 1))
twohot = torch.cat((twohot, null), dim=0)

if available_token:
rep_available = torch.zeros((1, input_dim), dtype=torch.float)
availableidx = int((highest_dominoe + 1) * 2 + available)
rep_available[0, availableidx] = 1.0
twohot = torch.cat((twohot, rep_available), dim=0)

if with_batch:
twohot = twohot.unsqueeze(0)

return twohot


def dominoesString(dominoe):
return f"{dominoe[0]:>2}|{dominoe[1]:<2}"

Expand Down

0 comments on commit 4cda78e

Please sign in to comment.