forked from WilliamRo/tframe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
53 lines (39 loc) · 1.64 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import absolute_import
import six
import tensorflow as tf
def accuracy(labels, outputs):
assert isinstance(labels, tf.Tensor) and isinstance(outputs, tf.Tensor)
label_shape = labels.get_shape().as_list()
if len(label_shape) > 1 or label_shape[1] > 1:
labels = tf.argmax(labels, 1, name='labels')
outputs = tf.argmax(outputs, 1, name='prediction')
correct_prediction = tf.equal(labels, outputs, 'correct_prediction')
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32),
name='accuracy')
def delta(truth, output):
assert isinstance(truth, tf.Tensor) and isinstance(output, tf.Tensor)
return tf.norm(truth - output)
def norm_error_ratio(truth, output):
assert isinstance(truth, tf.Tensor) and isinstance(output, tf.Tensor)
return tf.norm(truth - output) / tf.norm(truth) * 100
def rms_error_ratio(truth, output):
assert isinstance(truth, tf.Tensor) and isinstance(output, tf.Tensor)
rms = lambda x: tf.sqrt(tf.reduce_mean(tf.square(x)))
return rms(truth - output) / rms(truth) * 100
def get(identifier):
if identifier is None or callable(identifier):
return identifier
elif isinstance(identifier, six.string_types):
identifier = identifier.lower()
if identifier in ['accuracy', 'acc']:
return accuracy
elif identifier in ['delta', 'distance']:
return delta
elif identifier in ['ratio', 'norm_ratio']:
return norm_error_ratio
elif identifier in ['rms_ratio']:
return rms_error_ratio
else:
raise ValueError('Can not resolve "{}"'.format(identifier))
else:
raise TypeError('identifier must be a function or a string')