Skip to content

Commit

Permalink
[TextVectorization Layer] Added tests for testing the funtionality (#…
Browse files Browse the repository at this point in the history
…20586)

* [TextVectorization Layer] Added tests for funtionality

* Fix formatting
  • Loading branch information
Frightera authored Dec 4, 2024
1 parent 1aff379 commit 5354cf9
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions keras/src/layers/preprocessing/text_vectorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,115 @@ def test_raises_exception_ragged_tensor(self):
vocabulary=["baz", "bar", "foo"],
ragged=True,
)

def test_multi_hot_output(self):
layer = layers.TextVectorization(
output_mode="multi_hot", vocabulary=["foo", "bar", "baz"]
)
input_data = [["foo bar"], ["baz foo foo"]]
output = layer(input_data)

"""
First batch
Tokens present: ["foo", "bar"]
For each token in vocabulary:
foo (index 1): present -> 1
bar (index 2): present -> 1
baz (index 3): absent -> 0
Result: [0, 1, 1, 0]
Second batch
Tokens: ["baz", "foo", "foo"]
For each token in vocabulary:
foo (index 1): present -> 1
bar (index 2): absent -> 0
baz (index 3): present -> 1
Result: [0, 1, 0, 1]
"""
self.assertAllClose(output, [[0, 1, 1, 0], [0, 1, 0, 1]])

def test_output_mode_count_output(self):
layer = layers.TextVectorization(
output_mode="count", vocabulary=["foo", "bar", "baz"]
)
output = layer(["foo bar", "baz foo foo"])
self.assertAllClose(output, [[0, 1, 1, 0], [0, 2, 0, 1]])

def test_output_mode_tf_idf_output(self):
layer = layers.TextVectorization(
output_mode="tf_idf",
vocabulary=["foo", "bar", "baz"],
idf_weights=[0.3, 0.5, 0.2],
)
output = layer(["foo bar", "baz foo foo"])
self.assertAllClose(
output, [[0.0, 0.3, 0.5, 0.0], [0.0, 0.6, 0.0, 0.2]]
)

def test_lower_and_strip_punctuation_standardization(self):
layer = layers.TextVectorization(
standardize="lower_and_strip_punctuation",
vocabulary=["hello", "world", "this", "is", "nice", "test"],
)
output = layer(["Hello, World!. This is just a nice test!"])
self.assertTrue(backend.is_tensor(output))

# test output sequence length, taking first batch.
self.assertEqual(len(output[0]), 8)

self.assertAllEqual(output, [[2, 3, 4, 5, 1, 1, 6, 7]])

def test_lower_standardization(self):
layer = layers.TextVectorization(
standardize="lower",
vocabulary=[
"hello,",
"hello",
"world",
"this",
"is",
"nice",
"test",
],
)
output = layer(["Hello, World!. This is just a nice test!"])
self.assertTrue(backend.is_tensor(output))
self.assertEqual(len(output[0]), 8)
"""
The input is lowercased and tokenized into words. The vocab is:
{0: '',
1: '[UNK]',
2: 'hello,',
3: 'hello',
4: 'world',
5: 'this',
6: 'is',
7: 'nice',
8: 'test'}
"""
self.assertAllEqual(output, [[2, 1, 5, 6, 1, 1, 7, 1]])

def test_char_splitting(self):
layer = layers.TextVectorization(
split="character", vocabulary=list("abcde"), output_mode="int"
)
output = layer(["abcf"])
self.assertTrue(backend.is_tensor(output))
self.assertEqual(len(output[0]), 4)
self.assertAllEqual(output, [[2, 3, 4, 1]])

def test_custom_splitting(self):
def custom_split(text):
return tf.strings.split(text, sep="|")

layer = layers.TextVectorization(
split=custom_split,
vocabulary=["foo", "bar", "foobar"],
output_mode="int",
)
output = layer(["foo|bar"])
self.assertTrue(backend.is_tensor(output))

# after custom split, the outputted index should be the last
# token in the vocab.
self.assertAllEqual(output, [[4]])

0 comments on commit 5354cf9

Please sign in to comment.