diff --git a/deepxde/model.py b/deepxde/model.py index feefe7a9f..34a58c027 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -353,13 +353,23 @@ def outputs_losses_test(inputs, targets, auxiliary_vars): "backend pytorch." ) - def train_step(inputs, targets, auxiliary_vars, perform_backward=True): + def train_step(inputs, targets, auxiliary_vars): + def closure(): + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + self.opt.zero_grad() + total_loss.backward() + return total_loss + + self.opt.step(closure) + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + def train_step_nncg(inputs, targets, auxiliary_vars): def closure(): losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] total_loss = torch.sum(losses) self.opt.zero_grad() - if perform_backward: - total_loss.backward() return total_loss self.opt.step(closure) @@ -370,7 +380,7 @@ def closure(): self.outputs = outputs self.outputs_losses_train = outputs_losses_train self.outputs_losses_test = outputs_losses_test - self.train_step = train_step + self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg def _compile_jax(self, lr, loss_fn, decay): """jax""" @@ -652,7 +662,7 @@ def train( if self.opt_name == "L-BFGS": self._train_pytorch_lbfgs() elif self.opt_name == "NNCG": - self._train_pytorch_nncg(iterations, display_every) + self._train_sgd(iterations, display_every) elif backend_name == "paddle": self._train_paddle_lbfgs() else: @@ -796,32 +806,6 @@ def _train_pytorch_lbfgs(self): if self.stop_training: break - def _train_pytorch_nncg(self, iterations, display_every): - for i in range(iterations): - self.callbacks.on_epoch_begin() - self.callbacks.on_batch_begin() - - self.train_state.set_data_train( - *self.data.train_next_batch(self.batch_size) - ) - self.train_step( - self.train_state.X_train, - self.train_state.y_train, - self.train_state.train_aux_vars, - perform_backward=False, - ) - - self.train_state.epoch += 1 - self.train_state.step += 1 - if self.train_state.step % display_every == 0 or i + 1 == iterations: - self._test() - - self.callbacks.on_batch_end() - self.callbacks.on_epoch_end() - - if self.stop_training: - break - def _train_paddle_lbfgs(self): prev_n_iter = 0