Skip to content

Commit

Permalink
add train_step_nncg
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikrathore8 committed Oct 31, 2024
1 parent 60d3ff3 commit 6e739e1
Showing 1 changed file with 15 additions and 31 deletions.
46 changes: 15 additions & 31 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,23 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
"backend pytorch."
)

def train_step(inputs, targets, auxiliary_vars, perform_backward=True):
def train_step(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
total_loss.backward()
return total_loss

self.opt.step(closure)
if self.lr_scheduler is not None:
self.lr_scheduler.step()

def train_step_nncg(inputs, targets, auxiliary_vars):
def closure():
losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1]
total_loss = torch.sum(losses)
self.opt.zero_grad()
if perform_backward:
total_loss.backward()
return total_loss

self.opt.step(closure)
Expand All @@ -370,7 +380,7 @@ def closure():
self.outputs = outputs
self.outputs_losses_train = outputs_losses_train
self.outputs_losses_test = outputs_losses_test
self.train_step = train_step
self.train_step = train_step if self.opt_name != "NNCG" else train_step_nncg

def _compile_jax(self, lr, loss_fn, decay):
"""jax"""
Expand Down Expand Up @@ -652,7 +662,7 @@ def train(
if self.opt_name == "L-BFGS":
self._train_pytorch_lbfgs()
elif self.opt_name == "NNCG":
self._train_pytorch_nncg(iterations, display_every)
self._train_sgd(iterations, display_every)
elif backend_name == "paddle":
self._train_paddle_lbfgs()
else:
Expand Down Expand Up @@ -796,32 +806,6 @@ def _train_pytorch_lbfgs(self):
if self.stop_training:
break

def _train_pytorch_nncg(self, iterations, display_every):
for i in range(iterations):
self.callbacks.on_epoch_begin()
self.callbacks.on_batch_begin()

self.train_state.set_data_train(
*self.data.train_next_batch(self.batch_size)
)
self.train_step(
self.train_state.X_train,
self.train_state.y_train,
self.train_state.train_aux_vars,
perform_backward=False,
)

self.train_state.epoch += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0 or i + 1 == iterations:
self._test()

self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()

if self.stop_training:
break

def _train_paddle_lbfgs(self):
prev_n_iter = 0

Expand Down

0 comments on commit 6e739e1

Please sign in to comment.