Skip to content

Commit

Permalink
GPU support added
Browse files Browse the repository at this point in the history
  • Loading branch information
DebadityaPal committed May 4, 2024
1 parent c35f6aa commit 98b3567
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/disentangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
EPOCHS = args.epochs
DROP = args.drop
MAX_DIST = args.max_dist
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def header(args, out=sys.stdout):
head_text = "# "+ time.ctime(time.time())
Expand Down Expand Up @@ -446,31 +447,32 @@ def __init__(self):
pretrained.append(vector)
NWORDS = len(self.id_to_token)
DIM_WORDS = len(pretrained[0])
self.pEmbedding = nn.Embedding.from_pretrained(torch.tensor(pretrained))
self.pEmbedding = nn.Embedding.from_pretrained(torch.tensor(pretrained)).to(DEVICE)
input_size += 4 * DIM_WORDS

# Create hidden layers
self.hidden = []
self.hidden.append(nn.Linear(input_size, HIDDEN))
self.hidden.append(nn.Linear(input_size, HIDDEN, device=DEVICE))
for i in range(args.layers - 1):
self.hidden.append(nn.Linear(HIDDEN, HIDDEN))
self.final_sum = nn.Linear(HIDDEN, 1)
self.hidden.append(nn.Linear(HIDDEN, HIDDEN, device=DEVICE))
self.final_sum = nn.Linear(HIDDEN, 1, device=DEVICE)

def forward(self, query, options, gold, lengths, query_no):
if len(options) == 1:
return None, 0
final = []
if args.word_vectors:
qvecs = self.pEmbedding(torch.tensor(query))
qvecs = self.pEmbedding(torch.tensor(query, device=DEVICE))
qvec_max = torch.max(qvecs, 0)[0]
qvec_mean = torch.mean(qvecs, 0)
for otext, features in options:
inputs = torch.tensor(features)
inputs = torch.tensor(features, device=DEVICE)
if args.word_vectors:
ovecs = self.pEmbedding(torch.tensor(otext))
ovecs = self.pEmbedding(torch.tensor(otext, device=DEVICE))
ovec_max = torch.max(ovecs, 0)[0]
ovec_mean = torch.mean(ovecs, 0)
inputs = torch.cat([inputs, qvec_max, qvec_mean, ovec_max, ovec_mean])
inputs = torch.cat((inputs, qvec_max, qvec_mean, ovec_max, ovec_mean))

if args.drop > 0:
inputs = F.dropout(inputs, args.drop)
h = inputs
Expand Down Expand Up @@ -500,9 +502,9 @@ def forward(self, query, options, gold, lengths, query_no):
dense_gold = []
for i in range(len(options)):
dense_gold.append(1.0 / len(gold) if i in gold else 0.0)
answer = torch.tensor(dense_gold)
answer = torch.tensor(dense_gold, device=DEVICE)
loss = torch.dot(answer, nll)
predicted_link = np.argmax(final.data.numpy())
predicted_link = torch.argmax(final)
return loss, predicted_link

def get_ids(self, words):
Expand Down

0 comments on commit 98b3567

Please sign in to comment.