-
Notifications
You must be signed in to change notification settings - Fork 1
/
ArcMarginProduct.py
59 lines (54 loc) · 1.93 KB
/
ArcMarginProduct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class ArcMarginProduct(tf.keras.layers.Layer):
def __init__(self,s=30,m=0.5,easy_margin=False,ls_eps=0.0,**kwargs):
super(ArcMarginProduct, self).__init__(**kwargs)
self.n_classes = n_classes
self.s=s
self.m=m
self.ls_eps=ls_eps
self.easy_margin=easy_margin
self.cos_m=tf.math.cos(m)
self.sin_m=tf.math.sin(m)
self.th=tf.math.cos(math.pi-m)
self.mm=tf.math.sin(math.pi-m)*m
def get_config(self):
config=super().get_config().copy()
config.update({
'n_classes': self.n_classes,
's':self.s,
'm':self.m,
'ls_eps':self.ls_eps,
'easy_margin':self.easy_margin
})
return config
def build(self,input_shape):
super(ArcMarginProduct,self).build(input_shape[0])
self.w=self.add_weight(
name='W',
shape=(int(input_shape[0][-1]),self.n_classes),
intializer='glorat_uniform',
dtype='float32',
trainable=True,
regularizer=None
)
def call(self, inputs):
X,y=inputs
y=tf.cast(y,dtype=tf.int32)
cosine=tf.matmul(
tf.math.l2_normalize(X,axis=1),
tf.math.l2_normalize(self.w,axis=0)
)
sine=tf.math.sqrt(1.0-tf.math.pow(cosine,2))
phi=cosine*self.cos_m-sine*self.sine_m
if self.easy_margin:
phi = tf.where(cosine > 0, phi, cosine)
else:
phi=tf.where(cosine>self.th,phi,cosine-self.mm)
one_hot = tf.cast(
tf.one_hot(y,depth=self.n_classes),
dtype=cosine.dtype
)
if self.ls_eps > 0:
one_hot = (1-self.ls_eps)*one_hot + self.ls_eps/self.n_classes
output=(one_hot*phi)+((1.0-one_hot)*cosine)
output*=s
print("Ran till ArchMargin")