diff --git a/code/models/base.py b/code/models/base.py index 8610530..08c852e 100644 --- a/code/models/base.py +++ b/code/models/base.py @@ -67,7 +67,7 @@ def general_step(self, batch, batch_idx, step: str): batch_size=batch_len, on_step=False, on_epoch=True, - sync_dist=True + sync_dist=True, ) self.log_accuracies(x_hat, batch.y, batch_len, step) return loss @@ -90,12 +90,7 @@ def log_accuracies( else: ValueError("Unknown step Literal") - accuracies = self.accuracies_fn( - acc_fun, - x_hat, - y, - step - ) + accuracies = self.accuracies_fn(acc_fun, x_hat, y, step) for accuracy in accuracies: self.log( **accuracy, @@ -103,7 +98,7 @@ def log_accuracies( on_step=False, on_epoch=True, batch_size=batch_len, - sync_dist=True + sync_dist=True, ) def test_step(self, batch, batch_idx):