forked from preetam22n/DeepTrans-HSU
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
98 lines (73 loc) · 3.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import numpy as np
def compute_rmse(x_true, x_pre):
w, h, c = x_true.shape
class_rmse = [0] * c
for i in range(c):
class_rmse[i] = np.sqrt(((x_true[:, :, i] - x_pre[:, :, i]) ** 2).sum() / (w * h))
mean_rmse = np.sqrt(((x_true - x_pre) ** 2).sum() / (w * h * c))
return class_rmse, mean_rmse
def compute_re(x_true, x_pred):
img_w, img_h, img_c = x_true.shape
return np.sqrt(((x_true - x_pred) ** 2).sum() / (img_w * img_h * img_c))
def compute_sad(inp, target):
p = inp.shape[-1]
sad_err = [0] * p
for i in range(p):
inp_norm = np.linalg.norm(inp[:, i], 2)
tar_norm = np.linalg.norm(target[:, i], 2)
summation = np.matmul(inp[:, i].T, target[:, i])
sad_err[i] = np.arccos(summation / (inp_norm * tar_norm))
mean_sad = np.mean(sad_err)
return sad_err, mean_sad
def Nuclear_norm(inputs):
_, band, h, w = inputs.shape
inp = torch.reshape(inputs, (band, h * w))
out = torch.norm(inp, p='nuc')
return out
class SparseKLloss(nn.Module):
def __init__(self):
super(SparseKLloss, self).__init__()
def __call__(self, inp, decay):
inp = torch.sum(inp, 0, keepdim=True)
loss = Nuclear_norm(inp)
return decay * loss
class SumToOneLoss(nn.Module):
def __init__(self, device):
super(SumToOneLoss, self).__init__()
self.register_buffer('one', torch.tensor(1, dtype=torch.float, device=device))
self.loss = nn.L1Loss(reduction='sum')
def get_target_tensor(self, inp):
target_tensor = self.one
return target_tensor.expand_as(inp)
def __call__(self, inp, gamma_reg):
inp = torch.sum(inp, 1)
target_tensor = self.get_target_tensor(inp)
loss = self.loss(inp, target_tensor)
return gamma_reg * loss
class SAD(nn.Module):
def __init__(self, num_bands):
super(SAD, self).__init__()
self.num_bands = num_bands
def forward(self, inp, target):
try:
input_norm = torch.sqrt(torch.bmm(inp.view(-1, 1, self.num_bands),
inp.view(-1, self.num_bands, 1)))
target_norm = torch.sqrt(torch.bmm(target.view(-1, 1, self.num_bands),
target.view(-1, self.num_bands, 1)))
summation = torch.bmm(inp.view(-1, 1, self.num_bands), target.view(-1, self.num_bands, 1))
angle = torch.acos(summation / (input_norm * target_norm))
except ValueError:
return 0.0
return angle
class SID(nn.Module):
def __init__(self, epsilon: float = 1e5):
super(SID, self).__init__()
self.eps = epsilon
def forward(self, inp, target):
normalize_inp = (inp / torch.sum(inp, dim=0)) + self.eps
normalize_tar = (target / torch.sum(target, dim=0)) + self.eps
sid = torch.sum(normalize_inp * torch.log(normalize_inp / normalize_tar) +
normalize_tar * torch.log(normalize_tar / normalize_inp))
return sid