From 25ce3e1991fd17d10f712407248e7d65be4b323a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wang=20Ran=20=28=E6=B1=AA=E7=84=B6=29?= Date: Sun, 17 Feb 2019 18:42:11 +0800 Subject: [PATCH] bug fix --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index c5c9900..ef9071a 100644 --- a/train.py +++ b/train.py @@ -155,7 +155,7 @@ def train(epoch): if opt.model == 'gated': model.current_epoch = epoch - global e, updates, total_loss, start_time, report_total + global updates, total_loss, start_time, report_total for raw_src, src, src_len, raw_tgt, tgt, tgt_len in trainloader: @@ -192,7 +192,7 @@ def train(epoch): model.train() total_loss = 0 - start_time = 0 + start_time = time.time() report_total = 0 if updates % config.save_interval == 0: @@ -231,7 +231,7 @@ def eval(epoch): score = {} result = utils.eval_metrics(reference, candidate, label_dict, log_path) - logging_csv([e, updates, result['hamming_loss'], \ + logging_csv([epoch, updates, result['hamming_loss'], \ result['micro_f1'], result['micro_precision'], result['micro_recall']]) print('hamming_loss: %.8f | micro_f1: %.4f' % (result['hamming_loss'], result['micro_f1']))