-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss_functions.py
66 lines (51 loc) · 2.02 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import torch.nn.functional as F
def loss_kd(outputs, labels, teacher_outputs):
"""
loss function for Knowledge Distillation (KD)
"""
alpha = .95
T = 6
loss_CE = F.cross_entropy(outputs, labels)
D_KL = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (T * T)
KD_loss = (1. - alpha)*loss_CE + alpha*D_KL
return KD_loss
def loss_kd_self(outputs, labels, teacher_outputs):
"""
loss function for self training: Tf-KD_{self}
"""
alpha = .95
T = 6
loss_CE = F.cross_entropy(outputs, labels)
D_KL = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (T * T) * 20 # multiple is 1.0 in most of cases, some cases are 10 or 50
KD_loss = (1. - alpha)*loss_CE + alpha*D_KL
return KD_loss
def loss_kd_regularization(outputs, labels):
"""
loss function for mannually-designed regularization: Tf-KD_{reg}
"""
alpha = .95
T = 6
correct_prob = 0.99 # the probability for correct class in u(k)
loss_CE = F.cross_entropy(outputs, labels)
K = outputs.size(1)
teacher_soft = torch.ones_like(outputs).cuda()
teacher_soft = teacher_soft*(1-correct_prob)/(K-1) # p^d(k)
for i in range(outputs.shape[0]):
teacher_soft[i ,labels[i]] = correct_prob
loss_soft_regu = nn.KLDivLoss()(F.log_softmax(outputs, dim=1), F.softmax(teacher_soft/T, dim=1))*20
KD_loss = (1. - alpha)*loss_CE + alpha*loss_soft_regu
return KD_loss
def loss_label_smoothing(outputs, labels):
"""
loss function for label smoothing regularization
"""
alpha = 0.1
N = outputs.size(0) # batch_size
C = outputs.size(1) # number of classes
smoothed_labels = torch.full(size=(N, C), fill_value= alpha / (C - 1)).cuda()
smoothed_labels.scatter_(dim=1, index=torch.unsqueeze(labels, dim=1), value=1-alpha)
log_prob = torch.nn.functional.log_softmax(outputs, dim=1)
loss = -torch.sum(log_prob * smoothed_labels) / N
return loss