Skip to content

Commit

Permalink
Allow rank > 2 for input shapes in numerical_utils (#19020)
Browse files Browse the repository at this point in the history
* Allow rank > 2 for input shapes in numerical_utils

* Update `utils/numerical_utils.py` to support input shapes with a rank greater than two.
* when the output_mode is  `multi_hot`, `one_hot`, or `int`

#18995

* Add more test cases for one_hot and multi_hot

* Refactor binary_output check for input rank validation

Moved binary_output evaluation to start.
  • Loading branch information
dugujiujian1999 authored Jan 5, 2024
1 parent 62471f2 commit 6f35f2e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 40 deletions.
60 changes: 31 additions & 29 deletions keras/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy as np
from absl.testing import parameterized
from tensorflow import data as tf_data

from keras import backend
Expand All @@ -10,7 +11,7 @@
from keras.saving import saving_api


class DicretizationTest(testing.TestCase):
class DicretizationTest(testing.TestCase, parameterized.TestCase):
def test_discretization_basics(self):
self.run_layer_test(
layers.Discretization,
Expand All @@ -35,38 +36,39 @@ def test_adapt_flow(self):
output = layer(np.array([[0.0, 0.1, 0.3]]))
self.assertTrue(output.dtype, "int32")

def test_correctness(self):
# int mode
layer = layers.Discretization(
bin_boundaries=[0.0, 0.5, 1.0], output_mode="int"
)
output = layer(np.array([[-1.0, 0.0, 0.1, 0.8, 1.2]]))
self.assertTrue(backend.is_tensor(output))
self.assertAllClose(output, np.array([[0, 1, 1, 2, 3]]))

# one_hot mode
layer = layers.Discretization(
bin_boundaries=[0.0, 0.5, 1.0], output_mode="one_hot"
)
output = layer(np.array([0.1, 0.8]))
self.assertTrue(backend.is_tensor(output))
self.assertAllClose(output, np.array([[0, 1, 0, 0], [0, 0, 1, 0]]))

# multi_hot mode
layer = layers.Discretization(
bin_boundaries=[0.0, 0.5, 1.0], output_mode="multi_hot"
)
output = layer(np.array([[0.1, 0.8]]))
self.assertTrue(backend.is_tensor(output))
self.assertAllClose(output, np.array([[0, 1, 1, 0]]))
@parameterized.parameters(
[
("int", [[-1.0, 0.0, 0.1, 0.8, 1.2]], [[0, 1, 1, 2, 3]]),
("one_hot", [0.1, 0.8], [[0, 1, 0, 0], [0, 0, 1, 0]]),
("multi_hot", [[0.1, 0.8]], [[0, 1, 1, 0]]),
(
"one_hot",
[[[0.15, 0.75], [0.85, 0.45]]],
[
[
[[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],
[[0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0]],
]
],
),
(
"multi_hot",
[[[0.15, 0.75], [0.85, 0.45]]],
[[[0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0]]],
),
("count", [[0.1, 0.8, 0.9]], [[0, 1, 2, 0]]),
]
)
def test_correctness(self, output_mode, input_array, expected_output):
input_array = np.array(input_array)
expected_output = np.array(expected_output)

# count mode
layer = layers.Discretization(
bin_boundaries=[0.0, 0.5, 1.0], output_mode="count"
bin_boundaries=[0.0, 0.5, 1.0], output_mode=output_mode
)
output = layer(np.array([[0.1, 0.8, 0.9]]))
output = layer(input_array)
self.assertTrue(backend.is_tensor(output))
self.assertAllClose(output, np.array([[0, 1, 2, 0]]))
self.assertAllClose(output, expected_output)

def test_tf_data_compatibility(self):
# With fixed bins
Expand Down
29 changes: 18 additions & 11 deletions keras/utils/numerical_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,27 @@ def encode_categorical_inputs(
if output_mode == "int":
return backend_module.cast(inputs, dtype=dtype)

original_shape = inputs.shape
binary_output = output_mode in ("multi_hot", "one_hot")
original_shape = backend_module.shape(inputs)
rank_of_inputs = len(original_shape)

# In all cases, we should uprank scalar input to a single sample.
if len(backend_module.shape(inputs)) == 0:
if rank_of_inputs == 0:
# We need to update `rank_of_inputs`
# If necessary.
inputs = backend_module.numpy.expand_dims(inputs, -1)
elif rank_of_inputs > 2:
# The `count` mode does not support inputs with a rank greater than 2.
if not binary_output:
raise ValueError(
"When output_mode is anything other than "
"`'multi_hot', 'one_hot', or 'int'`, "
"the rank must be 2 or less. "
f"Received output_mode: {output_mode} "
f"and input shape: {original_shape}, "
f"which would result in output rank {rank_of_inputs}."
)

if len(backend_module.shape(inputs)) > 2:
raise ValueError(
"When output_mode is not `'int'`, maximum supported output rank "
f"is 2. Received output_mode {output_mode} and input shape "
f"{original_shape}, "
f"which would result in output rank {inputs.shape.rank}."
)

binary_output = output_mode in ("multi_hot", "one_hot")
if binary_output:
if output_mode == "one_hot":
bincounts = backend_module.nn.one_hot(inputs, depth)
Expand Down

0 comments on commit 6f35f2e

Please sign in to comment.