Skip to content

Commit

Permalink
Make metrics stateful (#19)
Browse files Browse the repository at this point in the history
This patch makes metrics stateful.
  • Loading branch information
ybubnov authored Nov 16, 2018
1 parent 6a9a43c commit f12f335
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 26 deletions.
2 changes: 1 addition & 1 deletion keras_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.4"
__version__ = "0.0.5"


from keras_metrics.metrics import *
33 changes: 23 additions & 10 deletions keras_metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ class layer(Layer):

def __init__(self, label=None, **kwargs):
super(layer, self).__init__(**kwargs)
self.stateful = True
self.epsilon = K.constant(K.epsilon(), dtype="float64")

# If layer metric is explicitly created to evaluate specified class,
# then use a binary transformation of the output arrays, otherwise
# calculate an "overal" metric.
# calculate an "overall" metric.
if label:
self.cast_strategy = partial(self._binary, label=label)
else:
Expand All @@ -36,7 +38,7 @@ def _categorical(self, y_true, y_pred, dtype):
# of the output vector has exactly two elements, we can choose
# the label automatically.
#
# When the shape had dimenstion 3 and more and the label is
# When the shape had dimension 3 and more and the label is
# not specified, we should throw an error as long as calculated
# metric is incorrect.
_, labels = y_pred.shape
Expand All @@ -46,7 +48,7 @@ def _categorical(self, y_true, y_pred, dtype):
raise ValueError("With 2 and more output classes a "
"metric label must be specified")

y_true = K.cast(y_true, dtype)
y_true = K.cast(K.round(y_true), dtype)
y_pred = K.cast(K.round(y_pred), dtype)
return y_true, y_pred

Expand Down Expand Up @@ -167,15 +169,14 @@ def __call__(self, y_true, y_pred):
class recall(layer):
"""Create a metric for model's recall calculation.
Recall measures propotion of actual positives that was indetified correctly.
Recall measures proportion of actual positives that was identified correctly.
"""

def __init__(self, name="recall", **kwargs):
super(recall, self).__init__(name=name, **kwargs)

self.tp = true_positive()
self.fn = false_negative()
self.capping = K.constant(1, dtype="int32")

def reset_states(self):
"""Reset the state of the metrics."""
Expand All @@ -186,8 +187,13 @@ def __call__(self, y_true, y_pred):
tp = self.tp(y_true, y_pred)
fn = self.fn(y_true, y_pred)

div = K.maximum((tp + fn), self.capping)
return truediv(tp, div)
self.add_update(self.tp.updates)
self.add_update(self.fn.updates)

tp = K.cast(tp, self.epsilon.dtype)
fn = K.cast(fn, self.epsilon.dtype)

return truediv(tp, tp + fn + self.epsilon)


class precision(layer):
Expand All @@ -202,7 +208,6 @@ def __init__(self, name="precision", **kwargs):

self.tp = true_positive()
self.fp = false_positive()
self.capping = K.constant(1, dtype="int32")

def reset_states(self):
"""Reset the state of the metrics."""
Expand All @@ -213,8 +218,13 @@ def __call__(self, y_true, y_pred):
tp = self.tp(y_true, y_pred)
fp = self.fp(y_true, y_pred)

div = K.maximum((tp + fp), self.capping)
return truediv(tp, div)
self.add_update(self.tp.updates)
self.add_update(self.fp.updates)

tp = K.cast(tp, self.epsilon.dtype)
fp = K.cast(fp, self.epsilon.dtype)

return truediv(tp, tp + fp + self.epsilon)


class f1_score(layer):
Expand All @@ -238,4 +248,7 @@ def __call__(self, y_true, y_pred):
pr = self.precision(y_true, y_pred)
rec = self.recall(y_true, y_pred)

self.add_update(self.precision.updates)
self.add_update(self.recall.updates)

return 2 * truediv(pr * rec, pr + rec + K.epsilon())
47 changes: 32 additions & 15 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import keras
import keras.backend
import keras_metrics
import itertools
import numpy
import unittest

Expand All @@ -8,6 +10,7 @@ class TestMetrics(unittest.TestCase):

def test_metrics(self):
tp = keras_metrics.true_positive()
tn = keras_metrics.true_negative()
fp = keras_metrics.false_positive()
fn = keras_metrics.false_negative()

Expand All @@ -16,27 +19,33 @@ def test_metrics(self):
f1 = keras_metrics.f1_score()

model = keras.models.Sequential()
model.add(keras.layers.Dense(1, activation="sigmoid", input_dim=2))
model.add(keras.layers.Dense(1, activation="softmax"))
model.add(keras.layers.Activation(keras.backend.sin))
model.add(keras.layers.Activation(keras.backend.abs))

model.compile(optimizer="sgd",
loss="binary_crossentropy",
metrics=[tp, fp, fn, precision, recall, f1])
metrics=[tp, tn, fp, fn, precision, recall, f1])

samples = 1000
x = numpy.random.random((samples, 2))
samples = 10000
batch_size = 100
lim = numpy.pi/2

x = numpy.random.uniform(0, lim, (samples, 1))
y = numpy.random.randint(2, size=(samples, 1))

model.fit(x, y, epochs=1, batch_size=10)
metrics = model.evaluate(x, y, batch_size=10)[1:]
model.fit(x, y, epochs=10, batch_size=batch_size)
metrics = model.evaluate(x, y, batch_size=batch_size)[1:]

metrics = list(map(float, metrics))

tp_val = metrics[0]
fp_val = metrics[1]
fn_val = metrics[2]
tn_val = metrics[1]
fp_val = metrics[2]
fn_val = metrics[3]

precision = metrics[3]
recall = metrics[4]
f1 = metrics[5]
precision = metrics[4]
recall = metrics[5]
f1 = metrics[6]

expected_precision = tp_val / (tp_val + fp_val)
expected_recall = tp_val / (tp_val + fn_val)
Expand All @@ -45,9 +54,17 @@ def test_metrics(self):
f1_divisor = (expected_precision+expected_recall)
expected_f1 = (2 * f1_divident / f1_divisor)

self.assertAlmostEqual(expected_precision, precision, delta=0.05)
self.assertAlmostEqual(expected_recall, recall, delta=0.05)
self.assertAlmostEqual(expected_f1, f1, delta=0.05)
self.assertGreaterEqual(tp_val, 0.0)
self.assertGreaterEqual(fp_val, 0.0)
self.assertGreaterEqual(fn_val, 0.0)
self.assertGreaterEqual(tn_val, 0.0)

self.assertEqual(sum(metrics[0:4]), samples)

places = 4
self.assertAlmostEqual(expected_precision, precision, places=places)
self.assertAlmostEqual(expected_recall, recall, places=places)
self.assertAlmostEqual(expected_f1, f1, places=places)


if __name__ == "__main__":
Expand Down

0 comments on commit f12f335

Please sign in to comment.