-
Notifications
You must be signed in to change notification settings - Fork 100
/
lwf.py
146 lines (129 loc) · 7.08 KB
/
lwf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from copy import deepcopy
from argparse import ArgumentParser
from .incremental_learning import Inc_Learning_Appr
from datasets.exemplars_dataset import ExemplarsDataset
class Appr(Inc_Learning_Appr):
"""Class implementing the Learning Without Forgetting (LwF) approach
described in https://arxiv.org/abs/1606.09282
"""
# Weight decay of 0.0005 is used in the original article (page 4).
# Page 4: "The warm-up step greatly enhances fine-tuning’s old-task performance, but is not so crucial to either our
# method or the compared Less Forgetting Learning (see Table 2(b))."
def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=10000,
momentum=0, wd=0, multi_softmax=False, wu_nepochs=0, wu_lr_factor=1, fix_bn=False, eval_on_train=False,
logger=None, exemplars_dataset=None, lamb=1, T=2):
super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
exemplars_dataset)
self.model_old = None
self.lamb = lamb
self.T = T
@staticmethod
def exemplars_dataset_class():
return ExemplarsDataset
@staticmethod
def extra_parser(args):
"""Returns a parser containing the approach specific parameters"""
parser = ArgumentParser()
# Page 5: "lambda is a loss balance weight, set to 1 for most our experiments. Making lambda larger will favor
# the old task performance over the new task’s, so we can obtain a old-task-new-task performance line by
# changing lambda."
parser.add_argument('--lamb', default=1, type=float, required=False,
help='Forgetting-intransigence trade-off (default=%(default)s)')
# Page 5: "We use T=2 according to a grid search on a held out set, which aligns with the authors’
# recommendations." -- Using a higher value for T produces a softer probability distribution over classes.
parser.add_argument('--T', default=2, type=int, required=False,
help='Temperature scaling (default=%(default)s)')
return parser.parse_known_args(args)
def _get_optimizer(self):
"""Returns the optimizer"""
if len(self.exemplars_dataset) == 0 and len(self.model.heads) > 1:
# if there are no exemplars, previous heads are not modified
params = list(self.model.model.parameters()) + list(self.model.heads[-1].parameters())
else:
params = self.model.parameters()
return torch.optim.SGD(params, lr=self.lr, weight_decay=self.wd, momentum=self.momentum)
def train_loop(self, t, trn_loader, val_loader):
"""Contains the epochs loop"""
# add exemplars to train_loader
if len(self.exemplars_dataset) > 0 and t > 0:
trn_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
batch_size=trn_loader.batch_size,
shuffle=True,
num_workers=trn_loader.num_workers,
pin_memory=trn_loader.pin_memory)
# FINETUNING TRAINING -- contains the epochs loop
super().train_loop(t, trn_loader, val_loader)
# EXEMPLAR MANAGEMENT -- select training subset
self.exemplars_dataset.collect_exemplars(self.model, trn_loader, val_loader.dataset.transform)
def post_train_process(self, t, trn_loader):
"""Runs after training all the epochs of the task (after the train session)"""
# Restore best and save model for future tasks
self.model_old = deepcopy(self.model)
self.model_old.eval()
self.model_old.freeze_all()
def train_epoch(self, t, trn_loader):
"""Runs a single epoch"""
self.model.train()
if self.fix_bn and t > 0:
self.model.freeze_bn()
for images, targets in trn_loader:
# Forward old model
targets_old = None
if t > 0:
targets_old = self.model_old(images.to(self.device))
# Forward current model
outputs = self.model(images.to(self.device))
loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
# Backward
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
self.optimizer.step()
def eval(self, t, val_loader):
"""Contains the evaluation code"""
with torch.no_grad():
total_loss, total_acc_taw, total_acc_tag, total_num = 0, 0, 0, 0
self.model.eval()
for images, targets in val_loader:
# Forward old model
targets_old = None
if t > 0:
targets_old = self.model_old(images.to(self.device))
# Forward current model
outputs = self.model(images.to(self.device))
loss = self.criterion(t, outputs, targets.to(self.device), targets_old)
hits_taw, hits_tag = self.calculate_metrics(outputs, targets)
# Log
total_loss += loss.data.cpu().numpy().item() * len(targets)
total_acc_taw += hits_taw.sum().data.cpu().numpy().item()
total_acc_tag += hits_tag.sum().data.cpu().numpy().item()
total_num += len(targets)
return total_loss / total_num, total_acc_taw / total_num, total_acc_tag / total_num
def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5):
"""Calculates cross-entropy with temperature scaling"""
out = torch.nn.functional.softmax(outputs, dim=1)
tar = torch.nn.functional.softmax(targets, dim=1)
if exp != 1:
out = out.pow(exp)
out = out / out.sum(1).view(-1, 1).expand_as(out)
tar = tar.pow(exp)
tar = tar / tar.sum(1).view(-1, 1).expand_as(tar)
out = out + eps / out.size(1)
out = out / out.sum(1).view(-1, 1).expand_as(out)
ce = -(tar * out.log()).sum(1)
if size_average:
ce = ce.mean()
return ce
def criterion(self, t, outputs, targets, outputs_old=None):
"""Returns the loss value"""
loss = 0
if t > 0:
# Knowledge distillation loss for all previous tasks
loss += self.lamb * self.cross_entropy(torch.cat(outputs[:t], dim=1),
torch.cat(outputs_old[:t], dim=1), exp=1.0 / self.T)
# Current cross-entropy loss -- with exemplars use all heads
if len(self.exemplars_dataset) > 0:
return loss + torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
return loss + torch.nn.functional.cross_entropy(outputs[t], targets - self.model.task_offset[t])