Skip to content

Commit

Permalink
add hamming test and skip tests for jax
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 28, 2024
1 parent fad0065 commit a18e6e0
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 1 deletion.
1 change: 0 additions & 1 deletion k3_addons/metrics/f_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(
**kwargs,
):
super().__init__(name=name, dtype=dtype)

if average not in (None, "micro", "macro", "weighted"):
raise ValueError(
"Unknown average type. Acceptable values "
Expand Down
54 changes: 54 additions & 0 deletions k3_addons/metrics/hamming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from keras import ops
from k3_addons.metrics.utils import MeanMetricWrapper
from k3_addons.api_export import k3_export


def hamming_distance(actuals, predictions):
result = ops.not_equal(actuals, predictions)
not_eq = ops.sum(ops.cast(result, "float32"))
ham_distance = ops.divide_no_nan(not_eq, len(result))
return ham_distance


def hamming_loss_fn(
y_true,
y_pred,
threshold,
mode,
):
if mode not in ["multiclass", "multilabel"]:
raise TypeError("mode must be either multiclass or multilabel]")

if threshold is None:
threshold = ops.max(y_pred, axis=-1, keepdims=True)
# make sure [0, 0, 0] doesn't become [1, 1, 1]
# Use abs(x) > eps, instead of x != 0 to check for zero
y_pred = ops.logical_and(y_pred >= threshold, ops.abs(y_pred) > 1e-12)
else:
y_pred = y_pred > threshold

y_true = ops.cast(y_true, "int32")
y_pred = ops.cast(y_pred, "int32")

if mode == "multiclass":
nonzero = ops.cast(ops.count_nonzero(y_true * y_pred, axis=-1), "float32")
return 1.0 - nonzero

else:
nonzero = ops.cast(ops.count_nonzero(y_true - y_pred, axis=-1), "float32")
return nonzero / ops.shape(y_true)[-1]


@k3_export("k3_addons.metrics.HammingLoss")
class HammingLoss(MeanMetricWrapper):
def __init__(
self,
mode,
name="hamming_loss",
threshold=None,
dtype=None,
**kwargs,
):
super().__init__(
hamming_loss_fn, name=name, dtype=dtype, mode=mode, threshold=threshold
)
151 changes: 151 additions & 0 deletions k3_addons/metrics/hamming_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import numpy as np
from keras import layers, ops, backend, Sequential

from k3_addons.metrics.hamming import HammingLoss, hamming_distance


def test_config():
hl_obj = HammingLoss(mode="multilabel", threshold=0.8)
assert hl_obj.name == "hamming_loss"
assert backend.standardize_dtype(hl_obj.dtype) == "float32"


def check_results(obj, value):
np.testing.assert_allclose(
ops.convert_to_numpy(value), ops.convert_to_numpy(obj.result()), atol=1e-5
)


def test_mc_4_classes():
actuals = ops.convert_to_tensor(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
],
dtype="float32",
)
predictions = ops.convert_to_tensor(
[
[0.85, 0.12, 0.03, 0],
[0, 0, 1, 0],
[0.10, 0.045, 0.045, 0.81],
[1, 0, 0, 0],
[0.80, 0.10, 0.10, 0],
[1, 0, 0, 0],
[0.05, 0, 0.90, 0.05],
],
dtype="float32",
)
# Initialize
hl_obj = HammingLoss("multiclass", threshold=0.8)
hl_obj.update_state(actuals, predictions)
# Check results
check_results(hl_obj, 0.2857143)


def test_mc_5_classes():
actuals = ops.convert_to_tensor(
[
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
],
dtype="float32",
)

predictions = ops.convert_to_tensor(
[
[0.85, 0, 0.15, 0, 0],
[0, 0, 0, 1, 0],
[0, 1, 0, 0, 0],
[0.05, 0.90, 0.04, 0, 0.01],
[0.10, 0, 0.81, 0.09, 0],
[0.10, 0.045, 0, 0.81, 0.045],
[1, 0, 0, 0, 0],
[0, 0.85, 0, 0, 0.15],
],
dtype="float32",
)
# Initialize
hl_obj = HammingLoss("multiclass", threshold=0.8)
hl_obj.update_state(actuals, predictions)
# Check results
check_results(hl_obj, 0.25)


def test_ml_4_classes():
actuals = ops.convert_to_tensor(
[[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 1]], dtype="float32"
)
predictions = ops.convert_to_tensor(
[[0.97, 0.56, 0.83, 0.77], [0.34, 0.95, 0.7, 0.89], [0.95, 0.45, 0.23, 0.56]],
dtype="float32",
)
# Initialize
hl_obj = HammingLoss("multilabel", threshold=0.8)
hl_obj.update_state(actuals, predictions)
# Check results
check_results(hl_obj, 0.16666667)


def test_ml_5_classes():
actuals = ops.convert_to_tensor(
[
[1, 0, 0, 0, 0],
[0, 0, 1, 1, 0],
[0, 1, 0, 1, 0],
[0, 1, 1, 0, 0],
[0, 0, 1, 1, 0],
[0, 0, 1, 1, 0],
[1, 0, 0, 0, 1],
[0, 1, 1, 0, 0],
],
dtype="float32",
)
predictions = ops.convert_to_tensor(
[
[1, 0.75, 0.2, 0.55, 0],
[0.65, 0.22, 0.97, 0.88, 0],
[0, 1, 0, 1, 0],
[0, 0.85, 0.9, 0.34, 0.5],
[0.4, 0.65, 0.87, 0, 0.12],
[0.66, 0.55, 1, 0.98, 0],
[0.95, 0.34, 0.67, 0.65, 0.10],
[0.45, 0.97, 0.89, 0.67, 0.46],
],
dtype="float32",
)
# Initialize
hl_obj = HammingLoss("multilabel", threshold=0.7)
hl_obj.update_state(actuals, predictions)
# Check results
check_results(hl_obj, 0.075)


def hamming_distance_test():
actuals = ops.convert_to_tensor([1, 1, 0, 0, 1, 0, 1, 0, 0, 1], dtype="int32")
predictions = ops.convert_to_tensor([1, 0, 0, 0, 1, 0, 0, 1, 0, 1], dtype="int32")
test_result = hamming_distance(actuals, predictions)
np.testing.assert_allclose(0.3, test_result, atol=1e-5)


# Keras model check
def test_keras_model():
model = Sequential()
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(3, activation="softmax"))
h1 = HammingLoss(mode="multiclass")
model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=[h1])
data = np.random.random((100, 10))
labels = np.random.random((100, 3))
model.fit(data, labels, epochs=1, batch_size=32, verbose=0)
24 changes: 24 additions & 0 deletions k3_addons/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
import numpy as np


class MeanMetricWrapper(keras.metrics.Mean):
def __init__(
self,
fn,
name=None,
dtype=None,
**kwargs,
):
super().__init__(name=name, dtype=dtype)
self._fn = fn
self._fn_kwargs = kwargs

def update_state(self, y_true, y_pred, sample_weight=None):
y_true = ops.cast(y_true, self._dtype)
y_pred = ops.cast(y_pred, self._dtype)
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
return super().update_state(matches, sample_weight=sample_weight)

def get_config(self):
config = {k: v for k, v in self._fn_kwargs.items()}
base_config = super().get_config()
return {**base_config, **config}


def _get_model(metric, num_output):
# Test API comptibility with tf.keras Model
model = keras.Sequential()
Expand Down

0 comments on commit a18e6e0

Please sign in to comment.