From f0e0b54fb6333aeff3ad62139d6acbd7d0a7ac5d Mon Sep 17 00:00:00 2001 From: bennyjg Date: Wed, 9 Sep 2020 12:48:29 +0300 Subject: [PATCH] Feat: add ReduceLROnPlateau support Move scheduler step to outer training loop and add stopping_loss as parameter for case of ReduceLROnPlateau scheduler --- pytorch_tabnet/tab_model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 0b40364e..3e5b8daa 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -243,6 +243,12 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None, else: self.patience_counter += 1 + if self.scheduler is not None: + if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.scheduler.step(stopping_loss) + else: + self.scheduler.step() + self.epoch += 1 total_time += time.time() - starting_time if self.verbose > 0: @@ -637,8 +643,6 @@ def train_epoch(self, train_loader): 'stopping_loss': stopping_loss, } - if self.scheduler is not None: - self.scheduler.step() return epoch_metrics def train_batch(self, data, targets): @@ -894,8 +898,6 @@ def train_epoch(self, train_loader): 'stopping_loss': stopping_loss, } - if self.scheduler is not None: - self.scheduler.step() return epoch_metrics def train_batch(self, data, targets):