diff --git a/keras_tqdm/tqdm_callback.py b/keras_tqdm/tqdm_callback.py index 70d03f5..0799cc5 100644 --- a/keras_tqdm/tqdm_callback.py +++ b/keras_tqdm/tqdm_callback.py @@ -96,9 +96,9 @@ def on_epoch_begin(self, epoch, logs={}): self.running_logs = {} def on_epoch_end(self, epoch, logs={}): - metrics = self.format_metrics(logs) - desc = self.inner_description_update.format(epoch=epoch, metrics=metrics) if self.show_inner: + metrics = self.format_metrics(logs) + desc = self.inner_description_update.format(epoch=epoch, metrics=metrics) self.tqdm_inner.desc = desc # set miniters and mininterval to 0 so last update displays self.tqdm_inner.miniters = 0 @@ -119,9 +119,9 @@ def on_batch_end(self, batch, logs={}): self.inner_count += update if self.inner_count < self.inner_total: self.append_logs(logs) - metrics = self.format_metrics(self.running_logs) - desc = self.inner_description_update.format(epoch=self.epoch, metrics=metrics) if self.show_inner: + metrics = self.format_metrics(self.running_logs) + desc = self.inner_description_update.format(epoch=self.epoch, metrics=metrics) self.tqdm_inner.desc = desc self.tqdm_inner.update(update) @@ -137,7 +137,9 @@ def on_train_end(self, logs={}): self.tqdm_outer.close() def append_logs(self, logs): - metrics = self.params['metrics'] + metrics = self.params.get('metrics') + if not metrics: + return for metric, value in six.iteritems(logs): if metric in metrics: if metric in self.running_logs: @@ -146,7 +148,10 @@ def append_logs(self, logs): self.running_logs[metric] = [value[()]] def format_metrics(self, logs): - metrics = self.params['metrics'] + metrics = self.params.get('metrics') + if not metrics: + return '' + strings = [self.metric_format.format(name=metric, value=np.mean(logs[metric], axis=None)) for metric in metrics if metric in logs]