forked from GuoxiaWang/insightface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
partial_fc.py
222 lines (197 loc) · 9.27 KB
/
partial_fc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import logging
import os
import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn.functional import normalize, linear
from torch.nn.parameter import Parameter
class PartialFC(Module):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
@torch.no_grad()
def __init__(self, rank, local_rank, world_size, batch_size, resume,
margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
"""
rank: int
Unique process(GPU) ID from 0 to world_size - 1.
local_rank: int
Unique process(GPU) ID within the server from 0 to 7.
world_size: int
Number of GPU.
batch_size: int
Batch size on current rank(GPU).
resume: bool
Select whether to restore the weight of softmax.
margin_softmax: callable
A function of margin softmax, eg: cosface, arcface.
num_classes: int
The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
required.
sample_rate: float
The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
embedding_size: int
The feature dimension, default is 512.
prefix: str
Path for save checkpoint, default is './'.
"""
super(PartialFC, self).__init__()
#
self.num_classes: int = num_classes
self.rank: int = rank
self.local_rank: int = local_rank
self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
self.world_size: int = world_size
self.batch_size: int = batch_size
self.margin_softmax: callable = margin_softmax
self.sample_rate: float = sample_rate
self.embedding_size: int = embedding_size
self.prefix: str = prefix
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
self.num_sample: int = int(self.sample_rate * self.num_local)
self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
if resume:
try:
self.weight: torch.Tensor = torch.load(self.weight_name)
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
raise IndexError
logging.info("softmax weight resume successfully!")
logging.info("softmax weight mom resume successfully!")
except (FileNotFoundError, KeyError, IndexError):
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init!")
logging.info("softmax weight mom init!")
else:
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init successfully!")
logging.info("softmax weight mom init successfully!")
self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
self.index = None
if int(self.sample_rate) == 1:
self.update = lambda: 0
self.sub_weight = Parameter(self.weight)
self.sub_weight_mom = self.weight_mom
else:
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
def save_params(self):
""" Save softmax weight for each rank on prefix
"""
torch.save(self.weight.data, self.weight_name)
torch.save(self.weight_mom, self.weight_mom_name)
@torch.no_grad()
def sample(self, total_label):
"""
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
`num_sample`.
total_label: tensor
Label after all gather, which cross all GPUs.
"""
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
total_label[~index_positive] = -1
total_label[index_positive] -= self.class_start
if int(self.sample_rate) != 1:
positive = torch.unique(total_label[index_positive], sorted=True)
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local], device=self.device)
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1]
index = index.sort()[0]
else:
index = positive
self.index = index
total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
self.sub_weight = Parameter(self.weight[index])
self.sub_weight_mom = self.weight_mom[index]
def forward(self, total_features, norm_weight):
""" Partial fc forward, `logits = X * sample(W)`
"""
torch.cuda.current_stream().wait_stream(self.stream)
logits = linear(total_features, norm_weight)
return logits
@torch.no_grad()
def update(self):
""" Set updated weight and weight_mom to memory bank.
"""
self.weight_mom[self.index] = self.sub_weight_mom
self.weight[self.index] = self.sub_weight
def prepare(self, label, optimizer):
"""
get sampled class centers for cal softmax.
label: tensor
Label tensor on each rank.
optimizer: opt
Optimizer for partial fc, which need to get weight mom.
"""
with torch.cuda.stream(self.stream):
total_label = torch.zeros(
size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
self.sample(total_label)
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
optimizer.param_groups[-1]['params'][0] = self.sub_weight
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
norm_weight = normalize(self.sub_weight)
return total_label, norm_weight
def forward_backward(self, label, features, optimizer):
"""
Partial fc forward and backward with model parallel
label: tensor
Label tensor on each rank(GPU)
features: tensor
Features tensor on each rank(GPU)
optimizer: optimizer
Optimizer for partial fc
Returns:
--------
x_grad: tensor
The gradient of features.
loss_v: tensor
Loss value for cross entropy.
"""
total_label, norm_weight = self.prepare(label, optimizer)
total_features = torch.zeros(
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
total_features.requires_grad = True
logits = self.forward(total_features, norm_weight)
logits = self.margin_softmax(logits, total_label)
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
# calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
# calculate prob
logits_exp.div_(logits_sum_exp)
# get one-hot
grad = logits_exp
index = torch.where(total_label != -1)[0]
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
one_hot.scatter_(1, total_label[index, None], 1)
# calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_label[index, None])
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
# calculate grad
grad[index] -= one_hot
grad.div_(self.batch_size * self.world_size)
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
# feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
x_grad = x_grad * self.world_size
# backward backbone
return x_grad, loss_v