diff --git a/code/models/base.py b/code/models/base.py index 57d4bd2..8610530 100644 --- a/code/models/base.py +++ b/code/models/base.py @@ -67,6 +67,7 @@ def general_step(self, batch, batch_idx, step: str): batch_size=batch_len, on_step=False, on_epoch=True, + sync_dist=True ) self.log_accuracies(x_hat, batch.y, batch_len, step) return loss