Skip to content

Commit

Permalink
Feat: add ReduceLROnPlateau support
Browse files Browse the repository at this point in the history
Move scheduler step to outer training loop and add stopping_loss as parameter for case of ReduceLROnPlateau scheduler
  • Loading branch information
bennyjg authored and Optimox committed Sep 11, 2020
1 parent c557349 commit f0e0b54
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions pytorch_tabnet/tab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
else:
self.patience_counter += 1

if self.scheduler is not None:
if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(stopping_loss)
else:
self.scheduler.step()

self.epoch += 1
total_time += time.time() - starting_time
if self.verbose > 0:
Expand Down Expand Up @@ -637,8 +643,6 @@ def train_epoch(self, train_loader):
'stopping_loss': stopping_loss,
}

if self.scheduler is not None:
self.scheduler.step()
return epoch_metrics

def train_batch(self, data, targets):
Expand Down Expand Up @@ -894,8 +898,6 @@ def train_epoch(self, train_loader):
'stopping_loss': stopping_loss,
}

if self.scheduler is not None:
self.scheduler.step()
return epoch_metrics

def train_batch(self, data, targets):
Expand Down

0 comments on commit f0e0b54

Please sign in to comment.