diff --git a/.github/container/Dockerfile.maxtext.amd64 b/.github/container/Dockerfile.maxtext.amd64 index 63c6767c0..8289a6099 100644 --- a/.github/container/Dockerfile.maxtext.amd64 +++ b/.github/container/Dockerfile.maxtext.amd64 @@ -17,13 +17,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/Dockerfile.maxtext.arm64 b/.github/container/Dockerfile.maxtext.arm64 index a971d2405..bd64c5b6d 100644 --- a/.github/container/Dockerfile.maxtext.arm64 +++ b/.github/container/Dockerfile.maxtext.arm64 @@ -58,13 +58,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/maxtext-mha.patch b/.github/container/maxtext-mha.patch deleted file mode 100644 index af2f2feb0..000000000 --- a/.github/container/maxtext-mha.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff --git a/requirements.txt b/requirements.txt -index cae6c73..4b7a214 100644 ---- a/requirements.txt -+++ b/requirements.txt -@@ -17,8 +17,8 @@ pylint - pytest - pytype - sentencepiece==0.1.97 --tensorflow-text>=2.13.0 --tensorflow>=2.13.0 -+tensorflow-text==2.13.0 -+tensorflow==2.13.0 - tensorflow-datasets - tensorboardx - tensorboard-plugin-profile diff --git a/README.md b/README.md index 1764c5f00..054f49ae8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,23 @@ -# JAX Toolbox +# **JAX Toolbox** +[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/NVIDIA/JAX-Toolbox/blob/main/LICENSE.md) +[![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status) + +JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html). + +## Frameworks and Supported Models +We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs. + +| Framework | Models | Use cases | Container | +| :--- | :---: | :---: | :---: | +| [maxtext](./rosetta/rosetta/projects/maxtext)| GPT, LLaMA, Gemma, Mistral, Mixtral | pretraining | `ghcr.io/nvidia/jax:maxtext` | +| [paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | +| [t5x](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | +| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | +| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | +| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | + +# Build Pipeline Status @@ -22,242 +40,294 @@ - + + - @@ -267,26 +337,9 @@ In all of the above cases, `ghcr.io/nvidia/jax:XXX` points to the most recent nightly build of the container for `XXX`. These containers are also tagged as `ghcr.io/nvidia/jax:XXX-YYYY-MM-DD`, if a stable reference is required. -## Note -This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: [T5x](https://github.com/google-research/t5x), [PAXML](https://github.com/google/paxml), [Transformer Engine](https://github.com/NVIDIA/TransformerEngine), [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html) and others to come soon. - -## Frameworks and Supported Models -We currently support the following frameworks and models. More details about each model and the available containers can be found in their respective READMEs. - -| Framework | Supported Models | Use-cases | Container | -| :--- | :---: | :---: | :---: | -| [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | -| [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | -| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | -| [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | -| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | -| maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` | - -We will update this table as new models become available, so stay tuned. - ## Environment Variables -The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning: +The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning of XLA and NCCL: | XLA Flags | Value | Explanation | | --------- | ----- | ----------- | @@ -302,10 +355,10 @@ There are various other XLA flags users can set to improve performance. For a de For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. -## Profiling JAX programs on GPU +## Profiling See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. -## FAQ (Frequently Asked Questions) +## Frequently asked questions (FAQ)
`bus error` when running JAX in a docker container @@ -340,7 +393,6 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
## JAX on Public Clouds - * AWS * [Add EFA integration](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-efa.html) * [SageMaker code sample](https://github.com/aws-samples/aws-samples-for-ray/tree/main/sagemaker/jax_alpa_language_model)
- + - + ghcr.io/nvidia/jax:base +
+ [no tests] +
- + - + ghcr.io/nvidia/jax:jax +
- + - - + +
+ - +
- + - - + +
+ - - + +
+ - +
- + - - + +
+ - +
- + - - + +
+ - +
- + - + ghcr.io/nvidia/jax:levanter - - + + + +
+ + +
- + - - + +
+ - +
- + - + ghcr.io/nvidia/jax:equinox - - + + + +
+ + + +
+ [tests disabled] +
- + - + ghcr.io/nvidia/jax:triton - + + + - + - - + +
+ - +
- + - + ghcr.io/nvidia/jax:upstream-t5x - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:t5x - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:upstream-pax - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:pax - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:maxtext - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:gemma - + + + - + - - + +
+ - +