Skip to content

Commit

Permalink
Add final plugin call on training end
Browse files Browse the repository at this point in the history
If you want your plugin to be call at the whole training end (very
useful for one last save of progress, e.g. network snapshots or output
generated), just add it to the end queue with interval 1. Implemented by
default for SaverPlugin and OutputGenerator.
  • Loading branch information
Michalaq committed Feb 6, 2018
1 parent a6de55c commit 00b032c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 8 additions & 2 deletions plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class SaverPlugin(Plugin):
last_pattern = 'network-snapshot-{}-{}.dat'

def __init__(self, checkpoints_path, keep_old_checkpoints=False, network_snapshot_ticks=40):
super().__init__([(network_snapshot_ticks, 'epoch')])
super().__init__([(network_snapshot_ticks, 'epoch'), (1, 'end')])
self.checkpoints_path = checkpoints_path
self.keep_old_checkpoints = keep_old_checkpoints
self._best_val_loss = float('+inf')
Expand All @@ -165,6 +165,9 @@ def epoch(self, epoch_index):
)
)

def end(self, *args):
self.epoch(*args)

def _clear(self, pattern):
pattern = os.path.join(self.checkpoints_path, pattern)
for file_name in glob(pattern):
Expand All @@ -174,7 +177,7 @@ def _clear(self, pattern):
class OutputGenerator(Plugin):

def __init__(self, sample_fn, output_postprocessors, samples_count=6, output_snapshot_ticks=3):
super(OutputGenerator, self).__init__([(output_snapshot_ticks, 'epoch')])
super(OutputGenerator, self).__init__([(output_snapshot_ticks, 'epoch'), (1, 'end')])
self.sample_fn = sample_fn
self.output_postprocessors = output_postprocessors
self.samples_count = samples_count
Expand All @@ -188,6 +191,9 @@ def epoch(self, epoch_index):
for proc in self.output_postprocessors:
proc(out, self.trainer.cur_nimg // 1000)

def end(self, *args):
self.epoch(*args)


class CometPlugin(Plugin):

Expand Down
4 changes: 3 additions & 1 deletion trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self,
self.plugin_queues = {
'iteration': [],
'epoch': [],
's': []
's': [],
'end': []
}

def register_plugin(self, plugin):
Expand Down Expand Up @@ -79,6 +80,7 @@ def run(self, total_kimg=1):
self.stats['kimg_stat']['val'] = self.cur_nimg / 1000.
self.stats['tick_stat']['val'] = self.cur_tick
self.call_plugins('epoch', self.cur_tick)
self.call_plugins('end', 1)

def train(self):
fake_latents_in = self.random_latents_generator().cuda()
Expand Down

0 comments on commit 00b032c

Please sign in to comment.