-
Notifications
You must be signed in to change notification settings - Fork 4
/
mtp_loss.py
90 lines (74 loc) · 3.3 KB
/
mtp_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from torch.nn import functional as F
import random
import torch
random.seed(0)
torch.manual_seed(0)
class MTPLoss:
def __init__(self,
modes=3,
prediction_steps=16,
alpha_class=0.1
):
self.modes = modes
self.prediction_steps = prediction_steps
self.mode_size = self._get_mode_size()
self.alpha_class = alpha_class
def get_output_size(self):
return (self.prediction_steps + 1) * self.modes
def _get_mode_size(self):
return self.prediction_steps + 1
def my_dist(self, outputs, data_flat):
# shape to the same multi-dim
data = data_flat.view(-1, 1, self.mode_size - 1)
# using all but the probability
prediction = outputs[:, :, 0:self.mode_size - 1]
unscaled_norm = torch.norm(prediction - data, dim=2, p=1)
return unscaled_norm / data_flat.size()[1]
def _expand(self, output):
return output.view(-1, self.modes, self.mode_size)
def __call__(self, output, expected_inversion, my_arange):
"""
The loss between the network output and the expected output
:param output: the output of the network that contains predictions and probabilities
:param expected_inversion: the inversion expected from the data point
:param my_arange: a sequential vector to avoid recomputation
:return:
"""
expanded = self._expand(output)
data_flat = expected_inversion.view(-1, self.mode_size - 1)
dists = self.my_dist(expanded, data_flat)
best_mode = torch.argmin(dists, 1).detach()
# last index is reserved for probability
prob_raw = expanded[:, :, self.mode_size - 1]
log_prob = F.log_softmax(prob_raw, dim=1)
prob_contrib = -log_prob[my_arange, best_mode]
# the log of the probability (non-normalized)
norm_contrib = dists[my_arange, best_mode]
# the norm contribution
result = self.alpha_class * prob_contrib + norm_contrib
return result
if __name__ == '__main__':
batch_len = 1
device = 'cpu'
modes = 7
prediction_steps = 32
mtp_loss = MTPLoss(modes=modes, prediction_steps=prediction_steps)
# creating some reference data
expected_inversion = (torch.arange(32).to(device)*0.01).unsqueeze(0)
# my_arange is an auxiliary variable which is the size of batch len
my_arange = torch.arange(batch_len, dtype=int).to(device)
# creating some output in the shape expected from the neural network
steps_with_probaility = prediction_steps + 1
expected_inversion_with_probability = torch.ones(batch_len, steps_with_probaility)
expected_inversion_with_probability[:, 0:prediction_steps] = expected_inversion[:, :]
some_output = torch.ones(batch_len, mtp_loss.get_output_size())
for i in range(modes):
# the prediction parts
some_output[:, steps_with_probaility*i:steps_with_probaility*i+prediction_steps] = \
expected_inversion[:, :] * i / 5.
# the probability parts
some_output[:, steps_with_probaility*i+prediction_steps] = 1./(i+1)
# computing the loss
cur_mtp_losses_batch = mtp_loss(some_output, expected_inversion, my_arange)
cur_mtp_loss = torch.mean(cur_mtp_losses_batch)
print(cur_mtp_loss)