diff --git a/src/disentangle.py b/src/disentangle.py index 8d96cd8..6efaeb2 100755 --- a/src/disentangle.py +++ b/src/disentangle.py @@ -57,6 +57,7 @@ EPOCHS = args.epochs DROP = args.drop MAX_DIST = args.max_dist +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def header(args, out=sys.stdout): head_text = "# "+ time.ctime(time.time()) @@ -446,31 +447,32 @@ def __init__(self): pretrained.append(vector) NWORDS = len(self.id_to_token) DIM_WORDS = len(pretrained[0]) - self.pEmbedding = nn.Embedding.from_pretrained(torch.tensor(pretrained)) + self.pEmbedding = nn.Embedding.from_pretrained(torch.tensor(pretrained)).to(DEVICE) input_size += 4 * DIM_WORDS # Create hidden layers self.hidden = [] - self.hidden.append(nn.Linear(input_size, HIDDEN)) + self.hidden.append(nn.Linear(input_size, HIDDEN, device=DEVICE)) for i in range(args.layers - 1): - self.hidden.append(nn.Linear(HIDDEN, HIDDEN)) - self.final_sum = nn.Linear(HIDDEN, 1) + self.hidden.append(nn.Linear(HIDDEN, HIDDEN, device=DEVICE)) + self.final_sum = nn.Linear(HIDDEN, 1, device=DEVICE) def forward(self, query, options, gold, lengths, query_no): if len(options) == 1: return None, 0 final = [] if args.word_vectors: - qvecs = self.pEmbedding(torch.tensor(query)) + qvecs = self.pEmbedding(torch.tensor(query, device=DEVICE)) qvec_max = torch.max(qvecs, 0)[0] qvec_mean = torch.mean(qvecs, 0) for otext, features in options: - inputs = torch.tensor(features) + inputs = torch.tensor(features, device=DEVICE) if args.word_vectors: - ovecs = self.pEmbedding(torch.tensor(otext)) + ovecs = self.pEmbedding(torch.tensor(otext, device=DEVICE)) ovec_max = torch.max(ovecs, 0)[0] ovec_mean = torch.mean(ovecs, 0) - inputs = torch.cat([inputs, qvec_max, qvec_mean, ovec_max, ovec_mean]) + inputs = torch.cat((inputs, qvec_max, qvec_mean, ovec_max, ovec_mean)) + if args.drop > 0: inputs = F.dropout(inputs, args.drop) h = inputs @@ -500,9 +502,9 @@ def forward(self, query, options, gold, lengths, query_no): dense_gold = [] for i in range(len(options)): dense_gold.append(1.0 / len(gold) if i in gold else 0.0) - answer = torch.tensor(dense_gold) + answer = torch.tensor(dense_gold, device=DEVICE) loss = torch.dot(answer, nll) - predicted_link = np.argmax(final.data.numpy()) + predicted_link = torch.argmax(final) return loss, predicted_link def get_ids(self, words):