From bb1666a632d10e86786ab69b114b32e42bf16cce Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Wed, 11 Sep 2024 17:57:29 -0700 Subject: [PATCH] removing tensorflow_text for aarch64 compatiblity --- MaxText/requirements_with_jax_ss.txt | 3 +-- MaxText/tokenizer.py | 8 +++++--- constraints_gpu.txt | 1 - requirements.txt | 1 - requirements_with_jax_stable_stack.txt | 3 +-- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/MaxText/requirements_with_jax_ss.txt b/MaxText/requirements_with_jax_ss.txt index f5afb5684..345039cfc 100644 --- a/MaxText/requirements_with_jax_ss.txt +++ b/MaxText/requirements_with_jax_ss.txt @@ -8,10 +8,9 @@ pyink pre-commit pytype sentencepiece==0.1.97 -tensorflow-text>=2.13.0 tensorflow-datasets tiktoken transformers mlperf-logging@git+https://github.com/mlperf/logging.git google-jetstream -jsonlines \ No newline at end of file +jsonlines diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index e6f389569..692a5bbd5 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -19,7 +19,7 @@ from typing import Dict, Iterable, Union, Literal, Sequence, Collection, List from pathlib import Path import tensorflow as tf -import tensorflow_text as tftxt +import sentencepiece as sp import max_logging import tiktoken from tiktoken.load import load_tiktoken_bpe @@ -199,7 +199,8 @@ def __init__(self, model_path: str, add_bos: bool, add_eos: bool): max_logging.log(f"Tokenizer path: {model_path}") with tf.io.gfile.GFile(model_path, "rb") as model_fp: sp_model = model_fp.read() - self.sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False) + # this tokenizer is ONLY compatible with previous tftxt sp tokenizer if reverse=False + self.sp_tokenizer = sp.SentencePieceProcessor(model_proto=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False) def encode(self, s: str) -> List[int]: return self.sp_tokenizer.tokenize(s) @@ -224,9 +225,10 @@ def _process_string(string_tensor): # encode and extract the tokenized integers modified_string = tokenizer.encode(string_value) return [modified_string] + for k in data_keys: if isinstance(tokenizer, TikTokenTokenizer): features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0] elif isinstance(tokenizer, SentencePieceTokenizer): - features[k] = tokenizer.encode(features[k]) + features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0] return features diff --git a/constraints_gpu.txt b/constraints_gpu.txt index beab2a6d2..1485446a2 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -136,7 +136,6 @@ tensorflow-estimator==2.13.0 tensorflow-hub==0.16.1 tensorflow-io-gcs-filesystem==0.37.0 tensorflow-metadata==1.15.0 -tensorflow-text==2.13.0 tensorstore==0.1.63 termcolor==2.4.0 tf-keras==2.15.0 diff --git a/requirements.txt b/requirements.txt index 092cd3127..6b01be6f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,6 @@ pyink pre-commit pytype sentencepiece==0.1.97 -tensorflow-text>=2.13.0 tensorflow>=2.13.0 tensorflow-datasets tensorboardx diff --git a/requirements_with_jax_stable_stack.txt b/requirements_with_jax_stable_stack.txt index f5afb5684..345039cfc 100644 --- a/requirements_with_jax_stable_stack.txt +++ b/requirements_with_jax_stable_stack.txt @@ -8,10 +8,9 @@ pyink pre-commit pytype sentencepiece==0.1.97 -tensorflow-text>=2.13.0 tensorflow-datasets tiktoken transformers mlperf-logging@git+https://github.com/mlperf/logging.git google-jetstream -jsonlines \ No newline at end of file +jsonlines