Skip to content

Commit

Permalink
Merge pull request #547 from google:yiinho-prebuilt-te
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624316618
  • Loading branch information
maxtext authors committed Apr 12, 2024
2 parents e947d62 + dd8fef5 commit ebd39aa
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 28 deletions.
31 changes: 20 additions & 11 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ jobs:
- name: Test int8_training
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false'
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false'
- name: Test fp8_training
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
Expand All @@ -123,47 +123,56 @@ jobs:
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
'python3 pedagogical_examples/shmap_collective_matmul.py'
# IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job
gpu:
strategy:
fail-fast: false
matrix:
device-type: ["a100-40gb-4"]
name: "GPU test (${{ matrix.device-type }})"
build-mode: ["pinned"]
name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})"
runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"]
env:
LOCAL_IMAGE_NAME: "maxtext_base_image_${{ matrix.build-mode }}_${{ github.sha }}"
steps:
- uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Cleanup old docker images
run: |
docker system prune --all --force
- name: Install dependencies
run: |
bash docker_build_dependency_image.sh DEVICE=gpu
bash docker_build_dependency_image.sh DEVICE=gpu MODE=${{ matrix.build-mode }} LOCAL_IMAGE_NAME="$LOCAL_IMAGE_NAME"
- name: Test gsutil installation
run: |
docker run --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
docker run --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \
'which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}'
- name: Test with pytest
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest -m "not tpu"'
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged "$LOCAL_IMAGE_NAME" bash -c 'cd MaxText;python3 -m pytest -m "not tpu"'
- name: Test train.py
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=dot_product'
- name: Test train.py with flash attention
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=cudnn_flash_te'
- name: Test train.py with per_device_batch_size < 1
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=dot_product'
- name: Test int8_training
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=dot_product'
- name: Test decode.py
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1'
- name: Test decode.py with per_device_batch_size < 1
run: |
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \
docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25'
5 changes: 4 additions & 1 deletion .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,7 @@ jobs:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable MODE=stable DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_stable
- name: build jax nightly image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_nightly MODE=nightly DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_nightly
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_nightly MODE=nightly DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_nightly
- name: build jax pinned image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_pinned MODE=pinned DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_pinned
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ dmypy.json

# DS_Store files
**/.DS_Store

# Wheel build
*.whl
155 changes: 155 additions & 0 deletions constraints_gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
absl-py==1.4.0
aqtp==0.6.1
array-record==0.5.0
astroid==3.1.0
astunparse==1.6.3
attrs==23.2.0
cachetools==5.3.3
certifi==2024.2.2
charset-normalizer==3.3.2
chex==0.1.85
click==8.1.7
cloud-accelerator-diagnostics==0.1.0
cloud-tpu-diagnostics==0.1.5
cloudpickle==3.0.0
clu==0.0.12
contextlib2==21.6.0
dill==0.3.8
dm-tree==0.1.8
docstring_parser==0.16
editdistance==0.8.1
etils==1.7.0
exceptiongroup==1.2.0
flatbuffers==24.3.7
flax==0.8.1
fsspec==2024.2.0
gast==0.4.0
google-api-core==2.17.1
google-auth==2.28.2
google-auth-oauthlib==1.0.0
google-cloud-aiplatform==1.47.0
google-cloud-appengine-logging==1.4.3
google-cloud-audit-log==0.2.5
google-cloud-bigquery==3.20.1
google-cloud-core==2.4.1
google-cloud-logging==3.10.0
google-cloud-resource-manager==1.12.3
google-cloud-storage==2.15.0
google-crc32c==1.5.0
google-jetstream==0.2.0
google-pasta==0.2.0
google-resumable-media==2.7.0
googleapis-common-protos==1.63.0
grain-nightly==0.0.6
grpc-google-iam-v1==0.13.0
grpcio==1.62.1
grpcio-status==1.48.2
gviz-api==1.10.0
h5py==3.10.0
idna==3.6
immutabledict==4.2.0
importlab==0.8.1
importlib_resources==6.3.0
iniconfig==2.0.0
isort==5.13.2
jax==0.4.25
jaxlib==0.4.25
jaxtyping==0.2.28
Jinja2==3.1.3
keras==2.13.1
libclang==16.0.6
libcst==1.2.0
Markdown==3.5.2
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mccabe==0.7.0
mdurl==0.1.2
ml-collections==0.1.1
ml-dtypes==0.3.2
ml_goodput_measurement==0.0.2
mlperf-logging==3.0.0
more-itertools==10.2.0
msgpack==1.0.8
msgspec==0.18.6
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx==3.1
ninja==1.11.1.1
numpy==1.24.3
nvidia-cublas-cu12==12.4.2.65
nvidia-cuda-cupti-cu12==12.4.99
nvidia-cuda-nvcc-cu12==12.4.99
nvidia-cuda-nvrtc-cu12==12.4.99
nvidia-cuda-runtime-cu12==12.4.99
nvidia-cudnn-cu12==8.9.7.29
nvidia-cufft-cu12==11.2.0.44
nvidia-cusolver-cu12==11.6.0.99
nvidia-cusparse-cu12==12.3.0.142
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
oauthlib==3.2.2
opt-einsum==3.3.0
optax==0.2.1
orbax-checkpoint==0.5.5
packaging==24.0
pandas==2.2.1
platformdirs==4.2.0
pluggy==1.4.0
portpicker==1.6.0
promise==2.3
proto-plus==1.23.0
protobuf==3.20.3
psutil==5.9.8
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycnite==2023.10.11
pydantic==1.10.14
pydot==2.0.0
pyglove==0.4.4
Pygments==2.17.2
pylint==3.1.0
pyparsing==3.1.2
pytest==8.1.1
python-dateutil==2.9.0.post0
pytype==2024.3.11
pytz==2024.1
PyYAML==6.0.1
requests==2.31.0
requests-oauthlib==1.4.0
rich==13.7.1
rsa==4.9
scipy==1.12.0
sentencepiece==0.1.97
seqio==0.0.19
shapely==2.0.3
six==1.16.0
tabulate==0.9.0
tensorboard==2.13.0
tensorboard-data-server==0.7.2
tensorboard_plugin_profile==2.15.1
tensorboardX==2.6.2.2
tensorflow==2.13.1
tensorflow-datasets==4.9.4
tensorflow-estimator==2.13.0
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.36.0
tensorflow-metadata==1.14.0
tensorflow-text==2.13.0
tensorstore==0.1.54
termcolor==2.4.0
tf-keras==2.15.0
tfds-nightly==4.9.2.dev202308090034
toml==0.10.2
tomli==2.0.1
tomlkit==0.12.4
toolz==0.12.1
tqdm==4.66.2
transformer-engine==1.5.0+297459b
typeguard==2.13.3
typing-inspect==0.9.0
typing_extensions==4.5.0
tzdata==2024.1
urllib3==2.2.1
Werkzeug==3.0.1
wrapt==1.16.0
zipp==3.18.0
13 changes: 10 additions & 3 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Example command:
# Example command:
# bash docker_build_dependency_image.sh MODE=stable
# bash docker_build_dependency_image.sh MODE=nightly
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
Expand All @@ -24,6 +24,9 @@ set -e

export LOCAL_IMAGE_NAME=maxtext_base_image

# Use Docker BuildKit so we can cache pip packages.
export DOCKER_BUILDKIT=1

echo "Starting to build your docker image. This will take a few minutes but the image can be reused as you iterate."

# Set environment variables
Expand All @@ -42,7 +45,6 @@ fi
if [[ -z ${MODE} ]]; then
export MODE=stable
echo "Default MODE=${MODE}"

fi

if [[ -z ${DEVICE} ]]; then
Expand All @@ -54,7 +56,12 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
export LIBTPU_GCS_PATH=NONE
echo "Default LIBTPU_GCS_PATH=${LIBTPU_GCS_PATH}"
if [[ ${DEVICE} == "gpu" ]]; then
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
if [[ ${MODE} == "pinned" ]]; then
export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-03-13
else
export BASEIMAGE=ghcr.io/nvidia/jax:base
fi
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
fi
Expand Down
4 changes: 2 additions & 2 deletions docker_upload_runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Each time you update the base image via a "bash docker_build_dependency_image.sh", there will be a slow upload process
# (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds.

# Example command:
# Example command:
# bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner

set -e
Expand Down Expand Up @@ -49,4 +49,4 @@ docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f ./maxtext_runner.Docke
docker tag ${LOCAL_IMAGE_NAME_RUNNER} gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest
docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest

echo "All done, check out your artifacts at: gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}"
echo "All done, check out your artifacts at: gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}"
3 changes: 2 additions & 1 deletion maxtext_dependencies.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# syntax=docker/dockerfile:experimental
# Use Python 3.10 as the base image
FROM python:3.10-slim-bullseye

Expand Down Expand Up @@ -42,6 +43,6 @@ COPY . .
RUN ls .

RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}"
RUN bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}

WORKDIR /deps
12 changes: 8 additions & 4 deletions maxtext_gpu_dependencies.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
FROM ghcr.io/nvidia/jax:base
# syntax=docker/dockerfile:experimental
ARG BASEIMAGE=ghcr.io/nvidia/jax:base
FROM $BASEIMAGE

# Install dependencies for adjusting network rto
RUN apt-get update && apt-get install -y iproute2 ethtool lsof
Expand Down Expand Up @@ -33,11 +35,13 @@ RUN mkdir -p /deps
# Set the working directory in the container
WORKDIR /deps

# Copy all files from local workspace into docker container
COPY . .
# Copy necessary build files to docker container
COPY setup.sh requirements.txt constraints_gpu.txt /deps/
RUN ls .

RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
RUN bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}

COPY . .

WORKDIR /deps
2 changes: 1 addition & 1 deletion maxtext_runner.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ WORKDIR /deps
# Copy all files from local workspace into docker container
COPY . .

WORKDIR /deps
WORKDIR /deps
11 changes: 11 additions & 0 deletions maxtext_transformerengine_builder.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM ghcr.io/nvidia/jax:base

WORKDIR /root
ENV NVTE_FRAMEWORK=jax


RUN git clone https://github.com/NVIDIA/TransformerEngine
WORKDIR /root/TransformerEngine
RUN git checkout 297459bd08e1b791ca7a2872cfa8582220477782
RUN git submodule update --init --recursive
RUN python setup.py bdist_wheel
Loading

0 comments on commit ebd39aa

Please sign in to comment.