diff --git a/.github/container/Dockerfile.maxtext.amd64 b/.github/container/Dockerfile.maxtext.amd64 index 63c6767c0..8289a6099 100644 --- a/.github/container/Dockerfile.maxtext.amd64 +++ b/.github/container/Dockerfile.maxtext.amd64 @@ -17,13 +17,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/Dockerfile.maxtext.arm64 b/.github/container/Dockerfile.maxtext.arm64 index a971d2405..bd64c5b6d 100644 --- a/.github/container/Dockerfile.maxtext.arm64 +++ b/.github/container/Dockerfile.maxtext.arm64 @@ -58,13 +58,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/maxtext-mha.patch b/.github/container/maxtext-mha.patch deleted file mode 100644 index af2f2feb0..000000000 --- a/.github/container/maxtext-mha.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff --git a/requirements.txt b/requirements.txt -index cae6c73..4b7a214 100644 ---- a/requirements.txt -+++ b/requirements.txt -@@ -17,8 +17,8 @@ pylint - pytest - pytype - sentencepiece==0.1.97 --tensorflow-text>=2.13.0 --tensorflow>=2.13.0 -+tensorflow-text==2.13.0 -+tensorflow==2.13.0 - tensorflow-datasets - tensorboardx - tensorboard-plugin-profile diff --git a/README.md b/README.md index 1764c5f00..054f49ae8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,23 @@ -# JAX Toolbox +# **JAX Toolbox** +[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/NVIDIA/JAX-Toolbox/blob/main/LICENSE.md) +[![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status) + +JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html). + +## Frameworks and Supported Models +We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs. + +| Framework | Models | Use cases | Container | +| :--- | :---: | :---: | :---: | +| [maxtext](./rosetta/rosetta/projects/maxtext)| GPT, LLaMA, Gemma, Mistral, Mixtral | pretraining | `ghcr.io/nvidia/jax:maxtext` | +| [paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | +| [t5x](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | +| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | +| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | +| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | + +# Build Pipeline Status
- + |
ghcr.io/nvidia/jax:base
|
+ |
- + | + [no tests] + |
- + |
ghcr.io/nvidia/jax:jax
|
+ |
-
-
+
- - - + - - + - - + |
|
- + |
ghcr.io/nvidia/jax:levanter
|
-
-
+
+
+
+ + + + |
- - + | |
- + |
ghcr.io/nvidia/jax:equinox
|
-
-
+
+
+
+ + + + + |
+ + [tests disabled] + | -|
- + |
ghcr.io/nvidia/jax:triton
|
- + + + | - - + | |
- + |
ghcr.io/nvidia/jax:upstream-t5x
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:t5x
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:upstream-pax
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:pax
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:maxtext
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:gemma
|
- + + + | - - + |