Skip to content

Commit

Permalink
feat: Add Fast Geometric Ensembling (#56)
Browse files Browse the repository at this point in the history
* Update CHANGELOG.rst

* primal update

* primal update

* fix typos

* primal update

* improve FastGeometricClassifier

* improve docstrings

* add FastGeometricRegressor

* flake8 formatting

* flake8 formatting

* add unit tests

* add unit tests

* add unit tests

* flake8 formatting

* improve documentation

* improve unit tests

* revert classification script

* fix the workflow

* add example

* improve api and internal workflow

* improve docstrings

* fix the example
  • Loading branch information
xuyxu authored Mar 23, 2021
1 parent e469314 commit 464e084
Show file tree
Hide file tree
Showing 11 changed files with 1,083 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
* |Enhancement| Add flexible instantiation of optimizers and schedulers | `@cspsampedro <https://github.com/cspsampedro>`__
* |Feature| |API| Add support on accepting instantiated base estimators as valid input | `@xuyxu <https://github.com/xuyxu>`__
* |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
39 changes: 39 additions & 0 deletions docs/parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,42 @@ AdversarialTrainingRegressor

.. autoclass:: torchensemble.adversarial_training.AdversarialTrainingRegressor
:members:

Fast Geometric Ensemble
-----------------------

Motivated by geometric insights on the loss surface of deep neural networks,
Fast Geometirc Ensembling (FGE) is an efficient ensemble that uses a
customized learning rate scheduler to generate base estimators, similar to
snapshot ensemble.

Reference:
T. Garipov, P. Izmailov, D. Podoprikhin et al., Loss Surfaces, Mode
Connectivity, and Fast Ensembling of DNNs, NeurIPS, 2018.

Notice that unlike all ensembles above, using fast geometric ensemble (FGE) is
**a two-staged process**. Concretely, you first need to call :meth:`fit` to
build a dummy base estimator that will be used to generate ensembles. Second,
you need to call :meth:`ensemble` to generate real base estimators in the
ensemble. The pipeline is shown in the following code snippet:

.. code:: python
model = FastGeometricClassifier(**ensemble_related_args)
estimator = model.fit(train_loader, **base_estimator_related_args) # train the base estimator
model.ensemble(estimator, train_loader, **fge_related_args) # generate the ensemble using the base estimator
You can refer to scripts in `examples <https://github.com/xuyxu/Ensemble-Pytorch/tree/master/examples>`__ for
a detailed example.

FastGeometricClassifier
***********************

.. autoclass:: torchensemble.fast_geometric.FastGeometricClassifier
:members:

FastGeometricRegressor
***********************

.. autoclass:: torchensemble.fast_geometric.FastGeometricRegressor
:members:
1 change: 1 addition & 0 deletions examples/classification_cifar10_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchensemble.bagging import BaggingClassifier
from torchensemble.gradient_boosting import GradientBoostingClassifier
from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier
from torchensemble.fast_geometric import FastGeometricClassifier

from torchensemble.utils.logging import set_logger

Expand Down
172 changes: 172 additions & 0 deletions examples/fast_geometric_ensemble_cifar10_resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from torchensemble import FastGeometricClassifier
from torchensemble.utils.logging import set_logger


# The class `BasicBlock` and `ResNet` is modified from:
# https://github.com/kuangliu/pytorch-cifar
class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(self.expansion * planes),
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64

self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out


if __name__ == "__main__":

# Hyper-parameters
n_estimators = 10
lr = 1e-1
weight_decay = 5e-4
momentum = 0.9
epochs = 200

# Utils
batch_size = 128
data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT
torch.manual_seed(0)
torch.cuda.set_device(0)

# Load data
train_transformer = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)

test_transformer = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)

train_loader = DataLoader(
datasets.CIFAR10(
data_dir, train=True, download=True, transform=train_transformer
),
batch_size=batch_size,
shuffle=True,
)

test_loader = DataLoader(
datasets.CIFAR10(data_dir, train=False, transform=test_transformer),
batch_size=batch_size,
shuffle=True,
)

# Set the Logger
logger = set_logger("FastGeometricClassifier_cifar10_resnet")

# Choose the Ensemble Method
model = FastGeometricClassifier(
estimator=ResNet,
estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]},
n_estimators=n_estimators,
cuda=True,
)

# Set the Optimizer
model.set_optimizer(
"SGD", lr=lr, weight_decay=weight_decay, momentum=momentum
)

# Set the Scheduler
model.set_scheduler("CosineAnnealingLR", T_max=epochs)

# Train
estimator = model.fit(train_loader, epochs=epochs, test_loader=test_loader)

# Ensemble
model.ensemble(
estimator,
train_loader,
cycle=4,
lr_1=5e-2,
lr_2=5e-4,
test_loader=test_loader,
)

# Evaluate
acc = model.predict(test_loader)
print("Testing Acc: {:.3f}".format(acc))
4 changes: 4 additions & 0 deletions torchensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .snapshot_ensemble import SnapshotEnsembleRegressor
from .adversarial_training import AdversarialTrainingClassifier
from .adversarial_training import AdversarialTrainingRegressor
from .fast_geometric import FastGeometricClassifier
from .fast_geometric import FastGeometricRegressor


__all__ = [
Expand All @@ -25,4 +27,6 @@
"SnapshotEnsembleRegressor",
"AdversarialTrainingClassifier",
"AdversarialTrainingRegressor",
"FastGeometricClassifier",
"FastGeometricRegressor",
]
1 change: 1 addition & 0 deletions torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_doc(item):
"""Return the selected item."""
__doc = {
"model": const.__model_doc,
"seq_model": const.__seq_model_doc,
"fit": const.__fit_doc,
"set_optimizer": const.__set_optimizer_doc,
"set_scheduler": const.__set_scheduler_doc,
Expand Down
27 changes: 27 additions & 0 deletions torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@
"""


__seq_model_doc = """
Parameters
----------
estimator : torch.nn.Module
The class or object of your base estimator.
- If :obj:`class`, it should inherit from :mod:`torch.nn.Module`.
- If :obj:`object`, it should be instantiated from a class inherited
from :mod:`torch.nn.Module`.
n_estimators : int
The number of base estimators in the ensemble.
estimator_args : dict, default=None
The dictionary of hyper-parameters used to instantiate base
estimators. This parameter will have no effect if ``estimator`` is a
base estimator object after instantiation.
cuda : bool, default=True
- If ``True``, use GPU to train and evaluate the ensemble.
- If ``False``, use CPU to train and evaluate the ensemble.
Attributes
----------
estimators_ : torch.nn.ModuleList
An internal container that stores all fitted base estimators.
"""


__set_optimizer_doc = """
Parameters
----------
Expand Down
Loading

0 comments on commit 464e084

Please sign in to comment.