Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add In Memory Changes for Pathways #854

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import numpy as np
import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.pathways_checkpoint_manager as pw_emergency_checkpoint_manager
from orbax.checkpoint.multihost.utils import is_pathways_on_cloud_backend

# pylint: disable=too-many-positional-arguments

Expand Down Expand Up @@ -101,16 +103,34 @@ def create_orbax_emergency_checkpoint_manager(
local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps),
persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps),
)
emergency_mngr = emergency_checkpoint_manager.CheckpointManager(
local_checkpoint_dir,
epath.Path(persistent_checkpoint_dir),
global_mesh=global_mesh,
abstract_state=abstract_state,
options=options,
local_state_handler=emergency_checkpoint_manager.local_checkpoint_handler(),
logger=orbax_logger,
max_logging.log(
"Determining if this is a pathways on cloud backend:"
f" {is_pathways_on_cloud_backend()}"
)

if is_pathways_on_cloud_backend():
local_state_handler = pw_emergency_checkpoint_manager.local_checkpoint_handler()
emergency_mngr = pw_emergency_checkpoint_manager.PathwaysCheckpointManager(
local_checkpoint_dir,
epath.Path(persistent_checkpoint_dir),
global_mesh=global_mesh,
abstract_state=abstract_state,
options=options,
local_state_handler=local_state_handler,
logger=orbax_logger,
)
else:
local_state_handler = emergency_checkpoint_manager.local_checkpoint_handler()
emergency_mngr = emergency_checkpoint_manager.CheckpointManager(
local_checkpoint_dir,
epath.Path(persistent_checkpoint_dir),
global_mesh=global_mesh,
abstract_state=abstract_state,
options=options,
local_state_handler=local_state_handler,
logger=orbax_logger,
)

max_logging.log("Emergency checkpoint manager created!")
return emergency_mngr

Expand Down Expand Up @@ -184,7 +204,7 @@ def load_state_if_possible(
def map_to_pspec(data):
pspec = data.sharding.spec
mesh = data.sharding.mesh
if not enable_single_replica_ckpt_restoring:
if not enable_single_replica_ckpt_restoring or isinstance(checkpoint_manager, pw_emergency_checkpoint_manager.PathwaysCheckpointManager):
return ocp.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec)
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
Expand All @@ -197,6 +217,7 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

# TO DO: These restore args may not be correct for Pathways.
return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
single_replica_sharding=single_replica_sharding,
Expand All @@ -209,7 +230,7 @@ def map_to_pspec(data):
abstract_unboxed_pre_state,
)

if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager):
if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager) or isinstance(checkpoint_manager, pw_emergency_checkpoint_manager.PathwaysCheckpointManager):
return (
checkpoint_manager.restore(
latest_step,
Expand Down
438 changes: 438 additions & 0 deletions MaxText/configs/pathways_in_memory.yml

Large diffs are not rendered by default.

441 changes: 441 additions & 0 deletions MaxText/configs/pathways_in_memory.yml.orig

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions MaxText/configs/v5e/1b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
echo "Running 1b.sh"
# 1B parameter model.
# This config will work out of the box for any number of v5e-256 slices.
#
# Command Flags:
# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
# DATASET_PATH (Required, unless dataset_path is already set in base.yml)
# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
#
# Example to invoke this script:
# bash MaxText/configs/v5e/1b.sh RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>"
#
# Example to AOT compile:
# bash MaxText/configs/v5e/1b.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2


# Stop execution if any command exits with error
set -e

export EXECUTABLE="train.py" # or train_compile.py
export RUN_PREFLIGHT="true"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

# The setup accommodates two cases:
# 1) Passing the 'RUN_NAME' variable at runtime
# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
if [ -n "$RUN_NAME" ];
then
export M_RUN_NAME=$RUN_NAME
fi

# Set up network optimizations
# if [ "$RUN_PREFLIGHT" = "true" ]; then
# bash preflight.sh
# fi

# Train using default global_parameter_scale
export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
python3 MaxText/$EXECUTABLE MaxText/configs/pathways_in_memory.yml\
steps=10000 per_device_batch_size=1\
remat_policy=full\
max_target_length=2048 base_output_directory=$OUTPUT_PATH\
dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
dataset_type=synthetic attention='flash' gcs_metrics=true\
enable_checkpointing=true\
enable_single_controller=true\
async_checkpointing=false\
checkpoint_period=3\
# global_parameter_scale=8
49 changes: 49 additions & 0 deletions MaxText/configs/v5e/8b.sh.orig
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
echo "Running 16b.sh"
# 16B parameter model.
# This config will work out of the box for any number of v5e-256 slices.
#
# Command Flags:
# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
# DATASET_PATH (Required, unless dataset_path is already set in base.yml)
# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
#
# Example to invoke this script:
# bash MaxText/configs/v5e/16b.sh RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>"
#
# Example to AOT compile:
# bash MaxText/configs/v5e/16b.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2


# Stop execution if any command exits with error
set -e

export EXECUTABLE="train.py" # or train_compile.py
export RUN_PREFLIGHT="true"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

# The setup accommodates two cases:
# 1) Passing the 'RUN_NAME' variable at runtime
# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow
if [ -n "$RUN_NAME" ];
then
export M_RUN_NAME=$RUN_NAME
fi

# Set up network optimizations
if [ "$RUN_PREFLIGHT" = "true" ]; then
bash preflight.sh
fi

# Train
export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
python3 MaxText/$EXECUTABLE MaxText/configs/base.yml\
steps=15 per_device_batch_size=6 enable_checkpointing=false\
remat_policy=full global_parameter_scale=16\
max_target_length=2048 base_output_directory=$OUTPUT_PATH\
dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
dataset_type=synthetic attention='flash' gcs_metrics=true
3 changes: 2 additions & 1 deletion MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jax.experimental import mesh_utils
import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.checkpoint_manager as pw_emergency_checkpoint_manager


import json
Expand Down Expand Up @@ -556,7 +557,7 @@ def setup_initial_state(
)

if restored:
if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager):
if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager) or isinstance(checkpoint_manager, pw_emergency_checkpoint_manager.PathwaysCheckpointManager):
state = restored
else:
if "iter" in restored and restored["iter"] is not None:
Expand Down
5 changes: 4 additions & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pathwaysutils

"""
Copyright 2023 Google LLC

Expand Down Expand Up @@ -35,6 +37,7 @@
import numpy as np
import orbax.checkpoint
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.pathways_checkpoint_manager as pw_emergency_checkpoint_manager

import checkpointing
import max_utils
Expand Down Expand Up @@ -210,7 +213,7 @@ def save_checkpoint(
# specify chunk_byte_size to force orbax to control maximum file size in checkpoint
save_args = jax.tree.map(lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=_CHUNK_BYTE_SIZE), state)

if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager):
if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager) or isinstance(checkpoint_manager, pw_emergency_checkpoint_manager.PathwaysCheckpointManager):
return checkpoint_manager.save(
step,
args=orbax.checkpoint.args.PyTreeSave(item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE),
Expand Down
10 changes: 6 additions & 4 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ set -e

export LOCAL_IMAGE_NAME=maxtext_base_image

# Use Docker BuildKit so we can cache pip packages.
# Use Docker Build --no-cacheKit so we can cache pip packages.
export DOCKER_BUILDKIT=1

echo "Starting to build your docker image. This will take a few minutes but the image can be reused as you iterate."
Expand Down Expand Up @@ -62,7 +62,7 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
else
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --no-cache --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
if [[ ${MODE} == "stable_stack" ]]; then
if [[ ! -v BASEIMAGE ]]; then
Expand All @@ -86,8 +86,8 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
fi
fi
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --no-cache --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
docker build --no-cache --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} .
fi

echo ""
Expand All @@ -102,3 +102,5 @@ echo ""
echo "You can run MaxText and your development tests inside of the docker image. Changes to your workspace will automatically
be reflected inside the docker container."
echo "Once you want you upload your docker container to GCR, take a look at docker_upload_runner.sh"


10 changes: 7 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
jax>=0.4.30
jaxlib>=0.4.30
--pre
jax>=0.4.34.dev20240922
jaxlib>=0.4.34.dev20240922
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html

orbax-checkpoint>=0.5.21
absl-py
array-record
aqtp
aqtp==0.7.5
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
datasets
Expand Down Expand Up @@ -33,3 +36,4 @@ transformers
mlperf-logging@git+https://github.com/mlperf/logging.git
google-jetstream
jsonlines
pathwaysutils@git+https://github.com/google/pathways-utils.git@test_675739460
2 changes: 2 additions & 0 deletions rto_setup.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
echo Skipping rto_setup.sh; exit 0

echo "Running rto_setup.sh"

# Stop execution if any command exits with error
Expand Down
9 changes: 9 additions & 0 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,12 @@ if [[ "$MODE" == "pinned" ]]; then
else
pip3 install -U -r requirements.txt
fi

#################################
# Changes needed to test in-memory checkpointing
#################################

# Add dev orbax from github changes
yes | pip3 uninstall orbax-checkpoint
pip3 install git+https://github.com/google/orbax.git@test_675739460#subdirectory=checkpoint

Loading