Skip to content

Commit 464e084

Browse files
authored
feat: Add Fast Geometric Ensembling (#56)
* Update CHANGELOG.rst * primal update * primal update * fix typos * primal update * improve FastGeometricClassifier * improve docstrings * add FastGeometricRegressor * flake8 formatting * flake8 formatting * add unit tests * add unit tests * add unit tests * flake8 formatting * improve documentation * improve unit tests * revert classification script * fix the workflow * add example * improve api and internal workflow * improve docstrings * fix the example
1 parent e469314 commit 464e084

11 files changed

+1083
-30
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Changelog
44
Ver 0.1.*
55
---------
66

7+
* |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
78
* |Enhancement| Add flexible instantiation of optimizers and schedulers | `@cspsampedro <https://github.com/cspsampedro>`__
89
* |Feature| |API| Add support on accepting instantiated base estimators as valid input | `@xuyxu <https://github.com/xuyxu>`__
910
* |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu <https://github.com/xuyxu>`__

docs/parameters.rst

+39
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,42 @@ AdversarialTrainingRegressor
134134

135135
.. autoclass:: torchensemble.adversarial_training.AdversarialTrainingRegressor
136136
:members:
137+
138+
Fast Geometric Ensemble
139+
-----------------------
140+
141+
Motivated by geometric insights on the loss surface of deep neural networks,
142+
Fast Geometirc Ensembling (FGE) is an efficient ensemble that uses a
143+
customized learning rate scheduler to generate base estimators, similar to
144+
snapshot ensemble.
145+
146+
Reference:
147+
T. Garipov, P. Izmailov, D. Podoprikhin et al., Loss Surfaces, Mode
148+
Connectivity, and Fast Ensembling of DNNs, NeurIPS, 2018.
149+
150+
Notice that unlike all ensembles above, using fast geometric ensemble (FGE) is
151+
**a two-staged process**. Concretely, you first need to call :meth:`fit` to
152+
build a dummy base estimator that will be used to generate ensembles. Second,
153+
you need to call :meth:`ensemble` to generate real base estimators in the
154+
ensemble. The pipeline is shown in the following code snippet:
155+
156+
.. code:: python
157+
158+
model = FastGeometricClassifier(**ensemble_related_args)
159+
estimator = model.fit(train_loader, **base_estimator_related_args) # train the base estimator
160+
model.ensemble(estimator, train_loader, **fge_related_args) # generate the ensemble using the base estimator
161+
162+
You can refer to scripts in `examples <https://github.com/xuyxu/Ensemble-Pytorch/tree/master/examples>`__ for
163+
a detailed example.
164+
165+
FastGeometricClassifier
166+
***********************
167+
168+
.. autoclass:: torchensemble.fast_geometric.FastGeometricClassifier
169+
:members:
170+
171+
FastGeometricRegressor
172+
***********************
173+
174+
.. autoclass:: torchensemble.fast_geometric.FastGeometricRegressor
175+
:members:

examples/classification_cifar10_cnn.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torchensemble.bagging import BaggingClassifier
1313
from torchensemble.gradient_boosting import GradientBoostingClassifier
1414
from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier
15+
from torchensemble.fast_geometric import FastGeometricClassifier
1516

1617
from torchensemble.utils.logging import set_logger
1718

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.utils.data import DataLoader
5+
from torchvision import datasets, transforms
6+
7+
from torchensemble import FastGeometricClassifier
8+
from torchensemble.utils.logging import set_logger
9+
10+
11+
# The class `BasicBlock` and `ResNet` is modified from:
12+
# https://github.com/kuangliu/pytorch-cifar
13+
class BasicBlock(nn.Module):
14+
expansion = 1
15+
16+
def __init__(self, in_planes, planes, stride=1):
17+
super(BasicBlock, self).__init__()
18+
self.conv1 = nn.Conv2d(
19+
in_planes,
20+
planes,
21+
kernel_size=3,
22+
stride=stride,
23+
padding=1,
24+
bias=False,
25+
)
26+
self.bn1 = nn.BatchNorm2d(planes)
27+
self.conv2 = nn.Conv2d(
28+
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
29+
)
30+
self.bn2 = nn.BatchNorm2d(planes)
31+
32+
self.shortcut = nn.Sequential()
33+
if stride != 1 or in_planes != self.expansion * planes:
34+
self.shortcut = nn.Sequential(
35+
nn.Conv2d(
36+
in_planes,
37+
self.expansion * planes,
38+
kernel_size=1,
39+
stride=stride,
40+
bias=False,
41+
),
42+
nn.BatchNorm2d(self.expansion * planes),
43+
)
44+
45+
def forward(self, x):
46+
out = F.relu(self.bn1(self.conv1(x)))
47+
out = self.bn2(self.conv2(out))
48+
out += self.shortcut(x)
49+
out = F.relu(out)
50+
return out
51+
52+
53+
class ResNet(nn.Module):
54+
def __init__(self, block, num_blocks, num_classes=10):
55+
super(ResNet, self).__init__()
56+
self.in_planes = 64
57+
58+
self.conv1 = nn.Conv2d(
59+
3, 64, kernel_size=3, stride=1, padding=1, bias=False
60+
)
61+
self.bn1 = nn.BatchNorm2d(64)
62+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
63+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
64+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
65+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
66+
self.linear = nn.Linear(512 * block.expansion, num_classes)
67+
68+
def _make_layer(self, block, planes, num_blocks, stride):
69+
strides = [stride] + [1] * (num_blocks - 1)
70+
layers = []
71+
for stride in strides:
72+
layers.append(block(self.in_planes, planes, stride))
73+
self.in_planes = planes * block.expansion
74+
return nn.Sequential(*layers)
75+
76+
def forward(self, x):
77+
out = F.relu(self.bn1(self.conv1(x)))
78+
out = self.layer1(out)
79+
out = self.layer2(out)
80+
out = self.layer3(out)
81+
out = self.layer4(out)
82+
out = F.avg_pool2d(out, 4)
83+
out = out.view(out.size(0), -1)
84+
out = self.linear(out)
85+
return out
86+
87+
88+
if __name__ == "__main__":
89+
90+
# Hyper-parameters
91+
n_estimators = 10
92+
lr = 1e-1
93+
weight_decay = 5e-4
94+
momentum = 0.9
95+
epochs = 200
96+
97+
# Utils
98+
batch_size = 128
99+
data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT
100+
torch.manual_seed(0)
101+
torch.cuda.set_device(0)
102+
103+
# Load data
104+
train_transformer = transforms.Compose(
105+
[
106+
transforms.RandomHorizontalFlip(),
107+
transforms.RandomCrop(32, 4),
108+
transforms.ToTensor(),
109+
transforms.Normalize(
110+
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
111+
),
112+
]
113+
)
114+
115+
test_transformer = transforms.Compose(
116+
[
117+
transforms.ToTensor(),
118+
transforms.Normalize(
119+
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
120+
),
121+
]
122+
)
123+
124+
train_loader = DataLoader(
125+
datasets.CIFAR10(
126+
data_dir, train=True, download=True, transform=train_transformer
127+
),
128+
batch_size=batch_size,
129+
shuffle=True,
130+
)
131+
132+
test_loader = DataLoader(
133+
datasets.CIFAR10(data_dir, train=False, transform=test_transformer),
134+
batch_size=batch_size,
135+
shuffle=True,
136+
)
137+
138+
# Set the Logger
139+
logger = set_logger("FastGeometricClassifier_cifar10_resnet")
140+
141+
# Choose the Ensemble Method
142+
model = FastGeometricClassifier(
143+
estimator=ResNet,
144+
estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]},
145+
n_estimators=n_estimators,
146+
cuda=True,
147+
)
148+
149+
# Set the Optimizer
150+
model.set_optimizer(
151+
"SGD", lr=lr, weight_decay=weight_decay, momentum=momentum
152+
)
153+
154+
# Set the Scheduler
155+
model.set_scheduler("CosineAnnealingLR", T_max=epochs)
156+
157+
# Train
158+
estimator = model.fit(train_loader, epochs=epochs, test_loader=test_loader)
159+
160+
# Ensemble
161+
model.ensemble(
162+
estimator,
163+
train_loader,
164+
cycle=4,
165+
lr_1=5e-2,
166+
lr_2=5e-4,
167+
test_loader=test_loader,
168+
)
169+
170+
# Evaluate
171+
acc = model.predict(test_loader)
172+
print("Testing Acc: {:.3f}".format(acc))

torchensemble/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from .snapshot_ensemble import SnapshotEnsembleRegressor
1111
from .adversarial_training import AdversarialTrainingClassifier
1212
from .adversarial_training import AdversarialTrainingRegressor
13+
from .fast_geometric import FastGeometricClassifier
14+
from .fast_geometric import FastGeometricRegressor
1315

1416

1517
__all__ = [
@@ -25,4 +27,6 @@
2527
"SnapshotEnsembleRegressor",
2628
"AdversarialTrainingClassifier",
2729
"AdversarialTrainingRegressor",
30+
"FastGeometricClassifier",
31+
"FastGeometricRegressor",
2832
]

torchensemble/_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def get_doc(item):
2525
"""Return the selected item."""
2626
__doc = {
2727
"model": const.__model_doc,
28+
"seq_model": const.__seq_model_doc,
2829
"fit": const.__fit_doc,
2930
"set_optimizer": const.__set_optimizer_doc,
3031
"set_scheduler": const.__set_scheduler_doc,

torchensemble/_constants.py

+27
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,33 @@
3131
"""
3232

3333

34+
__seq_model_doc = """
35+
Parameters
36+
----------
37+
estimator : torch.nn.Module
38+
The class or object of your base estimator.
39+
40+
- If :obj:`class`, it should inherit from :mod:`torch.nn.Module`.
41+
- If :obj:`object`, it should be instantiated from a class inherited
42+
from :mod:`torch.nn.Module`.
43+
n_estimators : int
44+
The number of base estimators in the ensemble.
45+
estimator_args : dict, default=None
46+
The dictionary of hyper-parameters used to instantiate base
47+
estimators. This parameter will have no effect if ``estimator`` is a
48+
base estimator object after instantiation.
49+
cuda : bool, default=True
50+
51+
- If ``True``, use GPU to train and evaluate the ensemble.
52+
- If ``False``, use CPU to train and evaluate the ensemble.
53+
54+
Attributes
55+
----------
56+
estimators_ : torch.nn.ModuleList
57+
An internal container that stores all fitted base estimators.
58+
"""
59+
60+
3461
__set_optimizer_doc = """
3562
Parameters
3663
----------

0 commit comments

Comments
 (0)