Skip to content

Commit

Permalink
use hf tokenizer with grain
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed Dec 18, 2023
1 parent 06d079e commit 7afdb97
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 7 additions & 2 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
import pygrain_operations
import pygrain_tokenizer


from transformers import AutoModelForCausalLM, AutoTokenizer


AUTOTUNE = tf.data.experimental.AUTOTUNE


Expand Down Expand Up @@ -380,8 +384,9 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,
vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model')

# Load tokenizer
sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path,
vocab_size=config.vocab_size)
# sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path,
# vocab_size=config.vocab_size)
sp_tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablecode-completion-alpha-3b")

# Set global batch size.
global_batch_size_to_load = config.global_batch_size_to_load
Expand Down
4 changes: 3 additions & 1 deletion MaxText/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import max_logging

import numpy as np

Features = Dict[str, tf.Tensor]

Expand Down Expand Up @@ -64,5 +65,6 @@ class TokenizeOp:

def __call__(self, features: Features) -> Features:
for k in self.data_keys:
features[k] = self.sp_tokenizer.tokenize(features[k])
# features[k] = self.sp_tokenizer.tokenize(features[k])
features[k] = np.asarray(self.sp_tokenizer.encode(str(features[k])))
return features

0 comments on commit 7afdb97

Please sign in to comment.