-
Notifications
You must be signed in to change notification settings - Fork 8
/
loss.py
33 lines (22 loc) · 996 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import tensorflow as tf
import tensorflow_probability as tfp
# custom loss function
def custom_loss(y, p_base, minval=1e-9, maxval=1e9, scale = 512):
p = p_base
mu = p[:, 0:8:2]
# creating each matrix element in 4x4
Mdia = minval + tf.math.maximum(p[:, 1:8:2], 0.0)
Mcov = p[:,8:]
# placeholder zero element
zeros = tf.zeros_like(Mdia[:,0])
# assembles scale_tril matrix
row1 = tf.stack([Mdia[:,0],zeros,zeros,zeros])
row2 = tf.stack([Mcov[:,0],Mdia[:,1],zeros,zeros])
row3 = tf.stack([Mcov[:,1],Mcov[:,2],Mdia[:,2],zeros])
row4 = tf.stack([Mcov[:,3],Mcov[:,4],Mcov[:,5],Mdia[:,3]])
scale_tril = tf.transpose(tf.stack([row1,row2,row3,row4]),perm=[2,0,1])
dist = tfp.distributions.MultivariateNormalTriL(loc = mu, scale_tril = scale_tril)
likelihood = dist.prob(y)
likelihood = tf.clip_by_value(likelihood,minval,maxval)
NLL = -1*tf.math.log(likelihood)
return tf.keras.backend.sum(NLL)