From 7afdb976a063f8c7b414155ef145c8a03cd25bc1 Mon Sep 17 00:00:00 2001 From: A9isha Date: Sat, 16 Dec 2023 00:17:52 +0000 Subject: [PATCH] use hf tokenizer with grain --- MaxText/input_pipeline.py | 9 +++++++-- MaxText/tokenizer.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/MaxText/input_pipeline.py b/MaxText/input_pipeline.py index 9d9e1faa0..fcf6d4253 100644 --- a/MaxText/input_pipeline.py +++ b/MaxText/input_pipeline.py @@ -35,6 +35,10 @@ import pygrain_operations import pygrain_tokenizer + +from transformers import AutoModelForCausalLM, AutoTokenizer + + AUTOTUNE = tf.data.experimental.AUTOTUNE @@ -380,8 +384,9 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict, vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') # Load tokenizer - sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path, - vocab_size=config.vocab_size) + # sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path, + # vocab_size=config.vocab_size) + sp_tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablecode-completion-alpha-3b") # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index b8e7f50a3..e47caa0d1 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -25,6 +25,7 @@ import max_logging +import numpy as np Features = Dict[str, tf.Tensor] @@ -64,5 +65,6 @@ class TokenizeOp: def __call__(self, features: Features) -> Features: for k in self.data_keys: - features[k] = self.sp_tokenizer.tokenize(features[k]) + # features[k] = self.sp_tokenizer.tokenize(features[k]) + features[k] = np.asarray(self.sp_tokenizer.encode(str(features[k]))) return features