-
Notifications
You must be signed in to change notification settings - Fork 6
/
loss_weighted_crossentropy.py
31 lines (28 loc) · 1.22 KB
/
loss_weighted_crossentropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
from keras import backend as K
import numpy as np
def weighted_categorical_crossentropy(weights):
# https://forums.fast.ai/t/unbalanced-classes-in-image-segmentation/18289
""" weighted_categorical_crossentropy
Args:
* weights<ktensor|nparray|list>: crossentropy weights
Returns:
* weighted categorical crossentropy function
"""
if isinstance(weights,list) or isinstance(np.ndarray):
weights=K.variable(weights)
def loss(target,output,from_logits=False):
if not from_logits:
output /= tf.reduce_sum(output,
len(output.get_shape()) - 1,
True)
_epsilon = tf.convert_to_tensor(K.epsilon(), dtype=output.dtype.base_dtype)
output = tf.clip_by_value(output, _epsilon, 1. - _epsilon)
losses = target * tf.log(output)
print(losses)
weighted_losses = target * tf.log(output) * weights
print(weighted_losses)
return - tf.reduce_sum(weighted_losses,len(output.get_shape()) - 1)
else:
raise ValueError('WeightedCategoricalCrossentropy: not valid with logits')
return loss