Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Nov 16, 2023
1 parent 4ebff99 commit 077c21b
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 186 deletions.
7 changes: 4 additions & 3 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ vocab_relative_path: "tokenizer" # Assumes we're allowed
# eval_dataset_name: 'c4/en:3.0.1'
# dataset_name: 'array-record/c4/en/3.0.1'
# eval_dataset_name: 'array-record/c4/en/3.0.1'
# eval_split: 'validation' # for c4 data
dataset_name: 'lm1b/1.1.0'
eval_dataset_name: 'lm1b/1.1.0'
eval_split: 'test'
eval_split: 'test' # for lm1b
per_device_batch_size: 12.0
eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
dataset_type: array_record # must be c4, array_record or synthetic
# dataset_type: c4
# dataset_type: c4 # must be c4, array_record or synthetic
dataset_type: array_record

# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
Expand Down
63 changes: 25 additions & 38 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
import multihost_dataloading
import sequence_packing
import pygrain_operations
from transformers import T5Tokenizer
from sentencepiece import SentencePieceProcessor
import pygrain_tokenizer

AUTOTUNE = tf.data.experimental.AUTOTUNE
Expand Down Expand Up @@ -172,7 +170,7 @@ def filter_fn(x):

def preprocessing_pipeline_pygrain(
dataset,
operations,
vocab_path,
batch_size: int,
global_mesh,
shuffle: bool,
Expand All @@ -185,6 +183,12 @@ def preprocessing_pipeline_pygrain(
data_sharding = None,
data_shuffle_seed = 0,
):

operations = []
operations.append(pygrain_operations.ParseFeatures())
operations.append(pygrain_operations.NormalizeFeatures())
operations.append(pygrain_tokenizer.Tokenize(["inputs","targets"], max_length, vocab_path))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.filter_keys))
operations.append(pygrain.FilterOperation(condition_function = pygrain_operations.length_filter(max_length)))

# Pack and Batch examples.
Expand All @@ -194,7 +198,7 @@ def preprocessing_pipeline_pygrain(
length_struct={'inputs':max_length,'targets':max_length}))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.CombineKeys()))
else:
# operations.append(pygrain.MapOperation(map_function=pygrain_operations.PadToMaxLength(max_length)))
operations.append(pygrain.MapOperation(map_function=pygrain_operations.PadToMaxLength(max_length)))
operations.append(pygrain.BatchOperation(batch_size=batch_size // jax.process_count(), drop_remainder=drop_remainder))

# Shift inputs for teacher-forced training
Expand All @@ -212,22 +216,15 @@ def preprocessing_pipeline_pygrain(
)

dataloader = pygrain.DataLoader(
data_source = dataset,
operations = operations,
sampler = index_sampler,
worker_count=0,
data_source = dataset,
operations = operations,
sampler = index_sampler,
worker_count=1,
)

data_iter = iter(dataloader)
global_shape = (batch_size, max_length)
# Return PyGrainIterator
# return data_iter
multihost_gen = (
multihost_dataloading.get_next_batch_sharded_pygrain(
data_iter, data_sharding, global_shape, global_mesh
)
)
# Return multi-host jax.Array prep iterator
return multihost_gen

return data_iter


def get_datasets(
Expand Down Expand Up @@ -287,7 +284,7 @@ def preprocess_dataset(config: ml_collections.ConfigDict,
# Load tokenizer
sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path,
vocab_size=config.vocab_size)
# sp_tokenizer = T5Tokenizer.from_pretrained('t5-base')

# Tokenize data.
train_ds = train_ds.map(
tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE)
Expand Down Expand Up @@ -354,15 +351,9 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,
vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model')

# Load tokenizer
# sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path,
# vocab_size=config.vocab_size)
# sp_tokenizer = T5Tokenizer.from_pretrained('t5-base', model_max_length=1024)
sp_tokenizer = SentencePieceProcessor(vocab_path)

operations = [pygrain.MapOperation(map_function=pygrain_operations.normalize_features())]
#operations.append(pygrain.MapOperation(map_function=pygrain_operations.TokenizeOperation(sp_tokenizer)))
operations.append(pygrain_tokenizer.TokenizeAndPad(["inputs","targets"], config.max_target_length, vocab_path))

sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path,
vocab_size=config.vocab_size)

# Set global batch size.
global_batch_size_to_load = config.global_batch_size_to_load

Expand All @@ -371,44 +362,40 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,
else:
eval_batch_size = global_batch_size_to_load

def filter_keys(record):
return {'inputs': record['inputs'], 'targets': record['targets']}
operations.append(pygrain.MapOperation(map_function=filter_keys))

train_iter = preprocessing_pipeline_pygrain(
train_ds,
operations,
vocab_path,
global_batch_size_to_load,
global_mesh,
shuffle=config.enable_data_shuffling,
num_epochs=1,
pack_examples=False,
max_length=config.max_target_length,
shift=False,
shift=True,
data_sharding=config.data_sharding,
data_shuffle_seed = data_shuffle_seed,)

eval_iter = preprocessing_pipeline_pygrain(
eval_ds,
operations,
vocab_path,
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
max_length=config.max_eval_target_length,
shift=False,
shift=True,
data_sharding=config.data_sharding,
data_shuffle_seed = data_shuffle_seed,)

predict_iter = preprocessing_pipeline_pygrain(
eval_ds,
operations,
vocab_path,
eval_batch_size,
global_mesh,
shuffle=config.enable_data_shuffling,
pack_examples=False,
max_length=config.max_eval_target_length,
shift=False,
shift=True,
data_sharding=config.data_sharding,
data_shuffle_seed = data_shuffle_seed,)

Expand Down
8 changes: 0 additions & 8 deletions MaxText/multihost_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def get_next_batch_sharded(local_dataset: tf.data.Dataset,
if not loaded_data_success:
local_data = local_dataset.next()

# local_devices = jax.local_devices()
local_devices = global_mesh.local_devices
local_device_count = jax.local_device_count()

Expand Down Expand Up @@ -190,11 +189,8 @@ def get_next_batch_sharded_pygrain(data_iter,
data_axes = jax.tree_map(lambda x: PartitionSpec(*data_sharding), local_data)
_ = check_inputs("array_record", local_data, global_data_shape, data_axes)

# local_devices = jax.local_devices()
local_devices = global_mesh.local_devices
local_device_count = jax.local_device_count()
print(f"local_device: {local_devices}")
print(f"local_device_count: {local_device_count}")

def _put_to_devices(x):
try:
Expand All @@ -217,10 +213,6 @@ def form_gda(local_data, shape):
device_buffers = _put_to_devices(local_data)
# Wrap device buffers as GDA
shape = tuple(shape)
print("####################### Debug")
print(f"shape: {shape}; global_mesh: {global_mesh}; ")
print(f"input_sharding_constraint: {input_sharding_constraint};")
print(f"jax.sharding.NamedSharding: {jax.sharding.NamedSharding(global_mesh, input_sharding_constraint)};")
input_gda = jax.make_array_from_single_device_arrays(shape,
jax.sharding.NamedSharding(global_mesh, input_sharding_constraint), device_buffers)
return input_gda
Expand Down
43 changes: 35 additions & 8 deletions MaxText/pygrain_operations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Mapping, Sequence
import dataclasses
from typing import Dict
import grain.python as pygrain
import numpy as np
Expand All @@ -14,6 +16,29 @@ def _normalize_features(features):
return {'inputs':features, 'targets': features}
return _normalize_features(features)

@dataclasses.dataclass
class ParseFeatures(pygrain.MapTransform):
def map(self, features):
def _parse(example):
parsed = tf.io.parse_example(
example, {
'text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
})
return parsed
return _parse(features)


@dataclasses.dataclass
class NormalizeFeatures(pygrain.MapTransform):
def map(self, features):
return {
'inputs':features['text'].numpy().decode(),
'targets': features['text'].numpy().decode()
}

def filter_keys(record):
return {'inputs': record['inputs'], 'targets': record['targets']}

class TokenizeOperation():
""" TokenizeOp
"""
Expand Down Expand Up @@ -70,27 +95,29 @@ def __call__(self, data):
combined_data.update(positions)
return combined_data

def shift_right_tf(x, axis=1):
def shift_right(x, axis=1):
"""Shift the input to the right by padding and slicing on axis."""
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[axis] = (1, 0)
slices = [slice(None),] * len(x.shape)
slices[axis] = slice(0, -1)
padded = tf.pad(
padded = np.pad(
x,
tf.constant(pad_widths),
pad_widths,
mode='constant',
constant_values=tf.constant(0, x.dtype))
constant_values=x.dtype.type(0)
# constant_values=tf.constant(0, x.dtype)
)
return padded[tuple(slices)]

def shift_inputs_tf(x, segment_ids=None, axis=1):
def shift_inputs(x, segment_ids=None, axis=1):
"""Shift inputs and replace EOS by 0 for packed inputs."""
shifted = shift_right_tf(x, axis=axis)
shifted = shift_right(x, axis=axis)
# For packed targets, the first shifted token of a new sequence is made
# 0, rather than being the EOS token for the last sequence.
if segment_ids is not None:
shifted *= tf.cast(
segment_ids == shift_right_tf(segment_ids, axis=axis), x.dtype
segment_ids == shift_right(segment_ids, axis=axis), x.dtype
)
return shifted

Expand All @@ -101,5 +128,5 @@ def __init__(self, axis = 0, segmented=True):

def __call__(self, x):
segment_ids = x['inputs_segmentation'] if self.segmented else None
x['inputs'] = shift_inputs_tf(x['inputs'], segment_ids=segment_ids, axis=self.axis)
x['inputs'] = shift_inputs(x['inputs'], segment_ids=segment_ids, axis=self.axis)
return x
Loading

0 comments on commit 077c21b

Please sign in to comment.