Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated DyNet to PyTorch #7

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 53 additions & 63 deletions src/disentangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import sys
import string
import time

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

FEATURES = 77
Expand All @@ -32,7 +34,6 @@

# Inference arguments
parser.add_argument('--max-dist', default=101, type=int, help="Maximum number of messages to consider when forming a link (count includes the current message).")
parser.add_argument('--dynet-autobatch', action='store_true', help="Use dynet autobatching.")

# Training arguments
parser.add_argument('--report-freq', default=5000, type=int, help="How frequently to evaluate on the development set.")
Expand All @@ -56,6 +57,7 @@
EPOCHS = args.epochs
DROP = args.drop
MAX_DIST = args.max_dist
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def header(args, out=sys.stdout):
head_text = "# "+ time.ctime(time.time())
Expand All @@ -66,11 +68,6 @@ def header(args, out=sys.stdout):
log_file = open(args.prefix +".log", 'w')
header(sys.argv, [log_file, sys.stdout])

import dynet_config
batching = 1 if args.dynet_autobatch else 0
dynet_config.set(mem=512, autobatch=batching, weight_decay=WEIGHT_DECAY, random_seed=args.seed)
import dynet as dy

from reserved_words import reserved


Expand Down Expand Up @@ -430,11 +427,9 @@ def simplify_token(token):
return ''.join(chars)


class DyNetModel():
class PyTorchModel(nn.Module):
def __init__(self):
super().__init__()

self.model = dy.ParameterCollection()
super(PyTorchModel, self).__init__()

input_size = FEATURES

Expand All @@ -452,78 +447,75 @@ def __init__(self):
pretrained.append(vector)
NWORDS = len(self.id_to_token)
DIM_WORDS = len(pretrained[0])
self.pEmbedding = self.model.add_lookup_parameters((NWORDS, DIM_WORDS))
self.pEmbedding.init_from_array(np.array(pretrained))
self.pEmbedding = nn.Embedding.from_pretrained(torch.tensor(pretrained)).to(DEVICE)
input_size += 4 * DIM_WORDS

# Create hidden layers
self.hidden = []
self.bias = []
self.hidden.append(self.model.add_parameters((HIDDEN, input_size)))
self.bias.append(self.model.add_parameters((HIDDEN,)))
self.hidden.append(nn.Linear(input_size, HIDDEN, device=DEVICE))
for i in range(args.layers - 1):
self.hidden.append(self.model.add_parameters((HIDDEN, HIDDEN)))
self.bias.append(self.model.add_parameters((HIDDEN,)))
self.final_sum = self.model.add_parameters((HIDDEN, 1))
self.hidden.append(nn.Linear(HIDDEN, HIDDEN, device=DEVICE))
self.final_sum = nn.Linear(HIDDEN, 1, device=DEVICE)

def __call__(self, query, options, gold, lengths, query_no):
def forward(self, query, options, gold, lengths, query_no):
if len(options) == 1:
return None, 0

final = []
if args.word_vectors:
qvecs = [dy.lookup(self.pEmbedding, w) for w in query]
qvec_max = dy.emax(qvecs)
qvec_mean = dy.average(qvecs)
qvecs = self.pEmbedding(torch.tensor(query, device=DEVICE))
qvec_max = torch.max(qvecs, 0)[0]
qvec_mean = torch.mean(qvecs, 0)
for otext, features in options:
inputs = dy.inputTensor(features)
inputs = torch.tensor(features, device=DEVICE)
if args.word_vectors:
ovecs = [dy.lookup(self.pEmbedding, w) for w in otext]
ovec_max = dy.emax(ovecs)
ovec_mean = dy.average(ovecs)
inputs = dy.concatenate([inputs, qvec_max, qvec_mean, ovec_max, ovec_mean])
ovecs = self.pEmbedding(torch.tensor(otext, device=DEVICE))
ovec_max = torch.max(ovecs, 0)[0]
ovec_mean = torch.mean(ovecs, 0)
inputs = torch.cat((inputs, qvec_max, qvec_mean, ovec_max, ovec_mean))

if args.drop > 0:
inputs = dy.dropout(inputs, args.drop)
inputs = F.dropout(inputs, args.drop)
h = inputs
for pH, pB in zip(self.hidden, self.bias):
h = dy.affine_transform([pB, pH, h])
for layer in self.hidden:
h = layer(h)
if args.nonlin == "linear":
pass
elif args.nonlin == "tanh":
h = dy.tanh(h)
h = F.tanh(h)
elif args.nonlin == "cube":
h = dy.cube(h)
h = h ** 3
elif args.nonlin == "logistic":
h = dy.logistic(h)
h = F.sigmoid(h)
elif args.nonlin == "relu":
h = dy.rectify(h)
h = F.relu(h)
elif args.nonlin == "elu":
h = dy.elu(h)
h = F.elu(h)
elif args.nonlin == "selu":
h = dy.selu(h)
h = F.selu(h)
elif args.nonlin == "softsign":
h = dy.softsign(h)
h = h / (1 + torch.abs(h))
elif args.nonlin == "swish":
h = dy.cmult(h, dy.logistic(h))
final.append(dy.sum_dim(h, [0]))

final = dy.concatenate(final)
nll = -dy.log_softmax(final)
h = h * F.sigmoid(h)
final.append(torch.sum(h, 0))
final = torch.stack(final)
nll = -F.log_softmax(final, dim=0)
dense_gold = []
for i in range(len(options)):
dense_gold.append(1.0 / len(gold) if i in gold else 0.0)
answer = dy.inputTensor(dense_gold)
loss = dy.transpose(answer) * nll
predicted_link = np.argmax(final.npvalue())

answer = torch.tensor(dense_gold, device=DEVICE)
loss = torch.dot(answer, nll)
predicted_link = torch.argmax(final)
return loss, predicted_link

def get_ids(self, words):
ans = []
backup = self.token_to_id.get('<unka>', 0)
for word in words:
ans.append(self.token_to_id.get(word, backup))
return ans

###############################################################################

def do_instance(instance, train, model, optimizer, do_cache=True):
name, query, gold, text_ascii, text_tok, info, target_info = instance

Expand All @@ -549,8 +541,9 @@ def do_instance(instance, train, model, optimizer, do_cache=True):
loss = 0.0
if train and example_loss is not None:
example_loss.backward()
optimizer.update()
loss = example_loss.scalar_value()
optimizer.step()
optimizer.zero_grad()
loss = example_loss.item()
predicted = output
matched = (predicted in gold)

Expand All @@ -576,13 +569,14 @@ def do_instance(instance, train, model, optimizer, do_cache=True):
model = None
optimizer = None
scheduler = None
model = DyNetModel()
model = PyTorchModel()
optimizer = None
if args.opt == 'sgd':
optimizer = dy.SimpleSGDTrainer(model.model, learning_rate=LEARNING_RATE)
elif args.opt == 'mom':
optimizer = dy.MomentumSGDTrainer(model.model, learning_rate=LEARNING_RATE, mom=MOMENTUM)
optimizer.set_clip_threshold(args.clip)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
elif args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
nn.utils.clip_grad_norm_(model.parameters(), args.clip)


prev_best = None
if args.train:
Expand All @@ -591,7 +585,7 @@ def do_instance(instance, train, model, optimizer, do_cache=True):
random.shuffle(train)

# Update learning rate
optimizer.learning_rate = LEARNING_RATE / (1+ LEARNING_DECAY_RATE * epoch)
optimizer.param_groups[0]['lr'] = LEARNING_RATE / (1+ LEARNING_DECAY_RATE * epoch)

# Loop over batches
loss = 0
Expand All @@ -600,8 +594,6 @@ def do_instance(instance, train, model, optimizer, do_cache=True):
loss_steps = 0
for instance in train:
step += 1

dy.renew_cg()
ex_loss, matched, _ = do_instance(instance, True, model, optimizer)
loss += ex_loss
loss_steps += 1
Expand All @@ -615,7 +607,6 @@ def do_instance(instance, train, model, optimizer, do_cache=True):
dev_match = 0
dev_total = 0
for dinstance in dev:
dy.renew_cg()
_, matched, _ = do_instance(dinstance, False, model, optimizer)
if matched:
dev_match += 1
Expand All @@ -628,7 +619,7 @@ def do_instance(instance, train, model, optimizer, do_cache=True):

if prev_best is None or prev_best[0] < dacc:
prev_best = (dacc, epoch)
model.model.save(args.prefix + ".dy.model")
torch.save(model.state_dict(), args.prefix + ".torch.model")

if prev_best is not None and epoch - prev_best[1] > 5:
break
Expand All @@ -637,12 +628,11 @@ def do_instance(instance, train, model, optimizer, do_cache=True):
if prev_best is not None or args.model:
location = args.model
if location is None:
location = args.prefix +".dy.model"
location = args.prefix + ".torch.model"
model.model.populate(location)

# Run on test instances
for instance in test:
dy.renew_cg()
_, _, prediction = do_instance(instance, False, model, optimizer, False)
print("{}:{} {} -".format(instance[0], instance[1], instance[1] - prediction))

Expand Down