From 464e0846ea0c5e121dbcae3e3b80f475a16b27ea Mon Sep 17 00:00:00 2001 From: Yi-Xuan Xu Date: Tue, 23 Mar 2021 15:29:32 +0800 Subject: [PATCH] feat: Add Fast Geometric Ensembling (#56) * Update CHANGELOG.rst * primal update * primal update * fix typos * primal update * improve FastGeometricClassifier * improve docstrings * add FastGeometricRegressor * flake8 formatting * flake8 formatting * add unit tests * add unit tests * add unit tests * flake8 formatting * improve documentation * improve unit tests * revert classification script * fix the workflow * add example * improve api and internal workflow * improve docstrings * fix the example --- CHANGELOG.rst | 1 + docs/parameters.rst | 39 + examples/classification_cifar10_cnn.py | 1 + ...ast_geometric_ensemble_cifar10_resnet18.py | 172 +++++ torchensemble/__init__.py | 4 + torchensemble/_base.py | 1 + torchensemble/_constants.py | 27 + torchensemble/fast_geometric.py | 721 ++++++++++++++++++ torchensemble/snapshot_ensemble.py | 4 +- torchensemble/tests/test_all_models.py | 58 +- torchensemble/tests/test_fast_geometric.py | 85 +++ 11 files changed, 1083 insertions(+), 30 deletions(-) create mode 100644 examples/fast_geometric_ensemble_cifar10_resnet18.py create mode 100644 torchensemble/fast_geometric.py create mode 100644 torchensemble/tests/test_fast_geometric.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ecb5589..7b5a013 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,7 @@ Changelog Ver 0.1.* --------- +* |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu `__ * |Enhancement| Add flexible instantiation of optimizers and schedulers | `@cspsampedro `__ * |Feature| |API| Add support on accepting instantiated base estimators as valid input | `@xuyxu `__ * |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu `__ diff --git a/docs/parameters.rst b/docs/parameters.rst index e4ee3a5..8ab9619 100644 --- a/docs/parameters.rst +++ b/docs/parameters.rst @@ -134,3 +134,42 @@ AdversarialTrainingRegressor .. autoclass:: torchensemble.adversarial_training.AdversarialTrainingRegressor :members: + +Fast Geometric Ensemble +----------------------- + +Motivated by geometric insights on the loss surface of deep neural networks, +Fast Geometirc Ensembling (FGE) is an efficient ensemble that uses a +customized learning rate scheduler to generate base estimators, similar to +snapshot ensemble. + +Reference: + T. Garipov, P. Izmailov, D. Podoprikhin et al., Loss Surfaces, Mode + Connectivity, and Fast Ensembling of DNNs, NeurIPS, 2018. + +Notice that unlike all ensembles above, using fast geometric ensemble (FGE) is +**a two-staged process**. Concretely, you first need to call :meth:`fit` to +build a dummy base estimator that will be used to generate ensembles. Second, +you need to call :meth:`ensemble` to generate real base estimators in the +ensemble. The pipeline is shown in the following code snippet: + +.. code:: python + + model = FastGeometricClassifier(**ensemble_related_args) + estimator = model.fit(train_loader, **base_estimator_related_args) # train the base estimator + model.ensemble(estimator, train_loader, **fge_related_args) # generate the ensemble using the base estimator + +You can refer to scripts in `examples `__ for +a detailed example. + +FastGeometricClassifier +*********************** + +.. autoclass:: torchensemble.fast_geometric.FastGeometricClassifier + :members: + +FastGeometricRegressor +*********************** + +.. autoclass:: torchensemble.fast_geometric.FastGeometricRegressor + :members: diff --git a/examples/classification_cifar10_cnn.py b/examples/classification_cifar10_cnn.py index d15c5f8..bd82118 100644 --- a/examples/classification_cifar10_cnn.py +++ b/examples/classification_cifar10_cnn.py @@ -12,6 +12,7 @@ from torchensemble.bagging import BaggingClassifier from torchensemble.gradient_boosting import GradientBoostingClassifier from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier +from torchensemble.fast_geometric import FastGeometricClassifier from torchensemble.utils.logging import set_logger diff --git a/examples/fast_geometric_ensemble_cifar10_resnet18.py b/examples/fast_geometric_ensemble_cifar10_resnet18.py new file mode 100644 index 0000000..4905ba0 --- /dev/null +++ b/examples/fast_geometric_ensemble_cifar10_resnet18.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + +from torchensemble import FastGeometricClassifier +from torchensemble.utils.logging import set_logger + + +# The class `BasicBlock` and `ResNet` is modified from: +# https://github.com/kuangliu/pytorch-cifar +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +if __name__ == "__main__": + + # Hyper-parameters + n_estimators = 10 + lr = 1e-1 + weight_decay = 5e-4 + momentum = 0.9 + epochs = 200 + + # Utils + batch_size = 128 + data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT + torch.manual_seed(0) + torch.cuda.set_device(0) + + # Load data + train_transformer = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, 4), + transforms.ToTensor(), + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), + ] + ) + + test_transformer = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), + ] + ) + + train_loader = DataLoader( + datasets.CIFAR10( + data_dir, train=True, download=True, transform=train_transformer + ), + batch_size=batch_size, + shuffle=True, + ) + + test_loader = DataLoader( + datasets.CIFAR10(data_dir, train=False, transform=test_transformer), + batch_size=batch_size, + shuffle=True, + ) + + # Set the Logger + logger = set_logger("FastGeometricClassifier_cifar10_resnet") + + # Choose the Ensemble Method + model = FastGeometricClassifier( + estimator=ResNet, + estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]}, + n_estimators=n_estimators, + cuda=True, + ) + + # Set the Optimizer + model.set_optimizer( + "SGD", lr=lr, weight_decay=weight_decay, momentum=momentum + ) + + # Set the Scheduler + model.set_scheduler("CosineAnnealingLR", T_max=epochs) + + # Train + estimator = model.fit(train_loader, epochs=epochs, test_loader=test_loader) + + # Ensemble + model.ensemble( + estimator, + train_loader, + cycle=4, + lr_1=5e-2, + lr_2=5e-4, + test_loader=test_loader, + ) + + # Evaluate + acc = model.predict(test_loader) + print("Testing Acc: {:.3f}".format(acc)) diff --git a/torchensemble/__init__.py b/torchensemble/__init__.py index bf2da37..f1ffbcd 100644 --- a/torchensemble/__init__.py +++ b/torchensemble/__init__.py @@ -10,6 +10,8 @@ from .snapshot_ensemble import SnapshotEnsembleRegressor from .adversarial_training import AdversarialTrainingClassifier from .adversarial_training import AdversarialTrainingRegressor +from .fast_geometric import FastGeometricClassifier +from .fast_geometric import FastGeometricRegressor __all__ = [ @@ -25,4 +27,6 @@ "SnapshotEnsembleRegressor", "AdversarialTrainingClassifier", "AdversarialTrainingRegressor", + "FastGeometricClassifier", + "FastGeometricRegressor", ] diff --git a/torchensemble/_base.py b/torchensemble/_base.py index fe83649..bc987d0 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -25,6 +25,7 @@ def get_doc(item): """Return the selected item.""" __doc = { "model": const.__model_doc, + "seq_model": const.__seq_model_doc, "fit": const.__fit_doc, "set_optimizer": const.__set_optimizer_doc, "set_scheduler": const.__set_scheduler_doc, diff --git a/torchensemble/_constants.py b/torchensemble/_constants.py index ecf91e3..81a7abd 100644 --- a/torchensemble/_constants.py +++ b/torchensemble/_constants.py @@ -31,6 +31,33 @@ """ +__seq_model_doc = """ + Parameters + ---------- + estimator : torch.nn.Module + The class or object of your base estimator. + + - If :obj:`class`, it should inherit from :mod:`torch.nn.Module`. + - If :obj:`object`, it should be instantiated from a class inherited + from :mod:`torch.nn.Module`. + n_estimators : int + The number of base estimators in the ensemble. + estimator_args : dict, default=None + The dictionary of hyper-parameters used to instantiate base + estimators. This parameter will have no effect if ``estimator`` is a + base estimator object after instantiation. + cuda : bool, default=True + + - If ``True``, use GPU to train and evaluate the ensemble. + - If ``False``, use CPU to train and evaluate the ensemble. + + Attributes + ---------- + estimators_ : torch.nn.ModuleList + An internal container that stores all fitted base estimators. +""" + + __set_optimizer_doc = """ Parameters ---------- diff --git a/torchensemble/fast_geometric.py b/torchensemble/fast_geometric.py new file mode 100644 index 0000000..90b79df --- /dev/null +++ b/torchensemble/fast_geometric.py @@ -0,0 +1,721 @@ +""" + Motivated by geometric insights on the loss surface of deep neural networks, + Fast Geometirc Ensembling (FGE) is an efficient ensemble that uses a + customized learning rate scheduler to generate base estimators, similar to + snapshot ensemble. + + Reference: + T. Garipov, P. Izmailov, D. Podoprikhin et al., Loss Surfaces, Mode + Connectivity, and Fast Ensembling of DNNs, NeurIPS, 2018. +""" + + +import copy +import torch +import logging +import warnings +import torch.nn as nn +import torch.nn.functional as F + +from ._base import BaseModule, torchensemble_model_doc +from .utils import io +from .utils import set_module +from .utils import operator as op + + +__all__ = [ + "_BaseFastGeometric", + "FastGeometricClassifier", + "FastGeometricRegressor", +] + + +__fit_doc = """ + Parameters + ---------- + train_loader : torch.utils.data.DataLoader + A :mod:`DataLoader` container that contains the training data. + epochs : int, default=100 + The number of training epochs used to fit the dummy base estimator. + log_interval : int, default=100 + The number of batches to wait before logging the training status. + test_loader : torch.utils.data.DataLoader, default=None + A :mod:`DataLoader` container that contains the evaluating data. + + - If ``None``, no validation is conducted during the training stage + of the dummy base estimator. + - If not ``None``, the dummy base estimator will be evaluated on this + dataloader after each training epoch, and the checkpoint with the + best validation performance will be reserved. + + Returns + ------- + estimator_ : :obj:`object` + The fitted base estimator. + + - If test_loader is ``None``, the base estimator fully trained will be + returned. + - If test_loader is not ``None``, the base estimator with the best + validation performance will be returned. +""" + + +__fge_doc = """ + Parameters + ---------- + estimator : :obj:`object` + The fitted base estimator. + train_loader : torch.utils.data.DataLoader + A :mod:`DataLoader` container that contains the training data. + cycle : int, default=4 + The number of cycles used to build each base estimator in the ensemble. + lr_1 : float, default=5e-2 + ``alpha_1`` in original paper used to adjust the learning rate, also + serves as the initial learning rate of the internal optimizer. + lr_2 : float, default=1e-4 + ``alpha_2`` in original paper used to adjust the learning rate, also + serves as the smallest learning rate of the internal optimizer. + test_loader : torch.utils.data.DataLoader, default=None + A :mod:`DataLoader` container that contains the evaluating data. + + - If ``None``, no validation is conducted after each real base + estimator being generated. + - If not ``None``, the ensemble will be evaluated on this + dataloader after each base estimator being generated. + log_interval : int, default=100 + The number of batches to wait before logging the training status. + save_model : bool, default=True + Specify whether to save the model parameters. + + - If test_loader is ``None``, the ensemble fully trained will be + saved. + - If test_loader is not ``None``, the ensemble with the best + validation performance will be saved. + save_dir : string, default=None + Specify where to save the model parameters. + + - If ``None``, the model will be saved in the current directory. + - If not ``None``, the model will be saved in the specified + directory: ``save_dir``. +""" + + +def _fast_geometric_model_doc(header, item="fit"): + """ + Decorator on obtaining documentation for different fast geometric models. + """ + + def get_doc(item): + """Return selected item""" + __doc = {"fit": __fit_doc, "fge": __fge_doc} + return __doc[item] + + def adddoc(cls): + doc = [header + "\n\n"] + doc.extend(get_doc(item)) + cls.__doc__ = "".join(doc) + return cls + + return adddoc + + +class _BaseFastGeometric(BaseModule): + def __init__( + self, estimator, n_estimators, estimator_args=None, cuda=True + ): + super(BaseModule, self).__init__() + + self.base_estimator_ = estimator + self.n_estimators = n_estimators + self.estimator_args = estimator_args + + if estimator_args and not isinstance(estimator, type): + msg = ( + "The input `estimator_args` will have no effect since" + " `estimator` is already an object after instantiation." + ) + warnings.warn(msg, RuntimeWarning) + + self.device = torch.device("cuda" if cuda else "cpu") + self.logger = logging.getLogger() + + self.estimators_ = nn.ModuleList() + self.use_scheduler_ = False + + def _forward(self, x): + """ + Implementation on the internal data forwarding in fast geometric + ensemble. + """ + # Average + results = [estimator(x) for estimator in self.estimators_] + output = op.average(results) + + return output + + def _adjust_lr( + self, optimizer, epoch, i, n_iters, cycle, alpha_1, alpha_2 + ): + """ + Set the internal learning rate scheduler for fast geometric ensemble. + Please refer to the original paper for details. + """ + + def scheduler(i): + t = ((epoch % cycle) + i) / cycle + if t < 0.5: + return alpha_1 * (1.0 - 2.0 * t) + alpha_2 * 2.0 * t + else: + return alpha_1 * (2.0 * t - 1.0) + alpha_2 * (2.0 - 2.0 * t) + + lr = scheduler(i / n_iters) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + return lr + + @torchensemble_model_doc( + """Set the attributes on optimizer for Fast Geometric Ensemble.""", + "set_optimizer", + ) + def set_optimizer(self, optimizer_name, **kwargs): + self.optimizer_name = optimizer_name + self.optimizer_args = kwargs + + @torchensemble_model_doc( + """Set the attributes on scheduler for Fast Geometric Ensemble.""", + "set_scheduler", + ) + def set_scheduler(self, scheduler_name, **kwargs): + msg = ( + "The learning rate scheduler for fast geometirc ensemble will" + " only be used in the first stage on building the dummy base" + " estimator." + ) + warnings.warn(msg, UserWarning) + + self.scheduler_name = scheduler_name + self.scheduler_args = kwargs + self.use_scheduler_ = True + + +@torchensemble_model_doc( + """Implementation on the FastGeometricClassifier.""", "seq_model" +) +class FastGeometricClassifier(_BaseFastGeometric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.is_classification = True + + @torchensemble_model_doc( + """Implementation on the data forwarding in FastGeometricClassifier.""", # noqa: E501 + "classifier_forward", + ) + def forward(self, x): + proba = self._forward(x) + + return F.softmax(proba, dim=1) + + @torchensemble_model_doc( + ( + """Set the attributes on optimizer for FastGeometricClassifier. """ + + """Notice that keyword arguments specified here will also be """ + + """used in the ensembling stage except the learning rate..""" + ), + "set_optimizer", + ) + def set_optimizer(self, optimizer_name, **kwargs): + super().set_optimizer(optimizer_name=optimizer_name, **kwargs) + + @torchensemble_model_doc( + ( + """Set the attributes on scheduler for FastGeometricClassifier. """ + + """Notice that this scheduler will only be used in the stage on """ # noqa: E501 + + """fitting the dummy base estimator.""" + ), + "set_scheduler", + ) + def set_scheduler(self, scheduler_name, **kwargs): + super().set_scheduler(scheduler_name=scheduler_name, **kwargs) + + @_fast_geometric_model_doc( + """Implementation on the training stage of FastGeometricClassifier.""", # noqa: E501 + "fit", + ) + def fit( + self, train_loader, epochs=100, log_interval=100, test_loader=None + ): + self._validate_parameters(epochs, log_interval) + self.n_outputs = self._decide_n_outputs( + train_loader, self.is_classification + ) + + # A dummy base estimator + estimator_ = self._make_estimator() + ret_estimator = None + + # Set the optimizer and scheduler + optimizer = set_module.set_optimizer( + estimator_, self.optimizer_name, **self.optimizer_args + ) + + if self.use_scheduler_: + scheduler = set_module.set_scheduler( + optimizer, self.scheduler_name, **self.scheduler_args + ) + + # Utils + criterion = nn.CrossEntropyLoss() + best_acc = 0.0 + + for epoch in range(epochs): + + # Training + estimator_.train() + for batch_idx, (data, target) in enumerate(train_loader): + + batch_size = data.size(0) + data, target = data.to(self.device), target.to(self.device) + + optimizer.zero_grad() + output = estimator_(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + # Print training status + if batch_idx % log_interval == 0: + with torch.no_grad(): + _, predicted = torch.max(output.data, 1) + correct = (predicted == target).sum().item() + + msg = ( + "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f} |" + " Correct: {:d}/{:d}" + ) + self.logger.info( + msg.format( + epoch, + batch_idx, + loss, + correct, + batch_size, + ) + ) + + # Validation + if test_loader: + estimator_.eval() + with torch.no_grad(): + correct = 0 + total = 0 + for _, (data, target) in enumerate(test_loader): + data = data.to(self.device) + target = target.to(self.device) + output = estimator_(data) + _, predicted = torch.max(output.data, 1) + correct += (predicted == target).sum().item() + total += target.size(0) + acc = 100 * correct / total + + if acc > best_acc: + best_acc = acc + ret_estimator = copy.deepcopy(estimator_) + + msg = ( + "Validation Acc: {:.3f} % | Historical Best: {:.3f} %" + ) + self.logger.info(msg.format(acc, best_acc)) + + if self.use_scheduler_: + scheduler.step() + + # Extra step if `test_loader` is None + if ret_estimator is None: + ret_estimator = copy.deepcopy(estimator_) + + return ret_estimator + + @_fast_geometric_model_doc( + """Implementation on the ensembling stage of FastGeometricClassifier.""", # noqa: E501 + "fge", + ) + def ensemble( + self, + estimator, + train_loader, + cycle=4, + lr_1=5e-2, + lr_2=1e-4, + log_interval=100, + test_loader=None, + save_model=True, + save_dir=None, + ): + + # Set the internal optimizer + optimizer = set_module.set_optimizer( + estimator, self.optimizer_name, **self.optimizer_args + ) + + # Utils + criterion = nn.CrossEntropyLoss() + best_acc = 0.0 + n_iters = len(train_loader) + updated = False + epoch = 0 + + while len(self.estimators_) < self.n_estimators: + + # Training + estimator.train() + for batch_idx, (data, target) in enumerate(train_loader): + + # Update learning rate + self._adjust_lr( + optimizer, epoch, batch_idx, n_iters, cycle, lr_1, lr_2 + ) + + batch_size = data.size(0) + data, target = data.to(self.device), target.to(self.device) + + optimizer.zero_grad() + output = estimator(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + # Print training status + if batch_idx % log_interval == 0: + with torch.no_grad(): + _, predicted = torch.max(output.data, 1) + correct = (predicted == target).sum().item() + + msg = ( + "lr: {:.5f} | Epoch: {:03d} | Batch: {:03d} |" + " Loss: {:.5f} | Correct: {:d}/{:d}" + ) + self.logger.info( + msg.format( + optimizer.param_groups[0]["lr"], + epoch, + batch_idx, + loss, + correct, + batch_size, + ) + ) + + # Update the ensemble + if (epoch % cycle + 1) == cycle // 2: + self.estimators_.append(copy.deepcopy(estimator)) + updated = True + + msg = "Save the base estimator with index: {}" + self.logger.info(msg.format(len(self.estimators_) - 1)) + + # Validation after each base estimator being added + if test_loader and updated: + self.eval() + with torch.no_grad(): + correct = 0 + total = 0 + for _, (data, target) in enumerate(test_loader): + data = data.to(self.device) + target = target.to(self.device) + output = self.forward(data) + _, predicted = torch.max(output.data, 1) + correct += (predicted == target).sum().item() + total += target.size(0) + acc = 100 * correct / total + + if acc > best_acc: + best_acc = acc + if save_model: + io.save(self, save_dir, self.logger) + + msg = ( + "n_estimators: {} | Validation Acc: {:.3f} %" + " | Historical Best: {:.3f} %" + ) + self.logger.info( + msg.format(len(self.estimators_), acc, best_acc) + ) + updated = False # reset the updating flag + epoch += 1 + + if save_model and not test_loader: + io.save(self, save_dir, self.logger) + self.is_fitted_ = True + + @torchensemble_model_doc( + """Implementation on the evaluating stage of FastGeometricClassifier.""", # noqa: E501 + "classifier_predict", + ) + def predict(self, test_loader): + + if len(self.estimators_) == 0: + msg = ( + "Please call the `ensemble` method to build the ensemble" + " first." + ) + self.logger.error(msg) + raise RuntimeError(msg) + + self.eval() + correct = 0 + total = 0 + + for _, (data, target) in enumerate(test_loader): + data, target = data.to(self.device), target.to(self.device) + output = self.forward(data) + _, predicted = torch.max(output.data, 1) + correct += (predicted == target).sum().item() + total += target.size(0) + + acc = 100 * correct / total + + return acc + + +@torchensemble_model_doc( + """Implementation on the FastGeometricRegressor.""", "seq_model" +) +class FastGeometricRegressor(_BaseFastGeometric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.is_classification = False + + @torchensemble_model_doc( + """Implementation on the data forwarding in FastGeometricRegressor.""", # noqa: E501 + "regressor_forward", + ) + def forward(self, x): + pred = self._forward(x) + return pred + + @torchensemble_model_doc( + ( + """Set the attributes on optimizer for FastGeometricRegressor. """ + + """Notice that keyword arguments specified here will also be """ + + """used in the ensembling stage except the learning rate.""" + ), + "set_optimizer", + ) + def set_optimizer(self, optimizer_name, **kwargs): + super().set_optimizer(optimizer_name=optimizer_name, **kwargs) + + @torchensemble_model_doc( + ( + """Set the attributes on scheduler for FastGeometricRegressor. """ + + """Notice that this scheduler will only be used in the stage on """ # noqa: E501 + + """fitting the dummy base estimator.""" + ), + "set_scheduler", + ) + def set_scheduler(self, scheduler_name, **kwargs): + super().set_scheduler(scheduler_name=scheduler_name, **kwargs) + + @_fast_geometric_model_doc( + """Implementation on the training stage of FastGeometricRegressor.""", # noqa: E501 + "fit", + ) + def fit( + self, + train_loader, + lr_clip=None, + epochs=100, + log_interval=100, + test_loader=None, + save_model=True, + save_dir=None, + ): + self._validate_parameters(epochs, log_interval) + self.n_outputs = self._decide_n_outputs( + train_loader, self.is_classification + ) + + # A dummy base estimator + estimator_ = self._make_estimator() + ret_estimator = None + + # Set the optimizer and scheduler + optimizer = set_module.set_optimizer( + estimator_, self.optimizer_name, **self.optimizer_args + ) + + if self.use_scheduler_: + scheduler = set_module.set_scheduler( + optimizer, self.scheduler_name, **self.scheduler_args + ) + + # Utils + criterion = nn.MSELoss() + best_mse = float("inf") + + for epoch in range(epochs): + + # Training + estimator_.train() + for batch_idx, (data, target) in enumerate(train_loader): + + data, target = data.to(self.device), target.to(self.device) + + optimizer.zero_grad() + output = estimator_(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + # Print training status + if batch_idx % log_interval == 0: + with torch.no_grad(): + msg = "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f}" + self.logger.info(msg.format(epoch, batch_idx, loss)) + + # Validation + if test_loader: + estimator_.eval() + with torch.no_grad(): + mse = 0 + for _, (data, target) in enumerate(test_loader): + data = data.to(self.device) + target = target.to(self.device) + output = estimator_(data) + mse += criterion(output, target) + mse /= len(test_loader) + + if mse < best_mse: + best_mse = mse + ret_estimator = copy.deepcopy(estimator_) + + msg = ( + "Epoch: {:03d} | Validation MSE: {:.5f} |" + " Historical Best: {:.5f}" + ) + self.logger.info(msg.format(epoch, mse, best_mse)) + + if self.use_scheduler_: + scheduler.step() + + # Extra step if `test_loader` is None + if ret_estimator is None: + ret_estimator = copy.deepcopy(estimator_) + + return estimator_ + + @_fast_geometric_model_doc( + """Implementation on the ensembling stage of FastGeometricRegressor.""", # noqa: E501 + "fge", + ) + def ensemble( + self, + estimator, + train_loader, + cycle=4, + lr_1=5e-2, + lr_2=1e-4, + log_interval=100, + test_loader=None, + save_model=True, + save_dir=None, + ): + + # Set the internal optimizer + optimizer = set_module.set_optimizer( + estimator, self.optimizer_name, **self.optimizer_args + ) + + # Utils + criterion = nn.MSELoss() + best_mse = float("inf") + n_iters = len(train_loader) + updated = False + epoch = 0 + + while len(self.estimators_) < self.n_estimators: + + # Training + estimator.train() + for batch_idx, (data, target) in enumerate(train_loader): + + # Update learning rate + self._adjust_lr( + optimizer, epoch, batch_idx, n_iters, cycle, lr_1, lr_2 + ) + + data, target = data.to(self.device), target.to(self.device) + + optimizer.zero_grad() + output = estimator(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + # Print training status + if batch_idx % log_interval == 0: + with torch.no_grad(): + msg = "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f}" + self.logger.info(msg.format(epoch, batch_idx, loss)) + + # Update the ensemble + if (epoch % cycle + 1) == cycle // 2: + self.estimators_.append(copy.deepcopy(estimator)) + updated = True + + msg = "Save the base estimator with index: {}" + self.logger.info(msg.format(len(self.estimators_) - 1)) + + # Validation after each base estimator being added + if test_loader and updated: + self.eval() + with torch.no_grad(): + mse = 0 + for _, (data, target) in enumerate(test_loader): + data = data.to(self.device) + target = target.to(self.device) + output = self.forward(data) + mse += criterion(output, target) + mse /= len(test_loader) + + if mse < best_mse: + best_mse = mse + if save_model: + io.save(self, save_dir, self.logger) + + msg = ( + "Epoch: {:03d} | Validation MSE: {:.5f} |" + " Historical Best: {:.5f}" + ) + self.logger.info(msg.format(epoch, mse, best_mse)) + updated = False # reset the updating flag + epoch += 1 + + if save_model and not test_loader: + io.save(self, save_dir, self.logger) + self.is_fitted_ = True + + @torchensemble_model_doc( + """Implementation on the evaluating stage of FastGeometricRegressor.""", # noqa: E501 + "regressor_predict", + ) + def predict(self, test_loader): + + if len(self.estimators_) == 0: + msg = ( + "Please call the `ensemble` method to build the ensemble" + " first." + ) + self.logger.error(msg) + raise RuntimeError(msg) + + self.eval() + mse = 0 + criterion = nn.MSELoss() + + for batch_idx, (data, target) in enumerate(test_loader): + data, target = data.to(self.device), target.to(self.device) + output = self.forward(data) + mse += criterion(output, target) + + return mse / len(test_loader) diff --git a/torchensemble/snapshot_ensemble.py b/torchensemble/snapshot_ensemble.py index e640349..af8c19d 100644 --- a/torchensemble/snapshot_ensemble.py +++ b/torchensemble/snapshot_ensemble.py @@ -208,7 +208,7 @@ def set_scheduler(self, scheduler_name, **kwargs): @torchensemble_model_doc( - """Implementation on the SnapshotEnsembleClassifier.""", "model" + """Implementation on the SnapshotEnsembleClassifier.""", "seq_model" ) class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble): def __init__(self, **kwargs): @@ -372,7 +372,7 @@ def predict(self, test_loader): @torchensemble_model_doc( - """Implementation on the SnapshotEnsembleRegressor.""", "model" + """Implementation on the SnapshotEnsembleRegressor.""", "seq_model" ) class SnapshotEnsembleRegressor(_BaseSnapshotEnsemble): def __init__(self, **kwargs): diff --git a/torchensemble/tests/test_all_models.py b/torchensemble/tests/test_all_models.py index 9744759..fe45198 100644 --- a/torchensemble/tests/test_all_models.py +++ b/torchensemble/tests/test_all_models.py @@ -16,6 +16,7 @@ torchensemble.GradientBoostingClassifier, torchensemble.SnapshotEnsembleClassifier, torchensemble.AdversarialTrainingClassifier, + torchensemble.FastGeometricClassifier, ] @@ -26,6 +27,7 @@ torchensemble.GradientBoostingRegressor, torchensemble.SnapshotEnsembleRegressor, torchensemble.AdversarialTrainingRegressor, + torchensemble.FastGeometricRegressor, ] @@ -108,20 +110,20 @@ def test_clf_class(clf): epochs = 6 # Train - model.fit( - train_loader, epochs=epochs, test_loader=test_loader, save_model=True - ) + ret = model.fit(train_loader, epochs=epochs, test_loader=test_loader) + + # Extra step for Fast Geometric Ensemble + if isinstance(model, torchensemble.FastGeometricClassifier): + model.ensemble(ret, train_loader, test_loader=test_loader) # Test - prev_acc = model.predict(test_loader) + model.predict(test_loader) # Reload new_model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) io.load(new_model) - post_acc = new_model.predict(test_loader) - - assert prev_acc == post_acc # ensure the same performance + new_model.predict(test_loader) @pytest.mark.parametrize("clf", all_clf) @@ -152,20 +154,20 @@ def test_clf_object(clf): epochs = 6 # Train - model.fit( - train_loader, epochs=epochs, test_loader=test_loader, save_model=True - ) + ret = model.fit(train_loader, epochs=epochs, test_loader=test_loader) + + # Extra step for Fast Geometric Ensemble + if isinstance(model, torchensemble.FastGeometricClassifier): + model.ensemble(ret, train_loader, test_loader=test_loader) # Test - prev_acc = model.predict(test_loader) + model.predict(test_loader) # Reload new_model = clf(estimator=MLP_clf(), n_estimators=n_estimators, cuda=False) io.load(new_model) - post_acc = new_model.predict(test_loader) - - assert prev_acc == post_acc # ensure the same performance + new_model.predict(test_loader) @pytest.mark.parametrize("reg", all_reg) @@ -196,20 +198,20 @@ def test_reg_class(reg): epochs = 6 # Train - model.fit( - train_loader, epochs=epochs, test_loader=test_loader, save_model=True - ) + ret = model.fit(train_loader, epochs=epochs, test_loader=test_loader) + + # Extra step for Fast Geometric Ensemble + if isinstance(model, torchensemble.FastGeometricRegressor): + model.ensemble(ret, train_loader, test_loader=test_loader) # Test - prev_mse = model.predict(test_loader) + model.predict(test_loader) # Reload new_model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) io.load(new_model) - post_mse = new_model.predict(test_loader) - - assert prev_mse == post_mse # ensure the same performance + new_model.predict(test_loader) @pytest.mark.parametrize("reg", all_reg) @@ -240,17 +242,17 @@ def test_reg_object(reg): epochs = 6 # Train - model.fit( - train_loader, epochs=epochs, test_loader=test_loader, save_model=True - ) + ret = model.fit(train_loader, epochs=epochs, test_loader=test_loader) + + # Extra step for Fast Geometric Ensemble + if isinstance(model, torchensemble.FastGeometricRegressor): + model.ensemble(ret, train_loader, test_loader=test_loader) # Test - prev_mse = model.predict(test_loader) + model.predict(test_loader) # Reload new_model = reg(estimator=MLP_reg(), n_estimators=n_estimators, cuda=False) io.load(new_model) - post_mse = new_model.predict(test_loader) - - assert prev_mse == post_mse # ensure the same performance + new_model.predict(test_loader) diff --git a/torchensemble/tests/test_fast_geometric.py b/torchensemble/tests/test_fast_geometric.py new file mode 100644 index 0000000..0b8b135 --- /dev/null +++ b/torchensemble/tests/test_fast_geometric.py @@ -0,0 +1,85 @@ +import torch +import pytest +import numpy as np +import torch.nn as nn +from torch.utils.data import TensorDataset, DataLoader + +from torchensemble import FastGeometricClassifier as clf +from torchensemble import FastGeometricRegressor as reg +from torchensemble.utils.logging import set_logger + + +set_logger("pytest_fast_geometric") + + +# Testing data +X_test = torch.Tensor(np.array(([0.5, 0.5], [0.6, 0.6]))) + +y_test_clf = torch.LongTensor(np.array(([1, 0]))) +y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6]))) +y_test_reg = y_test_reg.view(-1, 1) + + +# Base estimator +class MLP_clf(nn.Module): + def __init__(self): + super(MLP_clf, self).__init__() + self.linear1 = nn.Linear(2, 2) + self.linear2 = nn.Linear(2, 2) + + def forward(self, X): + X = X.view(X.size()[0], -1) + output = self.linear1(X) + output = self.linear2(output) + return output + + +class MLP_reg(nn.Module): + def __init__(self): + super(MLP_reg, self).__init__() + self.linear1 = nn.Linear(2, 2) + self.linear2 = nn.Linear(2, 1) + + def forward(self, X): + X = X.view(X.size()[0], -1) + output = self.linear1(X) + output = self.linear2(output) + return output + + +def test_fast_geometric_workflow_clf(): + """ + This unit test checks the error message when calling `predict` before + calling `ensemble`. + """ + model = clf(estimator=MLP_clf, n_estimators=2, cuda=False) + + model.set_optimizer("Adam") + + # Prepare data + test = TensorDataset(X_test, y_test_clf) + test_loader = DataLoader(test, batch_size=2, shuffle=False) + + # Training + with pytest.raises(RuntimeError) as excinfo: + model.predict(test_loader) + assert "Please call the `ensemble` method to build" in str(excinfo.value) + + +def test_fast_geometric_workflow_reg(): + """ + This unit test checks the error message when calling `predict` before + calling `ensemble`. + """ + model = reg(estimator=MLP_reg, n_estimators=2, cuda=False) + + model.set_optimizer("Adam") + + # Prepare data + test = TensorDataset(X_test, y_test_reg) + test_loader = DataLoader(test, batch_size=2, shuffle=False) + + # Training + with pytest.raises(RuntimeError) as excinfo: + model.predict(test_loader) + assert "Please call the `ensemble` method to build" in str(excinfo.value)