-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add hamming test and skip tests for jax
- Loading branch information
Showing
4 changed files
with
229 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from keras import ops | ||
from k3_addons.metrics.utils import MeanMetricWrapper | ||
from k3_addons.api_export import k3_export | ||
|
||
|
||
def hamming_distance(actuals, predictions): | ||
result = ops.not_equal(actuals, predictions) | ||
not_eq = ops.sum(ops.cast(result, "float32")) | ||
ham_distance = ops.divide_no_nan(not_eq, len(result)) | ||
return ham_distance | ||
|
||
|
||
def hamming_loss_fn( | ||
y_true, | ||
y_pred, | ||
threshold, | ||
mode, | ||
): | ||
if mode not in ["multiclass", "multilabel"]: | ||
raise TypeError("mode must be either multiclass or multilabel]") | ||
|
||
if threshold is None: | ||
threshold = ops.max(y_pred, axis=-1, keepdims=True) | ||
# make sure [0, 0, 0] doesn't become [1, 1, 1] | ||
# Use abs(x) > eps, instead of x != 0 to check for zero | ||
y_pred = ops.logical_and(y_pred >= threshold, ops.abs(y_pred) > 1e-12) | ||
else: | ||
y_pred = y_pred > threshold | ||
|
||
y_true = ops.cast(y_true, "int32") | ||
y_pred = ops.cast(y_pred, "int32") | ||
|
||
if mode == "multiclass": | ||
nonzero = ops.cast(ops.count_nonzero(y_true * y_pred, axis=-1), "float32") | ||
return 1.0 - nonzero | ||
|
||
else: | ||
nonzero = ops.cast(ops.count_nonzero(y_true - y_pred, axis=-1), "float32") | ||
return nonzero / ops.shape(y_true)[-1] | ||
|
||
|
||
@k3_export("k3_addons.metrics.HammingLoss") | ||
class HammingLoss(MeanMetricWrapper): | ||
def __init__( | ||
self, | ||
mode, | ||
name="hamming_loss", | ||
threshold=None, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
hamming_loss_fn, name=name, dtype=dtype, mode=mode, threshold=threshold | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import numpy as np | ||
from keras import layers, ops, backend, Sequential | ||
|
||
from k3_addons.metrics.hamming import HammingLoss, hamming_distance | ||
|
||
|
||
def test_config(): | ||
hl_obj = HammingLoss(mode="multilabel", threshold=0.8) | ||
assert hl_obj.name == "hamming_loss" | ||
assert backend.standardize_dtype(hl_obj.dtype) == "float32" | ||
|
||
|
||
def check_results(obj, value): | ||
np.testing.assert_allclose( | ||
ops.convert_to_numpy(value), ops.convert_to_numpy(obj.result()), atol=1e-5 | ||
) | ||
|
||
|
||
def test_mc_4_classes(): | ||
actuals = ops.convert_to_tensor( | ||
[ | ||
[1, 0, 0, 0], | ||
[0, 0, 1, 0], | ||
[0, 0, 0, 1], | ||
[0, 1, 0, 0], | ||
[0, 1, 0, 0], | ||
[1, 0, 0, 0], | ||
[0, 0, 1, 0], | ||
], | ||
dtype="float32", | ||
) | ||
predictions = ops.convert_to_tensor( | ||
[ | ||
[0.85, 0.12, 0.03, 0], | ||
[0, 0, 1, 0], | ||
[0.10, 0.045, 0.045, 0.81], | ||
[1, 0, 0, 0], | ||
[0.80, 0.10, 0.10, 0], | ||
[1, 0, 0, 0], | ||
[0.05, 0, 0.90, 0.05], | ||
], | ||
dtype="float32", | ||
) | ||
# Initialize | ||
hl_obj = HammingLoss("multiclass", threshold=0.8) | ||
hl_obj.update_state(actuals, predictions) | ||
# Check results | ||
check_results(hl_obj, 0.2857143) | ||
|
||
|
||
def test_mc_5_classes(): | ||
actuals = ops.convert_to_tensor( | ||
[ | ||
[1, 0, 0, 0, 0], | ||
[0, 0, 0, 1, 0], | ||
[0, 0, 0, 0, 1], | ||
[0, 1, 0, 0, 0], | ||
[0, 0, 1, 0, 0], | ||
[0, 0, 1, 0, 0], | ||
[1, 0, 0, 0, 0], | ||
[0, 1, 0, 0, 0], | ||
], | ||
dtype="float32", | ||
) | ||
|
||
predictions = ops.convert_to_tensor( | ||
[ | ||
[0.85, 0, 0.15, 0, 0], | ||
[0, 0, 0, 1, 0], | ||
[0, 1, 0, 0, 0], | ||
[0.05, 0.90, 0.04, 0, 0.01], | ||
[0.10, 0, 0.81, 0.09, 0], | ||
[0.10, 0.045, 0, 0.81, 0.045], | ||
[1, 0, 0, 0, 0], | ||
[0, 0.85, 0, 0, 0.15], | ||
], | ||
dtype="float32", | ||
) | ||
# Initialize | ||
hl_obj = HammingLoss("multiclass", threshold=0.8) | ||
hl_obj.update_state(actuals, predictions) | ||
# Check results | ||
check_results(hl_obj, 0.25) | ||
|
||
|
||
def test_ml_4_classes(): | ||
actuals = ops.convert_to_tensor( | ||
[[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 0, 1]], dtype="float32" | ||
) | ||
predictions = ops.convert_to_tensor( | ||
[[0.97, 0.56, 0.83, 0.77], [0.34, 0.95, 0.7, 0.89], [0.95, 0.45, 0.23, 0.56]], | ||
dtype="float32", | ||
) | ||
# Initialize | ||
hl_obj = HammingLoss("multilabel", threshold=0.8) | ||
hl_obj.update_state(actuals, predictions) | ||
# Check results | ||
check_results(hl_obj, 0.16666667) | ||
|
||
|
||
def test_ml_5_classes(): | ||
actuals = ops.convert_to_tensor( | ||
[ | ||
[1, 0, 0, 0, 0], | ||
[0, 0, 1, 1, 0], | ||
[0, 1, 0, 1, 0], | ||
[0, 1, 1, 0, 0], | ||
[0, 0, 1, 1, 0], | ||
[0, 0, 1, 1, 0], | ||
[1, 0, 0, 0, 1], | ||
[0, 1, 1, 0, 0], | ||
], | ||
dtype="float32", | ||
) | ||
predictions = ops.convert_to_tensor( | ||
[ | ||
[1, 0.75, 0.2, 0.55, 0], | ||
[0.65, 0.22, 0.97, 0.88, 0], | ||
[0, 1, 0, 1, 0], | ||
[0, 0.85, 0.9, 0.34, 0.5], | ||
[0.4, 0.65, 0.87, 0, 0.12], | ||
[0.66, 0.55, 1, 0.98, 0], | ||
[0.95, 0.34, 0.67, 0.65, 0.10], | ||
[0.45, 0.97, 0.89, 0.67, 0.46], | ||
], | ||
dtype="float32", | ||
) | ||
# Initialize | ||
hl_obj = HammingLoss("multilabel", threshold=0.7) | ||
hl_obj.update_state(actuals, predictions) | ||
# Check results | ||
check_results(hl_obj, 0.075) | ||
|
||
|
||
def hamming_distance_test(): | ||
actuals = ops.convert_to_tensor([1, 1, 0, 0, 1, 0, 1, 0, 0, 1], dtype="int32") | ||
predictions = ops.convert_to_tensor([1, 0, 0, 0, 1, 0, 0, 1, 0, 1], dtype="int32") | ||
test_result = hamming_distance(actuals, predictions) | ||
np.testing.assert_allclose(0.3, test_result, atol=1e-5) | ||
|
||
|
||
# Keras model check | ||
def test_keras_model(): | ||
model = Sequential() | ||
model.add(layers.Dense(64, activation="relu")) | ||
model.add(layers.Dense(3, activation="softmax")) | ||
h1 = HammingLoss(mode="multiclass") | ||
model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=[h1]) | ||
data = np.random.random((100, 10)) | ||
labels = np.random.random((100, 3)) | ||
model.fit(data, labels, epochs=1, batch_size=32, verbose=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters