-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathlandmark_loss.py
33 lines (27 loc) · 946 Bytes
/
landmark_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules import Module
class LMK_Loss(Module):
def __init__(self):
super(LMK_Loss, self).__init__()
def forward(self, gt, pred, weight=None):
pred = pred.requires_grad_()
dot_sum = (pred * gt).sum(axis=1)
predm = torch.sqrt((pred * pred).sum(axis=1))
gtm = torch.sqrt((gt * gt).sum(axis=1))
if weight is None:
loss = (1 - dot_sum / (predm * gtm)).sum() / pred.shape[0]
else:
loss = ((1 - dot_sum / (predm * gtm)) * weight).sum() / pred.shape[0]
return loss
if __name__ == '__main__':
criterion = LMK_Loss()
a = torch.abs(torch.randn(2, 2, 16, 16))
b = torch.abs(torch.randn(2, 2, 16, 16))
c = torch.abs(torch.randn(2, 16, 16))
loss = criterion(a, b, c)
loss.backward()
print(loss)