Skip to content

Commit

Permalink
Fix confusion matrix type (#20584)
Browse files Browse the repository at this point in the history
* fix: fix confusion matrix float32 problem

use int

* Use float32 for threshold comparisons and include warnings when the weight are float but the dtype is int
  • Loading branch information
edge7 authored Dec 10, 2024
1 parent c6c0720 commit 9294db1
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 18 deletions.
27 changes: 21 additions & 6 deletions keras/src/metrics/iou_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from keras.src import backend
from keras.src import initializers
from keras.src import ops
Expand Down Expand Up @@ -55,8 +57,8 @@ def __init__(
sparse_y_pred=True,
axis=-1,
):
# defaulting to float32 to avoid issues with confusion matrix
super().__init__(name=name, dtype=dtype or "float32")
# defaulting to int to avoid issues with confusion matrix
super().__init__(name=name, dtype=dtype or "int")
# Metric should be maximized during optimization.
self._direction = "up"
self.num_classes = num_classes
Expand All @@ -69,6 +71,7 @@ def __init__(
name="total_confusion_matrix",
shape=(num_classes, num_classes),
initializer=initializers.Zeros(),
dtype=self.dtype,
)

def update_state(self, y_true, y_pred, sample_weight=None):
Expand Down Expand Up @@ -102,7 +105,17 @@ def update_state(self, y_true, y_pred, sample_weight=None):

if sample_weight is None:
sample_weight = 1

else:
if (
hasattr(sample_weight, "dtype")
and "float" in str(sample_weight.dtype)
and "int" in str(self.dtype)
):
warnings.warn(
"You are passing weight as `float`, but dtype is `int`. "
"This may result in an incorrect weight due to type casting"
" Consider using integer weights."
)
sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype)

if len(sample_weight.shape) > 1:
Expand Down Expand Up @@ -131,7 +144,7 @@ def update_state(self, y_true, y_pred, sample_weight=None):
y_pred,
self.num_classes,
weights=sample_weight,
dtype="float32",
dtype=self.dtype,
)

return self.total_cm.assign(self.total_cm + current_cm)
Expand Down Expand Up @@ -272,10 +285,11 @@ def result(self):
denominator = ops.take_along_axis(
denominator, target_class_ids, axis=-1
)
denominator = ops.cast(denominator, dtype="float32")

# If the denominator is 0, we need to ignore the class.
num_valid_entries = ops.sum(
ops.cast(ops.greater(denominator, 1e-9), dtype=self.dtype)
ops.cast(ops.greater(denominator, 1e-9), dtype="float32")
)

iou = ops.divide(true_positives, denominator + backend.epsilon())
Expand Down Expand Up @@ -406,7 +420,8 @@ def update_state(self, y_true, y_pred, sample_weight=None):
Update op.
"""
y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)
y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)
# convert y_pred on float 32 and cast just after to dtype
y_pred = ops.convert_to_tensor(y_pred, dtype="float32")
y_pred = ops.cast(y_pred >= self.threshold, self.dtype)
return super().update_state(y_true, y_pred, sample_weight)

Expand Down
145 changes: 133 additions & 12 deletions keras/src/metrics/iou_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import models
from keras.src import testing
from keras.src.metrics import iou_metrics as metrics
from keras.src.ops import convert_to_tensor


class IoUTest(testing.TestCase):
Expand All @@ -25,9 +26,7 @@ def test_unweighted(self):
y_pred = [0, 1, 0, 1]
y_true = [0, 0, 1, 1]

obj = metrics.IoU(
num_classes=2, target_class_ids=[0, 1], dtype="float32"
)
obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])

result = obj(y_true, y_pred)

Expand Down Expand Up @@ -64,7 +63,9 @@ def test_multi_dim_input(self):
y_true = np.array([[0, 0], [1, 1]])
sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])

obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])
obj = metrics.IoU(
num_classes=2, target_class_ids=[0, 1], dtype="float32"
)

result = obj(y_true, y_pred, sample_weight=sample_weight)

Expand Down Expand Up @@ -136,7 +137,9 @@ def test_different_thresholds_weighted(self):
expected_result = (
0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)
obj = metrics.BinaryIoU(
target_class_ids=[0, 1], threshold=0.3, dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand All @@ -150,7 +153,9 @@ def test_different_thresholds_weighted(self):
expected_result = (
0.5 / (0.5 + 0.7 - 0.5) + 0.3 / (0.5 + 0.3 - 0.3)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)
obj = metrics.BinaryIoU(
target_class_ids=[0, 1], threshold=0.5, dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand Down Expand Up @@ -191,7 +196,9 @@ def test_multi_dim_input(self):
expected_result = (
0.2 / (0.6 + 0.3 - 0.2) + 0.3 / (0.4 + 0.7 - 0.3)
) / 2
obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=threshold)
obj = metrics.BinaryIoU(
target_class_ids=[0, 1], threshold=threshold, dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand Down Expand Up @@ -281,7 +288,7 @@ def test_weighted(self):
y_true = np.array([0, 0, 1, 1])
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])

m_obj = metrics.MeanIoU(num_classes=2)
m_obj = metrics.MeanIoU(num_classes=2, dtype="float32")

result = m_obj(y_true, y_pred, sample_weight=sample_weight)

Expand All @@ -300,7 +307,7 @@ def test_weighted_ignore_class_1(self):
y_true = np.array([0, 0, 1, -1])
sample_weight = np.array([0.2, 0.3, 0.4, 0.1])

m_obj = metrics.MeanIoU(num_classes=2, ignore_class=-1)
m_obj = metrics.MeanIoU(num_classes=2, ignore_class=-1, dtype="float32")

result = m_obj(y_true, y_pred, sample_weight=sample_weight)

Expand All @@ -319,7 +326,7 @@ def test_multi_dim_input(self):
y_true = np.array([[0, 0], [1, 1]])
sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])

m_obj = metrics.MeanIoU(num_classes=2)
m_obj = metrics.MeanIoU(num_classes=2, dtype="float32")

result = m_obj(y_true, y_pred, sample_weight=sample_weight)

Expand Down Expand Up @@ -351,6 +358,112 @@ def test_zero_and_non_zero_entries(self):
expected_result = (0 + 1 / (1 + 1 - 1)) / 1
self.assertAllClose(result, expected_result, atol=1e-3)

@staticmethod
def _confusion_matrix(y_true, y_pred, num_classes):
"""
Creates a confusion matrix as a numpy array using vectorized operations.
Parameters:
- y_true: array-like, true class labels.
- y_pred: array-like, predicted class labels.
- num_classes: int, number of classes.
Returns:
- conf_matrix: np.ndarray, confusion matrix of shape (num_classes,
num_classes).
"""
# Map pairs of (y_true, y_pred) to indices in the confusion matrix
indices = y_true * num_classes + y_pred
# Count occurrences of each index
conf_matrix = np.bincount(indices, minlength=num_classes * num_classes)
# Reshape the flat array into a 2D confusion matrix
conf_matrix = conf_matrix.reshape((num_classes, num_classes))
return conf_matrix

@staticmethod
def _get_big_chunk(dtype):
np.random.seed(14)
all_y_true = np.random.choice([0, 1, 2], size=(10, 530, 530))
# Generate random probabilities for each channel
random_probs = np.random.rand(10, 530, 530, 3)
# Normalize to ensure the last dimension sums to 1
all_y_pred = random_probs / random_probs.sum(axis=-1, keepdims=True)
# Convert predictions to class indices
all_y_pred_arg = np.argmax(all_y_pred, axis=-1)
mean_iou_metric = metrics.MeanIoU(num_classes=3, dtype=dtype)
conf_matrix_start_point = np.array(
[
[18729664, 18728760, 18731196],
[18727297, 18726105, 18728071],
[18727917, 18717835, 18723155],
]
)
mean_iou_metric.total_cm = mean_iou_metric.add_variable(
name="total_confusion_matrix",
shape=(3, 3),
initializer=convert_to_tensor(conf_matrix_start_point),
dtype=dtype or "int",
)
mean_iou_metric.update_state(all_y_true, all_y_pred_arg)
tmp_true = np.reshape(all_y_true, -1)
tmp_pred = np.reshape(all_y_pred_arg, -1)
return (
all_y_true,
all_y_pred_arg,
mean_iou_metric,
tmp_true,
tmp_pred,
conf_matrix_start_point,
)

def test_big_chunk(self):
# Init. process with dtype=None which will default to int
(
all_y_true,
all_y_pred_arg,
mean_iou_metric_all,
tmp_true,
tmp_pred,
conf_matrix_start_point,
) = self._get_big_chunk(dtype=None)
conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)
# Validate confusion matrices and results
conf_matrix_manual = (
self._confusion_matrix(tmp_true, tmp_pred, 3)
+ conf_matrix_start_point
)
self.assertTrue(
np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
msg="Confusion matrices do not match!",
)
# Now same but with float32 dtype, in here the confusion matrix
# should not match. Likely this can be removed
(
all_y_true,
all_y_pred_arg,
mean_iou_metric_all,
tmp_true,
tmp_pred,
conf_matrix_start_point,
) = self._get_big_chunk(dtype="float32")
conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)
# Validate confusion matrices and results
conf_matrix_manual = (
self._confusion_matrix(tmp_true, tmp_pred, 3)
+ conf_matrix_start_point
)
self.assertFalse(
np.array_equal(conf_matrix_from_keras, conf_matrix_manual),
msg="Confusion matrices match, but they should not!",
)

def test_user_warning_float_weight(self):
y_pred = [0, 1, 1, 1]
y_true = [0, 1, 1, 0]
m_obj = metrics.MeanIoU(num_classes=3)
with pytest.warns(Warning, match=r"weight.*float.*int.*casting"):
m_obj(y_true, y_pred, sample_weight=np.array([0.2, 0.3, 0.4, 0.1]))


class OneHotIoUTest(testing.TestCase):
def test_unweighted(self):
Expand Down Expand Up @@ -385,7 +498,9 @@ def test_weighted(self):
# true_positives = [0, 0, 0.1]
# iou = true_positives / (sum_row + sum_col - true_positives))
expected_result = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2
obj = metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])
obj = metrics.OneHotIoU(
num_classes=3, target_class_ids=[0, 2], dtype="float32"
)
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

Expand Down Expand Up @@ -439,6 +554,12 @@ def test_weighted(self):
expected_result = (
0.1 / (0.4 + 0.6 - 0.1) + 0 + 0.1 / (0.6 + 0.1 - 0.1)
) / 3
obj = metrics.OneHotMeanIoU(num_classes=3)
obj = metrics.OneHotMeanIoU(num_classes=3, dtype="float32")
result = obj(y_true, y_pred, sample_weight=sample_weight)
self.assertAllClose(result, expected_result, atol=1e-3)

# Check same result with int weights
sample_weight_int = [1, 2, 3, 3, 1]
obj_int = metrics.OneHotMeanIoU(num_classes=3)
result_int = obj_int(y_true, y_pred, sample_weight=sample_weight_int)
self.assertAllClose(result_int, expected_result, atol=1e-3)

0 comments on commit 9294db1

Please sign in to comment.