-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #139 from aws-samples/feature/#137_jax_example
Feature/#137 jax example
- Loading branch information
Showing
6 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# JAX container for Amazon EC2 GPU accelerated Instances | ||
|
||
Ths directory contains a sample Dockerfile `jax_paxml.Dockerfile` to run [JAX](https://github.com/google/jax) and [Paxml](https://github.com/google/paxml) on AWS. | ||
|
||
## Container description | ||
|
||
In principle, the reference `Dockerfile` does the following: | ||
|
||
- Provide JAX built for NVIDIA CUDA devices, by using a recent NVIDIA CUDA image as the | ||
parent image. | ||
- Remove unneccessary networking packages that might conflict with AWS technologies. | ||
- Install EFA user-space libraries. It's important to avoid building the kernel drivers during | ||
`docker build`, and skip the self-tests, as both of these steps fail are expected to fail when run | ||
during container build. | ||
- Install NCCL recommended version. | ||
- Install [aws-ofi-nccl](https://github.com/aws/aws-ofi-nccl) to get NCCL to utilize EFA. | ||
- Install JAX. | ||
- Install Paxml. | ||
- Install Praxis. | ||
|
||
## Build the container | ||
|
||
Build the jax container as follow | ||
|
||
```bash | ||
# Build a container image | ||
DOCKER_BUILDKIT=1 docker build --progress=plain -f jax_paxml.Dockerfile -t paxml:jax-0.4.18-1.2.0 . | ||
|
||
# Verify the image has been built | ||
docker images | ||
``` | ||
|
||
Convert container to enroot format | ||
|
||
```bash | ||
# Convert to enroot format. Attempt to remove an existing .sqsh, otherwise enroot refuses to | ||
# run when the output .sqsh file already exists. | ||
rm /fsx/paxml_jax-0.4.18-1.2.0.sqsh ; enroot import -o /fsx/paxml_jax-0.4.18-1.2.0.sqsh dockerd://paxml:jax-0.4.18-1.2.0 | ||
``` | ||
|
||
## Run | ||
|
||
Once the container converted to the enroot format, you can run on a **Slurm** cluster the lm_clouds example of Pax. | ||
The following command submit a job to the **Slurm** cluster to train a 2B parameters transformer based SPMD language model on synthetic data. | ||
|
||
```bash | ||
sbatch jax.sbatch | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/bin/bash | ||
|
||
#SBATCH -o jax_%j.out | ||
#SBATCH -e jax_%j.err | ||
#SBATCH -n 384 | ||
#SBATCH --gpus-per-node=8 | ||
#SBATCH --exclusive | ||
|
||
GPU_PER_NODE=8 | ||
TOTAL_NB_GPUS=$(($SLURM_JOB_NUM_NODES * $GPU_PER_NODE)) | ||
|
||
CHECKPOINT_DIR=/data/700/$SLURM_JOBID | ||
if [ ! -d ${CHECKPOINT_DIR} ]; then | ||
mkdir -p ${CHECKPOINT_DIR} | ||
fi | ||
|
||
# EFA Flags | ||
export FI_PROVIDER=efa | ||
export FI_EFA_USE_DEVICE_RDMA=1 | ||
export FI_EFA_FORK_SAFE=1 | ||
|
||
# NCCL Flags | ||
export NCCL_DEBUG=INFO | ||
export NCCL_NVLS_ENABLE=0 | ||
|
||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
|
||
# Library Path | ||
export LD_LIBRARY_PATH=/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/cuda-12/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 | ||
|
||
# XLA Configuration | ||
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7 | ||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" | ||
export TPU_TYPE=gpu | ||
export TF_FORCE_GPU_ALLOW_GROWTH=true | ||
|
||
# Setup and checkpoint directory | ||
export LEAD_NODE=${SLURMD_NODENAME} | ||
export BASE_DIR=${CHECKPOINT_DIR} | ||
|
||
# JAX Configuration | ||
export TRAINING_CONFIG=paxml.tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitsteps | ||
export JAX_FLAGS="--fdl.ICI_MESH_SHAPE=[1,${TOTAL_NB_GPUS},1] --fdl.PERCORE_BATCH_SIZE=32" | ||
|
||
srun --container-image /fsx/paxml_jax-0.4.18-1.2.0.sqsh --container-mounts /fsx/data:/data -n ${TOTAL_NB_GPUS} -N ${SLURM_JOB_NUM_NODES} /bin/bash run_paxml.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
FROM nvcr.io/nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 | ||
|
||
ARG EFA_INSTALLER_VERSION=1.29.1 | ||
ARG NCCL_VERSION=v2.18.6-1 | ||
ARG AWS_OFI_NCCL_VERSION=v1.7.4-aws | ||
ARG JAX_VERSION=0.4.18 | ||
ARG PRAXIS_VERSION=1.2.0 | ||
ARG PAXML_VERSION=1.2.0 | ||
|
||
ENV DEBIAN_FRONTEND=noninteractive | ||
ENV PYTHON_VERSION=3.10 | ||
ENV LD_LIBRARY_PATH=/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/cuda-12/lib64:$LD_LIBRARY_PATH | ||
ENV PATH=/opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/local/cuda-12/bin:$PATH | ||
ENV CUDA_HOME=/usr/local/cuda-12 | ||
|
||
|
||
######################### | ||
# Packages and Pre-reqs # | ||
RUN apt-get update -y && \ | ||
apt-get purge -y --allow-change-held-packages libmlx5-1 ibverbs-utils libibverbs-dev libibverbs1 libnccl-dev libnccl2 | ||
RUN apt-get install -y --allow-unauthenticated \ | ||
autoconf \ | ||
automake \ | ||
bash \ | ||
build-essential \ | ||
ca-certificates \ | ||
curl \ | ||
debianutils \ | ||
dnsutils \ | ||
g++ \ | ||
git \ | ||
libtool \ | ||
libhwloc-dev \ | ||
netcat \ | ||
openssh-client \ | ||
openssh-server \ | ||
openssl \ | ||
python3-distutils \ | ||
python"${PYTHON_VERSION}"-dev \ | ||
python-is-python3 \ | ||
util-linux | ||
|
||
RUN update-ca-certificates | ||
|
||
########################### | ||
# Python/Pip dependencies # | ||
RUN curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py \ | ||
&& python"${PYTHON_VERSION}" /tmp/get-pip.py | ||
RUN pip"${PYTHON_VERSION}" install numpy wheel build | ||
|
||
###################################### | ||
# Install EFA Libfabric and Open MPI # | ||
RUN cd /tmp \ | ||
&& curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ | ||
&& tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ | ||
&& cd aws-efa-installer \ | ||
&& ./efa_installer.sh -y -d --skip-kmod --skip-limit-conf --no-verify | ||
|
||
############################ | ||
# Compile and Install NCCL # | ||
RUN git clone -b "${NCCL_VERSION}" https://github.com/NVIDIA/nccl.git /opt/nccl \ | ||
&& cd /opt/nccl \ | ||
&& make -j src.build CUDA_HOME=${CUDA_HOME} \ | ||
&& cp -R /opt/nccl/build/* /usr/ | ||
|
||
############################### | ||
# Compile AWS OFI NCCL Plugin # | ||
RUN git clone -b "${AWS_OFI_NCCL_VERSION}" https://github.com/aws/aws-ofi-nccl.git /opt/aws-ofi-nccl \ | ||
&& cd /opt/aws-ofi-nccl \ | ||
&& ./autogen.sh \ | ||
&& ./configure --prefix=/opt/aws-ofi-nccl/install \ | ||
--with-libfabric=/opt/amazon/efa/ \ | ||
--with-cuda=${CUDA_HOME} \ | ||
--with-mpi=/opt/amazon/openmpi/ \ | ||
--with-nccl=/opt/nccl/build \ | ||
--enable-platform-aws \ | ||
&& make -j && make install | ||
|
||
############### | ||
# Install JAX # | ||
RUN pip install --upgrade "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
RUN pip install "orbax-checkpoint>=0.4.0,<0.5.0" | ||
|
||
################## | ||
# Install Praxis # | ||
RUN pip install praxis==${PRAXIS_VERSION} | ||
|
||
################# | ||
# Install Paxml # | ||
RUN pip install paxml==${PAXML_VERSION} | ||
|
||
##################################### | ||
# Allow unauthenticated SSH for MPI # | ||
RUN mkdir -p /var/run/sshd \ | ||
&& sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config \ | ||
&& echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config \ | ||
&& sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config | ||
|
||
COPY run_paxml.sh /run_paxml.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/usr/bin/env bash | ||
set -ex | ||
|
||
|
||
# TRAINING_CONFIG example paxml.tasks.lm.params.lm_cloud.LmCloudSpmd2B | ||
python3.10 -m paxml.main \ | ||
--job_log_dir="${BASE_DIR}/LOG_DIR" \ | ||
--fdl_config=${TRAINING_CONFIG} \ | ||
${JAX_FLAGS} \ | ||
--multiprocess_gpu=true \ | ||
--server_addr=${LEAD_NODE}:12345 \ | ||
--num_hosts=${SLURM_NPROCS} \ | ||
--host_idx=${SLURM_PROCID} \ | ||
--alsologtostderr |