diff --git a/MaxText/requirements_with_jax_ss.txt b/MaxText/requirements_with_jax_ss.txt index f5afb5684..345039cfc 100644 --- a/MaxText/requirements_with_jax_ss.txt +++ b/MaxText/requirements_with_jax_ss.txt @@ -8,10 +8,9 @@ pyink pre-commit pytype sentencepiece==0.1.97 -tensorflow-text>=2.13.0 tensorflow-datasets tiktoken transformers mlperf-logging@git+https://github.com/mlperf/logging.git google-jetstream -jsonlines \ No newline at end of file +jsonlines diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index e6f389569..e93d0fa52 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -19,7 +19,7 @@ from typing import Dict, Iterable, Union, Literal, Sequence, Collection, List from pathlib import Path import tensorflow as tf -import tensorflow_text as tftxt +import sentencepiece as sp import max_logging import tiktoken from tiktoken.load import load_tiktoken_bpe @@ -199,12 +199,17 @@ def __init__(self, model_path: str, add_bos: bool, add_eos: bool): max_logging.log(f"Tokenizer path: {model_path}") with tf.io.gfile.GFile(model_path, "rb") as model_fp: sp_model = model_fp.read() - self.sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False) + # this tokenizer is ONLY compatible with previous tftxt sp tokenizer if reverse=False + self.sp_tokenizer = sp.SentencePieceProcessor(model_proto=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False) def encode(self, s: str) -> List[int]: + if isinstance(s, tf.Tensor): + s = s.numpy() return self.sp_tokenizer.tokenize(s) def decode(self, t: Sequence[int]) -> str: + if isinstance(s, tf.Tensor): + s = s.numpy().tolist() return self.sp_tokenizer.detokenize(t) def build_tokenizer(tokenizer_path, add_bos, add_eos): diff --git a/constraints_gpu.txt b/constraints_gpu.txt index beab2a6d2..1485446a2 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -136,7 +136,6 @@ tensorflow-estimator==2.13.0 tensorflow-hub==0.16.1 tensorflow-io-gcs-filesystem==0.37.0 tensorflow-metadata==1.15.0 -tensorflow-text==2.13.0 tensorstore==0.1.63 termcolor==2.4.0 tf-keras==2.15.0 diff --git a/requirements.txt b/requirements.txt index 092cd3127..6b01be6f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,6 @@ pyink pre-commit pytype sentencepiece==0.1.97 -tensorflow-text>=2.13.0 tensorflow>=2.13.0 tensorflow-datasets tensorboardx diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index f5afb5684..345039cfc 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -8,10 +8,9 @@ pyink pre-commit pytype sentencepiece==0.1.97 -tensorflow-text>=2.13.0 tensorflow-datasets tiktoken transformers mlperf-logging@git+https://github.com/mlperf/logging.git google-jetstream -jsonlines \ No newline at end of file +jsonlines