Skip to content

Commit

Permalink
Refactoring Maxtext build process with stable stack
Browse files Browse the repository at this point in the history
  • Loading branch information
parambole committed Sep 19, 2024
1 parent 46d704a commit 2cc8166
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 94 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ jobs:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_nightly MODE=nightly DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_nightly
- name: build jax stable stack image
run : |
bash docker_maxtext_jax_stable_stack_image_upload.sh PROJECT_ID=tpu-prod-env-multipod BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.30-rev1 CLOUD_IMAGE_NAME=maxtext-jax-stable-stack IMAGE_TAG=jax0.4.30-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt DELETE_LOCAL_IMAGE=true
gpu:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable_stack_0.4.33 MODE=stable_stack DEVICE=TPU PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable_stack_0.4.33 BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu:jax0.4.33-rev1 MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
strategy:
fail-fast: false
matrix:
Expand Down
23 changes: 22 additions & 1 deletion docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# Example command:
# bash docker_build_dependency_image.sh MODE=stable
# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}} MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
# bash docker_build_dependency_image.sh MODE=nightly
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13

Expand Down Expand Up @@ -63,7 +64,27 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
if [[ ${MODE} == "stable_stack" ]]; then
if [[ ! -v BASEIMAGE ]]; then
echo "Erroring out because BASEIMAGE is unset, please set it!"
exit 1
fi
if [[ ! -v MAXTEXT_REQUIREMENTS_FILE ]]; then
echo "Erroring out because MAXTEXT_REQUIREMENTS_FILE is unset, please set it!"
exit 1
fi
COMMIT_HASH=$(git rev-parse --short HEAD)
echo "Building JAX Stable Stack MaxText at commit hash ${COMMIT_HASH} . . ."
docker build --no-cache \
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
--build-arg COMMIT_HASH=${COMMIT_HASH} \
--build-arg MAXTEXT_REQUIREMENTS_FILE=${MAXTEXT_REQUIREMENTS_FILE} \
--network=host \
-t ${LOCAL_IMAGE_NAME} \
-f ./maxtext_jax_stable_stack_tpu.Dockerfile .
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
fi
fi
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
Expand Down
90 changes: 0 additions & 90 deletions docker_maxtext_jax_stable_stack_image_upload.sh

This file was deleted.

2 changes: 1 addition & 1 deletion maxtext_jax_stable_stack_tpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ RUN if [ ! -z "${MAXTEXT_REQUIREMENTS_FILE}" ]; then \
fi

# Run the script available in JAX Stable Stack base image to generate the manifest file
RUN bash /generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH

0 comments on commit 2cc8166

Please sign in to comment.