Skip to content

Commit

Permalink
Implement a new unit test and correct some typos. (#19029)
Browse files Browse the repository at this point in the history
* test: implement sparse inputs unit test for Discretization layer

Co-authored-by: AlvaroMaza <[email protected]>

* chore: correct spelling errors in some unit tests

Co-authored-by: AlvaroMaza <[email protected]>

* fix: correct the implementation of the 'sparse' argument in Discretization.

* fix `self.sparse` not working.
* set `test_sparse_inputs` to `test_sparse_output`

Co-authored-by: AlvaroMaza <[email protected]>

---------

Co-authored-by: AlvaroMaza <[email protected]>
  • Loading branch information
dugujiujian1999 and AlvaroMaza authored Jan 8, 2024
1 parent 85c9a58 commit 4993d49
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
2 changes: 1 addition & 1 deletion keras/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_variable_numpy(self):
self.assertAllClose(v.numpy(), np.array([1, 2, 3]))

@pytest.mark.skipif(
backend.backend() != "tf",
backend.backend() != "tensorflow",
reason="Tests for MirroredVariable under tf backend",
)
def test_variable_numpy_scalar(self):
Expand Down
2 changes: 2 additions & 0 deletions keras/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def call(self, inputs):
dtype=self.compute_dtype,
backend_module=self.backend,
)
if self.sparse:
return tf.sparse.from_dense(outputs)
return outputs

def get_config(self):
Expand Down
58 changes: 54 additions & 4 deletions keras/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy as np
import pytest
from absl.testing import parameterized
from tensorflow import data as tf_data

Expand All @@ -11,7 +12,7 @@
from keras.saving import saving_api


class DicretizationTest(testing.TestCase, parameterized.TestCase):
class DiscretizationTest(testing.TestCase, parameterized.TestCase):
def test_discretization_basics(self):
self.run_layer_test(
layers.Discretization,
Expand Down Expand Up @@ -125,6 +126,55 @@ def test_saving(self):
model = saving_api.load_model(fpath)
self.assertAllClose(layer(ref_input), ref_output)

def test_sparse_inputs(self):
# TODO
pass
@parameterized.parameters(
[
(
"one_hot",
[[-1.0, 0.2, 0.7, 1.2]],
[
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
],
),
(
"multi_hot",
[[[-1.0], [0.2], [0.7], [1.2]]],
[
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
],
),
(
"count",
[[-1.0], [0.2], [0.7], [1.2]],
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
),
]
)
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Sparse tensor only works in TensorFlow",
)
def test_sparse_output(self, output_mode, input_array, expected_output):
from keras.utils.module_utils import tensorflow as tf

x = np.array(input_array)
layer = layers.Discretization(
bin_boundaries=[0.0, 0.5, 1.0], sparse=True, output_mode=output_mode
)
output = layer(x)
self.assertTrue(isinstance(output, tf.SparseTensor))
self.assertAllClose(output, np.array(expected_output))

0 comments on commit 4993d49

Please sign in to comment.