From 812c0231d89ead54fc865838de2ebabb1bd6a71c Mon Sep 17 00:00:00 2001 From: Bernard Han Date: Wed, 4 Sep 2024 23:59:56 +0000 Subject: [PATCH 1/3] add logging statement --- MaxText/torch_datasets/parquet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/MaxText/torch_datasets/parquet.py b/MaxText/torch_datasets/parquet.py index ce15a3f25..98b90638c 100644 --- a/MaxText/torch_datasets/parquet.py +++ b/MaxText/torch_datasets/parquet.py @@ -113,9 +113,12 @@ class FileParallelSequentialRead(ParquetIterableDataset): def _iter_impl(self, assigned_parquet_files: Iterable[str]) -> Iterable: """File Parallel, Sequential Read iterator.""" + worker_info = torch.utils.data.get_worker_info() for each_parquet_file in assigned_parquet_files: table = pq.ParquetFile(each_parquet_file) for batch in table.iter_batches( batch_size=self.batch_size, columns=self.columns ): - yield from batch.to_pylist() + res = batch.to_pylist() + max_logging.log(f"Worker {worker_info.id} retrieving a batch from {each_parquet_file}") + yield from res From c00e7a61c4f2e719f493d3f2e382e3a2bcf48334 Mon Sep 17 00:00:00 2001 From: Bernard Han Date: Mon, 9 Sep 2024 00:21:39 +0000 Subject: [PATCH 2/3] check in test code --- MaxText/configs/base.yml | 3 +++ MaxText/standalone_dataloader.py | 8 +++++++- MaxText/torch_datasets/parquet.py | 23 ++++++++++++++++++++--- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index aff57e035..07bca2ee4 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -375,3 +375,6 @@ gcs_metrics_bucket: "distributed-training-metrics" # The strategy used for reading the dataset. The options for this field are specified by the keys of # `DATA_LOADER_STRATEGIES_BY_NAME` in MaxText/standalone_dataloader.py. data_loader_strategy_name: FileParallelSequentialRead +pin_memory: True +persistent_workers: True +rough_desired_simulated_data_sample_size: 53000 \ No newline at end of file diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index b004475fe..4055434d9 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -136,12 +136,16 @@ def parquet_data_loader(config): allocated_parquet_files=sublists[worker_id], batch_size=batch_size, columns=["outputs", "image_base64_str"], + config=config, ) data_loader = DataLoader( dataset=dataset, num_workers=config.data_loader_num_workers, batch_size=batch_size, + # batch_size=1, prefetch_factor=config.prefetch_factor, + pin_memory=config.pin_memory, + persistent_workers=config.persistent_workers, ) return data_loader @@ -206,7 +210,9 @@ def data_load_loop(config): local_steps = 0 step_data_loading_start = datetime.datetime.now() step_start = datetime.datetime.now() - for _ in data_loader: + # for outputs, image_base64_strs in data_loader: + for batch in data_loader: + max_logging.log(f"STANDALONE DATALOADER : Obtained {len(batch['outputs'])} outputs and {len(batch['image_base64_str'])} images.") step_data_loading_end = datetime.datetime.now() data_loading_interval = (step_data_loading_end - step_data_loading_start).total_seconds() max_logging.log( diff --git a/MaxText/torch_datasets/parquet.py b/MaxText/torch_datasets/parquet.py index 98b90638c..f547e0a30 100644 --- a/MaxText/torch_datasets/parquet.py +++ b/MaxText/torch_datasets/parquet.py @@ -17,10 +17,13 @@ import random import abc from typing import Iterable +import time import torch from torch.utils.data import IterableDataset import pyarrow.parquet as pq +import string +import random import max_logging @@ -32,12 +35,13 @@ class ParquetIterableDataset(abc.ABC, IterableDataset): Implementers must override the `_iter_impl` method. """ - def __init__(self, allocated_parquet_files: Iterable[str], columns=None, batch_size=1000): + def __init__(self, allocated_parquet_files: Iterable[str], columns=None, batch_size=1000, config=None): max_logging.log(f'Using {self.__class__.__name__} strategy.') max_logging.log(f'Allocated with the following data files: {allocated_parquet_files}.') self.allocated_parquet_files = allocated_parquet_files self.columns = columns self.batch_size = batch_size + self.config = config @abc.abstractmethod def _iter_impl(self, assigned_parquet_files: Iterable[str]) -> Iterable: @@ -110,7 +114,14 @@ def _random_iterator(self, itr: Iterable): class FileParallelSequentialRead(ParquetIterableDataset): """File Parallel, Sequential Read implementation for Parquet files.""" - + def _construct_data_sample(self, target_size): + if target_size < 101: + target_size = 101 + output_len = 100 + output = "".join(random.choices(string.ascii_letters, k=output_len)) + image_base64_str = "".join(random.choices(string.ascii_letters, k=target_size - output_len)) + return {"outputs": output, "image_base64_str": image_base64_str} + def _iter_impl(self, assigned_parquet_files: Iterable[str]) -> Iterable: """File Parallel, Sequential Read iterator.""" worker_info = torch.utils.data.get_worker_info() @@ -119,6 +130,12 @@ def _iter_impl(self, assigned_parquet_files: Iterable[str]) -> Iterable: for batch in table.iter_batches( batch_size=self.batch_size, columns=self.columns ): + # The actual obtained data sample is not used to return but keeping + # it here to account for the actual data loading time from filesystem. res = batch.to_pylist() max_logging.log(f"Worker {worker_info.id} retrieving a batch from {each_parquet_file}") - yield from res + + # generated_res = [self._construct_data_sample(self.config.rough_desired_simulated_data_sample_size) for _ in range(self.batch_size)] + # yield from generated_res + for _ in range(self.batch_size): + yield self._construct_data_sample(self.config.rough_desired_simulated_data_sample_size) \ No newline at end of file From f5c57153c237914bb63b8f1982bcc69521ff8938 Mon Sep 17 00:00:00 2001 From: Bernard Han Date: Mon, 9 Sep 2024 19:02:44 +0000 Subject: [PATCH 3/3] check in everything for framework validation --- .vscode/launch.json | 70 ------------------------ MaxText/configs/cpu.sh | 55 +++++++++++++++++++ MaxText/deployment.yaml | 117 +++++++++++++++++++++++++--------------- 3 files changed, 128 insertions(+), 114 deletions(-) delete mode 100644 .vscode/launch.json create mode 100644 MaxText/configs/cpu.sh diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index ddd8eb0f6..000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,70 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Debug MaxText Decode", - "type": "python", - "request": "launch", - "console": "integratedTerminal", - "justMyCode": false, - "python": "python3", - "program": "${workspaceFolder}/MaxText/decode.py", - "args": ["MaxText/configs/base.yml", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", - "base_output_directory=gs://test-maxtext-output", - "dataset_path=gs://test-maxtext-dataset", - "steps=2", - "attention=dot_product", - "enable_checkpointing=false"] - }, - { - "name": "Debug MaxText Train", - "type": "python", - "request": "launch", - "console": "integratedTerminal", - "justMyCode": false, - "python": "python3", - "program": "${workspaceFolder}/MaxText/train.py", - "args": ["MaxText/configs/base.yml", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", - "base_output_directory=gs://test-maxtext-output", - "dataset_path=gs://test-maxtext-dataset", - "steps=2", - "enable_checkpointing=false"] - }, - { - "name": "Debug MaxText Inference Microbenchmark", - "type": "python", - "request": "launch", - "console": "integratedTerminal", - "justMyCode": false, - "python": "python3", - "program": "${workspaceFolder}/MaxText/inference_microbenchmark.py", - "args": [ - "MaxText/configs/base.yml", - "model_name=llama2-7b", - "tokenizer_path=assets/tokenizer.llama2", - "weight_dtype=bfloat16", - "scan_layers=false", - "attention=dot_product", - "max_prefill_predict_length=1024", - "max_target_length=2048", - "ici_fsdp_parallelism=1", - "ici_tensor_parallelism=-1", - "ici_autoregressive_parallelism=1", - "inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024", - "inference_microbenchmark_stages=generate", - "inference_microbenchmark_loop_iters=1", - "run_name=runner_$(date +%Y-%m-%d-%H-%M)", - "base_output_directory=gs://test-maxtext-output", - "prefill_cache_axis_order=0,2,1,3", - "ar_cache_axis_order=0,2,1,3", - "compute_axis_order=0,2,1,3", - "reshape_q=true", - "per_device_batch_size=24", - "quantization=int8", - "quantize_kvcache=True", - ] - }, - ] -} \ No newline at end of file diff --git a/MaxText/configs/cpu.sh b/MaxText/configs/cpu.sh new file mode 100644 index 000000000..55639d8ae --- /dev/null +++ b/MaxText/configs/cpu.sh @@ -0,0 +1,55 @@ +echo "Running cpu.sh" +# Parameter model running on CPU. +# This config will work out of the box for any number of v5e-256 slices. +# +# Command Flags: +# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) +# DATASET_PATH (Required, unless dataset_path is already set in base.yml) +# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) +# PLATFORM (Optional, can be "gke" or "gce", default is "gce") +# +# Example to invoke this script: +# bash MaxText/configs/v5e/cpu.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" +# +# Example to AOT compile: +# bash MaxText/configs/v5e/cpu.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 + + +# Stop execution if any command exits with error +set -e + +export PLATFORM="gce" +export EXECUTABLE="train.py" # or train_compile.py + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +# Use 64b parameters if not set. +PARAMETERS="${PARAMETERS:-64}" +echo "Using ${PARAMETERS}b parameters" + +# The setup accommodates two cases: +# 1) Passing the 'RUN_NAME' variable at runtime +# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow +if [ -n "$RUN_NAME" ]; +then + export M_RUN_NAME=$RUN_NAME +fi + +# Set up network optimizations +bash preflight.sh PLATFORM=$PLATFORM + +# Train +export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +JAX_PLATFORMS=cpu python3 MaxText/$EXECUTABLE MaxText/configs/base.yml\ + steps=$STEPS checkpoint_period=$CHECKPOINT_PERIOD per_device_batch_size=1 enable_checkpointing=true\ + async_checkpointing=false\ + remat_policy=full global_parameter_scale=$PARAMETERS\ + max_target_length=2048 base_output_directory=$OUTPUT_PATH\ + hardware=$HARDWARE\ + use_iota_embed=true reuse_example_batch=1\ + dataset_type=synthetic attention='flash' gcs_metrics=true\ + load_full_state_path=$PREVIOUS_STATE \ No newline at end of file diff --git a/MaxText/deployment.yaml b/MaxText/deployment.yaml index 72d43b9b7..c98e39bf9 100644 --- a/MaxText/deployment.yaml +++ b/MaxText/deployment.yaml @@ -3,7 +3,7 @@ apiVersion: v1 kind: PersistentVolume metadata: - name: gcs-distributed-training-pv + name: bernardhan-gcs-distributed-training-pv spec: accessModes: - ReadWriteMany @@ -12,7 +12,7 @@ spec: persistentVolumeReclaimPolicy: Retain storageClassName: gcsfuse-sc # dummy storage class claimRef: - name: training-test-claim + name: bernardhan-gcs-training-test-claim # Running in the "default" namespace so you can submit to the local queue # created in the default namespace. namespace: default @@ -24,12 +24,13 @@ spec: - metadata-cache:stat-cache-max-size-mb:-1 - metadata-cache:type-cache-max-size-mb:-1 - file-system:kernel-list-cache-ttl-secs:-1 - - file-cache:max-size-mb:-1 - - file-cache:cache-file-for-range-read:false - - file-cache:enable-parallel-downloads:false + # DISABLES CACHE + # - file-cache:max-size-mb:-1 + # - file-cache:cache-file-for-range-read:true + # - file-cache:enable-parallel-downloads:true csi: driver: gcsfuse.csi.storage.gke.io - volumeHandle: xai-hf-dataset-parquet # unique bucket name. xai-hf-dataset-parquet-10g for the 10G * 12K dataset. + volumeHandle: xai-hf-dataset-parquet-10g # unique bucket name. xai-hf-dataset-parquet-10g for the 10G * 12K dataset. volumeAttributes: enableMetrics: "true" skipCSIBucketAccessCheck: "true" @@ -37,7 +38,7 @@ spec: apiVersion: v1 kind: PersistentVolumeClaim metadata: - name: training-test-claim + name: bernardhan-gcs-training-test-claim namespace: default spec: accessModes: @@ -45,34 +46,34 @@ spec: resources: requests: storage: 64Gi - volumeName: gcs-distributed-training-pv + volumeName: bernardhan-gcs-distributed-training-pv storageClassName: gcsfuse-sc # dummy storage class --- apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: # Modify this name to distinguish your workload from others. - name: xpk-test-workload + name: bernardhan-gcs-training-workload labels: - kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue - xpk.google.com/workload: xpk-test-workload + # kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue + xpk.google.com/workload: bernardhan-gcs-training-workload annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool # 1:1 job replica to node pool assignment spec: failurePolicy: maxRestarts: 0 replicatedJobs: - - name: slice-job + - name: benchmark-job replicas: 1 template: spec: - parallelism: 512 # Equal to the number of VMs per slice - completions: 512 # Same as the above. + parallelism: 1 # Equal to the number of VMs per slice + completions: 1 # Same as the above. backoffLimit: 0 # When any pod fails, the job is failed template: metadata: labels: - xpk.google.com/workload: xpk-test-workload + xpk.google.com/workload: bernardhan-gcs-training-workload # Required for GCSFuse. # For other storage solutions, please modify this section. annotations: @@ -83,20 +84,21 @@ spec: spec: initContainers: - # Metadata Prefetch native sidecar. - # Added to test the GCSfuse - Tuning and best practices for AI/ML workloads: - # https://docs.google.com/document/d/1NI64_qfTPBOQBmn_AOUwwFt7XQBQYrCqeLKIeKYkx5w/edit?tab=t.0#bookmark=id.i4mbb8t99ic2 - - name: metadata-prefetch-container - image: ubuntu:22.04 + # manually inject the gcsfuse sidecar container + - args: + - --v=5 + env: + - name: NATIVE_SIDECAR + value: "TRUE" + image: jiaxun/gcs-fuse-csi-driver-sidecar-mounter:v999.999.999 + imagePullPolicy: IfNotPresent + name: gke-gcsfuse-sidecar + resources: + requests: + cpu: 250m + ephemeral-storage: 5Gi + memory: 256Mi restartPolicy: Always - command: - - "/bin/sh" - - "-c" - - | - echo "Starting ls on the bucket..." - # Redirect output to /dev/null to prevent storage of output. - echo "Metadata prefetch for /mnt/gcsfuse..." && ls -R /mnt/gcsfuse > /dev/null && echo "Metadata prefetch for /mnt/gcsfuse complete." & - tail -f /dev/null securityContext: allowPrivilegeEscalation: false capabilities: @@ -109,8 +111,12 @@ spec: seccompProfile: type: RuntimeDefault volumeMounts: - - name: gcs-pvc - mountPath: /mnt/gcsfuse + - mountPath: /gcsfuse-tmp + name: gke-gcsfuse-tmp + - mountPath: /gcsfuse-buffer + name: gke-gcsfuse-buffer + - mountPath: /gcsfuse-cache + name: gke-gcsfuse-cache schedulerName: default-scheduler restartPolicy: Never @@ -135,10 +141,10 @@ spec: terminationGracePeriodSeconds: 30 # For GCSFuse: the setup of K8S SA is needed. https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver#authentication # For other storage solutions that do not need the K8S SA, please remove this line. - serviceAccountName: bernardhan-benchmark + serviceAccountName: tess-dataloading-benchmarks containers: - name: jax-cpu - image: gcr.io/gcs-tess/distributed_pytorch_training_benchmark + image: gcr.io/gcs-tess/bernardhan_test_framework_yield_random_chars_no_from env: - name: REPLICATED_JOB_NAME @@ -155,13 +161,29 @@ spec: fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] # Modify the following two values too, if you intend to run the workload in smaller scale. - name: PROCESSES_IN_JOB - value: "512" + value: "1" - name: JAX_PROCESS_COUNT - value: "512" + value: "1" - name: JOBSET_NAME - value: "xpk-test-workload" + value: "bernardhan-gcs-training-workload" - name: JAX_COORDINATOR_ADDRESS value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)" + - name: MY_NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + - name: MY_POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: MY_POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + - name: MY_NODE_IP + valueFrom: + fieldRef: + fieldPath: status.hostIP ports: - containerPort: 8471 @@ -174,20 +196,20 @@ spec: - -c - | # Modify the parameters here. - export RUN_NAME="YOUR_RUN_NAME" + export RUN_NAME=bernardhan-test-framework-09-07-06 export DATASET_DIRECTORY="/mnt/gcsfuse" export EPOCHS=2 - export MAX_STEPS=-1 - export LOCAL_BATCH_SIZE=32 - export PREFETCH_FACTOR=2 - export DATA_LOADER_NUM_WORKERS=10 - export PER_STEP_INTERVAL=0.1 + export MAX_STEPS=100 + export LOCAL_BATCH_SIZE=256 + export PREFETCH_FACTOR=5 + export DATA_LOADER_NUM_WORKERS=2 + export PER_STEP_INTERVAL=5 export DATA_LOADER_STRATEGY_NAME="FileParallelSequentialRead" export GCS_METRICS_BUCKET="distributed-training-metrics" - + export TARGET_DATA_SAMPLE_SIZE=120 # Not recommended to modify the flags below. export COMMON_RUN_FLAGS="enable_checkpointing=False hardware=cpu"; - export BENCHMARK_RUN_FLAGS="run_name=${RUN_NAME} dataset_directory=${DATASET_DIRECTORY} epochs=${EPOCHS} max_steps=${MAX_STEPS} local_batch_size=${LOCAL_BATCH_SIZE} prefetch_factor=${PREFETCH_FACTOR} data_loader_num_workers=${DATA_LOADER_NUM_WORKERS} per_step_interval=${PER_STEP_INTERVAL} data_loader_strategy_name=${DATA_LOADER_STRATEGY_NAME} gcs_metrics_bucket=${GCS_METRICS_BUCKET}"; + export BENCHMARK_RUN_FLAGS="run_name=${RUN_NAME} dataset_directory=${DATASET_DIRECTORY} epochs=${EPOCHS} max_steps=${MAX_STEPS} local_batch_size=${LOCAL_BATCH_SIZE} prefetch_factor=${PREFETCH_FACTOR} data_loader_num_workers=${DATA_LOADER_NUM_WORKERS} per_step_interval=${PER_STEP_INTERVAL} data_loader_strategy_name=${DATA_LOADER_STRATEGY_NAME} gcs_metrics_bucket=${GCS_METRICS_BUCKET} rough_desired_simulated_data_sample_size=${TARGET_DATA_SAMPLE_SIZE}"; echo XPK Start: $(date) ; _sigterm() ( kill -SIGTERM $! 2>/dev/null;); trap _sigterm SIGTERM;(JAX_PLATFORMS=cpu python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml ${BENCHMARK_RUN_FLAGS} ${COMMON_RUN_FLAGS}) & PID=$!; while kill -0 $PID 2>/dev/null; do sleep 5; done; wait $PID; EXIT_CODE=$? ; echo XPK End: $(date); echo EXIT_CODE=$EXIT_CODE; volumeMounts: @@ -198,6 +220,13 @@ spec: readOnly: true volumes: + # manually inject gcsfuse related emptyDir + - emptyDir: {} + name: gke-gcsfuse-tmp + - emptyDir: {} + name: gke-gcsfuse-buffer + - emptyDir: {} + name: gke-gcsfuse-cache - name: gcs-pvc persistentVolumeClaim: - claimName: training-test-claim \ No newline at end of file + claimName: bernardhan-gcs-training-test-claim \ No newline at end of file