-
Notifications
You must be signed in to change notification settings - Fork 0
/
TripletLoss.py
67 lines (34 loc) · 1.28 KB
/
TripletLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from torch import nn
from torch.autograd import Variable
import tensorflow as tf
class TripletLoss(object):
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
def __init__(self, margin=None):
self.margin = margin
'''if margin is not None:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()'''
def __call__(self, dist_ap, dist_an):
"""
Args:
dist_ap: pytorch Variable, distance between anchor and positive sample,
shape [N]
dist_an: pytorch Variable, distance between anchor and negative sample,
shape [N]
Returns:
loss: pytorch Variable, with shape [1]
"""
''' y = Variable(dist_an.data.new().resize_as_(dist_an.data).fill_(1))
if self.margin is not None:
loss = self.ranking_loss(dist_an, dist_ap, y)
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss'''
#d_pos = tf.reduce_sum(tf.square(dist_ap, 1)
#d_neg = tf.reduce_sum(tf.square(dist_an), 1)
loss = tf.maximum(0., margin + dist_ap - dist_an)
loss = tf.reduce_mean(loss)
retutn loss