diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 31c01554a..5c46d29f1 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -118,14 +118,15 @@ vocab_relative_path: "tokenizer" # Assumes we're allowed # eval_dataset_name: 'c4/en:3.0.1' # dataset_name: 'array-record/c4/en/3.0.1' # eval_dataset_name: 'array-record/c4/en/3.0.1' +# eval_split: 'validation' # for c4 data dataset_name: 'lm1b/1.1.0' eval_dataset_name: 'lm1b/1.1.0' -eval_split: 'test' +eval_split: 'test' # for lm1b per_device_batch_size: 12.0 eval_per_device_batch_size: 0 max_corpus_chars: 10_000_000 -dataset_type: array_record # must be c4, array_record or synthetic -# dataset_type: c4 +# dataset_type: c4 # must be c4, array_record or synthetic +dataset_type: array_record # Training loop steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps diff --git a/MaxText/input_pipeline.py b/MaxText/input_pipeline.py index dd5c71153..f6fe27116 100644 --- a/MaxText/input_pipeline.py +++ b/MaxText/input_pipeline.py @@ -33,8 +33,6 @@ import multihost_dataloading import sequence_packing import pygrain_operations -from transformers import T5Tokenizer -from sentencepiece import SentencePieceProcessor import pygrain_tokenizer AUTOTUNE = tf.data.experimental.AUTOTUNE @@ -172,7 +170,7 @@ def filter_fn(x): def preprocessing_pipeline_pygrain( dataset, - operations, + vocab_path, batch_size: int, global_mesh, shuffle: bool, @@ -185,6 +183,12 @@ def preprocessing_pipeline_pygrain( data_sharding = None, data_shuffle_seed = 0, ): + + operations = [] + operations.append(pygrain_operations.ParseFeatures()) + operations.append(pygrain_operations.NormalizeFeatures()) + operations.append(pygrain_tokenizer.Tokenize(["inputs","targets"], max_length, vocab_path)) + operations.append(pygrain.MapOperation(map_function=pygrain_operations.filter_keys)) operations.append(pygrain.FilterOperation(condition_function = pygrain_operations.length_filter(max_length))) # Pack and Batch examples. @@ -194,7 +198,7 @@ def preprocessing_pipeline_pygrain( length_struct={'inputs':max_length,'targets':max_length})) operations.append(pygrain.MapOperation(map_function=pygrain_operations.CombineKeys())) else: - # operations.append(pygrain.MapOperation(map_function=pygrain_operations.PadToMaxLength(max_length))) + operations.append(pygrain.MapOperation(map_function=pygrain_operations.PadToMaxLength(max_length))) operations.append(pygrain.BatchOperation(batch_size=batch_size // jax.process_count(), drop_remainder=drop_remainder)) # Shift inputs for teacher-forced training @@ -212,22 +216,15 @@ def preprocessing_pipeline_pygrain( ) dataloader = pygrain.DataLoader( - data_source = dataset, - operations = operations, - sampler = index_sampler, - worker_count=0, + data_source = dataset, + operations = operations, + sampler = index_sampler, + worker_count=1, ) + data_iter = iter(dataloader) - global_shape = (batch_size, max_length) - # Return PyGrainIterator - # return data_iter - multihost_gen = ( - multihost_dataloading.get_next_batch_sharded_pygrain( - data_iter, data_sharding, global_shape, global_mesh - ) - ) - # Return multi-host jax.Array prep iterator - return multihost_gen + + return data_iter def get_datasets( @@ -287,7 +284,7 @@ def preprocess_dataset(config: ml_collections.ConfigDict, # Load tokenizer sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path, vocab_size=config.vocab_size) - # sp_tokenizer = T5Tokenizer.from_pretrained('t5-base') + # Tokenize data. train_ds = train_ds.map( tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) @@ -354,15 +351,9 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict, vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') # Load tokenizer - # sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path, - # vocab_size=config.vocab_size) - # sp_tokenizer = T5Tokenizer.from_pretrained('t5-base', model_max_length=1024) - sp_tokenizer = SentencePieceProcessor(vocab_path) - - operations = [pygrain.MapOperation(map_function=pygrain_operations.normalize_features())] - #operations.append(pygrain.MapOperation(map_function=pygrain_operations.TokenizeOperation(sp_tokenizer))) - operations.append(pygrain_tokenizer.TokenizeAndPad(["inputs","targets"], config.max_target_length, vocab_path)) - + sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path, + vocab_size=config.vocab_size) + # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load @@ -371,44 +362,40 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict, else: eval_batch_size = global_batch_size_to_load - def filter_keys(record): - return {'inputs': record['inputs'], 'targets': record['targets']} - operations.append(pygrain.MapOperation(map_function=filter_keys)) - train_iter = preprocessing_pipeline_pygrain( train_ds, - operations, + vocab_path, global_batch_size_to_load, global_mesh, shuffle=config.enable_data_shuffling, num_epochs=1, pack_examples=False, max_length=config.max_target_length, - shift=False, + shift=True, data_sharding=config.data_sharding, data_shuffle_seed = data_shuffle_seed,) eval_iter = preprocessing_pipeline_pygrain( eval_ds, - operations, + vocab_path, eval_batch_size, global_mesh, shuffle=config.enable_data_shuffling, pack_examples=False, max_length=config.max_eval_target_length, - shift=False, + shift=True, data_sharding=config.data_sharding, data_shuffle_seed = data_shuffle_seed,) predict_iter = preprocessing_pipeline_pygrain( eval_ds, - operations, + vocab_path, eval_batch_size, global_mesh, shuffle=config.enable_data_shuffling, pack_examples=False, max_length=config.max_eval_target_length, - shift=False, + shift=True, data_sharding=config.data_sharding, data_shuffle_seed = data_shuffle_seed,) diff --git a/MaxText/multihost_dataloading.py b/MaxText/multihost_dataloading.py index 65535426c..993b9f419 100644 --- a/MaxText/multihost_dataloading.py +++ b/MaxText/multihost_dataloading.py @@ -126,7 +126,6 @@ def get_next_batch_sharded(local_dataset: tf.data.Dataset, if not loaded_data_success: local_data = local_dataset.next() - # local_devices = jax.local_devices() local_devices = global_mesh.local_devices local_device_count = jax.local_device_count() @@ -190,11 +189,8 @@ def get_next_batch_sharded_pygrain(data_iter, data_axes = jax.tree_map(lambda x: PartitionSpec(*data_sharding), local_data) _ = check_inputs("array_record", local_data, global_data_shape, data_axes) - # local_devices = jax.local_devices() local_devices = global_mesh.local_devices local_device_count = jax.local_device_count() - print(f"local_device: {local_devices}") - print(f"local_device_count: {local_device_count}") def _put_to_devices(x): try: @@ -217,10 +213,6 @@ def form_gda(local_data, shape): device_buffers = _put_to_devices(local_data) # Wrap device buffers as GDA shape = tuple(shape) - print("####################### Debug") - print(f"shape: {shape}; global_mesh: {global_mesh}; ") - print(f"input_sharding_constraint: {input_sharding_constraint};") - print(f"jax.sharding.NamedSharding: {jax.sharding.NamedSharding(global_mesh, input_sharding_constraint)};") input_gda = jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(global_mesh, input_sharding_constraint), device_buffers) return input_gda diff --git a/MaxText/pygrain_operations.py b/MaxText/pygrain_operations.py index aae2f86e8..7e62fcd72 100644 --- a/MaxText/pygrain_operations.py +++ b/MaxText/pygrain_operations.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping, Sequence +import dataclasses from typing import Dict import grain.python as pygrain import numpy as np @@ -14,6 +16,29 @@ def _normalize_features(features): return {'inputs':features, 'targets': features} return _normalize_features(features) +@dataclasses.dataclass +class ParseFeatures(pygrain.MapTransform): + def map(self, features): + def _parse(example): + parsed = tf.io.parse_example( + example, { + 'text': tf.io.FixedLenFeature(shape=(), dtype=tf.string) + }) + return parsed + return _parse(features) + + +@dataclasses.dataclass +class NormalizeFeatures(pygrain.MapTransform): + def map(self, features): + return { + 'inputs':features['text'].numpy().decode(), + 'targets': features['text'].numpy().decode() + } + +def filter_keys(record): + return {'inputs': record['inputs'], 'targets': record['targets']} + class TokenizeOperation(): """ TokenizeOp """ @@ -70,27 +95,29 @@ def __call__(self, data): combined_data.update(positions) return combined_data -def shift_right_tf(x, axis=1): +def shift_right(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) slices = [slice(None),] * len(x.shape) slices[axis] = slice(0, -1) - padded = tf.pad( + padded = np.pad( x, - tf.constant(pad_widths), + pad_widths, mode='constant', - constant_values=tf.constant(0, x.dtype)) + constant_values=x.dtype.type(0) + # constant_values=tf.constant(0, x.dtype) + ) return padded[tuple(slices)] -def shift_inputs_tf(x, segment_ids=None, axis=1): +def shift_inputs(x, segment_ids=None, axis=1): """Shift inputs and replace EOS by 0 for packed inputs.""" - shifted = shift_right_tf(x, axis=axis) + shifted = shift_right(x, axis=axis) # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: shifted *= tf.cast( - segment_ids == shift_right_tf(segment_ids, axis=axis), x.dtype + segment_ids == shift_right(segment_ids, axis=axis), x.dtype ) return shifted @@ -101,5 +128,5 @@ def __init__(self, axis = 0, segmented=True): def __call__(self, x): segment_ids = x['inputs_segmentation'] if self.segmented else None - x['inputs'] = shift_inputs_tf(x['inputs'], segment_ids=segment_ids, axis=self.axis) + x['inputs'] = shift_inputs(x['inputs'], segment_ids=segment_ids, axis=self.axis) return x \ No newline at end of file diff --git a/MaxText/pygrain_tokenizer.py b/MaxText/pygrain_tokenizer.py index 22d486229..f5647cc4c 100644 --- a/MaxText/pygrain_tokenizer.py +++ b/MaxText/pygrain_tokenizer.py @@ -7,119 +7,10 @@ from typing import Any from sentencepiece import SentencePieceProcessor import grain.python as grain - -class AbstractTokenizeAndSplit(grain.MapTransform): - """Tokenize and split text features. - - The split of the tokenized features will replace the text features. - - This transform makes 2 assumptions: - - Records are flat dictionaries with 1 or more text features. - - It follows a DataSourceWithSplitInfo which should produce elements as: - (example, (split_index, expected_split_count)) - - The transform will produces None if the actual example doesn't have the - corresponding split. - """ - - def __init__( - self, - feature_names: str | Sequence[str], - sequence_length: int | Sequence[int], - ): - """Creates a new TokenizeAndSplit transform. - - Args: - feature_names: One or multiple feature names that contain text. - sequence_length: One or multiple sequence lengths to use for the text - features. Output features will have [0, sequence_length] tokens. - """ - super().__init__() - if isinstance(feature_names, str): - feature_names = [feature_names] - if isinstance(sequence_length, int): - sequence_length = [sequence_length] * len(feature_names) - elif len(sequence_length) != len(feature_names): - raise ValueError( - f"Number of features and sequence lengths mismatch: {feature_names=}," - f" {sequence_length=}" - ) - self._feature_names = feature_names - self._sequence_length = sequence_length - self._stats = { - "empty_splits": 0, - "discarded_splits": 0, - } - - def map( - self, features: tuple[dict[str, Any], tuple[int, int]] - ) -> dict[str, Any] | None: - features, (split_index, expected_split_count) = features - actual_split_count = 0 - for feature_name, sequence_length in zip( - self._feature_names, self._sequence_length, strict=True - ): - text = features[feature_name] - token_ids = self._tokenize(text) - start = split_index * sequence_length - end = (split_index + 1) * sequence_length - if start >= len(token_ids): - self._stats["empty_splits"] += 1 - return None - actual_split_count = max( - actual_split_count, int(math.ceil(len(token_ids) / sequence_length)) - ) - features[feature_name] = np.asarray(token_ids[start:end]) - if split_index == 0 and actual_split_count > expected_split_count: - self._stats["discarded_splits"] += ( - actual_split_count - expected_split_count - ) - return features - - def get_stats(self) -> Mapping[str, int]: - return copy.copy(self._stats) - - @abc.abstractmethod - def _tokenize(self, text: str) -> Sequence[int]: - """Tokenizes the text.""" - - -class SentencePieceTokenizeAndSplit(AbstractTokenizeAndSplit): - """Tokenize and split text features using a Gemini tokenizer.""" - - def __init__( - self, - feature_names: str | Sequence[str], - sequence_length: int | Sequence[int], - sentencepiece_model_path: str, - ): - super().__init__(feature_names, sequence_length) - self._sentencepiece_model_path = sentencepiece_model_path - self._initialize_processor_lock = threading.Lock() - self._tokenizer = None - - def _tokenize(self, text: str) -> Sequence[int]: - if self._tokenizer is None: - with self._initialize_processor_lock: - if self._tokenizer is None: - self._tokenizer = sentencepiece_processor.SentencePieceProcessor() - self._tokenizer.Load(filename=self._sentencepiece_model_path) - return self._tokenizer.EncodeAsIds(text) - - def __getstate__(self): - state = self.__dict__.copy() - del state["_tokenizer"] - del state["_initialize_processor_lock"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._tokenizer = None - self._initialize_processor_lock = threading.Lock() - +import numpy as np @dataclasses.dataclass -class TokenizeAndPad(grain.MapTransform): +class Tokenize(grain.MapTransform): """Tokenize, truncate and pad features to sequence length.""" feature_names: str | Sequence[str] @@ -139,17 +30,13 @@ def map(self, features: dict[str, Any]) -> dict[str, Any]: with self._initialize_processor_lock: if self._processor is None: # Ensures only one thread initializes SPP. self._processor = SentencePieceProcessor(self.model_path) - # self._processor.Load(filename=self.model_path) for feature_name, sequence_length in zip( self.feature_names, self.sequence_length, strict=True ): text = features[feature_name] token_ids = self._processor.EncodeAsIds(text) token_ids = token_ids[:sequence_length] - token_ids = token_ids + [self._processor.pad_id()] * ( - sequence_length - len(token_ids) - ) - features[feature_name] = token_ids + features[feature_name] = np.asarray(token_ids) return features def __getstate__(self): diff --git a/MaxText/train.py b/MaxText/train.py index 69d8b2deb..f6ba13aa7 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -23,10 +23,6 @@ import os import sys -jax.config.update('jax_default_prng_impl', 'unsafe_rbg') -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" -os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" -print(f"Found {jax.device_count()} devices.") from typing import Sequence import datetime @@ -55,6 +51,8 @@ from cloud_tpu_diagnostics.configuration import diagnostic_configuration from cloud_tpu_diagnostics.configuration import stack_trace_configuration +from multihost_dataloading import get_next_batch_sharded_pygrain + import max_logging cc.initialize_cache(os.path.expanduser("~/jax_cache")) @@ -81,7 +79,15 @@ def load_next_batch(train_iter, example_batch, config): if config.reuse_example_batch and example_batch is not None: return example_batch else: - return train_iter() + return train_iter() + +def load_next_batch_pygrain(train_iter, example_batch, config, mesh): + if config.reuse_example_batch and example_batch is not None: + return example_batch + else: + global_shape = (config.global_batch_size_to_load, config.max_target_length) + return get_next_batch_sharded_pygrain( + train_iter, config.data_sharding, global_shape, mesh) def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): """Records scalar metrics to be written to tensorboard""" @@ -243,11 +249,6 @@ def train_loop(config, state=None): data_iterator, _ = create_data_iterator_with_tokenizer(config, mesh) - # if config.dataset_type == "array_record": - # inital_iterator = data_iterator - # else: - # inital_iterator = None - state, state_mesh_annotations, data_iterator = max_utils.setup_initial_state(model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager) data_pspec = P(*config.data_sharding) @@ -274,7 +275,10 @@ def train_loop(config, state=None): running_gcs_metrics = [] if config.gcs_metrics else None for step in np.arange(get_first_step(state), config.steps): - example_batch = load_next_batch(data_iterator, example_batch, config) + if config.dataset_type == "array_record": + example_batch = load_next_batch_pygrain(data_iterator, example_batch, config, mesh) + else: + example_batch = load_next_batch(data_iterator, example_batch, config) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state, metrics, nextrng = p_train_step( model, config, state, example_batch, nextrng @@ -311,6 +315,10 @@ def train_loop(config, state=None): return state def main(argv: Sequence[str]) -> None: + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + print(f"Found {jax.device_count()} devices.") pyconfig.initialize(argv) os.environ["TFDS_DATA_DIR"] = pyconfig.config.dataset_path debug_config = debug_configuration.DebugConfig( @@ -319,7 +327,7 @@ def main(argv: Sequence[str]) -> None: stack_trace_to_cloud = pyconfig.config.stack_trace_to_cloud, stack_trace_interval_seconds = pyconfig.config.stack_trace_interval_seconds)) diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) - with diagnostic.diagnose(diagnostic_config): + with diagnostic.diagnose(diagnostic_config): train_loop(pyconfig.config) diff --git a/requirements.txt b/requirements.txt index 704122695..50e116ae2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ -orbax-checkpoint +orbax-checkpoint==0.4.1 absl-py +array-record argparse cloud-tpu-diagnostics datetime google-cloud-storage +grain-nightly==0.0.3 flax ml-collections numpy diff --git a/setup_gcsfuse.sh b/setup_gcsfuse.sh new file mode 100644 index 000000000..8fe77dab6 --- /dev/null +++ b/setup_gcsfuse.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Description: +# bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxtext-dataset MOUNT_PATH=dataset + +set -e + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" + echo "$KEY"="$VALUE" +done + +if [[ -z ${DATASET_GCS_BUCKET} || -z ${MOUNT_PATH} ]]; then + echo "Please set arguments: DATASET_GCS_BUCKET and MOUNT_PATH" + exit 1 +fi + +if [[ $GCS_BUCKET == gs://* ]] ; +then + echo "Remove gs:// from GCS bucket name" + exit 1 +fi + +sudo apt-get -y install fuse +export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` +echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list +curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - +sudo apt-get update +sudo apt-get -y install gcsfuse + +mkdir -p $MOUNT_PATH + +gcsfuse --implicit-dirs "$DATASET_GCS_BUCKET" "$MOUNT_PATH"