forked from WilliamRo/tframe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
38 lines (27 loc) · 1 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import absolute_import
import six
import tensorflow as tf
def cross_entropy(labels, logits):
with tf.name_scope('cross_entropy'):
return tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits))
def mean_squared_error(y_true, y_predict):
return tf.reduce_mean(tf.square(tf.abs(y_true - y_predict)))
def euclidean(y_true, y_predict):
distances = tf.norm(y_true - y_predict)
return tf.reduce_mean(distances)
def get(identifier):
if callable(identifier):
return identifier
elif isinstance(identifier, six.string_types):
identifier = identifier.lower()
if identifier in ['mean_squared', 'mean_squared_error', 'mse']:
return mean_squared_error
elif identifier in ['cross_entropy']:
return cross_entropy
elif identifier in ['euclid', 'euclidean']:
return euclidean
else:
raise ValueError('Can not resolve "{}"'.format(identifier))
else:
raise TypeError('identifier must be a function or a string')