Skip to content

Commit

Permalink
fix for convergence test
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Dec 29, 2023
1 parent 5618295 commit 79268e9
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 58 deletions.
5 changes: 3 additions & 2 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ dataset_path: ""
vocab_size: 32_768 # powers of 2 for sharding
assets_path: "assets"
vocab_relative_path: "tokenizer" # Assumes we're allowed
# dataset_name: 'c4/en:3.1.0'
# eval_dataset_name: 'c4/en:3.1.0'
# dataset_name: 'c4/en:3.0.1'
# eval_dataset_name: 'c4/en:3.0.1'
dataset_name: 'array-record/c4/en/3.0.1'
eval_dataset_name: 'array-record/c4/en/3.0.1'
eval_split: 'validation' # for c4 data
Expand All @@ -128,6 +128,7 @@ max_corpus_chars: 10_000_000
# dataset_type: c4 # must be c4, array_record or synthetic
dataset_type: array_record
grain_worker_count: 1
pack_examples: True

# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
Expand Down
60 changes: 45 additions & 15 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def filter_fn(x):
return tf.less(l, max_len + 1)
return filter_fn

def add_annotations(ds):
def _add_annotations(features):
features['inputs_segmentation'] = tf.ones_like(features['inputs'], dtype = tf.int32)
features['inputs_position'] = tf.range(tf.size(features['inputs']), dtype = tf.int32)
features['targets_segmentation'] = tf.ones_like(features['inputs'], dtype = tf.int32)
features['targets_position'] = tf.range(tf.size(features['inputs']), dtype = tf.int32)
return features

return ds.map(
_add_annotations,
num_parallel_calls=AUTOTUNE)
# -----------------------------------------------------------------------------
# Main dataset preparation.
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -125,11 +136,14 @@ def preprocessing_pipeline(
# Perform greedy sequence packing
if pack_examples:
dataset = sequence_packing.pack_dataset(dataset, max_length)
else:
dataset = add_annotations(dataset)
# dataset.apply(add_annotations)

# Shift inputs for teacher-forced training
if shift:
dataset = dataset.map(
functools.partial(shift_data, axis=0, segmented=pack_examples),
functools.partial(shift_data, axis=0, segmented=1),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=True)

Expand All @@ -152,8 +166,20 @@ def preprocessing_pipeline(
# simple (static-shape) padded batching
dataset = dataset.padded_batch(
batch_size // jax.process_count(),
padded_shapes={'inputs': max_length, 'targets': max_length},
padding_values={'inputs': 0, 'targets': 0},
padded_shapes={
'inputs': max_length,
'targets': max_length,
'inputs_position': max_length,
'targets_position': max_length,
'inputs_segmentation': max_length,
'targets_segmentation': max_length,
},
padding_values={'inputs': 0, 'targets': 0,
'inputs_position': 0,
'targets_position': 0,
'inputs_segmentation': 0,
'targets_segmentation': 0,
},
drop_remainder=drop_remainder)

if prefetch_size:
Expand Down Expand Up @@ -188,23 +214,27 @@ def preprocessing_pipeline_pygrain(
operations = []
operations.append(pygrain_operations.ParseFeatures())
operations.append(pygrain_operations.NormalizeFeatures())
operations.append(pygrain_tokenizer.Tokenize(["inputs","targets"], max_length, vocab_path))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.filter_keys))
operations.append(pygrain.FilterOperation(condition_function = pygrain_operations.length_filter(max_length)))
operations.append(pygrain_tokenizer.Tokenize(["inputs","targets"], max_length, vocab_path, 32768))
# operations.append(pygrain.MapOperation(map_function=pygrain_operations.filter_keys))
operations.append(pygrain_operations.LengthFilter(max_length))
#operations.append(pygrain.FilterOperation(condition_function = pygrain_operations.length_filter(max_length)))

# Pack and Batch examples.
if pack_examples:
operations.append(pygrain.experimental.PackAndBatchOperation(
batch_size=batch_size // jax.process_count(),
length_struct={'inputs':max_length,'targets':max_length}))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.CombineKeys()))
# operations.append(pygrain.MapOperation(map_function=pygrain_operations.CombineKeys()))
operations.append(pygrain_operations.ReformatPacking())
else:
operations.append(pygrain.MapOperation(map_function=pygrain_operations.PadToMaxLength(max_length)))
operations.append(pygrain.BatchOperation(batch_size=batch_size // jax.process_count(), drop_remainder=drop_remainder))

# Shift inputs for teacher-forced training
if shift:
operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=1,segmented=pack_examples)))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.ShiftData(axis=1)))

# operations.append(pygrain_operations.ConvertToTF())

index_sampler = pygrain.IndexSampler(
num_records=len(dataset),
Expand Down Expand Up @@ -261,7 +291,7 @@ def get_datasets_pygrain(
config: ml_collections.ConfigDict,
read_config = None,
):
"""Load dataset from array_record files for using with pygrain"""
"""Load dataset from array_record files for using with pygrain"""
data_dir = os.path.join(config.dataset_path, config.dataset_name)
train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)]
train_ds = pygrain.ArrayRecordDataSource(train_files)
Expand Down Expand Up @@ -313,7 +343,7 @@ def filter_keys(record):
global_mesh,
shuffle=config.enable_data_shuffling,
num_epochs=None,
pack_examples=True,
pack_examples=config.pack_examples,
max_length=config.max_target_length,
shift=True,
data_sharding = config.data_sharding,
Expand All @@ -324,7 +354,7 @@ def filter_keys(record):
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
pack_examples=config.pack_examples,
max_length=config.max_eval_target_length,
shift=False,
data_sharding = config.data_sharding,
Expand All @@ -335,7 +365,7 @@ def filter_keys(record):
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
pack_examples=config.pack_examples,
max_length=config.max_predict_length,
shift=False,
drop_remainder=False,
Expand Down Expand Up @@ -373,7 +403,7 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,
global_mesh,
shuffle=config.enable_data_shuffling,
num_epochs=1,
pack_examples=False,
pack_examples=config.pack_examples,
max_length=config.max_target_length,
shift=True,
data_sharding=config.data_sharding,
Expand All @@ -386,7 +416,7 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
pack_examples=config.pack_examples,
max_length=config.max_eval_target_length,
shift=True,
data_sharding=config.data_sharding,
Expand All @@ -399,7 +429,7 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
pack_examples=config.pack_examples,
max_length=config.max_eval_target_length,
shift=True,
data_sharding=config.data_sharding,
Expand Down
111 changes: 71 additions & 40 deletions MaxText/pygrain_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@
import tensorflow as tf
Features = Dict[str, tf.Tensor]

class normalize_features():
"""Normalize text feature keys.
"""
def __call__(self, features):
def _normalize_features(features):
# features['inputs'] = features.pop('text')
# features['targets'] = features['inputs']
return {'inputs':features, 'targets': features}
return _normalize_features(features)

@dataclasses.dataclass
class ParseFeatures(pygrain.MapTransform):
def map(self, features):
Expand All @@ -36,6 +26,47 @@ def map(self, features):
'targets': features['text'].numpy().decode()
}

# @dataclasses.dataclass
# class ConvertToTF(pygrain.MapTransform):
# def map(self, data):
# for key in data:
# data[key] = tf.convert_to_tensor(data[key], dtype=tf.int32)
# return data

@dataclasses.dataclass
class ReformatPacking(pygrain.MapTransform):
def map(self, data):
return{
'inputs':data[0]['inputs'],
'targets':data[0]['targets'],
'inputs_segmentation':data[1]['inputs'],
'targets_segmentation':data[1]['targets'],
'inputs_position':data[2]['inputs'],
'targets_position':data[2]['targets'],
}


@dataclasses.dataclass
class LengthFilter(pygrain.FilterTransform):
def __init__(self, max_length):
self.max_length = max_length
def filter(self, data):
# source, target = data['inputs'], data['targets']
# l = np.maximum(np.shape(source)[0], np.shape(target)[0])
# print(data['inputs'].shape)
return data['inputs'].shape[0] < self.max_length


def length_filter():
"""pygrain max length filter
"""
def __init__(self,max_length):
self.max_length = max_length
def __call__(self, x):
source, target = x['inputs'], x['targets']
l = np.maximum(np.shape(source)[0], np.shape(target)[0])
return np.less(l, self.max_length + 1)

def filter_keys(record):
return {'inputs': record['inputs'], 'targets': record['targets']}

Expand Down Expand Up @@ -74,26 +105,28 @@ def pad(x, max_length):
pad_amount = max(max_length - x.shape[0], 0)
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
return np.pad(x, pad_amount)
data['inputs_segmentation'] = np.ones(data['inputs'].shape)
data['inputs_position'] = np.ones(data['inputs'].shape, dtype = np.int32)
data['inputs_segmentation'] = np.ones(data['inputs'].shape, dtype = np.int32)
data['inputs_position'] = np.arange(data['inputs'].shape[0], dtype = np.int32)
data['targets_segmentation'] = np.ones(data['targets'].shape, dtype = np.int32)
data['targets_position'] = np.arange(data['targets'].shape[0], dtype = np.int32)
for key, _ in data.items():
data[key] = pad(data[key], self.feature_lengths)
return data

class CombineKeys():
""" Combine tuples of sequence packing output in different keys
"""
def __call__(self, data):
combined_data = data[0]
segments = data[1]
segments['inputs_segmentation'] = segments.pop('inputs')
segments['targets_segmentation'] = segments.pop('targets')
positions = data[2]
positions['inputs_position'] = positions.pop('inputs')
positions['targets_position'] = positions.pop('targets')
combined_data.update(segments)
combined_data.update(positions)
return combined_data
# class CombineKeys():
# """ Combine tuples of sequence packing output in different keys
# """
# def __call__(self, data):
# combined_data = data[0]
# segments = data[1]
# segments['inputs_segmentation'] = segments.pop('inputs')
# segments['targets_segmentation'] = segments.pop('targets')
# positions = data[2]
# positions['inputs_position'] = positions.pop('inputs')
# positions['targets_position'] = positions.pop('targets')
# combined_data.update(segments)
# combined_data.update(positions)
# return combined_data

def shift_right(x, axis=1):
"""Shift the input to the right by padding and slicing on axis."""
Expand All @@ -106,27 +139,25 @@ def shift_right(x, axis=1):
pad_widths,
mode='constant',
constant_values=x.dtype.type(0)
# constant_values=tf.constant(0, x.dtype)
)
return padded[tuple(slices)]

def shift_inputs(x, segment_ids=None, axis=1):
def shift_and_refine(x, axis=1):
"""Shift inputs and replace EOS by 0 for packed inputs."""
shifted = shift_right(x, axis=axis)
x['inputs'] = shift_right(x['inputs'], axis=axis)
targets_nonzero = (x['targets'] != 0)
x['inputs_segmentation'] *= targets_nonzero
x['targets_segmentation'] *= targets_nonzero
# For packed targets, the first shifted token of a new sequence is made
# 0, rather than being the EOS token for the last sequence.
if segment_ids is not None:
shifted *= tf.cast(
segment_ids == shift_right(segment_ids, axis=axis), x.dtype
)
return shifted
# 0, rather than being the EOS token for the last sequence.
x['inputs'] *= (x['inputs_segmentation'] == shift_right(x['inputs_segmentation'], axis=axis))

return x

class ShiftData():
def __init__(self, axis = 0, segmented=True):
def __init__(self, axis = 1):
self.axis = axis
self.segmented = segmented

def __call__(self, x):
segment_ids = x['inputs_segmentation'] if self.segmented else None
x['inputs'] = shift_inputs(x['inputs'], segment_ids=segment_ids, axis=self.axis)
return x
x = shift_and_refine(x, axis=self.axis)
return x
3 changes: 2 additions & 1 deletion MaxText/pygrain_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Tokenize(grain.MapTransform):
feature_names: str | Sequence[str]
sequence_length: int | Sequence[int]
model_path: str
vocab_size: int

def __post_init__(self):
self._processor = None
Expand All @@ -36,7 +37,7 @@ def map(self, features: dict[str, Any]) -> dict[str, Any]:
text = features[feature_name]
token_ids = self._processor.EncodeAsIds(text)
token_ids = token_ids[:sequence_length]
features[feature_name] = np.asarray(token_ids)
features[feature_name] = np.asarray(token_ids, dtype=np.int32)
return features

def __getstate__(self):
Expand Down
41 changes: 41 additions & 0 deletions end_to_end/test_convergence_1b_params_grain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash
set -e

USER=${1}
LOSS_THRESHOLD=${2}
OUTPUT_PATH=${3}
DATASET_PATH=${4}
PACKING=${5}
WORKER=${6}
# if [ -z ${5} ]
# then
# RUN_NAME=${USER}
# else
# RUN_NAME=${5}
# fi

if [ -z ${7} ]
then
STEPS=3400
else
STEPS=${6}
fi

# Train
export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true"

python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\
steps=$STEPS per_device_batch_size=12.0 learning_rate=1e-3 enable_checkpointing=false\
max_target_length=2048 global_parameter_scale=1\
enable_profiler=false metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\
dataset_path=$DATASET_PATH log_period=150 pack_examples=$PACKING\
grain_worker_count=$WORKER

# Assert training loss is smaller than input LOSS_THRESHOLD
python3 end_to_end/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD

0 comments on commit 79268e9

Please sign in to comment.