Skip to content

Commit

Permalink
add grain to 1028 head
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Nov 2, 2023
1 parent 1a526cc commit 4ebff99
Show file tree
Hide file tree
Showing 10 changed files with 603 additions and 31 deletions.
52 changes: 42 additions & 10 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions, Checkpointer, AsyncCheckpointer
from orbax.checkpoint import type_handlers
import socket
import grain.python as pygrain

import max_logging

Expand Down Expand Up @@ -55,7 +56,8 @@ def create_orbax_checkpoint_manager(
checkpoint_dir: str,
enable_checkpointing: bool,
use_async: bool,
save_interval_steps: int
save_interval_steps: int,
dataset_type: str,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand All @@ -65,9 +67,17 @@ def create_orbax_checkpoint_manager(
p = epath.Path(checkpoint_dir)
if use_async:
_multislice_distribute_initialize()
checkpointer = AsyncCheckpointer(checkpoint.PyTreeCheckpointHandler())
if dataset_type == "array_record":
checkpointer = {'state':AsyncCheckpointer(checkpoint.PyTreeCheckpointHandler()),
'iter':Checkpointer(pygrain.PyGrainCheckpointHandler())}
else:
checkpointer = AsyncCheckpointer(checkpoint.PyTreeCheckpointHandler())
else:
checkpointer = Checkpointer(checkpoint.PyTreeCheckpointHandler())
if dataset_type == "array_record":
checkpointer = {'state':Checkpointer(checkpoint.PyTreeCheckpointHandler()),
'iter':Checkpointer(pygrain.PyGrainCheckpointHandler())}
else:
checkpointer = Checkpointer(checkpoint.PyTreeCheckpointHandler())

mngr = CheckpointManager(
p,
Expand All @@ -86,6 +96,8 @@ def load_state_if_possible(checkpoint_manager: CheckpointManager,
load_from_other_directory: str,
load_from_other_directory_step: int,
abstract_unboxed_pre_state: train_state.TrainState,
dataset_type,
iterator,
mesh,
state_mesh_annotations):
"""Loads TrainState as possible from the inputs.
Expand Down Expand Up @@ -118,22 +130,38 @@ def map_to_pspec(data, pspec):
else:
return type_handlers.RestoreArgs()

restore_args = jax.tree_util.tree_map(map_to_pspec,
abstract_unboxed_pre_state,
state_mesh_annotations)
if dataset_type=="array_record":
restore_state = jax.tree_util.tree_map(map_to_pspec,
abstract_unboxed_pre_state,
state_mesh_annotations)
restore_args = {'state':restore_state, 'iter':iterator}
else:
restore_args = jax.tree_util.tree_map(map_to_pspec,
abstract_unboxed_pre_state,
state_mesh_annotations)

latest_step = checkpoint_manager.latest_step()
if latest_step is not None:
max_logging.log(f"restoring state from this run's directory latest step \
{latest_step}")
return checkpoint_manager.restore(latest_step, abstract_unboxed_pre_state,
if dataset_type=="array_record":
return checkpoint_manager.restore(latest_step, {'state':abstract_unboxed_pre_state,'iter':iterator},
{"restore_args" : restore_args}), None
else:
return checkpoint_manager.restore(latest_step, abstract_unboxed_pre_state,
{"restore_args" : restore_args}), None
elif first_checkpoint_path != "":
max_logging.log(f"restoring state from first_checkpoint_path {first_checkpoint_path}")
p = epath.Path(first_checkpoint_path)
checkpointer = Checkpointer(checkpoint.PyTreeCheckpointHandler())
return None, checkpointer.restore(p,
item=abstract_unboxed_pre_state,
if dataset_type=="array_record":
return None, checkpointer.restore(p,
item={'state':abstract_unboxed_pre_state,'iter':iterator},
restore_args=restore_args).params
else:
return None, checkpointer.restore(p,
item=abstract_unboxed_pre_state,
restore_args=restore_args).params
elif load_from_other_directory != "":
p = epath.Path(load_from_other_directory)
checkpointer_loader = Checkpointer(checkpoint.PyTreeCheckpointHandler())
Expand All @@ -144,8 +172,12 @@ def map_to_pspec(data, pspec):
else:
step = load_from_other_directory_step
max_logging.log(f"restoring state from {load_from_other_directory} step {step}")
return mngr_loader.restore(step, abstract_unboxed_pre_state,
if dataset_type=="array_record":
return mngr_loader.restore(step, {'state':abstract_unboxed_pre_state,'iter':iterator},
{"restore_args" : restore_args}), None
else:
return mngr_loader.restore(step, abstract_unboxed_pre_state,
{"restore_args" : restore_args}), None
else:
max_logging.log("No existing checkpoints found, not restoring checkpoint.")
return None, None
19 changes: 12 additions & 7 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ global_parameter_scale: 1
base_emb_dim: 2560
base_num_heads: 8
base_mlp_dim: 8192
base_num_decoder_layers: 16
base_num_decoder_layers: 2
head_dim: 256
# activation functions are .
mlp_activations: ["relu"]
Expand Down Expand Up @@ -114,13 +114,18 @@ dataset_path: ""
vocab_size: 32_768 # powers of 2 for sharding
assets_path: "assets"
vocab_relative_path: "tokenizer" # Assumes we're allowed
dataset_name: 'c4/en:3.0.1'
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
# dataset_name: 'c4/en:3.0.1'
# eval_dataset_name: 'c4/en:3.0.1'
# dataset_name: 'array-record/c4/en/3.0.1'
# eval_dataset_name: 'array-record/c4/en/3.0.1'
dataset_name: 'lm1b/1.1.0'
eval_dataset_name: 'lm1b/1.1.0'
eval_split: 'test'
per_device_batch_size: 12.0
eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
dataset_type: c4 # must be c4 or synthetic
dataset_type: array_record # must be c4, array_record or synthetic
# dataset_type: c4

# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
Expand All @@ -140,9 +145,9 @@ learning_rate_schedule_steps: -1 # By default the length of the schedule is set
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

# Maximum length cutoff for training examples.
max_target_length: 2048
max_target_length: 1024
# Maximum length cutoff for held-out evaluation examples.
max_eval_target_length: 512
max_eval_target_length: 1024

# Maximum length cutoff for predicted tokens.
max_predict_length: 64
Expand Down
4 changes: 2 additions & 2 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def decode_loop(config, state=None):
) # TODO: we need an optax.GradientTransformation to form a TrainState, but we don't use it when decoding


_, sp_tokenizer = create_data_iterator_with_tokenizer(config, mesh)
data_iter, sp_tokenizer = create_data_iterator_with_tokenizer(config, mesh)

state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager)
state, state_mesh_annotations = max_utils.setup_initial_state(model, data_iter, tx, config, rng, mesh, checkpoint_manager)

state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
Expand Down
174 changes: 173 additions & 1 deletion MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@
"""Input pipeline for a LM1B dataset."""

import os
import re
from typing import Optional
import functools

# from array_record.python import array_record_data_source
import ml_collections
import tensorflow as tf
import tensorflow_datasets as tfds
import grain.python as pygrain
import jax
from jax.sharding import PartitionSpec as P

import tokenizer
import multihost_dataloading
import sequence_packing
import pygrain_operations
from transformers import T5Tokenizer
from sentencepiece import SentencePieceProcessor
import pygrain_tokenizer

AUTOTUNE = tf.data.experimental.AUTOTUNE

Expand Down Expand Up @@ -163,6 +170,65 @@ def filter_fn(x):
# Return multi-host jax.Array prep iterator
return multihost_gen

def preprocessing_pipeline_pygrain(
dataset,
operations,
batch_size: int,
global_mesh,
shuffle: bool,
num_epochs: Optional[int] = 1,
pack_examples: bool = True,
shuffle_buffer_size: int = 1024,
max_length: int = 512,
shift: bool = True,
drop_remainder: bool = True,
data_sharding = None,
data_shuffle_seed = 0,
):
operations.append(pygrain.FilterOperation(condition_function = pygrain_operations.length_filter(max_length)))

# Pack and Batch examples.
if pack_examples:
operations.append(pygrain.experimental.PackAndBatchOperation(
batch_size=batch_size // jax.process_count(),
length_struct={'inputs':max_length,'targets':max_length}))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.CombineKeys()))
else:
# operations.append(pygrain.MapOperation(map_function=pygrain_operations.PadToMaxLength(max_length)))
operations.append(pygrain.BatchOperation(batch_size=batch_size // jax.process_count(), drop_remainder=drop_remainder))

# Shift inputs for teacher-forced training
if shift:
operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=0,segmented=pack_examples)))

index_sampler = pygrain.IndexSampler(
num_records=len(dataset),
num_epochs = num_epochs,
shard_options=pygrain.ShardOptions(
shard_index = jax.process_index(), shard_count = jax.process_count(), drop_remainder = True
),
shuffle = shuffle,
seed = data_shuffle_seed
)

dataloader = pygrain.DataLoader(
data_source = dataset,
operations = operations,
sampler = index_sampler,
worker_count=0,
)
data_iter = iter(dataloader)
global_shape = (batch_size, max_length)
# Return PyGrainIterator
# return data_iter
multihost_gen = (
multihost_dataloading.get_next_batch_sharded_pygrain(
data_iter, data_sharding, global_shape, global_mesh
)
)
# Return multi-host jax.Array prep iterator
return multihost_gen


def get_datasets(
config: ml_collections.ConfigDict,
Expand Down Expand Up @@ -193,6 +259,21 @@ def get_datasets(

return train_ds, eval_ds

def get_datasets_pygrain(
config: ml_collections.ConfigDict,
read_config = None,
):
data_dir = os.path.join(config.dataset_path, config.dataset_name)
train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)]
train_ds = pygrain.ArrayRecordDataSource(train_files)
if config.eval_dataset_name:
eval_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(rf'.*{config.eval_split}.*', f)]
eval_ds = pygrain.ArrayRecordDataSource(eval_files)
else:
eval_ds_builder = train_ds_builder

return train_ds, eval_ds

def preprocess_dataset(config: ml_collections.ConfigDict,
global_mesh,
train_ds, eval_ds,
Expand All @@ -202,10 +283,11 @@ def preprocess_dataset(config: ml_collections.ConfigDict,
if vocab_path is None:
vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model')


# Load tokenizer
sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path,
vocab_size=config.vocab_size)

# sp_tokenizer = T5Tokenizer.from_pretrained('t5-base')
# Tokenize data.
train_ds = train_ds.map(
tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
Expand Down Expand Up @@ -262,6 +344,76 @@ def filter_keys(record):

return train_iter, eval_iter, predict_iter, sp_tokenizer

def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,
global_mesh,
train_ds, eval_ds,
vocab_path: Optional[str] = None,
data_shuffle_seed = 0,):
"""PyGrain: Pre-process the dataset and return iterators"""
if vocab_path is None:
vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model')

# Load tokenizer
# sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path,
# vocab_size=config.vocab_size)
# sp_tokenizer = T5Tokenizer.from_pretrained('t5-base', model_max_length=1024)
sp_tokenizer = SentencePieceProcessor(vocab_path)

operations = [pygrain.MapOperation(map_function=pygrain_operations.normalize_features())]
#operations.append(pygrain.MapOperation(map_function=pygrain_operations.TokenizeOperation(sp_tokenizer)))
operations.append(pygrain_tokenizer.TokenizeAndPad(["inputs","targets"], config.max_target_length, vocab_path))

# Set global batch size.
global_batch_size_to_load = config.global_batch_size_to_load

if config.eval_per_device_batch_size > 0:
eval_batch_size = config.eval_per_device_batch_size * global_mesh.size
else:
eval_batch_size = global_batch_size_to_load

def filter_keys(record):
return {'inputs': record['inputs'], 'targets': record['targets']}
operations.append(pygrain.MapOperation(map_function=filter_keys))

train_iter = preprocessing_pipeline_pygrain(
train_ds,
operations,
global_batch_size_to_load,
global_mesh,
shuffle=config.enable_data_shuffling,
num_epochs=1,
pack_examples=False,
max_length=config.max_target_length,
shift=False,
data_sharding=config.data_sharding,
data_shuffle_seed = data_shuffle_seed,)

eval_iter = preprocessing_pipeline_pygrain(
eval_ds,
operations,
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
max_length=config.max_eval_target_length,
shift=False,
data_sharding=config.data_sharding,
data_shuffle_seed = data_shuffle_seed,)

predict_iter = preprocessing_pipeline_pygrain(
eval_ds,
operations,
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
max_length=config.max_eval_target_length,
shift=False,
data_sharding=config.data_sharding,
data_shuffle_seed = data_shuffle_seed,)

return train_iter, eval_iter, predict_iter, sp_tokenizer


def make_c4_train_iterator_and_tokenizer(config, mesh):
""" Make train iterator and tokenizer for C4 dataset"""
Expand All @@ -281,6 +433,24 @@ def make_c4_train_iterator_and_tokenizer(config, mesh):
)
return train_iter, sp_tokenizer

def make_pygrain_train_iterator_and_tokenizer(config, mesh):
""" Make train iterator and tokenizer for C4 dataset"""
read_config = tfds.ReadConfig(
shuffle_seed = config.data_shuffle_seed,
)
train_ds, eval_ds = get_datasets_pygrain(
config=config,
read_config = read_config,
)
train_iter, _, _, sp_tokenizer = preprocess_dataset_pygrain(
config,
mesh,
train_ds, eval_ds,
vocab_path=os.path.join(config.assets_path, config.vocab_relative_path),
data_shuffle_seed = config.data_shuffle_seed,
)
return train_iter, sp_tokenizer

class SyntheticDataIterator():
"""Creates a synthetic data iterator for performance testing work"""
def __init__(self, config, mesh):
Expand Down Expand Up @@ -319,5 +489,7 @@ def create_data_iterator_with_tokenizer(config, mesh):
return SyntheticDataIterator(config, mesh), None
elif config.dataset_type == "c4":
return make_c4_train_iterator_and_tokenizer(config, mesh)
elif config.dataset_type == "array_record":
return make_pygrain_train_iterator_and_tokenizer(config, mesh)
else:
assert False, "dataset type not implemented"
Loading

0 comments on commit 4ebff99

Please sign in to comment.