Skip to content

Commit

Permalink
Add patience to tensorboard logging (#1835)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bram Vanroy authored Sep 3, 2020
1 parent 0f40cc1 commit a5401ac
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
22 changes: 15 additions & 7 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _update_average(self, step):
self.moving_average = copy_params
else:
average_decay = max(self.average_decay,
1 - (step + 1)/(step + 10))
1 - (step + 1) / (step + 10))
for (i, avg), cpt in zip(enumerate(self.moving_average),
self.model.parameters()):
self.moving_average[i] = \
Expand Down Expand Up @@ -295,8 +295,8 @@ def train(self,
break

if (self.model_saver is not None
and (save_checkpoint_steps != 0
and step % save_checkpoint_steps == 0)):
and (save_checkpoint_steps != 0
and step % save_checkpoint_steps == 0)):
self.model_saver.save(step, moving_average=self.moving_average)

if train_steps > 0 and step >= train_steps:
Expand Down Expand Up @@ -331,7 +331,7 @@ def validate(self, valid_iter, moving_average=None):

for batch in valid_iter:
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
else (batch.src, None)
else (batch.src, None)
tgt = batch.tgt

with torch.cuda.amp.autocast(enabled=self.optim.amp):
Expand Down Expand Up @@ -377,7 +377,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
tgt_outer = batch.tgt

bptt = False
for j in range(0, target_size-1, trunc_size):
for j in range(0, target_size - 1, trunc_size):
# 1. Create truncated target.
tgt = tgt_outer[j: j + trunc_size]

Expand Down Expand Up @@ -475,7 +475,12 @@ def _maybe_report_training(self, step, num_steps, learning_rate,
"""
if self.report_manager is not None:
return self.report_manager.report_training(
step, num_steps, learning_rate, report_stats,
step,
num_steps,
learning_rate,
None if self.earlystopper is None
else self.earlystopper.current_tolerance,
report_stats,
multigpu=self.n_gpu > 1)

def _report_step(self, learning_rate, step, train_stats=None,
Expand All @@ -486,7 +491,10 @@ def _report_step(self, learning_rate, step, train_stats=None,
"""
if self.report_manager is not None:
return self.report_manager.report_step(
learning_rate, step, train_stats=train_stats,
learning_rate,
None if self.earlystopper is None
else self.earlystopper.current_tolerance,
step, train_stats=train_stats,
valid_stats=valid_stats)

def maybe_noise_source(self, batch):
Expand Down
29 changes: 20 additions & 9 deletions onmt/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def start(self):
def log(self, *args, **kwargs):
logger.info(*args, **kwargs)

def report_training(self, step, num_steps, learning_rate,
def report_training(self, step, num_steps, learning_rate, patience,
report_stats, multigpu=False):
"""
This is the user-defined batch-level traing progress
Expand All @@ -72,7 +72,7 @@ def report_training(self, step, num_steps, learning_rate,
report_stats = \
onmt.utils.Statistics.all_gather_stats(report_stats)
self._report_training(
step, num_steps, learning_rate, report_stats)
step, num_steps, learning_rate, patience, report_stats)
return onmt.utils.Statistics()
else:
return report_stats
Expand All @@ -81,17 +81,22 @@ def _report_training(self, *args, **kwargs):
""" To be overridden """
raise NotImplementedError()

def report_step(self, lr, step, train_stats=None, valid_stats=None):
def report_step(self, lr, patience, step, train_stats=None,
valid_stats=None):
"""
Report stats of a step
Args:
lr(float): current learning rate
patience(int): current patience
step(int): current step
train_stats(Statistics): training stats
valid_stats(Statistics): validation stats
lr(float): current learning rate
"""
self._report_step(
lr, step, train_stats=train_stats, valid_stats=valid_stats)
lr, patience, step,
train_stats=train_stats,
valid_stats=valid_stats)

def _report_step(self, *args, **kwargs):
raise NotImplementedError()
Expand All @@ -111,12 +116,13 @@ def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
super(ReportMgr, self).__init__(report_every, start_time)
self.tensorboard_writer = tensorboard_writer

def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):
def maybe_log_tensorboard(self, stats, prefix, learning_rate,
patience, step):
if self.tensorboard_writer is not None:
stats.log_tensorboard(
prefix, self.tensorboard_writer, learning_rate, step)
prefix, self.tensorboard_writer, learning_rate, patience, step)

def _report_training(self, step, num_steps, learning_rate,
def _report_training(self, step, num_steps, learning_rate, patience,
report_stats):
"""
See base class method `ReportMgrBase.report_training`.
Expand All @@ -127,12 +133,15 @@ def _report_training(self, step, num_steps, learning_rate,
self.maybe_log_tensorboard(report_stats,
"progress",
learning_rate,
patience,
step)
report_stats = onmt.utils.Statistics()

return report_stats

def _report_step(self, lr, step, train_stats=None, valid_stats=None):
def _report_step(self, lr, patience, step,
train_stats=None,
valid_stats=None):
"""
See base class method `ReportMgrBase.report_step`.
"""
Expand All @@ -143,6 +152,7 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
self.maybe_log_tensorboard(train_stats,
"train",
lr,
patience,
step)

if valid_stats is not None:
Expand All @@ -152,4 +162,5 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
self.maybe_log_tensorboard(valid_stats,
"valid",
lr,
patience,
step)
4 changes: 3 additions & 1 deletion onmt/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ def output(self, step, num_steps, learning_rate, start):
time.time() - start))
sys.stdout.flush()

def log_tensorboard(self, prefix, writer, learning_rate, step):
def log_tensorboard(self, prefix, writer, learning_rate, patience, step):
""" display statistics to tensorboard """
t = self.elapsed_time()
writer.add_scalar(prefix + "/xent", self.xent(), step)
writer.add_scalar(prefix + "/ppl", self.ppl(), step)
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
writer.add_scalar(prefix + "/lr", learning_rate, step)
if patience is not None:
writer.add_scalar(prefix + "/patience", patience, step)

0 comments on commit a5401ac

Please sign in to comment.