Skip to content

Commit

Permalink
Added new toolnet sequence model
Browse files Browse the repository at this point in the history
  • Loading branch information
shreshthtuli committed Aug 24, 2020
1 parent ed847b4 commit aae0be0
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 21 deletions.
83 changes: 83 additions & 0 deletions src/GNN/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,86 @@ def forward(self, g, goalVec, goalObjectsVec, tool_vec):
probNoTool = torch.sigmoid(self.activation(self.p2(probNoTool))).flatten()
output = torch.cat(((1-probNoTool)*tools.flatten(), probNoTool), dim=0)
return output


######################################################################################

# The following models are for tool sequence prediction

class GGCN_Metric_Attn_L_NT_Tseq_C(nn.Module):
"""
The best performing model for the sequential tool prediction task.
Separate likelihood prediction of no tool for more robust tool output considering any tool
to be used or not as prior.
"""
def __init__(self,
in_feats,
n_objects,
n_hidden,
n_classes,
n_layers,
etypes,
activation,
dropout,
embedding,
weighted):
super(GGCN_Metric_Attn_L_NT_Tseq_C, self).__init__()
self.n_classes = n_classes
self.etypes = etypes
self.name = "GGCN_Metric_Attn_L_NT_Tseq_C_" + str(n_hidden) + "_" + str(n_layers)
self.n_hidden = n_hidden
self.layers = nn.ModuleList()
self.layers.append(nn.Linear(in_feats + n_objects*4, n_hidden))
for i in range(n_layers - 1):
self.layers.append(nn.Linear(n_hidden, n_hidden))
self.tool_lstm = nn.LSTM(n_hidden, n_hidden)
self.attention = nn.Linear(n_hidden + n_hidden + n_hidden, n_hidden)
self.attention2 = nn.Linear(n_hidden, 1)
self.embed = nn.Linear(PRETRAINED_VECTOR_SIZE, n_hidden)
self.fc1 = nn.Linear(4 * n_hidden, n_hidden)
self.fc2 = nn.Linear(n_hidden, n_hidden)
self.fc3 = nn.Linear(n_hidden, n_hidden)
self.fc4 = nn.Linear(n_hidden, 1)
self.p1 = nn.Linear(3 * n_hidden, n_hidden)
self.p2 = nn.Linear(n_hidden, 1)
self.final = nn.Sigmoid()
self.activation = nn.PReLU()

def forward(self, g_list, goalVec, goalObjectsVec, tool_vec, t_list):
tool_embedding = self.activation(self.embed(tool_vec))
t_list = [(tool_embedding[TOOLS.index(i)] if i!='no-tool' else torch.zeros(self.n_hidden)).view(1, -1) for i in t_list]
lstm_hidden = (torch.randn(1, 1, self.n_hidden), torch.randn(1, 1, self.n_hidden))
goalObjectsVec = self.activation(self.embed(goalObjectsVec))
goal_embed = self.activation(self.embed(goalVec))
predicted_tools = []
for ind,g in enumerate(g_list):
h = g.ndata['feat']
edgeMatrices = [g.adjacency_matrix(etype=t) for t in self.etypes]
edges = torch.cat(edgeMatrices, 1).to_dense()
h = torch.cat((h, edges), 1)
for i, layer in enumerate(self.layers):
h = self.activation(layer(h))
if (ind != 0):
lstm_out, lstm_hidden = self.tool_lstm(t_list[ind-1].view(1,1,-1), lstm_hidden)
else:
lstm_out = torch.zeros(1, 1, self.n_hidden)
lstm_out = lstm_out.view(-1)
attn_embedding = torch.cat([h, goalObjectsVec.repeat(h.size(0)).view(h.size(0), -1), lstm_out.repeat(h.size(0)).view(h.size(0), -1)], 1)
attn_embedding = self.activation(self.attention(attn_embedding))
attn_weights = F.softmax(self.attention2(attn_embedding), dim=0)
scene_embedding = torch.mm(attn_weights.t(), h)
scene_and_goal = torch.cat([scene_embedding, goal_embed.view(1,-1), lstm_out.view(1,-1)], 1)
l = []
for i in range(NUMTOOLS-1):
final_to_decode = torch.cat([scene_and_goal, tool_embedding[i].view(1, -1)], 1)
h = self.activation(self.fc1(final_to_decode))
h = self.activation(self.fc2(h))
h = self.activation(self.fc3(h))
h = self.final(self.fc4(h))
l.append(h.flatten())
tools = torch.stack(l)
probNoTool = self.activation(self.p1(scene_and_goal))
probNoTool = torch.sigmoid(self.activation(self.p2(probNoTool))).flatten()
output = torch.cat(((1-probNoTool)*tools.flatten(), probNoTool), dim=0)
predicted_tools.append(output)
return predicted_tools
66 changes: 45 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,33 @@ def accuracy_score(dset, graphs, model, modelEnc, num_objects = 0, verbose = Fal
goal_num, world_num, tools, g, t = graph
if 'gcn_seq' in training:
actionSeq, graphSeq = g; loss = 0; toolSeq = tools
for i, g in enumerate(graphSeq):
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
y_true = torch.zeros(NUMTOOLS)
y_true[TOOLS.index(toolSeq[i])] = 1
total_test_loss += l(y_pred.view(1,-1), y_true)
y_pred = list(y_pred.reshape(-1))
# tools_possible = dset.goal_scene_to_tools[(goal_num,world_num)]
tool_predicted = TOOLS[y_pred.index(max(y_pred))]
if tool_predicted == toolSeq[i]:
total_correct += 1
elif verbose:
print (goal_num, world_num, tool_predicted, toolSeq[i])
denominator += 1
if 'Tseq' in model.name:
y_pred = model(graphSeq, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec, tools)
for i in range(len(y_pred)):
y_pred_i = list(y_pred[i].reshape(-1))
tool_predicted = TOOLS[y_pred_i.index(max(y_pred_i))]
y_true = torch.zeros(NUMTOOLS)
y_true[TOOLS.index(toolSeq[i])] = 1
total_test_loss += l(y_pred[i].view(1,-1), y_true)
if tool_predicted == toolSeq[i]:
total_correct += 1
elif verbose:
print (goal_num, world_num, tool_predicted, toolSeq[i])
denominator += 1
else:
for i, g in enumerate(graphSeq):
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
y_true = torch.zeros(NUMTOOLS)
y_true[TOOLS.index(toolSeq[i])] = 1
total_test_loss += l(y_pred.view(1,-1), y_true)
y_pred = list(y_pred.reshape(-1))
# tools_possible = dset.goal_scene_to_tools[(goal_num,world_num)]
tool_predicted = TOOLS[y_pred.index(max(y_pred))]
if tool_predicted == toolSeq[i]:
total_correct += 1
elif verbose:
print (goal_num, world_num, tool_predicted, toolSeq[i])
denominator += 1
continue
elif 'gcn' in training:
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
Expand Down Expand Up @@ -394,13 +408,20 @@ def backprop(data, optimizer, graphs, model, num_objects, modelEnc=None, batch_s
for iter_num, graph in tqdm(list(enumerate(graphs)), ncols=80):
goal_num, world_num, tools, g, t = graph
if 'gcn_seq' in training:
actionSeq, graphSeq = g; loss = 0; toolSeq = tools
for i, g in enumerate(graphSeq):
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
y_true = torch.zeros(NUMTOOLS)
y_true[TOOLS.index(tools[i])] = 1
loss += l(y_pred.view(1,-1), y_true)
if weighted: loss *= (1 if t == data.min_time[(goal_num, world_num)] else 0.5)
actionSeq, graphSeq = g; loss = 0
if 'Tseq' in model.name:
y_pred = model(graphSeq, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec, tools)
for i in range(len(y_pred)):
y_true = torch.zeros(NUMTOOLS)
y_true[TOOLS.index(tools[i])] = 1
loss += l(y_pred[i].view(1,-1), y_true)
else:
for i,g in enumerate(graphSeq):
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
y_true = torch.zeros(NUMTOOLS)
y_true[TOOLS.index(tools[i])] = 1
loss += l(y_pred.view(1,-1), y_true)
if weighted: loss *= (1 if t == data.min_time[(goal_num, world_num)] else 0.5)
batch_loss += loss
elif 'gcn' in training:
y_pred = model(g, goal2vec[goal_num], goalObjects2vec[goal_num], tool_vec)
Expand Down Expand Up @@ -518,7 +539,10 @@ def get_model(model_name):
if training == 'gcn' or training == 'gcn_seq':
size, layers = (4, 5) if training == 'gcn' else (2, 3)
modelEnc = None
if ("Final" not in model_name and "_NT" in model_name) or "Final_W" in model_name:
if "Tseq" in model_name:
model_class = getattr(src.GNN.models, "GGCN_Metric_Attn_L_NT_Tseq_C")
model = model_class(data.features, data.num_objects, size * GRAPH_HIDDEN, NUMTOOLS, layers, etypes, torch.tanh, 0.5, embedding, weighted)
elif ("Final" not in model_name and "_NT" in model_name) or "Final_W" in model_name:
model_class = getattr(src.GNN.models, "DGL_Simple_Likelihood")
model = model_class(data.features, data.num_objects, size * GRAPH_HIDDEN, NUMTOOLS, layers, etypes, torch.tanh, 0.5, embedding, weighted)
else:
Expand Down
Binary file not shown.

0 comments on commit aae0be0

Please sign in to comment.