Skip to content

Commit

Permalink
Merge pull request #139 from aws-samples/feature/#137_jax_example
Browse files Browse the repository at this point in the history
Feature/#137 jax example
  • Loading branch information
mhuguesaws authored Feb 15, 2024
2 parents e77f67c + 53ef008 commit bd700b3
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 0 deletions.
File renamed without changes.
48 changes: 48 additions & 0 deletions 3.test_cases/jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# JAX container for Amazon EC2 GPU accelerated Instances

Ths directory contains a sample Dockerfile `jax_paxml.Dockerfile` to run [JAX](https://github.com/google/jax) and [Paxml](https://github.com/google/paxml) on AWS.

## Container description

In principle, the reference `Dockerfile` does the following:

- Provide JAX built for NVIDIA CUDA devices, by using a recent NVIDIA CUDA image as the
parent image.
- Remove unneccessary networking packages that might conflict with AWS technologies.
- Install EFA user-space libraries. It's important to avoid building the kernel drivers during
`docker build`, and skip the self-tests, as both of these steps fail are expected to fail when run
during container build.
- Install NCCL recommended version.
- Install [aws-ofi-nccl](https://github.com/aws/aws-ofi-nccl) to get NCCL to utilize EFA.
- Install JAX.
- Install Paxml.
- Install Praxis.

## Build the container

Build the jax container as follow

```bash
# Build a container image
DOCKER_BUILDKIT=1 docker build --progress=plain -f jax_paxml.Dockerfile -t paxml:jax-0.4.18-1.2.0 .

# Verify the image has been built
docker images
```

Convert container to enroot format

```bash
# Convert to enroot format. Attempt to remove an existing .sqsh, otherwise enroot refuses to
# run when the output .sqsh file already exists.
rm /fsx/paxml_jax-0.4.18-1.2.0.sqsh ; enroot import -o /fsx/paxml_jax-0.4.18-1.2.0.sqsh dockerd://paxml:jax-0.4.18-1.2.0
```

## Run

Once the container converted to the enroot format, you can run on a **Slurm** cluster the lm_clouds example of Pax.
The following command submit a job to the **Slurm** cluster to train a 2B parameters transformer based SPMD language model on synthetic data.

```bash
sbatch jax.sbatch
```
45 changes: 45 additions & 0 deletions 3.test_cases/jax/jax.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/bin/bash

#SBATCH -o jax_%j.out
#SBATCH -e jax_%j.err
#SBATCH -n 384
#SBATCH --gpus-per-node=8
#SBATCH --exclusive

GPU_PER_NODE=8
TOTAL_NB_GPUS=$(($SLURM_JOB_NUM_NODES * $GPU_PER_NODE))

CHECKPOINT_DIR=/data/700/$SLURM_JOBID
if [ ! -d ${CHECKPOINT_DIR} ]; then
mkdir -p ${CHECKPOINT_DIR}
fi

# EFA Flags
export FI_PROVIDER=efa
export FI_EFA_USE_DEVICE_RDMA=1
export FI_EFA_FORK_SAFE=1

# NCCL Flags
export NCCL_DEBUG=INFO
export NCCL_NVLS_ENABLE=0

export CUDA_DEVICE_MAX_CONNECTIONS=1

# Library Path
export LD_LIBRARY_PATH=/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/cuda-12/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64

# XLA Configuration
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true"
export TPU_TYPE=gpu
export TF_FORCE_GPU_ALLOW_GROWTH=true

# Setup and checkpoint directory
export LEAD_NODE=${SLURMD_NODENAME}
export BASE_DIR=${CHECKPOINT_DIR}

# JAX Configuration
export TRAINING_CONFIG=paxml.tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitsteps
export JAX_FLAGS="--fdl.ICI_MESH_SHAPE=[1,${TOTAL_NB_GPUS},1] --fdl.PERCORE_BATCH_SIZE=32"

srun --container-image /fsx/paxml_jax-0.4.18-1.2.0.sqsh --container-mounts /fsx/data:/data -n ${TOTAL_NB_GPUS} -N ${SLURM_JOB_NUM_NODES} /bin/bash run_paxml.sh
99 changes: 99 additions & 0 deletions 3.test_cases/jax/jax_paxml.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
FROM nvcr.io/nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04

ARG EFA_INSTALLER_VERSION=1.29.1
ARG NCCL_VERSION=v2.18.6-1
ARG AWS_OFI_NCCL_VERSION=v1.7.4-aws
ARG JAX_VERSION=0.4.18
ARG PRAXIS_VERSION=1.2.0
ARG PAXML_VERSION=1.2.0

ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHON_VERSION=3.10
ENV LD_LIBRARY_PATH=/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/cuda-12/lib64:$LD_LIBRARY_PATH
ENV PATH=/opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/local/cuda-12/bin:$PATH
ENV CUDA_HOME=/usr/local/cuda-12


#########################
# Packages and Pre-reqs #
RUN apt-get update -y && \
apt-get purge -y --allow-change-held-packages libmlx5-1 ibverbs-utils libibverbs-dev libibverbs1 libnccl-dev libnccl2
RUN apt-get install -y --allow-unauthenticated \
autoconf \
automake \
bash \
build-essential \
ca-certificates \
curl \
debianutils \
dnsutils \
g++ \
git \
libtool \
libhwloc-dev \
netcat \
openssh-client \
openssh-server \
openssl \
python3-distutils \
python"${PYTHON_VERSION}"-dev \
python-is-python3 \
util-linux

RUN update-ca-certificates

###########################
# Python/Pip dependencies #
RUN curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py \
&& python"${PYTHON_VERSION}" /tmp/get-pip.py
RUN pip"${PYTHON_VERSION}" install numpy wheel build

######################################
# Install EFA Libfabric and Open MPI #
RUN cd /tmp \
&& curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \
&& tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \
&& cd aws-efa-installer \
&& ./efa_installer.sh -y -d --skip-kmod --skip-limit-conf --no-verify

############################
# Compile and Install NCCL #
RUN git clone -b "${NCCL_VERSION}" https://github.com/NVIDIA/nccl.git /opt/nccl \
&& cd /opt/nccl \
&& make -j src.build CUDA_HOME=${CUDA_HOME} \
&& cp -R /opt/nccl/build/* /usr/

###############################
# Compile AWS OFI NCCL Plugin #
RUN git clone -b "${AWS_OFI_NCCL_VERSION}" https://github.com/aws/aws-ofi-nccl.git /opt/aws-ofi-nccl \
&& cd /opt/aws-ofi-nccl \
&& ./autogen.sh \
&& ./configure --prefix=/opt/aws-ofi-nccl/install \
--with-libfabric=/opt/amazon/efa/ \
--with-cuda=${CUDA_HOME} \
--with-mpi=/opt/amazon/openmpi/ \
--with-nccl=/opt/nccl/build \
--enable-platform-aws \
&& make -j && make install

###############
# Install JAX #
RUN pip install --upgrade "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install "orbax-checkpoint>=0.4.0,<0.5.0"

##################
# Install Praxis #
RUN pip install praxis==${PRAXIS_VERSION}

#################
# Install Paxml #
RUN pip install paxml==${PAXML_VERSION}

#####################################
# Allow unauthenticated SSH for MPI #
RUN mkdir -p /var/run/sshd \
&& sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config \
&& echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config \
&& sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config

COPY run_paxml.sh /run_paxml.sh
14 changes: 14 additions & 0 deletions 3.test_cases/jax/run_paxml.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash
set -ex


# TRAINING_CONFIG example paxml.tasks.lm.params.lm_cloud.LmCloudSpmd2B
python3.10 -m paxml.main \
--job_log_dir="${BASE_DIR}/LOG_DIR" \
--fdl_config=${TRAINING_CONFIG} \
${JAX_FLAGS} \
--multiprocess_gpu=true \
--server_addr=${LEAD_NODE}:12345 \
--num_hosts=${SLURM_NPROCS} \
--host_idx=${SLURM_PROCID} \
--alsologtostderr

0 comments on commit bd700b3

Please sign in to comment.