Skip to content

Commit

Permalink
Use normalized probability for eval in NCE
Browse files Browse the repository at this point in the history
Refactor the train function into global variable independent style.
  • Loading branch information
Stonesjtu committed Nov 15, 2017
1 parent 73cbe2f commit 0a6dfa4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
29 changes: 11 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,7 @@ def setup_parser():
pin_memory=args.cuda,
)

print(corpus.train.dataset.dictionary.idx2word[0])


eval_batch_size = args.batch_size

################################################################## Build the criterion and model
#################################################################

Expand All @@ -122,35 +118,34 @@ def setup_parser():
noise=noise,
noise_ratio=args.noise_ratio,
norm_term=args.norm_term,
normed_eval=True, # evaluate PPL using normalized prob
)
else:
criterion = crossEntropy.CELoss(
ntokens=ntokens,
nhidden=args.nhid,
)

evaluate_criterion = crossEntropy.CELoss(
ntokens=ntokens,
nhidden=args.nhid,
decoder_weight=(criterion.decoder.weight, criterion.decoder.bias),
)

model = RNNModel(ntokens, args.emsize, args.nhid, args.nlayers,
criterion=criterion,
dropout=args.dropout,
tie_weights=args.tied)
print(model)
if args.cuda:
model.cuda()
print(model)
#################################################################
# Training code
#################################################################


def train():
def train(model, data_source, lr=1.0, weight_decay=1e-5, momentum=0.9):
params = model.parameters()
optimizer = optim.SGD(params=params, lr=lr,
momentum=0.9, weight_decay=1e-5)
optimizer = optim.SGD(
params=params,
lr=lr,
momentum=momentum,
weight_decay=weight_decay
)
# Turn on training mode which enables dropout.
model.train()
total_loss = 0
Expand All @@ -177,15 +172,14 @@ def train():
cur_loss, math.exp(cur_loss)))
total_loss = 0
print('-' * 87)
num_batch += 1

def evaluate(model, data_source, cuda=args.cuda):
# Turn on evaluation mode which disables dropout.
model.eval()
eval_loss = 0
total_length = 0

data_source.batch_size = 32
data_source.batch_size = eval_batch_size
for data_batch in data_source:
data, target, length = process_data(data_batch, cuda=cuda, eval=True)

Expand All @@ -208,7 +202,7 @@ def evaluate(model, data_source, cuda=args.cuda):
try:
for epoch in range(1, args.epochs + 1):
epoch_start_time = time.time()
train()
train(model, corpus.train, lr=lr)
if args.prof:
break
val_ppl = evaluate(model, corpus.valid)
Expand Down Expand Up @@ -249,4 +243,3 @@ def evaluate(model, data_source, cuda=args.cuda):

if args.tb_name:
writer.close()

13 changes: 9 additions & 4 deletions nce.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class NCELoss(nn.Module):
norm_term: the normalization term (lnZ in paper)
size_average: average the loss by batch size
decoder: the decoder matrix
normed_eval: using normalized probability during evaluation
Shape:
- noise: :math:`(V)` where `V = vocabulary size`
Expand All @@ -37,8 +38,8 @@ def __init__(self,
noise_ratio=10,
norm_term=9,
size_average=True,
decoder_weight=None,
per_word=True,
normed_eval=True,
):
super(NCELoss, self).__init__()

Expand All @@ -49,10 +50,10 @@ def __init__(self,
self.ntokens = ntokens
self.size_average = size_average
self.per_word = per_word
if normed_eval:
self.normed_eval = normed_eval
self.ce = nn.CrossEntropyLoss(size_average=False)
self.decoder = IndexLinear(nhidden, ntokens)
# Weight tying
if decoder_weight:
self.decoder.weight = decoder_weight

def forward(self, input, target=None):
"""compute the loss with output and the desired target
Expand Down Expand Up @@ -92,6 +93,10 @@ def forward(self, input, target=None):

loss = -1 * torch.sum(rnn_loss + noise_loss)

elif self.normed_eval:
# Fallback into conventional cross entropy
out = self.decoder(input)
loss = self.ce(out, target)
else:
out = self.decoder(input, indices=target.unsqueeze(1))
nll = out.sub(self.norm_term)
Expand Down

0 comments on commit 0a6dfa4

Please sign in to comment.