Skip to content

Commit

Permalink
Support tokenizer failing when length of tokens exceeds max length th…
Browse files Browse the repository at this point in the history
…rough param fail_on_input_truncation.

PiperOrigin-RevId: 644013051
Change-Id: I7eac38bbb346a19c82e378dfaef88d10d0425bea
  • Loading branch information
Sax Authors authored and copybara-github committed Jun 17, 2024
1 parent 25684e3 commit 0558cea
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 55 deletions.
38 changes: 14 additions & 24 deletions saxml/server/pax/lm/lm_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenizer for language models."""

from __future__ import annotations

import dataclasses
Expand Down Expand Up @@ -72,7 +71,6 @@ class LMTokenizer(base_hyperparams.FiddleBaseParameterizable):
vocabulary_path: str = None
extra_ids: int = 0
reverse_extra_ids: bool = True
fail_on_input_truncation: bool = False

_vocab: vocabularies.Vocabulary = dataclasses.field(init=False, repr=False)

Expand Down Expand Up @@ -113,25 +111,13 @@ def __post_init__(self):
os.makedirs(os.path.dirname(local_vocabulary_path))

tf.io.gfile.copy(
vocabulary_path, local_vocabulary_path, overwrite=True
)
vocabulary_path,
local_vocabulary_path,
overwrite=True)
vocabulary_path = local_vocabulary_path

self._vocab = vocab_cls(vocabulary_path)

def _truncate_labels(self, max_length: int, labels: tf.Tensor):
p = self.hparams
if (
p.fail_on_input_truncation
and labels.bounding_shape(axis=1) > max_length
):
raise ValueError(f'Labels size exceeds max length of {max_length}.')

if p.slice_left:
return labels[:, :max_length]
else:
return labels[:, -(max_length):]

@property
def Vocabulary(self) -> vocabularies.Vocabulary:
"""Get the vocabulary."""
Expand All @@ -158,9 +144,11 @@ def StringsToIdsTokenized(
lambda x: empty_str_tensor if x.shape == [1] and x[0] == b'' else x, # pylint: disable=g-explicit-bool-comparison
labels_in_str,
)
labels = self._truncate_labels(
max_length, tf.strings.to_number(labels_in_str, out_type=tf.int32)
)
labels = tf.strings.to_number(labels_in_str, out_type=tf.int32)
if p.slice_left:
labels = labels[:, :max_length]
else:
labels = labels[:, -(max_length):]

# Get the shape of each ragged tensor and drop the dimension of the shape.
padding_indices = max_length - (
Expand Down Expand Up @@ -237,8 +225,10 @@ def StringsToIds(
labels = tf.strings.to_number(labels_in_str, out_type=tf.int32)
else:
labels = self._vocab.encode_tf(strs)

labels = self._truncate_labels(max_length - 1, labels)
if p.slice_left:
labels = labels[:, : max_length - 1]
else:
labels = labels[:, -(max_length - 1) :]

if p.prepend_sos:
sos_ids = tf.fill(
Expand Down Expand Up @@ -318,8 +308,8 @@ def IdToString(self, ids: tf.Tensor) -> tf.Tensor:
"""Converts each token ID to a token string.
Args:
ids: A tensor of shape [batch, seqlen] and int32 data type. ids[n, i] is
the token ID at decoding step i for the n-th sample.
ids: A tensor of shape [batch, seqlen] and int32 data type.
ids[n, i] is the token ID at decoding step i for the n-th sample.
Returns:
A tensor of token strings with the same shape as the input ids.
Expand Down
20 changes: 0 additions & 20 deletions saxml/server/pax/lm/lm_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,26 +207,6 @@ def testTokenizedStringsToIdsSliceLeft(self):
self.assertAllEqual([[151, 88], [887, 50256]], labels)
self.assertAllEqual([[0.0, 0.0], [0.0, 1.0]], paddings)

def testTokenizedStringsToIdsFailOnInputTruncation(self):
p = _CreateTokenizedParams()
p.fail_on_input_truncation = True
tokenizer = instantiate(p)
max_length = 2
strs = tf.ragged.constant(['151,88,21', '887'])
with self.assertRaises(ValueError):
tokenizer.StringsToIds(strs, max_length)

def testTokenizedStringsToIdsFailOnInputTruncationSetButNoTruncation(self):
p = _CreateTokenizedParams()
p.fail_on_input_truncation = True
tokenizer = instantiate(p)
max_length = 2
strs = tf.ragged.constant(['151,88', '887'])
ids, labels, paddings = tokenizer.StringsToIds(strs, max_length)
self.assertAllEqual([[151, 88], [887, 50256]], ids)
self.assertAllEqual([[151, 88], [887, 50256]], labels)
self.assertAllEqual([[0.0, 0.0], [0.0, 1.0]], paddings)

def testTokenizedStringsToIdsSliceRight(self):
p = _CreateTokenizedParams()
p.slice_left = False
Expand Down
2 changes: 0 additions & 2 deletions saxml/server/pax/lm/params/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class CommonServingTemplate:
BATCH_SIZE = 1
BATCH_WAIT_SECS = None
INPUT_SEQ_LEN = 256
FAIL_ON_INPUT_TRUNCATION = False
SUFFIX_SEQ_LEN = 0 # Deprecating this attribute.
MIN_DECODE_STEPS = 0 # Currently ignored by all but BeamSearchHParams
MAX_DECODE_STEPS = 32
Expand Down Expand Up @@ -170,7 +169,6 @@ def serving_tokenizer(self):
vocabulary_path=self.VOCABULARY_PATH,
extra_ids=self.NUM_EXTRA_IDS,
reverse_extra_ids=self.REVERSE_EXTRA_IDS,
fail_on_input_truncation=self.FAIL_ON_INPUT_TRUNCATION,
)

def score(self) -> Optional[servable_lm_model.ScoreHParams]:
Expand Down
9 changes: 0 additions & 9 deletions saxml/server/pax/lm/params/template_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ class TestModelMaxSequenceLength(TestModel):
MAX_SEQ_LEN = 128


@template.make_servable()
class TestModelFailOnInputTruncation(TestModel):
FAIL_ON_INPUT_TRUNCATION = True


@template.make_servable()
class TestLayerwiseModel(TestModel):

Expand Down Expand Up @@ -97,10 +92,6 @@ def test_seqlen(self):
TestLayerwiseModel.INPUT_SEQ_LEN + TestLayerwiseModel.MAX_DECODE_STEPS,
)

def test_fail_on_input_truncation(self):
config = TestModelFailOnInputTruncation()
self.assertEqual(config.serving_tokenizer().fail_on_input_truncation, True)

def test_precompute_kv_cache(self):
model_cls = TestModelPrecomputeKVCache
model_cls.SPM_MODEL = os.path.join(
Expand Down

0 comments on commit 0558cea

Please sign in to comment.