Skip to content

Latest commit

 

History

History
33 lines (27 loc) · 3.13 KB

Data_Input_Pipeline.md

File metadata and controls

33 lines (27 loc) · 3.13 KB

Data Input Pipeline

Currently MaxText supports two data input pipelines: the tfds (tensorflow_datasets) based pipeline as default, and the Grain pipeline for determinism.

Deterministic Data Input Pipeline - Grain

MaxText users can optionally use Grain, a deterministic data input pipeline. With Grain, the indexes of data trained are saved in a tiny json file in checkpoints, which allows you to keep the data order, restart from the exact same data, and reproduce the same losses. The whole training process becomes reproducible, disruption-proof, and debuggable. To use this pipeline:

  1. Dataset needs to be in ArrayRecord format, which supports random access. For converting dataset into ArrayRecord, see instructions.
  2. ArrayRecord dataset, when hosted on GCS bucket, can only be read through Cloud Storage FUSE. The installation of Cloud Storage FUSE is included in setup.sh. User then needs to mount the GCS bucket to a local path for each worker, using the script setup_gcsfuse.sh. The script configs some parameters for the mount.
bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH
  1. Set dataset_type=c4-array_record and set dataset_path, dataset_name accordingly. dataset_path should be the same as $MOUNT_PATH in the above step. dataset_name is the path to the folder that contains the ArrayRecord dataset, so that os.path.join(config.dataset_path, config.dataset_name) is the full path to the ArrayRecord files.
  2. Tune grain_worker_count for performance. This parameter controls the number of child process used by Grain (more details in behind_the_scene, code). If you use a large number of workers, please check your config for gcsfuse in setup_gcsfuse.sh to avoid gcsfuse throttling.
  3. Example command:
bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxtext-dataset MOUNT_PATH=/tmp/gcsfuse && python3 MaxText/train.py MaxText/configs/base.yml run_name=<RUN_NAME> base_output_directory=gs://<MY_BUCKET>  dataset_path=/tmp/gcsfuse/ dataset_name='array-record/c4/en/3.0.1' dataset_type=c4-array_record grain_worker_count=2