Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test code to produce Lab Notes - 2024-09-07.ipynb #866

Draft
wants to merge 3 commits into
base: gcs-distributed-training-benchmark
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 0 additions & 70 deletions .vscode/launch.json

This file was deleted.

3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,6 @@ gcs_metrics_bucket: "distributed-training-metrics"
# The strategy used for reading the dataset. The options for this field are specified by the keys of
# `DATA_LOADER_STRATEGIES_BY_NAME` in MaxText/standalone_dataloader.py.
data_loader_strategy_name: FileParallelSequentialRead
pin_memory: True
persistent_workers: True
rough_desired_simulated_data_sample_size: 53000
55 changes: 55 additions & 0 deletions MaxText/configs/cpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
echo "Running cpu.sh"
# Parameter model running on CPU.
# This config will work out of the box for any number of v5e-256 slices.
#
# Command Flags:
# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
# DATASET_PATH (Required, unless dataset_path is already set in base.yml)
# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
# PLATFORM (Optional, can be "gke" or "gce", default is "gce")
#
# Example to invoke this script:
# bash MaxText/configs/v5e/cpu.sh RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>" PLATFORM="gke"
#
# Example to AOT compile:
# bash MaxText/configs/v5e/cpu.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2


# Stop execution if any command exits with error
set -e

export PLATFORM="gce"
export EXECUTABLE="train.py" # or train_compile.py

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

# Use 64b parameters if not set.
PARAMETERS="${PARAMETERS:-64}"
echo "Using ${PARAMETERS}b parameters"

# The setup accommodates two cases:
# 1) Passing the 'RUN_NAME' variable at runtime
# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
if [ -n "$RUN_NAME" ];
then
export M_RUN_NAME=$RUN_NAME
fi

# Set up network optimizations
bash preflight.sh PLATFORM=$PLATFORM

# Train
export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
JAX_PLATFORMS=cpu python3 MaxText/$EXECUTABLE MaxText/configs/base.yml\
steps=$STEPS checkpoint_period=$CHECKPOINT_PERIOD per_device_batch_size=1 enable_checkpointing=true\
async_checkpointing=false\
remat_policy=full global_parameter_scale=$PARAMETERS\
max_target_length=2048 base_output_directory=$OUTPUT_PATH\
hardware=$HARDWARE\
use_iota_embed=true reuse_example_batch=1\
dataset_type=synthetic attention='flash' gcs_metrics=true\
load_full_state_path=$PREVIOUS_STATE
117 changes: 73 additions & 44 deletions MaxText/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
apiVersion: v1
kind: PersistentVolume
metadata:
name: gcs-distributed-training-pv
name: bernardhan-gcs-distributed-training-pv
spec:
accessModes:
- ReadWriteMany
Expand All @@ -12,7 +12,7 @@ spec:
persistentVolumeReclaimPolicy: Retain
storageClassName: gcsfuse-sc # dummy storage class
claimRef:
name: training-test-claim
name: bernardhan-gcs-training-test-claim
# Running in the "default" namespace so you can submit to the local queue
# created in the default namespace.
namespace: default
Expand All @@ -24,55 +24,56 @@ spec:
- metadata-cache:stat-cache-max-size-mb:-1
- metadata-cache:type-cache-max-size-mb:-1
- file-system:kernel-list-cache-ttl-secs:-1
- file-cache:max-size-mb:-1
- file-cache:cache-file-for-range-read:false
- file-cache:enable-parallel-downloads:false
# DISABLES CACHE
# - file-cache:max-size-mb:-1
# - file-cache:cache-file-for-range-read:true
# - file-cache:enable-parallel-downloads:true
csi:
driver: gcsfuse.csi.storage.gke.io
volumeHandle: xai-hf-dataset-parquet # unique bucket name. xai-hf-dataset-parquet-10g for the 10G * 12K dataset.
volumeHandle: xai-hf-dataset-parquet-10g # unique bucket name. xai-hf-dataset-parquet-10g for the 10G * 12K dataset.
volumeAttributes:
enableMetrics: "true"
skipCSIBucketAccessCheck: "true"
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: training-test-claim
name: bernardhan-gcs-training-test-claim
namespace: default
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 64Gi
volumeName: gcs-distributed-training-pv
volumeName: bernardhan-gcs-distributed-training-pv
storageClassName: gcsfuse-sc # dummy storage class
---
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
# Modify this name to distinguish your workload from others.
name: xpk-test-workload
name: bernardhan-gcs-training-workload
labels:
kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue
xpk.google.com/workload: xpk-test-workload
# kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue
xpk.google.com/workload: bernardhan-gcs-training-workload
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool # 1:1 job replica to node pool assignment
spec:
failurePolicy:
maxRestarts: 0
replicatedJobs:
- name: slice-job
- name: benchmark-job
replicas: 1
template:
spec:
parallelism: 512 # Equal to the number of VMs per slice
completions: 512 # Same as the above.
parallelism: 1 # Equal to the number of VMs per slice
completions: 1 # Same as the above.
backoffLimit: 0 # When any pod fails, the job is failed
template:
metadata:
labels:
xpk.google.com/workload: xpk-test-workload
xpk.google.com/workload: bernardhan-gcs-training-workload
# Required for GCSFuse.
# For other storage solutions, please modify this section.
annotations:
Expand All @@ -83,20 +84,21 @@ spec:

spec:
initContainers:
# Metadata Prefetch native sidecar.
# Added to test the GCSfuse - Tuning and best practices for AI/ML workloads:
# https://docs.google.com/document/d/1NI64_qfTPBOQBmn_AOUwwFt7XQBQYrCqeLKIeKYkx5w/edit?tab=t.0#bookmark=id.i4mbb8t99ic2
- name: metadata-prefetch-container
image: ubuntu:22.04
# manually inject the gcsfuse sidecar container
- args:
- --v=5
env:
- name: NATIVE_SIDECAR
value: "TRUE"
image: jiaxun/gcs-fuse-csi-driver-sidecar-mounter:v999.999.999
imagePullPolicy: IfNotPresent
name: gke-gcsfuse-sidecar
resources:
requests:
cpu: 250m
ephemeral-storage: 5Gi
memory: 256Mi
restartPolicy: Always
command:
- "/bin/sh"
- "-c"
- |
echo "Starting ls on the bucket..."
# Redirect output to /dev/null to prevent storage of output.
echo "Metadata prefetch for /mnt/gcsfuse..." && ls -R /mnt/gcsfuse > /dev/null && echo "Metadata prefetch for /mnt/gcsfuse complete." &
tail -f /dev/null
securityContext:
allowPrivilegeEscalation: false
capabilities:
Expand All @@ -109,8 +111,12 @@ spec:
seccompProfile:
type: RuntimeDefault
volumeMounts:
- name: gcs-pvc
mountPath: /mnt/gcsfuse
- mountPath: /gcsfuse-tmp
name: gke-gcsfuse-tmp
- mountPath: /gcsfuse-buffer
name: gke-gcsfuse-buffer
- mountPath: /gcsfuse-cache
name: gke-gcsfuse-cache

schedulerName: default-scheduler
restartPolicy: Never
Expand All @@ -135,10 +141,10 @@ spec:
terminationGracePeriodSeconds: 30
# For GCSFuse: the setup of K8S SA is needed. https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver#authentication
# For other storage solutions that do not need the K8S SA, please remove this line.
serviceAccountName: bernardhan-benchmark
serviceAccountName: tess-dataloading-benchmarks
containers:
- name: jax-cpu
image: gcr.io/gcs-tess/distributed_pytorch_training_benchmark
image: gcr.io/gcs-tess/bernardhan_test_framework_yield_random_chars_no_from

env:
- name: REPLICATED_JOB_NAME
Expand All @@ -155,13 +161,29 @@ spec:
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
# Modify the following two values too, if you intend to run the workload in smaller scale.
- name: PROCESSES_IN_JOB
value: "512"
value: "1"
- name: JAX_PROCESS_COUNT
value: "512"
value: "1"
- name: JOBSET_NAME
value: "xpk-test-workload"
value: "bernardhan-gcs-training-workload"
- name: JAX_COORDINATOR_ADDRESS
value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
- name: MY_NODE_NAME
valueFrom:
fieldRef:
fieldPath: spec.nodeName
- name: MY_POD_NAME
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: MY_POD_IP
valueFrom:
fieldRef:
fieldPath: status.podIP
- name: MY_NODE_IP
valueFrom:
fieldRef:
fieldPath: status.hostIP

ports:
- containerPort: 8471
Expand All @@ -174,20 +196,20 @@ spec:
- -c
- |
# Modify the parameters here.
export RUN_NAME="YOUR_RUN_NAME"
export RUN_NAME=bernardhan-test-framework-09-07-06
export DATASET_DIRECTORY="/mnt/gcsfuse"
export EPOCHS=2
export MAX_STEPS=-1
export LOCAL_BATCH_SIZE=32
export PREFETCH_FACTOR=2
export DATA_LOADER_NUM_WORKERS=10
export PER_STEP_INTERVAL=0.1
export MAX_STEPS=100
export LOCAL_BATCH_SIZE=256
export PREFETCH_FACTOR=5
export DATA_LOADER_NUM_WORKERS=2
export PER_STEP_INTERVAL=5
export DATA_LOADER_STRATEGY_NAME="FileParallelSequentialRead"
export GCS_METRICS_BUCKET="distributed-training-metrics"

export TARGET_DATA_SAMPLE_SIZE=120
# Not recommended to modify the flags below.
export COMMON_RUN_FLAGS="enable_checkpointing=False hardware=cpu";
export BENCHMARK_RUN_FLAGS="run_name=${RUN_NAME} dataset_directory=${DATASET_DIRECTORY} epochs=${EPOCHS} max_steps=${MAX_STEPS} local_batch_size=${LOCAL_BATCH_SIZE} prefetch_factor=${PREFETCH_FACTOR} data_loader_num_workers=${DATA_LOADER_NUM_WORKERS} per_step_interval=${PER_STEP_INTERVAL} data_loader_strategy_name=${DATA_LOADER_STRATEGY_NAME} gcs_metrics_bucket=${GCS_METRICS_BUCKET}";
export BENCHMARK_RUN_FLAGS="run_name=${RUN_NAME} dataset_directory=${DATASET_DIRECTORY} epochs=${EPOCHS} max_steps=${MAX_STEPS} local_batch_size=${LOCAL_BATCH_SIZE} prefetch_factor=${PREFETCH_FACTOR} data_loader_num_workers=${DATA_LOADER_NUM_WORKERS} per_step_interval=${PER_STEP_INTERVAL} data_loader_strategy_name=${DATA_LOADER_STRATEGY_NAME} gcs_metrics_bucket=${GCS_METRICS_BUCKET} rough_desired_simulated_data_sample_size=${TARGET_DATA_SAMPLE_SIZE}";
echo XPK Start: $(date) ; _sigterm() ( kill -SIGTERM $! 2>/dev/null;); trap _sigterm SIGTERM;(JAX_PLATFORMS=cpu python3 MaxText/standalone_dataloader.py MaxText/configs/base.yml ${BENCHMARK_RUN_FLAGS} ${COMMON_RUN_FLAGS}) & PID=$!; while kill -0 $PID 2>/dev/null; do sleep 5; done; wait $PID; EXIT_CODE=$? ; echo XPK End: $(date); echo EXIT_CODE=$EXIT_CODE;

volumeMounts:
Expand All @@ -198,6 +220,13 @@ spec:
readOnly: true

volumes:
# manually inject gcsfuse related emptyDir
- emptyDir: {}
name: gke-gcsfuse-tmp
- emptyDir: {}
name: gke-gcsfuse-buffer
- emptyDir: {}
name: gke-gcsfuse-cache
- name: gcs-pvc
persistentVolumeClaim:
claimName: training-test-claim
claimName: bernardhan-gcs-training-test-claim
8 changes: 7 additions & 1 deletion MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,16 @@ def parquet_data_loader(config):
allocated_parquet_files=sublists[worker_id],
batch_size=batch_size,
columns=["outputs", "image_base64_str"],
config=config,
)
data_loader = DataLoader(
dataset=dataset,
num_workers=config.data_loader_num_workers,
batch_size=batch_size,
# batch_size=1,
prefetch_factor=config.prefetch_factor,
pin_memory=config.pin_memory,
persistent_workers=config.persistent_workers,
)
return data_loader

Expand Down Expand Up @@ -206,7 +210,9 @@ def data_load_loop(config):
local_steps = 0
step_data_loading_start = datetime.datetime.now()
step_start = datetime.datetime.now()
for _ in data_loader:
# for outputs, image_base64_strs in data_loader:
for batch in data_loader:
max_logging.log(f"STANDALONE DATALOADER : Obtained {len(batch['outputs'])} outputs and {len(batch['image_base64_str'])} images.")
step_data_loading_end = datetime.datetime.now()
data_loading_interval = (step_data_loading_end - step_data_loading_start).total_seconds()
max_logging.log(
Expand Down
Loading
Loading