Skip to content

Commit

Permalink
minor updates in garom.py
Browse files Browse the repository at this point in the history
* Removing `retain_graph` in backward for discriminator
* Fixing issues with different precision training for Lightining
  • Loading branch information
dario-coscia authored and ndem0 committed Nov 17, 2023
1 parent 48b2e33 commit 56bd0da
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions pina/solvers/garom.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def __init__(
self.gamma = gamma
self.lambda_k = lambda_k
self.regularizer = float(regularizer)
self._generator = self.models[0]
self._discriminator = self.models[1]

def forward(self, x, mc_steps=20, variance=False):
"""
Expand Down Expand Up @@ -215,7 +217,7 @@ def _train_discriminator(self, parameters, snapshots):
d_loss = d_loss_real - self.k * d_loss_fake

# backward step
d_loss.backward(retain_graph=True)
d_loss.backward()
optimizer.step()

return d_loss_real, d_loss_fake, d_loss
Expand Down Expand Up @@ -251,7 +253,7 @@ def training_step(self, batch, batch_idx):

condition_name = dataloader.condition_names[condition_id]
condition = self.problem.conditions[condition_name]
pts = batch['pts']
pts = batch['pts'].detach()
out = batch['output']

if condition_name not in self.problem.conditions:
Expand Down Expand Up @@ -282,11 +284,11 @@ def training_step(self, batch, batch_idx):

@property
def generator(self):
return self.models[0]
return self._generator

@property
def discriminator(self):
return self.models[1]
return self._discriminator

@property
def optimizer_generator(self):
Expand Down

0 comments on commit 56bd0da

Please sign in to comment.