diff --git a/msd_pytorch/msd_model.py b/msd_pytorch/msd_model.py index eaeac02..81440fb 100644 --- a/msd_pytorch/msd_model.py +++ b/msd_pytorch/msd_model.py @@ -191,6 +191,8 @@ def forward(self, input=None, target=None): if target is not None: self.set_target(target) + self.output = None + self.loss = None self.output = self.net(self.input) self.loss = self.criterion(self.output, self.target)