From 4ebff993a955898be7bb77c5aefb03715366219e Mon Sep 17 00:00:00 2001 From: aireenmei Date: Thu, 2 Nov 2023 06:11:55 +0000 Subject: [PATCH] add grain to 1028 head --- MaxText/checkpointing.py | 52 +++++++-- MaxText/configs/base.yml | 19 ++-- MaxText/decode.py | 4 +- MaxText/input_pipeline.py | 174 ++++++++++++++++++++++++++++++- MaxText/max_utils.py | 20 +++- MaxText/multihost_dataloading.py | 82 ++++++++++++++- MaxText/pyconfig.py | 2 +- MaxText/pygrain_operations.py | 105 +++++++++++++++++++ MaxText/pygrain_tokenizer.py | 164 +++++++++++++++++++++++++++++ MaxText/train.py | 12 ++- 10 files changed, 603 insertions(+), 31 deletions(-) create mode 100644 MaxText/pygrain_operations.py create mode 100644 MaxText/pygrain_tokenizer.py diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index 99919651e..ff93849f7 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -24,6 +24,7 @@ from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions, Checkpointer, AsyncCheckpointer from orbax.checkpoint import type_handlers import socket +import grain.python as pygrain import max_logging @@ -55,7 +56,8 @@ def create_orbax_checkpoint_manager( checkpoint_dir: str, enable_checkpointing: bool, use_async: bool, - save_interval_steps: int + save_interval_steps: int, + dataset_type: str, ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -65,9 +67,17 @@ def create_orbax_checkpoint_manager( p = epath.Path(checkpoint_dir) if use_async: _multislice_distribute_initialize() - checkpointer = AsyncCheckpointer(checkpoint.PyTreeCheckpointHandler()) + if dataset_type == "array_record": + checkpointer = {'state':AsyncCheckpointer(checkpoint.PyTreeCheckpointHandler()), + 'iter':Checkpointer(pygrain.PyGrainCheckpointHandler())} + else: + checkpointer = AsyncCheckpointer(checkpoint.PyTreeCheckpointHandler()) else: - checkpointer = Checkpointer(checkpoint.PyTreeCheckpointHandler()) + if dataset_type == "array_record": + checkpointer = {'state':Checkpointer(checkpoint.PyTreeCheckpointHandler()), + 'iter':Checkpointer(pygrain.PyGrainCheckpointHandler())} + else: + checkpointer = Checkpointer(checkpoint.PyTreeCheckpointHandler()) mngr = CheckpointManager( p, @@ -86,6 +96,8 @@ def load_state_if_possible(checkpoint_manager: CheckpointManager, load_from_other_directory: str, load_from_other_directory_step: int, abstract_unboxed_pre_state: train_state.TrainState, + dataset_type, + iterator, mesh, state_mesh_annotations): """Loads TrainState as possible from the inputs. @@ -118,22 +130,38 @@ def map_to_pspec(data, pspec): else: return type_handlers.RestoreArgs() - restore_args = jax.tree_util.tree_map(map_to_pspec, - abstract_unboxed_pre_state, - state_mesh_annotations) + if dataset_type=="array_record": + restore_state = jax.tree_util.tree_map(map_to_pspec, + abstract_unboxed_pre_state, + state_mesh_annotations) + restore_args = {'state':restore_state, 'iter':iterator} + else: + restore_args = jax.tree_util.tree_map(map_to_pspec, + abstract_unboxed_pre_state, + state_mesh_annotations) + latest_step = checkpoint_manager.latest_step() if latest_step is not None: max_logging.log(f"restoring state from this run's directory latest step \ {latest_step}") - return checkpoint_manager.restore(latest_step, abstract_unboxed_pre_state, + if dataset_type=="array_record": + return checkpoint_manager.restore(latest_step, {'state':abstract_unboxed_pre_state,'iter':iterator}, + {"restore_args" : restore_args}), None + else: + return checkpoint_manager.restore(latest_step, abstract_unboxed_pre_state, {"restore_args" : restore_args}), None elif first_checkpoint_path != "": max_logging.log(f"restoring state from first_checkpoint_path {first_checkpoint_path}") p = epath.Path(first_checkpoint_path) checkpointer = Checkpointer(checkpoint.PyTreeCheckpointHandler()) - return None, checkpointer.restore(p, - item=abstract_unboxed_pre_state, + if dataset_type=="array_record": + return None, checkpointer.restore(p, + item={'state':abstract_unboxed_pre_state,'iter':iterator}, restore_args=restore_args).params + else: + return None, checkpointer.restore(p, + item=abstract_unboxed_pre_state, + restore_args=restore_args).params elif load_from_other_directory != "": p = epath.Path(load_from_other_directory) checkpointer_loader = Checkpointer(checkpoint.PyTreeCheckpointHandler()) @@ -144,8 +172,12 @@ def map_to_pspec(data, pspec): else: step = load_from_other_directory_step max_logging.log(f"restoring state from {load_from_other_directory} step {step}") - return mngr_loader.restore(step, abstract_unboxed_pre_state, + if dataset_type=="array_record": + return mngr_loader.restore(step, {'state':abstract_unboxed_pre_state,'iter':iterator}, {"restore_args" : restore_args}), None + else: + return mngr_loader.restore(step, abstract_unboxed_pre_state, + {"restore_args" : restore_args}), None else: max_logging.log("No existing checkpoints found, not restoring checkpoint.") return None, None diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 36484e0db..31c01554a 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -61,7 +61,7 @@ global_parameter_scale: 1 base_emb_dim: 2560 base_num_heads: 8 base_mlp_dim: 8192 -base_num_decoder_layers: 16 +base_num_decoder_layers: 2 head_dim: 256 # activation functions are . mlp_activations: ["relu"] @@ -114,13 +114,18 @@ dataset_path: "" vocab_size: 32_768 # powers of 2 for sharding assets_path: "assets" vocab_relative_path: "tokenizer" # Assumes we're allowed -dataset_name: 'c4/en:3.0.1' -eval_dataset_name: 'c4/en:3.0.1' -eval_split: 'validation' +# dataset_name: 'c4/en:3.0.1' +# eval_dataset_name: 'c4/en:3.0.1' +# dataset_name: 'array-record/c4/en/3.0.1' +# eval_dataset_name: 'array-record/c4/en/3.0.1' +dataset_name: 'lm1b/1.1.0' +eval_dataset_name: 'lm1b/1.1.0' +eval_split: 'test' per_device_batch_size: 12.0 eval_per_device_batch_size: 0 max_corpus_chars: 10_000_000 -dataset_type: c4 # must be c4 or synthetic +dataset_type: array_record # must be c4, array_record or synthetic +# dataset_type: c4 # Training loop steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps @@ -140,9 +145,9 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set # dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. # Maximum length cutoff for training examples. -max_target_length: 2048 +max_target_length: 1024 # Maximum length cutoff for held-out evaluation examples. -max_eval_target_length: 512 +max_eval_target_length: 1024 # Maximum length cutoff for predicted tokens. max_predict_length: 64 diff --git a/MaxText/decode.py b/MaxText/decode.py index 4f55dc3e2..66937c3f8 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -139,9 +139,9 @@ def decode_loop(config, state=None): ) # TODO: we need an optax.GradientTransformation to form a TrainState, but we don't use it when decoding - _, sp_tokenizer = create_data_iterator_with_tokenizer(config, mesh) + data_iter, sp_tokenizer = create_data_iterator_with_tokenizer(config, mesh) - state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager) + state, state_mesh_annotations = max_utils.setup_initial_state(model, data_iter, tx, config, rng, mesh, checkpoint_manager) state_mesh_shardings = jax.tree_map( lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) diff --git a/MaxText/input_pipeline.py b/MaxText/input_pipeline.py index 67890a0c2..dd5c71153 100644 --- a/MaxText/input_pipeline.py +++ b/MaxText/input_pipeline.py @@ -17,18 +17,25 @@ """Input pipeline for a LM1B dataset.""" import os +import re from typing import Optional import functools +# from array_record.python import array_record_data_source import ml_collections import tensorflow as tf import tensorflow_datasets as tfds +import grain.python as pygrain import jax from jax.sharding import PartitionSpec as P import tokenizer import multihost_dataloading import sequence_packing +import pygrain_operations +from transformers import T5Tokenizer +from sentencepiece import SentencePieceProcessor +import pygrain_tokenizer AUTOTUNE = tf.data.experimental.AUTOTUNE @@ -163,6 +170,65 @@ def filter_fn(x): # Return multi-host jax.Array prep iterator return multihost_gen +def preprocessing_pipeline_pygrain( + dataset, + operations, + batch_size: int, + global_mesh, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + shift: bool = True, + drop_remainder: bool = True, + data_sharding = None, + data_shuffle_seed = 0, +): + operations.append(pygrain.FilterOperation(condition_function = pygrain_operations.length_filter(max_length))) + + # Pack and Batch examples. + if pack_examples: + operations.append(pygrain.experimental.PackAndBatchOperation( + batch_size=batch_size // jax.process_count(), + length_struct={'inputs':max_length,'targets':max_length})) + operations.append(pygrain.MapOperation(map_function=pygrain_operations.CombineKeys())) + else: + # operations.append(pygrain.MapOperation(map_function=pygrain_operations.PadToMaxLength(max_length))) + operations.append(pygrain.BatchOperation(batch_size=batch_size // jax.process_count(), drop_remainder=drop_remainder)) + + # Shift inputs for teacher-forced training + if shift: + operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=0,segmented=pack_examples))) + + index_sampler = pygrain.IndexSampler( + num_records=len(dataset), + num_epochs = num_epochs, + shard_options=pygrain.ShardOptions( + shard_index = jax.process_index(), shard_count = jax.process_count(), drop_remainder = True + ), + shuffle = shuffle, + seed = data_shuffle_seed + ) + + dataloader = pygrain.DataLoader( + data_source = dataset, + operations = operations, + sampler = index_sampler, + worker_count=0, + ) + data_iter = iter(dataloader) + global_shape = (batch_size, max_length) + # Return PyGrainIterator + # return data_iter + multihost_gen = ( + multihost_dataloading.get_next_batch_sharded_pygrain( + data_iter, data_sharding, global_shape, global_mesh + ) + ) + # Return multi-host jax.Array prep iterator + return multihost_gen + def get_datasets( config: ml_collections.ConfigDict, @@ -193,6 +259,21 @@ def get_datasets( return train_ds, eval_ds +def get_datasets_pygrain( + config: ml_collections.ConfigDict, + read_config = None, +): + data_dir = os.path.join(config.dataset_path, config.dataset_name) + train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)] + train_ds = pygrain.ArrayRecordDataSource(train_files) + if config.eval_dataset_name: + eval_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(rf'.*{config.eval_split}.*', f)] + eval_ds = pygrain.ArrayRecordDataSource(eval_files) + else: + eval_ds_builder = train_ds_builder + + return train_ds, eval_ds + def preprocess_dataset(config: ml_collections.ConfigDict, global_mesh, train_ds, eval_ds, @@ -202,10 +283,11 @@ def preprocess_dataset(config: ml_collections.ConfigDict, if vocab_path is None: vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') + # Load tokenizer sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path, vocab_size=config.vocab_size) - + # sp_tokenizer = T5Tokenizer.from_pretrained('t5-base') # Tokenize data. train_ds = train_ds.map( tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) @@ -262,6 +344,76 @@ def filter_keys(record): return train_iter, eval_iter, predict_iter, sp_tokenizer +def preprocess_dataset_pygrain(config: ml_collections.ConfigDict, + global_mesh, + train_ds, eval_ds, + vocab_path: Optional[str] = None, + data_shuffle_seed = 0,): + """PyGrain: Pre-process the dataset and return iterators""" + if vocab_path is None: + vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') + + # Load tokenizer + # sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path, + # vocab_size=config.vocab_size) + # sp_tokenizer = T5Tokenizer.from_pretrained('t5-base', model_max_length=1024) + sp_tokenizer = SentencePieceProcessor(vocab_path) + + operations = [pygrain.MapOperation(map_function=pygrain_operations.normalize_features())] + #operations.append(pygrain.MapOperation(map_function=pygrain_operations.TokenizeOperation(sp_tokenizer))) + operations.append(pygrain_tokenizer.TokenizeAndPad(["inputs","targets"], config.max_target_length, vocab_path)) + + # Set global batch size. + global_batch_size_to_load = config.global_batch_size_to_load + + if config.eval_per_device_batch_size > 0: + eval_batch_size = config.eval_per_device_batch_size * global_mesh.size + else: + eval_batch_size = global_batch_size_to_load + + def filter_keys(record): + return {'inputs': record['inputs'], 'targets': record['targets']} + operations.append(pygrain.MapOperation(map_function=filter_keys)) + + train_iter = preprocessing_pipeline_pygrain( + train_ds, + operations, + global_batch_size_to_load, + global_mesh, + shuffle=config.enable_data_shuffling, + num_epochs=1, + pack_examples=False, + max_length=config.max_target_length, + shift=False, + data_sharding=config.data_sharding, + data_shuffle_seed = data_shuffle_seed,) + + eval_iter = preprocessing_pipeline_pygrain( + eval_ds, + operations, + eval_batch_size, + global_mesh, + shuffle=config.enable_data_shuffling, + pack_examples=False, + max_length=config.max_eval_target_length, + shift=False, + data_sharding=config.data_sharding, + data_shuffle_seed = data_shuffle_seed,) + + predict_iter = preprocessing_pipeline_pygrain( + eval_ds, + operations, + eval_batch_size, + global_mesh, + shuffle=config.enable_data_shuffling, + pack_examples=False, + max_length=config.max_eval_target_length, + shift=False, + data_sharding=config.data_sharding, + data_shuffle_seed = data_shuffle_seed,) + + return train_iter, eval_iter, predict_iter, sp_tokenizer + def make_c4_train_iterator_and_tokenizer(config, mesh): """ Make train iterator and tokenizer for C4 dataset""" @@ -281,6 +433,24 @@ def make_c4_train_iterator_and_tokenizer(config, mesh): ) return train_iter, sp_tokenizer +def make_pygrain_train_iterator_and_tokenizer(config, mesh): + """ Make train iterator and tokenizer for C4 dataset""" + read_config = tfds.ReadConfig( + shuffle_seed = config.data_shuffle_seed, + ) + train_ds, eval_ds = get_datasets_pygrain( + config=config, + read_config = read_config, + ) + train_iter, _, _, sp_tokenizer = preprocess_dataset_pygrain( + config, + mesh, + train_ds, eval_ds, + vocab_path=os.path.join(config.assets_path, config.vocab_relative_path), + data_shuffle_seed = config.data_shuffle_seed, + ) + return train_iter, sp_tokenizer + class SyntheticDataIterator(): """Creates a synthetic data iterator for performance testing work""" def __init__(self, config, mesh): @@ -319,5 +489,7 @@ def create_data_iterator_with_tokenizer(config, mesh): return SyntheticDataIterator(config, mesh), None elif config.dataset_type == "c4": return make_c4_train_iterator_and_tokenizer(config, mesh) + elif config.dataset_type == "array_record": + return make_pygrain_train_iterator_and_tokenizer(config, mesh) else: assert False, "dataset type not implemented" diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 57b595de3..995fdcca4 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -180,7 +180,7 @@ def init_train_state(model, tx, config, key): return state -def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager): +def setup_initial_state(model, iterator, tx, config, rng, mesh, checkpoint_manager): """ We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. @@ -205,16 +205,29 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager): # Initialization with nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) - state, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, + restore, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, config.load_parameters_path, config.load_from_other_directory, config.load_from_other_directory_step, unboxed_abstract_state, + config.dataset_type, + iterator, mesh, state_mesh_annotations) state_mesh_shardings = jax.tree_map( lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + + state = None + if restore: + if restore['state']: + state = restore['state'] + else: + state = restore + + if restore['iter'] and config.dataset_type=="array_record": + iterator = restore['iter'] + if not state: state = jax.jit( init_train_state_partial, @@ -226,7 +239,8 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager): raw_params = None state = unbox_logicallypartioned_trainstate(state) - return state, state_mesh_annotations + return state, state_mesh_annotations, iterator + # Learning Rate Schedule diff --git a/MaxText/multihost_dataloading.py b/MaxText/multihost_dataloading.py index 8bc6c2c51..65535426c 100644 --- a/MaxText/multihost_dataloading.py +++ b/MaxText/multihost_dataloading.py @@ -43,12 +43,15 @@ DATA_DIM = 0 # assume data dimension is the first -def check_inputs(dataset, global_data_shape, data_axes): +def check_inputs(dataset_type, dataset, global_data_shape, data_axes): # pylint: disable=missing-function-docstring # dataset_structure = jax.tree_util.tree_structure(iter(dataset).next()) - dataset_structure = jax.tree_util.tree_structure( - tf.data.experimental.get_structure(dataset) - ) + if dataset_type == "array_record": + dataset_structure = jax.tree_util.tree_structure(dataset) + else: + dataset_structure = jax.tree_util.tree_structure( + tf.data.experimental.get_structure(dataset) + ) global_data_shape_structure = jax.tree_util.tree_structure(global_data_shape) data_axes_structure = jax.tree_util.tree_structure(data_axes) try: @@ -92,7 +95,7 @@ def get_batch_sharded_data_pipeline( Returns: sharded_dataset: per_host dataset """ - _ = check_inputs(dataset, global_data_shape, data_axes) + _ = check_inputs("c4", dataset, global_data_shape, data_axes) dataset = iter(dataset.as_numpy_iterator()) @@ -148,6 +151,9 @@ def form_gda(local_data, shape): device_buffers = _put_to_devices(local_data) # Wrap device buffers as GDA shape = tuple(shape) + print("####################### Debug") + print(f"shape: {shape}; global_mesh: {global_mesh}; ") + print(f"input_sharding_constraint: {input_sharding_constraint};") input_gda = jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(global_mesh, input_sharding_constraint), device_buffers) return input_gda @@ -156,3 +162,69 @@ def form_gda(local_data, shape): return input_gdas +def get_next_batch_sharded_pygrain(data_iter, + data_sharding, + global_shape: Pytree, + global_mesh: Mesh) -> jax.Array: + """Splits the host loaded data equally over all devices.""" + + SLEEP_TIME = 10 + MAX_DATA_LOAD_ATTEMPTS = 30 + data_load_attempts = 0 + loaded_data_success = False + while not loaded_data_success and data_load_attempts < MAX_DATA_LOAD_ATTEMPTS: + data_load_attempts += 1 + try: + local_data = next(data_iter) + loaded_data_success = True + except tf.errors.FailedPreconditionError: + max_logging.log("Failed to get next data batch, retrying") + time.sleep(SLEEP_TIME) + # Try one last time, if this fails we will see the full stack trace. + if not loaded_data_success: + local_data = next(data_iter) + + global_data_shape = jax.tree_map( + lambda x: PartitionSpec(*global_shape), local_data + ) + data_axes = jax.tree_map(lambda x: PartitionSpec(*data_sharding), local_data) + _ = check_inputs("array_record", local_data, global_data_shape, data_axes) + + # local_devices = jax.local_devices() + local_devices = global_mesh.local_devices + local_device_count = jax.local_device_count() + print(f"local_device: {local_devices}") + print(f"local_device_count: {local_device_count}") + + def _put_to_devices(x): + try: + per_device_arrays = np.split(x, local_device_count, axis=0) + except ValueError as array_split_error: + raise ValueError( + f'Unable to put to devices shape {x.shape} with ' + f'local device count {local_device_count}') from array_split_error + device_buffers = [ + jax.device_put(arr, d) + for arr, d in zip(per_device_arrays, local_devices) + ] + return device_buffers + # 'fully shard' the data (first) axis across both axes + # of the hardware mesh. This is layout matches the + # manual device placing we just did. + input_sharding_constraint = PartitionSpec(*data_sharding, None) + + def form_gda(local_data, shape): + device_buffers = _put_to_devices(local_data) + # Wrap device buffers as GDA + shape = tuple(shape) + print("####################### Debug") + print(f"shape: {shape}; global_mesh: {global_mesh}; ") + print(f"input_sharding_constraint: {input_sharding_constraint};") + print(f"jax.sharding.NamedSharding: {jax.sharding.NamedSharding(global_mesh, input_sharding_constraint)};") + input_gda = jax.make_array_from_single_device_arrays(shape, + jax.sharding.NamedSharding(global_mesh, input_sharding_constraint), device_buffers) + return input_gda + + input_gdas = jax.tree_map(form_gda, local_data, global_data_shape) + + return input_gdas \ No newline at end of file diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 737c46615..d765a5c32 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -92,7 +92,7 @@ def user_init(raw_keys): base_output_directory = raw_keys["base_output_directory"] validate_gcs_bucket_name(base_output_directory, "base_output_directory") dataset_path = raw_keys["dataset_path"] - validate_gcs_bucket_name(dataset_path, "dataset_path") + # validate_gcs_bucket_name(dataset_path, "dataset_path") assert ((raw_keys["load_parameters_path"]=="" and raw_keys["load_from_other_directory"]=="") or raw_keys["enable_checkpointing"]), "You must set enable_checkpointing to load a checkpoint" assert raw_keys["load_parameters_path"]=="" or raw_keys["load_from_other_directory"]=="" \ diff --git a/MaxText/pygrain_operations.py b/MaxText/pygrain_operations.py new file mode 100644 index 000000000..aae2f86e8 --- /dev/null +++ b/MaxText/pygrain_operations.py @@ -0,0 +1,105 @@ +from typing import Dict +import grain.python as pygrain +import numpy as np +import tensorflow as tf +Features = Dict[str, tf.Tensor] + +class normalize_features(): + """Normalize text feature keys. + """ + def __call__(self, features): + def _normalize_features(features): + # features['inputs'] = features.pop('text') + # features['targets'] = features['inputs'] + return {'inputs':features, 'targets': features} + return _normalize_features(features) + +class TokenizeOperation(): + """ TokenizeOp + """ + def __init__(self, sp_tokenizer): + self.sp_tokenizer = sp_tokenizer + + def __call__(self, features: Features): + data_keys = ('inputs', 'targets') + for k in data_keys: + features[k] = np.asarray(self.sp_tokenizer.tokenize(str(features[k]))) + # features[k] = self.sp_tokenizer.tokenize(str(features[k])) + # import pdb;pdb.set_trace() + return features + +class length_filter(): + """pygrain max length filter + """ + def __init__(self,max_length): + self.max_len = max_length + def __call__(self, x): + source, target = x['inputs'], x['targets'] + l = np.maximum(np.shape(source)[0], np.shape(target)[0]) + return np.less(l, self.max_len + 1) + +class PadToMaxLength(): + """Pads each input to the specified length + """ + def __init__(self, feature_lengths): + self.feature_lengths = feature_lengths + + def __call__(self, data): + def pad(x, max_length): + pad_amount = max(max_length - x.shape[0], 0) + pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1) + return np.pad(x, pad_amount) + data['inputs_segmentation'] = np.ones(data['inputs'].shape) + data['inputs_position'] = np.ones(data['inputs'].shape, dtype = np.int32) + for key, _ in data.items(): + data[key] = pad(data[key], self.feature_lengths) + return data + +class CombineKeys(): + """ Combine tuples of sequence packing output in different keys + """ + def __call__(self, data): + combined_data = data[0] + segments = data[1] + segments['inputs_segmentation'] = segments.pop('inputs') + segments['targets_segmentation'] = segments.pop('targets') + positions = data[2] + positions['inputs_position'] = positions.pop('inputs') + positions['targets_position'] = positions.pop('targets') + combined_data.update(segments) + combined_data.update(positions) + return combined_data + +def shift_right_tf(x, axis=1): + """Shift the input to the right by padding and slicing on axis.""" + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[axis] = (1, 0) + slices = [slice(None),] * len(x.shape) + slices[axis] = slice(0, -1) + padded = tf.pad( + x, + tf.constant(pad_widths), + mode='constant', + constant_values=tf.constant(0, x.dtype)) + return padded[tuple(slices)] + +def shift_inputs_tf(x, segment_ids=None, axis=1): + """Shift inputs and replace EOS by 0 for packed inputs.""" + shifted = shift_right_tf(x, axis=axis) + # For packed targets, the first shifted token of a new sequence is made + # 0, rather than being the EOS token for the last sequence. + if segment_ids is not None: + shifted *= tf.cast( + segment_ids == shift_right_tf(segment_ids, axis=axis), x.dtype + ) + return shifted + +class ShiftData(): + def __init__(self, axis = 0, segmented=True): + self.axis = axis + self.segmented = segmented + + def __call__(self, x): + segment_ids = x['inputs_segmentation'] if self.segmented else None + x['inputs'] = shift_inputs_tf(x['inputs'], segment_ids=segment_ids, axis=self.axis) + return x \ No newline at end of file diff --git a/MaxText/pygrain_tokenizer.py b/MaxText/pygrain_tokenizer.py new file mode 100644 index 000000000..22d486229 --- /dev/null +++ b/MaxText/pygrain_tokenizer.py @@ -0,0 +1,164 @@ +import abc +from collections.abc import Mapping, Sequence +import copy +import dataclasses +import math +import threading +from typing import Any +from sentencepiece import SentencePieceProcessor +import grain.python as grain + +class AbstractTokenizeAndSplit(grain.MapTransform): + """Tokenize and split text features. + + The split of the tokenized features will replace the text features. + + This transform makes 2 assumptions: + - Records are flat dictionaries with 1 or more text features. + - It follows a DataSourceWithSplitInfo which should produce elements as: + (example, (split_index, expected_split_count)) + + The transform will produces None if the actual example doesn't have the + corresponding split. + """ + + def __init__( + self, + feature_names: str | Sequence[str], + sequence_length: int | Sequence[int], + ): + """Creates a new TokenizeAndSplit transform. + + Args: + feature_names: One or multiple feature names that contain text. + sequence_length: One or multiple sequence lengths to use for the text + features. Output features will have [0, sequence_length] tokens. + """ + super().__init__() + if isinstance(feature_names, str): + feature_names = [feature_names] + if isinstance(sequence_length, int): + sequence_length = [sequence_length] * len(feature_names) + elif len(sequence_length) != len(feature_names): + raise ValueError( + f"Number of features and sequence lengths mismatch: {feature_names=}," + f" {sequence_length=}" + ) + self._feature_names = feature_names + self._sequence_length = sequence_length + self._stats = { + "empty_splits": 0, + "discarded_splits": 0, + } + + def map( + self, features: tuple[dict[str, Any], tuple[int, int]] + ) -> dict[str, Any] | None: + features, (split_index, expected_split_count) = features + actual_split_count = 0 + for feature_name, sequence_length in zip( + self._feature_names, self._sequence_length, strict=True + ): + text = features[feature_name] + token_ids = self._tokenize(text) + start = split_index * sequence_length + end = (split_index + 1) * sequence_length + if start >= len(token_ids): + self._stats["empty_splits"] += 1 + return None + actual_split_count = max( + actual_split_count, int(math.ceil(len(token_ids) / sequence_length)) + ) + features[feature_name] = np.asarray(token_ids[start:end]) + if split_index == 0 and actual_split_count > expected_split_count: + self._stats["discarded_splits"] += ( + actual_split_count - expected_split_count + ) + return features + + def get_stats(self) -> Mapping[str, int]: + return copy.copy(self._stats) + + @abc.abstractmethod + def _tokenize(self, text: str) -> Sequence[int]: + """Tokenizes the text.""" + + +class SentencePieceTokenizeAndSplit(AbstractTokenizeAndSplit): + """Tokenize and split text features using a Gemini tokenizer.""" + + def __init__( + self, + feature_names: str | Sequence[str], + sequence_length: int | Sequence[int], + sentencepiece_model_path: str, + ): + super().__init__(feature_names, sequence_length) + self._sentencepiece_model_path = sentencepiece_model_path + self._initialize_processor_lock = threading.Lock() + self._tokenizer = None + + def _tokenize(self, text: str) -> Sequence[int]: + if self._tokenizer is None: + with self._initialize_processor_lock: + if self._tokenizer is None: + self._tokenizer = sentencepiece_processor.SentencePieceProcessor() + self._tokenizer.Load(filename=self._sentencepiece_model_path) + return self._tokenizer.EncodeAsIds(text) + + def __getstate__(self): + state = self.__dict__.copy() + del state["_tokenizer"] + del state["_initialize_processor_lock"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._tokenizer = None + self._initialize_processor_lock = threading.Lock() + + +@dataclasses.dataclass +class TokenizeAndPad(grain.MapTransform): + """Tokenize, truncate and pad features to sequence length.""" + + feature_names: str | Sequence[str] + sequence_length: int | Sequence[int] + model_path: str + + def __post_init__(self): + self._processor = None + self._initialize_processor_lock = threading.Lock() + if isinstance(self.feature_names, str): + self.feature_names = [self.feature_names] + if isinstance(self.sequence_length, int): + self.sequence_length = [self.sequence_length] * len(self.feature_names) + + def map(self, features: dict[str, Any]) -> dict[str, Any]: + if self._processor is None: + with self._initialize_processor_lock: + if self._processor is None: # Ensures only one thread initializes SPP. + self._processor = SentencePieceProcessor(self.model_path) + # self._processor.Load(filename=self.model_path) + for feature_name, sequence_length in zip( + self.feature_names, self.sequence_length, strict=True + ): + text = features[feature_name] + token_ids = self._processor.EncodeAsIds(text) + token_ids = token_ids[:sequence_length] + token_ids = token_ids + [self._processor.pad_id()] * ( + sequence_length - len(token_ids) + ) + features[feature_name] = token_ids + return features + + def __getstate__(self): + state = self.__dict__.copy() + del state["_processor"] + del state["_initialize_processor_lock"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._processor = None + self._initialize_processor_lock = threading.Lock() diff --git a/MaxText/train.py b/MaxText/train.py index 6b37faed8..69d8b2deb 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -217,6 +217,7 @@ def train_loop(config, state=None): config.enable_checkpointing, config.async_checkpointing, config.save_period, + config.dataset_type, ) # Initial PRNG Keys init_rng, nextrng = random.split(random.PRNGKey(config.init_weights_seed), 2) @@ -242,7 +243,12 @@ def train_loop(config, state=None): data_iterator, _ = create_data_iterator_with_tokenizer(config, mesh) - state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, init_rng, mesh, checkpoint_manager) + # if config.dataset_type == "array_record": + # inital_iterator = data_iterator + # else: + # inital_iterator = None + + state, state_mesh_annotations, data_iterator = max_utils.setup_initial_state(model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager) data_pspec = P(*config.data_sharding) num_model_parameters = calculate_num_params_from_pytree(state.params) @@ -280,7 +286,9 @@ def train_loop(config, state=None): last_step_completion = new_time if checkpoint_manager is not None: - if checkpoint_manager.save(step, state): + if config.dataset_type == "array_record" and checkpoint_manager.save(step, {'state':state,'iter':data_iterator}): + max_logging.log(f"saved a checkpoint (containing state and iter) at step {step}") + elif checkpoint_manager.save(step, state): max_logging.log(f"saved a checkpoint at step {step}") # Upon preemption, exit when and only when all ongoing saves are complete. if checkpoint_manager.reached_preemption(step):