diff --git a/torchensemble/soft_gradient_boosting.py b/torchensemble/soft_gradient_boosting.py index 8ffe5d9..ed41f49 100644 --- a/torchensemble/soft_gradient_boosting.py +++ b/torchensemble/soft_gradient_boosting.py @@ -209,11 +209,14 @@ def fit( test_loader=None, save_model=True, save_dir=None, + on_epoch_end_cb=None ): # Instantiate base estimators and set attributes - for _ in range(self.n_estimators): - self.estimators_.append(self._make_estimator()) + # dont instantiate if estimators loaded from save_dir + if len(self.estimators_) != self.n_estimators: + for _ in range(self.n_estimators): + self.estimators_.append(self._make_estimator()) self._validate_parameters(epochs, log_interval) self.n_outputs = self._decide_n_outputs(train_loader) @@ -295,6 +298,9 @@ def fit( else: scheduler.step() + # Call on epoch end + if on_epoch_end_cb: + on_epoch_end_cb(epoch) if save_model and not test_loader: io.save(self, save_dir, self.logger) @@ -390,6 +396,7 @@ def fit( test_loader=None, save_model=True, save_dir=None, + on_epoch_end_cb=None ): super().fit( train_loader=train_loader, @@ -399,6 +406,7 @@ def fit( test_loader=test_loader, save_model=save_model, save_dir=save_dir, + on_epoch_end_cb=on_epoch_end_cb, ) @torchensemble_model_doc(