Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Dec 19, 2023
1 parent 12cb195 commit 5618295
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 20 deletions.
2 changes: 1 addition & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
# dataset_type: c4 # must be c4, array_record or synthetic
dataset_type: array_record
grain_worker_count: 0
grain_worker_count: 1

# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
Expand Down
21 changes: 2 additions & 19 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,6 @@ def preprocessing_pipeline(
# Return multi-host jax.Array prep iterator
return multihost_gen

# def preprocessing_pipeline_lazydata(
# dataset,
# vocab_path,
# batch_size: int,
# global_mesh,
# shuffle: bool,
# num_epochs: Optional[int] = 1,
# pack_examples: bool = True,
# shuffle_buffer_size: int = 1024,
# max_length: int = 512,
# shift: bool = True,
# drop_remainder: bool = True,
# data_sharding = None,
# data_shuffle_seed = 0,
# ):
# dataset = normalize_features(dataset)
# dataset = dataset.filter(length_filter(max_length))


def preprocessing_pipeline_pygrain(
dataset,
Expand All @@ -202,7 +184,7 @@ def preprocessing_pipeline_pygrain(
data_sharding = None,
data_shuffle_seed = 0,
):

"""Apply pygrain operations to preprocess the given dataset."""
operations = []
operations.append(pygrain_operations.ParseFeatures())
operations.append(pygrain_operations.NormalizeFeatures())
Expand Down Expand Up @@ -279,6 +261,7 @@ def get_datasets_pygrain(
config: ml_collections.ConfigDict,
read_config = None,
):
"""Load dataset from array_record files for using with pygrain"""
data_dir = os.path.join(config.dataset_path, config.dataset_name)
train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)]
train_ds = pygrain.ArrayRecordDataSource(train_files)
Expand Down

0 comments on commit 5618295

Please sign in to comment.