Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Nov 19, 2023
1 parent 077c21b commit 1332d8b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
19 changes: 10 additions & 9 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ global_parameter_scale: 1
base_emb_dim: 2560
base_num_heads: 8
base_mlp_dim: 8192
base_num_decoder_layers: 2
base_num_decoder_layers: 16
head_dim: 256
# activation functions are .
mlp_activations: ["relu"]
Expand Down Expand Up @@ -114,19 +114,20 @@ dataset_path: ""
vocab_size: 32_768 # powers of 2 for sharding
assets_path: "assets"
vocab_relative_path: "tokenizer" # Assumes we're allowed
# dataset_name: 'c4/en:3.0.1'
# eval_dataset_name: 'c4/en:3.0.1'
# dataset_name: 'array-record/c4/en/3.0.1'
# eval_dataset_name: 'array-record/c4/en/3.0.1'
# eval_split: 'validation' # for c4 data
dataset_name: 'lm1b/1.1.0'
eval_dataset_name: 'lm1b/1.1.0'
eval_split: 'test' # for lm1b
# dataset_name: 'c4/en:3.1.0'
# eval_dataset_name: 'c4/en:3.1.0'
dataset_name: 'array-record/c4/en/3.0.1'
eval_dataset_name: 'array-record/c4/en/3.0.1'
eval_split: 'validation' # for c4 data
# dataset_name: 'lm1b/1.1.0'
# eval_dataset_name: 'lm1b/1.1.0'
# eval_split: 'test' # for lm1b
per_device_batch_size: 12.0
eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
# dataset_type: c4 # must be c4, array_record or synthetic
dataset_type: array_record
grain_worker_count: 0

# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
Expand Down
52 changes: 42 additions & 10 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ def _normalize_features(features):
_normalize_features,
num_parallel_calls=AUTOTUNE)

# Max length filter.
def length_filter(max_len):
def filter_fn(x):
source, target = x['inputs'], x['targets']
l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
return tf.less(l, max_len + 1)
return filter_fn

# -----------------------------------------------------------------------------
# Main dataset preparation.
Expand All @@ -106,14 +113,6 @@ def preprocessing_pipeline(
):
"""Shuffle and batch/pack the given dataset."""

# Max length filter.
def length_filter(max_len):
def filter_fn(x):
source, target = x['inputs'], x['targets']
l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
return tf.less(l, max_len + 1)
return filter_fn

if max_length > 0:
dataset = dataset.filter(length_filter(max_length))

Expand Down Expand Up @@ -168,8 +167,28 @@ def filter_fn(x):
# Return multi-host jax.Array prep iterator
return multihost_gen

# def preprocessing_pipeline_lazydata(
# dataset,
# vocab_path,
# batch_size: int,
# global_mesh,
# shuffle: bool,
# num_epochs: Optional[int] = 1,
# pack_examples: bool = True,
# shuffle_buffer_size: int = 1024,
# max_length: int = 512,
# shift: bool = True,
# drop_remainder: bool = True,
# data_sharding = None,
# data_shuffle_seed = 0,
# ):
# dataset = normalize_features(dataset)
# dataset = dataset.filter(length_filter(max_length))


def preprocessing_pipeline_pygrain(
dataset,
grain_worker_count,
vocab_path,
batch_size: int,
global_mesh,
Expand Down Expand Up @@ -219,7 +238,7 @@ def preprocessing_pipeline_pygrain(
data_source = dataset,
operations = operations,
sampler = index_sampler,
worker_count=1,
worker_count=grain_worker_count,
)

data_iter = iter(dataloader)
Expand Down Expand Up @@ -267,7 +286,17 @@ def get_datasets_pygrain(
eval_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(rf'.*{config.eval_split}.*', f)]
eval_ds = pygrain.ArrayRecordDataSource(eval_files)
else:
eval_ds_builder = train_ds_builder
eval_ds = train_ds

# train_ds = tfds.data_source(config.dataset_name, split="train")
# if config.eval_dataset_name:
# eval_ds = tfds.data_source(config.dataset_name, split=config.eval_split)
# else:
# eval_ds = train_ds

# lazy_dataset = pygrain.experimental.lazy_dataset
# train_ds = lazy_dataset.SourceLazyMapDataset(train_ds)
# eval_ds = lazy_dataset.SourceLazyMapDataset(eval_ds)

return train_ds, eval_ds

Expand Down Expand Up @@ -364,6 +393,7 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,

train_iter = preprocessing_pipeline_pygrain(
train_ds,
config.grain_worker_count,
vocab_path,
global_batch_size_to_load,
global_mesh,
Expand All @@ -377,6 +407,7 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,

eval_iter = preprocessing_pipeline_pygrain(
eval_ds,
config.grain_worker_count,
vocab_path,
eval_batch_size,
global_mesh,
Expand All @@ -389,6 +420,7 @@ def preprocess_dataset_pygrain(config: ml_collections.ConfigDict,

predict_iter = preprocessing_pipeline_pygrain(
eval_ds,
config.grain_worker_count,
vocab_path,
eval_batch_size,
global_mesh,
Expand Down
6 changes: 5 additions & 1 deletion setup_gcsfuse.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,8 @@ sudo apt-get -y install gcsfuse

mkdir -p $MOUNT_PATH

gcsfuse --implicit-dirs "$DATASET_GCS_BUCKET" "$MOUNT_PATH"
# gcsfuse --implicit-dirs "$DATASET_GCS_BUCKET" "$MOUNT_PATH"

gcsfuse --implicit-dirs --http-client-timeout=5s --max-conns-per-host=2000 \
--debug_fuse_errors --debug_fuse --debug_gcs --debug_invariants --debug_mutex \
--log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH"

0 comments on commit 1332d8b

Please sign in to comment.