-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathloss.py
106 lines (100 loc) · 4.51 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import torch.nn.functional as f
import numpy as np
import config as C
class CycleS(nn.Module):
def __init__(self):
super(CycleS, self).__init__()
self.mse = nn.MSELoss()
self.delta = 0.5
self.m = C.MARGIN
self.epsilon = 0.1
def pairwise_loss(self, all_S):
loss_num = 0
loss_sum = 0
for i in range(len(all_S)):
for j in range(len(all_S)):
if i < j:
loss_num += 1
S = all_S[i][j]
if S.shape[0] < S.shape[1]:
S21 = S
S12 = S21.transpose(1, 0)
else:
S12 = S
S21 = S12.transpose(1, 0)
scale12 = np.log(self.delta / (1 - self.delta) * S12.size(1)) / self.epsilon
scale21 = np.log(self.delta / (1 - self.delta) * S21.size(1)) / self.epsilon
S12_hat = f.softmax(S12 * scale12, dim=1)
S21_hat = f.softmax(S21 * scale21, dim=1)
S1221_hat = torch.mm(S12_hat, S21_hat)
n = S1221_hat.shape[0]
I = torch.eye(n).cuda()
pos = S1221_hat * I
neg = S1221_hat * (1 - I)
loss = 0
loss += torch.sum(f.relu(torch.max(neg, 1)[0] + self.m - torch.diag(pos)))
loss += torch.sum(f.relu(torch.max(neg, 0)[0] + self.m - torch.diag(pos)))
loss /= 2 * n
loss_sum += loss
return loss_sum / loss_num
def triplewise_loss(self, all_S):
loss_num = 0
loss_sum = 0
for i in range(len(all_S)):
for j in range(len(all_S)):
if i < j:
for k in range(len(all_S)):
if k != i and k != j :
loss_num += 1
S12_ = all_S[i][k]
S23_ = all_S[k][j]
S = torch.mm(S12_, S23_)
if S.shape[0] < S.shape[1]:
S21 = S
S12 = S21.transpose(1, 0)
else:
S12 = S
S21 = S12.transpose(1, 0)
scale12 = np.log(self.delta / (1 - self.delta) * S12.size(1)) / self.epsilon
scale21 = np.log(self.delta / (1 - self.delta) * S21.size(1)) / self.epsilon
S12_hat = f.softmax(S12 * scale12, dim=1)
S21_hat = f.softmax(S21 * scale21, dim=1)
S1221_hat = torch.mm(S12_hat, S21_hat)
n = S1221_hat.shape[0]
I = torch.eye(n).cuda()
pos = S1221_hat * I
neg = S1221_hat * (1 - I)
loss = 0
loss += torch.sum(f.relu(torch.max(neg, 1)[0] + self.m - torch.diag(pos)))
loss += torch.sum(f.relu(torch.max(neg, 0)[0] + self.m - torch.diag(pos)))
loss /= 2 * n
loss_sum += loss
return loss_sum / loss_num
def gen_X_S(self, feature_ls: list):
norm_feature = [f.normalize(i, dim=-1) for i in feature_ls]
all_blocks_S = []
all_blocks_X = []
for idx, x in enumerate(norm_feature):
row_blocks_S = []
row_blocks_X = []
for idy, y in enumerate(norm_feature):
S = torch.mm(x, y.transpose(0, 1))
scale = np.log(self.delta / (1 - self.delta) * S.size(1)) / self.epsilon
S_hat = f.softmax(S * scale, dim=1)
row_blocks_X.append(S_hat)
row_blocks_S.append(S)
row_blocks_X = torch.cat(row_blocks_X, dim=1)
all_blocks_S.append(row_blocks_S)
all_blocks_X.append(row_blocks_X)
all_blocks_X = torch.cat(all_blocks_X, dim=0)
return all_blocks_S, all_blocks_X
def forward(self, feature_ls):
S, X = self.gen_X_S(feature_ls)
loss = 0
if 'pairwise' in C.LOSS:
loss += self.pairwise_loss(S)
if 'triplewise' in C.LOSS:
loss += self.triplewise_loss(S)
return loss