From afd0e23a523ecd8a82363445a1a320e702423086 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Thu, 12 Sep 2024 15:26:20 -0700 Subject: [PATCH] check with main --- MaxText/tokenizer.py | 8 +++----- constraints_gpu.txt | 1 + requirements.txt | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index 692a5bbd5..e6f389569 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -19,7 +19,7 @@ from typing import Dict, Iterable, Union, Literal, Sequence, Collection, List from pathlib import Path import tensorflow as tf -import sentencepiece as sp +import tensorflow_text as tftxt import max_logging import tiktoken from tiktoken.load import load_tiktoken_bpe @@ -199,8 +199,7 @@ def __init__(self, model_path: str, add_bos: bool, add_eos: bool): max_logging.log(f"Tokenizer path: {model_path}") with tf.io.gfile.GFile(model_path, "rb") as model_fp: sp_model = model_fp.read() - # this tokenizer is ONLY compatible with previous tftxt sp tokenizer if reverse=False - self.sp_tokenizer = sp.SentencePieceProcessor(model_proto=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False) + self.sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False) def encode(self, s: str) -> List[int]: return self.sp_tokenizer.tokenize(s) @@ -225,10 +224,9 @@ def _process_string(string_tensor): # encode and extract the tokenized integers modified_string = tokenizer.encode(string_value) return [modified_string] - for k in data_keys: if isinstance(tokenizer, TikTokenTokenizer): features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0] elif isinstance(tokenizer, SentencePieceTokenizer): - features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0] + features[k] = tokenizer.encode(features[k]) return features diff --git a/constraints_gpu.txt b/constraints_gpu.txt index 1485446a2..beab2a6d2 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -136,6 +136,7 @@ tensorflow-estimator==2.13.0 tensorflow-hub==0.16.1 tensorflow-io-gcs-filesystem==0.37.0 tensorflow-metadata==1.15.0 +tensorflow-text==2.13.0 tensorstore==0.1.63 termcolor==2.4.0 tf-keras==2.15.0 diff --git a/requirements.txt b/requirements.txt index 6b01be6f0..092cd3127 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ pyink pre-commit pytype sentencepiece==0.1.97 +tensorflow-text>=2.13.0 tensorflow>=2.13.0 tensorflow-datasets tensorboardx