Skip to content

Commit

Permalink
Add get_vocabulary, id_to_token and token_to_id methods to ByteTokeni…
Browse files Browse the repository at this point in the history
…zer and UnicodeCodepointTokenizer.
  • Loading branch information
SamanehSaadat committed Jun 13, 2024
1 parent 50e0414 commit c7f9075
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
24 changes: 24 additions & 0 deletions keras_nlp/src/tokenizers/byte_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def vocabulary_size(self):
"""Get the integer size of the tokenizer vocabulary."""
return 256

def get_vocabulary(self):
vocab = {}
for i in range(self.vocabulary_size()):
vocab[chr(i)] = i
return vocab

def tokenize(self, inputs):
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)
Expand Down Expand Up @@ -264,6 +270,24 @@ def detokenize(self, inputs):
outputs = tf.squeeze(outputs, 0)
return outputs

def id_to_token(self, id):
"""Convert an integer id to a string token."""
if id >= self.vocabulary_size() or id < 0:
raise ValueError(
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
f"Received: {id}"
)
return chr(id)

def token_to_id(self, token):
"""Convert a string token to an integer id."""
id = ord(token)
if id >= self.vocabulary_size():
raise ValueError(
f"Token {token} is not supported by `ByteTokenizer`."
)
return id

def get_config(self):
config = super().get_config()
config.update(
Expand Down
14 changes: 14 additions & 0 deletions keras_nlp/src/tokenizers/byte_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,17 @@ def test_config(self):
tokenizer(input_data),
cloned_tokenizer(input_data),
)

def test_token_to_id(self):
input_tokens = ["f", "u", "n"]
expected_ids = [102, 117, 110]
tokenizer = ByteTokenizer()
ids = [tokenizer.token_to_id(t) for t in input_tokens]
self.assertAllEqual(ids, expected_ids)

def test_id_to_token(self):
input_ids = [102, 117, 110]
expected_tokens = ["f", "u", "n"]
tokenizer = ByteTokenizer()
tokens = [tokenizer.id_to_token(i) for i in input_ids]
self.assertAllEqual(tokens, expected_tokens)
24 changes: 24 additions & 0 deletions keras_nlp/src/tokenizers/unicode_codepoint_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ def vocabulary_size(self):
size was provided"""
return self._vocabulary_size

def get_vocabulary(self):
vocab = {}
for i in range(self.vocabulary_size()):
vocab[chr(i)] = i
return vocab

def tokenize(self, inputs):
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)
Expand Down Expand Up @@ -331,3 +337,21 @@ def detokenize(self, inputs):
if unbatched:
outputs = tf.squeeze(outputs, 0)
return outputs

def id_to_token(self, id):
"""Convert an integer id to a string token."""
if id >= self.vocabulary_size() or id < 0:
raise ValueError(
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
f"Received: {id}"
)
return chr(id)

def token_to_id(self, token):
"""Convert a string token to an integer id."""
id = ord(token)
if id >= self.vocabulary_size():
raise ValueError(
f"Token {token} is not supported by `UnicodeCodepointTokenizer`."
)
return id
14 changes: 14 additions & 0 deletions keras_nlp/src/tokenizers/unicode_codepoint_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,17 @@ def test_config(self):
tokenizer(input_data),
cloned_tokenizer(input_data),
)

def test_token_to_id(self):
input_tokens = ["ب", "و", "خ"]
expected_ids = [1576, 1608, 1582]
tokenizer = UnicodeCodepointTokenizer(vocabulary_size=2000)
ids = [tokenizer.token_to_id(t) for t in input_tokens]
self.assertAllEqual(ids, expected_ids)

def test_id_to_token(self):
input_ids = [1576, 1608, 1582]
expected_tokens = ["ب", "و", "خ"]
tokenizer = UnicodeCodepointTokenizer(vocabulary_size=2000)
tokens = [tokenizer.id_to_token(i) for i in input_ids]
self.assertAllEqual(tokens, expected_tokens)

0 comments on commit c7f9075

Please sign in to comment.