Skip to content

Commit

Permalink
removing tensorflow_text for aarch64 compatiblity
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Sep 13, 2024
1 parent 08f68a3 commit bb1666a
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 9 deletions.
3 changes: 1 addition & 2 deletions MaxText/requirements_with_jax_ss.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ pyink
pre-commit
pytype
sentencepiece==0.1.97
tensorflow-text>=2.13.0
tensorflow-datasets
tiktoken
transformers
mlperf-logging@git+https://github.com/mlperf/logging.git
google-jetstream
jsonlines
jsonlines
8 changes: 5 additions & 3 deletions MaxText/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict, Iterable, Union, Literal, Sequence, Collection, List
from pathlib import Path
import tensorflow as tf
import tensorflow_text as tftxt
import sentencepiece as sp
import max_logging
import tiktoken
from tiktoken.load import load_tiktoken_bpe
Expand Down Expand Up @@ -199,7 +199,8 @@ def __init__(self, model_path: str, add_bos: bool, add_eos: bool):
max_logging.log(f"Tokenizer path: {model_path}")
with tf.io.gfile.GFile(model_path, "rb") as model_fp:
sp_model = model_fp.read()
self.sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False)
# this tokenizer is ONLY compatible with previous tftxt sp tokenizer if reverse=False
self.sp_tokenizer = sp.SentencePieceProcessor(model_proto=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False)

def encode(self, s: str) -> List[int]:
return self.sp_tokenizer.tokenize(s)
Expand All @@ -224,9 +225,10 @@ def _process_string(string_tensor):
# encode and extract the tokenized integers
modified_string = tokenizer.encode(string_value)
return [modified_string]

for k in data_keys:
if isinstance(tokenizer, TikTokenTokenizer):
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
elif isinstance(tokenizer, SentencePieceTokenizer):
features[k] = tokenizer.encode(features[k])
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
return features
1 change: 0 additions & 1 deletion constraints_gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ tensorflow-estimator==2.13.0
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.0
tensorflow-metadata==1.15.0
tensorflow-text==2.13.0
tensorstore==0.1.63
termcolor==2.4.0
tf-keras==2.15.0
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pyink
pre-commit
pytype
sentencepiece==0.1.97
tensorflow-text>=2.13.0
tensorflow>=2.13.0
tensorflow-datasets
tensorboardx
Expand Down
3 changes: 1 addition & 2 deletions requirements_with_jax_stable_stack.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ pyink
pre-commit
pytype
sentencepiece==0.1.97
tensorflow-text>=2.13.0
tensorflow-datasets
tiktoken
transformers
mlperf-logging@git+https://github.com/mlperf/logging.git
google-jetstream
jsonlines
jsonlines

0 comments on commit bb1666a

Please sign in to comment.