From c50ddfdbcf7a68c8076cb6f3434d082d360ad3d5 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Mon, 25 Mar 2024 15:42:26 +0000 Subject: [PATCH 01/26] Prebuild and install Transformer Engine package --- .gitignore | 3 + constraints.txt | 135 ++++++++++++++++++ docker_build_dependency_image.sh | 5 +- maxtext_dependencies.Dockerfile | 2 +- maxtext_gpu_dependencies.Dockerfile | 8 +- maxtext_runner.Dockerfile | 2 +- maxtext_transformerenginer_builder.Dockerfile | 11 ++ setup.sh | 15 +- 8 files changed, 173 insertions(+), 8 deletions(-) create mode 100644 constraints.txt create mode 100644 maxtext_transformerenginer_builder.Dockerfile diff --git a/.gitignore b/.gitignore index 615825017..69382978b 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ dmypy.json # DS_Store files **/.DS_Store + +# Wheel build +*.whl diff --git a/constraints.txt b/constraints.txt new file mode 100644 index 000000000..ed39c9a19 --- /dev/null +++ b/constraints.txt @@ -0,0 +1,135 @@ +absl-py==1.4.0 +aqtp==0.6.1 +array-record==0.5.0 +astroid==3.1.0 +astunparse==1.6.3 +attrs==23.2.0 +cachetools==5.3.3 +certifi==2024.2.2 +charset-normalizer==3.3.2 +chex==0.1.85 +click==8.1.7 +cloud-tpu-diagnostics==0.1.5 +cloudpickle==3.0.0 +contextlib2==21.6.0 +dill==0.3.8 +dm-tree==0.1.8 +etils==1.7.0 +exceptiongroup==1.2.0 +flatbuffers==24.3.7 +flax==0.8.1 +fsspec==2024.2.0 +gast==0.4.0 +google-api-core==2.17.1 +google-auth==2.28.2 +google-auth-oauthlib==1.0.0 +google-cloud-core==2.4.1 +google-cloud-storage==2.15.0 +google-crc32c==1.5.0 +google-pasta==0.2.0 +google-resumable-media==2.7.0 +googleapis-common-protos==1.63.0 +grain-nightly==0.0.6 +grpcio==1.62.1 +gviz-api==1.10.0 +h5py==3.10.0 +idna==3.6 +immutabledict==4.2.0 +importlab==0.8.1 +importlib_resources==6.3.0 +iniconfig==2.0.0 +isort==5.13.2 +jax==0.4.25 +jaxlib==0.4.25 +jaxtyping==0.2.28 +Jinja2==3.1.3 +keras==2.13.1 +libclang==16.0.6 +libcst==1.2.0 +Markdown==3.5.2 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +mccabe==0.7.0 +mdurl==0.1.2 +ml-collections==0.1.1 +ml-dtypes==0.3.2 +mlperf-logging==3.0.0 +more-itertools==10.2.0 +msgpack==1.0.8 +msgspec==0.18.6 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +networkx==3.1 +ninja==1.11.1.1 +numpy==1.24.3 +nvidia-cublas-cu12==12.4.2.65 +nvidia-cuda-cupti-cu12==12.4.99 +nvidia-cuda-nvcc-cu12==12.4.99 +nvidia-cuda-nvrtc-cu12==12.4.99 +nvidia-cuda-runtime-cu12==12.4.99 +nvidia-cudnn-cu12==8.9.7.29 +nvidia-cufft-cu12==11.2.0.44 +nvidia-cusolver-cu12==11.6.0.99 +nvidia-cusparse-cu12==12.3.0.142 +nvidia-nccl-cu12==2.19.3 +nvidia-nvjitlink-cu12==12.4.99 +oauthlib==3.2.2 +opt-einsum==3.3.0 +optax==0.2.1 +orbax-checkpoint==0.5.5 +packaging==24.0 +pandas==2.2.1 +platformdirs==4.2.0 +pluggy==1.4.0 +promise==2.3 +protobuf==3.20.3 +psutil==5.9.8 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycnite==2023.10.11 +pydantic==1.10.14 +pydot==2.0.0 +Pygments==2.17.2 +pylint==3.1.0 +pyparsing==3.1.2 +pytest==8.1.1 +python-dateutil==2.9.0.post0 +pytype==2024.3.11 +pytz==2024.1 +PyYAML==6.0.1 +requests==2.31.0 +requests-oauthlib==1.4.0 +rich==13.7.1 +rsa==4.9 +scipy==1.12.0 +sentencepiece==0.1.97 +six==1.16.0 +tabulate==0.9.0 +tensorboard==2.13.0 +tensorboard-data-server==0.7.2 +tensorboard_plugin_profile==2.15.1 +tensorboardX==2.6.2.2 +tensorflow==2.13.1 +tensorflow-datasets==4.9.4 +tensorflow-estimator==2.13.0 +tensorflow-hub==0.16.1 +tensorflow-io-gcs-filesystem==0.36.0 +tensorflow-metadata==1.14.0 +tensorflow-text==2.13.0 +tensorstore==0.1.54 +termcolor==2.4.0 +tf-keras==2.15.0 +toml==0.10.2 +tomli==2.0.1 +tomlkit==0.12.4 +toolz==0.12.1 +tqdm==4.66.2 +transformer-engine==1.4.0+0fbc76a +typeguard==2.13.3 +typing-inspect==0.9.0 +typing_extensions==4.5.0 +tzdata==2024.1 +urllib3==2.2.1 +Werkzeug==3.0.1 +wrapt==1.16.0 +zipp==3.18.0 diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index a1b0cfeaf..3b5b8048b 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Example command: +# Example command: # bash docker_build_dependency_image.sh MODE=stable # bash docker_build_dependency_image.sh MODE=nightly # bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13 @@ -24,6 +24,9 @@ set -e export LOCAL_IMAGE_NAME=maxtext_base_image +# Use Docker BuildKit so we can cache pip packages. +export DOCKER_BUILDKIT=1 + echo "Starting to build your docker image. This will take a few minutes but the image can be reused as you iterate." # Set environment variables diff --git a/maxtext_dependencies.Dockerfile b/maxtext_dependencies.Dockerfile index cfc16547d..5305de5ff 100644 --- a/maxtext_dependencies.Dockerfile +++ b/maxtext_dependencies.Dockerfile @@ -42,6 +42,6 @@ COPY . . RUN ls . RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" -RUN bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} +RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} WORKDIR /app diff --git a/maxtext_gpu_dependencies.Dockerfile b/maxtext_gpu_dependencies.Dockerfile index 242435b29..affe14c4c 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/maxtext_gpu_dependencies.Dockerfile @@ -33,11 +33,13 @@ RUN mkdir -p /deps # Set the working directory in the container WORKDIR /deps -# Copy all files from local workspace into docker container -COPY . . +# Copy necessary build files to docker container +COPY setup.sh requirements.txt constraints.txt /deps/ RUN ls . RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" -RUN bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} +RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} + +COPY . . WORKDIR /app diff --git a/maxtext_runner.Dockerfile b/maxtext_runner.Dockerfile index 15106dd89..64c7a2c35 100644 --- a/maxtext_runner.Dockerfile +++ b/maxtext_runner.Dockerfile @@ -9,4 +9,4 @@ WORKDIR /app # Copy all files from local workspace into docker container COPY . . -WORKDIR /app \ No newline at end of file +WORKDIR /app diff --git a/maxtext_transformerenginer_builder.Dockerfile b/maxtext_transformerenginer_builder.Dockerfile new file mode 100644 index 000000000..10092bafb --- /dev/null +++ b/maxtext_transformerenginer_builder.Dockerfile @@ -0,0 +1,11 @@ +FROM ghcr.io/nvidia/jax:base + +WORKDIR /root +COPY ./constraints.txt . +ENV NVTE_FRAMEWORK=jax + +RUN git clone https://github.com/NVIDIA/TransformerEngine +WORKDIR /root/TransformerEngine +RUN git checkout 0fbc76af3733ae997394eaf82b78ff9c0498fe9 +RUN git submodule update --init --recursive +RUN python setup.by bdist_wheel diff --git a/setup.sh b/setup.sh index aa36a7ddf..d3c390b13 100644 --- a/setup.sh +++ b/setup.sh @@ -131,8 +131,11 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then pip3 install -U "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu" - pip3 install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip3 install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints.txt fi + pip3 install "transformer-engine==1.4.0+0fbc76a" \ + --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ \ + -c constraints.txt fi elif [[ $MODE == "nightly" ]]; then # Nightly mode @@ -142,6 +145,9 @@ elif [[ $MODE == "nightly" ]]; then pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # Install jaxlib-nightly pip3 install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html + # Install prebuilt Transformer Engine for GPU builds. + pip3 install "transformer-engine==1.4.0+0fbc76a" \ + --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ elif [[ $DEVICE == "tpu" ]]; then echo "Installing jax-nightly, jaxlib-nightly" # Install jax-nightly @@ -172,4 +178,9 @@ else fi # Install dependencies from requirements.txt -cd $run_name_folder_path && pip install --upgrade pip && pip3 install -r requirements.txt +cd $run_name_folder_path && pip install --upgrade pip +if [[ $DEVICE == "gpu" ]] && [[ "$MODE" == "stable" || ! -v MODE ]] && [[ ! -v JAX_VERSION ]]; then + pip3 install -r requirements.txt -c constraints.txt +else + pip3 install -U -r requirements.txt +fi From 67202ec86e8df8fbaef110b34ae9a022b6a720af Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Mon, 25 Mar 2024 16:29:06 +0000 Subject: [PATCH 02/26] Add buildx to GHA to enable buildkit --- .github/workflows/UnitTests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 69117160e..87abb7c31 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -122,7 +122,7 @@ jobs: run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ 'python3 pedagogical_examples/shmap_collective_matmul.py' - + # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job gpu: strategy: @@ -133,6 +133,8 @@ jobs: runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"] steps: - uses: actions/checkout@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - name: Cleanup old docker images run: | docker system prune --all --force From 9a3ca8e29b2c20375e0e463d644d7a2214a77061 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Mon, 1 Apr 2024 19:29:05 +0000 Subject: [PATCH 03/26] Activate version pinning for pinned mode only --- constraints.txt => constraints_gpu.txt | 0 maxtext_gpu_dependencies.Dockerfile | 2 +- setup.sh | 21 +++++++++++++++------ 3 files changed, 16 insertions(+), 7 deletions(-) rename constraints.txt => constraints_gpu.txt (100%) diff --git a/constraints.txt b/constraints_gpu.txt similarity index 100% rename from constraints.txt rename to constraints_gpu.txt diff --git a/maxtext_gpu_dependencies.Dockerfile b/maxtext_gpu_dependencies.Dockerfile index affe14c4c..d19904857 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/maxtext_gpu_dependencies.Dockerfile @@ -34,7 +34,7 @@ RUN mkdir -p /deps WORKDIR /deps # Copy necessary build files to docker container -COPY setup.sh requirements.txt constraints.txt /deps/ +COPY setup.sh requirements.txt constraints_gpu.txt /deps/ RUN ls . RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" diff --git a/setup.sh b/setup.sh index d3c390b13..10fb209e0 100644 --- a/setup.sh +++ b/setup.sh @@ -96,7 +96,17 @@ if [ -e "$libtpu_path" ]; then rm "$libtpu_path" fi -if [[ "$MODE" == "stable" || ! -v MODE ]]; then +if [[ "$MODE" == "pinned" ]]; then + if [[ "$DEVICE" != "gpu" ]]; then + echo "pinned mode is supported for GPU builds only." + exit 1 + fi + echo "Installing pinned jax, jaxlib for NVIDIA gpu." + pip3 install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints_gpu.txt + pip3 install "transformer-engine==1.4.0+0fbc76a" \ + --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ \ + -c constraints_gpu.txt +elif [[ "$MODE" == "stable" || ! -v MODE ]]; then # Stable mode if [[ $DEVICE == "tpu" ]]; then echo "Installing stable jax, jaxlib for tpu" @@ -131,11 +141,10 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then pip3 install -U "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu" - pip3 install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints.txt + pip3 install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html fi pip3 install "transformer-engine==1.4.0+0fbc76a" \ - --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ \ - -c constraints.txt + --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ fi elif [[ $MODE == "nightly" ]]; then # Nightly mode @@ -179,8 +188,8 @@ fi # Install dependencies from requirements.txt cd $run_name_folder_path && pip install --upgrade pip -if [[ $DEVICE == "gpu" ]] && [[ "$MODE" == "stable" || ! -v MODE ]] && [[ ! -v JAX_VERSION ]]; then - pip3 install -r requirements.txt -c constraints.txt +if [[ "$MODE" == "pinned" ]]; then + pip3 install -r requirements.txt -c constraints_gpu.txt else pip3 install -U -r requirements.txt fi From 7387daec0a3bb22b43912ab990dabcbe9e82e04c Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Mon, 1 Apr 2024 19:44:55 +0000 Subject: [PATCH 04/26] Add pinned mode GHA --- .github/workflows/UnitTests.yml | 21 ++++++++++++--------- docker_build_dependency_image.sh | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index fce5db3d8..301cb1e31 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -122,8 +122,11 @@ jobs: fail-fast: false matrix: device-type: ["a100-40gb-4"] - name: "GPU test (${{ matrix.device-type }})" + build-mode: ["stable", "pinned"] + name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})" runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"] + env: + LOCAL_IMAGE_NAME: "maxtext_base_image_${{ matrix.build-mode }}_${{ github.sha }}" steps: - uses: actions/checkout@v3 - name: Set up Docker Buildx @@ -133,31 +136,31 @@ jobs: docker system prune --all --force - name: Install dependencies run: | - bash docker_build_dependency_image.sh DEVICE=gpu + bash docker_build_dependency_image.sh DEVICE=gpu MODE=${{ matrix.build-mode }} - name: Test gsutil installation run: | - docker run --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ 'which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}' - name: Test with pytest run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest -m "not tpu"' + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c 'cd MaxText;python3 -m pytest -m "not tpu"' - name: Test train.py run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=dot_product' - name: Test train.py with per_device_batch_size < 1 run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=dot_product' - name: Test int8_training run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=dot_product' - name: Test decode.py run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ 'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1' - name: Test decode.py with per_device_batch_size < 1 run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ 'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25' diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 3b5b8048b..2fc1fe89e 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -22,7 +22,7 @@ # Enable "exit immediately if any command fails" option set -e -export LOCAL_IMAGE_NAME=maxtext_base_image +export LOCAL_IMAGE_NAME="${LOCAL_IMAGE_NAME:-maxtext_base_image}" # Use Docker BuildKit so we can cache pip packages. export DOCKER_BUILDKIT=1 From 756e963faad96a491f08d5f78b0db7024bd9242b Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Tue, 2 Apr 2024 22:39:08 +0000 Subject: [PATCH 05/26] TE builder fix --- maxtext_transformerenginer_builder.Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/maxtext_transformerenginer_builder.Dockerfile b/maxtext_transformerenginer_builder.Dockerfile index 10092bafb..004e753a3 100644 --- a/maxtext_transformerenginer_builder.Dockerfile +++ b/maxtext_transformerenginer_builder.Dockerfile @@ -1,11 +1,11 @@ FROM ghcr.io/nvidia/jax:base WORKDIR /root -COPY ./constraints.txt . ENV NVTE_FRAMEWORK=jax + RUN git clone https://github.com/NVIDIA/TransformerEngine WORKDIR /root/TransformerEngine RUN git checkout 0fbc76af3733ae997394eaf82b78ff9c0498fe9 RUN git submodule update --init --recursive -RUN python setup.by bdist_wheel +RUN python setup.py bdist_wheel From b938fa81069291e28fb62caed4cf92aa639a6626 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Wed, 3 Apr 2024 21:50:53 +0000 Subject: [PATCH 06/26] Add pinned build to daily build workflow Also undo some unnecessary change --- .github/workflows/UploadDockerImages.yml | 5 ++++- docker_build_dependency_image.sh | 2 +- docker_upload_runner.sh | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index 3f11c3127..cf4099d85 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -50,4 +50,7 @@ jobs: bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable MODE=stable DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_stable - name: build jax nightly image run : | - bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_nightly MODE=nightly DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_nightly \ No newline at end of file + bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_nightly MODE=nightly DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_nightly + - name: build jax pinned image + run : | + bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_pinned MODE=pinned DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_pinned diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 2fc1fe89e..3b5b8048b 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -22,7 +22,7 @@ # Enable "exit immediately if any command fails" option set -e -export LOCAL_IMAGE_NAME="${LOCAL_IMAGE_NAME:-maxtext_base_image}" +export LOCAL_IMAGE_NAME=maxtext_base_image # Use Docker BuildKit so we can cache pip packages. export DOCKER_BUILDKIT=1 diff --git a/docker_upload_runner.sh b/docker_upload_runner.sh index c59c02566..65dfafbf1 100644 --- a/docker_upload_runner.sh +++ b/docker_upload_runner.sh @@ -20,7 +20,7 @@ # Each time you update the base image via a "bash docker_build_dependency_image.sh", there will be a slow upload process # (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds. -# Example command: +# Example command: # bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner set -e @@ -49,4 +49,4 @@ docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} -f ./maxtext_runner.Docke docker tag ${LOCAL_IMAGE_NAME_RUNNER} gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest docker push gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}:latest -echo "All done, check out your artifacts at: gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}" \ No newline at end of file +echo "All done, check out your artifacts at: gcr.io/$PROJECT/${CLOUD_IMAGE_NAME}" From 41ea174ceaa0b465d27e347b7baa7dc507845f99 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Wed, 3 Apr 2024 23:06:13 +0000 Subject: [PATCH 07/26] Fix GHA issue --- .github/workflows/UnitTests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 66a729e5f..69d157ea6 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -140,7 +140,7 @@ jobs: docker system prune --all --force - name: Install dependencies run: | - bash docker_build_dependency_image.sh DEVICE=gpu MODE=${{ matrix.build-mode }} + bash docker_build_dependency_image.sh DEVICE=gpu MODE=${{ matrix.build-mode }} LOCAL_IMAGE_NAME="$LOCAL_IMAGE_NAME" - name: Test gsutil installation run: | docker run --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ From 1120463c0a3ffc7ab67f4b43ca5ae858c6b7fef4 Mon Sep 17 00:00:00 2001 From: prrathi <53785742+prrathi@users.noreply.github.com> Date: Thu, 4 Apr 2024 00:57:53 -0700 Subject: [PATCH 08/26] Update embeddings.py --- MaxText/layers/embeddings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index ac773ab71..6c954941c 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -164,8 +164,8 @@ def __call__( ) position = position[:, :, jnp.newaxis, jnp.newaxis] sinusoid_inp = position / timescale - sin = jnp.sin(sinusoid_inp) - cos = jnp.cos(sinusoid_inp) + sin = jnp.sin(sinusoid_inp).astype(inputs.dtype) + cos = jnp.cos(sinusoid_inp).astype(inputs.dtype) first_half, second_half = jnp.split(inputs, 2, axis=-1) first_part = first_half * cos - second_half * sin second_part = second_half * cos + first_half * sin @@ -198,4 +198,4 @@ def __call__( signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = -1) # signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]]) position_embedding = signal.astype(jnp.float32) - return input_embedding + position_embedding \ No newline at end of file + return input_embedding + position_embedding From 271b58fc4868d1649bae2f3ca608dd0900957787 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Fri, 5 Apr 2024 00:33:58 +0000 Subject: [PATCH 09/26] Update TE version --- constraints_gpu.txt | 2 +- maxtext_transformerenginer_builder.Dockerfile | 2 +- setup.sh | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/constraints_gpu.txt b/constraints_gpu.txt index ed39c9a19..0a86a9c0e 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -124,7 +124,7 @@ tomli==2.0.1 tomlkit==0.12.4 toolz==0.12.1 tqdm==4.66.2 -transformer-engine==1.4.0+0fbc76a +transformer-engine==1.5.0+297459b typeguard==2.13.3 typing-inspect==0.9.0 typing_extensions==4.5.0 diff --git a/maxtext_transformerenginer_builder.Dockerfile b/maxtext_transformerenginer_builder.Dockerfile index 004e753a3..22a66e960 100644 --- a/maxtext_transformerenginer_builder.Dockerfile +++ b/maxtext_transformerenginer_builder.Dockerfile @@ -6,6 +6,6 @@ ENV NVTE_FRAMEWORK=jax RUN git clone https://github.com/NVIDIA/TransformerEngine WORKDIR /root/TransformerEngine -RUN git checkout 0fbc76af3733ae997394eaf82b78ff9c0498fe9 +RUN git checkout 297459bd08e1b791ca7a2872cfa8582220477782 RUN git submodule update --init --recursive RUN python setup.py bdist_wheel diff --git a/setup.sh b/setup.sh index 10fb209e0..6c1f858e0 100644 --- a/setup.sh +++ b/setup.sh @@ -103,7 +103,7 @@ if [[ "$MODE" == "pinned" ]]; then fi echo "Installing pinned jax, jaxlib for NVIDIA gpu." pip3 install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -c constraints_gpu.txt - pip3 install "transformer-engine==1.4.0+0fbc76a" \ + pip3 install "transformer-engine==1.5.0+297459b" \ --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ \ -c constraints_gpu.txt elif [[ "$MODE" == "stable" || ! -v MODE ]]; then @@ -143,7 +143,7 @@ elif [[ "$MODE" == "stable" || ! -v MODE ]]; then echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu" pip3 install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html fi - pip3 install "transformer-engine==1.4.0+0fbc76a" \ + pip3 install "transformer-engine==1.5.0+297459b" \ --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ fi elif [[ $MODE == "nightly" ]]; then @@ -155,7 +155,7 @@ elif [[ $MODE == "nightly" ]]; then # Install jaxlib-nightly pip3 install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html # Install prebuilt Transformer Engine for GPU builds. - pip3 install "transformer-engine==1.4.0+0fbc76a" \ + pip3 install "transformer-engine==1.5.0+297459b" \ --extra-index-url https://us-python.pkg.dev/gce-ai-infra/maxtext-build-support-packages/simple/ elif [[ $DEVICE == "tpu" ]]; then echo "Installing jax-nightly, jaxlib-nightly" From b3df5aad40251f8bcf965282032272f0c0c0267e Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Fri, 5 Apr 2024 16:25:20 +0000 Subject: [PATCH 10/26] Add nightly unit test in GHA Also add flash attention smoketest --- .github/workflows/UnitTests.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 69d157ea6..2758af18f 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -126,7 +126,7 @@ jobs: fail-fast: false matrix: device-type: ["a100-40gb-4"] - build-mode: ["stable", "pinned"] + build-mode: ["nightly", "stable", "pinned"] name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})" runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"] env: @@ -152,6 +152,10 @@ jobs: run: | docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=dot_product' + - name: Test train.py with flash attention + run: | + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ + 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=cudnn_flash_te' - name: Test train.py with per_device_batch_size < 1 run: | docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged "$LOCAL_IMAGE_NAME" bash -c \ From 04262f3cf5d4f1b064c484e3cbc2ca2a722f9d32 Mon Sep 17 00:00:00 2001 From: Rafi Witten Date: Fri, 22 Mar 2024 23:43:50 +0000 Subject: [PATCH 11/26] README updates --- MaxText/configs/README.md | 2 +- README.md | 134 ++++++++++------------------------- getting_started/First_run.md | 65 +++++++++++++++++ 3 files changed, 102 insertions(+), 99 deletions(-) create mode 100644 getting_started/First_run.md diff --git a/MaxText/configs/README.md b/MaxText/configs/README.md index 4ca6557ed..e2f98454f 100644 --- a/MaxText/configs/README.md +++ b/MaxText/configs/README.md @@ -15,7 +15,7 @@ --> # High Performance Model Configs -This directory contains high performance model configurations for different generations of TPU hardware. +This directory contains high performance model configurations for different generations of TPU and GPU hardware. These configurations do 3 things: * Sets various XLA compiler flags as `LIBTPU_INIT_ARGS` to optimize runtime performance. diff --git a/README.md b/README.md index 9dafc9e05..1b19f93aa 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ +## Use Vertex AI Tensorboard +MaxText supports automatic upload of logs collected in a directory to a Tensorboard instance in Vertex AI. For more information on how MaxText supports this feature, visit [cloud-accelerator-diagnostics](https://pypi.org/project/cloud-accelerator-diagnostics) PyPI package documentation. + +### What is Vertex AI Tensorboard and Vertex AI Experiment +Vertex AI Tensorboard is a fully managed and enterprise-ready version of open-source Tensorboard. To learn more about Vertex AI Tensorboard, visit [this](https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-introduction). Vertex AI Experiment is a tool that helps to track and analyze an experiment run on Vertex AI Tensorboard. To learn more about Vertex AI Experiments, visit [this](https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments). + +You can use a single Vertex AI Tensorboard instance to track and compare metrics from multiple Vertex AI Experiments. While you can view metrics from multiple Vertex AI Experiments within a single Tensorboard instance, the underlying log data for each experiment remains separate. + +### Prerequisites +* Enable [Vertex AI API](https://cloud.google.com/vertex-ai/docs/start/cloud-environment#enable_vertexai_apis) in your Google Cloud console. +* Assign [Vertex AI User IAM role](https://cloud.google.com/vertex-ai/docs/general/access-control#aiplatform.user) to the service account used by the TPU VMs. This is required to create and access the Vertex AI Tensorboard in Google Cloud console. If you are using XPK for MaxText, the necessary Vertex AI User IAM role will be automatically assigned to your node pools by XPK – no need to assign it manually. + +### Upload Logs to Vertex AI Tensorboard +**Scenario 1: Using XPK to run MaxText on GKE** + +XPK simplifies MaxText's Vertex AI Tensorboard integration. A Vertex Tensorboard instance and Experiment are automatically created by XPK during workload scheduling. Also, XPK automatically sets the necessary environment variables, eliminating the need to manually configure this in MaxText. Set `use_vertex_tensorboard=False` to avoid setting up Vertex Tensorboard again in MaxText. This is how the configuration will look like for running MaxText via XPK: +``` +use_vertex_tensorboard: False +vertex_tensorboard_project: "" +vertex_tensorboard_region: "" +``` +The above configuration will upload logs in `config.tensorboard_dir` to Vertex Tensorboard instance set as an environment variable by XPK. + +**Scenario 2: Running MaxText on GCE** + +Set `use_vertex_tensorboard=True` to upload logs in `config.tensorboard_dir` to a Tensorboard instance in Vertex AI. You can manually create a Tensorboard instance named `-tb-instance` and an Experiment named `config.run_name` in Vertex AI on Google Cloud console. Otherwise, MaxText will create those resources for you when `use_vertex_tensorboard=True`. Note that Vertex AI is available in only [these](https://cloud.google.com/vertex-ai/docs/general/locations#available-regions) regions. + +**Scenario 2.1: Configuration to upload logs to Vertex AI Tensorboard** + +``` +run_name: "test-run" +use_vertex_tensorboard: True +vertex_tensorboard_project: "test-project" # or vertex_tensorboard_project: "" +vertex_tensorboard_location: "us-central1" +``` +The above configuration will try to create a Vertex AI Tensorboard instance named `test-project-tb-instance` and a Vertex AI Experiment named `test-run` in the `us-central1` region of `test-project`. If you set `vertex_tensorboard_project=""`, then the default project (`gcloud config get project`) set on the VM will be used to create the Vertex AI resources. It will only create these resources if they do not already exist. Also, the logs in `config.tensorboard_dir` will be uploaded to `test-project-tb-instance` Tensorboard instance and `test-run` Experiment in Vertex AI. + +**Scenario 2.2: Configuration to not upload logs to Vertex AI Tensorboard** + +The following configuration will not upload any log data collected in `config.tensorboard_dir` to Tensorboard in Vertex AI. +``` +use_vertex_tensorboard: False +vertex_tensorboard_project: "" +vertex_tensorboard_location: "" +``` diff --git a/requirements.txt b/requirements.txt index cae6c739a..ebe9a2e4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ orbax-checkpoint>=0.5.5 absl-py array-record aqtp +cloud-accelerator-diagnostics cloud-tpu-diagnostics google-cloud-storage grain-nightly From 24d24b68b872ca88be9e5b049dd64a6dac782272 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Tue, 9 Apr 2024 17:43:40 -0700 Subject: [PATCH 15/26] Fix the sync of cloud_accelerator_diagnostics import to Github PiperOrigin-RevId: 623329502 --- MaxText/vertex_tensorboard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/vertex_tensorboard.py b/MaxText/vertex_tensorboard.py index f8056d22c..1cb438db5 100644 --- a/MaxText/vertex_tensorboard.py +++ b/MaxText/vertex_tensorboard.py @@ -23,8 +23,8 @@ import max_logging import max_utils -from cloud_accelerator_diagnostics.pip_package.cloud_accelerator_diagnostics.src.tensorboard_uploader import tensorboard -from cloud_accelerator_diagnostics.pip_package.cloud_accelerator_diagnostics.src.tensorboard_uploader import uploader +from cloud_accelerator_diagnostics import tensorboard +from cloud_accelerator_diagnostics import uploader class VertexTensorboardManager: From fc216ffb809e168904ed46dc72adbb3dec159374 Mon Sep 17 00:00:00 2001 From: Nina Cai Date: Wed, 10 Apr 2024 20:44:54 +0000 Subject: [PATCH 16/26] Separate gpu and tpu end-to-end scripts --- end_to_end/gpu/a3/test_llama2_7b.sh | 71 +++++++++ end_to_end/tpu/eval_assert.py | 139 ++++++++++++++++++ end_to_end/tpu/gemma/2b/test_gemma.sh | 64 ++++++++ end_to_end/tpu/gemma/7b/1_test_gemma.sh | 31 ++++ end_to_end/tpu/gemma/7b/2_test_gemma.sh | 55 +++++++ end_to_end/tpu/gemma/Run_Gemma.md | 31 ++++ end_to_end/tpu/llama_finetuning_test.sh | 19 +++ .../tpu/test_checkpoint_compatibility.sh | 49 ++++++ end_to_end/tpu/test_checkpoint_resharding.sh | 18 +++ end_to_end/tpu/test_checkpointing.sh | 63 ++++++++ end_to_end/tpu/test_convergence_1b_params.sh | 52 +++++++ end_to_end/tpu/test_decode.sh | 36 +++++ end_to_end/tpu/test_determinism.sh | 31 ++++ .../test_generate_param_only_checkpoint.sh | 123 ++++++++++++++++ end_to_end/tpu/test_gpt3.sh | 16 ++ end_to_end/tpu/test_llama2_7b.sh | 73 +++++++++ end_to_end/tpu/test_mistral.sh | 22 +++ end_to_end/tpu/test_mixtral.sh | 22 +++ end_to_end/tpu/test_tflops.sh | 22 +++ end_to_end/tpu/test_tflops_16b_params.sh | 38 +++++ end_to_end/tpu/test_tflops_32b_params.sh | 38 +++++ end_to_end/tpu/test_tflops_64b_params.sh | 38 +++++ end_to_end/tpu/test_vocab_creation.sh | 14 ++ 23 files changed, 1065 insertions(+) create mode 100644 end_to_end/gpu/a3/test_llama2_7b.sh create mode 100644 end_to_end/tpu/eval_assert.py create mode 100644 end_to_end/tpu/gemma/2b/test_gemma.sh create mode 100644 end_to_end/tpu/gemma/7b/1_test_gemma.sh create mode 100644 end_to_end/tpu/gemma/7b/2_test_gemma.sh create mode 100644 end_to_end/tpu/gemma/Run_Gemma.md create mode 100644 end_to_end/tpu/llama_finetuning_test.sh create mode 100644 end_to_end/tpu/test_checkpoint_compatibility.sh create mode 100644 end_to_end/tpu/test_checkpoint_resharding.sh create mode 100644 end_to_end/tpu/test_checkpointing.sh create mode 100644 end_to_end/tpu/test_convergence_1b_params.sh create mode 100644 end_to_end/tpu/test_decode.sh create mode 100644 end_to_end/tpu/test_determinism.sh create mode 100644 end_to_end/tpu/test_generate_param_only_checkpoint.sh create mode 100644 end_to_end/tpu/test_gpt3.sh create mode 100644 end_to_end/tpu/test_llama2_7b.sh create mode 100644 end_to_end/tpu/test_mistral.sh create mode 100644 end_to_end/tpu/test_mixtral.sh create mode 100644 end_to_end/tpu/test_tflops.sh create mode 100644 end_to_end/tpu/test_tflops_16b_params.sh create mode 100644 end_to_end/tpu/test_tflops_32b_params.sh create mode 100644 end_to_end/tpu/test_tflops_64b_params.sh create mode 100644 end_to_end/tpu/test_vocab_creation.sh diff --git a/end_to_end/gpu/a3/test_llama2_7b.sh b/end_to_end/gpu/a3/test_llama2_7b.sh new file mode 100644 index 000000000..7a6a9e0c3 --- /dev/null +++ b/end_to_end/gpu/a3/test_llama2_7b.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# This file is both an integration test that runs once a day on a A3 and documentation for how to get started with Llama2-7b + +# The flow of this file is as follows: +# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText. +# 2. Run training of Llama2-7b. +# 3. Run decoding from the trained checkpoint. + + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +export ASYNC_CHECKPOINTING=false + +# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU +pip install torch --index-url https://download.pytorch.org/whl/cpu + +# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint +export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-7b/meta-ckpt + +# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. +# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py` +gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ + +# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own +export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext-gpu + +#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` +python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH} + +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands +export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items + +# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. +# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. +export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx} +python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} + +export RUN_NAME="llama-2-1vm-$(date +%Y-%m-%d-%H-%M)" + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_FUSED_ATTN=1 +export NCCL_DEBUG=VERSION + +export XLA_FLAGS="--xla_dump_to=$BASE_OUTPUT_PATH/$RUN_NAME/HLO_dumps/ +--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true +--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions + --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true + --xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728 + --xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true + --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true + --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false + --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false + --xla_disable_hlo_passes=rematerialization" + +python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b enable_checkpointing=true attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} base_output_directory=$BASE_OUTPUT_DIRECTORY enable_profiler=false + +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 +export TF_FORCE_GPU_ALLOW_GROWTH=true + +python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py new file mode 100644 index 000000000..c879ce99f --- /dev/null +++ b/end_to_end/tpu/eval_assert.py @@ -0,0 +1,139 @@ +""" + Copyright 2023 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +# pylint: skip-file +"""Reads and asserts over target values""" +from absl import app +from typing import Sequence +from math import isclose +from google.cloud import storage +import json + + +def compute_avg_metric(metrics_file, target, start_line=10): + """ Reads and computes average of target value + If start_line is negative then uses the last lines, e.g. start from end + 1 - |start_line|""" + + + avg = 0 + i = 0 + with open(metrics_file, 'r', encoding='utf8') as file: + lines = file.readlines() + if start_line < 0: + start_line = len(lines) + start_line + for line in lines: + # skip the first start_line lines for burn in + if i >= start_line: + vals = json.loads(line) + avg += vals[target] + i+=1 + avg /= (i-start_line) + + return avg + + +def assert_metric_average(metrics_file, threshold, target): + avg_value = compute_avg_metric(metrics_file, target) + # Checks for acceptable performance by asserting that the average metric (e.g. TFLOPs) + # is greater than the threshold. + print(f'avg value of target {target} is {avg_value}') + assert avg_value >= float(threshold) + print('assert metric average passed.') + +def test_final_loss(metrics_file, target_loss): + target_loss = float(target_loss) + with open(metrics_file, 'r', encoding='utf8') as metrics: + use_last_n_data = 10 + avg_final_loss = compute_avg_metric(metrics_file, 'learning/loss', start_line= -1 * use_last_n_data) + print(f"Mean of last {use_last_n_data} losses is {avg_final_loss}") + print(f"Target loss is {target_loss}") + assert avg_final_loss < target_loss + print('Final loss test passed.') + +def test_checkpointing(metrics_file, target, dataset_type): + """Asserts over loss values from loaded checkpoint""" + metrics_file_saved = 'saved_' + metrics_file + metrics_file_restored = 'restored_' + metrics_file + + with open(metrics_file_saved, 'r', encoding='utf8') as saved,\ + open(metrics_file_restored, 'r', encoding='utf8') as restored: + saved_loss = json.loads(saved.readlines()[-1])[target] + restored_loss = json.loads(restored.readlines()[0])[target] + # Checks that checkpoint restore was successful by comparing loss of last + # step in saved checkpoint to loss of first step in restored checkpoint + print("saved loss: ", saved_loss) + print("restored loss: ", restored_loss) + if dataset_type=='c4': + assert isclose(saved_loss, restored_loss, rel_tol=0.1) + elif dataset_type=='c4-array_record': + assert saved_loss==restored_loss + else: + raise ValueError(f"Unknown dataset_type {dataset_type}. dataset_type must be c4, c4-array_record or synthetic") + print('checkpointing test passed.') + +def test_determinism(metrics_file, target): + """Asserts over loss values from two runs""" + run_1 = 'run_1_' + metrics_file + run_2 = 'run_2_' + metrics_file + + with open(run_1, 'r', encoding='utf8') as run_1_file,\ + open(run_2, 'r', encoding='utf8') as run_2_file: + run_1_loss = json.loads(run_1_file.readlines()[-1])[target] + run_2_loss = json.loads(run_2_file.readlines()[-1])[target] + # Check that the two runs have the same loss + print(f"Run 1 loss:{run_1_loss}", flush=True) + print(f"Run 2 loss:{run_2_loss}", flush=True) + assert run_1_loss==run_2_loss + print('determinism test passed.') + +def test_vocab_creation(target): + bucket_name = target.split("/")[2] + vocab_path = "/".join(target.split("/")[3:]) + storage_client = storage.Client() + assert storage.Blob(bucket=storage_client.bucket(bucket_name), name=vocab_path).exists(storage_client) + print('vocab creation test passed.') + +def test_start_step(metrics_file, start_step_target): + with open(metrics_file, 'r', encoding='utf8') as metrics: + start_step = json.loads(metrics.readlines()[0])["step"] + print(f"Start step is {start_step}, start step target is {start_step_target}") + assert start_step==float(start_step_target) + print("Start step test passed.") + +def main(argv: Sequence[str]) -> None: + + _, test_scenario, *test_vars = argv + + if test_scenario == 'metrics_average': + assert_metric_average(*test_vars) + elif test_scenario == 'checkpoint_save_restore': + test_checkpointing(*test_vars, dataset_type='c4') + elif test_scenario == 'grain_checkpoint_save_restore': + test_checkpointing(*test_vars, dataset_type='c4-array_record') + elif test_scenario == 'determinism': + test_determinism(*test_vars) + elif test_scenario == 'vocab_creation': + test_vocab_creation(*test_vars) + elif test_scenario == 'final_loss': + test_final_loss(*test_vars) + elif test_scenario == 'test_start_step': + test_start_step(*test_vars) + else: + raise ValueError(f"Unrecognized test_scenario {test_scenario}") + + +if __name__ == "__main__": + app.run(main) diff --git a/end_to_end/tpu/gemma/2b/test_gemma.sh b/end_to_end/tpu/gemma/2b/test_gemma.sh new file mode 100644 index 000000000..74d776951 --- /dev/null +++ b/end_to_end/tpu/gemma/2b/test_gemma.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma-2b. + +# The flow of this file is as follows: +# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText +# 2. Run decoding, finetuning of Gemma 2B with the converted checkpoint. Also, run pretraining of Gemma 2B +# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. +# 4. Run decoding from the finetuned checkpoint from step 2 +# 5. Ahead of Time Compilation for running Gemma 2B on v5e-256 + + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +export MODEL_VARIATION='2b' + +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# Non-Googlers please remember to use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. +export CHKPT_BUCKET=gs://maxtext-gemma/flax +export MODEL_BUCKET=gs://maxtext-gemma +python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} + +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands +export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items +export RUN_NAME=unscanned_chkpt_${idx} +# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. +# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. +python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-2b' force_unroll=true + +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items + +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write about it" + +# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" + +# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning +export FINETUNE_RUN_NAME=runner_finetune_${idx} +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5 + +# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b + +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. +# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding +export PARAM_RUN_NAME=param_chkpt_${idx} +python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true + +# Now, run decoding on the checkpoint generated from our finetune run. +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" + +# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. +# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B. +# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. +python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/end_to_end/tpu/gemma/7b/1_test_gemma.sh b/end_to_end/tpu/gemma/7b/1_test_gemma.sh new file mode 100644 index 000000000..a718e844b --- /dev/null +++ b/end_to_end/tpu/gemma/7b/1_test_gemma.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# This file, combined with step 2 in the same directory, demonstrates converting a Gemma checkpoint from Kaggle and running various MaxText operations on it. +# This step is tested nightly on an ordinary CPU VM. + +# The flow of this file is as follows: +# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. +# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. + +# Example Usage: bash end_to_end/gemma/7b/1_test_gemma.sh +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +MODEL_VARIATION='7b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. +export CHKPT_BUCKET=gs://maxtext-gemma/flax +export MODEL_BUCKET=gs://maxtext-gemma +JAX_PLATFORMS=cpu python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} +echo "Writen MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}" + +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. +export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items +# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. +# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. +export RUN_NAME=unscanned_chkpt_${idx} +JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true +echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items" diff --git a/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/end_to_end/tpu/gemma/7b/2_test_gemma.sh new file mode 100644 index 000000000..6f8e37b79 --- /dev/null +++ b/end_to_end/tpu/gemma/7b/2_test_gemma.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma-7b. +# Please make sure you have run end_to_end/gemma/7b/1_test_gemma.sh before running commands from this file. + +# The flow of this file is as follows: +# 1. Run decoding, finetuning of Gemma 7B with the converted checkpoint obtained from end_to_end/gemma/7b/1_test_gemma.sh. Also, run pretraining of Gemma 7B +# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. +# 3. Run decoding from the finetuned checkpoint from step 1 +# 4. Ahead of Time Compilation for running Gemma 7B on v5e-256 + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +export MODEL_VARIATION='7b' + +# Non-Googlers please remember to MODEL_BUCKET to GCS bucket where this script uses internal buckets for testing. +export MODEL_BUCKET=gs://maxtext-gemma +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands +export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items +export RUN_NAME=unscanned_chkpt_${idx} +# We defined path to unscanned checkpoint created in 1_test_gemma.sh +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items + +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people’s faces" + +# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people's faces" + +# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning +export FINETUNE_RUN_NAME=runner_finetune_${idx} +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-7b checkpoint_period=5 + +# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b + +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. +# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding +export PARAM_RUN_NAME=param_chkpt_${idx} +python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true + +# Now, run decoding on the checkpoint generated from our finetune run. +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" + +# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. +# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. +# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. +python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/end_to_end/tpu/gemma/Run_Gemma.md b/end_to_end/tpu/gemma/Run_Gemma.md new file mode 100644 index 000000000..b099cf883 --- /dev/null +++ b/end_to_end/tpu/gemma/Run_Gemma.md @@ -0,0 +1,31 @@ + + +# Gemma +[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the art open models built from research and technology that we used to create the Gemini models. + +Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). + +After downloading the weights run [test_convert_chkpt.sh](https://github.com/google/maxtext/blob/main/end_to_end/gemma/test_convert_chkpt.sh), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma). + +## MaxText supports pretraining and finetuning with high performance + +Model Flop utilization for training on v5e and v5p TPUs. + +| Model | v5e-256 (bf16) | v5p-128 (bf16) | v5e-256 (int8) | v5p-128 (int8) | +| -------- | -------------- | -------------- | -------------- | -------------- | +| Gemma-2b | 58% | 55% | 64% | 68% | +| Gemma-7b | 58% | 60% | 70% | 70% | diff --git a/end_to_end/tpu/llama_finetuning_test.sh b/end_to_end/tpu/llama_finetuning_test.sh new file mode 100644 index 000000000..4758379dd --- /dev/null +++ b/end_to_end/tpu/llama_finetuning_test.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# This script is designed for internal use within Google. External users can adapt it by: +# - Updating GCS paths (gs://) to your accessible locations. +# - Using the checkpoint generated from train.py or available one in open source (https://llama.meta.com/llama-downloads/). + +set -e +idx=$(date +%Y-%m-%d-%H-%M) + +base_ckpt_path=gs://maxtext-llama/test/2024-01-15-06-49/decode-ckpt-maxtext/0/items +BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +DATASET_PATH=gs://maxtext-dataset + +export LOSS_THRESHOLD=2.5 + +python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt' + +# Assert training loss is smaller than input LOSS_THRESHOLD +python3 end_to_end/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD diff --git a/end_to_end/tpu/test_checkpoint_compatibility.sh b/end_to_end/tpu/test_checkpoint_compatibility.sh new file mode 100644 index 000000000..ae37801e9 --- /dev/null +++ b/end_to_end/tpu/test_checkpoint_compatibility.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set -ex + +if [ -f "run_*_metrics.txt" ]; then + rm run_*_metrics.txt + echo "removed existing run_*_metrics.txt" +fi + +RUN_NAME=${1}-$(date +%Y-%m-%d-%H-%M) +OUTPUT_PATH=${2} +DATASET_PATH=${3} +model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128" + +echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" +bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ + +echo "Run_1: Starting the first run using the grain input pipeline" + +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=3 ${model_params}\ + max_target_length=128 per_device_batch_size=1\ + metrics_file=run_1_metrics.txt checkpoint_period=2 async_checkpointing=false\ + dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ + dataset_type=c4-array_record grain_worker_count=0\ + dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1 + +echo +echo "Finished Run_1 at step 2" +echo "Run_2: Resuming using the tfds input pipeline" +echo + +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 ${model_params}\ + max_target_length=128 per_device_batch_size=1\ + metrics_file=run_2_metrics.txt checkpoint_period=2 async_checkpointing=false\ + dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ + +echo +echo "Finished Run_2 at step 4" +echo "Run_3: Resuming using the grain input pipeline" +echo + +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=7 ${model_params}\ + max_target_length=128 per_device_batch_size=1\ + metrics_file=run_3_metrics.txt checkpoint_period=2 async_checkpointing=false\ + dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ + dataset_type=c4-array_record grain_worker_count=0\ + dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1 + +python3 end_to_end/eval_assert.py test_start_step run_2_metrics.txt 3.0 +python3 end_to_end/eval_assert.py test_start_step run_3_metrics.txt 5.0 diff --git a/end_to_end/tpu/test_checkpoint_resharding.sh b/end_to_end/tpu/test_checkpoint_resharding.sh new file mode 100644 index 000000000..ae3b741a5 --- /dev/null +++ b/end_to_end/tpu/test_checkpoint_resharding.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -ex + +RUN_NAME=${1}_$(date +%Y-%m-%d-%H) +OUTPUT_PATH=${2} +DATASET_PATH=${3} + +# Train and save checkpoint - sharded with DCN Data Parallelism + ICI FSDP Parallelism +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=101\ + metrics_file='saved_metrics.txt' checkpoint_period=20 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ + dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=4 ici_tensor_parallelism=1 collect_stack_trace=False + +# Retrieve checkpoint - sharded with DCN Data Parallelism + ICI FSDP + Tensor Parallelism +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=102\ + metrics_file='restored_metrics.txt' base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ + dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 collect_stack_trace=False + +python3 end_to_end/eval_assert.py checkpoint_save_restore metrics.txt learning/loss diff --git a/end_to_end/tpu/test_checkpointing.sh b/end_to_end/tpu/test_checkpointing.sh new file mode 100644 index 000000000..e3fd9e6dd --- /dev/null +++ b/end_to_end/tpu/test_checkpointing.sh @@ -0,0 +1,63 @@ +#!/bin/bash +set -ex + +if [ -f "saved_metrics.txt" ]; then + rm saved_metrics.txt + echo "removed existing saved_metrics.txt" +fi + +if [ -f "restored_metrics.txt" ]; then + rm restored_metrics.txt + echo "removed existing restored_metrics.txt" +fi + +RUN_NAME=${1}-${4}-$(date +%Y-%m-%d-%H-%M) +OUTPUT_PATH=${2} +DATASET_PATH=${3} +COLLECT_STACK_TRACE=${4} +DATASET_TYPE=${5} +eval_metrics=checkpoint_save_restore +model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128" +CMD_DATA="" + +if [ "$DATASET_TYPE" == "c4-array_record" ] +then + eval_metrics=grain_checkpoint_save_restore + echo "Using c4-array_record dataset type" + echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" + bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ + DATASET_PATH=/tmp/gcsfuse/ + CMD_DATA=" grain_worker_count=0 dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" +fi + +#Train +CMD1="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\ + metrics_file=saved_metrics.txt checkpoint_period=3 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ + async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE" +CMD1+=$model_params +CMD1+=$CMD_DATA + +CMD2="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\ + metrics_file=restored_metrics.txt base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ + async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE" +CMD2+=$model_params +CMD2+=$CMD_DATA + +echo +echo "Start the first training run" +echo "Command is:" +echo $CMD1 + +$CMD1 +# Wait for first train to finish +# process_id=$! +# wait $process_id +echo +echo "First training run done" +echo "Start the second training run" +echo "Command is:" +echo $CMD2 + +$CMD2 + +python3 end_to_end/eval_assert.py $eval_metrics metrics.txt learning/loss diff --git a/end_to_end/tpu/test_convergence_1b_params.sh b/end_to_end/tpu/test_convergence_1b_params.sh new file mode 100644 index 000000000..38108b1c5 --- /dev/null +++ b/end_to_end/tpu/test_convergence_1b_params.sh @@ -0,0 +1,52 @@ +#!/bin/bash +set -ex + +echo "Running test_convergence_1b_params.sh" +# Run this on 64 chips to achieve a loss value of ~2.5 after 20400 steps, or ~2.7 after 10200 steps (v4-128) +# +# Command Flags: +# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) +# DATASET_PATH (Required, unless dataset_path is already set in base.yml) +# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) +# LOSS_THRESHOLD (Optional, default is 100.0 ) +# +# Example to invoke this script: +# bash end_to_end/test_convergence_1b_params.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" LOSS_THRESHOLD=100.0 + +export LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass. +export STEPS=20400 # Run for 20B tokens for a 1B sized mode for "chinchilla" scaling https://arxiv.org/abs/2203.15556 + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +if [ -n "$RUN_NAME" ]; +then + export M_RUN_NAME=$RUN_NAME +fi + +if [ "$DATASET_TYPE" == "c4-array_record" ] +then + EVAL_METRICS=grain_checkpoint_save_restore + echo "Using c4-array_record dataset type" + echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" + bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ + DATASET_PATH=/tmp/gcsfuse/ + CMD_DATA=" dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" +fi + +TRAIN_CMD="python3 MaxText/train.py MaxText/configs/base.yml\ + steps=$STEPS per_device_batch_size=8.0 learning_rate=3e-4 enable_checkpointing=false \ + max_target_length=2048 global_parameter_scale=1 \ + enable_profiler=false metrics_file=metrics.txt base_output_directory=$OUTPUT_PATH\ + dataset_path=$DATASET_PATH log_period=150 remat_policy=minimal enable_data_shuffling=false" +TRAIN_CMD+=$CMD_DATA + +# Train +export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +$TRAIN_CMD + +# Assert training loss is smaller than input LOSS_THRESHOLD +python3 end_to_end/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD diff --git a/end_to_end/tpu/test_decode.sh b/end_to_end/tpu/test_decode.sh new file mode 100644 index 000000000..fc432e025 --- /dev/null +++ b/end_to_end/tpu/test_decode.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -ex + +NUM_TOKEN_THRESHOLD=${1} +OUTPUT_PATH=${2} +DATASET_PATH=${3} +# Run name is optional 4th input - our daily XLML tests will use one. + + +if [ -z ${4} ] +then + RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S) +else + RUN_NAME=${4}_$(date +%Y-%m-%d-%H) +fi + +if [ -z ${5} ] +then + ICI_TENSOR_PARALLELISM=4 +else + ICI_TENSOR_PARALLELISM=${5} +fi + +# Decode without checkpoint +python3 MaxText/decode.py MaxText/configs/base.yml run_name=$RUN_NAME\ + steps=50 enable_checkpointing=False metrics_file=/tmp/${RUN_NAME}_metrics.txt \ + base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ + attention=dot_product ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} + + +# Get latest converted Gemma 2B checkpoint from internal GCS bucket +export GEMMA_2B_CKPT_PATH=$(gsutil ls gs://maxtext-gemma/2b | sort -r | head -1) +# Decode with different sampling strategies. +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=weighted decode_sampling_temperature=.00001 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=nucleus decode_sampling_nucleus_p=0 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=topk decode_sampling_top_k=1 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" diff --git a/end_to_end/tpu/test_determinism.sh b/end_to_end/tpu/test_determinism.sh new file mode 100644 index 000000000..9d9c066fc --- /dev/null +++ b/end_to_end/tpu/test_determinism.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -ex + +RUN_NAME=${1}_$(date +%Y-%m-%d-%H) +OUTPUT_PATH=${2} +DATASET_PATH=${3} +DATASET_TYPE=${4} + +if [ "$DATASET_TYPE" == "c4-array_record" ] +then + EVAL_METRICS=grain_checkpoint_save_restore + echo "Using c4-array_record dataset type" + echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" + bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ + DATASET_PATH=/tmp/gcsfuse/ + CMD_DATA=" dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" +fi + +#Train +CMD1="python3 MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME}_1 steps=5 metrics_file=run_1_metrics.txt\ + enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH" +CMD1+=$CMD_DATA + + +CMD2="python3 MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME}_2 steps=5 metrics_file=run_2_metrics.txt\ + enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH" +CMD2+=$CMD_DATA + +$CMD1 +$CMD2 +python3 end_to_end/eval_assert.py determinism metrics.txt learning/loss diff --git a/end_to_end/tpu/test_generate_param_only_checkpoint.sh b/end_to_end/tpu/test_generate_param_only_checkpoint.sh new file mode 100644 index 000000000..17d6c3463 --- /dev/null +++ b/end_to_end/tpu/test_generate_param_only_checkpoint.sh @@ -0,0 +1,123 @@ +#!/bin/bash + +set -uex + +helpFunction() +{ + echo "" + echo "Usage: $0 " + echo -e "\t-n dry_run is true " + echo -e "\t-r runid: run_test_model_0b" + echo -e "\t-d dataset_path: gs://test-maxtext-dataset" + echo -e "\t-o output_path: gs://test-maxtext-output" + echo -e "\t-i ici_tensor_parallelism: 8" + echo -e "\t-a attention: flash" + echo -e "\t-q quantization: int8" + exit 1 # Exit script after printing help +} + +# Default option values +dry_run=false +run_id=test_model_0b_$(date +%Y-%m-%d-%H) +dataset_path=gs://test-maxtext-dataset +base_output_directory=gs://test-maxtext-output +ici_tensor_parallelism=8 +attention=flash +quantization="" + +while getopts "nr:d:o:t:i:a:q:" opt +do + case "$opt" in + n ) dry_run=true ;; + r ) run_id="$OPTARG" ;; + d ) dataset_path="$OPTARG";; + o ) base_output_directory="$OPTARG";; + i ) ici_tensor_parallelism="$OPTARG" ;; + a ) attention="$OPTARG" ;; + q ) quantization="int8" ;; + ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent + esac +done + +echo +echo "Running: ./$0 dataset_path=${dataset_path} base_output_directory=${base_output_directory}" +echo " dry_run=${dry_run} run_id=${run_id} " +echo " ici_tensor_parallelism=${ici_tensor_parallelism} attention=${attention} quantization=${quantization}" +echo + +if "$dry_run"; then + cmd=echo +else + cmd='' +fi + +training_ckpt_run_id=${run_id}-ckpt-train-steps-5 +decode_ckpt_run_id=${run_id}-decode-ckpt-train-steps-5 +model_params="base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128" + +echo +echo "Create a test training checkpoint" +echo +$cmd python3 MaxText/train.py MaxText/configs/base.yml \ +run_name=${training_ckpt_run_id} \ +base_output_directory=${base_output_directory} \ +dataset_path=${dataset_path} attention=${attention} \ +steps=5 checkpoint_period=3 async_checkpointing=false \ +quantization=${quantization} \ +${model_params} \ + + +if [ $? -eq 0 ] +then + echo + echo "Successfully created a training checkpoint" + echo "Checkpoint path: ${base_output_directory}/${training_ckpt_run_id}/checkpoints/3/items" +else + echo + echo "Could not create a training checkpoint" >&2 + exit 1 +fi + +echo +echo "Generate a decode checkpoint from the test training checkpoint" +echo + +$cmd python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml \ +run_name=${decode_ckpt_run_id} attention=${attention} \ +base_output_directory=${base_output_directory} \ +dataset_path=${dataset_path} async_checkpointing=false \ +load_full_state_path=${base_output_directory}/${training_ckpt_run_id}/checkpoints/3/items \ +quantization=${quantization} \ +${model_params} \ + + +if [ $? -eq 0 ] +then + echo "Successfully created an decode checkpoint" + echo "Checkpoint path: ${base_output_directory}/${decode_ckpt_run_id}/checkpoints/0/items" + +else + echo "Could not create an decode checkpoint" >&2 + exit 1 +fi + +echo +echo "Run decode using the generated checkpoint" +echo +$cmd python3 MaxText/decode.py MaxText/configs/base.yml \ +run_name=${run_id}-decode-steps-50 \ +base_output_directory=${base_output_directory} \ +dataset_path=${dataset_path} \ +load_parameters_path=${base_output_directory}/${decode_ckpt_run_id}/checkpoints/0/items \ +attention=dot_product ici_tensor_parallelism=${ici_tensor_parallelism} steps=50 \ +metrics_file=/tmp/${run_id}_metrics.txt async_checkpointing=false max_target_length=128 per_device_batch_size=1 \ +quantization=${quantization} \ +${model_params} \ + +if [ $? -eq 0 ] +then + echo "Successfully ran decode using decode optimized checkpoint" +else + echo "Could not run decode decode optimized checkpoint" >&2 + exit 1 +fi diff --git a/end_to_end/tpu/test_gpt3.sh b/end_to_end/tpu/test_gpt3.sh new file mode 100644 index 000000000..7dd0b315a --- /dev/null +++ b/end_to_end/tpu/test_gpt3.sh @@ -0,0 +1,16 @@ +set -euox pipefail + +TIMESTAMP=$(date +%Y%m%d-%H%M) +export PAXML_CKPT_PATH=gs://maxtext-gpt3/ckpt_test/paxml/checkpoints/checkpoint_00000000/state +export OUTPUT_PATH=gs://maxtext-gpt3/tests +export RUN_NAME=test_${TIMESTAMP} + +# convert gpt3-52k model +python3 MaxText/convert_gpt3_ckpt_from_paxml.py --paxml-ckpt-path=${PAXML_CKPT_PATH} --maxtext-model-name=gpt3-52k --run-name=${RUN_NAME} --base-output-directory=${OUTPUT_PATH} + +# Run gpt3-52k with the converted ckpt +python3 MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} model_name=gpt3-52k\ + steps=10 per_device_batch_size=6 enable_checkpointing=true async_checkpointing=false\ + enable_profiler=false remat_policy=full\ + max_target_length=2048 base_output_directory=${OUTPUT_PATH}\ + dataset_type=synthetic diff --git a/end_to_end/tpu/test_llama2_7b.sh b/end_to_end/tpu/test_llama2_7b.sh new file mode 100644 index 000000000..c61663c77 --- /dev/null +++ b/end_to_end/tpu/test_llama2_7b.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Llama2-7b + +# The flow of this file is as follows: +# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText. +# 2. Run decoding, finetuning of Llama2-7b with this converted checkpoint. Also, run pretraining of Llama2-7b. +# 3. Run decoding from the finetuned weights +# 4. Convert the scanned checkpoint from step #1 into unscanned checkpoint format and run more efficient decoding. + + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset +export ASYNC_CHECKPOINTING=false + +# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU +pip install torch --index-url https://download.pytorch.org/whl/cpu + +# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint +export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-7b/meta-ckpt + +# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. +# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py` +gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ + +# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own +export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext + +#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` +python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH} + +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands +export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items + +# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. +# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. +export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx} +python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true + +# Like before, we define `UNSCANNED_CKPT_PATH` to refer to the checkpoint subdirectory exactly +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items + +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert` +python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product scan_layers=false + + +# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert` +python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product + +# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning +python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 checkpoint_period=5 + +# We also run pre-training of Llama2-7b, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from +python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 + +# Now, run decoding on the checkpoint generated from our finetune run. Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert +# the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run, say the checkpoint saved at finetuning step #5 +# Also, `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding +export PARAMETER_CHECKPOINT_RUN=generate_param_only_checkpoint_${idx} +python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true + +# Like before, we define `NEW_CKPT_PATH` to refer to the checkpoint subdirectory exactly +export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items + +# We run decoding on the fine-tuned parameter checkpoint +python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false diff --git a/end_to_end/tpu/test_mistral.sh b/end_to_end/tpu/test_mistral.sh new file mode 100644 index 000000000..2a80904d5 --- /dev/null +++ b/end_to_end/tpu/test_mistral.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# This script is designed for internal use within Google. External users can adapt it by: +# - Updating GCS paths (gs://) to your accessible locations. +# - Using the checkpoint generated from train.py or available one in open source (i.e. https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar). + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) + +export M_ENABLE_CHECKPOINTING=true +export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +export M_DATASET_PATH=gs://maxtext-dataset +export M_ASYNC_CHECKPOINTING=false + +# Download checkpoint, convert it to MaxText, and run inference +pip3 install torch +gsutil -m cp -r gs://maxtext-external/mistral-7B-v0.1 /tmp +python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/mistral-7B-v0.1 --model-size mistral-7b --maxtext-model-path gs://maxtext-mistral/test/${idx}/decode-ckpt-maxtext/ +python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mistral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_direct_${idx} per_device_batch_size=1 model_name='mistral-7b' tokenizer_path=gs://maxtext-external/mistral-7B-v0.1/tokenizer.mistral ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to read about the Bible. I love" attention=dot_product + +# Training +python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mistral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_${idx} per_device_batch_size=1 model_name='mistral-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 tokenizer_path=gs://maxtext-external/mistral-7B-v0.1/tokenizer.mistral diff --git a/end_to_end/tpu/test_mixtral.sh b/end_to_end/tpu/test_mixtral.sh new file mode 100644 index 000000000..f5c3fd97a --- /dev/null +++ b/end_to_end/tpu/test_mixtral.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# This script is designed for internal use within Google. External users can adapt it by: +# - Updating GCS paths (gs://) to your accessible locations. +# - Using the checkpoint generated from train.py or available one in open source (i.e. https://files.mixtral-8x7b-v0-1.mistral.ai/Mixtral-8x7B-v0.1-Instruct.tar). + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) + +export M_ENABLE_CHECKPOINTING=true +export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +export M_DATASET_PATH=gs://maxtext-dataset +export M_ASYNC_CHECKPOINTING=false + +# Download checkpoint, convert it to MaxText, and run inference +pip3 install torch +gsutil -m cp -r gs://maxtext-external/mixtral-8x7B-v0.1-Instruct /tmp +python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/mixtral-8x7B-v0.1-Instruct --model-size mixtral-8x7b --maxtext-model-path gs://maxtext-mixtral/test/${idx}/decode-ckpt-maxtext/ +python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mixtral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_direct_${idx} per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=28 prompt="[INST] I love to [/INST]" autoregressive_decode_assert="That's great to hear! I love to learn new things and explore different interests" attention=dot_product + +# Training +python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mixtral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_${idx} per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=10 max_target_length=1024 tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral diff --git a/end_to_end/tpu/test_tflops.sh b/end_to_end/tpu/test_tflops.sh new file mode 100644 index 000000000..f0543d95e --- /dev/null +++ b/end_to_end/tpu/test_tflops.sh @@ -0,0 +1,22 @@ +#!/bin/bash +set -ex + +USER=${1} +TFLOP_THRESHOLD=${2} +OUTPUT_PATH=${3} +DATASET_PATH=${4} + + +if [ -z ${5} ] +then + RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S) +else + RUN_NAME=${5}_$(date +%Y-%m-%d-%H) +fi + +#Train +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ + steps=150 reuse_example_batch=1 remat_policy='full' enable_profiler=True enable_checkpointing=False metrics_file='metrics.txt'\ + base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH log_period=150 + +python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_tflops_16b_params.sh b/end_to_end/tpu/test_tflops_16b_params.sh new file mode 100644 index 000000000..9e992305f --- /dev/null +++ b/end_to_end/tpu/test_tflops_16b_params.sh @@ -0,0 +1,38 @@ +#!/bin/bash +echo "Running test_tflops_16b_params.sh" + +# Command Flags: +# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) +# DATASET_PATH (Required, unless dataset_path is already set in base.yml) +# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) +# PLATFORM (Optional, can be "gke" or "gce", default is "gce") +# TFLOP_THRESHOLD (Optional, default is 0 ) +# +# Example to invoke this script: +# bash end_to_end/test_tflops_16b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 + +# Stop execution if any command exits with error +set -ex + +export TFLOP_THRESHOLD=0 +export PLATFORM="gce" + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +# Set up network optimizations +bash preflight.sh PLATFORM=$PLATFORM + +# Train +export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ + steps=150 per_device_batch_size=6 enable_checkpointing=false\ + enable_profiler=false remat_policy=full\ + max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ + dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=16 + +# Assert TFLOP/s +python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_tflops_32b_params.sh b/end_to_end/tpu/test_tflops_32b_params.sh new file mode 100644 index 000000000..59f0585a3 --- /dev/null +++ b/end_to_end/tpu/test_tflops_32b_params.sh @@ -0,0 +1,38 @@ +#!/bin/bash +echo "Running test_tflops_32b_params.sh" + +# Command Flags: +# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) +# DATASET_PATH (Required, unless dataset_path is already set in base.yml) +# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) +# PLATFORM (Optional, can be "gke" or "gce", default is "gce") +# TFLOP_THRESHOLD (Optional, default is 0 ) +# +# Example to invoke this script: +# bash end_to_end/test_tflops_32b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 + +# Stop execution if any command exits with error +set -ex + +export TFLOP_THRESHOLD=0 +export PLATFORM="gce" + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +# Set up network optimizations +bash preflight.sh PLATFORM=$PLATFORM + +# Train +export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ + steps=150 per_device_batch_size=4 enable_checkpointing=false\ + enable_profiler=false remat_policy=full\ + max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ + dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=32 + +# Assert TFLOP/s +python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_tflops_64b_params.sh b/end_to_end/tpu/test_tflops_64b_params.sh new file mode 100644 index 000000000..7c05d7413 --- /dev/null +++ b/end_to_end/tpu/test_tflops_64b_params.sh @@ -0,0 +1,38 @@ +#!/bin/bash +echo "Running test_tflops_64b_params.sh" + +# Command Flags: +# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) +# DATASET_PATH (Required, unless dataset_path is already set in base.yml) +# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) +# PLATFORM (Optional, can be "gke" or "gce", default is "gce") +# TFLOP_THRESHOLD (Optional, default is 0 ) +# +# Example to invoke this script: +# bash end_to_end/test_tflops_64b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 + +# Stop execution if any command exits with error +set -ex + +export TFLOP_THRESHOLD=0 +export PLATFORM="gce" + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +# Set up network optimizations +bash preflight.sh PLATFORM=$PLATFORM + +# Train +export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ + steps=150 per_device_batch_size=2 enable_checkpointing=false\ + enable_profiler=false remat_policy=full\ + max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ + dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=64 + +# Assert TFLOP/s +python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_vocab_creation.sh b/end_to_end/tpu/test_vocab_creation.sh new file mode 100644 index 000000000..dafd4f00f --- /dev/null +++ b/end_to_end/tpu/test_vocab_creation.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -ex + +RUN_NAME=${1}_$(date +%Y-%m-%d-%H) +OUTPUT_PATH=${2} +DATASET_PATH=${3} +VOCAB_PATH=$OUTPUT_PATH/vocab_test_creation_$RUN_NAME + + +#Train +python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 enable_checkpointing=False\ + base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH tokenizer_path=$VOCAB_PATH + +python3 end_to_end/eval_assert.py vocab_creation $VOCAB_PATH From 1a677f26a45d92ef07ec6b88cbe760019ace31b4 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Wed, 10 Apr 2024 14:15:06 -0700 Subject: [PATCH 17/26] squashed and rebased fp8 implementation --- .github/workflows/UnitTests.yml | 6 +++++- MaxText/configs/base.yml | 2 +- MaxText/decode.py | 2 +- MaxText/layers/linears.py | 3 ++- MaxText/layers/models.py | 1 + MaxText/layers/quantizations.py | 23 +++++++++++++++++++++++ 6 files changed, 33 insertions(+), 4 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index fb7698c12..c94344afe 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -98,7 +98,11 @@ jobs: - name: Test int8_training run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false' + 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false' + - name: Test fp8_training + run: | + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=fp8 steps=2 enable_checkpointing=false' - name: Test generate_param_only_checkpoint run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index d76e1b0e6..76b255565 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -59,7 +59,7 @@ save_config_to_gcs: False # Activation dtypes. dtype: "bfloat16" -quantization: "" #defaults to no quantization, i.e. bf16. possible alternative setting is 'int8' +quantization: "" #defaults to no quantization, i.e. bf16. possible alternative setting is 'int8' or use fp8 to run with 8-bit floating-point GeMMs on NVIDIA GPUs. quantize_kvcache: False # Shard the range finding operation for quantization. By default this is set to number of slices. diff --git a/MaxText/decode.py b/MaxText/decode.py index 1003d7701..15606b18e 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -35,7 +35,7 @@ def main(config): tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, prefill_lengths=[config.max_prefill_predict_length]) assert tokens.size <= config.max_prefill_predict_length, "can't take too many tokens" - + assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" prefill_result = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 01ee36a88..4cf3d1939 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -282,7 +282,8 @@ def __call__(self, inputs, deterministic: bool = False): dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=self.kernel_axes, - name='gate')(inputs) + name='gate', + quant=self.quant,)(inputs) weights, selected_experts = lax.top_k(gate_logits, self.num_experts_per_tok) weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index d15f76444..cc170650f 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -258,6 +258,7 @@ def __call__(self, 'cache': cache_spec, 'intermediates': 0, 'aqt':0, + '_overwrite_with_gradient': 0, }, split_rngs={ 'params': True, diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index e1e83a690..c6450ea17 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -20,12 +20,22 @@ from aqt.jax.v2.flax import aqt_flax from common_types import Array, Config from dataclasses import dataclass +import flax.linen as nn import jax import jax.numpy as jnp from jax.tree_util import tree_flatten_with_path, tree_unflatten MAX_INT8 = 127.5 +@dataclass +class Quantization: + """Base class for quantization configurations""" + + def dot_general_cls(self): + """ Placeholder for dot_general implementation in subclasses. """ + pass + + @dataclass class AqtQuantization: """ Configures AQT quantization github.com/google/aqt. """ @@ -50,6 +60,15 @@ def einsum(self): ) return aqt_einsum +@dataclass +class Fp8Quantization(Quantization): + """ Configures Fp8 quantization for NVIDIA GPUs""" + quant_mode = "train" + + def dot_general_cls(self): + """ Returns dot_general configured with aqt params. """ + return nn.Fp8DotGeneralOp + def _get_quant_config(config): """Set quantization params based on user configuration.""" if not config.quantization or config.quantization == '': @@ -74,6 +93,8 @@ def _get_quant_config(config): dlhs_accumulator_dtype=jnp.int32, drhs_accumulator_dtype=drhs_accumulator_dtype, ) + elif config.quantization == "fp8": + return "fp8" else: raise ValueError(f'Invalid value configured for quantization {config.quantization}.') @@ -99,6 +120,8 @@ def configure_quantization(config: Config, quant_mode_str: str = 'train'): """ Configure quantization based on user config and quant mode.""" quant_cfg = _get_quant_config(config) if quant_cfg: + if quant_cfg == "fp8": + return Fp8Quantization() quant_mode = get_quant_mode(quant_mode_str) return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode) return None From 4d3c94b8b09948948d5d317d6ad4e118fdcc3c29 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Thu, 11 Apr 2024 16:23:27 +0000 Subject: [PATCH 18/26] GHA unit test pinned mode only --- .github/workflows/UnitTests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 9c44769eb..6fcb22805 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -98,7 +98,7 @@ jobs: - name: Test int8_training run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false' + 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false' - name: Test fp8_training run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ @@ -130,7 +130,7 @@ jobs: fail-fast: false matrix: device-type: ["a100-40gb-4"] - build-mode: ["nightly", "stable", "pinned"] + build-mode: ["pinned"] name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})" runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"] env: From c8565be44418d87e213502fe1588c930c071f571 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Thu, 11 Apr 2024 16:48:07 +0000 Subject: [PATCH 19/26] Pin base image Also update constraints --- constraints_gpu.txt | 20 ++++++++++++++++++++ docker_build_dependency_image.sh | 8 ++++++-- maxtext_dependencies.Dockerfile | 1 + maxtext_gpu_dependencies.Dockerfile | 4 +++- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/constraints_gpu.txt b/constraints_gpu.txt index 0a86a9c0e..1991093fb 100644 --- a/constraints_gpu.txt +++ b/constraints_gpu.txt @@ -9,11 +9,15 @@ certifi==2024.2.2 charset-normalizer==3.3.2 chex==0.1.85 click==8.1.7 +cloud-accelerator-diagnostics==0.1.0 cloud-tpu-diagnostics==0.1.5 cloudpickle==3.0.0 +clu==0.0.12 contextlib2==21.6.0 dill==0.3.8 dm-tree==0.1.8 +docstring_parser==0.16 +editdistance==0.8.1 etils==1.7.0 exceptiongroup==1.2.0 flatbuffers==24.3.7 @@ -23,14 +27,23 @@ gast==0.4.0 google-api-core==2.17.1 google-auth==2.28.2 google-auth-oauthlib==1.0.0 +google-cloud-aiplatform==1.47.0 +google-cloud-appengine-logging==1.4.3 +google-cloud-audit-log==0.2.5 +google-cloud-bigquery==3.20.1 google-cloud-core==2.4.1 +google-cloud-logging==3.10.0 +google-cloud-resource-manager==1.12.3 google-cloud-storage==2.15.0 google-crc32c==1.5.0 +google-jetstream==0.2.0 google-pasta==0.2.0 google-resumable-media==2.7.0 googleapis-common-protos==1.63.0 grain-nightly==0.0.6 +grpc-google-iam-v1==0.13.0 grpcio==1.62.1 +grpcio-status==1.48.2 gviz-api==1.10.0 h5py==3.10.0 idna==3.6 @@ -53,6 +66,7 @@ mccabe==0.7.0 mdurl==0.1.2 ml-collections==0.1.1 ml-dtypes==0.3.2 +ml_goodput_measurement==0.0.2 mlperf-logging==3.0.0 more-itertools==10.2.0 msgpack==1.0.8 @@ -81,7 +95,9 @@ packaging==24.0 pandas==2.2.1 platformdirs==4.2.0 pluggy==1.4.0 +portpicker==1.6.0 promise==2.3 +proto-plus==1.23.0 protobuf==3.20.3 psutil==5.9.8 pyasn1==0.5.1 @@ -89,6 +105,7 @@ pyasn1-modules==0.3.0 pycnite==2023.10.11 pydantic==1.10.14 pydot==2.0.0 +pyglove==0.4.4 Pygments==2.17.2 pylint==3.1.0 pyparsing==3.1.2 @@ -103,6 +120,8 @@ rich==13.7.1 rsa==4.9 scipy==1.12.0 sentencepiece==0.1.97 +seqio==0.0.19 +shapely==2.0.3 six==1.16.0 tabulate==0.9.0 tensorboard==2.13.0 @@ -119,6 +138,7 @@ tensorflow-text==2.13.0 tensorstore==0.1.54 termcolor==2.4.0 tf-keras==2.15.0 +tfds-nightly==4.9.2.dev202308090034 toml==0.10.2 tomli==2.0.1 tomlkit==0.12.4 diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 3b5b8048b..6ce43225a 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -45,7 +45,6 @@ fi if [[ -z ${MODE} ]]; then export MODE=stable echo "Default MODE=${MODE}" - fi if [[ -z ${DEVICE} ]]; then @@ -57,7 +56,12 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then export LIBTPU_GCS_PATH=NONE echo "Default LIBTPU_GCS_PATH=${LIBTPU_GCS_PATH}" if [[ ${DEVICE} == "gpu" ]]; then - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + if [[ ${MODE} == "pinned" ]]; then + export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-03-13 + else + export BASEIMAGE=ghcr.io/nvidia/jax:base + fi + docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . else docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . fi diff --git a/maxtext_dependencies.Dockerfile b/maxtext_dependencies.Dockerfile index 5305de5ff..972765e20 100644 --- a/maxtext_dependencies.Dockerfile +++ b/maxtext_dependencies.Dockerfile @@ -1,3 +1,4 @@ +# syntax=docker/dockerfile:experimental # Use Python 3.10 as the base image FROM python:3.10-slim-bullseye diff --git a/maxtext_gpu_dependencies.Dockerfile b/maxtext_gpu_dependencies.Dockerfile index d19904857..c7d075519 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/maxtext_gpu_dependencies.Dockerfile @@ -1,4 +1,6 @@ -FROM ghcr.io/nvidia/jax:base +# syntax=docker/dockerfile:experimental +ARG BASEIMAGE=ghcr.io/nvidia/jax:base +FROM $BASEIMAGE # Install dependencies for adjusting network rto RUN apt-get update && apt-get install -y iproute2 ethtool lsof From 959c7f46bdcc39163cc7a082ab4e0b24280baa60 Mon Sep 17 00:00:00 2001 From: In-Ho Yi Date: Thu, 11 Apr 2024 16:49:21 +0000 Subject: [PATCH 20/26] Fix typo --- ...der.Dockerfile => maxtext_transformerengine_builder.Dockerfile | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename maxtext_transformerenginer_builder.Dockerfile => maxtext_transformerengine_builder.Dockerfile (100%) diff --git a/maxtext_transformerenginer_builder.Dockerfile b/maxtext_transformerengine_builder.Dockerfile similarity index 100% rename from maxtext_transformerenginer_builder.Dockerfile rename to maxtext_transformerengine_builder.Dockerfile From 8992c9e83484658c1fb786ecf47125cf40e135be Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Thu, 11 Apr 2024 18:24:56 +0000 Subject: [PATCH 21/26] Call max_utils.get_project() only when Vertex Tensorboard is enabled --- MaxText/tests/train_gpu_smoke_test.py | 1 - MaxText/tests/train_int8_smoke_test.py | 1 - MaxText/tests/train_smoke_test.py | 1 - MaxText/train.py | 3 ++- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/MaxText/tests/train_gpu_smoke_test.py b/MaxText/tests/train_gpu_smoke_test.py index 80a4676bb..d7f9df3a5 100644 --- a/MaxText/tests/train_gpu_smoke_test.py +++ b/MaxText/tests/train_gpu_smoke_test.py @@ -25,7 +25,6 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - os.environ["TENSORBOARD_PROJECT"] = "test-project" train_main([ None, "third_party/py/maxtext/configs/gpu_smoke_test.yml", diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 89b8e5b55..5efc5ebeb 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -25,7 +25,6 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - os.environ["TENSORBOARD_PROJECT"] = "test-project" train_main([None, "third_party/py/maxtext/configs/base.yml", f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", r"dataset_path=gs://maxtext-dataset", diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index 71a0c5813..8cd41fb33 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -25,7 +25,6 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - os.environ["TENSORBOARD_PROJECT"] = "test-project" train_main([None, "third_party/py/maxtext/configs/base.yml", f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", r"dataset_path=gs://maxtext-dataset", diff --git a/MaxText/train.py b/MaxText/train.py index 3faaca4f7..4a1966945 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -508,7 +508,8 @@ def main(argv: Sequence[str]) -> None: validate_train_config(config) os.environ["TFDS_DATA_DIR"] = config.dataset_path vertex_tensorboard_manager = VertexTensorboardManager() - vertex_tensorboard_manager.configure_vertex_tensorboard(config) + if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + vertex_tensorboard_manager.configure_vertex_tensorboard(config) debug_config = debug_configuration.DebugConfig( stack_trace_config = stack_trace_configuration.StackTraceConfig( From d6769933c1f80d741ff662e2e26c2b545837154e Mon Sep 17 00:00:00 2001 From: Nina Cai Date: Thu, 11 Apr 2024 21:27:14 +0000 Subject: [PATCH 22/26] Move tpu end-to-end test scripts to tpu folder --- .github/workflows/UnitTests.yml | 8 +- end_to_end/eval_assert.py | 139 ------------------ end_to_end/gemma/2b/test_gemma.sh | 64 -------- end_to_end/gemma/7b/1_test_gemma.sh | 31 ---- end_to_end/gemma/7b/2_test_gemma.sh | 55 ------- end_to_end/gemma/Run_Gemma.md | 31 ---- end_to_end/llama_finetuning_test.sh | 19 --- end_to_end/test_checkpoint_compatibility.sh | 49 ------ end_to_end/test_checkpoint_resharding.sh | 18 --- end_to_end/test_checkpointing.sh | 63 -------- end_to_end/test_convergence_1b_params.sh | 52 ------- end_to_end/test_decode.sh | 36 ----- end_to_end/test_determinism.sh | 31 ---- .../test_generate_param_only_checkpoint.sh | 123 ---------------- end_to_end/test_gpt3.sh | 16 -- end_to_end/test_llama2_7b.sh | 73 --------- end_to_end/test_mistral.sh | 22 --- end_to_end/test_mixtral.sh | 22 --- end_to_end/test_tflops.sh | 22 --- end_to_end/test_tflops_16b_params.sh | 38 ----- end_to_end/test_tflops_32b_params.sh | 38 ----- end_to_end/test_tflops_64b_params.sh | 38 ----- end_to_end/test_vocab_creation.sh | 14 -- end_to_end/tpu/gemma/7b/1_test_gemma.sh | 2 +- end_to_end/tpu/gemma/7b/2_test_gemma.sh | 4 +- end_to_end/tpu/gemma/Run_Gemma.md | 2 +- end_to_end/tpu/llama_finetuning_test.sh | 2 +- .../tpu/test_checkpoint_compatibility.sh | 4 +- end_to_end/tpu/test_checkpoint_resharding.sh | 2 +- end_to_end/tpu/test_checkpointing.sh | 2 +- end_to_end/tpu/test_convergence_1b_params.sh | 4 +- end_to_end/tpu/test_determinism.sh | 2 +- end_to_end/tpu/test_tflops.sh | 2 +- end_to_end/tpu/test_tflops_16b_params.sh | 4 +- end_to_end/tpu/test_tflops_32b_params.sh | 4 +- end_to_end/tpu/test_tflops_64b_params.sh | 4 +- end_to_end/tpu/test_vocab_creation.sh | 2 +- 37 files changed, 24 insertions(+), 1018 deletions(-) delete mode 100644 end_to_end/eval_assert.py delete mode 100644 end_to_end/gemma/2b/test_gemma.sh delete mode 100644 end_to_end/gemma/7b/1_test_gemma.sh delete mode 100644 end_to_end/gemma/7b/2_test_gemma.sh delete mode 100644 end_to_end/gemma/Run_Gemma.md delete mode 100644 end_to_end/llama_finetuning_test.sh delete mode 100644 end_to_end/test_checkpoint_compatibility.sh delete mode 100644 end_to_end/test_checkpoint_resharding.sh delete mode 100644 end_to_end/test_checkpointing.sh delete mode 100644 end_to_end/test_convergence_1b_params.sh delete mode 100644 end_to_end/test_decode.sh delete mode 100644 end_to_end/test_determinism.sh delete mode 100644 end_to_end/test_generate_param_only_checkpoint.sh delete mode 100644 end_to_end/test_gpt3.sh delete mode 100644 end_to_end/test_llama2_7b.sh delete mode 100644 end_to_end/test_mistral.sh delete mode 100644 end_to_end/test_mixtral.sh delete mode 100644 end_to_end/test_tflops.sh delete mode 100644 end_to_end/test_tflops_16b_params.sh delete mode 100644 end_to_end/test_tflops_32b_params.sh delete mode 100644 end_to_end/test_tflops_64b_params.sh delete mode 100644 end_to_end/test_vocab_creation.sh diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index c94344afe..a2feaf336 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -106,19 +106,19 @@ jobs: - name: Test generate_param_only_checkpoint run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4' + 'bash end_to_end/tpu/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4' - name: Test generate_param_only_checkpoint with int8 quantization run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8' + 'bash end_to_end/tpu/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8' - name: Test grain checkpoint determinism run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'bash end_to_end/test_checkpointing.sh runner gs://runner-maxtext-logs gs://maxtext-dataset False c4-array_record' + 'bash end_to_end/tpu/test_checkpointing.sh runner gs://runner-maxtext-logs gs://maxtext-dataset False c4-array_record' - name: Test checkpoint compatibility run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'bash end_to_end/test_checkpoint_compatibility.sh runner gs://runner-maxtext-logs gs://maxtext-dataset' + 'bash end_to_end/tpu/test_checkpoint_compatibility.sh runner gs://runner-maxtext-logs gs://maxtext-dataset' - name: Validate Pedagogical Example, Shmap_collective_matmul run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ diff --git a/end_to_end/eval_assert.py b/end_to_end/eval_assert.py deleted file mode 100644 index c879ce99f..000000000 --- a/end_to_end/eval_assert.py +++ /dev/null @@ -1,139 +0,0 @@ -""" - Copyright 2023 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -# pylint: skip-file -"""Reads and asserts over target values""" -from absl import app -from typing import Sequence -from math import isclose -from google.cloud import storage -import json - - -def compute_avg_metric(metrics_file, target, start_line=10): - """ Reads and computes average of target value - If start_line is negative then uses the last lines, e.g. start from end + 1 - |start_line|""" - - - avg = 0 - i = 0 - with open(metrics_file, 'r', encoding='utf8') as file: - lines = file.readlines() - if start_line < 0: - start_line = len(lines) + start_line - for line in lines: - # skip the first start_line lines for burn in - if i >= start_line: - vals = json.loads(line) - avg += vals[target] - i+=1 - avg /= (i-start_line) - - return avg - - -def assert_metric_average(metrics_file, threshold, target): - avg_value = compute_avg_metric(metrics_file, target) - # Checks for acceptable performance by asserting that the average metric (e.g. TFLOPs) - # is greater than the threshold. - print(f'avg value of target {target} is {avg_value}') - assert avg_value >= float(threshold) - print('assert metric average passed.') - -def test_final_loss(metrics_file, target_loss): - target_loss = float(target_loss) - with open(metrics_file, 'r', encoding='utf8') as metrics: - use_last_n_data = 10 - avg_final_loss = compute_avg_metric(metrics_file, 'learning/loss', start_line= -1 * use_last_n_data) - print(f"Mean of last {use_last_n_data} losses is {avg_final_loss}") - print(f"Target loss is {target_loss}") - assert avg_final_loss < target_loss - print('Final loss test passed.') - -def test_checkpointing(metrics_file, target, dataset_type): - """Asserts over loss values from loaded checkpoint""" - metrics_file_saved = 'saved_' + metrics_file - metrics_file_restored = 'restored_' + metrics_file - - with open(metrics_file_saved, 'r', encoding='utf8') as saved,\ - open(metrics_file_restored, 'r', encoding='utf8') as restored: - saved_loss = json.loads(saved.readlines()[-1])[target] - restored_loss = json.loads(restored.readlines()[0])[target] - # Checks that checkpoint restore was successful by comparing loss of last - # step in saved checkpoint to loss of first step in restored checkpoint - print("saved loss: ", saved_loss) - print("restored loss: ", restored_loss) - if dataset_type=='c4': - assert isclose(saved_loss, restored_loss, rel_tol=0.1) - elif dataset_type=='c4-array_record': - assert saved_loss==restored_loss - else: - raise ValueError(f"Unknown dataset_type {dataset_type}. dataset_type must be c4, c4-array_record or synthetic") - print('checkpointing test passed.') - -def test_determinism(metrics_file, target): - """Asserts over loss values from two runs""" - run_1 = 'run_1_' + metrics_file - run_2 = 'run_2_' + metrics_file - - with open(run_1, 'r', encoding='utf8') as run_1_file,\ - open(run_2, 'r', encoding='utf8') as run_2_file: - run_1_loss = json.loads(run_1_file.readlines()[-1])[target] - run_2_loss = json.loads(run_2_file.readlines()[-1])[target] - # Check that the two runs have the same loss - print(f"Run 1 loss:{run_1_loss}", flush=True) - print(f"Run 2 loss:{run_2_loss}", flush=True) - assert run_1_loss==run_2_loss - print('determinism test passed.') - -def test_vocab_creation(target): - bucket_name = target.split("/")[2] - vocab_path = "/".join(target.split("/")[3:]) - storage_client = storage.Client() - assert storage.Blob(bucket=storage_client.bucket(bucket_name), name=vocab_path).exists(storage_client) - print('vocab creation test passed.') - -def test_start_step(metrics_file, start_step_target): - with open(metrics_file, 'r', encoding='utf8') as metrics: - start_step = json.loads(metrics.readlines()[0])["step"] - print(f"Start step is {start_step}, start step target is {start_step_target}") - assert start_step==float(start_step_target) - print("Start step test passed.") - -def main(argv: Sequence[str]) -> None: - - _, test_scenario, *test_vars = argv - - if test_scenario == 'metrics_average': - assert_metric_average(*test_vars) - elif test_scenario == 'checkpoint_save_restore': - test_checkpointing(*test_vars, dataset_type='c4') - elif test_scenario == 'grain_checkpoint_save_restore': - test_checkpointing(*test_vars, dataset_type='c4-array_record') - elif test_scenario == 'determinism': - test_determinism(*test_vars) - elif test_scenario == 'vocab_creation': - test_vocab_creation(*test_vars) - elif test_scenario == 'final_loss': - test_final_loss(*test_vars) - elif test_scenario == 'test_start_step': - test_start_step(*test_vars) - else: - raise ValueError(f"Unrecognized test_scenario {test_scenario}") - - -if __name__ == "__main__": - app.run(main) diff --git a/end_to_end/gemma/2b/test_gemma.sh b/end_to_end/gemma/2b/test_gemma.sh deleted file mode 100644 index 74d776951..000000000 --- a/end_to_end/gemma/2b/test_gemma.sh +++ /dev/null @@ -1,64 +0,0 @@ -#!/bin/bash - -# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma-2b. - -# The flow of this file is as follows: -# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText -# 2. Run decoding, finetuning of Gemma 2B with the converted checkpoint. Also, run pretraining of Gemma 2B -# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. -# 4. Run decoding from the finetuned checkpoint from step 2 -# 5. Ahead of Time Compilation for running Gemma 2B on v5e-256 - - -set -ex -idx=$(date +%Y-%m-%d-%H-%M) -export MODEL_VARIATION='2b' - -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Non-Googlers please remember to use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). -# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. -export CHKPT_BUCKET=gs://maxtext-gemma/flax -export MODEL_BUCKET=gs://maxtext-gemma -python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} - -# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data -export DATASET_PATH=gs://maxtext-dataset -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs -# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items -export RUN_NAME=unscanned_chkpt_${idx} -# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. -# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-2b' force_unroll=true - -export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items - -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. -# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write about it" - -# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" - -# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning -export FINETUNE_RUN_NAME=runner_finetune_${idx} -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5 - -# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b - -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. -# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. -# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding -export PARAM_RUN_NAME=param_chkpt_${idx} -python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true - -# Now, run decoding on the checkpoint generated from our finetune run. -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" - -# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. -# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B. -# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. -python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/end_to_end/gemma/7b/1_test_gemma.sh b/end_to_end/gemma/7b/1_test_gemma.sh deleted file mode 100644 index a718e844b..000000000 --- a/end_to_end/gemma/7b/1_test_gemma.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -# This file, combined with step 2 in the same directory, demonstrates converting a Gemma checkpoint from Kaggle and running various MaxText operations on it. -# This step is tested nightly on an ordinary CPU VM. - -# The flow of this file is as follows: -# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. -# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. - -# Example Usage: bash end_to_end/gemma/7b/1_test_gemma.sh -set -ex -idx=$(date +%Y-%m-%d-%H-%M) -MODEL_VARIATION='7b' - -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). -# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. -export CHKPT_BUCKET=gs://maxtext-gemma/flax -export MODEL_BUCKET=gs://maxtext-gemma -JAX_PLATFORMS=cpu python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} -echo "Writen MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}" - -# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items -# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. -# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -export RUN_NAME=unscanned_chkpt_${idx} -JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true -echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items" diff --git a/end_to_end/gemma/7b/2_test_gemma.sh b/end_to_end/gemma/7b/2_test_gemma.sh deleted file mode 100644 index 6f8e37b79..000000000 --- a/end_to_end/gemma/7b/2_test_gemma.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma-7b. -# Please make sure you have run end_to_end/gemma/7b/1_test_gemma.sh before running commands from this file. - -# The flow of this file is as follows: -# 1. Run decoding, finetuning of Gemma 7B with the converted checkpoint obtained from end_to_end/gemma/7b/1_test_gemma.sh. Also, run pretraining of Gemma 7B -# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. -# 3. Run decoding from the finetuned checkpoint from step 1 -# 4. Ahead of Time Compilation for running Gemma 7B on v5e-256 - -set -ex -idx=$(date +%Y-%m-%d-%H-%M) -export MODEL_VARIATION='7b' - -# Non-Googlers please remember to MODEL_BUCKET to GCS bucket where this script uses internal buckets for testing. -export MODEL_BUCKET=gs://maxtext-gemma -# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data -export DATASET_PATH=gs://maxtext-dataset -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs -# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items -export RUN_NAME=unscanned_chkpt_${idx} -# We defined path to unscanned checkpoint created in 1_test_gemma.sh -export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items - -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. -# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people’s faces" - -# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people's faces" - -# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning -export FINETUNE_RUN_NAME=runner_finetune_${idx} -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-7b checkpoint_period=5 - -# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b - -# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. -# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. -# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding -export PARAM_RUN_NAME=param_chkpt_${idx} -python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true - -# Now, run decoding on the checkpoint generated from our finetune run. -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" - -# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. -# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. -# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. -python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/end_to_end/gemma/Run_Gemma.md b/end_to_end/gemma/Run_Gemma.md deleted file mode 100644 index 627cd1e7b..000000000 --- a/end_to_end/gemma/Run_Gemma.md +++ /dev/null @@ -1,31 +0,0 @@ - - -# Gemma -[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the art open models built from research and technology that we used to create the Gemini models. - -Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). - -After downloading the weights run [convert_gemma_chkpt.py](../../MaxText/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/gemma](../../end_to_end/gemma). - -## MaxText supports pretraining and finetuning with high performance - -Model Flop utilization for training on v5e and v5p TPUs. - -| Model | v5e-256 (bf16) | v5p-128 (bf16) | v5e-256 (int8) | v5p-128 (int8) | -| -------- | -------------- | -------------- | -------------- | -------------- | -| Gemma-2b | 58% | 55% | 64% | 68% | -| Gemma-7b | 58% | 60% | 70% | 70% | diff --git a/end_to_end/llama_finetuning_test.sh b/end_to_end/llama_finetuning_test.sh deleted file mode 100644 index 4758379dd..000000000 --- a/end_to_end/llama_finetuning_test.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# This script is designed for internal use within Google. External users can adapt it by: -# - Updating GCS paths (gs://) to your accessible locations. -# - Using the checkpoint generated from train.py or available one in open source (https://llama.meta.com/llama-downloads/). - -set -e -idx=$(date +%Y-%m-%d-%H-%M) - -base_ckpt_path=gs://maxtext-llama/test/2024-01-15-06-49/decode-ckpt-maxtext/0/items -BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs -DATASET_PATH=gs://maxtext-dataset - -export LOSS_THRESHOLD=2.5 - -python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt' - -# Assert training loss is smaller than input LOSS_THRESHOLD -python3 end_to_end/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD diff --git a/end_to_end/test_checkpoint_compatibility.sh b/end_to_end/test_checkpoint_compatibility.sh deleted file mode 100644 index ae37801e9..000000000 --- a/end_to_end/test_checkpoint_compatibility.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -set -ex - -if [ -f "run_*_metrics.txt" ]; then - rm run_*_metrics.txt - echo "removed existing run_*_metrics.txt" -fi - -RUN_NAME=${1}-$(date +%Y-%m-%d-%H-%M) -OUTPUT_PATH=${2} -DATASET_PATH=${3} -model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128" - -echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" -bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ - -echo "Run_1: Starting the first run using the grain input pipeline" - -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=3 ${model_params}\ - max_target_length=128 per_device_batch_size=1\ - metrics_file=run_1_metrics.txt checkpoint_period=2 async_checkpointing=false\ - dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ - dataset_type=c4-array_record grain_worker_count=0\ - dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1 - -echo -echo "Finished Run_1 at step 2" -echo "Run_2: Resuming using the tfds input pipeline" -echo - -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 ${model_params}\ - max_target_length=128 per_device_batch_size=1\ - metrics_file=run_2_metrics.txt checkpoint_period=2 async_checkpointing=false\ - dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ - -echo -echo "Finished Run_2 at step 4" -echo "Run_3: Resuming using the grain input pipeline" -echo - -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=7 ${model_params}\ - max_target_length=128 per_device_batch_size=1\ - metrics_file=run_3_metrics.txt checkpoint_period=2 async_checkpointing=false\ - dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ - dataset_type=c4-array_record grain_worker_count=0\ - dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1 - -python3 end_to_end/eval_assert.py test_start_step run_2_metrics.txt 3.0 -python3 end_to_end/eval_assert.py test_start_step run_3_metrics.txt 5.0 diff --git a/end_to_end/test_checkpoint_resharding.sh b/end_to_end/test_checkpoint_resharding.sh deleted file mode 100644 index ae3b741a5..000000000 --- a/end_to_end/test_checkpoint_resharding.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -set -ex - -RUN_NAME=${1}_$(date +%Y-%m-%d-%H) -OUTPUT_PATH=${2} -DATASET_PATH=${3} - -# Train and save checkpoint - sharded with DCN Data Parallelism + ICI FSDP Parallelism -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=101\ - metrics_file='saved_metrics.txt' checkpoint_period=20 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=4 ici_tensor_parallelism=1 collect_stack_trace=False - -# Retrieve checkpoint - sharded with DCN Data Parallelism + ICI FSDP + Tensor Parallelism -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=102\ - metrics_file='restored_metrics.txt' base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 collect_stack_trace=False - -python3 end_to_end/eval_assert.py checkpoint_save_restore metrics.txt learning/loss diff --git a/end_to_end/test_checkpointing.sh b/end_to_end/test_checkpointing.sh deleted file mode 100644 index e3fd9e6dd..000000000 --- a/end_to_end/test_checkpointing.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash -set -ex - -if [ -f "saved_metrics.txt" ]; then - rm saved_metrics.txt - echo "removed existing saved_metrics.txt" -fi - -if [ -f "restored_metrics.txt" ]; then - rm restored_metrics.txt - echo "removed existing restored_metrics.txt" -fi - -RUN_NAME=${1}-${4}-$(date +%Y-%m-%d-%H-%M) -OUTPUT_PATH=${2} -DATASET_PATH=${3} -COLLECT_STACK_TRACE=${4} -DATASET_TYPE=${5} -eval_metrics=checkpoint_save_restore -model_params=" base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128" -CMD_DATA="" - -if [ "$DATASET_TYPE" == "c4-array_record" ] -then - eval_metrics=grain_checkpoint_save_restore - echo "Using c4-array_record dataset type" - echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" - bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ - DATASET_PATH=/tmp/gcsfuse/ - CMD_DATA=" grain_worker_count=0 dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" -fi - -#Train -CMD1="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\ - metrics_file=saved_metrics.txt checkpoint_period=3 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE" -CMD1+=$model_params -CMD1+=$CMD_DATA - -CMD2="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\ - metrics_file=restored_metrics.txt base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ - async_checkpointing=false collect_stack_trace=$COLLECT_STACK_TRACE" -CMD2+=$model_params -CMD2+=$CMD_DATA - -echo -echo "Start the first training run" -echo "Command is:" -echo $CMD1 - -$CMD1 -# Wait for first train to finish -# process_id=$! -# wait $process_id -echo -echo "First training run done" -echo "Start the second training run" -echo "Command is:" -echo $CMD2 - -$CMD2 - -python3 end_to_end/eval_assert.py $eval_metrics metrics.txt learning/loss diff --git a/end_to_end/test_convergence_1b_params.sh b/end_to_end/test_convergence_1b_params.sh deleted file mode 100644 index 38108b1c5..000000000 --- a/end_to_end/test_convergence_1b_params.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash -set -ex - -echo "Running test_convergence_1b_params.sh" -# Run this on 64 chips to achieve a loss value of ~2.5 after 20400 steps, or ~2.7 after 10200 steps (v4-128) -# -# Command Flags: -# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) -# DATASET_PATH (Required, unless dataset_path is already set in base.yml) -# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) -# LOSS_THRESHOLD (Optional, default is 100.0 ) -# -# Example to invoke this script: -# bash end_to_end/test_convergence_1b_params.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" LOSS_THRESHOLD=100.0 - -export LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass. -export STEPS=20400 # Run for 20B tokens for a 1B sized mode for "chinchilla" scaling https://arxiv.org/abs/2203.15556 - -# Set environment variables -for ARGUMENT in "$@"; do - IFS='=' read -r KEY VALUE <<< "$ARGUMENT" - export "$KEY"="$VALUE" -done - -if [ -n "$RUN_NAME" ]; -then - export M_RUN_NAME=$RUN_NAME -fi - -if [ "$DATASET_TYPE" == "c4-array_record" ] -then - EVAL_METRICS=grain_checkpoint_save_restore - echo "Using c4-array_record dataset type" - echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" - bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ - DATASET_PATH=/tmp/gcsfuse/ - CMD_DATA=" dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" -fi - -TRAIN_CMD="python3 MaxText/train.py MaxText/configs/base.yml\ - steps=$STEPS per_device_batch_size=8.0 learning_rate=3e-4 enable_checkpointing=false \ - max_target_length=2048 global_parameter_scale=1 \ - enable_profiler=false metrics_file=metrics.txt base_output_directory=$OUTPUT_PATH\ - dataset_path=$DATASET_PATH log_period=150 remat_policy=minimal enable_data_shuffling=false" -TRAIN_CMD+=$CMD_DATA - -# Train -export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -$TRAIN_CMD - -# Assert training loss is smaller than input LOSS_THRESHOLD -python3 end_to_end/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD diff --git a/end_to_end/test_decode.sh b/end_to_end/test_decode.sh deleted file mode 100644 index fc432e025..000000000 --- a/end_to_end/test_decode.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash -set -ex - -NUM_TOKEN_THRESHOLD=${1} -OUTPUT_PATH=${2} -DATASET_PATH=${3} -# Run name is optional 4th input - our daily XLML tests will use one. - - -if [ -z ${4} ] -then - RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S) -else - RUN_NAME=${4}_$(date +%Y-%m-%d-%H) -fi - -if [ -z ${5} ] -then - ICI_TENSOR_PARALLELISM=4 -else - ICI_TENSOR_PARALLELISM=${5} -fi - -# Decode without checkpoint -python3 MaxText/decode.py MaxText/configs/base.yml run_name=$RUN_NAME\ - steps=50 enable_checkpointing=False metrics_file=/tmp/${RUN_NAME}_metrics.txt \ - base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ - attention=dot_product ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} - - -# Get latest converted Gemma 2B checkpoint from internal GCS bucket -export GEMMA_2B_CKPT_PATH=$(gsutil ls gs://maxtext-gemma/2b | sort -r | head -1) -# Decode with different sampling strategies. -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=weighted decode_sampling_temperature=.00001 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=nucleus decode_sampling_nucleus_p=0 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=topk decode_sampling_top_k=1 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" diff --git a/end_to_end/test_determinism.sh b/end_to_end/test_determinism.sh deleted file mode 100644 index 9d9c066fc..000000000 --- a/end_to_end/test_determinism.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -set -ex - -RUN_NAME=${1}_$(date +%Y-%m-%d-%H) -OUTPUT_PATH=${2} -DATASET_PATH=${3} -DATASET_TYPE=${4} - -if [ "$DATASET_TYPE" == "c4-array_record" ] -then - EVAL_METRICS=grain_checkpoint_save_restore - echo "Using c4-array_record dataset type" - echo "Mounting $DATASET_PATH to /tmp/gcsfuse/" - bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/tmp/gcsfuse/ - DATASET_PATH=/tmp/gcsfuse/ - CMD_DATA=" dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" -fi - -#Train -CMD1="python3 MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME}_1 steps=5 metrics_file=run_1_metrics.txt\ - enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH" -CMD1+=$CMD_DATA - - -CMD2="python3 MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME}_2 steps=5 metrics_file=run_2_metrics.txt\ - enable_checkpointing=False enable_data_shuffling=True enable_dropout=False base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH" -CMD2+=$CMD_DATA - -$CMD1 -$CMD2 -python3 end_to_end/eval_assert.py determinism metrics.txt learning/loss diff --git a/end_to_end/test_generate_param_only_checkpoint.sh b/end_to_end/test_generate_param_only_checkpoint.sh deleted file mode 100644 index 17d6c3463..000000000 --- a/end_to_end/test_generate_param_only_checkpoint.sh +++ /dev/null @@ -1,123 +0,0 @@ -#!/bin/bash - -set -uex - -helpFunction() -{ - echo "" - echo "Usage: $0 " - echo -e "\t-n dry_run is true " - echo -e "\t-r runid: run_test_model_0b" - echo -e "\t-d dataset_path: gs://test-maxtext-dataset" - echo -e "\t-o output_path: gs://test-maxtext-output" - echo -e "\t-i ici_tensor_parallelism: 8" - echo -e "\t-a attention: flash" - echo -e "\t-q quantization: int8" - exit 1 # Exit script after printing help -} - -# Default option values -dry_run=false -run_id=test_model_0b_$(date +%Y-%m-%d-%H) -dataset_path=gs://test-maxtext-dataset -base_output_directory=gs://test-maxtext-output -ici_tensor_parallelism=8 -attention=flash -quantization="" - -while getopts "nr:d:o:t:i:a:q:" opt -do - case "$opt" in - n ) dry_run=true ;; - r ) run_id="$OPTARG" ;; - d ) dataset_path="$OPTARG";; - o ) base_output_directory="$OPTARG";; - i ) ici_tensor_parallelism="$OPTARG" ;; - a ) attention="$OPTARG" ;; - q ) quantization="int8" ;; - ? ) helpFunction ;; # Print helpFunction in case parameter is non-existent - esac -done - -echo -echo "Running: ./$0 dataset_path=${dataset_path} base_output_directory=${base_output_directory}" -echo " dry_run=${dry_run} run_id=${run_id} " -echo " ici_tensor_parallelism=${ici_tensor_parallelism} attention=${attention} quantization=${quantization}" -echo - -if "$dry_run"; then - cmd=echo -else - cmd='' -fi - -training_ckpt_run_id=${run_id}-ckpt-train-steps-5 -decode_ckpt_run_id=${run_id}-decode-ckpt-train-steps-5 -model_params="base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=192 base_num_decoder_layers=8 head_dim=128" - -echo -echo "Create a test training checkpoint" -echo -$cmd python3 MaxText/train.py MaxText/configs/base.yml \ -run_name=${training_ckpt_run_id} \ -base_output_directory=${base_output_directory} \ -dataset_path=${dataset_path} attention=${attention} \ -steps=5 checkpoint_period=3 async_checkpointing=false \ -quantization=${quantization} \ -${model_params} \ - - -if [ $? -eq 0 ] -then - echo - echo "Successfully created a training checkpoint" - echo "Checkpoint path: ${base_output_directory}/${training_ckpt_run_id}/checkpoints/3/items" -else - echo - echo "Could not create a training checkpoint" >&2 - exit 1 -fi - -echo -echo "Generate a decode checkpoint from the test training checkpoint" -echo - -$cmd python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml \ -run_name=${decode_ckpt_run_id} attention=${attention} \ -base_output_directory=${base_output_directory} \ -dataset_path=${dataset_path} async_checkpointing=false \ -load_full_state_path=${base_output_directory}/${training_ckpt_run_id}/checkpoints/3/items \ -quantization=${quantization} \ -${model_params} \ - - -if [ $? -eq 0 ] -then - echo "Successfully created an decode checkpoint" - echo "Checkpoint path: ${base_output_directory}/${decode_ckpt_run_id}/checkpoints/0/items" - -else - echo "Could not create an decode checkpoint" >&2 - exit 1 -fi - -echo -echo "Run decode using the generated checkpoint" -echo -$cmd python3 MaxText/decode.py MaxText/configs/base.yml \ -run_name=${run_id}-decode-steps-50 \ -base_output_directory=${base_output_directory} \ -dataset_path=${dataset_path} \ -load_parameters_path=${base_output_directory}/${decode_ckpt_run_id}/checkpoints/0/items \ -attention=dot_product ici_tensor_parallelism=${ici_tensor_parallelism} steps=50 \ -metrics_file=/tmp/${run_id}_metrics.txt async_checkpointing=false max_target_length=128 per_device_batch_size=1 \ -quantization=${quantization} \ -${model_params} \ - -if [ $? -eq 0 ] -then - echo "Successfully ran decode using decode optimized checkpoint" -else - echo "Could not run decode decode optimized checkpoint" >&2 - exit 1 -fi diff --git a/end_to_end/test_gpt3.sh b/end_to_end/test_gpt3.sh deleted file mode 100644 index 7dd0b315a..000000000 --- a/end_to_end/test_gpt3.sh +++ /dev/null @@ -1,16 +0,0 @@ -set -euox pipefail - -TIMESTAMP=$(date +%Y%m%d-%H%M) -export PAXML_CKPT_PATH=gs://maxtext-gpt3/ckpt_test/paxml/checkpoints/checkpoint_00000000/state -export OUTPUT_PATH=gs://maxtext-gpt3/tests -export RUN_NAME=test_${TIMESTAMP} - -# convert gpt3-52k model -python3 MaxText/convert_gpt3_ckpt_from_paxml.py --paxml-ckpt-path=${PAXML_CKPT_PATH} --maxtext-model-name=gpt3-52k --run-name=${RUN_NAME} --base-output-directory=${OUTPUT_PATH} - -# Run gpt3-52k with the converted ckpt -python3 MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} model_name=gpt3-52k\ - steps=10 per_device_batch_size=6 enable_checkpointing=true async_checkpointing=false\ - enable_profiler=false remat_policy=full\ - max_target_length=2048 base_output_directory=${OUTPUT_PATH}\ - dataset_type=synthetic diff --git a/end_to_end/test_llama2_7b.sh b/end_to_end/test_llama2_7b.sh deleted file mode 100644 index c61663c77..000000000 --- a/end_to_end/test_llama2_7b.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash - -# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Llama2-7b - -# The flow of this file is as follows: -# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText. -# 2. Run decoding, finetuning of Llama2-7b with this converted checkpoint. Also, run pretraining of Llama2-7b. -# 3. Run decoding from the finetuned weights -# 4. Convert the scanned checkpoint from step #1 into unscanned checkpoint format and run more efficient decoding. - - -set -ex -idx=$(date +%Y-%m-%d-%H-%M) - -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs -# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data -export DATASET_PATH=gs://maxtext-dataset -export ASYNC_CHECKPOINTING=false - -# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU -pip install torch --index-url https://download.pytorch.org/whl/cpu - -# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint -export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-7b/meta-ckpt - -# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. -# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py` -gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ - -# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own -export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext - -#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` -python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH} - -# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands -export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items - -# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. -# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx} -python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true - -# Like before, we define `UNSCANNED_CKPT_PATH` to refer to the checkpoint subdirectory exactly -export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items - -# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert` -python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product scan_layers=false - - -# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# We compare our decoded results by asserting with golden PyTorch outputs using `autoregressive_decode_assert` -python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to share." attention=dot_product - -# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning -python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 checkpoint_period=5 - -# We also run pre-training of Llama2-7b, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 - -# Now, run decoding on the checkpoint generated from our finetune run. Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert -# the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run, say the checkpoint saved at finetuning step #5 -# Also, `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding -export PARAMETER_CHECKPOINT_RUN=generate_param_only_checkpoint_${idx} -python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true - -# Like before, we define `NEW_CKPT_PATH` to refer to the checkpoint subdirectory exactly -export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items - -# We run decoding on the fine-tuned parameter checkpoint -python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false diff --git a/end_to_end/test_mistral.sh b/end_to_end/test_mistral.sh deleted file mode 100644 index 2a80904d5..000000000 --- a/end_to_end/test_mistral.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -# This script is designed for internal use within Google. External users can adapt it by: -# - Updating GCS paths (gs://) to your accessible locations. -# - Using the checkpoint generated from train.py or available one in open source (i.e. https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar). - -set -ex -idx=$(date +%Y-%m-%d-%H-%M) - -export M_ENABLE_CHECKPOINTING=true -export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs -export M_DATASET_PATH=gs://maxtext-dataset -export M_ASYNC_CHECKPOINTING=false - -# Download checkpoint, convert it to MaxText, and run inference -pip3 install torch -gsutil -m cp -r gs://maxtext-external/mistral-7B-v0.1 /tmp -python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/mistral-7B-v0.1 --model-size mistral-7b --maxtext-model-path gs://maxtext-mistral/test/${idx}/decode-ckpt-maxtext/ -python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mistral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_direct_${idx} per_device_batch_size=1 model_name='mistral-7b' tokenizer_path=gs://maxtext-external/mistral-7B-v0.1/tokenizer.mistral ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" autoregressive_decode_assert="read. I love to read about the Bible. I love" attention=dot_product - -# Training -python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mistral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_${idx} per_device_batch_size=1 model_name='mistral-7b' ici_tensor_parallelism=4 steps=10 max_target_length=1024 tokenizer_path=gs://maxtext-external/mistral-7B-v0.1/tokenizer.mistral diff --git a/end_to_end/test_mixtral.sh b/end_to_end/test_mixtral.sh deleted file mode 100644 index f5c3fd97a..000000000 --- a/end_to_end/test_mixtral.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -# This script is designed for internal use within Google. External users can adapt it by: -# - Updating GCS paths (gs://) to your accessible locations. -# - Using the checkpoint generated from train.py or available one in open source (i.e. https://files.mixtral-8x7b-v0-1.mistral.ai/Mixtral-8x7B-v0.1-Instruct.tar). - -set -ex -idx=$(date +%Y-%m-%d-%H-%M) - -export M_ENABLE_CHECKPOINTING=true -export M_BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs -export M_DATASET_PATH=gs://maxtext-dataset -export M_ASYNC_CHECKPOINTING=false - -# Download checkpoint, convert it to MaxText, and run inference -pip3 install torch -gsutil -m cp -r gs://maxtext-external/mixtral-8x7B-v0.1-Instruct /tmp -python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/mixtral-8x7B-v0.1-Instruct --model-size mixtral-8x7b --maxtext-model-path gs://maxtext-mixtral/test/${idx}/decode-ckpt-maxtext/ -python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mixtral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_direct_${idx} per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=28 prompt="[INST] I love to [/INST]" autoregressive_decode_assert="That's great to hear! I love to learn new things and explore different interests" attention=dot_product - -# Training -python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=gs://maxtext-mixtral/test/${idx}/decode-ckpt-maxtext/0/items run_name=runner_${idx} per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=10 max_target_length=1024 tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral diff --git a/end_to_end/test_tflops.sh b/end_to_end/test_tflops.sh deleted file mode 100644 index f0543d95e..000000000 --- a/end_to_end/test_tflops.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -set -ex - -USER=${1} -TFLOP_THRESHOLD=${2} -OUTPUT_PATH=${3} -DATASET_PATH=${4} - - -if [ -z ${5} ] -then - RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S) -else - RUN_NAME=${5}_$(date +%Y-%m-%d-%H) -fi - -#Train -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ - steps=150 reuse_example_batch=1 remat_policy='full' enable_profiler=True enable_checkpointing=False metrics_file='metrics.txt'\ - base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH log_period=150 - -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/test_tflops_16b_params.sh b/end_to_end/test_tflops_16b_params.sh deleted file mode 100644 index 9e992305f..000000000 --- a/end_to_end/test_tflops_16b_params.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -echo "Running test_tflops_16b_params.sh" - -# Command Flags: -# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) -# DATASET_PATH (Required, unless dataset_path is already set in base.yml) -# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) -# PLATFORM (Optional, can be "gke" or "gce", default is "gce") -# TFLOP_THRESHOLD (Optional, default is 0 ) -# -# Example to invoke this script: -# bash end_to_end/test_tflops_16b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 - -# Stop execution if any command exits with error -set -ex - -export TFLOP_THRESHOLD=0 -export PLATFORM="gce" - -# Set environment variables -for ARGUMENT in "$@"; do - IFS='=' read -r KEY VALUE <<< "$ARGUMENT" - export "$KEY"="$VALUE" -done - -# Set up network optimizations -bash preflight.sh PLATFORM=$PLATFORM - -# Train -export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ - steps=150 per_device_batch_size=6 enable_checkpointing=false\ - enable_profiler=false remat_policy=full\ - max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ - dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=16 - -# Assert TFLOP/s -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/test_tflops_32b_params.sh b/end_to_end/test_tflops_32b_params.sh deleted file mode 100644 index 59f0585a3..000000000 --- a/end_to_end/test_tflops_32b_params.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -echo "Running test_tflops_32b_params.sh" - -# Command Flags: -# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) -# DATASET_PATH (Required, unless dataset_path is already set in base.yml) -# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) -# PLATFORM (Optional, can be "gke" or "gce", default is "gce") -# TFLOP_THRESHOLD (Optional, default is 0 ) -# -# Example to invoke this script: -# bash end_to_end/test_tflops_32b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 - -# Stop execution if any command exits with error -set -ex - -export TFLOP_THRESHOLD=0 -export PLATFORM="gce" - -# Set environment variables -for ARGUMENT in "$@"; do - IFS='=' read -r KEY VALUE <<< "$ARGUMENT" - export "$KEY"="$VALUE" -done - -# Set up network optimizations -bash preflight.sh PLATFORM=$PLATFORM - -# Train -export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ - steps=150 per_device_batch_size=4 enable_checkpointing=false\ - enable_profiler=false remat_policy=full\ - max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ - dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=32 - -# Assert TFLOP/s -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/test_tflops_64b_params.sh b/end_to_end/test_tflops_64b_params.sh deleted file mode 100644 index 7c05d7413..000000000 --- a/end_to_end/test_tflops_64b_params.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -echo "Running test_tflops_64b_params.sh" - -# Command Flags: -# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) -# DATASET_PATH (Required, unless dataset_path is already set in base.yml) -# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) -# PLATFORM (Optional, can be "gke" or "gce", default is "gce") -# TFLOP_THRESHOLD (Optional, default is 0 ) -# -# Example to invoke this script: -# bash end_to_end/test_tflops_64b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 - -# Stop execution if any command exits with error -set -ex - -export TFLOP_THRESHOLD=0 -export PLATFORM="gce" - -# Set environment variables -for ARGUMENT in "$@"; do - IFS='=' read -r KEY VALUE <<< "$ARGUMENT" - export "$KEY"="$VALUE" -done - -# Set up network optimizations -bash preflight.sh PLATFORM=$PLATFORM - -# Train -export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ - steps=150 per_device_batch_size=2 enable_checkpointing=false\ - enable_profiler=false remat_policy=full\ - max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ - dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=64 - -# Assert TFLOP/s -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/test_vocab_creation.sh b/end_to_end/test_vocab_creation.sh deleted file mode 100644 index dafd4f00f..000000000 --- a/end_to_end/test_vocab_creation.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash -set -ex - -RUN_NAME=${1}_$(date +%Y-%m-%d-%H) -OUTPUT_PATH=${2} -DATASET_PATH=${3} -VOCAB_PATH=$OUTPUT_PATH/vocab_test_creation_$RUN_NAME - - -#Train -python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 enable_checkpointing=False\ - base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH tokenizer_path=$VOCAB_PATH - -python3 end_to_end/eval_assert.py vocab_creation $VOCAB_PATH diff --git a/end_to_end/tpu/gemma/7b/1_test_gemma.sh b/end_to_end/tpu/gemma/7b/1_test_gemma.sh index a718e844b..2b7a30b65 100644 --- a/end_to_end/tpu/gemma/7b/1_test_gemma.sh +++ b/end_to_end/tpu/gemma/7b/1_test_gemma.sh @@ -7,7 +7,7 @@ # 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. # 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. -# Example Usage: bash end_to_end/gemma/7b/1_test_gemma.sh +# Example Usage: bash end_to_end/tpu/gemma/7b/1_test_gemma.sh set -ex idx=$(date +%Y-%m-%d-%H-%M) MODEL_VARIATION='7b' diff --git a/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/end_to_end/tpu/gemma/7b/2_test_gemma.sh index 6f8e37b79..7353b6e38 100644 --- a/end_to_end/tpu/gemma/7b/2_test_gemma.sh +++ b/end_to_end/tpu/gemma/7b/2_test_gemma.sh @@ -1,10 +1,10 @@ #!/bin/bash # This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma-7b. -# Please make sure you have run end_to_end/gemma/7b/1_test_gemma.sh before running commands from this file. +# Please make sure you have run end_to_end/tpu/gemma/7b/1_test_gemma.sh before running commands from this file. # The flow of this file is as follows: -# 1. Run decoding, finetuning of Gemma 7B with the converted checkpoint obtained from end_to_end/gemma/7b/1_test_gemma.sh. Also, run pretraining of Gemma 7B +# 1. Run decoding, finetuning of Gemma 7B with the converted checkpoint obtained from end_to_end/tpu/gemma/7b/1_test_gemma.sh. Also, run pretraining of Gemma 7B # 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. # 3. Run decoding from the finetuned checkpoint from step 1 # 4. Ahead of Time Compilation for running Gemma 7B on v5e-256 diff --git a/end_to_end/tpu/gemma/Run_Gemma.md b/end_to_end/tpu/gemma/Run_Gemma.md index b099cf883..8fbfc94c6 100644 --- a/end_to_end/tpu/gemma/Run_Gemma.md +++ b/end_to_end/tpu/gemma/Run_Gemma.md @@ -19,7 +19,7 @@ Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). -After downloading the weights run [test_convert_chkpt.sh](https://github.com/google/maxtext/blob/main/end_to_end/gemma/test_convert_chkpt.sh), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma). +After downloading the weights run [convert_gemma_chkpt.py](../../MaxText/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/tpu/gemma](../../end_to_end/tpu/gemma). ## MaxText supports pretraining and finetuning with high performance diff --git a/end_to_end/tpu/llama_finetuning_test.sh b/end_to_end/tpu/llama_finetuning_test.sh index 4758379dd..ff68f7236 100644 --- a/end_to_end/tpu/llama_finetuning_test.sh +++ b/end_to_end/tpu/llama_finetuning_test.sh @@ -16,4 +16,4 @@ export LOSS_THRESHOLD=2.5 python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt' # Assert training loss is smaller than input LOSS_THRESHOLD -python3 end_to_end/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD +python3 end_to_end/tpu/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD \ No newline at end of file diff --git a/end_to_end/tpu/test_checkpoint_compatibility.sh b/end_to_end/tpu/test_checkpoint_compatibility.sh index ae37801e9..20adeb5ec 100644 --- a/end_to_end/tpu/test_checkpoint_compatibility.sh +++ b/end_to_end/tpu/test_checkpoint_compatibility.sh @@ -45,5 +45,5 @@ python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=7 ${m dataset_type=c4-array_record grain_worker_count=0\ dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1 -python3 end_to_end/eval_assert.py test_start_step run_2_metrics.txt 3.0 -python3 end_to_end/eval_assert.py test_start_step run_3_metrics.txt 5.0 +python3 end_to_end/tpu/eval_assert.py test_start_step run_2_metrics.txt 3.0 +python3 end_to_end/tpu/eval_assert.py test_start_step run_3_metrics.txt 5.0 diff --git a/end_to_end/tpu/test_checkpoint_resharding.sh b/end_to_end/tpu/test_checkpoint_resharding.sh index ae3b741a5..b136a5ee8 100644 --- a/end_to_end/tpu/test_checkpoint_resharding.sh +++ b/end_to_end/tpu/test_checkpoint_resharding.sh @@ -15,4 +15,4 @@ python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=102\ metrics_file='restored_metrics.txt' base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 collect_stack_trace=False -python3 end_to_end/eval_assert.py checkpoint_save_restore metrics.txt learning/loss +python3 end_to_end/tpu/eval_assert.py checkpoint_save_restore metrics.txt learning/loss diff --git a/end_to_end/tpu/test_checkpointing.sh b/end_to_end/tpu/test_checkpointing.sh index e3fd9e6dd..337352f8b 100644 --- a/end_to_end/tpu/test_checkpointing.sh +++ b/end_to_end/tpu/test_checkpointing.sh @@ -60,4 +60,4 @@ echo $CMD2 $CMD2 -python3 end_to_end/eval_assert.py $eval_metrics metrics.txt learning/loss +python3 end_to_end/tpu/eval_assert.py $eval_metrics metrics.txt learning/loss diff --git a/end_to_end/tpu/test_convergence_1b_params.sh b/end_to_end/tpu/test_convergence_1b_params.sh index 38108b1c5..73c04cbde 100644 --- a/end_to_end/tpu/test_convergence_1b_params.sh +++ b/end_to_end/tpu/test_convergence_1b_params.sh @@ -11,7 +11,7 @@ echo "Running test_convergence_1b_params.sh" # LOSS_THRESHOLD (Optional, default is 100.0 ) # # Example to invoke this script: -# bash end_to_end/test_convergence_1b_params.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" LOSS_THRESHOLD=100.0 +# bash end_to_end/tpu/test_convergence_1b_params.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" LOSS_THRESHOLD=100.0 export LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass. export STEPS=20400 # Run for 20B tokens for a 1B sized mode for "chinchilla" scaling https://arxiv.org/abs/2203.15556 @@ -49,4 +49,4 @@ export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xl $TRAIN_CMD # Assert training loss is smaller than input LOSS_THRESHOLD -python3 end_to_end/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD +python3 end_to_end/tpu/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD diff --git a/end_to_end/tpu/test_determinism.sh b/end_to_end/tpu/test_determinism.sh index 9d9c066fc..853bcb107 100644 --- a/end_to_end/tpu/test_determinism.sh +++ b/end_to_end/tpu/test_determinism.sh @@ -28,4 +28,4 @@ CMD2+=$CMD_DATA $CMD1 $CMD2 -python3 end_to_end/eval_assert.py determinism metrics.txt learning/loss +python3 end_to_end/tpu/eval_assert.py determinism metrics.txt learning/loss diff --git a/end_to_end/tpu/test_tflops.sh b/end_to_end/tpu/test_tflops.sh index f0543d95e..078215e70 100644 --- a/end_to_end/tpu/test_tflops.sh +++ b/end_to_end/tpu/test_tflops.sh @@ -19,4 +19,4 @@ python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ steps=150 reuse_example_batch=1 remat_policy='full' enable_profiler=True enable_checkpointing=False metrics_file='metrics.txt'\ base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH log_period=150 -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec +python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_tflops_16b_params.sh b/end_to_end/tpu/test_tflops_16b_params.sh index 9e992305f..abe0ba478 100644 --- a/end_to_end/tpu/test_tflops_16b_params.sh +++ b/end_to_end/tpu/test_tflops_16b_params.sh @@ -9,7 +9,7 @@ echo "Running test_tflops_16b_params.sh" # TFLOP_THRESHOLD (Optional, default is 0 ) # # Example to invoke this script: -# bash end_to_end/test_tflops_16b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 +# bash end_to_end/tpu/test_tflops_16b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 # Stop execution if any command exits with error set -ex @@ -35,4 +35,4 @@ python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=16 # Assert TFLOP/s -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec +python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_tflops_32b_params.sh b/end_to_end/tpu/test_tflops_32b_params.sh index 59f0585a3..33bf7d656 100644 --- a/end_to_end/tpu/test_tflops_32b_params.sh +++ b/end_to_end/tpu/test_tflops_32b_params.sh @@ -9,7 +9,7 @@ echo "Running test_tflops_32b_params.sh" # TFLOP_THRESHOLD (Optional, default is 0 ) # # Example to invoke this script: -# bash end_to_end/test_tflops_32b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 +# bash end_to_end/tpu/test_tflops_32b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 # Stop execution if any command exits with error set -ex @@ -35,4 +35,4 @@ python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=32 # Assert TFLOP/s -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec +python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_tflops_64b_params.sh b/end_to_end/tpu/test_tflops_64b_params.sh index 7c05d7413..290ac302a 100644 --- a/end_to_end/tpu/test_tflops_64b_params.sh +++ b/end_to_end/tpu/test_tflops_64b_params.sh @@ -9,7 +9,7 @@ echo "Running test_tflops_64b_params.sh" # TFLOP_THRESHOLD (Optional, default is 0 ) # # Example to invoke this script: -# bash end_to_end/test_tflops_64b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 +# bash end_to_end/tpu/test_tflops_64b_params.sh RUN_NAME=""" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" TFLOP_THRESHOLD=0 # Stop execution if any command exits with error set -ex @@ -35,4 +35,4 @@ python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ dataset_path=$DATASET_PATH log_period=150 global_parameter_scale=64 # Assert TFLOP/s -python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec +python3 end_to_end/tpu/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec diff --git a/end_to_end/tpu/test_vocab_creation.sh b/end_to_end/tpu/test_vocab_creation.sh index dafd4f00f..3b7b4603e 100644 --- a/end_to_end/tpu/test_vocab_creation.sh +++ b/end_to_end/tpu/test_vocab_creation.sh @@ -11,4 +11,4 @@ VOCAB_PATH=$OUTPUT_PATH/vocab_test_creation_$RUN_NAME python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME steps=5 enable_checkpointing=False\ base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH tokenizer_path=$VOCAB_PATH -python3 end_to_end/eval_assert.py vocab_creation $VOCAB_PATH +python3 end_to_end/tpu/eval_assert.py vocab_creation $VOCAB_PATH From 16a05c022e86f598df521621cbbeb3148006dd20 Mon Sep 17 00:00:00 2001 From: michelle-yooh Date: Thu, 11 Apr 2024 17:47:29 +0000 Subject: [PATCH 23/26] unify WORKDIR to /deps --- .github/workflows/UnitTests.yml | 42 ++++++++++++++--------------- docker_build_dependency_image.sh | 2 +- maxtext_dependencies.Dockerfile | 2 +- maxtext_gpu_dependencies.Dockerfile | 2 +- maxtext_libtpu_path.Dockerfile | 2 +- maxtext_runner.Dockerfile | 4 +-- 6 files changed, 27 insertions(+), 27 deletions(-) diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a2feaf336..712b9efef 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -70,58 +70,58 @@ jobs: bash docker_build_dependency_image.sh - name: Test gsutil installation run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}' - name: Test with pytest run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest' + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest' - name: Test train.py with c4 run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false' - name: Test train.py with synthetic data run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false dataset_type=synthetic' - name: Test train.py with per_device_batch_size < 1 run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false' - name: Test decode.py run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1' - name: Test decode.py with per_device_batch_size < 1 run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25' - name: Test int8_training run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false' - name: Test fp8_training run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=fp8 steps=2 enable_checkpointing=false' - name: Test generate_param_only_checkpoint run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'bash end_to_end/tpu/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4' - name: Test generate_param_only_checkpoint with int8 quantization run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'bash end_to_end/tpu/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8' - name: Test grain checkpoint determinism run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'bash end_to_end/tpu/test_checkpointing.sh runner gs://runner-maxtext-logs gs://maxtext-dataset False c4-array_record' - name: Test checkpoint compatibility run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'bash end_to_end/tpu/test_checkpoint_compatibility.sh runner gs://runner-maxtext-logs gs://maxtext-dataset' - name: Validate Pedagogical Example, Shmap_collective_matmul run: | - docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 pedagogical_examples/shmap_collective_matmul.py' # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job @@ -142,28 +142,28 @@ jobs: bash docker_build_dependency_image.sh DEVICE=gpu - name: Test gsutil installation run: | - docker run --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}' - name: Test with pytest run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest -m "not tpu"' + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest -m "not tpu"' - name: Test train.py run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=dot_product' - name: Test train.py with per_device_batch_size < 1 run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=dot_product' - name: Test int8_training run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=dot_product' - name: Test decode.py run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=1' - name: Test decode.py with per_device_batch_size < 1 run: | - docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ + docker run -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 -e TF_FORCE_GPU_ALLOW_GROWTH=true --shm-size=2g --runtime=nvidia --gpus all -v /home/runner/actions-runner/_work/maxtext/maxtext:/deps --rm --privileged maxtext_base_image bash -c \ 'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=dot_product enable_checkpointing=false max_target_length=128 per_device_batch_size=.25' diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index a1b0cfeaf..8da78ac94 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -70,7 +70,7 @@ echo "" echo "Built your base docker image and named it ${LOCAL_IMAGE_NAME}. It only has the dependencies installed. Assuming you're on a TPUVM, to run the docker image locally and mirror your local working directory run:" -echo "docker run -v $(pwd):/app --rm -it --privileged --entrypoint bash ${LOCAL_IMAGE_NAME}" +echo "docker run -v $(pwd):/deps --rm -it --privileged --entrypoint bash ${LOCAL_IMAGE_NAME}" echo "" echo "You can run MaxText and your development tests inside of the docker image. Changes to your workspace will automatically be reflected inside the docker container." diff --git a/maxtext_dependencies.Dockerfile b/maxtext_dependencies.Dockerfile index cfc16547d..da2807dd2 100644 --- a/maxtext_dependencies.Dockerfile +++ b/maxtext_dependencies.Dockerfile @@ -44,4 +44,4 @@ RUN ls . RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" RUN bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} -WORKDIR /app +WORKDIR /deps diff --git a/maxtext_gpu_dependencies.Dockerfile b/maxtext_gpu_dependencies.Dockerfile index 242435b29..ecc8fc83a 100644 --- a/maxtext_gpu_dependencies.Dockerfile +++ b/maxtext_gpu_dependencies.Dockerfile @@ -40,4 +40,4 @@ RUN ls . RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}" RUN bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE} -WORKDIR /app +WORKDIR /deps diff --git a/maxtext_libtpu_path.Dockerfile b/maxtext_libtpu_path.Dockerfile index 9443db05c..57df45bd7 100644 --- a/maxtext_libtpu_path.Dockerfile +++ b/maxtext_libtpu_path.Dockerfile @@ -5,4 +5,4 @@ FROM $BASEIMAGE # Set the TPU_LIBRARY_PATH ENV TPU_LIBRARY_PATH='/root/custom_libtpu/libtpu.so' -WORKDIR /app \ No newline at end of file +WORKDIR /deps \ No newline at end of file diff --git a/maxtext_runner.Dockerfile b/maxtext_runner.Dockerfile index 15106dd89..b1e232f73 100644 --- a/maxtext_runner.Dockerfile +++ b/maxtext_runner.Dockerfile @@ -4,9 +4,9 @@ FROM $BASEIMAGE #FROM maxtext_base_image # Set the working directory in the container -WORKDIR /app +WORKDIR /deps # Copy all files from local workspace into docker container COPY . . -WORKDIR /app \ No newline at end of file +WORKDIR /deps \ No newline at end of file From 5b8a3c3bd7f0d746bd25d89169d744b9ee1e550a Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 12 Apr 2024 20:26:34 +0000 Subject: [PATCH 24/26] Share GCS path between Gemma-7b tests --- end_to_end/tpu/gemma/7b/1_test_gemma.sh | 38 ++++++++++++++-------- end_to_end/tpu/gemma/7b/2_test_gemma.sh | 43 ++++++++++++++++--------- 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/end_to_end/tpu/gemma/7b/1_test_gemma.sh b/end_to_end/tpu/gemma/7b/1_test_gemma.sh index 2b7a30b65..521dd6550 100644 --- a/end_to_end/tpu/gemma/7b/1_test_gemma.sh +++ b/end_to_end/tpu/gemma/7b/1_test_gemma.sh @@ -7,25 +7,37 @@ # 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. # 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. -# Example Usage: bash end_to_end/tpu/gemma/7b/1_test_gemma.sh +# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/gemma/7b/1_test_gemma.sh +# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/gemma/7b/2_test_gemma.sh. +# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. + set -ex -idx=$(date +%Y-%m-%d-%H-%M) MODEL_VARIATION='7b' -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs + # After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). -# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. +# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($BASE_OUTPUT_PATH). +# Non-Googlers please remember to point CHKPT_BUCKET to GCS buckets that you own export CHKPT_BUCKET=gs://maxtext-gemma/flax -export MODEL_BUCKET=gs://maxtext-gemma -JAX_PLATFORMS=cpu python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} -echo "Writen MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}" + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. + # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/gemma/7b/2_test_gemma.sh + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" +fi + +echo "Converted checkpoints are stored at ${BASE_OUTPUT_PATH}" + + +JAX_PLATFORMS=cpu python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt --model_size ${MODEL_VARIATION} +echo "Wrote MaxText compatible checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt" # We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items +export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -export RUN_NAME=unscanned_chkpt_${idx} -JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true -echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items" +export RUN_NAME=unscanned_chkpt +JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true +echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/end_to_end/tpu/gemma/7b/2_test_gemma.sh index 7353b6e38..172ca6e23 100644 --- a/end_to_end/tpu/gemma/7b/2_test_gemma.sh +++ b/end_to_end/tpu/gemma/7b/2_test_gemma.sh @@ -9,21 +9,32 @@ # 3. Run decoding from the finetuned checkpoint from step 1 # 4. Ahead of Time Compilation for running Gemma 7B on v5e-256 +# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/gemma/7b/1_test_gemma.sh +# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/gemma/7b/1_test_gemma.sh +# Please note that in these two scripts (1_test_gemma.sh and 2_test_gemma.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and +# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. + set -ex -idx=$(date +%Y-%m-%d-%H-%M) export MODEL_VARIATION='7b' -# Non-Googlers please remember to MODEL_BUCKET to GCS bucket where this script uses internal buckets for testing. -export MODEL_BUCKET=gs://maxtext-gemma +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run + # Use the same BASE_OUTPUT_PATH as end_to_end/tpu/gemma/7b/1_test_gemma.sh + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" +fi + + + # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data export DATASET_PATH=gs://maxtext-dataset -# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run -export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs + + # We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands -export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items -export RUN_NAME=unscanned_chkpt_${idx} +export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items +export RUN_NAME=unscanned_chkpt # We defined path to unscanned checkpoint created in 1_test_gemma.sh -export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` @@ -31,23 +42,23 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/it python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people’s faces" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people's faces" +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people’s faces" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning -export FINETUNE_RUN_NAME=runner_finetune_${idx} -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-7b checkpoint_period=5 +export FINETUNE_RUN_NAME=runner_finetune +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-7b checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. -# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. +# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding -export PARAM_RUN_NAME=param_chkpt_${idx} -python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true +export PARAM_RUN_NAME=param_chkpt +python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" # We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. # This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. From efb6d9e4a2856198f8272dcbfb7d9c91eda0728e Mon Sep 17 00:00:00 2001 From: michelle-yooh Date: Tue, 9 Apr 2024 20:12:17 +0000 Subject: [PATCH 25/26] Add README for llama2-7B --- MaxText/configs/a3/llama_2_7b/16vm.sh | 2 +- MaxText/configs/a3/llama_2_7b/README.md | 28 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 MaxText/configs/a3/llama_2_7b/README.md diff --git a/MaxText/configs/a3/llama_2_7b/16vm.sh b/MaxText/configs/a3/llama_2_7b/16vm.sh index 76ee26b32..fa07a470e 100644 --- a/MaxText/configs/a3/llama_2_7b/16vm.sh +++ b/MaxText/configs/a3/llama_2_7b/16vm.sh @@ -37,6 +37,6 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ # 16 nodes python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \ - steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=6 max_target_length=4096 model_name=llama2-7b \ + steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \ enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \ dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true diff --git a/MaxText/configs/a3/llama_2_7b/README.md b/MaxText/configs/a3/llama_2_7b/README.md new file mode 100644 index 000000000..966049b8e --- /dev/null +++ b/MaxText/configs/a3/llama_2_7b/README.md @@ -0,0 +1,28 @@ + + +# High Performance Model Configs on A3 GPU +Expected performance results for Llama2-7B model running on A3 GPU: + + +### Llama2-7B +| Hardware | TFLOP/sec/chip | +| ---------------------- | ---------------- | +| 1x A3 (h100-80gb-8) | 492 | +| 2x A3 (h100-80gb-8) | 422 | +| 4x A3 (h100-80gb-8) | 407 | +| 8x A3 (h100-80gb-8) | 409 | +| 16x A3 (h100-80gb-8) | 375 | From 9343eec18cf3943bb7598aa915c3625f23fb597d Mon Sep 17 00:00:00 2001 From: ssusie Date: Tue, 16 Apr 2024 16:59:15 +0000 Subject: [PATCH 26/26] adding script to fix the style and adding modified/fixed files with line length 125 --- .github/workflows/CPUTests.yml | 38 + .github/workflows/UnitTests.yml | 25 - MaxText/__init__.py | 22 +- MaxText/accelerator_to_spec_map.py | 479 +++-------- MaxText/checkpointing.py | 140 ++-- MaxText/common_types.py | 16 +- MaxText/convert_gemma_chkpt.py | 198 ++--- MaxText/convert_gpt3_ckpt_from_paxml.py | 203 +++-- MaxText/decode.py | 35 +- MaxText/generate_param_only_checkpoint.py | 63 +- MaxText/inference_microbenchmark.py | 134 +-- MaxText/inference_scratch/analyze_sharegpt.py | 42 +- MaxText/inference_utils.py | 68 +- .../input_pipeline/_grain_data_processing.py | 132 +-- MaxText/input_pipeline/_grain_operations.py | 93 ++- MaxText/input_pipeline/_grain_tokenizer.py | 30 +- .../input_pipeline/_tfds_data_processing.py | 180 ++-- .../_tfds_data_processing_c4_mlperf.py | 134 +-- .../input_pipeline_interface.py | 205 ++--- MaxText/layers/attentions.py | 781 +++++++++--------- MaxText/layers/embeddings.py | 34 +- MaxText/layers/gemma.py | 132 ++- MaxText/layers/gpt3.py | 304 ++++--- MaxText/layers/initializers.py | 12 +- MaxText/layers/linears.py | 117 ++- MaxText/layers/llama2.py | 135 ++- MaxText/layers/mistral.py | 221 +++-- MaxText/layers/models.py | 262 +++--- MaxText/layers/normalizations.py | 3 +- MaxText/layers/quantizations.py | 97 +-- MaxText/llama_or_mistral_ckpt.py | 468 +++++------ MaxText/max_logging.py | 25 +- MaxText/max_utils.py | 374 +++++---- MaxText/maxengine.py | 233 +++--- MaxText/maxengine_config.py | 23 +- MaxText/maxengine_server.py | 8 +- MaxText/maxtext_utils.py | 149 ++-- MaxText/multihost_dataloading.py | 40 +- MaxText/optimizers.py | 60 +- MaxText/pyconfig.py | 237 +++--- MaxText/sequence_packing.py | 83 +- MaxText/standalone_checkpointer.py | 62 +- MaxText/standalone_dataloader.py | 33 +- MaxText/tests/attention_test.py | 140 ++-- MaxText/tests/gpt3_test.py | 84 +- MaxText/tests/grain_data_processing_test.py | 259 +++--- .../inference_microbenchmark_smoke_test.py | 22 +- MaxText/tests/llama_test.py | 97 ++- MaxText/tests/max_utils_test.py | 127 ++- MaxText/tests/model_test.py | 113 ++- MaxText/tests/multihost_dataloading_test.py | 55 +- MaxText/tests/profiler_test.py | 79 +- MaxText/tests/quantizations_test.py | 102 ++- MaxText/tests/standalone_dl_ckpt_test.py | 82 +- MaxText/tests/tfds_data_processing_test.py | 130 +-- MaxText/tests/tokenizer_test.py | 56 +- MaxText/tests/train_compile_test.py | 199 +++-- MaxText/tests/train_int8_smoke_test.py | 54 +- MaxText/tests/train_smoke_test.py | 53 +- MaxText/tests/weight_dtypes_test.py | 87 +- MaxText/tokenizer.py | 39 +- MaxText/train.py | 331 ++++---- MaxText/train_compile.py | 97 ++- MaxText/train_tokenizer.py | 135 ++- MaxText/vertex_tensorboard.py | 64 +- code_style.sh | 33 + pedagogical_examples/non_spmd.py | 40 +- pedagogical_examples/shardings.py | 166 ++-- .../shmap_collective_matmul.py | 238 +++--- pylintrc | 2 + 70 files changed, 4581 insertions(+), 4433 deletions(-) create mode 100644 .github/workflows/CPUTests.yml create mode 100644 code_style.sh diff --git a/.github/workflows/CPUTests.yml b/.github/workflows/CPUTests.yml new file mode 100644 index 000000000..a6cd25a1e --- /dev/null +++ b/.github/workflows/CPUTests.yml @@ -0,0 +1,38 @@ +name: Linter + +on: + push: + branches: + - '**' + +jobs: + cpu: + name: "CPU tests" + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-20.04] + python-version: ['3.10'] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install pylint pyink pytype==2024.2.27 + - name: Typecheck the code with pytype + run: | + pytype --jobs auto --disable import-error MaxText/ + - name: Analysing the code with pylint in Maxtext/ + run: | + pylint MaxText/ && \ + echo 'Maxtext PyLint check successful' || { echo \ + 'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; } + - name: Analysing the code with pylint in pedagogical_examples/ + run: | + pylint pedagogical_examples/ && \ + echo 'PyLint check on pedagogical_examples/ is successful' || { echo \ + 'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; } diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 1efcfb4e7..08ff79a8c 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -27,31 +27,6 @@ on: - cron: '0 */2 * * *' jobs: - cpu: - name: "CPU test" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-20.04] - python-version: ['3.10'] - steps: - - uses: actions/checkout@v3 - - name: setup python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install Dependencies - run: | - pip install pytype==2024.2.27 - pip install pylint - - name: Typecheck the code with pytype - run: | - pytype --jobs auto --disable import-error MaxText/ - - name: Analysing the code with pylint - run: | - pylint MaxText/ - - # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'gpu' job tpu: strategy: diff --git a/MaxText/__init__.py b/MaxText/__init__.py index 83f918e62..c133d2d71 100644 --- a/MaxText/__init__.py +++ b/MaxText/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/MaxText/accelerator_to_spec_map.py b/MaxText/accelerator_to_spec_map.py index fa6b64d0a..255aef965 100644 --- a/MaxText/accelerator_to_spec_map.py +++ b/MaxText/accelerator_to_spec_map.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Static map of TPU names such as v4-8 to properties such as chip layout.""" @@ -22,358 +22,135 @@ from dataclasses import dataclass + @dataclass class SystemCharacteristics: platform: str topology_name: str - chip_config_name: str # 'megacore' or 'default' + chip_config_name: str # 'megacore' or 'default' chips_per_host_bounds: tuple devices_per_slice: int + UserFacingNameToSystemCharacteristics = { # v5e - 'v5e-16': SystemCharacteristics( - 'tpu', 'v5e:4x4', 'default', (2, 2, 1), 16 - ), - 'v5e-32': SystemCharacteristics( - 'tpu', 'v5e:4x8', 'default', (2, 2, 1), 32 - ), - 'v5e-64': SystemCharacteristics( - 'tpu', 'v5e:8x8', 'default', (2, 2, 1), 64 - ), - 'v5e-128': SystemCharacteristics( - 'tpu', 'v5e:8x16', 'default', (2, 2, 1), 128 - ), - 'v5e-256': SystemCharacteristics( - 'tpu', 'v5e:16x16', 'default', (2, 2, 1), 256 - ), + "v5e-16": SystemCharacteristics("tpu", "v5e:4x4", "default", (2, 2, 1), 16), + "v5e-32": SystemCharacteristics("tpu", "v5e:4x8", "default", (2, 2, 1), 32), + "v5e-64": SystemCharacteristics("tpu", "v5e:8x8", "default", (2, 2, 1), 64), + "v5e-128": SystemCharacteristics("tpu", "v5e:8x16", "default", (2, 2, 1), 128), + "v5e-256": SystemCharacteristics("tpu", "v5e:16x16", "default", (2, 2, 1), 256), # v4 - 'v4-8': SystemCharacteristics( - 'tpu', 'v4:2x2x1', 'megacore', (2, 2, 1), 4 - ), - 'v4-16': SystemCharacteristics( - 'tpu', 'v4:2x2x2', 'megacore', (2, 2, 1), 8 - ), - 'v4-32': SystemCharacteristics( - 'tpu', 'v4:2x2x4', 'megacore', (2, 2, 1), 16 - ), - 'v4-64': SystemCharacteristics( - 'tpu', 'v4:2x4x4', 'megacore', (2, 2, 1), 32 - ), - 'v4-128': SystemCharacteristics( - 'tpu', 'v4:4x4x4', 'megacore', (2, 2, 1), 64 - ), - 'v4-256': SystemCharacteristics( - 'tpu', 'v4:4x4x8', 'megacore', (2, 2, 1), 128 - ), - 'v4-384': SystemCharacteristics( - 'tpu', 'v4:4x4x12', 'megacore', (2, 2, 1), 192 - ), - 'v4-512': SystemCharacteristics( - 'tpu', 'v4:4x8x8', 'megacore', (2, 2, 1), 256 - ), - 'v4-1024': SystemCharacteristics( - 'tpu', 'v4:8x8x8', 'megacore', (2, 2, 1), 512 - ), - 'v4-1536': SystemCharacteristics( - 'tpu', 'v4:8x8x12','megacore', (2, 2, 1), 768 - ), - 'v4-2048': SystemCharacteristics( - 'tpu', 'v4:8x8x16','megacore', (2, 2, 1), 1024 - ), - 'v4-4096': SystemCharacteristics( - 'tpu', 'v4:8x16x16', 'megacore', (2, 2, 1), 2048 - ), + "v4-8": SystemCharacteristics("tpu", "v4:2x2x1", "megacore", (2, 2, 1), 4), + "v4-16": SystemCharacteristics("tpu", "v4:2x2x2", "megacore", (2, 2, 1), 8), + "v4-32": SystemCharacteristics("tpu", "v4:2x2x4", "megacore", (2, 2, 1), 16), + "v4-64": SystemCharacteristics("tpu", "v4:2x4x4", "megacore", (2, 2, 1), 32), + "v4-128": SystemCharacteristics("tpu", "v4:4x4x4", "megacore", (2, 2, 1), 64), + "v4-256": SystemCharacteristics("tpu", "v4:4x4x8", "megacore", (2, 2, 1), 128), + "v4-384": SystemCharacteristics("tpu", "v4:4x4x12", "megacore", (2, 2, 1), 192), + "v4-512": SystemCharacteristics("tpu", "v4:4x8x8", "megacore", (2, 2, 1), 256), + "v4-1024": SystemCharacteristics("tpu", "v4:8x8x8", "megacore", (2, 2, 1), 512), + "v4-1536": SystemCharacteristics("tpu", "v4:8x8x12", "megacore", (2, 2, 1), 768), + "v4-2048": SystemCharacteristics("tpu", "v4:8x8x16", "megacore", (2, 2, 1), 1024), + "v4-4096": SystemCharacteristics("tpu", "v4:8x16x16", "megacore", (2, 2, 1), 2048), # v5p - 'v5p-8': SystemCharacteristics( - 'tpu', 'v5:2x2x1', 'megacore', (2, 2, 1), 4 - ), - 'v5p-16': SystemCharacteristics( - 'tpu', 'v5:2x2x2', 'megacore', (2, 2, 1), 8 - ), - 'v5p-32': SystemCharacteristics( - 'tpu', 'v5:2x2x4', 'megacore', (2, 2, 1), 16 - ), - 'v5p-64': SystemCharacteristics( - 'tpu', 'v5:2x4x4', 'megacore', (2, 2, 1), 32 - ), - 'v5p-128': SystemCharacteristics( - 'tpu', 'v5:4x4x4', 'megacore', (2, 2, 1), 64 - ), - 'v5p-256': SystemCharacteristics( - 'tpu', 'v5:4x4x8', 'megacore', (2, 2, 1), 128 - ), - 'v5p-384': SystemCharacteristics( - 'tpu', 'v5:4x4x12', 'megacore', (2, 2, 1), 192 - ), - 'v5p-512': SystemCharacteristics( - 'tpu', 'v5:4x8x8', 'megacore', (2, 2, 1), 256 - ), - 'v5p-640': SystemCharacteristics( - 'tpu', 'v5:4x4x20', 'megacore', (2, 2, 1), 320 - ), - 'v5p-768': SystemCharacteristics( - 'tpu', 'v5:4x8x12', 'megacore', (2, 2, 1), 384 - ), - 'v5p-896': SystemCharacteristics( - 'tpu', 'v5:4x4x28', 'megacore', (2, 2, 1), 448 - ), - 'v5p-1024': SystemCharacteristics( - 'tpu', 'v5:8x8x8', 'megacore', (2, 2, 1), 512 - ), - 'v5p-1152': SystemCharacteristics( - 'tpu', 'v5:4x12x12', 'megacore', (2, 2, 1), 576 - ), - 'v5p-1280': SystemCharacteristics( - 'tpu', 'v5:4x8x20', 'megacore', (2, 2, 1), 640 - ), - 'v5p-1408': SystemCharacteristics( - 'tpu', 'v5:4x4x44', 'megacore', (2, 2, 1), 704 - ), - 'v5p-1536': SystemCharacteristics( - 'tpu', 'v5:8x8x12', 'megacore', (2, 2, 1), 768 - ), - 'v5p-1664': SystemCharacteristics( - 'tpu', 'v5:4x4x52', 'megacore', (2, 2, 1), 832 - ), - 'v5p-1792': SystemCharacteristics( - 'tpu', 'v5:4x8x28', 'megacore', (2, 2, 1), 896 - ), - 'v5p-1920': SystemCharacteristics( - 'tpu', 'v5:4x12x20', 'megacore', (2, 2, 1), 960 - ), - 'v5p-2048': SystemCharacteristics( - 'tpu', 'v5:8x8x16', 'megacore', (2, 2, 1), 1024 - ), - 'v5p-2176': SystemCharacteristics( - 'tpu', 'v5:4x4x68', 'megacore', (2, 2, 1), 1088 - ), - 'v5p-2304': SystemCharacteristics( - 'tpu', 'v5:8x12x12', 'megacore', (2, 2, 1), 1152 - ), - 'v5p-2432': SystemCharacteristics( - 'tpu', 'v5:4x4x76', 'megacore', (2, 2, 1), 1216 - ), - 'v5p-2560': SystemCharacteristics( - 'tpu', 'v5:8x8x20', 'megacore', (2, 2, 1), 1280 - ), - 'v5p-2688': SystemCharacteristics( - 'tpu', 'v5:4x12x28', 'megacore', (2, 2, 1), 1344 - ), - 'v5p-2816': SystemCharacteristics( - 'tpu', 'v5:4x8x44', 'megacore', (2, 2, 1), 1408 - ), - 'v5p-2944': SystemCharacteristics( - 'tpu', 'v5:4x4x92', 'megacore', (2, 2, 1), 1472 - ), - 'v5p-3072': SystemCharacteristics( - 'tpu', 'v5:8x12x16', 'megacore', (2, 2, 1), 1536 - ), - 'v5p-3200': SystemCharacteristics( - 'tpu', 'v5:4x20x20', 'megacore', (2, 2, 1), 1600 - ), - 'v5p-3328': SystemCharacteristics( - 'tpu', 'v5:4x8x52', 'megacore', (2, 2, 1), 1664 - ), - 'v5p-3456': SystemCharacteristics( - 'tpu', 'v5:12x12x12', 'megacore', (2, 2, 1), 1728 - ), - 'v5p-3584': SystemCharacteristics( - 'tpu', 'v5:8x8x28', 'megacore', (2, 2, 1), 1792 - ), - 'v5p-3712': SystemCharacteristics( - 'tpu', 'v5:4x4x116', 'megacore', (2, 2, 1), 1856 - ), - 'v5p-3840': SystemCharacteristics( - 'tpu', 'v5:8x12x20', 'megacore', (2, 2, 1), 1920 - ), - 'v5p-3968': SystemCharacteristics( - 'tpu', 'v5:4x4x124', 'megacore', (2, 2, 1), 1984 - ), - 'v5p-4096': SystemCharacteristics( - 'tpu', 'v5:8x16x16', 'megacore', (2, 2, 1), 2048 - ), - 'v5p-4224': SystemCharacteristics( - 'tpu', 'v5:4x12x44', 'megacore', (2, 2, 1), 2112 - ), - 'v5p-4352': SystemCharacteristics( - 'tpu', 'v5:4x8x68', 'megacore', (2, 2, 1), 2176 - ), - 'v5p-4480': SystemCharacteristics( - 'tpu', 'v5:4x20x28', 'megacore', (2, 2, 1), 2240 - ), - 'v5p-4608': SystemCharacteristics( - 'tpu', 'v5:12x12x16', 'megacore', (2, 2, 1), 2304 - ), - 'v5p-4736': SystemCharacteristics( - 'tpu', 'v5:4x4x148', 'megacore', (2, 2, 1), 2368 - ), - 'v5p-4864': SystemCharacteristics( - 'tpu', 'v5:4x8x76', 'megacore', (2, 2, 1), 2432 - ), - 'v5p-4992': SystemCharacteristics( - 'tpu', 'v5:4x12x52', 'megacore', (2, 2, 1), 2496 - ), - 'v5p-5120': SystemCharacteristics( - 'tpu', 'v5:8x16x20', 'megacore', (2, 2, 1), 2560 - ), - 'v5p-5248': SystemCharacteristics( - 'tpu', 'v5:4x4x164', 'megacore', (2, 2, 1), 2624 - ), - 'v5p-5376': SystemCharacteristics( - 'tpu', 'v5:8x12x28', 'megacore', (2, 2, 1), 2688 - ), - 'v5p-5504': SystemCharacteristics( - 'tpu', 'v5:4x4x172', 'megacore', (2, 2, 1), 2752 - ), - 'v5p-5632': SystemCharacteristics( - 'tpu', 'v5:8x8x44', 'megacore', (2, 2, 1), 2816 - ), - 'v5p-5760': SystemCharacteristics( - 'tpu', 'v5:12x12x20', 'megacore', (2, 2, 1), 2880 - ), - 'v5p-5888': SystemCharacteristics( - 'tpu', 'v5:4x8x92', 'megacore', (2, 2, 1), 2944 - ), - 'v5p-6016': SystemCharacteristics( - 'tpu', 'v5:4x4x188', 'megacore', (2, 2, 1), 3008 - ), - 'v5p-6144': SystemCharacteristics( - 'tpu', 'v5:12x16x16', 'megacore', (2, 2, 1), 3072 - ), - 'v5p-6272': SystemCharacteristics( - 'tpu', 'v5:4x28x28', 'megacore', (2, 2, 1), 3136 - ), - 'v5p-6400': SystemCharacteristics( - 'tpu', 'v5:8x20x20', 'megacore', (2, 2, 1), 3200 - ), - 'v5p-6528': SystemCharacteristics( - 'tpu', 'v5:4x12x68', 'megacore', (2, 2, 1), 3264 - ), - 'v5p-6656': SystemCharacteristics( - 'tpu', 'v5:8x8x52', 'megacore', (2, 2, 1), 3328 - ), - 'v5p-6784': SystemCharacteristics( - 'tpu', 'v5:4x4x212', 'megacore', (2, 2, 1), 3392 - ), - 'v5p-6912': SystemCharacteristics( - 'tpu', 'v5:12x12x24', 'megacore', (2, 2, 1), 3456 - ), - 'v5p-7040': SystemCharacteristics( - 'tpu', 'v5:4x20x44', 'megacore', (2, 2, 1), 3520 - ), - 'v5p-7168': SystemCharacteristics( - 'tpu', 'v5:8x16x28', 'megacore', (2, 2, 1), 3584 - ), - 'v5p-7296': SystemCharacteristics( - 'tpu', 'v5:4x12x76', 'megacore', (2, 2, 1), 3648 - ), - 'v5p-7424': SystemCharacteristics( - 'tpu', 'v5:4x8x116', 'megacore', (2, 2, 1), 3712 - ), - 'v5p-7552': SystemCharacteristics( - 'tpu', 'v5:4x4x236', 'megacore', (2, 2, 1), 3776 - ), - 'v5p-7680': SystemCharacteristics( - 'tpu', 'v5:12x16x20', 'megacore', (2, 2, 1), 3840 - ), - 'v5p-7808': SystemCharacteristics( - 'tpu', 'v5:4x4x244', 'megacore', (2, 2, 1), 3904 - ), - 'v5p-7936': SystemCharacteristics( - 'tpu', 'v5:4x8x124', 'megacore', (2, 2, 1), 3968 - ), - 'v5p-8064': SystemCharacteristics( - 'tpu', 'v5:12x12x28', 'megacore', (2, 2, 1), 4032 - ), - 'v5p-8192': SystemCharacteristics( - 'tpu', 'v5:16x16x16', 'megacore', (2, 2, 1), 4096 - ), - 'v5p-8320': SystemCharacteristics( - 'tpu', 'v5:4x20x52', 'megacore', (2, 2, 1), 4160 - ), - 'v5p-8448': SystemCharacteristics( - 'tpu', 'v5:8x12x44', 'megacore', (2, 2, 1), 4224 - ), - 'v5p-8704': SystemCharacteristics( - 'tpu', 'v5:8x8x68', 'megacore', (2, 2, 1), 4352 - ), - 'v5p-8832': SystemCharacteristics( - 'tpu', 'v5:4x12x92', 'megacore', (2, 2, 1), 4416 - ), - 'v5p-8960': SystemCharacteristics( - 'tpu', 'v5:8x20x28', 'megacore', (2, 2, 1), 4480 - ), - 'v5p-9216': SystemCharacteristics( - 'tpu', 'v5:12x16x24', 'megacore', (2, 2, 1), 4608 - ), - 'v5p-9472': SystemCharacteristics( - 'tpu', 'v5:4x8x148', 'megacore', (2, 2, 1), 4736 - ), - 'v5p-9600': SystemCharacteristics( - 'tpu', 'v5:12x20x20', 'megacore', (2, 2, 1), 4800 - ), - 'v5p-9728': SystemCharacteristics( - 'tpu', 'v5:8x8x76', 'megacore', (2, 2, 1), 4864 - ), - 'v5p-9856': SystemCharacteristics( - 'tpu', 'v5:4x28x44', 'megacore', (2, 2, 1), 4928 - ), - 'v5p-9984': SystemCharacteristics( - 'tpu', 'v5:8x12x52', 'megacore', (2, 2, 1), 4992 - ), - 'v5p-10240': SystemCharacteristics( - 'tpu', 'v5:16x16x20', 'megacore', (2, 2, 1), 5120 - ), - 'v5p-10368': SystemCharacteristics( - 'tpu', 'v5:12x12x36', 'megacore', (2, 2, 1), 5184 - ), - 'v5p-10496': SystemCharacteristics( - 'tpu', 'v5:4x8x164', 'megacore', (2, 2, 1), 5248 - ), - 'v5p-10752': SystemCharacteristics( - 'tpu', 'v5:12x16x28', 'megacore', (2, 2, 1), 5376 - ), - 'v5p-10880': SystemCharacteristics( - 'tpu', 'v5:4x20x68', 'megacore', (2, 2, 1), 5440 - ), - 'v5p-11008': SystemCharacteristics( - 'tpu', 'v5:4x8x172', 'megacore', (2, 2, 1), 5504 - ), - 'v5p-11136': SystemCharacteristics( - 'tpu', 'v5:4x12x116', 'megacore', (2, 2, 1), 5568 - ), - 'v5p-11264': SystemCharacteristics( - 'tpu', 'v5:8x16x44', 'megacore', (2, 2, 1), 5632 - ), - 'v5p-11520': SystemCharacteristics( - 'tpu', 'v5:12x20x24', 'megacore', (2, 2, 1), 5760 - ), - 'v5p-11648': SystemCharacteristics( - 'tpu', 'v5:4x28x52', 'megacore', (2, 2, 1), 5824 - ), - 'v5p-11776': SystemCharacteristics( - 'tpu', 'v5:8x8x92', 'megacore', (2, 2, 1), 5888 - ), - 'v5p-11904': SystemCharacteristics( - 'tpu', 'v5:4x12x124', 'megacore', (2, 2, 1), 5952 - ), - 'v5p-12032': SystemCharacteristics( - 'tpu', 'v5:4x8x188', 'megacore', (2, 2, 1), 6016 - ), - 'v5p-12160': SystemCharacteristics( - 'tpu', 'v5:4x20x76', 'megacore', (2, 2, 1), 6080 - ), - 'v5p-12288': SystemCharacteristics( - 'tpu', 'v5:16x16x24', 'megacore', (2, 2, 1), 6144 - ), - 'v5p-13824': SystemCharacteristics( - 'tpu', 'v5:12x24x24', 'megacore', (2, 2, 1), 6912 - ), - 'v5p-17920': SystemCharacteristics( - 'tpu', 'v5:16x20x28', 'megacore', (2, 2, 1), 8960 - ), + "v5p-8": SystemCharacteristics("tpu", "v5:2x2x1", "megacore", (2, 2, 1), 4), + "v5p-16": SystemCharacteristics("tpu", "v5:2x2x2", "megacore", (2, 2, 1), 8), + "v5p-32": SystemCharacteristics("tpu", "v5:2x2x4", "megacore", (2, 2, 1), 16), + "v5p-64": SystemCharacteristics("tpu", "v5:2x4x4", "megacore", (2, 2, 1), 32), + "v5p-128": SystemCharacteristics("tpu", "v5:4x4x4", "megacore", (2, 2, 1), 64), + "v5p-256": SystemCharacteristics("tpu", "v5:4x4x8", "megacore", (2, 2, 1), 128), + "v5p-384": SystemCharacteristics("tpu", "v5:4x4x12", "megacore", (2, 2, 1), 192), + "v5p-512": SystemCharacteristics("tpu", "v5:4x8x8", "megacore", (2, 2, 1), 256), + "v5p-640": SystemCharacteristics("tpu", "v5:4x4x20", "megacore", (2, 2, 1), 320), + "v5p-768": SystemCharacteristics("tpu", "v5:4x8x12", "megacore", (2, 2, 1), 384), + "v5p-896": SystemCharacteristics("tpu", "v5:4x4x28", "megacore", (2, 2, 1), 448), + "v5p-1024": SystemCharacteristics("tpu", "v5:8x8x8", "megacore", (2, 2, 1), 512), + "v5p-1152": SystemCharacteristics("tpu", "v5:4x12x12", "megacore", (2, 2, 1), 576), + "v5p-1280": SystemCharacteristics("tpu", "v5:4x8x20", "megacore", (2, 2, 1), 640), + "v5p-1408": SystemCharacteristics("tpu", "v5:4x4x44", "megacore", (2, 2, 1), 704), + "v5p-1536": SystemCharacteristics("tpu", "v5:8x8x12", "megacore", (2, 2, 1), 768), + "v5p-1664": SystemCharacteristics("tpu", "v5:4x4x52", "megacore", (2, 2, 1), 832), + "v5p-1792": SystemCharacteristics("tpu", "v5:4x8x28", "megacore", (2, 2, 1), 896), + "v5p-1920": SystemCharacteristics("tpu", "v5:4x12x20", "megacore", (2, 2, 1), 960), + "v5p-2048": SystemCharacteristics("tpu", "v5:8x8x16", "megacore", (2, 2, 1), 1024), + "v5p-2176": SystemCharacteristics("tpu", "v5:4x4x68", "megacore", (2, 2, 1), 1088), + "v5p-2304": SystemCharacteristics("tpu", "v5:8x12x12", "megacore", (2, 2, 1), 1152), + "v5p-2432": SystemCharacteristics("tpu", "v5:4x4x76", "megacore", (2, 2, 1), 1216), + "v5p-2560": SystemCharacteristics("tpu", "v5:8x8x20", "megacore", (2, 2, 1), 1280), + "v5p-2688": SystemCharacteristics("tpu", "v5:4x12x28", "megacore", (2, 2, 1), 1344), + "v5p-2816": SystemCharacteristics("tpu", "v5:4x8x44", "megacore", (2, 2, 1), 1408), + "v5p-2944": SystemCharacteristics("tpu", "v5:4x4x92", "megacore", (2, 2, 1), 1472), + "v5p-3072": SystemCharacteristics("tpu", "v5:8x12x16", "megacore", (2, 2, 1), 1536), + "v5p-3200": SystemCharacteristics("tpu", "v5:4x20x20", "megacore", (2, 2, 1), 1600), + "v5p-3328": SystemCharacteristics("tpu", "v5:4x8x52", "megacore", (2, 2, 1), 1664), + "v5p-3456": SystemCharacteristics("tpu", "v5:12x12x12", "megacore", (2, 2, 1), 1728), + "v5p-3584": SystemCharacteristics("tpu", "v5:8x8x28", "megacore", (2, 2, 1), 1792), + "v5p-3712": SystemCharacteristics("tpu", "v5:4x4x116", "megacore", (2, 2, 1), 1856), + "v5p-3840": SystemCharacteristics("tpu", "v5:8x12x20", "megacore", (2, 2, 1), 1920), + "v5p-3968": SystemCharacteristics("tpu", "v5:4x4x124", "megacore", (2, 2, 1), 1984), + "v5p-4096": SystemCharacteristics("tpu", "v5:8x16x16", "megacore", (2, 2, 1), 2048), + "v5p-4224": SystemCharacteristics("tpu", "v5:4x12x44", "megacore", (2, 2, 1), 2112), + "v5p-4352": SystemCharacteristics("tpu", "v5:4x8x68", "megacore", (2, 2, 1), 2176), + "v5p-4480": SystemCharacteristics("tpu", "v5:4x20x28", "megacore", (2, 2, 1), 2240), + "v5p-4608": SystemCharacteristics("tpu", "v5:12x12x16", "megacore", (2, 2, 1), 2304), + "v5p-4736": SystemCharacteristics("tpu", "v5:4x4x148", "megacore", (2, 2, 1), 2368), + "v5p-4864": SystemCharacteristics("tpu", "v5:4x8x76", "megacore", (2, 2, 1), 2432), + "v5p-4992": SystemCharacteristics("tpu", "v5:4x12x52", "megacore", (2, 2, 1), 2496), + "v5p-5120": SystemCharacteristics("tpu", "v5:8x16x20", "megacore", (2, 2, 1), 2560), + "v5p-5248": SystemCharacteristics("tpu", "v5:4x4x164", "megacore", (2, 2, 1), 2624), + "v5p-5376": SystemCharacteristics("tpu", "v5:8x12x28", "megacore", (2, 2, 1), 2688), + "v5p-5504": SystemCharacteristics("tpu", "v5:4x4x172", "megacore", (2, 2, 1), 2752), + "v5p-5632": SystemCharacteristics("tpu", "v5:8x8x44", "megacore", (2, 2, 1), 2816), + "v5p-5760": SystemCharacteristics("tpu", "v5:12x12x20", "megacore", (2, 2, 1), 2880), + "v5p-5888": SystemCharacteristics("tpu", "v5:4x8x92", "megacore", (2, 2, 1), 2944), + "v5p-6016": SystemCharacteristics("tpu", "v5:4x4x188", "megacore", (2, 2, 1), 3008), + "v5p-6144": SystemCharacteristics("tpu", "v5:12x16x16", "megacore", (2, 2, 1), 3072), + "v5p-6272": SystemCharacteristics("tpu", "v5:4x28x28", "megacore", (2, 2, 1), 3136), + "v5p-6400": SystemCharacteristics("tpu", "v5:8x20x20", "megacore", (2, 2, 1), 3200), + "v5p-6528": SystemCharacteristics("tpu", "v5:4x12x68", "megacore", (2, 2, 1), 3264), + "v5p-6656": SystemCharacteristics("tpu", "v5:8x8x52", "megacore", (2, 2, 1), 3328), + "v5p-6784": SystemCharacteristics("tpu", "v5:4x4x212", "megacore", (2, 2, 1), 3392), + "v5p-6912": SystemCharacteristics("tpu", "v5:12x12x24", "megacore", (2, 2, 1), 3456), + "v5p-7040": SystemCharacteristics("tpu", "v5:4x20x44", "megacore", (2, 2, 1), 3520), + "v5p-7168": SystemCharacteristics("tpu", "v5:8x16x28", "megacore", (2, 2, 1), 3584), + "v5p-7296": SystemCharacteristics("tpu", "v5:4x12x76", "megacore", (2, 2, 1), 3648), + "v5p-7424": SystemCharacteristics("tpu", "v5:4x8x116", "megacore", (2, 2, 1), 3712), + "v5p-7552": SystemCharacteristics("tpu", "v5:4x4x236", "megacore", (2, 2, 1), 3776), + "v5p-7680": SystemCharacteristics("tpu", "v5:12x16x20", "megacore", (2, 2, 1), 3840), + "v5p-7808": SystemCharacteristics("tpu", "v5:4x4x244", "megacore", (2, 2, 1), 3904), + "v5p-7936": SystemCharacteristics("tpu", "v5:4x8x124", "megacore", (2, 2, 1), 3968), + "v5p-8064": SystemCharacteristics("tpu", "v5:12x12x28", "megacore", (2, 2, 1), 4032), + "v5p-8192": SystemCharacteristics("tpu", "v5:16x16x16", "megacore", (2, 2, 1), 4096), + "v5p-8320": SystemCharacteristics("tpu", "v5:4x20x52", "megacore", (2, 2, 1), 4160), + "v5p-8448": SystemCharacteristics("tpu", "v5:8x12x44", "megacore", (2, 2, 1), 4224), + "v5p-8704": SystemCharacteristics("tpu", "v5:8x8x68", "megacore", (2, 2, 1), 4352), + "v5p-8832": SystemCharacteristics("tpu", "v5:4x12x92", "megacore", (2, 2, 1), 4416), + "v5p-8960": SystemCharacteristics("tpu", "v5:8x20x28", "megacore", (2, 2, 1), 4480), + "v5p-9216": SystemCharacteristics("tpu", "v5:12x16x24", "megacore", (2, 2, 1), 4608), + "v5p-9472": SystemCharacteristics("tpu", "v5:4x8x148", "megacore", (2, 2, 1), 4736), + "v5p-9600": SystemCharacteristics("tpu", "v5:12x20x20", "megacore", (2, 2, 1), 4800), + "v5p-9728": SystemCharacteristics("tpu", "v5:8x8x76", "megacore", (2, 2, 1), 4864), + "v5p-9856": SystemCharacteristics("tpu", "v5:4x28x44", "megacore", (2, 2, 1), 4928), + "v5p-9984": SystemCharacteristics("tpu", "v5:8x12x52", "megacore", (2, 2, 1), 4992), + "v5p-10240": SystemCharacteristics("tpu", "v5:16x16x20", "megacore", (2, 2, 1), 5120), + "v5p-10368": SystemCharacteristics("tpu", "v5:12x12x36", "megacore", (2, 2, 1), 5184), + "v5p-10496": SystemCharacteristics("tpu", "v5:4x8x164", "megacore", (2, 2, 1), 5248), + "v5p-10752": SystemCharacteristics("tpu", "v5:12x16x28", "megacore", (2, 2, 1), 5376), + "v5p-10880": SystemCharacteristics("tpu", "v5:4x20x68", "megacore", (2, 2, 1), 5440), + "v5p-11008": SystemCharacteristics("tpu", "v5:4x8x172", "megacore", (2, 2, 1), 5504), + "v5p-11136": SystemCharacteristics("tpu", "v5:4x12x116", "megacore", (2, 2, 1), 5568), + "v5p-11264": SystemCharacteristics("tpu", "v5:8x16x44", "megacore", (2, 2, 1), 5632), + "v5p-11520": SystemCharacteristics("tpu", "v5:12x20x24", "megacore", (2, 2, 1), 5760), + "v5p-11648": SystemCharacteristics("tpu", "v5:4x28x52", "megacore", (2, 2, 1), 5824), + "v5p-11776": SystemCharacteristics("tpu", "v5:8x8x92", "megacore", (2, 2, 1), 5888), + "v5p-11904": SystemCharacteristics("tpu", "v5:4x12x124", "megacore", (2, 2, 1), 5952), + "v5p-12032": SystemCharacteristics("tpu", "v5:4x8x188", "megacore", (2, 2, 1), 6016), + "v5p-12160": SystemCharacteristics("tpu", "v5:4x20x76", "megacore", (2, 2, 1), 6080), + "v5p-12288": SystemCharacteristics("tpu", "v5:16x16x24", "megacore", (2, 2, 1), 6144), + "v5p-13824": SystemCharacteristics("tpu", "v5:12x24x24", "megacore", (2, 2, 1), 6912), + "v5p-17920": SystemCharacteristics("tpu", "v5:16x20x28", "megacore", (2, 2, 1), 8960), } + def get_system_characteristics(user_facing_name): return UserFacingNameToSystemCharacteristics.get(user_facing_name) diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index bd229cc91..072117680 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" @@ -28,12 +28,13 @@ from multihost_dataloading import MultiHostDataLoadIterator from flax.training import train_state + def create_orbax_checkpoint_manager( checkpoint_dir: str, enable_checkpointing: bool, use_async: bool, save_interval_steps: int, - dataset_type: Optional[str] = 'c4' + dataset_type: Optional[str] = "c4", ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -42,19 +43,19 @@ def create_orbax_checkpoint_manager( max_logging.log("Creating checkpoint manager...") p = epath.Path(checkpoint_dir) - if dataset_type=='c4-array_record': - item_names = ('items', 'iter') + if dataset_type == "c4-array_record": + item_names = ("items", "iter") else: - item_names = ('items',) + item_names = ("items",) mngr = CheckpointManager( p, - item_names = item_names, - options = CheckpointManagerOptions( + item_names=item_names, + options=CheckpointManagerOptions( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, - ) + ), ) max_logging.log("Checkpoint manager created!") return mngr @@ -82,19 +83,19 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): devices inside the replica that current host is in """ idx = _find_idx(device_array, replica_axis_idx) - replica_result = np.take(device_array, - idx, - axis=replica_axis_idx) + replica_result = np.take(device_array, idx, axis=replica_axis_idx) return np.expand_dims(replica_result, axis=replica_axis_idx) -def load_state_if_possible(checkpoint_manager: CheckpointManager, - data_iterator: Union[MultiHostDataLoadIterator, None], - load_parameters_from_path: str, - load_full_state_from_path: str, - abstract_unboxed_pre_state: train_state.TrainState, - enable_single_replica_ckpt_restoring: Optional[bool] = False, - dataset_type: Optional[str] = 'c4'): +def load_state_if_possible( + checkpoint_manager: CheckpointManager, + data_iterator: Union[MultiHostDataLoadIterator, None], + load_parameters_from_path: str, + load_full_state_from_path: str, + abstract_unboxed_pre_state: train_state.TrainState, + enable_single_replica_ckpt_restoring: Optional[bool] = False, + dataset_type: Optional[str] = "c4", +): """Loads TrainState as possible from the inputs. Args: @@ -121,57 +122,59 @@ def load_state_if_possible(checkpoint_manager: CheckpointManager, latest_step = checkpoint_manager.latest_step() if latest_step is not None: - max_logging.log(f"restoring from this run's directory latest step \ - {latest_step}") + max_logging.log( + f"restoring from this run's directory latest step \ + {latest_step}" + ) - def map_to_pspec(data): + def map_to_pspec(data): pspec = data.sharding.spec mesh = data.sharding.mesh if not enable_single_replica_ckpt_restoring: return orbax.checkpoint.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec) orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, - orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), - override=True) + jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True + ) orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, - orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), - override=True) + jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True + ) replica_axis_index = 0 # for maxtext data is the first dimension replica_devices = _replica_devices(mesh.devices, replica_axis_index) replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) return orbax.checkpoint.type_handlers.SingleReplicaArrayRestoreArgs( - sharding=jax.sharding.NamedSharding(mesh, pspec), - single_replica_sharding=single_replica_sharding, - replica_axis_index=replica_axis_index, - global_shape=data.shape, - dtype=data.dtype, - ) - - restore_args = jax.tree_util.tree_map(map_to_pspec, - abstract_unboxed_pre_state, - ) - if dataset_type == 'c4-array_record' and data_iterator is not None: - return checkpoint_manager.restore( - latest_step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeRestore( - item=abstract_unboxed_pre_state, - restore_args=restore_args), - iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)) - ), None + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + replica_axis_index=replica_axis_index, + global_shape=data.shape, + dtype=data.dtype, + ) + + restore_args = jax.tree_util.tree_map( + map_to_pspec, + abstract_unboxed_pre_state, + ) + if dataset_type == "c4-array_record" and data_iterator is not None: + return ( + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args), + iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator), + ), + ), + None, + ) else: return ( - checkpoint_manager.restore( - latest_step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeRestore( - item=abstract_unboxed_pre_state, - restore_args=restore_args) + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + ), ), - ), - None) + None, + ) if load_parameters_from_path != "": max_logging.log(f"restoring params from {load_parameters_from_path=}") @@ -182,16 +185,17 @@ def map_to_pspec(data): # memory, we instead specify here that we are just restoring the params field of the checkpoint # (which itself may be a dictionary containing a key named 'params'). restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(abstract_unboxed_pre_state.params) - restored = ckptr.restore(p, item = {'params': abstract_unboxed_pre_state.params}, transforms={}, - restore_args = {'params': restore_args}) - return None, restored['params'] + restored = ckptr.restore( + p, item={"params": abstract_unboxed_pre_state.params}, transforms={}, restore_args={"params": restore_args} + ) + return None, restored["params"] elif load_full_state_from_path != "": max_logging.log(f"restoring full state from {load_full_state_from_path=}") p = epath.Path(load_full_state_from_path) ckptr = orbax.checkpoint.StandardCheckpointer() restored = ckptr.restore(p, args=orbax.checkpoint.args.StandardRestore(abstract_unboxed_pre_state)) - return {'items': restored}, None + return {"items": restored}, None else: max_logging.log("No existing checkpoints found, not restoring checkpoint.") diff --git a/MaxText/common_types.py b/MaxText/common_types.py index a2e0f389b..2961104f3 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -32,13 +32,13 @@ AxisNames = tuple[str, ...] -BATCH = 'activation_batch' -LENGTH = 'activation_length' -HEAD = 'activation_heads' -D_KV = 'activation_kv' - -MODEL_MODE_AUTOREGRESSIVE = 'autoregressive' -MODEL_MODE_PREFILL = 'prefill' -MODEL_MODE_TRAIN = 'train' +BATCH = "activation_batch" +LENGTH = "activation_length" +HEAD = "activation_heads" +D_KV = "activation_kv" + +MODEL_MODE_AUTOREGRESSIVE = "autoregressive" +MODEL_MODE_PREFILL = "prefill" +MODEL_MODE_TRAIN = "train" DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 diff --git a/MaxText/convert_gemma_chkpt.py b/MaxText/convert_gemma_chkpt.py index a3cf1fd39..c690130c2 100644 --- a/MaxText/convert_gemma_chkpt.py +++ b/MaxText/convert_gemma_chkpt.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=line-too-long """ Convert orbax Gemma checkpoint to MaxText compatible checkpoint. @@ -18,7 +18,8 @@ import jax import jax.numpy as jnp import numpy as np -jax.config.update('jax_platform_name', 'cpu') + +jax.config.update("jax_platform_name", "cpu") import argparse import copy from flax.training import train_state @@ -35,44 +36,35 @@ Params = dict[str, Any] + def nest_params(params: Params) -> Params: """Nests params as a dict of dicts rather than a flat dict.""" nested_params = {} for path, param in params.items(): - *path, leaf = path.split('/') + *path, leaf = path.split("/") subdict = nested_params for key in path: subdict = subdict.setdefault(key, {}) subdict[leaf] = param return nested_params + def main(raw_args=None) -> None: parser = argparse.ArgumentParser() - parser.add_argument('--base_model_path', type=str, required=True) - parser.add_argument('--maxtext_model_path', type=str, required=True) - parser.add_argument('--model_size', type=str, required=True) + parser.add_argument("--base_model_path", type=str, required=True) + parser.add_argument("--maxtext_model_path", type=str, required=True) + parser.add_argument("--model_size", type=str, required=True) args = parser.parse_args(raw_args) - if args.model_size not in ('2b','7b'): + if args.model_size not in ("2b", "7b"): raise NotImplementedError print("Loading checkpoint") checkpointer = orbax.checkpoint.PyTreeCheckpointer() params = checkpointer.restore(args.base_model_path) params = nest_params(params) - num_layers = ( - max(( - int(k.split('_')[1]) - for k in params['transformer'].keys() - if 'layer_' in k - )) - + 1 - ) - hidden_dim, embed_dim = ( - params['transformer']['layer_0']['mlp']['linear']['w'].shape - ) - num_heads, head_dim, _ = ( - params['transformer']['layer_0']['attn']['attn_vec_einsum']['w'].shape - ) + num_layers = max((int(k.split("_")[1]) for k in params["transformer"].keys() if "layer_" in k)) + 1 + hidden_dim, embed_dim = params["transformer"]["layer_0"]["mlp"]["linear"]["w"].shape + num_heads, head_dim, _ = params["transformer"]["layer_0"]["attn"]["attn_vec_einsum"]["w"].shape print("Model configurations from checkpoint") print(f"num_layers: {num_layers}") print(f"hidden_dim: {hidden_dim}") @@ -81,109 +73,96 @@ def main(raw_args=None) -> None: print(f"head_dim: {head_dim}") jax_weights = { - 'decoder': { - 'decoder_norm': { - 'scale': params['transformer']['final_norm']['scale'] + 1 - }, + "decoder": { + "decoder_norm": {"scale": params["transformer"]["final_norm"]["scale"] + 1}, }, - 'token_embedder':{ - 'embedding': params['transformer']['embedder']['input_embedding'] * jnp.sqrt(embed_dim) - } - + "token_embedder": {"embedding": params["transformer"]["embedder"]["input_embedding"] * jnp.sqrt(embed_dim)}, } self_attention = dict({ - 'query': { - 'kernel' : [] - }, - 'key': { - 'kernel' : [] - }, - 'value': { - 'kernel' : [] - }, - 'out': { - 'kernel' : [] - }, + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, }) layer_weight = dict({ - 'mlp': { - 'wi_0': { - 'kernel' : [] - }, - 'wi_1': { - 'kernel' : [] - }, - 'wo': { - 'kernel' : [] - }, - }, - 'pre_self_attention_norm': { - 'scale': [] - }, - 'pre_ffw_norm': { - 'scale': [] - }, + "mlp": { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, + }, + "pre_self_attention_norm": {"scale": []}, + "pre_ffw_norm": {"scale": []}, }) for layer_idx in range(num_layers): - in_layer_name = 'layer_' + str(layer_idx) + in_layer_name = "layer_" + str(layer_idx) # attention block - if args.model_size == '2b': # MQA - self_attention['query']['kernel'].append(params['transformer'][in_layer_name]['attn']['q_einsum']['w'].transpose((1, 0, 2)) * head_dim**-0.5) - self_attention['key']['kernel'].append(params['transformer'][in_layer_name]['attn']['kv_einsum']['w'][0].transpose((1, 0, 2))) - self_attention['value']['kernel'].append(params['transformer'][in_layer_name]['attn']['kv_einsum']['w'][1].transpose((1, 0, 2))) + if args.model_size == "2b": # MQA + self_attention["query"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["q_einsum"]["w"].transpose((1, 0, 2)) * head_dim**-0.5 + ) + self_attention["key"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["kv_einsum"]["w"][0].transpose((1, 0, 2)) + ) + self_attention["value"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["kv_einsum"]["w"][1].transpose((1, 0, 2)) + ) else: - self_attention['query']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][0].transpose((1, 0, 2)) * head_dim**-0.5) - self_attention['key']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][1].transpose((1, 0, 2))) - self_attention['value']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][2].transpose((1, 0, 2))) - self_attention['out']['kernel'].append(params['transformer'][in_layer_name]['attn']['attn_vec_einsum']['w']) + self_attention["query"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][0].transpose((1, 0, 2)) * head_dim**-0.5 + ) + self_attention["key"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][1].transpose((1, 0, 2)) + ) + self_attention["value"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][2].transpose((1, 0, 2)) + ) + self_attention["out"]["kernel"].append(params["transformer"][in_layer_name]["attn"]["attn_vec_einsum"]["w"]) # mlp - layer_weight['mlp']['wi_0']['kernel'].append(params['transformer'][in_layer_name]['mlp']['gating_einsum']['w'][0]) - layer_weight['mlp']['wi_1']['kernel'].append(params['transformer'][in_layer_name]['mlp']['gating_einsum']['w'][1]) - layer_weight['mlp']['wo']['kernel'].append(params['transformer'][in_layer_name]['mlp']['linear']['w']) - layer_weight['pre_self_attention_norm']['scale'].append(params['transformer'][in_layer_name]['pre_attention_norm']['scale'] + 1) - layer_weight['pre_ffw_norm']['scale'].append(params['transformer'][in_layer_name]['pre_ffw_norm']['scale'] + 1) - - self_attention['query']['kernel'] = np.array(self_attention['query']['kernel']).transpose((1, 0, 2, 3)) - self_attention['key']['kernel'] = np.array(self_attention['key']['kernel']).transpose((1, 0, 2, 3)) - self_attention['value']['kernel'] = np.array(self_attention['value']['kernel']).transpose((1, 0, 2, 3)) - self_attention['out']['kernel'] = np.array(self_attention['out']['kernel']).transpose((1, 0, 2, 3)) - - layer_weight['mlp']['wi_0']['kernel'] = np.array(layer_weight['mlp']['wi_0']['kernel']).transpose((1, 0, 2)) - layer_weight['mlp']['wi_1']['kernel'] = np.array(layer_weight['mlp']['wi_1']['kernel']).transpose((1, 0, 2)) - layer_weight['mlp']['wo']['kernel'] = np.array(layer_weight['mlp']['wo']['kernel']).transpose((1, 0, 2)) - layer_weight['pre_self_attention_norm']['scale'] = np.array(layer_weight['pre_self_attention_norm']['scale']).transpose((1, 0)) - layer_weight['pre_ffw_norm']['scale'] = np.array(layer_weight['pre_ffw_norm']['scale']).transpose((1, 0)) - - layer_weight['self_attention'] = copy.deepcopy(self_attention) - jax_weights['decoder']['layers'] = copy.deepcopy(layer_weight) + layer_weight["mlp"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name]["mlp"]["gating_einsum"]["w"][0]) + layer_weight["mlp"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name]["mlp"]["gating_einsum"]["w"][1]) + layer_weight["mlp"]["wo"]["kernel"].append(params["transformer"][in_layer_name]["mlp"]["linear"]["w"]) + layer_weight["pre_self_attention_norm"]["scale"].append( + params["transformer"][in_layer_name]["pre_attention_norm"]["scale"] + 1 + ) + layer_weight["pre_ffw_norm"]["scale"].append(params["transformer"][in_layer_name]["pre_ffw_norm"]["scale"] + 1) + + self_attention["query"]["kernel"] = np.array(self_attention["query"]["kernel"]).transpose((1, 0, 2, 3)) + self_attention["key"]["kernel"] = np.array(self_attention["key"]["kernel"]).transpose((1, 0, 2, 3)) + self_attention["value"]["kernel"] = np.array(self_attention["value"]["kernel"]).transpose((1, 0, 2, 3)) + self_attention["out"]["kernel"] = np.array(self_attention["out"]["kernel"]).transpose((1, 0, 2, 3)) + + layer_weight["mlp"]["wi_0"]["kernel"] = np.array(layer_weight["mlp"]["wi_0"]["kernel"]).transpose((1, 0, 2)) + layer_weight["mlp"]["wi_1"]["kernel"] = np.array(layer_weight["mlp"]["wi_1"]["kernel"]).transpose((1, 0, 2)) + layer_weight["mlp"]["wo"]["kernel"] = np.array(layer_weight["mlp"]["wo"]["kernel"]).transpose((1, 0, 2)) + layer_weight["pre_self_attention_norm"]["scale"] = np.array(layer_weight["pre_self_attention_norm"]["scale"]).transpose( + (1, 0) + ) + layer_weight["pre_ffw_norm"]["scale"] = np.array(layer_weight["pre_ffw_norm"]["scale"]).transpose((1, 0)) + + layer_weight["self_attention"] = copy.deepcopy(self_attention) + jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight) jax_weights = jax.tree_map(jnp.array, jax_weights) + def astype_fn(x): if isinstance(x, jnp.ndarray): return x.astype(jnp.bfloat16) else: return x - jax_weights = jax.tree_map(astype_fn, jax_weights) - enable_checkpointing=True - async_checkpointing=False - save_interval_steps=1 + jax_weights = jax.tree_map(astype_fn, jax_weights) + enable_checkpointing = True + async_checkpointing = False + save_interval_steps = 1 checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - args.maxtext_model_path, - enable_checkpointing, - async_checkpointing, - save_interval_steps + args.maxtext_model_path, enable_checkpointing, async_checkpointing, save_interval_steps ) state_new = train_state.TrainState( - step=0, - apply_fn=None, - params={'params': jax_weights}, - tx=None, # type: ignore - opt_state={} + step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore ) if checkpoint_manager is not None: @@ -194,5 +173,6 @@ def astype_fn(x): checkpoint_manager.wait_until_finished() sys.exit() + if __name__ == "__main__": main() diff --git a/MaxText/convert_gpt3_ckpt_from_paxml.py b/MaxText/convert_gpt3_ckpt_from_paxml.py index 78d43c47b..3ec57f8a2 100644 --- a/MaxText/convert_gpt3_ckpt_from_paxml.py +++ b/MaxText/convert_gpt3_ckpt_from_paxml.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=line-too-long """Convert weights from a paxml gpt3 model to a MaxText one. @@ -53,6 +53,7 @@ from train import save_checkpoint import argparse + def fmt_size(num_bytes: int) -> str: assert num_bytes > 0 for unit in ["B", "KiB", "MiB", "GiB"]: @@ -61,14 +62,15 @@ def fmt_size(num_bytes: int) -> str: num_bytes /= 1024.0 return f"{num_bytes:.2f} {unit}" + def check_memory(): """print out cpu/tpu memory.""" cpu_bytes = Process().memory_info().rss max_logging.log(f"cpu memory: {fmt_size(cpu_bytes)}") for d in jax.local_devices(): stats = d.memory_stats() - used = stats['bytes_in_use'] - limit = stats['bytes_limit'] + used = stats["bytes_in_use"] + limit = stats["bytes_limit"] max_logging.log(f"tpu memory: Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}") @@ -76,13 +78,16 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name """convert ckpt.""" base_args = [ - '', 'MaxText/configs/base.yml', # base arg - 'per_device_batch_size=1', - 'ici_fsdp_parallelism=-1', 'ici_tensor_parallelism=1', - f'model_name={maxtext_model_name}', - f'run_name={run_name}', f'base_output_directory={base_output_directory}', - 'checkpoint_period=1', - 'async_checkpointing=false', + "", + "MaxText/configs/base.yml", # base arg + "per_device_batch_size=1", + "ici_fsdp_parallelism=-1", + "ici_tensor_parallelism=1", + f"model_name={maxtext_model_name}", + f"run_name={run_name}", + f"base_output_directory={base_output_directory}", + "checkpoint_period=1", + "async_checkpointing=false", ] pyconfig.initialize(base_args) cfg = pyconfig.config @@ -96,10 +101,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name tx = optimizers.get_optimizer(cfg, learning_rate_schedule) checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - cfg.checkpoint_dir, - cfg.enable_checkpointing, - cfg.async_checkpointing, - cfg.checkpoint_period, + cfg.checkpoint_dir, + cfg.enable_checkpointing, + cfg.async_checkpointing, + cfg.checkpoint_period, ) state, _, _ = max_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager) @@ -108,33 +113,87 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name # maxtext keystr: (paxml keystr, transform_fn) keystr_map = { - "['token_embedder']['embedding']": (".params.lm.softmax.logits_ffn.linear.w", lambda x: x.T), - "['decoder']['position_embedder']['embedding']": (".params.lm.position_emb.emb_var", None), - "['decoder']['layers']['pre_self_attention_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['pre_self_attention_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['query']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['query']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['key']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['key']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['value']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['value']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['qkv_proj']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis])), - "['decoder']['layers']['self_attention']['qkv_proj']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['out']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1])), - "['decoder']['layers']['self_attention']['out']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wi']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wi']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wo']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wo']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['decoder_norm']['scale']": (".params.lm.final_ln.scale", lambda x: x.T), - "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), + "['token_embedder']['embedding']": (".params.lm.softmax.logits_ffn.linear.w", lambda x: x.T), + "['decoder']['position_embedder']['embedding']": (".params.lm.position_emb.emb_var", None), + "['decoder']['layers']['pre_self_attention_norm']['scale']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['pre_self_attention_norm']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['query']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 0], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['query']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 0], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['key']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 1], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['key']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 1], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['value']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 2], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['value']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 2], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['qkv_proj']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis]), + ), + "['decoder']['layers']['self_attention']['qkv_proj']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['out']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", + lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1]), + ), + "['decoder']['layers']['self_attention']['out']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wi']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wi']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wo']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wo']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['decoder_norm']['scale']": (".params.lm.final_ln.scale", lambda x: x.T), + "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), } state_map = { - ".step": ("step", None), - ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), + ".step": ("step", None), + ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), } def get_layer_prefix(keystr_pax): @@ -151,9 +210,15 @@ def get_layer_prefix(keystr_pax): state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) prefix_pax_opt_state = get_layer_prefix(keystr_pax) # first momentum in optimizer state - state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", transform_fn) + state_map[f".opt_state.mu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", + transform_fn, + ) # second momentum in optimizer state - state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", transform_fn) + state_map[f".opt_state.nu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", + transform_fn, + ) def verify_fn(key_path, _): keystr = jax.tree_util.keystr(key_path) @@ -161,19 +226,19 @@ def verify_fn(key_path, _): jax.tree_util.tree_map_with_path(verify_fn, state) - memory_metrics = {'max_cpu_bytes': 0} + memory_metrics = {"max_cpu_bytes": 0} - bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://"):].split('/', 1) + bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://") :].split("/", 1) def map_fn(key_path, value): key_path_str = jax.tree_util.keystr(key_path) file_path, transform_fn = state_map[key_path_str] full_path = os.path.join(paxml_ckpt_prefix, file_path) - spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} - spec['kvstore'] = { - 'bucket': bucket_name, - 'driver': 'gcs', - 'path': full_path, + spec = {"driver": "zarr", "metadata_key": ".zarray", "kvstore": {}} + spec["kvstore"] = { + "bucket": bucket_name, + "driver": "gcs", + "path": full_path, } arr = ts.open(ts.Spec(spec), open=True).result().read().result() @@ -184,10 +249,9 @@ def map_fn(key_path, value): shape = value.shape sharding = value.sharding result = jax.make_array_from_single_device_arrays( - shape, - sharding, - [jax.device_put(np.array(arr[index]), d) - for d, index in sharding.addressable_devices_indices_map(shape).items()], + shape, + sharding, + [jax.device_put(np.array(arr[index]), d) for d, index in sharding.addressable_devices_indices_map(shape).items()], ) # log peak cpu memory @@ -216,15 +280,18 @@ def map_fn(key_path, value): max_logging.log(f"Peak cpu memory in a single process: {fmt_size(memory_metrics['max_cpu_bytes'])}") max_logging.log("checkpoint converted and saved successfully.") + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--paxml-ckpt-path', - type=str, - default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000", - required=True) - parser.add_argument('--maxtext-model-name', choices=['gpt3-175b', 'gpt3-52k'], type=str, required=True) - parser.add_argument('--base-output-directory', type=str, required=True) - parser.add_argument('--run-name', type=str, required=True) + parser.add_argument( + "--paxml-ckpt-path", + type=str, + default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000", + required=True, + ) + parser.add_argument("--maxtext-model-name", choices=["gpt3-175b", "gpt3-52k"], type=str, required=True) + parser.add_argument("--base-output-directory", type=str, required=True) + parser.add_argument("--run-name", type=str, required=True) args = parser.parse_args() if not args.paxml_ckpt_path.startswith("gs://"): diff --git a/MaxText/decode.py b/MaxText/decode.py index 15606b18e..56a92f94f 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -'''CLI Utility for Running Inference on a Single Stream''' +"""CLI Utility for Running Inference on a Single Stream""" import jax @@ -32,26 +32,21 @@ def main(config): metadata = engine.get_tokenizer() vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) tokenizer = vocab.tokenizer - tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, - prefill_lengths=[config.max_prefill_predict_length]) + tokens, true_length = token_utils.tokenize_and_pad( + text, vocab, is_bos=True, prefill_lengths=[config.max_prefill_predict_length] + ) assert tokens.size <= config.max_prefill_predict_length, "can't take too many tokens" assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" - prefill_result = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length - ) - slot=0 + prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + slot = 0 decode_state = engine.init_decode_state() - decode_state = engine.insert( - prefill_result, decode_state, slot=slot - ) + decode_state = engine.insert(prefill_result, decode_state, slot=slot) steps = range(config.max_prefill_predict_length, config.max_target_length) sampled_tokens_list = [] for _ in steps: - decode_state, sampled_tokens = engine.generate( - params, decode_state - ) + decode_state, sampled_tokens = engine.generate(params, decode_state) sampled_tokens_list.append(sampled_tokens) results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list] @@ -59,15 +54,19 @@ def main(config): print(f"Input `{text}` -> `{output}`") if config.autoregressive_decode_assert != "": - assert output==config.autoregressive_decode_assert, \ - f"generated text mismatch {output=} {config.autoregressive_decode_assert=}" + assert ( + output == config.autoregressive_decode_assert + ), f"generated text mismatch {output=} {config.autoregressive_decode_assert=}" + def validate_config(config): - assert config.load_full_state_path == "", "Decode doesn't operate on full states! Convert to parameter checkpoint first."\ - "Using generate_param_only_checkpoint." + assert config.load_full_state_path == "", ( + "Decode doesn't operate on full states! Convert to parameter checkpoint first." "Using generate_param_only_checkpoint." + ) + if __name__ == "__main__": - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(sys.argv) cfg = pyconfig.config diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index cda4273ca..09ea7412b 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Trasforms a "full state" including optimzer state to a bfloat16 "parameter state" without optimizer state. @@ -40,35 +40,38 @@ Transformer = models.Transformer + def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): - """ If input layers are scanned, and force_unroll is set, - return modify training_state and train_state_annotations to be "unrolled". - Otherwise do nothing.""" + """If input layers are scanned, and force_unroll is set, + return modify training_state and train_state_annotations to be "unrolled". + Otherwise do nothing.""" if not config.scan_layers or not config.force_unroll: return - training_state_layers = training_state.params['params']['decoder']['layers'] - training_state_annotations_layers = training_state_annotations.params['params']['decoder']['layers'] + training_state_layers = training_state.params["params"]["decoder"]["layers"] + training_state_annotations_layers = training_state_annotations.params["params"]["decoder"]["layers"] def new_pspec(x): - return jax.sharding.PartitionSpec(*x[0:config.param_scan_axis] + x[config.param_scan_axis+1:]) + return jax.sharding.PartitionSpec(*x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :]) new_per_layer_state_annotation = jax.tree_map(new_pspec, training_state_annotations_layers) - new_per_layer_state_sharding = jax.tree_map(lambda x : jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) + new_per_layer_state_sharding = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) for i in range(config.num_decoder_layers): + def slice_ith(input_layers): - return jax.tree_map(lambda x : jax.numpy.take(x, i, axis = config.param_scan_axis), input_layers) + return jax.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers) + + new_layer = jax.jit(slice_ith, out_shardings=new_per_layer_state_sharding)(training_state_layers) - new_layer = jax.jit(slice_ith, out_shardings = new_per_layer_state_sharding)(training_state_layers) + training_state.params["params"]["decoder"][f"layers_{i}"] = new_layer + training_state_annotations.params["params"]["decoder"][f"layers_{i}"] = new_per_layer_state_annotation - training_state.params['params']['decoder'][f'layers_{i}'] = new_layer - training_state_annotations.params['params']['decoder'][f'layers_{i}'] = new_per_layer_state_annotation + del training_state.params["params"]["decoder"]["layers"] + del training_state_annotations.params["params"]["decoder"]["layers"] - del training_state.params['params']['decoder']['layers'] - del training_state_annotations.params['params']['decoder']['layers'] + jax.tree_map(lambda x: x.delete(), training_state_layers) - jax.tree_map(lambda x : x.delete(), training_state_layers) def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" @@ -78,22 +81,22 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): rng = random.PRNGKey(0) learning_rate_schedule = max_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - state, state_mesh_notations, _ = max_utils.setup_training_state( - model, None, tx, config, rng, mesh, checkpoint_manager - ) + state, state_mesh_notations, _ = max_utils.setup_training_state(model, None, tx, config, rng, mesh, checkpoint_manager) num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") return state, state_mesh_notations + def _save_decode_checkpoint(config, state, checkpoint_manager): """Generate checkpoint for decode from the training_state.""" - with jax.spmd_mode('allow_all'): - decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x : x.astype(jax.numpy.bfloat16), state.params)) + with jax.spmd_mode("allow_all"): + decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params)) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, 0, decode_state): max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}") checkpoint_manager.wait_until_finished() + def generate_decode_checkpoint(config): """ Generate an decode checkpoint from a given training checkpoint. diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index ec46b137f..8fe46b3e2 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Inference microbenchmark for prefill and autoregressive steps.""" import datetime @@ -28,19 +28,21 @@ def summarize_pytree_data(params, name="Params"): - """ Generate basic metrics of a given Pytree. """ + """Generate basic metrics of a given Pytree.""" num_params, total_param_size, avg_param_size = max_utils.summarize_size_from_pytree(params) num_params_in_billions = num_params / 1e9 total_param_size_in_gb = total_param_size / 1e9 - print(f"{name} stats: \n" - f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" - f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" - f"\tAvg size: {avg_param_size:.3f} bytes\n") + print( + f"{name} stats: \n" + f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" + f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n" + ) return num_params, total_param_size, avg_param_size def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""): - """ Inner loop for benchmarking prefill step. """ + """Inner loop for benchmarking prefill step.""" max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() for i in range(iters): @@ -53,9 +55,10 @@ def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_le return (end - start).total_seconds(), decode_state -def prefill_benchmark(config, engine, params, decode_state, tokens, true_length, - iters=100, profile_name="", num_model_params=None): - """ Handles init, warmup, running prefill benchmark, and printing results. """ +def prefill_benchmark( + config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None +): + """Handles init, warmup, running prefill benchmark, and printing results.""" if num_model_params is None: num_model_params, _, _ = summarize_pytree_data(params, name="Params") @@ -69,22 +72,27 @@ def prefill_benchmark(config, engine, params, decode_state, tokens, true_length, print(f"Prefill results for length {tokens.size}:\n") profile_name = f"prefill_{tokens.size}" if profile_name == "" else profile_name - time_in_s, decode_state = prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, - profile_name=profile_name) + time_in_s, decode_state = prefill_benchmark_loop( + config, engine, decode_state, params, tokens, true_length, iters, profile_name=profile_name + ) prefill_average_ms = 1000 * time_in_s / iters total_prefill_tflops, _, _ = maxtext_utils.calculate_tflops_prefill(num_model_params, tokens.size, config) - tflops_per_sec_per_device = total_prefill_tflops / jax.device_count() / prefill_average_ms * 1000. - print(f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n" - f"\tPrefill total TFLOPs: {total_prefill_tflops:.3f}\n" - f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n") - result_dict = {"prefill_time_in_ms": prefill_average_ms, - "prefill_total_tflops": total_prefill_tflops, - "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device} + tflops_per_sec_per_device = total_prefill_tflops / jax.device_count() / prefill_average_ms * 1000.0 + print( + f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n" + f"\tPrefill total TFLOPs: {total_prefill_tflops:.3f}\n" + f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n" + ) + result_dict = { + "prefill_time_in_ms": prefill_average_ms, + "prefill_total_tflops": total_prefill_tflops, + "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device, + } return result_dict, decode_state def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=""): - """ Inner loop for benchmarking ar step. """ + """Inner loop for benchmarking ar step.""" max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() for _ in range(iters): @@ -96,9 +104,9 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name= def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100): - """ Handles init, warmup, running ar benchmark, and printing results. """ + """Handles init, warmup, running ar benchmark, and printing results.""" if cache_size is None: - _, cache_size, _ = summarize_pytree_data(decode_state['cache'], name="Cache") + _, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") if model_size is None: _, model_size, _ = summarize_pytree_data(params, name="Params") global_batch_size = jax.device_count() * config.per_device_batch_size @@ -112,38 +120,41 @@ def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_si profile_name = "autoregress" if profile_name == "" else profile_name time_in_s, decode_state = ar_benchmark_loop(config, engine, decode_state, params, profile_name=profile_name, iters=iters) seconds_per_step = time_in_s / iters - ar_average_ms = seconds_per_step*1000 + ar_average_ms = seconds_per_step * 1000 total_throughput = jax.device_count() * config.per_device_batch_size / seconds_per_step GB_per_step_per_device = (model_size + cache_size) / 1e9 / jax.device_count() - bw_per_device = GB_per_step_per_device/seconds_per_step - print(f"AutoRegressive results:\n" - f"\tAR step average time: {ar_average_ms:.3f}ms\n" - f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n" - f"\tAR global batch size: {global_batch_size}\n" - f"\tAR throughput: {total_throughput:.3f} tokens/second\n" - f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n") - - - result_dict = {"ar_step_in_ms": ar_average_ms, - "ar_step_in_ms_per_seq": ar_average_ms / global_batch_size, - "ar_global_batch_size": global_batch_size, - "ar_total_throughput_tokens_per_second": total_throughput, - "ar_device_bandwidth_GB_per_second": bw_per_device} + bw_per_device = GB_per_step_per_device / seconds_per_step + print( + f"AutoRegressive results:\n" + f"\tAR step average time: {ar_average_ms:.3f}ms\n" + f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n" + f"\tAR global batch size: {global_batch_size}\n" + f"\tAR throughput: {total_throughput:.3f} tokens/second\n" + f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n" + ) + + result_dict = { + "ar_step_in_ms": ar_average_ms, + "ar_step_in_ms_per_seq": ar_average_ms / global_batch_size, + "ar_global_batch_size": global_batch_size, + "ar_total_throughput_tokens_per_second": total_throughput, + "ar_device_bandwidth_GB_per_second": bw_per_device, + } return result_dict, decode_state def collate_results(config, results, model_size, cache_size, num_model_params, incl_config=False): - """ Adds model/cache size info and optionally config info to results. """ + """Adds model/cache size info and optionally config info to results.""" results["sizes"] = { - "Model_size_in_GB": model_size / 1e9, - "cache_size_in_GB": cache_size / 1e9, - "model_params_in_billions": num_model_params / 1e9, + "Model_size_in_GB": model_size / 1e9, + "cache_size_in_GB": cache_size / 1e9, + "model_params_in_billions": num_model_params / 1e9, } if incl_config: results["config"] = {} for k, v in dict(config.get_keys()).items(): - results["config"][k] = str(v) if k == "dtype" else v # json fails with original dtype + results["config"][k] = str(v) if k == "dtype" else v # json fails with original dtype return results @@ -172,18 +183,25 @@ def main(config): vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) decode_state = engine.init_decode_state() - _, cache_size, _ = summarize_pytree_data(decode_state['cache'], name="Cache") + _, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") num_model_params, model_size, _ = summarize_pytree_data(params, name="Model") benchmark_results = {"Prefill": {}} benchmark_results["AutoRegressive"], decode_state = ar_benchmark( - config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size) + config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size + ) for prefill_length in prefill_lengths: - tokens, true_length = token_utils.tokenize_and_pad( - text, vocab, is_bos=True, prefill_lengths=[prefill_length]) + tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, prefill_lengths=[prefill_length]) benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( - config, engine, params, decode_state, tokens, true_length, - iters=benchmark_loop_iters, num_model_params=num_model_params) + config, + engine, + params, + decode_state, + tokens, + true_length, + iters=benchmark_loop_iters, + num_model_params=num_model_params, + ) results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) write_results(results, filename="") diff --git a/MaxText/inference_scratch/analyze_sharegpt.py b/MaxText/inference_scratch/analyze_sharegpt.py index 47d8715cf..47847275d 100644 --- a/MaxText/inference_scratch/analyze_sharegpt.py +++ b/MaxText/inference_scratch/analyze_sharegpt.py @@ -20,13 +20,16 @@ MAX_INPUT_TOKENS = 1024 MAX_OUTPUT_TOKENS = 1024 + def next_power_of_2(x): - return 1 if x == 0 else 2**(x - 1).bit_length() + return 1 if x == 0 else 2 ** (x - 1).bit_length() + def tokens_in_input_str(s): - return_val = int(1.3 * len(s.split())) + return_val = int(1.3 * len(s.split())) return return_val + def get_prefill_and_generate_times(filename=""): if filename == "": return PREFILL_BUCKET_SIZE_TO_MS, SYSTEM_TIME_PER_DECODE_TOKEN_MS @@ -37,17 +40,18 @@ def get_prefill_and_generate_times(filename=""): for k, v in microbenchmark_results["Prefill"].items(): prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) - return prefill_bucket_size_to_ms, microbenchmark_results['AutoRegressive']['ar_step_in_ms_per_seq'] + return prefill_bucket_size_to_ms, microbenchmark_results["AutoRegressive"]["ar_step_in_ms_per_seq"] + def get_conversations_from_file(filename, max_input_tokens, max_output_tokens): convo_token_numbers = [] - with open(filename, 'r') as f: + with open(filename, "r") as f: loaded_share_gpt = json.load(f) for example in loaded_share_gpt: - if len(example['conversations']) < 2: + if len(example["conversations"]) < 2: continue - num_input_tokens = tokens_in_input_str(example['conversations'][0]['value']) - num_output_tokens = tokens_in_input_str(example['conversations'][1]['value']) + num_input_tokens = tokens_in_input_str(example["conversations"][0]["value"]) + num_output_tokens = tokens_in_input_str(example["conversations"][1]["value"]) convo_token_numbers.append((num_input_tokens, num_output_tokens)) num_convos = len(convo_token_numbers) @@ -78,9 +82,11 @@ def compute_times(convos, prefill_bucket_size_to_ms, system_time_per_decode_toke total_generate_time_seconds = total_generate_system_ms / 1000 total_time_s = total_prefill_time_seconds + total_generate_time_seconds - print(f"\nTotal time {total_time_s:.3f} seconds: " - f"\n\tPrefill time: {total_prefill_time_seconds:.3f} seconds" - f"\n\tGenerate time: {total_generate_time_seconds:.3f} seconds") + print( + f"\nTotal time {total_time_s:.3f} seconds: " + f"\n\tPrefill time: {total_prefill_time_seconds:.3f} seconds" + f"\n\tGenerate time: {total_generate_time_seconds:.3f} seconds" + ) return total_time_s, total_prefill_time_seconds, total_generate_time_seconds @@ -92,11 +98,11 @@ def get_num_tokens_in_convos(convos): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('convo_file', type=str, - help='a json file containing conversations') - parser.add_argument('-t', '--mb_timing_file', type=str, default="", - help='a json file containing microbenchmark timing results') - parser.add_argument('-v', '--verbose', action="store_true") + parser.add_argument("convo_file", type=str, help="a json file containing conversations") + parser.add_argument( + "-t", "--mb_timing_file", type=str, default="", help="a json file containing microbenchmark timing results" + ) + parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() convos = get_conversations_from_file(args.convo_file, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS) @@ -104,5 +110,7 @@ def get_num_tokens_in_convos(convos): prefill_time_ms_buckets, generate_time_ms = get_prefill_and_generate_times(filename=args.mb_timing_file) total_time_seconds, _, _ = compute_times(convos, prefill_time_ms_buckets, generate_time_ms, args.verbose) - print(f"Output {total_output_tokens} tokens in {total_time_seconds:.3f} seconds " - f"= {total_output_tokens/total_time_seconds:.3f} out tok/s") + print( + f"Output {total_output_tokens} tokens in {total_time_seconds:.3f} seconds " + f"= {total_output_tokens/total_time_seconds:.3f} out tok/s" + ) diff --git a/MaxText/inference_utils.py b/MaxText/inference_utils.py index 786ecaae7..96c727c0c 100644 --- a/MaxText/inference_utils.py +++ b/MaxText/inference_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax import jax.numpy as jnp @@ -26,14 +26,15 @@ Inspired by an Google-internal implementation, Global Vision Transformer. """ + def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): - """ - logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], before logit - rng: rng key to use - algorithm: string representing supported algorithms - topk: restricting to topk logits before sampling - nucleus_topp: restricting to p probability mass before sampling - temperature: temperature parameter for scaling probability + """ + logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], before logit + rng: rng key to use + algorithm: string representing supported algorithms + topk: restricting to topk logits before sampling + nucleus_topp: restricting to p probability mass before sampling + temperature: temperature parameter for scaling probability """ if algorithm == "greedy": return jnp.argmax(logits, axis=-1) @@ -46,36 +47,29 @@ def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): else: raise ValueError(f"Sampling {algorithm=} not supported!") + def sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng): """Restrict sampling to the top logits with cumulative probability >= nucleus_topp. - - The nucleus sampling method is proposed in the paper `The Curious Case of - Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` - + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + """ if nucleus_topp < 0: raise ValueError("Can't apply nucleus with parameter {nucleus_topp=} less zero") logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending - sorted_cum_probs = jnp.cumsum( - jax.nn.softmax(logits_sorted, axis=-1), axis=-1) # get cumsum probs - cutoff_index = jnp.sum( - sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True) # find cutoff index + sorted_cum_probs = jnp.cumsum(jax.nn.softmax(logits_sorted, axis=-1), axis=-1) # get cumsum probs + cutoff_index = jnp.sum(sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True) # find cutoff index cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) - logits = jnp.where(logits < cutoff_logit, - jnp.full_like(logits, NEG_INF), logits) + logits = jnp.where(logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits) return jax.random.categorical(rng, logits / temperature) + def sample_topk_logits(logits, topk, temperature, rng): - """ Restricting sampling to the best k logits. """ + """Restricting sampling to the best k logits.""" if topk <= 0: raise ValueError("Can't apply algorithm topk with parameter {topk=} less than or equal to zero") topk_logits, topk_idxs = jax.lax.top_k(logits, topk) - topk_token = jnp.expand_dims( - jax.random.categorical(rng, topk_logits/temperature).astype(jnp.int32), - axis=-1) - sampled_tokens = jnp.squeeze( - jnp.take_along_axis(topk_idxs, topk_token, axis=-1), - axis=-1).astype(jnp.int32) + topk_token = jnp.expand_dims(jax.random.categorical(rng, topk_logits / temperature).astype(jnp.int32), axis=-1) + sampled_tokens = jnp.squeeze(jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1).astype(jnp.int32) return sampled_tokens - - diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index 92c90854d..57b8f1e5e 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline using Grain.""" @@ -30,31 +30,33 @@ import multihost_dataloading -def get_datasets( - config: ml_collections.ConfigDict -): + +def get_datasets(config: ml_collections.ConfigDict): """Load dataset from array_record files for using with grain""" data_dir = os.path.join(config.dataset_path, config.dataset_name) - train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)] + train_files = [data_dir + "/" + f for f in os.listdir(data_dir) if re.match(r".*train.*", f)] train_ds = grain.ArrayRecordDataSource(train_files) if config.eval_dataset_name: - eval_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(rf'.*{config.eval_split}.*', f)] + eval_files = [data_dir + "/" + f for f in os.listdir(data_dir) if re.match(rf".*{config.eval_split}.*", f)] eval_ds = grain.ArrayRecordDataSource(eval_files) else: eval_ds = train_ds return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, - global_mesh, - train_ds, eval_ds, - vocab_path: Optional[str] = None, - data_shuffle_seed = 0, - add_bos = True, - add_eos = True - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, + global_mesh, + train_ds, + eval_ds, + vocab_path: Optional[str] = None, + data_shuffle_seed=0, + add_bos=True, + add_eos=True, +): """Use grain to pre-process the dataset and return iterators""" # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load @@ -78,7 +80,8 @@ def preprocess_dataset(config: ml_collections.ConfigDict, num_epochs=1, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) eval_iter = preprocessing_pipeline( eval_ds, @@ -93,7 +96,8 @@ def preprocess_dataset(config: ml_collections.ConfigDict, shuffle=config.enable_data_shuffling, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) predict_iter = preprocessing_pipeline( eval_ds, @@ -108,45 +112,45 @@ def preprocess_dataset(config: ml_collections.ConfigDict, shuffle=config.enable_data_shuffling, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed) + data_shuffle_seed=data_shuffle_seed, + ) return train_iter, eval_iter, predict_iter + def preprocessing_pipeline( - dataset, - vocab_path, - add_bos: bool, - add_eos: bool, - grain_worker_count: int, - batch_size: int, - global_mesh, - dataloading_host_index, - dataloading_host_count, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - max_length: int = 512, - shift: bool = True, - drop_remainder: bool = True, - data_shuffle_seed = 0, + dataset, + vocab_path, + add_bos: bool, + add_eos: bool, + grain_worker_count: int, + batch_size: int, + global_mesh, + dataloading_host_index, + dataloading_host_count, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + max_length: int = 512, + shift: bool = True, + drop_remainder: bool = True, + data_shuffle_seed=0, ): """Apply grain operations to preprocess the given dataset.""" - assert ( - batch_size % global_mesh.size == 0 - ), 'Batch size should be divisible number of global devices.' + assert batch_size % global_mesh.size == 0, "Batch size should be divisible number of global devices." operations = [] operations.append(_grain_operations.ParseFeatures()) operations.append(_grain_operations.NormalizeFeatures()) - operations.append(_grain_tokenizer.TokenizeAndTrim(["inputs","targets"], - max_length, vocab_path, - add_bos, add_eos)) + operations.append(_grain_tokenizer.TokenizeAndTrim(["inputs", "targets"], max_length, vocab_path, add_bos, add_eos)) # Pack and Batch examples. if pack_examples: - operations.append(grain.experimental.PackAndBatchOperation( - batch_size=batch_size // jax.process_count(), - length_struct={'inputs':max_length,'targets':max_length})) + operations.append( + grain.experimental.PackAndBatchOperation( + batch_size=batch_size // jax.process_count(), length_struct={"inputs": max_length, "targets": max_length} + ) + ) operations.append(_grain_operations.ReformatPacking()) else: operations.append(_grain_operations.PadToMaxLength(max_length)) @@ -157,19 +161,19 @@ def preprocessing_pipeline( operations.append(_grain_operations.ShiftData(axis=1)) index_sampler = grain.IndexSampler( - num_records=len(dataset), - num_epochs = num_epochs, - shard_options=grain.ShardOptions( - shard_index = dataloading_host_index, shard_count = dataloading_host_count, drop_remainder = True - ), - shuffle = shuffle, - seed = data_shuffle_seed + num_records=len(dataset), + num_epochs=num_epochs, + shard_options=grain.ShardOptions( + shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=True + ), + shuffle=shuffle, + seed=data_shuffle_seed, ) dataloader = grain.DataLoader( - data_source = dataset, - operations = operations, - sampler = index_sampler, + data_source=dataset, + operations=operations, + sampler=index_sampler, worker_count=grain_worker_count, ) diff --git a/MaxText/input_pipeline/_grain_operations.py b/MaxText/input_pipeline/_grain_operations.py index 25508376b..685165381 100644 --- a/MaxText/input_pipeline/_grain_operations.py +++ b/MaxText/input_pipeline/_grain_operations.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Operations used by Grain""" @@ -21,60 +21,64 @@ import grain.python as grain import numpy as np import tensorflow as tf + Features = Dict[str, tf.Tensor] + @dataclasses.dataclass class ParseFeatures(grain.MapTransform): """Parse serialized example""" + def map(self, features): def _parse(example): - parsed = tf.io.parse_example( - example, { - 'text': tf.io.FixedLenFeature(shape=(), dtype=tf.string) - }) + parsed = tf.io.parse_example(example, {"text": tf.io.FixedLenFeature(shape=(), dtype=tf.string)}) return parsed + return _parse(features) @dataclasses.dataclass class NormalizeFeatures(grain.MapTransform): """Normalize text feature keys.""" + def map(self, features): - return { - 'inputs':features['text'].numpy().decode(), - 'targets': features['text'].numpy().decode() - } + return {"inputs": features["text"].numpy().decode(), "targets": features["text"].numpy().decode()} @dataclasses.dataclass class ReformatPacking(grain.MapTransform): """Reformat packing outputs.""" + def map(self, data): - return{ - 'inputs':data[0]['inputs'], - 'targets':data[0]['targets'], - 'inputs_segmentation':data[1]['inputs'], - 'targets_segmentation':data[1]['targets'], - 'inputs_position':data[2]['inputs'], - 'targets_position':data[2]['targets'], + return { + "inputs": data[0]["inputs"], + "targets": data[0]["targets"], + "inputs_segmentation": data[1]["inputs"], + "targets_segmentation": data[1]["targets"], + "inputs_position": data[2]["inputs"], + "targets_position": data[2]["targets"], } @dataclasses.dataclass class PadToMaxLength(grain.MapTransform): - """Pads each input to the specified length""" + """Pads each input to the specified length""" + def __init__(self, max_length): self.max_length = max_length + def map(self, data): """map to each element""" + def _pad(x, max_length): pad_amount = max(max_length - x.shape[0], 0) pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1) return np.pad(x, pad_amount) - data['inputs_segmentation'] = np.ones(data['inputs'].shape, dtype = np.int32) - data['inputs_position'] = np.arange(data['inputs'].shape[0], dtype = np.int32) - data['targets_segmentation'] = np.ones(data['targets'].shape, dtype = np.int32) - data['targets_position'] = np.arange(data['targets'].shape[0], dtype = np.int32) + + data["inputs_segmentation"] = np.ones(data["inputs"].shape, dtype=np.int32) + data["inputs_position"] = np.arange(data["inputs"].shape[0], dtype=np.int32) + data["targets_segmentation"] = np.ones(data["targets"].shape, dtype=np.int32) + data["targets_position"] = np.arange(data["targets"].shape[0], dtype=np.int32) for key, _ in data.items(): data[key] = _pad(data[key], self.max_length) return data @@ -84,33 +88,34 @@ def shift_right(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - slices = [slice(None),] * len(x.shape) + slices = [ + slice(None), + ] * len(x.shape) slices[axis] = slice(0, -1) - padded = np.pad( - x, - pad_widths, - mode='constant', - constant_values=x.dtype.type(0) - ) + padded = np.pad(x, pad_widths, mode="constant", constant_values=x.dtype.type(0)) return padded[tuple(slices)] + def shift_and_refine(x, axis=1): """Shift inputs, set segmentation to 0 when target element is 0. Replace EOS by 0 for packed inputs.""" - x['inputs'] = shift_right(x['inputs'], axis=axis) - targets_nonzero = x['targets'] != 0 - x['inputs_segmentation'] *= targets_nonzero - x['targets_segmentation'] *= targets_nonzero + x["inputs"] = shift_right(x["inputs"], axis=axis) + targets_nonzero = x["targets"] != 0 + x["inputs_segmentation"] *= targets_nonzero + x["targets_segmentation"] *= targets_nonzero # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. - x['inputs'] *= x['inputs_segmentation'] == shift_right(x['inputs_segmentation'], axis=axis) + x["inputs"] *= x["inputs_segmentation"] == shift_right(x["inputs_segmentation"], axis=axis) return x + @dataclasses.dataclass class ShiftData(grain.MapTransform): """Shift inputs and refine annotations.""" - def __init__(self, axis = 1): + + def __init__(self, axis=1): self.axis = axis + def map(self, data): return shift_and_refine(data, axis=self.axis) diff --git a/MaxText/input_pipeline/_grain_tokenizer.py b/MaxText/input_pipeline/_grain_tokenizer.py index e9436799a..9a79af448 100644 --- a/MaxText/input_pipeline/_grain_tokenizer.py +++ b/MaxText/input_pipeline/_grain_tokenizer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Tokenize Op used by Grain""" @@ -24,9 +24,11 @@ import grain.python as grain import numpy as np + @dataclasses.dataclass class TokenizeAndTrim(grain.MapTransform): """Tokenize and trim features to sequence length.""" + # pylint: disable=attribute-defined-outside-init feature_names: str | Sequence[str] sequence_length: int | Sequence[int] @@ -49,16 +51,14 @@ def map(self, features: dict[str, Any]) -> dict[str, Any]: if self._processor is None: # Ensures only one thread initializes SPP. self._processor = SentencePieceProcessor() self._processor.Load(self.model_path) - for feature_name, sequence_length in zip( - self.feature_names, self.sequence_length, strict=True - ): + for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True): text = features[feature_name] token_ids = self._processor.EncodeAsIds(text) if self.add_bos: token_ids = [self._processor.bos_id()] + token_ids if self.add_eos: - token_ids = token_ids[:sequence_length-1] + token_ids = token_ids[: sequence_length - 1] token_ids = token_ids + [self._processor.eos_id()] else: token_ids = token_ids[:sequence_length] diff --git a/MaxText/input_pipeline/_tfds_data_processing.py b/MaxText/input_pipeline/_tfds_data_processing.py index 506e7bd87..865278831 100644 --- a/MaxText/input_pipeline/_tfds_data_processing.py +++ b/MaxText/input_pipeline/_tfds_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline for a LM1B dataset.""" @@ -34,17 +34,16 @@ # Right-shifting token inputs for teacher-forced training. # ----------------------------------------------------------------------------- + def shift_right_tf(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - slices = [slice(None),] * len(x.shape) + slices = [ + slice(None), + ] * len(x.shape) slices[axis] = slice(0, -1) - padded = tf.pad( - x, - tf.constant(pad_widths), - mode='constant', - constant_values=tf.constant(0, x.dtype)) + padded = tf.pad(x, tf.constant(pad_widths), mode="constant", constant_values=tf.constant(0, x.dtype)) return padded[tuple(slices)] @@ -54,46 +53,45 @@ def shift_inputs_tf(x, segment_ids=None, axis=1): # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: - shifted *= tf.cast( - segment_ids == shift_right_tf(segment_ids, axis=axis), x.dtype - ) + shifted *= tf.cast(segment_ids == shift_right_tf(segment_ids, axis=axis), x.dtype) return shifted + def shift_data(x, axis=0, segmented=True): - segment_ids = x['inputs_segmentation'] if segmented else None - x['inputs'] = shift_inputs_tf(x['inputs'], segment_ids=segment_ids, axis=axis) + segment_ids = x["inputs_segmentation"] if segmented else None + x["inputs"] = shift_inputs_tf(x["inputs"], segment_ids=segment_ids, axis=axis) return x + def shift_data_by_truncation(x): - x['inputs'] = x['inputs'][:-1] - x['targets'] = x['targets'][1:] + x["inputs"] = x["inputs"][:-1] + x["targets"] = x["targets"][1:] return x def normalize_features(ds): """Normalize text feature keys.""" + def _normalize_features(features): - features['inputs'] = features.pop('text') - features['targets'] = features['inputs'] + features["inputs"] = features.pop("text") + features["targets"] = features["inputs"] return features - return ds.map( - _normalize_features, - num_parallel_calls=AUTOTUNE) + return ds.map(_normalize_features, num_parallel_calls=AUTOTUNE) + def length_trim(ds, max_len): - """"Trim to Max length""" + """ "Trim to Max length""" + def _trim_fn(features): - if tf.shape(features['inputs'])[0] > max_len: - features['inputs'] = features['inputs'][:max_len] - if tf.shape(features['targets'])[0] > max_len: - features['targets'] = features['targets'][:max_len] + if tf.shape(features["inputs"])[0] > max_len: + features["inputs"] = features["inputs"][:max_len] + if tf.shape(features["targets"])[0] > max_len: + features["targets"] = features["targets"][:max_len] return features - return ds.map( - _trim_fn, - num_parallel_calls=AUTOTUNE - ) + return ds.map(_trim_fn, num_parallel_calls=AUTOTUNE) + # ----------------------------------------------------------------------------- # Main dataset preparation. @@ -101,52 +99,45 @@ def _trim_fn(features): def preprocessing_pipeline( - dataset, - batch_size: int, - global_mesh, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - shuffle_buffer_size: int = 1024, - max_length: int = 512, - shift: bool = True, - drop_remainder: bool = True, - prefetch_size = tf.data.experimental.AUTOTUNE, - data_shuffle_seed = 0, + dataset, + batch_size: int, + global_mesh, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + shift: bool = True, + drop_remainder: bool = True, + prefetch_size=tf.data.experimental.AUTOTUNE, + data_shuffle_seed=0, ): """Shuffle and batch/pack the given dataset.""" def truncate_to_max_allowable_length(x, max_length): - x['inputs'] = x['inputs'][:max_length] - x['targets'] = x['targets'][:max_length] + x["inputs"] = x["inputs"][:max_length] + x["targets"] = x["targets"][:max_length] return x - if max_length > 0: # We can take upto max_length+1 because there would be truncation by 1 token # for both inputs and targets - dataset = dataset.map(lambda x: truncate_to_max_allowable_length(x, max_length+1)) + dataset = dataset.map(lambda x: truncate_to_max_allowable_length(x, max_length + 1)) # Shuffle and repeat. if shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, seed = data_shuffle_seed) + dataset = dataset.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) dataset = dataset.repeat(num_epochs) - # Shift inputs for teacher-forced training if shift: - dataset = dataset.map( - shift_data_by_truncation, - num_parallel_calls=tf.data.AUTOTUNE, - deterministic=True) + dataset = dataset.map(shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) # Perform greedy sequence packing if pack_examples: dataset = sequence_packing.pack_dataset(dataset, max_length) - assert ( - batch_size % global_mesh.size == 0 - ), 'Batch size should be divisible number of global devices.' + assert batch_size % global_mesh.size == 0, "Batch size should be divisible number of global devices." # Batch examples. if pack_examples: @@ -155,9 +146,10 @@ def truncate_to_max_allowable_length(x, max_length): # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size // jax.process_count(), - padded_shapes={'inputs': max_length, 'targets': max_length}, - padding_values={'inputs': 0, 'targets': 0}, - drop_remainder=drop_remainder) + padded_shapes={"inputs": max_length, "targets": max_length}, + padding_values={"inputs": 0, "targets": 0}, + drop_remainder=drop_remainder, + ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) @@ -169,20 +161,18 @@ def truncate_to_max_allowable_length(x, max_length): def get_datasets( - config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, - read_config = None, + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, + read_config=None, ): """Load and return dataset of batched examples for use during training.""" # Training dataset. train_ds_builder = tfds.builder(config.dataset_name) # train_data = get_raw_dataset(train_ds_builder, 'train') - train_ds = train_ds_builder.as_dataset(split='train', - read_config = read_config, - shuffle_files=config.enable_data_shuffling) + train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=config.enable_data_shuffling) # shard the dataset as soon as it is loaded - train_ds = train_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) + train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) train_ds = normalize_features(train_ds) # Evaluation dataset. @@ -191,25 +181,25 @@ def get_datasets( else: eval_ds_builder = train_ds_builder # eval_data = get_raw_dataset(eval_ds_builder, config.eval_split) - eval_ds = eval_ds_builder.as_dataset(split=config.eval_split, - read_config = read_config, - shuffle_files=False) - eval_ds = eval_ds.shard(num_shards = jax.process_count(), index = jax.process_index()) + eval_ds = eval_ds_builder.as_dataset(split=config.eval_split, read_config=read_config, shuffle_files=False) + eval_ds = eval_ds.shard(num_shards=jax.process_count(), index=jax.process_index()) eval_ds = normalize_features(eval_ds) return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - global_mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed = 0, - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + global_mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed=0, +): """Pre-process the dataset and return iterators""" # Tokenize data. - train_ds = train_ds.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) - eval_ds = eval_ds.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + train_ds = train_ds.map(tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + eval_ds = eval_ds.map(tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load @@ -220,9 +210,10 @@ def preprocess_dataset(config: ml_collections.ConfigDict, eval_batch_size = global_batch_size_to_load def filter_keys(record): - return {'inputs': record['inputs'], 'targets': record['targets']} - train_ds = train_ds.map(filter_keys,num_parallel_calls=tf.data.AUTOTUNE) - eval_ds = eval_ds.map(filter_keys,num_parallel_calls=tf.data.AUTOTUNE) + return {"inputs": record["inputs"], "targets": record["targets"]} + + train_ds = train_ds.map(filter_keys, num_parallel_calls=tf.data.AUTOTUNE) + eval_ds = eval_ds.map(filter_keys, num_parallel_calls=tf.data.AUTOTUNE) train_iter = preprocessing_pipeline( train_ds, @@ -233,7 +224,8 @@ def filter_keys(record): pack_examples=True, max_length=config.max_target_length, shift=True, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) eval_iter = preprocessing_pipeline( eval_ds, @@ -244,7 +236,8 @@ def filter_keys(record): max_length=config.max_target_length, shift=False, drop_remainder=False, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) predict_iter = preprocessing_pipeline( eval_ds, @@ -255,6 +248,7 @@ def filter_keys(record): max_length=config.max_target_length, shift=False, drop_remainder=False, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) return train_iter, eval_iter, predict_iter diff --git a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index c5a6d8d06..0ca44bdb1 100644 --- a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline for gpt3 c4 mlperf dataset.""" @@ -35,6 +35,7 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE + # data processing functions: # _shift_left_and_pad, rekey, reduce_concat_tokens and split_tokens_to_targets_length # Adapted from: @@ -58,8 +59,10 @@ def _shift_left_and_pad(tensor, pad_val): v = v[0] return v + def rekey(ds, key_map=None): """normalization with key mapping""" + def _rekey(x, key_map=None): """Replace the feature keys according to the mapping in `key_map`. For example, if the dataset returns examples of the format: @@ -75,20 +78,17 @@ def _rekey(x, key_map=None): A preprocessed example with the format listed above. """ if key_map: - return { - new_key: x[old_key] - for new_key, old_key in key_map.items() if old_key - } + return {new_key: x[old_key] for new_key, old_key in key_map.items() if old_key} return x - return ds.map( - functools.partial(_rekey, key_map=key_map), - num_parallel_calls=AUTOTUNE) + return ds.map(functools.partial(_rekey, key_map=key_map), num_parallel_calls=AUTOTUNE) + -def reduce_concat_tokens(dataset, - feature_key='targets', - batch_size=128, - ): +def reduce_concat_tokens( + dataset, + feature_key="targets", + batch_size=128, +): """Token-preprocessor to concatenate multiple unrelated documents. If we want to generate examples of exactly the right length, (to avoid wasting space on padding), then we use this function, folowed by @@ -100,9 +100,9 @@ def reduce_concat_tokens(dataset, Returns: a dataset """ - dataset = dataset.map( - lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE) + dataset = dataset.map(lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE) dataset = dataset.padded_batch(batch_size, padded_shapes={feature_key: [-1]}) + def _my_fn(x): tokens = tf.reshape(x[feature_key], [-1]) # strip padding @@ -111,10 +111,12 @@ def _my_fn(x): return dataset.map(_my_fn, num_parallel_calls=AUTOTUNE) -def split_tokens(dataset, - max_tokens_per_segment=128, - feature_key='targets', - ): + +def split_tokens( + dataset, + max_tokens_per_segment=128, + feature_key="targets", +): """Split examples into multiple examples each. The intended use case is to break up long examples for use in unsupervised transfer-learning. @@ -127,6 +129,7 @@ def split_tokens(dataset, Returns: a dataset """ + def _split_tokens(x): """Split one token sequence into multiple multiple.""" tokens = x[feature_key] @@ -135,9 +138,7 @@ def _split_tokens(x): # Pad to a multiple of length, then use tf.reshape to split up the tokens # into num_segments segments each of the given length. - num_segments = tf.cast( - tf.math.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)), - tf.int32) + num_segments = tf.cast(tf.math.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)), tf.int32) padding = num_segments * length - tf.size(tokens) tokens = tf.pad(tokens, [[0, padding]]) return tf.reshape(tokens, [-1, length]) @@ -149,19 +150,25 @@ def _strip_padding(x): dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0)) dataset = dataset.map(_split_tokens, num_parallel_calls=AUTOTUNE) dataset = dataset.unbatch() - return dataset.map( - _strip_padding, num_parallel_calls=AUTOTUNE) + return dataset.map(_strip_padding, num_parallel_calls=AUTOTUNE) + def split_tokens_to_targets_length(dataset, sequence_length): return split_tokens(dataset, max_tokens_per_segment=sequence_length) -def _pad_to_batch_size(ds: tf.data.Dataset, batch_size: int, num_examples: Optional[int] = None,) -> tf.data.Dataset: + +def _pad_to_batch_size( + ds: tf.data.Dataset, + batch_size: int, + num_examples: Optional[int] = None, +) -> tf.data.Dataset: """Pad unevenly distributed eval data in each shard with new entries to multiples of batch size.""" # local_num represents the total number of examples in eval dataset, if num_examples: local_num = num_examples else: + def _get_num_examples(ds: tf.data.Dataset) -> int: # Iterate one-by-one instead of len(list(...)) to reduce peak memory. num_examples = 0 @@ -173,61 +180,66 @@ def _get_num_examples(ds: tf.data.Dataset) -> int: local_num = _get_num_examples(ds) local_num_batches = (local_num + batch_size - 1) // batch_size # Find the max number of batches required across all Jax processes. - num_batches_all = multihost_utils.process_allgather( - jnp.array([local_num_batches]), tiled=False) + num_batches_all = multihost_utils.process_allgather(jnp.array([local_num_batches]), tiled=False) num_batches = np.max(num_batches_all) pad_num = num_batches * batch_size - local_num assert pad_num >= 0 print( - f'Eval data has {local_num} local entries, padding now with ' - f'{pad_num} extra entries to get {num_batches} batches.') + f"Eval data has {local_num} local entries, padding now with " f"{pad_num} extra entries to get {num_batches} batches." + ) + # Repeat a random example to make the last batch full. def _add_pad(x): - x['targets_segmentation'] *= 0 + x["targets_segmentation"] *= 0 return x + pad_ds = ds.take(1).map(_add_pad).repeat(pad_num) return ds.concatenate(pad_ds) + def get_datasets( - config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, ): """Load and return dataset of batched examples for use during training.""" # Training dataset. read_config = tfds.ReadConfig( - shuffle_seed = config.data_shuffle_seed, - ) + shuffle_seed=config.data_shuffle_seed, + ) train_ds_builder = tfds.builder(config.dataset_name) - train_ds = train_ds_builder.as_dataset(split='train2', read_config=read_config, shuffle_files=config.enable_data_shuffling) + train_ds = train_ds_builder.as_dataset(split="train2", read_config=read_config, shuffle_files=config.enable_data_shuffling) eval_ds_builder = tfds.builder(config.eval_dataset_name) - eval_ds = eval_ds_builder.as_dataset(split='validation_tokenized_5662seqs', read_config=read_config, shuffle_files=False) + eval_ds = eval_ds_builder.as_dataset(split="validation_tokenized_5662seqs", read_config=read_config, shuffle_files=False) # shard the dataset as soon as it is loaded - train_ds = train_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) - train_ds = rekey(train_ds, {'inputs': None, 'targets': 'text'}) + train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) + train_ds = rekey(train_ds, {"inputs": None, "targets": "text"}) - eval_ds = eval_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) + eval_ds = eval_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) # note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and splitted to target_length # mainly to avoid eval sequences change depending on the number of hosts - eval_ds = rekey(eval_ds, {'inputs': None, 'targets': 'ids'}) + eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"}) return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - global_mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed: int = 0, - shuffle_buffer_size: int = 128, - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + global_mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed: int = 0, + shuffle_buffer_size: int = 128, +): """Pre-process the dataset and return iterators for mlperf training.""" # tokenize - train_ds = train_ds.map( - tokenizer.TokenizeOp(sp_tokenizer, data_keys=('targets',)), num_parallel_calls=AUTOTUNE) + train_ds = train_ds.map(tokenizer.TokenizeOp(sp_tokenizer, data_keys=("targets",)), num_parallel_calls=AUTOTUNE) - train_ds = reduce_concat_tokens(train_ds, feature_key='targets', batch_size=4096) + train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096) train_ds = split_tokens_to_targets_length(train_ds, config.max_target_length) train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) @@ -241,8 +253,8 @@ def format_fn(x, eos_id: int = 1, pad_id: int = 0): x["inputs_position"] = x["targets_position"] x["targets"] = _shift_left_and_pad(x["targets"], eos_id) x["inputs_segmentation"] = tf.where( - tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), - x["targets_segmentation"], 0) + tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), x["targets_segmentation"], 0 + ) x["targets_segmentation"] = x["inputs_segmentation"] return x diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 58227feee..37ecd8f4a 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline""" @@ -26,79 +26,80 @@ from input_pipeline import _tfds_data_processing_c4_mlperf import tokenizer + def get_tokenizer(tokenizer_path, add_bos=True, add_eos=True): # Load tokenizer - sp_tokenizer = tokenizer.load_tokenizer(tokenizer_path=tokenizer_path, - add_bos=add_bos, - add_eos=add_eos) + sp_tokenizer = tokenizer.load_tokenizer(tokenizer_path=tokenizer_path, add_bos=add_bos, add_eos=add_eos) return sp_tokenizer + def make_c4_mlperf_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for customized C4 dataset for mlperf gpt3 training.""" + """Make train iterator and tokenizer for customized C4 dataset for mlperf gpt3 training.""" train_ds, eval_ds = _tfds_data_processing_c4_mlperf.get_datasets( - config=config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), + config=config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), ) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, eval_iter = _tfds_data_processing_c4_mlperf.preprocess_dataset( - config, - mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed=config.data_shuffle_seed + config, mesh, train_ds, eval_ds, sp_tokenizer, data_shuffle_seed=config.data_shuffle_seed ) return train_iter, eval_iter, sp_tokenizer + def make_c4_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for C4 dataset""" + """Make train iterator and tokenizer for C4 dataset""" read_config = tfds.ReadConfig( - shuffle_seed = config.data_shuffle_seed, + shuffle_seed=config.data_shuffle_seed, ) train_ds, eval_ds = _tfds_data_processing.get_datasets( - config=config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - read_config = read_config, + config=config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + read_config=read_config, ) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, _, _ = _tfds_data_processing.preprocess_dataset( - config, - mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed = config.data_shuffle_seed, + config, + mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed=config.data_shuffle_seed, ) return train_iter, None, sp_tokenizer + def make_grain_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for C4 dataset""" - train_ds, eval_ds = _grain_data_processing.get_datasets( - config=config - ) + """Make train iterator and tokenizer for C4 dataset""" + train_ds, eval_ds = _grain_data_processing.get_datasets(config=config) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, _, _ = _grain_data_processing.preprocess_dataset( - config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - global_mesh = mesh, - train_ds = train_ds, eval_ds = eval_ds, - vocab_path=config.tokenizer_path, - data_shuffle_seed = config.data_shuffle_seed, - add_bos = add_bos, - add_eos = add_eos + config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=mesh, + train_ds=train_ds, + eval_ds=eval_ds, + vocab_path=config.tokenizer_path, + data_shuffle_seed=config.data_shuffle_seed, + add_bos=add_bos, + add_eos=add_eos, ) return train_iter, None, sp_tokenizer -class SyntheticDataIterator(): + +class SyntheticDataIterator: """Creates a synthetic data iterator for performance testing work""" + def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) - data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - self.data_generator = jax.jit(SyntheticDataIterator.raw_generate_synthetic_data, - out_shardings=data_pspec_shardings, - static_argnums=0) + data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + self.data_generator = jax.jit( + SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0 + ) def __iter__(self): return self @@ -111,31 +112,34 @@ def __next__(self): def raw_generate_synthetic_data(config): """Generates a single batch of syntehtic data""" output = {} - output['inputs'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['inputs_position'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['inputs_segmentation'] = jax.numpy.ones( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets_position'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets_segmentation'] = jax.numpy.ones( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) + output["inputs"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32) + output["inputs_position"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) + output["inputs_segmentation"] = jax.numpy.ones( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) + output["targets"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32) + output["targets_position"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) + output["targets_segmentation"] = jax.numpy.ones( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) return output -class BadSyntheticDataIterator(): + +class BadSyntheticDataIterator: """Creates a Bad synthetic data iterator for loading on subset of hosts""" + def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) - data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - self.data_generator = jax.jit(BadSyntheticDataIterator.get_bad_synthetic_data, - out_shardings=data_pspec_shardings, - static_argnums=0) + data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + self.data_generator = jax.jit( + BadSyntheticDataIterator.get_bad_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0 + ) def __iter__(self): return self @@ -146,25 +150,31 @@ def __next__(self): @staticmethod def get_bad_synthetic_data(config): - """fill negative value in synthetic data """ + """fill negative value in synthetic data""" output = {} - output['inputs'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['inputs_position'] = jax.numpy.full((config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['inputs_segmentation'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets_position'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets_segmentation'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) + output["inputs"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["inputs_position"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["inputs_segmentation"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["targets"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["targets_position"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["targets_segmentation"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) return output + def get_process_loading_real_data(config, mesh): - """ Get list of processes loading data from GCS when expansion_factor_real_data != -1 - """ + """Get list of processes loading data from GCS when expansion_factor_real_data != -1""" sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) devices_indices_map = sharding.devices_indices_map((config.global_batch_size_to_load, config.max_target_length)) batch_cutoff = config.global_batch_size_to_train_on @@ -174,10 +184,11 @@ def get_process_loading_real_data(config, mesh): process_loading_real_data.add(p.process_index) return list(process_loading_real_data) + def make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos): process_indices = get_process_loading_real_data(config, mesh) - print(len(process_indices),"hosts out of",jax.process_count(),"are loading real data") - if config.expansion_factor_real_data != -1: # assert number of hosts loading real data + print(len(process_indices), "hosts out of", jax.process_count(), "are loading real data") + if config.expansion_factor_real_data != -1: # assert number of hosts loading real data assert len(process_indices) == jax.process_count() // config.expansion_factor_real_data if jax.process_index() in process_indices: if config.dataset_type == "c4": @@ -186,11 +197,14 @@ def make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos): return make_grain_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices) elif config.dataset_type == "c4_mlperf": print("Overwrite both add_bos and add_eos to False") - return make_c4_mlperf_train_iterator_and_tokenizer(config, mesh, add_bos=False, add_eos=False, process_indices = process_indices) + return make_c4_mlperf_train_iterator_and_tokenizer( + config, mesh, add_bos=False, add_eos=False, process_indices=process_indices + ) else: return BadSyntheticDataIterator(config, mesh), None, get_tokenizer(config.tokenizer_path, add_bos, add_eos) -def create_data_iterator_with_tokenizer(config, mesh, add_bos = True, add_eos = True): + +def create_data_iterator_with_tokenizer(config, mesh, add_bos=True, add_eos=True): if config.dataset_type == "synthetic": return SyntheticDataIterator(config, mesh), None, get_tokenizer(config.tokenizer_path, add_bos, add_eos) elif config.dataset_type in ("c4", "c4-array_record", "c4_mlperf"): @@ -198,15 +212,16 @@ def create_data_iterator_with_tokenizer(config, mesh, add_bos = True, add_eos = else: assert False, "dataset type not implemented" + def get_shaped_batch(config): - """ Return the shape of the batch - this is what eval_shape would return for the + """Return the shape of the batch - this is what eval_shape would return for the output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078.""" batch_shape = (config.global_batch_size_to_load, config.max_target_length) shaped_batch = {} - shaped_batch['inputs'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['inputs_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['inputs_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["targets_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) return shaped_batch diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 8ea5e7bd0..b1c0803d7 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -58,8 +58,7 @@ nd_dense_init = initializers.nd_dense_init shard_map = shard_map.shard_map -dynamic_vector_slice_in_dim = jax.vmap( - lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) +dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error @@ -89,10 +88,12 @@ def apply_mask_to_logits(logits: Array, mask: Array): """ return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) + def _maybe_aqt_einsum(quant: Quant): """Maybe overwrite dot general with aqt_dot_general.""" return jnp.einsum if quant is None else quant.einsum() + class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str @@ -100,44 +101,33 @@ class AttentionOp(nn.Module): num_query_heads: int num_kv_heads: int float32_qk_product: bool = False - max_prefill_predict_length: int = -1 + max_prefill_predict_length: int = -1 float32_logits: bool = False flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) - dropout_rate: float = 0. + dropout_rate: float = 0.0 dtype: DType = jnp.float32 quant: Optional[Quant] = None quantize_kvcache: bool = False - def check_attention_inputs( - self, - query: Array, - key: Array, - value: Array) -> None: + def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: """Check attention inputs.""" - assert key.ndim == value.ndim, 'k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert key.shape[-2] == value.shape[-2], ('k, v num_kv_heads must match.') - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + assert key.ndim == value.ndim, "k, v must have same rank." + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." + assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." + assert key.shape[-3] == value.shape[-3], "k, v lengths must match." + assert query.shape[-1] == key.shape[-1], "q, k depths must match." # Following Pallas MHA Flash Attention Reference. # https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py # This mask models (1) separate sequences (decoder_segment_ids) and (2) causality - def generate_attention_mask( - self, - query, - key, - decoder_segment_ids: Array | None, - model_mode: str - ) -> Array | None: + def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, model_mode: str) -> Array | None: mask = None if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: mask = decoder_segment_ids[:, None, None, None, :] == common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR elif decoder_segment_ids is not None: mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :] - mask = mask[:, None, None,:, :] + mask = mask[:, None, None, :, :] causal_mask = None # We enforce causality except for AUTOREGRESSION @@ -160,53 +150,43 @@ def generate_attention_mask( return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None - def apply_attention(self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None, - model_mode: str): + def apply_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, model_mode: str): self.check_attention_inputs(query, key, value) length = query.shape[-3] - if self.attention_kernel == 'dot_product' or\ - (self.attention_kernel == 'autoselected' and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) or\ - (self.attention_kernel == 'autoselected' and length < 128): + if ( + self.attention_kernel == "dot_product" + or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) + or (self.attention_kernel == "autoselected" and length < 128) + ): return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode) - elif self.attention_kernel == 'flash' or\ - self.attention_kernel == 'autoselected': + elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected": if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) return self.tpu_flash_attention(query, key, value, decoder_segment_ids), None, None - elif self.attention_kernel == 'cudnn_flash_te': + elif self.attention_kernel == "cudnn_flash_te": if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None else: - raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.') - - def tpu_flash_attention( - self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None) -> Array: + raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") + + def tpu_flash_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None) -> Array: """TPU Flash Attention.""" # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) key = jnp.transpose(key, axes=(0, 2, 1, 3)) value = jnp.transpose(value, axes=(0, 2, 1, 3)) - if decoder_segment_ids is not None: - decoder_segment_ids = splash_attention_kernel.SegmentIds( - decoder_segment_ids, decoder_segment_ids - ) + decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids, decoder_segment_ids) axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) - segment_axis_names = nn.logical_to_mesh_axes( - (BATCH, 'activation_length_no_heads') - ) + segment_axis_names = nn.logical_to_mesh_axes((BATCH, "activation_length_no_heads")) @functools.partial( shard_map, @@ -223,76 +203,73 @@ def tpu_flash_attention( def wrap_flash_attention(query, key, value, decoder_segment_ids): if decoder_segment_ids is not None: assert ( - query.shape[2] - == decoder_segment_ids.q.shape[1] - ), 'Sharding along sequence dimension not allowed in tpu kernel attention' + query.shape[2] == decoder_segment_ids.q.shape[1] + ), "Sharding along sequence dimension not allowed in tpu kernel attention" block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(512, query.shape[2]), - block_kv_compute=min(512, key.shape[2]), - block_kv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_kv_dkv=min(512, key.shape[2]), - block_kv_dkv_compute=min(512, query.shape[2]), - block_q_dq=min(512, query.shape[2]), - block_kv_dq=min(512, query.shape[2]), + block_q=min(512, query.shape[2]), + block_kv_compute=min(512, key.shape[2]), + block_kv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_kv_dkv=min(512, key.shape[2]), + block_kv_dkv_compute=min(512, query.shape[2]), + block_q_dq=min(512, query.shape[2]), + block_kv_dq=min(512, query.shape[2]), ) - masks = [splash_attention_mask.CausalMask( shape=(query.shape[2],query.shape[2])) for i in range(query.shape[1])] + masks = [splash_attention_mask.CausalMask(shape=(query.shape[2], query.shape[2])) for i in range(query.shape[1])] multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) - splash_kernel = splash_attention_kernel.make_splash_mha(mask = multi_head_mask, - head_shards = 1, - q_seq_shards = 1, - block_sizes = block_sizes) - - return jax.vmap(splash_kernel)(query,key,value, segment_ids = decoder_segment_ids) - - devices_in_data_fsdp = self.mesh.shape['data'] * self.mesh.shape['fsdp'] + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes + ) + + return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids) + + devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( - 'Batch dimension should be shardable among the devices in data and fsdp' - ' axis' + "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) x = wrap_flash_attention(query, key, value, decoder_segment_ids) x = jnp.transpose(x, axes=(0, 2, 1, 3)) return x - + def cudnn_flash_attention( - self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None, - model_mode: str = common_types.MODEL_MODE_TRAIN, - ) -> Array: + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + model_mode: str = common_types.MODEL_MODE_TRAIN, + ) -> Array: """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA - 2. Supports head_dim till 128; head_dim=256 support will be added soon + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon """ # These imports are only meant to work in a GPU build. - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error _, _, _, head_dim = query.shape # pylint: disable=unused-variable - #generate attn_mask + # generate attn_mask attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - dpa_layer = DotProductAttention(head_dim=head_dim, - num_attention_heads=self.num_query_heads, - num_gqa_groups=self.num_kv_heads, - attn_mask_type='causal', # 'causal' or 'padding' - attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - attention_dropout=self.dropout_rate, - dropout_rng_name='aqt', - dtype=self.dtype, - float32_logits=self.float32_logits, - qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=1.0/math.sqrt(head_dim), - transpose_batch_sequence=False) + dpa_layer = DotProductAttention( + head_dim=head_dim, + num_attention_heads=self.num_query_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type="causal", # 'causal' or 'padding' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0 / math.sqrt(head_dim), + transpose_batch_sequence=False, + ) return dpa_layer(query, key, value, mask=attn_mask) - def compute_local_attention(self, - attn_weights: Array, - value: Array) -> tuple[Array, Array, Array]: - """Computes the attention of a local subset of the kv cache. + def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Array, Array, Array]: + """Computes the attention of a local subset of the kv cache. Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py @@ -314,25 +291,17 @@ def compute_local_attention(self, local_sum = jnp.moveaxis(local_sum, -2, 1) local_max = jnp.moveaxis(local_max, -2, 1) - local_max = jnp.reshape(local_max, - (local_max.shape[0], - local_max.shape[1], - local_max.shape[2] * local_max.shape[3], - 1)) - local_sum = jnp.reshape(local_sum, - (local_sum.shape[0], - local_sum.shape[1], - local_sum.shape[2] * local_sum.shape[3], - 1)) + local_max = jnp.reshape(local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1)) + local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)) local_out = self.wv_product(local_exps, value) return local_out, local_max, local_sum def apply_attention_dot( self, - query: Array, - key: Array, - value: Array, + query: Array, + key: Array, + value: Array, decoder_segment_ids: Array | None, model_mode: str = common_types.MODEL_MODE_TRAIN, ): @@ -364,28 +333,24 @@ def qk_product(self, query: Array, key: Array) -> Array: Returns: results in shape [b, n_kv, n // n_kv, t, s]. """ - b, t, n, d = query.shape + b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) - result = jnp.einsum('btkgd,bskd->bkgts', query, key) - return result # (4, 8, 1, 1, 6) - + result = jnp.einsum("btkgd,bskd->bkgts", query, key) + return result # (4, 8, 1, 1, 6) - def wv_product( - self, - attn_weights: Array, - value: Array) -> Array: + def wv_product(self, attn_weights: Array, value: Array) -> Array: """weighted value product. Args: - attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. + attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. value: Value projection, in shape of [batch_size, v_len, num_kv_heads, kv_dim]. Returns: result in shape [batch_size, q_len, num_kv_heads * group_size, kv_dim] """ - out = jnp.einsum('bkgts,bskd->btkgd', attn_weights, value) + out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) return result @@ -399,8 +364,7 @@ def revert_kvlen_axis(self, kv): Returns: reshaped kv as [b, ..., s, n, d] """ - return jax.numpy.moveaxis(kv, (0,1,2,3), (1,2,0,3)) - + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (1, 2, 0, 3)) def move_kvlen_axis(self, kv): """Move key/value length axis to the end. @@ -411,7 +375,7 @@ def move_kvlen_axis(self, kv): Returns: reshaped kv as [b, ..., n, d, s] """ - return jax.numpy.moveaxis(kv, (0,1,2,3), (2,0,1,3)) + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (2, 0, 1, 3)) def cached_kv_shape(self, kv_shape): """Cached KV shape. @@ -430,26 +394,51 @@ def cached_kv_shape(self, kv_shape): def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 - kv_cache_layout = ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv', ) + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) - cached_key = self.variable('cache', 'cached_prefill_key', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_value = self.variable('cache', 'cached_prefill_value', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_segment_id = self.variable('cache', 'cache_prefill_segment_id', - nn.with_logical_partitioning(jnp.zeros, ('cache_batch', 'cache_sequence')), - (cache_logical_shape[0], self.max_prefill_predict_length), jnp.int32) + cached_key = self.variable( + "cache", + "cached_prefill_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value = self.variable( + "cache", + "cached_prefill_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_segment_id = self.variable( + "cache", + "cache_prefill_segment_id", + nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + (cache_logical_shape[0], self.max_prefill_predict_length), + jnp.int32, + ) if self.quantize_kvcache: cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1) - cached_key_scale_var = self.variable('cache', 'cached_prefill_key_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) - cached_value_scale_var = self.variable('cache', 'cached_prefill_value_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) + cached_key_scale_var = self.variable( + "cache", + "cached_prefill_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_prefill_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) else: cached_key_scale_var = None cached_value_scale_var = None @@ -461,85 +450,112 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 cache_length = self.max_target_length - self.max_prefill_predict_length - kv_cache_layout = ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv', ) + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) cache_logical_shape = (batch, cache_length, heads, kv_head_size) - cached_key = self.variable('cache', 'cached_ar_key', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_value = self.variable('cache', 'cached_ar_value', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_segment_id = self.variable('cache', 'cache_ar_segment_id', - nn.with_logical_partitioning(jnp.zeros, ('cache_batch', 'cache_sequence')), - (cache_logical_shape[0], cache_length), jnp.int32) + cached_key = self.variable( + "cache", + "cached_ar_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value = self.variable( + "cache", + "cached_ar_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_segment_id = self.variable( + "cache", + "cache_ar_segment_id", + nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + (cache_logical_shape[0], cache_length), + jnp.int32, + ) if self.quantize_kvcache: cache_logical_shape_scale = (batch, cache_length, heads, 1) - cached_key_scale_var = self.variable('cache', 'cached_ar_key_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) - cached_value_scale_var = self.variable('cache', 'cached_ar_value_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) + cached_key_scale_var = self.variable( + "cache", + "cached_ar_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_ar_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) else: cached_key_scale_var = None cached_value_scale_var = None - cache_index = self.variable('cache', 'cache_ar_index', - nn.with_logical_partitioning(jnp.zeros, ()), - (1,), jnp.int32) + cache_index = self.variable("cache", "cache_ar_index", nn.with_logical_partitioning(jnp.zeros, ()), (1,), jnp.int32) key_vars = (cached_key, cached_key_scale_var) value_vars = (cached_value, cached_value_scale_var) return key_vars, value_vars, cached_segment_id, cache_index - def kv_cache_prefill(self, - key: Array, - value: Array, - decoder_segment_ids: Array, - ): - """In prefill mode, we zero out the existing cache, run the computation and - prepare the cache as necessary. - - Args: - key: in shape [b, s, n, d]. - value: in shape [b, s, n, d]. - decoder_segment_ids: [b, s] -- marking segment ids for tokens - - Returns: - key, value, decoder_segment_id. - - """ - batch, sequence, heads, kv_head_size = key.shape - assert key.dtype == value.dtype, "Key and Value Dtypes should match." - - cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(batch, heads, kv_head_size, self.quantize_kvcache) - self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now - - key_shaped_for_cache = self.move_kvlen_axis(key) - value_shaped_for_cache = self.move_kvlen_axis(value) - - if self.quantize_kvcache: - key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache) - value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache) - cached_prefill_key_var[1].value = key_scale - cached_prefill_value_var[1].value = value_scale - - cached_prefill_key_var[0].value = key_shaped_for_cache - cached_prefill_value_var[0].value = value_shaped_for_cache + def kv_cache_prefill( + self, + key: Array, + value: Array, + decoder_segment_ids: Array, + ): + """In prefill mode, we zero out the existing cache, run the computation and + prepare the cache as necessary. - if decoder_segment_ids is not None: - cached_prefill_segment_id.value = decoder_segment_ids + Args: + key: in shape [b, s, n, d]. + value: in shape [b, s, n, d]. + decoder_segment_ids: [b, s] -- marking segment ids for tokens + + Returns: + key, value, decoder_segment_id. - return key, value, decoder_segment_ids - + """ + batch, sequence, heads, kv_head_size = key.shape + assert key.dtype == value.dtype, "Key and Value Dtypes should match." + + cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) + self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now + + key_shaped_for_cache = self.move_kvlen_axis(key) + value_shaped_for_cache = self.move_kvlen_axis(value) + + if self.quantize_kvcache: + key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache) + value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache) + cached_prefill_key_var[1].value = key_scale + cached_prefill_value_var[1].value = value_scale + + cached_prefill_key_var[0].value = key_shaped_for_cache + cached_prefill_value_var[0].value = value_shaped_for_cache - def update_ar_key_value(self, - one_token_key: Array, - one_token_value: Array, - cached_key_vars: tuple[nn.Variable, nn.Variable|None], - cached_value_vars: tuple[nn.Variable, nn.Variable|None], - one_hot_indices: Array) -> tuple[Array, Array]: - """Adds a single token's results to the ar kv cache + if decoder_segment_ids is not None: + cached_prefill_segment_id.value = decoder_segment_ids + + return key, value, decoder_segment_ids + + def update_ar_key_value( + self, + one_token_key: Array, + one_token_value: Array, + cached_key_vars: tuple[nn.Variable, nn.Variable | None], + cached_value_vars: tuple[nn.Variable, nn.Variable | None], + one_hot_indices: Array, + ) -> tuple[Array, Array]: + """Adds a single token's results to the ar kv cache Args: one_token_key (Array): Key of one token to add to the cache @@ -566,20 +582,39 @@ def update_ar_key_value(self, one_hot_indices = one_hot_indices.astype(int) - ar_key = cached_key_var.value ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key, jnp.squeeze(one_hot_indices), 0) - ar_key = nn.with_logical_constraint(ar_key, ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv',)) + ar_key = nn.with_logical_constraint( + ar_key, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) cached_key_var.value = ar_key ar_value = cached_value_var.value ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value, jnp.squeeze(one_hot_indices), 0) - ar_value = nn.with_logical_constraint(ar_value, ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv',)) + ar_value = nn.with_logical_constraint( + ar_value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) cached_value_var.value = ar_value if self.quantize_kvcache: - ar_key_scale = jax.lax.dynamic_update_index_in_dim(cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0) - ar_value_scale = jax.lax.dynamic_update_index_in_dim(cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0) + ar_key_scale = jax.lax.dynamic_update_index_in_dim( + cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0 + ) + ar_value_scale = jax.lax.dynamic_update_index_in_dim( + cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0 + ) cached_key_scale_var.value = ar_key_scale cached_value_scale_var.value = ar_value_scale @@ -594,58 +629,61 @@ def prefill_cache_var_model_var(self, cache_var, target_dtype): return self.revert_kvlen_axis(cache_var[0].value) else: raw_cache, quant_scale = cache_var - raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype) + raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype) return self.revert_kvlen_axis(raw_cache_unquantized) - - - def kv_cache_autoregressive(self, - key: Array, - value: Array, - ): - """In autoregressive mode, we update the cache for this entry and - then return the full cache. - - Args: - key: in shape [b, 1, n, d]. - value: in shape [b, 1, n, d]. - decoder_segment_ids: [b, 1] -- marking segment ids for tokens - - Returns: - tuple of (key, value, segment_id) for both prefill and ar cache, - Raises: - ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. - """ - batch, sequence, heads, kv_head_size = key.shape - if sequence != 1: - raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") - is_initialized = self.has_variable('cache', 'cache_ar_index') - if not is_initialized: - raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") - - cached_ar_key_var, cached_ar_value_var, cached_ar_segment_id, cache_ar_index = self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) - - key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) - value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) - - ar_key, ar_value = self.update_ar_key_value(key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value) - active_indicator = jnp.zeros((batch, 1), dtype = jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR - cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim(cached_ar_segment_id.value, active_indicator, jnp.squeeze(cache_ar_index.value), 1) - cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length) - - # Prep and return both prefill and ar caches - cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(self.max_target_length, heads, kv_head_size, self.quantize_kvcache) - - cached_prefill = self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), cached_prefill_segment_id.value - return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) - - def kv_cache( + def kv_cache_autoregressive( self, key: Array, value: Array, - decoder_segment_ids: Array, - model_mode: str - ) -> tuple: + ): + """In autoregressive mode, we update the cache for this entry and + then return the full cache. + + Args: + key: in shape [b, 1, n, d]. + value: in shape [b, 1, n, d]. + decoder_segment_ids: [b, 1] -- marking segment ids for tokens + + Returns: + tuple of (key, value, segment_id) for both prefill and ar cache, + Raises: + ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. + """ + batch, sequence, heads, kv_head_size = key.shape + if sequence != 1: + raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") + is_initialized = self.has_variable("cache", "cache_ar_index") + if not is_initialized: + raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") + + cached_ar_key_var, cached_ar_value_var, cached_ar_segment_id, cache_ar_index = self._get_ar_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) + + key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) + value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) + + ar_key, ar_value = self.update_ar_key_value(key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value) + active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id.value, active_indicator, jnp.squeeze(cache_ar_index.value), 1 + ) + cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length) + + # Prep and return both prefill and ar caches + cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( + self.max_target_length, heads, kv_head_size, self.quantize_kvcache + ) + + cached_prefill = ( + self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), + self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), + cached_prefill_segment_id.value, + ) + return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) + + def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str) -> tuple: """KV cache takes the current state and updates the state accordingly. The key and value have dimension [batch, length, num_heads, head_dim], @@ -665,7 +703,6 @@ def kv_cache( """ if key.shape != value.shape: raise ValueError(f"Can't KV cache with mismatched shapes {key.shape=}, {value.shape=}") - if model_mode == common_types.MODEL_MODE_TRAIN: return (key, value, decoder_segment_ids), None @@ -675,12 +712,8 @@ def kv_cache( return self.kv_cache_autoregressive(key, value) else: raise ValueError(f"Model Mode isn't supported! {model_mode=}") - - - def normalize_attention(self, - local_outs, - local_maxes, - local_sums): + + def normalize_attention(self, local_outs, local_maxes, local_sums): """Normalize across multiple localized attentions Args: @@ -689,14 +722,13 @@ def normalize_attention(self, local_sums (list): List of exponential sum entries for each local attention Returns: - Array: Combined attention that has been normalized + Array: Combined attention that has been normalized """ # Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py global_max = functools.reduce(jnp.maximum, local_maxes) - global_sum = sum([ - jnp.exp(local_max - global_max) * local_sum - for (local_sum, local_max) in zip(local_sums, local_maxes) - ]) + global_sum = sum( + [jnp.exp(local_max - global_max) * local_sum for (local_sum, local_max) in zip(local_sums, local_maxes)] + ) attn_out = 0 for local_max, local_out in zip(local_maxes, local_outs): @@ -704,17 +736,16 @@ def normalize_attention(self, attn_out += local_normalizer * local_out return attn_out - @nn.compact def __call__(self, query, key, value, decoder_segment_ids, model_mode): prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode) prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( - query=query, - key=prefill_kv_cache[0], - value=prefill_kv_cache[1], - decoder_segment_ids=prefill_kv_cache[2], - model_mode=model_mode, + query=query, + key=prefill_kv_cache[0], + value=prefill_kv_cache[1], + decoder_segment_ids=prefill_kv_cache[2], + model_mode=model_mode, ) # Return the "prefill" cache if it actually the combined prefill+ar kv cache @@ -723,12 +754,12 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): return prefill_unnormalized_output / prefill_exponentials_sum return prefill_unnormalized_output - ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( - query=query, - key=ar_kv_cache[0], - value=ar_kv_cache[1], - decoder_segment_ids=ar_kv_cache[2], - model_mode=model_mode, + ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( + query=query, + key=ar_kv_cache[0], + value=ar_kv_cache[1], + decoder_segment_ids=ar_kv_cache[2], + model_mode=model_mode, ) unnormalized_outputs = [prefill_unnormalized_output, ar_unnormalized_output] @@ -738,27 +769,27 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): class Attention(nn.Module): - """ Generic Attention. - - Attributes: - num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - num_kv_heads: number of kv attention heads. - head_dim: dimension of each head. - mesh: Mesh, device mesh - attention_kernel: str, guidance on if we should use an attention kernel - dtype: the dtype of the computation. - weight_dtype: the dtype of the weights. - max_target_length: maximum target length - max_prefill_predict_length: size of the maximum prefill - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid - numerical issues with bfloat16. - float32_logits: bool, if True then cast logits to float32 before softmax to avoid - numerical issues with bfloat16. - quant: Quant, stores quantization parameters, defaults to None implying no quantization. - quantize_kvcache: bool, quantize the kv cache. + """Generic Attention. + + Attributes: + num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + num_kv_heads: number of kv attention heads. + head_dim: dimension of each head. + mesh: Mesh, device mesh + attention_kernel: str, guidance on if we should use an attention kernel + dtype: the dtype of the computation. + weight_dtype: the dtype of the weights. + max_target_length: maximum target length + max_prefill_predict_length: size of the maximum prefill + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + quant: Quant, stores quantization parameters, defaults to None implying no quantization. + quantize_kvcache: bool, quantize the kv cache. """ config: Config @@ -771,14 +802,13 @@ class Attention(nn.Module): dtype: DType = jnp.float32 weight_dtype: DType = jnp.float32 max_prefill_predict_length: int = -1 - dropout_rate: float = 0. - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + dropout_rate: float = 0.0 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") float32_qk_product: bool = False # computes logits in float32 for stability. float32_logits: bool = False # cast logits in float32 for stability. quant: Optional[Quant] = None quantize_kvcache: bool = False - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) @@ -791,19 +821,21 @@ def query_projection(self, inputs_q: Array) -> Array: # 1/sqrt(depth_kq)! This is folded into the initializers of the # linear transformations, which is equivalent under Adafactor. depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + def query_init(*args): - #pylint: disable=no-value-for-parameter + # pylint: disable=no-value-for-parameter return self.kernel_init(*args) / depth_scaling query_proj = DenseGeneral( - features=(self.num_query_heads, self.head_dim), - axis=-1, - kernel_init=query_init, - kernel_axes=('embed', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='query', - quant=self.quant)(inputs_q) + features=(self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=query_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="query", + quant=self.quant, + )(inputs_q) return query_proj def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: @@ -818,66 +850,69 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: Projection of key or value, in shape of `[batch, kv_length, head_dim]`. """ if self.num_kv_heads == -1: - raise ValueError('num_kv_heads is not defined.') + raise ValueError("num_kv_heads is not defined.") if self.num_query_heads % self.num_kv_heads != 0: - raise ValueError('Invaid num_kv_heads for GQA.') + raise ValueError("Invaid num_kv_heads for GQA.") kv_proj = DenseGeneral( features=(self.num_kv_heads, self.head_dim), axis=-1, kernel_init=self.kernel_init, - kernel_axes=('embed', 'heads', 'kv'), + kernel_axes=("embed", "heads", "kv"), dtype=self.dtype, weight_dtype=self.weight_dtype, name=proj_name, - quant=self.quant)(inputs_kv) + quant=self.quant, + )(inputs_kv) return kv_proj def qkv_projection(self, inputs: Array, proj_name: str): - """ Fused QKV projection""" + """Fused QKV projection""" qkv_proj = DenseGeneral( - features=(3, self.num_query_heads, self.head_dim), - axis = -1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'qkv', 'heads', 'kv'), + features=(3, self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), dtype=self.dtype, weight_dtype=self.weight_dtype, name=proj_name, - quant=self.quant)(inputs) - qkv_proj = checkpoint_name(qkv_proj, 'qkv_proj') - query, key, value = qkv_proj[:,:,0,...], qkv_proj[:,:,1,...], qkv_proj[:,:,2,...] + quant=self.quant, + )(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] return query, key, value def out_projection(self, output_dim: int, out: Array) -> Array: out_proj = DenseGeneral( - features=output_dim, - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('heads', 'kv', 'embed'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='out', - quant=self.quant)(out) + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=("heads", "kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="out", + quant=self.quant, + )(out) return out_proj def key_rotary(self, key: Array, inputs_positions: Array): """Apply Rotary Embedding to key.""" - key = RotaryEmbedding( - embedding_dims=self.head_dim, - name='key_rotary')(inputs=key, position=inputs_positions) + key = RotaryEmbedding(embedding_dims=self.head_dim, name="key_rotary")(inputs=key, position=inputs_positions) return key @nn.compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - inputs_positions: Array, - decoder_segment_ids: Array | None = None, - *, - model_mode: str = common_types.MODEL_MODE_TRAIN, - deterministic: bool = False): + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + inputs_positions: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False, + ): """Applies Attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -902,38 +937,38 @@ def __call__(self, """ # apply projection. if self.config.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name='qkv_proj') + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: query = self.query_projection(inputs_q) - key = self.kv_projection(inputs_kv, proj_name='key') - value = self.kv_projection(inputs_kv, proj_name='value') + key = self.kv_projection(inputs_kv, proj_name="key") + value = self.kv_projection(inputs_kv, proj_name="value") # apply ROPE - query = RotaryEmbedding( - embedding_dims=self.head_dim, name='query_rotary' - )(inputs=query, position=inputs_positions) + query = RotaryEmbedding(embedding_dims=self.head_dim, name="query_rotary")(inputs=query, position=inputs_positions) key = self.key_rotary(key, inputs_positions) # annotate with sharding constraint. query = nn.with_logical_constraint(query, self.query_axis_names) - query = checkpoint_name(query, 'query_proj') + query = checkpoint_name(query, "query_proj") key = nn.with_logical_constraint(key, self.key_axis_names) - key = checkpoint_name(key, 'key_proj') + key = checkpoint_name(key, "key_proj") value = nn.with_logical_constraint(value, self.value_axis_names) - value = checkpoint_name(value, 'value_proj') - - attention_op = AttentionOp(mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - quantize_kvcache=self.quantize_kvcache, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - dropout_rate = self.dropout_rate, - dtype=self.dtype) + value = checkpoint_name(value, "value_proj") + + attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + quantize_kvcache=self.quantize_kvcache, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) @@ -941,5 +976,5 @@ def __call__(self, # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) - out = checkpoint_name(out, 'out_proj') + out = checkpoint_name(out, "out_proj") return out diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 6c954941c..9337986a0 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -32,6 +32,7 @@ _MAX_WAVELENGTH = 10_000 + class Embed(nn.Module): """A parameterized function from integers [0, n) to d-dimensional vectors. @@ -53,8 +54,8 @@ class Embed(nn.Module): def setup(self): self.embedding = self.param( - 'embedding', - with_logical_partitioning(self.embedding_init, ('vocab', 'embed')), + "embedding", + with_logical_partitioning(self.embedding_init, ("vocab", "embed")), (self.num_embeddings, self.features), self.config.weight_dtype, ) @@ -73,7 +74,7 @@ def __call__(self, inputs: Array) -> Array: if self.cast_input_dtype: inputs = inputs.astype(self.cast_input_dtype) if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') + raise ValueError("Input type must be an integer or unsigned integer.") if cfg.use_iota_embed: iota = lax.iota(jnp.int32, self.num_embeddings) @@ -81,9 +82,7 @@ def __call__(self, inputs: Array) -> Array: output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) else: output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = nn.with_logical_constraint( - output, ('activation_batch', 'activation_length', 'activation_embed') - ) + output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed")) return output def attend(self, query: Array) -> Array: @@ -122,9 +121,7 @@ class RotaryEmbedding(nn.Module): def setup(self) -> None: if self.embedding_dims % 2: - raise ValueError( - 'Embedding dim for rotary position embedding must be a multiple of 2.' - ) + raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") def __call__( self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks @@ -147,21 +144,14 @@ def __call__( """ assert position is not None if len(inputs.shape) != 4: - raise ValueError( - 'Input is assumed to be a rank 4 tensor of shape' - '[batch, sequence, heads, dims].' - ) + raise ValueError("Input is assumed to be a rank 4 tensor of shape" "[batch, sequence, heads, dims].") if self.embedding_dims != inputs.shape[3]: raise ValueError( - 'The embedding dims of the rotary position embedding' - 'must match the hidden dimension of the inputs.' + "The embedding dims of the rotary position embedding" "must match the hidden dimension of the inputs." ) half_embedding_dim = self.embedding_dims // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims - timescale = ( - self.min_timescale - * (self.max_timescale / self.min_timescale) ** fraction - ) + timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction position = position[:, :, jnp.newaxis, jnp.newaxis] sinusoid_inp = position / timescale sin = jnp.sin(sinusoid_inp).astype(inputs.dtype) @@ -189,13 +179,11 @@ def __call__( log_timescale_increment = jnp.log(float(self.max_wavelength)) / jnp.maximum( jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 ) - inv_timescales = jnp.exp( - jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment - ) + inv_timescales = jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) position = position[:, :, jnp.newaxis] inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :] scaled_time = position * inv_timescales - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = -1) + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1) # signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]]) position_embedding = signal.astype(jnp.float32) return input_embedding + position_embedding diff --git a/MaxText/layers/gemma.py b/MaxText/layers/gemma.py index cbbadf7bc..fb909f985 100644 --- a/MaxText/layers/gemma.py +++ b/MaxText/layers/gemma.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from flax import linen as nn import common_types @@ -52,69 +52,65 @@ # Decoder and Model definitions class GemmaDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - lnx = RMSNorm( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name='pre_self_attention_norm', - kernel_axes=('embed',))(inputs) + lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("embed",))( + inputs + ) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) attention_layer = Attention( - config=cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - float32_qk_product = True, - float32_logits = True, - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + float32_qk_product=True, + float32_logits=True, + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) attention_lnx += inputs residual = attention_lnx - attn_output = RMSNorm( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name='pre_ffw_norm', - kernel_axes=('embed',))(attention_lnx) + attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("embed",))( + attention_lnx + ) # MLP block. mlp_lnx = MlpBlock( @@ -123,32 +119,30 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(attn_output, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) next_layer_addition = mlp_lnx + residual - next_layer_addition_dropped_out = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,) - )(next_layer_addition, deterministic=deterministic) + next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + next_layer_addition, deterministic=deterministic + ) layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py index 518ec2912..853ec43bb 100644 --- a/MaxText/layers/gpt3.py +++ b/MaxText/layers/gpt3.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -55,12 +55,14 @@ Quant = quantizations.AqtQuantization -#----------------------------------------- +# ----------------------------------------- # The Normalization Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3LayerNorm(nn.Module): """GPT3 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -82,10 +84,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: features = x.shape[-1] scale = self.param( - 'scale', - nn.with_logical_partitioning(self.scale_init, self.kernel_axes), - (features,), - self.weight_dtype + "scale", nn.with_logical_partitioning(self.scale_init, self.kernel_axes), (features,), self.weight_dtype ) scale = jnp.asarray(scale, self.dtype) @@ -93,40 +92,41 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: if self.use_bias: bias = self.param( - 'bias', - nn.with_logical_partitioning(initializers.default_bias_init, self.kernel_axes), - (features,), - self.weight_dtype, + "bias", + nn.with_logical_partitioning(initializers.default_bias_init, self.kernel_axes), + (features,), + self.weight_dtype, ) bias = jnp.asarray(bias, self.dtype) output += bias return output -#----------------------------------------- +# ----------------------------------------- # The Attention Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3MultiHeadAttention(nn.Module): """Multi-head attention in gpt3. - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - head_dim: dimension of each head. - max_target_length: maximum length of output - max_prefill_predict_length: size of the maximum prefill - mesh: device mesh - dtype: the dtype of the computation. - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid - numerical issues with bfloat16. - float32_logits: bool, if True then cast logits to float32 before softmax to avoid - numerical issues with bfloat16. - fused_qkv: whether to fuse query, key and value into one projection. - quant: Quant, stores quantization config, defaults to None implying no quantization. - use_bias: whether to add bias in linear transformation. + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + max_target_length: maximum length of output + max_prefill_predict_length: size of the maximum prefill + mesh: device mesh + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + fused_qkv: whether to fuse query, key and value into one projection. + quant: Quant, stores quantization config, defaults to None implying no quantization. + use_bias: whether to add bias in linear transformation. """ config: Config @@ -138,8 +138,8 @@ class Gpt3MultiHeadAttention(nn.Module): attention_kernel: str dtype: DType = jnp.float32 weight_dtype: DType = jnp.float32 - dropout_rate: float = 0. - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + dropout_rate: float = 0.0 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") float32_qk_product: bool = False # computes logits in float32 for stability. float32_logits: bool = True # cast logits in float32 for stability. fused_qkv: bool = True @@ -152,88 +152,92 @@ class Gpt3MultiHeadAttention(nn.Module): out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) def qkv_projection(self, inputs: Array, proj_name: str): - """ Fused QKV projection""" + """Fused QKV projection""" qkv_proj = DenseGeneral( - features=(3, self.num_heads, self.head_dim), - axis = -1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'qkv', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=proj_name, - quant=self.quant, - use_bias=self.use_bias, - )(inputs) - qkv_proj = checkpoint_name(qkv_proj, 'qkv_proj') - query, key, value = qkv_proj[:,:,0,...], qkv_proj[:,:,1,...], qkv_proj[:,:,2,...] + features=(3, self.num_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + use_bias=self.use_bias, + )(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] return query, key, value def projection(self, inputs: Array, proj_name: str) -> Array: """individual projection for one of q, k and v.""" proj = DenseGeneral( - features=(self.num_heads, self.head_dim), - axis=-1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=proj_name, - quant=self.quant, - use_bias=self.use_bias, - )(inputs) + features=(self.num_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + use_bias=self.use_bias, + )(inputs) return proj def out_projection(self, output_dim: int, out: Array) -> Array: """output projection""" out_proj = DenseGeneral( - features=output_dim, - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('heads', 'kv', 'embed'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='out', - quant=self.quant, - use_bias=self.use_bias, - )(out) + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=("heads", "kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="out", + quant=self.quant, + use_bias=self.use_bias, + )(out) return out_proj @nn.compact - def __call__(self, - inputs_q: Array, - decoder_segment_ids: Array | None = None, - *, - model_mode: str = common_types.MODEL_MODE_TRAIN, - deterministic: bool = False): + def __call__( + self, + inputs_q: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False, + ): if self.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name='qkv_proj') + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: - query = self.projection(inputs_q, proj_name='query') - key = self.projection(inputs_q, proj_name='key') - value = self.projection(inputs_q, proj_name='value') + query = self.projection(inputs_q, proj_name="query") + key = self.projection(inputs_q, proj_name="key") + value = self.projection(inputs_q, proj_name="value") depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) query /= depth_scaling # annotate with sharding constraint. query = nn.with_logical_constraint(query, self.query_axis_names) - query = checkpoint_name(query, 'query_proj') + query = checkpoint_name(query, "query_proj") key = nn.with_logical_constraint(key, self.key_axis_names) - key = checkpoint_name(key, 'key_proj') + key = checkpoint_name(key, "key_proj") value = nn.with_logical_constraint(value, self.value_axis_names) - value = checkpoint_name(value, 'value_proj') - - attention_op = AttentionOp(mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - quantize_kvcache=self.config.quantize_kvcache, - num_query_heads=self.num_heads, - num_kv_heads=self.num_heads, - dtype=self.dtype) + value = checkpoint_name(value, "value_proj") + + attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + quantize_kvcache=self.config.quantize_kvcache, + num_query_heads=self.num_heads, + num_kv_heads=self.num_heads, + dtype=self.dtype, + ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) @@ -241,76 +245,74 @@ def __call__(self, # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) - out = checkpoint_name(out, 'out_proj') + out = checkpoint_name(out, "out_proj") return out -#----------------------------------------- +# ----------------------------------------- # The Decoder Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) - + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) lnx_layer_norm = Gpt3LayerNorm( dtype=cfg.dtype, - name='pre_self_attention_norm', - kernel_axes=('embed',), + name="pre_self_attention_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, reductions_in_fp32=False, use_bias=True, - ) + ) lnx = lnx_layer_norm(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) # Self-attention block - assert cfg.num_query_heads == cfg.num_kv_heads, \ - f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" + assert ( + cfg.num_query_heads == cfg.num_kv_heads + ), f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" attention_layer = Gpt3MultiHeadAttention( - config=cfg, - num_heads=cfg.num_query_heads, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dropout_rate=cfg.dropout_rate, - name='self_attention', - fused_qkv=cfg.fused_qkv, - use_bias=True, - quant=self.quant) + config=cfg, + num_heads=cfg.num_query_heads, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dropout_rate=cfg.dropout_rate, + name="self_attention", + fused_qkv=cfg.fused_qkv, + use_bias=True, + quant=self.quant, + ) attention_lnx = attention_layer( - lnx, - decoder_segment_ids=decoder_segment_ids, - model_mode=model_mode, - deterministic=deterministic) - - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + lnx, decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, deterministic=deterministic + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) attention_lnx += inputs # MLP block. @@ -320,33 +322,29 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", use_bias=True, use_pre_norm=True, config=cfg, quant=self.quant, )(attention_lnx, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) layer_output = attention_lnx + mlp_lnx - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/initializers.py b/MaxText/layers/initializers.py index 6f0bb9c23..5916ecb0c 100644 --- a/MaxText/layers/initializers.py +++ b/MaxText/layers/initializers.py @@ -27,13 +27,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] InitializerAxis = Union[int, Tuple[int, ...]] -NdInitializer = Callable[ - [PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array -] +NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array] -default_embed_init = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 -) +default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0) default_bias_init = jax.nn.initializers.constant(0.0) @@ -42,9 +38,7 @@ def nd_dense_init(scale, mode, distribution): """Initializer with in_axis, out_axis set at call time.""" def init_fn(key, shape, dtype, in_axis, out_axis): - fn = jax.nn.initializers.variance_scaling( - scale, mode, distribution, in_axis, out_axis - ) + fn = jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis, out_axis) return fn(key, shape, dtype) return init_fn diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 4cf3d1939..3d3f35b9b 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -40,18 +40,20 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization -def _convert_to_activation_function( - fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: + +def _convert_to_activation_function(fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: """Convert a string to an activation function.""" - if fn_or_string == 'linear': + if fn_or_string == "linear": return lambda x: x elif isinstance(fn_or_string, str): return getattr(nn, fn_or_string) elif callable(fn_or_string): return fn_or_string else: - raise ValueError(f"""Don't know how to convert {fn_or_string} - to an activation function""") + raise ValueError( + f"""Don't know how to convert {fn_or_string} + to an activation function""" + ) def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: @@ -83,7 +85,7 @@ class DenseGeneral(nn.Module): axis: Union[Iterable[int], int] = -1 weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") kernel_axes: Tuple[str, ...] = () quant: Optional[Quant] = None use_bias: bool = False @@ -105,8 +107,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): if self.quant: dot_general_cls = self.quant.dot_general_cls() dot_general = dot_general_cls() - return dot_general( - inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -123,12 +124,12 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): kernel = jnp.zeros(kernel_shape) else: kernel = self.param( - 'kernel', - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, ) kernel = jnp.asarray(kernel, self.dtype) @@ -136,9 +137,9 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): output = compute_dot_general(inputs, kernel, axis, contract_ind) if self.use_bias: - bias_axes, bias_shape = self.kernel_axes[-len(features):], kernel_shape[-len(features):] + bias_axes, bias_shape = self.kernel_axes[-len(features) :], kernel_shape[-len(features) :] bias = self.param( - 'bias', + "bias", nn.with_logical_partitioning(bias_init, bias_axes), bias_shape, self.weight_dtype, @@ -167,8 +168,8 @@ class MlpBlock(nn.Module): config: Config intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable[..., Any]]] = ('relu',) - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + activations: Sequence[Union[str, Callable[..., Any]]] = ("relu",) + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") intermediate_dropout_rate: float = 0.1 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -181,6 +182,7 @@ def get_norm_layer(self): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=self.use_bias) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @@ -192,39 +194,39 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): if self.use_pre_norm: inputs = self.get_norm_layer()( - name='mlp_layer_norm', - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - kernel_axes=('embed',), - epsilon=cfg.normalization_layer_epsilon, - )(inputs) + name="mlp_layer_norm", + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("embed",), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) # Iterate over specified MLP input activation functions. # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. activations = [] if cfg.fused_mlp: x = DenseGeneral( - (len(self.activations), self.intermediate_dim), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'num_activations', 'mlp'), - name='wi', - quant=self.quant, - use_bias=self.use_bias, + (len(self.activations), self.intermediate_dim), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "num_activations", "mlp"), + name="wi", + quant=self.quant, + use_bias=self.use_bias, )(inputs) for idx, act_fn in enumerate(self.activations): - y = _convert_to_activation_function(act_fn)(x[:,:,idx,...]) + y = _convert_to_activation_function(act_fn)(x[:, :, idx, ...]) activations.append(y) else: for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" x = DenseGeneral( self.intermediate_dim, dtype=self.dtype, weight_dtype=self.weight_dtype, kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), + kernel_axes=("embed", "mlp"), name=dense_name, quant=self.quant, use_bias=self.use_bias, @@ -234,26 +236,24 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): # Take elementwise product of above intermediate activations. x = functools.reduce(operator.mul, activations) - x = checkpoint_name(x, 'mlpwi') + x = checkpoint_name(x, "mlpwi") # Apply dropout and final dense output projection. x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic ) # Broadcast along length. - x = nn.with_logical_constraint( - x, ('activation_batch', 'activation_length', 'activation_mlp') - ) + x = nn.with_logical_constraint(x, ("activation_batch", "activation_length", "activation_mlp")) output = DenseGeneral( inputs.shape[-1], dtype=self.dtype, weight_dtype=self.weight_dtype, kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - name='wo', + kernel_axes=("mlp", "embed"), + name="wo", quant=self.quant, use_bias=self.use_bias, )(x) - output = checkpoint_name(output, 'mlpwo') + output = checkpoint_name(output, "mlpwo") return output @@ -278,38 +278,35 @@ class MoeBlock(nn.Module): @nn.compact def __call__(self, inputs, deterministic: bool = False): gate_logits = DenseGeneral( - self.num_experts, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - name='gate', - quant=self.quant,)(inputs) + self.num_experts, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + name="gate", + quant=self.quant, + )(inputs) weights, selected_experts = lax.top_k(gate_logits, self.num_experts_per_tok) weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) mlp_lnx = jnp.zeros_like(inputs) weights = weights.astype(self.dtype) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) # TODO(ranran): have a better solution to remove the loop here for k in range(self.num_experts): - weights_exp = jnp.sum(jnp.multiply(selected_experts==k, weights), axis=-1) - mlp_lnx_exp = MlpBlock( + weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1) + mlp_lnx_exp = MlpBlock( intermediate_dim=self.config.mlp_dim, activations=self.config.mlp_activations, intermediate_dropout_rate=self.config.dropout_rate, dtype=self.dtype, weight_dtype=self.weight_dtype, - name=f'mlp_{k}', + name=f"mlp_{k}", config=self.config, - )(inputs, deterministic=deterministic) + )(inputs, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint( - mlp_lnx_exp, ('activation_batch', 'activation_length', 'activation_embed') - ) - mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp - mlp_lnx += mlp_lnx_exp + mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp + mlp_lnx += mlp_lnx_exp return mlp_lnx diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index bc0fd1861..7723d6548 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -43,86 +43,82 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization -#----------------------------------------- +# ----------------------------------------- # The Decoder Layer specific for Llama2 -#----------------------------------------- +# ----------------------------------------- class LlamaDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) - + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) lnx_rms = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_layer_norm', - kernel_axes=('embed',), + name="pre_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - ) + ) lnx = lnx_rms(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) # Self-attention block attention_layer = Attention( - config = cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) - - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='post_self_attention_layer_norm', - kernel_axes=('embed',), + name="post_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) - hidden_states = nn.with_logical_constraint( - hidden_states, - ('activation_batch', 'activation_length', 'activation_embed') - ) + )(intermediate_inputs) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) # MLP block. mlp_lnx = linears.MlpBlock( @@ -131,32 +127,27 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) - + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 2992cf5ab..6954c157f 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -51,153 +51,144 @@ class MistralDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) lnx_rms = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_layer_norm', - kernel_axes=('embed',), - epsilon=cfg.normalization_layer_epsilon - ) + name="pre_self_attention_layer_norm", + kernel_axes=("embed",), + epsilon=cfg.normalization_layer_epsilon, + ) lnx = lnx_rms(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) # Self-attention block attention_layer = Attention( - config = cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) - - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='post_self_attention_layer_norm', - kernel_axes=('embed',), + name="post_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, ('activation_batch', 'activation_length', 'activation_embed')) + )(intermediate_inputs) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) if cfg.num_experts > 1: - # TODO(ranran): currently, this MoeBlock does not work as expected, and plan to fix it in coming PR. - - # mlp_lnx = linears.MoeBlock( - # config=cfg, - # num_experts=cfg.num_experts, - # num_experts_per_tok=cfg.num_experts_per_tok, - # kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), - # kernel_axes=('embed', 'mlp'), - # dtype=cfg.dtype, - # )(hidden_states, deterministic=deterministic) - - gate_logits = linears.DenseGeneral( - cfg.num_experts, - weight_dtype=cfg.weight_dtype, - dtype=cfg.dtype, - kernel_init=initializers.nd_dense_init( - 1.0, 'fan_in', 'truncated_normal'), - kernel_axes=('embed', 'mlp'), - name="gate", - quant=self.quant, - )(hidden_states) - weights, selected_experts = jax.lax.top_k(gate_logits, cfg.num_experts_per_tok) - weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) - mlp_lnx = jnp.zeros_like(hidden_states) - weights = weights.astype(cfg.dtype) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) - - # TODO(ranran): have a better solution to remove the loop here - for k in range(cfg.num_experts): - weights_exp = jnp.sum(jnp.multiply( - selected_experts == k, weights), axis=-1) - mlp_lnx_exp = linears.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name=f'mlp_{k}', - config=cfg, - )(hidden_states, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint( - mlp_lnx_exp, ('activation_batch', 'activation_length', 'activation_embed') - ) - mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp - mlp_lnx += mlp_lnx_exp - else: - mlp_lnx = linears.MlpBlock( + # TODO(ranran): currently, this MoeBlock does not work as expected, and plan to fix it in coming PR. + + # mlp_lnx = linears.MoeBlock( + # config=cfg, + # num_experts=cfg.num_experts, + # num_experts_per_tok=cfg.num_experts_per_tok, + # kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), + # kernel_axes=('embed', 'mlp'), + # dtype=cfg.dtype, + # )(hidden_states, deterministic=deterministic) + + gate_logits = linears.DenseGeneral( + cfg.num_experts, + weight_dtype=cfg.weight_dtype, + dtype=cfg.dtype, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", "mlp"), + name="gate", + quant=self.quant, + )(hidden_states) + weights, selected_experts = jax.lax.top_k(gate_logits, cfg.num_experts_per_tok) + weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) + mlp_lnx = jnp.zeros_like(hidden_states) + weights = weights.astype(cfg.dtype) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + + # TODO(ranran): have a better solution to remove the loop here + for k in range(cfg.num_experts): + weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1) + mlp_lnx_exp = linears.MlpBlock( intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name=f"mlp_{k}", config=cfg, )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp + mlp_lnx += mlp_lnx_exp + else: + mlp_lnx = linears.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + config=cfg, + )(hidden_states, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( - layer_output, ('activation_batch', 'activation_length', 'activation_embed'), + layer_output, + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 571a77a95..ff3c246b1 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -41,69 +41,69 @@ PositionalEmbedding = embeddings.PositionalEmbedding Quant = quantizations.AqtQuantization -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # The network: Decoder & Transformer Definitions -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ class DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_norm', + name="pre_self_attention_norm", epsilon=cfg.normalization_layer_epsilon, - kernel_axes=('embed',))(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + kernel_axes=("embed",), + )(inputs) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) attention_layer = Attention( - config = self.config, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) - + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) # MLP block. mlp_lnx = linears.MlpBlock( @@ -112,32 +112,30 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(lnx, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) next_layer_addition = mlp_lnx + attention_lnx - next_layer_addition_dropped_out = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,) - )(next_layer_addition, deterministic=deterministic) + next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + next_layer_addition, deterministic=deterministic + ) layer_output = next_layer_addition_dropped_out + inputs layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) @@ -146,6 +144,7 @@ def __call__(self, class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" + config: Config shared_embedding: nn.Module mesh: Mesh @@ -156,16 +155,20 @@ def get_decoder_layer(self): return DecoderLayer elif self.config.decoder_block == "llama2": from layers import llama2 + return llama2.LlamaDecoderLayer elif self.config.decoder_block == "mistral": # TODO(ranran): update to Mistral with sliding window attention from layers import mistral + return mistral.MistralDecoderLayer elif self.config.decoder_block == "gemma": from layers import gemma + return gemma.GemmaDecoderLayer elif self.config.decoder_block == "gpt3": from layers import gpt3 + return gpt3.Gpt3DecoderLayer else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @@ -175,74 +178,89 @@ def get_norm_layer(self): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=True) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @nn.compact - def __call__(self, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=common_types.MODEL_MODE_TRAIN, - ): + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=common_types.MODEL_MODE_TRAIN, + ): cfg = self.config mesh = self.mesh assert decoder_input_tokens.ndim == 2 # [batch, len] # [batch, length] -> [batch, length, emb_dim] - y = self.shared_embedding(decoder_input_tokens.astype('int32')) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic) + y = self.shared_embedding(decoder_input_tokens.astype("int32")) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) y = y.astype(cfg.dtype) if cfg.use_untrainable_positional_embedding: - y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) + y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) if cfg.trainable_position_size > 0: y += Embed( - num_embeddings=cfg.trainable_position_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - name='position_embedder', - config=cfg)(decoder_positions) + num_embeddings=cfg.trainable_position_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + name="position_embedder", + config=cfg, + )(decoder_positions) BlockLayer = self.get_decoder_layer() - if cfg.remat_policy != 'none': - if cfg.remat_policy == 'minimal': + if cfg.remat_policy != "none": + if cfg.remat_policy == "minimal": policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - elif cfg.remat_policy == 'save_dot_except_mlpwi': + elif cfg.remat_policy == "save_dot_except_mlpwi": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', 'out_proj', 'mlpwo', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", ) - elif cfg.remat_policy == 'save_dot_except_mlp': + elif cfg.remat_policy == "save_dot_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', 'out_proj', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", ) - elif cfg.remat_policy == 'save_qkv_proj': + elif cfg.remat_policy == "save_qkv_proj": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", ) - elif cfg.remat_policy == 'qkv_proj_offloaded': + elif cfg.remat_policy == "qkv_proj_offloaded": policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=['query_proj', 'value_proj', 'key_proj'], - offload_src="device", offload_dst="pinned_host") - elif cfg.remat_policy == 'minimal_offloaded': + names_which_can_be_saved=[], + names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") - elif cfg.remat_policy == 'minimal_flash': + elif cfg.remat_policy == "minimal_flash": policy = jax.checkpoint_policies.save_from_both_policies( - jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - jax.checkpoint_policies.save_only_these_names('context',), + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + jax.checkpoint_policies.save_only_these_names( + "context", + ), ) else: - assert ( - cfg.remat_policy == 'full' - ), 'Remat policy needs to be on list of remat policies' + assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" policy = None BlockLayer = nn.remat( # pylint: disable=invalid-name BlockLayer, @@ -251,23 +269,21 @@ def __call__(self, static_argnums=(-1, -2, -3, -4, -5), ) if cfg.scan_layers: - initializing = self.is_mutable_collection('params') - params_spec = ( - cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) - ) + initializing = self.is_mutable_collection("params") + params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) cache_spec = 0 y, _ = nn.scan( BlockLayer, variable_axes={ - 'params': params_spec, - 'cache': cache_spec, - 'intermediates': 0, - 'aqt':0, - '_overwrite_with_gradient': 0, + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, }, split_rngs={ - 'params': True, - 'dropout': cfg.enable_dropout, + "params": True, + "dropout": cfg.enable_dropout, }, in_axes=( nn.broadcast, @@ -276,8 +292,8 @@ def __call__(self, nn.broadcast, ), length=cfg.num_decoder_layers, - metadata_params={nn.PARTITION_NAME: 'layers'}, - )(config=cfg, mesh=mesh, name='layers', quant=self.quant)( + metadata_params={nn.PARTITION_NAME: "layers"}, + )(config=cfg, mesh=mesh, name="layers", quant=self.quant)( y, decoder_segment_ids, decoder_positions, @@ -286,8 +302,7 @@ def __call__(self, ) else: for lyr in range(cfg.num_decoder_layers): - y = BlockLayer(config=cfg, mesh=mesh, name=f'layers_{lyr}', - quant=self.quant)( + y = BlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( y, decoder_segment_ids, decoder_positions, @@ -296,15 +311,13 @@ def __call__(self, ) y = self.get_norm_layer()( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name='decoder_norm', - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=('embed',), - )(y) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="decoder_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("embed",), + )(y) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: @@ -318,16 +331,19 @@ def __call__(self, cfg.vocab_size, weight_dtype=cfg.weight_dtype, dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=('embed', 'vocab'), - name='logits_dense')(y) # We do not quantize the logits matmul. - logits = nn.with_logical_constraint( - logits, ('activation_batch', 'activation_length', 'activation_vocab')) + kernel_axes=("embed", "vocab"), + name="logits_dense", + )( + y + ) # We do not quantize the logits matmul. + logits = nn.with_logical_constraint(logits, ("activation_batch", "activation_length", "activation_vocab")) logits = logits.astype(jnp.float32) return logits class Transformer(nn.Module): """An decoder-only Transformer model.""" + # Make new attributes required, so that all Transformer dependencies (train, decode, compile, etc) will error instead of silently use defaults. # pylint: disable=attribute-defined-outside-init config: Config @@ -345,14 +361,11 @@ def setup(self): dtype=cfg.dtype, attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), - name='token_embedder', + name="token_embedder", config=cfg, ) - self.decoder = Decoder( - config=cfg, shared_embedding=self.shared_embedding, - mesh=mesh, quant=self.quant - ) + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) def __call__( self, @@ -360,14 +373,15 @@ def __call__( decoder_positions, decoder_segment_ids=None, enable_dropout=True, - model_mode=common_types.MODEL_MODE_TRAIN + model_mode=common_types.MODEL_MODE_TRAIN, ): """Applies Transformer decoder-branch on encoded-input and target.""" if decoder_segment_ids is not None and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: raise ValueError( - f'During autoregressive decoding we assume the tokens are in the active sequence' - f' which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}.') + f"During autoregressive decoding we assume the tokens are in the active sequence" + f" which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}." + ) logits = self.decoder( decoder_input_tokens=decoder_input_tokens, diff --git a/MaxText/layers/normalizations.py b/MaxText/layers/normalizations.py index 6d451d4fe..862c586c9 100644 --- a/MaxText/layers/normalizations.py +++ b/MaxText/layers/normalizations.py @@ -26,6 +26,7 @@ class RMSNorm(nn.Module): """RMS normalization.""" + epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -40,7 +41,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) scale = self.param( - 'scale', + "scale", nn.with_logical_partitioning(self.scale_init, self.kernel_axes), (features,), self.weight_dtype, diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index c6450ea17..dba7658bd 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -27,117 +27,119 @@ MAX_INT8 = 127.5 + @dataclass class Quantization: - """Base class for quantization configurations""" + """Base class for quantization configurations""" - def dot_general_cls(self): - """ Placeholder for dot_general implementation in subclasses. """ - pass + def dot_general_cls(self): + """Placeholder for dot_general implementation in subclasses.""" + pass @dataclass class AqtQuantization: - """ Configures AQT quantization github.com/google/aqt. """ + """Configures AQT quantization github.com/google/aqt.""" + quant_dg: aqt_config.DotGeneral quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN def dot_general_cls(self): - """ Returns dot_general configured with aqt params. """ - aqt_dg_cls = functools.partial( - aqt_flax.AqtDotGeneral, - self.quant_dg, - rhs_quant_mode=self.quant_mode - ) + """Returns dot_general configured with aqt params.""" + aqt_dg_cls = functools.partial(aqt_flax.AqtDotGeneral, self.quant_dg, rhs_quant_mode=self.quant_mode) return aqt_dg_cls def einsum(self): - """ Returns einsum configured with aqt params """ - aqt_einsum = functools.partial(aqt_flax.AqtEinsum( - cfg=self.quant_dg, - lhs_quant_mode=self.quant_mode - ) - ) + """Returns einsum configured with aqt params""" + aqt_einsum = functools.partial(aqt_flax.AqtEinsum(cfg=self.quant_dg, lhs_quant_mode=self.quant_mode)) return aqt_einsum + @dataclass class Fp8Quantization(Quantization): - """ Configures Fp8 quantization for NVIDIA GPUs""" + """Configures Fp8 quantization for NVIDIA GPUs""" + quant_mode = "train" def dot_general_cls(self): - """ Returns dot_general configured with aqt params. """ + """Returns dot_general configured with aqt params.""" return nn.Fp8DotGeneralOp + def _get_quant_config(config): """Set quantization params based on user configuration.""" - if not config.quantization or config.quantization == '': + if not config.quantization or config.quantization == "": return None elif config.quantization == "int8": if config.quantization_local_shard_count == 0: drhs_bits = None drhs_accumulator_dtype = None - drhs_local_aqt=None + drhs_local_aqt = None else: drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 drhs_local_aqt = aqt_config.LocalAqt(config.quantization_local_shard_count) return aqt_config.config_v3( - fwd_bits=8, - dlhs_bits=8, - drhs_bits=drhs_bits, - rng_type='jax.uniform', - dlhs_local_aqt=None, - drhs_local_aqt=drhs_local_aqt, - fwd_accumulator_dtype=jnp.int32, - dlhs_accumulator_dtype=jnp.int32, - drhs_accumulator_dtype=drhs_accumulator_dtype, + fwd_bits=8, + dlhs_bits=8, + drhs_bits=drhs_bits, + rng_type="jax.uniform", + dlhs_local_aqt=None, + drhs_local_aqt=drhs_local_aqt, + fwd_accumulator_dtype=jnp.int32, + dlhs_accumulator_dtype=jnp.int32, + drhs_accumulator_dtype=drhs_accumulator_dtype, ) elif config.quantization == "fp8": return "fp8" else: - raise ValueError(f'Invalid value configured for quantization {config.quantization}.') + raise ValueError(f"Invalid value configured for quantization {config.quantization}.") + def in_convert_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT) + def in_serve_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE) -def get_quant_mode(quant_mode_str: str = 'train'): - """ Set quant mode.""" - if quant_mode_str == 'train': + +def get_quant_mode(quant_mode_str: str = "train"): + """Set quant mode.""" + if quant_mode_str == "train": return aqt_flax.QuantMode.TRAIN - elif quant_mode_str == 'serve': + elif quant_mode_str == "serve": return aqt_flax.QuantMode.SERVE - elif quant_mode_str == 'convert': + elif quant_mode_str == "convert": return aqt_flax.QuantMode.CONVERT else: - raise ValueError(f'Invalid quantization mode {quant_mode_str}.') + raise ValueError(f"Invalid quantization mode {quant_mode_str}.") return None -def configure_quantization(config: Config, quant_mode_str: str = 'train'): - """ Configure quantization based on user config and quant mode.""" + +def configure_quantization(config: Config, quant_mode_str: str = "train"): + """Configure quantization based on user config and quant mode.""" quant_cfg = _get_quant_config(config) if quant_cfg: if quant_cfg == "fp8": - return Fp8Quantization() + return Fp8Quantization() quant_mode = get_quant_mode(quant_mode_str) return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode) return None + def _get_aqt_key_paths(aqt_vars): - """ Generate a list of paths which have aqt state """ + """Generate a list of paths which have aqt state""" aqt_tree_flat, _ = jax.tree_util.tree_flatten_with_path(aqt_vars) aqt_key_paths = [] for k, _ in aqt_tree_flat: pruned_keys = [] for d in list(k): - if 'AqtDotGeneral' in d.key: - pruned_keys.append(jax.tree_util.DictKey(key='kernel')) + if "AqtDotGeneral" in d.key: + pruned_keys.append(jax.tree_util.DictKey(key="kernel")) break else: - assert 'Aqt' not in d.key, f"Unexpected Aqt op {d.key} in {k}." + assert "Aqt" not in d.key, f"Unexpected Aqt op {d.key} in {k}." pruned_keys.append(d) aqt_key_paths.append(tuple(pruned_keys)) return aqt_key_paths @@ -153,16 +155,19 @@ def remove_quantized_params(params, aqt_vars): tree_flat[i] = v return tree_unflatten(tree_struct, tree_flat) + def configure_kv_quantization(config: Config): - """ Configure kv quantization based on user config.""" + """Configure kv quantization based on user config.""" return False if not config.quantize_kvcache else True + def quantize_kv(kv: Array): """Quantize key/values stored in kvcache.""" scale = jnp.max(jnp.abs(kv), axis=-1, keepdims=True) value = jnp.int8(jnp.rint(kv * (MAX_INT8 / scale))) return value, scale -def unquantize_kv(value: Array, scale:Array, dtype:jnp.dtype): + +def unquantize_kv(value: Array, scale: Array, dtype: jnp.dtype): """Unquantize key/values stored in kvcache.""" return value.astype(dtype) * scale / MAX_INT8 diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index b9a7aefe4..97a456764 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" r"""Convert weights from a Llama or Mistral model to a MaxText one. @@ -42,54 +42,55 @@ import sys import os -jax.config.update('jax_platform_name', 'cpu') +jax.config.update("jax_platform_name", "cpu") + def permute_to_match_maxtext_rope(arr): evens = arr[..., ::2] odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) MODEL_PARAMS_DICT = { - 'llama2-70b': { - 'num_layers': 80, - 'num_heads': 64, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-70b": { + "num_layers": 80, + "num_heads": 64, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, }, - 'llama2-13b': { - 'num_layers': 40, - 'num_heads': 40, - 'num_kv_heads': 40, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-13b": { + "num_layers": 40, + "num_heads": 40, + "num_kv_heads": 40, + "dims_per_head": 128, + "vocab": 32000, }, - 'llama2-7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 32, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 32, + "dims_per_head": 128, + "vocab": 32000, }, - 'mistral-7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, - 'base_emb_dim': 4096, - 'base_mlp_dim': 14336, + "mistral-7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, + "base_emb_dim": 4096, + "base_mlp_dim": 14336, }, - 'mixtral-8x7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, - 'base_emb_dim': 4096, - 'base_mlp_dim': 14336, - 'num_experts': 8, + "mixtral-8x7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, + "base_emb_dim": 4096, + "base_mlp_dim": 14336, + "num_experts": 8, }, } @@ -108,256 +109,213 @@ def convert(base_model_path, maxtext_model_path, model_size): """ """Convert model to maxtext.""" model_params = MODEL_PARAMS_DICT[model_size] - base_num_decoder_layers = model_params['num_layers'] - base_num_query_heads = model_params['num_heads'] - head_dim = model_params['dims_per_head'] - base_num_kv_heads = model_params['num_kv_heads'] - vocab_size = model_params['vocab'] - num_experts = model_params['num_experts'] if 'num_experts' in model_params else None - - print(f'Loading the base model from {base_model_path}') + base_num_decoder_layers = model_params["num_layers"] + base_num_query_heads = model_params["num_heads"] + head_dim = model_params["dims_per_head"] + base_num_kv_heads = model_params["num_kv_heads"] + vocab_size = model_params["vocab"] + num_experts = model_params["num_experts"] if "num_experts" in model_params else None + + print(f"Loading the base model from {base_model_path}") # Skip any hidden files for checkpoints - ckpt_paths = sorted(pathlib.Path(base_model_path).glob('[!.]*.pth')) + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.pth")) pytorch_vars = {} for i, ckpt_path in enumerate(ckpt_paths): - print(f'Loading checkpoint {i+1} of {len(ckpt_paths)} ...') - checkpoint = torch.load(ckpt_path, map_location='cpu') - pytorch_vars[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint + print(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") + checkpoint = torch.load(ckpt_path, map_location="cpu") + pytorch_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint pytorch_vars = [pytorch_vars[i] for i in sorted(list(pytorch_vars.keys()))] - layer_key = 'gate' if num_experts else 'mlp' + layer_key = "gate" if num_experts else "mlp" jax_weights = { - 'decoder': { - 'layers': { + "decoder": { + "layers": { layer_key: {}, - 'pre_self_attention_layer_norm': {}, - 'post_self_attention_layer_norm': {}, - 'self_attention': {}, + "pre_self_attention_layer_norm": {}, + "post_self_attention_layer_norm": {}, + "self_attention": {}, }, - 'decoder_norm': { - 'scale': pytorch_vars[0]['norm.weight'].type(torch.float16).numpy() + "decoder_norm": {"scale": pytorch_vars[0]["norm.weight"].type(torch.float16).numpy()}, + "logits_dense": { + "kernel": np.concatenate( + [var["output.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose()[:, :vocab_size] }, - 'logits_dense': { - 'kernel': np.concatenate([var['output.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose()[:, :vocab_size] - } }, - 'token_embedder': { - 'embedding': np.concatenate([var['tok_embeddings.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1)[:vocab_size, :] - - } - - } - - layer_weight = { - 'pre_self_attention_layer_norm': { - 'scale': [] + "token_embedder": { + "embedding": np.concatenate( + [var["tok_embeddings.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=1 + )[:vocab_size, :] }, - 'post_self_attention_layer_norm': { - 'scale': [] - } } + layer_weight = {"pre_self_attention_layer_norm": {"scale": []}, "post_self_attention_layer_norm": {"scale": []}} + if num_experts is None: - layer_weight['mlp'] = { - 'wi_0': { - 'kernel': [] - }, - 'wi_1': { - 'kernel': [] - }, - 'wo': { - 'kernel': [] - }, + layer_weight["mlp"] = { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, } else: - layer_weight['gate'] = { - 'kernel': [] - } + layer_weight["gate"] = {"kernel": []} for k in range(num_experts): - jax_weights['decoder']['layers'][f'mlp_{k}'] = {} - layer_weight[f'mlp_{k}'] = { - 'wi_0': { - 'kernel': [] - }, - 'wi_1': { - 'kernel': [] - }, - 'wo': { - 'kernel': [] - }, + jax_weights["decoder"]["layers"][f"mlp_{k}"] = {} + layer_weight[f"mlp_{k}"] = { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, } self_attention = { - 'query': { - 'kernel': [] - }, - 'key': { - 'kernel': [] - }, - 'value': { - 'kernel': [] - }, - 'out': { - 'kernel': [] - }, + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, } for layer_idx in range(base_num_decoder_layers): - wq = np.concatenate([var[f'layers.{layer_idx}.attention.wq.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wk = np.concatenate([var[f'layers.{layer_idx}.attention.wk.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wv = np.concatenate([var[f'layers.{layer_idx}.attention.wv.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - - wq = np.reshape(wq, [base_num_query_heads * head_dim, - base_num_query_heads, head_dim]) - wk = np.reshape(wk, [base_num_query_heads * head_dim, - base_num_kv_heads, head_dim]) - wv = np.reshape(wv, [base_num_query_heads * head_dim, - base_num_kv_heads, head_dim]) + wq = np.concatenate( + [var[f"layers.{layer_idx}.attention.wq.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wk = np.concatenate( + [var[f"layers.{layer_idx}.attention.wk.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wv = np.concatenate( + [var[f"layers.{layer_idx}.attention.wv.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + + wq = np.reshape(wq, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) + wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) + wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) wq = permute_to_match_maxtext_rope(wq) wk = permute_to_match_maxtext_rope(wk) w_post = np.concatenate( - [ - var[f'layers.{layer_idx}.attention.wo.weight'].type( - torch.float16).numpy() - for var in pytorch_vars - ], + [var[f"layers.{layer_idx}.attention.wo.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=1, ) - w_post = np.reshape( - w_post, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) - - self_attention['query']['kernel'].append(wq) - self_attention['key']['kernel'].append(wk) - self_attention['value']['kernel'].append(wv) - self_attention['out']['kernel'].append(w_post) - pre_self_attention_layernorm = pytorch_vars[0][f'layers.{layer_idx}.attention_norm.weight'].type( - torch.float16).numpy() - post_self_attention_layernorm = pytorch_vars[0][f'layers.{layer_idx}.ffn_norm.weight'].type( - torch.float16).numpy() - layer_weight['pre_self_attention_layer_norm']['scale'].append( - pre_self_attention_layernorm) - layer_weight['post_self_attention_layer_norm']['scale'].append( - post_self_attention_layernorm) + w_post = np.reshape(w_post, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) + + self_attention["query"]["kernel"].append(wq) + self_attention["key"]["kernel"].append(wk) + self_attention["value"]["kernel"].append(wv) + self_attention["out"]["kernel"].append(w_post) + pre_self_attention_layernorm = pytorch_vars[0][f"layers.{layer_idx}.attention_norm.weight"].type(torch.float16).numpy() + post_self_attention_layernorm = pytorch_vars[0][f"layers.{layer_idx}.ffn_norm.weight"].type(torch.float16).numpy() + layer_weight["pre_self_attention_layer_norm"]["scale"].append(pre_self_attention_layernorm) + layer_weight["post_self_attention_layer_norm"]["scale"].append(post_self_attention_layernorm) if num_experts is None: - wi_0 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w1.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wi_1 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w3.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wo = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w2.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1).transpose() - layer_weight['mlp']['wi_0']['kernel'].append(wi_0) - layer_weight['mlp']['wi_1']['kernel'].append(wi_1) - layer_weight['mlp']['wo']['kernel'].append(wo) + wi_0 = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.w1.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wi_1 = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.w3.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wo = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.w2.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=1 + ).transpose() + layer_weight["mlp"]["wi_0"]["kernel"].append(wi_0) + layer_weight["mlp"]["wi_1"]["kernel"].append(wi_1) + layer_weight["mlp"]["wo"]["kernel"].append(wo) else: - gate = np.concatenate([var[f'layers.{layer_idx}.feed_forward.gate.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - layer_weight['gate']['kernel'].append(gate) + gate = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.gate.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + layer_weight["gate"]["kernel"].append(gate) for k in range(num_experts): - wi_0 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w1.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wi_1 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w3.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wo = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w2.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1).transpose() - layer_weight[f'mlp_{k}']['wi_0']['kernel'].append(wi_0) - layer_weight[f'mlp_{k}']['wi_1']['kernel'].append(wi_1) - layer_weight[f'mlp_{k}']['wo']['kernel'].append(wo) - - self_attention['query']['kernel'] = np.array( - self_attention['query']['kernel']) - self_attention['key']['kernel'] = np.array(self_attention['key']['kernel']) - self_attention['value']['kernel'] = np.array( - self_attention['value']['kernel']) - self_attention['out']['kernel'] = np.array(self_attention['out']['kernel']) - self_attention['query']['kernel'] = np.transpose( - self_attention['query']['kernel'], axes=(1, 0, 2, 3)) - self_attention['key']['kernel'] = np.transpose( - self_attention['key']['kernel'], axes=(1, 0, 2, 3)) - self_attention['value']['kernel'] = np.transpose( - self_attention['value']['kernel'], axes=(1, 0, 2, 3)) + wi_0 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w1.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wi_1 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w3.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wo = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w2.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=1, + ).transpose() + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"].append(wi_0) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"].append(wi_1) + layer_weight[f"mlp_{k}"]["wo"]["kernel"].append(wo) + + self_attention["query"]["kernel"] = np.array(self_attention["query"]["kernel"]) + self_attention["key"]["kernel"] = np.array(self_attention["key"]["kernel"]) + self_attention["value"]["kernel"] = np.array(self_attention["value"]["kernel"]) + self_attention["out"]["kernel"] = np.array(self_attention["out"]["kernel"]) + self_attention["query"]["kernel"] = np.transpose(self_attention["query"]["kernel"], axes=(1, 0, 2, 3)) + self_attention["key"]["kernel"] = np.transpose(self_attention["key"]["kernel"], axes=(1, 0, 2, 3)) + self_attention["value"]["kernel"] = np.transpose(self_attention["value"]["kernel"], axes=(1, 0, 2, 3)) # layers, base_num_query_heads * head_dim, base_num_query_heads, head_dim => # base_num_query_heads, layers,head_dim, base_num_query_heads * head_dim - self_attention['out']['kernel'] = np.transpose( - self_attention['out']['kernel'], axes=(2, 0, 3, 1)) + self_attention["out"]["kernel"] = np.transpose(self_attention["out"]["kernel"], axes=(2, 0, 3, 1)) # scale the query weights - self_attention['query']['kernel'] = self_attention['query']['kernel'] / \ - np.sqrt(head_dim) + self_attention["query"]["kernel"] = self_attention["query"]["kernel"] / np.sqrt(head_dim) - jax_weights['decoder']['layers']['self_attention'] = self_attention + jax_weights["decoder"]["layers"]["self_attention"] = self_attention # self attention layer norm and swap the layer index - layer_weight['pre_self_attention_layer_norm']['scale'] = np.array( - layer_weight['pre_self_attention_layer_norm']['scale']) - layer_weight['post_self_attention_layer_norm']['scale'] = np.array( - layer_weight['post_self_attention_layer_norm']['scale']) - layer_weight['pre_self_attention_layer_norm']['scale'] = np.transpose( - layer_weight['pre_self_attention_layer_norm']['scale'], - axes=(1, 0)) - layer_weight['post_self_attention_layer_norm']['scale'] = np.transpose( - layer_weight['post_self_attention_layer_norm']['scale'], - axes=(1, 0)) - - jax_weights['decoder']['layers']['pre_self_attention_layer_norm'] = layer_weight['pre_self_attention_layer_norm'] - jax_weights['decoder']['layers']['post_self_attention_layer_norm'] = layer_weight['post_self_attention_layer_norm'] + layer_weight["pre_self_attention_layer_norm"]["scale"] = np.array(layer_weight["pre_self_attention_layer_norm"]["scale"]) + layer_weight["post_self_attention_layer_norm"]["scale"] = np.array(layer_weight["post_self_attention_layer_norm"]["scale"]) + layer_weight["pre_self_attention_layer_norm"]["scale"] = np.transpose( + layer_weight["pre_self_attention_layer_norm"]["scale"], axes=(1, 0) + ) + layer_weight["post_self_attention_layer_norm"]["scale"] = np.transpose( + layer_weight["post_self_attention_layer_norm"]["scale"], axes=(1, 0) + ) + + jax_weights["decoder"]["layers"]["pre_self_attention_layer_norm"] = layer_weight["pre_self_attention_layer_norm"] + jax_weights["decoder"]["layers"]["post_self_attention_layer_norm"] = layer_weight["post_self_attention_layer_norm"] if num_experts is None: - layer_weight['mlp']['wi_0']['kernel'] = np.array( - layer_weight['mlp']['wi_0']['kernel']) - layer_weight['mlp']['wi_1']['kernel'] = np.array( - layer_weight['mlp']['wi_1']['kernel']) - layer_weight['mlp']['wo']['kernel'] = np.array( - layer_weight['mlp']['wo']['kernel']) + layer_weight["mlp"]["wi_0"]["kernel"] = np.array(layer_weight["mlp"]["wi_0"]["kernel"]) + layer_weight["mlp"]["wi_1"]["kernel"] = np.array(layer_weight["mlp"]["wi_1"]["kernel"]) + layer_weight["mlp"]["wo"]["kernel"] = np.array(layer_weight["mlp"]["wo"]["kernel"]) # swap the layer index - layer_weight['mlp']['wi_0']['kernel'] = np.transpose( - layer_weight['mlp']['wi_0']['kernel'], axes=(1, 0, 2)) - layer_weight['mlp']['wi_1']['kernel'] = np.transpose( - layer_weight['mlp']['wi_1']['kernel'], axes=(1, 0, 2)) - layer_weight['mlp']['wo']['kernel'] = np.transpose( - layer_weight['mlp']['wo']['kernel'], axes=(1, 0, 2)) - - jax_weights['decoder']['layers']['mlp'] = layer_weight['mlp'] + layer_weight["mlp"]["wi_0"]["kernel"] = np.transpose(layer_weight["mlp"]["wi_0"]["kernel"], axes=(1, 0, 2)) + layer_weight["mlp"]["wi_1"]["kernel"] = np.transpose(layer_weight["mlp"]["wi_1"]["kernel"], axes=(1, 0, 2)) + layer_weight["mlp"]["wo"]["kernel"] = np.transpose(layer_weight["mlp"]["wo"]["kernel"], axes=(1, 0, 2)) + + jax_weights["decoder"]["layers"]["mlp"] = layer_weight["mlp"] else: - layer_weight['gate']['kernel'] = np.array(layer_weight['gate']['kernel']) - layer_weight['gate']['kernel'] = np.transpose( - layer_weight['gate']['kernel'], axes=(1, 0, 2)) - jax_weights['decoder']['layers']['gate'] = layer_weight['gate'] + layer_weight["gate"]["kernel"] = np.array(layer_weight["gate"]["kernel"]) + layer_weight["gate"]["kernel"] = np.transpose(layer_weight["gate"]["kernel"], axes=(1, 0, 2)) + jax_weights["decoder"]["layers"]["gate"] = layer_weight["gate"] for k in range(num_experts): - layer_weight[f'mlp_{k}']['wi_0']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wi_0']['kernel']) - layer_weight[f'mlp_{k}']['wi_1']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wi_1']['kernel']) - layer_weight[f'mlp_{k}']['wo']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wo']['kernel']) + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"]) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"]) + layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wo"]["kernel"]) # swap the layer index - layer_weight[f'mlp_{k}']['wi_0']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wi_0']['kernel'], axes=(1, 0, 2)) - layer_weight[f'mlp_{k}']['wi_1']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wi_1']['kernel'], axes=(1, 0, 2)) - layer_weight[f'mlp_{k}']['wo']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wo']['kernel'], axes=(1, 0, 2)) + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"], axes=(1, 0, 2)) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"], axes=(1, 0, 2)) + layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wo"]["kernel"], axes=(1, 0, 2)) - jax_weights['decoder']['layers'][f'mlp_{k}'] = layer_weight[f'mlp_{k}'] + jax_weights["decoder"]["layers"][f"mlp_{k}"] = layer_weight[f"mlp_{k}"] mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis") - s1=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec("checkpoint_sharding_axis")) #shards first axis - s2=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec(None,"checkpoint_sharding_axis")) #shards second axis - s3=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec(None)) #no sharding + s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis")) # shards first axis + s2 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "checkpoint_sharding_axis")) # shards second axis + s3 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # no sharding def checkpoint_device_put(arr): - if arr.shape[0]%SIMULATED_CPU_DEVICES_COUNT==0: + if arr.shape[0] % SIMULATED_CPU_DEVICES_COUNT == 0: print("sharding first axis") return jax.device_put(arr, device=s1) - elif len(arr.shape)>1 and arr.shape[1]%SIMULATED_CPU_DEVICES_COUNT==0: + elif len(arr.shape) > 1 and arr.shape[1] % SIMULATED_CPU_DEVICES_COUNT == 0: print("sharding second axis") return jax.device_put(arr, device=s2) else: @@ -374,41 +332,33 @@ def checkpoint_device_put(arr): save_interval_steps = 1 checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - maxtext_model_path, - enable_checkpointing, - async_checkpointing, - save_interval_steps + maxtext_model_path, enable_checkpointing, async_checkpointing, save_interval_steps ) state_new = train_state.TrainState( - step=0, - apply_fn=None, - params={'params': jax_weights}, - tx=None, # type: ignore - opt_state={} + step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore ) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new): - max_logging.log( - f"saved a checkpoint at step {step_number_to_save_new_ckpt}") + max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}") # Upon preemption, exit when and only when all ongoing saves are complete. if checkpoint_manager.reached_preemption(0): checkpoint_manager.wait_until_finished() sys.exit() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--base-model-path', type=str, required=True) - parser.add_argument('--maxtext-model-path', type=str, required=True) - parser.add_argument('--model-size', type=str, required=True) + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--maxtext-model-path", type=str, required=True) + parser.add_argument("--model-size", type=str, required=True) args = parser.parse_args() if args.model_size not in MODEL_PARAMS_DICT: raise NotImplementedError - os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}' + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}" convert(args.base_model_path, args.maxtext_model_path, args.model_size) diff --git a/MaxText/max_logging.py b/MaxText/max_logging.py index 61d984f7b..23f7cc5cf 100644 --- a/MaxText/max_logging.py +++ b/MaxText/max_logging.py @@ -1,20 +1,21 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Stub for logging utilities. Right now just meant to avoid raw prints""" + def log(user_str): - print(user_str, flush = True) + print(user_str, flush=True) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 92b789bb4..2e218c48a 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Common Max Utils needed by multiple modules""" import checkpointing @@ -44,17 +44,19 @@ from google.cloud import storage + def find_nans_and_infs(pytree): def finder(x): return jnp.any(jnp.isinf(x) | jnp.isnan(x)) + bad_pytree = jax.tree_map(finder, pytree) return jax.tree_util.tree_flatten(bad_pytree) + def l2norm_pytree(x): """L2 norm of a pytree of arrays.""" - return jax.tree_util.tree_reduce( - lambda x, y: x + jax.numpy.sum(jax.numpy.square(y)), x, initializer=0.0 - ) ** 0.5 + return jax.tree_util.tree_reduce(lambda x, y: x + jax.numpy.sum(jax.numpy.square(y)), x, initializer=0.0) ** 0.5 + def calculate_num_params_from_pytree(params): params_sizes = jax.tree_util.tree_map(jax.numpy.size, params) @@ -68,70 +70,78 @@ def calculate_leaf_params_per_chip(arr): shard = arr.addressable_shards[0] return np.prod(shard.data.shape) - params_sizes_per_chip = jax.tree_util.tree_map( - calculate_leaf_params_per_chip, params) - total_parameters_per_chip = jax.tree_util.tree_reduce( - lambda x, y: x + y, params_sizes_per_chip) + params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_per_chip, params) + total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip) return total_parameters_per_chip def calculate_bytes_from_pytree(params): - params_bytes = jax.tree_util.tree_map(lambda x : x.nbytes, params) + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) return total_bytes + def summarize_size_from_pytree(params): num_params = calculate_num_params_from_pytree(params) num_bytes = calculate_bytes_from_pytree(params) - return num_params, num_bytes, num_bytes/num_params + return num_params, num_bytes, num_bytes / num_params + def activate_profiler(config, optional_postfix=""): if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0): output_path = os.path.join(config.tensorboard_dir, optional_postfix) jax.profiler.start_trace(output_path) + def deactivate_profiler(config): if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0): jax.profiler.stop_trace() + def initialize_summary_writer(config): return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None + def close_summary_writer(summary_writer): if jax.process_index() == 0: summary_writer.close() + def _prepare_metrics_for_json(metrics, step, run_name): """Converts metric dictionary into json supported types (e.g. float)""" metrics_dict = {} - for val in metrics['scalar']: - metrics_dict[val] = float(metrics['scalar'][val]) - metrics_dict['step'] = float(step) - metrics_dict['run_name'] = run_name + for val in metrics["scalar"]: + metrics_dict[val] = float(metrics["scalar"][val]) + metrics_dict["step"] = float(step) + metrics_dict["run_name"] = run_name return metrics_dict + def write_metrics_locally(metrics, step, config, file): """Writes metrics locally for testing""" if step == 0: file.truncate(0) metrics_dict = _prepare_metrics_for_json(metrics, step, config.run_name) - file.write(str(json.dumps(metrics_dict))+'\n') + file.write(str(json.dumps(metrics_dict)) + "\n") if step == config.steps - 1: file.close() + def add_config_to_summary_writer(config, summary_writer): """Writes config params to tensorboard""" if jax.process_index() == 0: for key, value in config.get_keys().items(): add_text_to_summary_writer(key, str(value), summary_writer) + def add_text_to_summary_writer(key, value, summary_writer): """Writes given key-value pair to tensorboard as text/summary""" if jax.process_index() == 0: summary_writer.add_text(key, value) + def write_metrics_for_gcs(metrics, step, config, running_metrics): """Writes metrics to gcs""" metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name) @@ -139,18 +149,19 @@ def write_metrics_for_gcs(metrics, step, config, running_metrics): if (step + 1) % config.log_period == 0 or step == config.steps - 1: start_step = (step // config.log_period) * config.log_period metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt" - with open(metrics_filename, 'w', encoding="utf8") as metrics_for_gcs: + with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs: for metrics_step in running_metrics: - metrics_for_gcs.write(str(json.dumps(metrics_step))+'\n') + metrics_for_gcs.write(str(json.dumps(metrics_step)) + "\n") metrics_for_gcs.close() - gcs_filename=os.path.join(config.metrics_dir, metrics_filename) + gcs_filename = os.path.join(config.metrics_dir, metrics_filename) max_logging.log(f"Moving file {metrics_filename} to GCS...") upload_blob(gcs_filename, metrics_filename) max_logging.log(f"File {metrics_filename} moved successfully!") - running_metrics = [] # reset running_metrics to empty list + running_metrics = [] # reset running_metrics to empty list return running_metrics + def write_config_raw_keys_for_gcs(raw_keys): """Writes config raw keys to GCS""" if not raw_keys["save_config_to_gcs"] or jax.process_index() != 0: @@ -159,21 +170,23 @@ def write_config_raw_keys_for_gcs(raw_keys): raw_keys_dict = dict(raw_keys) filename = "config.yml" - with open(filename, 'w', encoding="utf8") as config_for_gcs: + with open(filename, "w", encoding="utf8") as config_for_gcs: yaml.dump(raw_keys_dict, config_for_gcs) config_for_gcs.close() - gcs_filename=os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], filename) + gcs_filename = os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], filename) max_logging.log(f"Moving file {filename} to GCS...") upload_blob(gcs_filename, filename) max_logging.log(f"File {filename} moved successfully!") + def parse_gcs_bucket_and_prefix(destination_gcs_name): path_parts = destination_gcs_name.replace("gs://", "").split("/") bucket = path_parts.pop(0) key = "/".join(path_parts) return bucket, key + def upload_blob(destination_gcs_name, source_file_name): """Uploads a file to a GCS location""" bucket_name, prefix_name = parse_gcs_bucket_and_prefix(destination_gcs_name) @@ -182,16 +195,18 @@ def upload_blob(destination_gcs_name, source_file_name): blob = bucket.blob(prefix_name) blob.upload_from_filename(source_file_name) + def maybe_initialize_jax_distributed_system(raw_keys): - """ The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of - indirection in MaxText to avoid breaking the call sites unnecessarily. + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. - Currently jax.distributed.initialize() fully works as expected! + Currently jax.distributed.initialize() fully works as expected! - For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. """ - if (raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] - and raw_keys["compile_topology_num_slices"]==-1) or raw_keys["hardware"]=='gpu_multiprocess': + if ( + raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1 + ) or raw_keys["hardware"] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") jax.distributed.initialize() max_logging.log("Jax distributed system initialized!") @@ -204,6 +219,7 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_cpu() max_logging.log("Jax distributed system initialized on CPUs!") + def initialize_jax_for_gpu(): """Jax distributed initialize for GPUs.""" if os.environ.get("JAX_COORDINATOR_IP") is not None: @@ -212,14 +228,15 @@ def initialize_jax_for_gpu(): jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", num_processes=int(os.getenv("NNODES")), - process_id=int(os.getenv("NODE_RANK"))) + process_id=int(os.getenv("NODE_RANK")), + ) max_logging.log(f"JAX global devices: {jax.devices()}") + def initialize_jax_for_cpu(): - """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready. - """ + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" coordinator_ip_address = get_coordinator_ip_address() - coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK # Env variables to be set in XPK or otherwise job_index = int(os.environ.get("JOB_INDEX")) job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) @@ -227,17 +244,20 @@ def initialize_jax_for_cpu(): pid = job_index * processes_in_job + job_completion_index max_logging.log(f" Jax process id is {pid} ") # Explicit initialize is needed only for CPUs - jax.distributed.initialize(coordinator_address=coordinator_address, - process_id=pid, - num_processes=int(os.environ.get("JAX_PROCESS_COUNT"))) + jax.distributed.initialize( + coordinator_address=coordinator_address, process_id=pid, num_processes=int(os.environ.get("JAX_PROCESS_COUNT")) + ) + def is_cpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a CPU backend.""" - return raw_keys["hardware"] == 'cpu' + return raw_keys["hardware"] == "cpu" + def is_gpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a GPU backend.""" - return raw_keys["hardware"] == 'gpu' + return raw_keys["hardware"] == "gpu" + def get_coordinator_ip_address(): """Get coordinator IP Address with retries""" @@ -260,48 +280,66 @@ def get_coordinator_ip_address(): max_logging.log(f"Coordinator IP address: {coordinator_ip_address}") return coordinator_ip_address + def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type): """Evaluates unspecified DCN/ICI parallelism values""" if -1 in parallelism_vals: - assert parallelism_vals.count(-1) == 1, f"Found unspecified values (-1) for more than one {parallelism_type}\ + assert ( + parallelism_vals.count(-1) == 1 + ), f"Found unspecified values (-1) for more than one {parallelism_type}\ parallelism axis. At most one axis can be unspecified." - determined_val = target_product/np.product(parallelism_vals)*-1 + determined_val = target_product / np.product(parallelism_vals) * -1 - assert determined_val >= 1 and determined_val.is_integer, f"Unspecified value unable to be determined with the given\ + assert ( + determined_val >= 1 and determined_val.is_integer + ), f"Unspecified value unable to be determined with the given\ {parallelism_type} parallelism values" parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) - target_type = "slices" if parallelism_type == 'DCN' else "devices per slice" + target_type = "slices" if parallelism_type == "DCN" else "devices per slice" - assert np.product(parallelism_vals) == target_product, f"Number of {target_type} {target_product} does not match\ + assert ( + np.product(parallelism_vals) == target_product + ), f"Number of {target_type} {target_product} does not match\ the product of the {parallelism_type} parallelism {np.product(parallelism_vals)}" return parallelism_vals + def create_device_mesh(config, devices=None): - """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas """ + """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: devices = jax.devices() num_devices = len(devices) num_slices = config.num_slices - num_devices_per_slice = num_devices//num_slices + num_devices_per_slice = num_devices // num_slices multi_slice_env = num_slices > 1 - dcn_parallelism = [config.dcn_data_parallelism, config.dcn_fsdp_parallelism, - config.dcn_fsdp_transpose_parallelism, config.dcn_sequence_parallelism, - config.dcn_tensor_parallelism, config.dcn_autoregressive_parallelism] - ici_parallelism = [config.ici_data_parallelism, config.ici_fsdp_parallelism, - config.ici_fsdp_transpose_parallelism, config.ici_sequence_parallelism, - config.ici_tensor_parallelism, config.ici_autoregressive_parallelism] + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_fsdp_transpose_parallelism, + config.dcn_sequence_parallelism, + config.dcn_tensor_parallelism, + config.dcn_autoregressive_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_fsdp_transpose_parallelism, + config.ici_sequence_parallelism, + config.ici_tensor_parallelism, + config.ici_autoregressive_parallelism, + ] # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, 'ICI') + ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") if multi_slice_env: - dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, 'DCN') + dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) else: mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) @@ -310,41 +348,35 @@ def create_device_mesh(config, devices=None): return mesh -def unbox_logicallypartioned( - boxed_pytree): - """ Unboxes the flax.LogicallyPartitioned pieces - Args: - boxed_pytree: a pytree that includes LogicallyPartitioned - leaves. - Returns: - a pytree where all all LogicallyPartitioned leaves have been unboxed. +def unbox_logicallypartioned(boxed_pytree): + """Unboxes the flax.LogicallyPartitioned pieces + + Args: + boxed_pytree: a pytree that includes LogicallyPartitioned + leaves. + Returns: + a pytree where all all LogicallyPartitioned leaves have been unboxed. """ - return jax.tree_util.tree_map(lambda x: x.unbox() if \ - isinstance(x, flax.linen.spmd.LogicallyPartitioned) \ - else x, boxed_pytree, \ - is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned)) + return jax.tree_util.tree_map( + lambda x: x.unbox() if isinstance(x, flax.linen.spmd.LogicallyPartitioned) else x, + boxed_pytree, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + def init_decode_state(apply_fn, params): """Init train state with null opt state for decode.""" - state = train_state.TrainState( - step=0, - apply_fn=apply_fn, - params=params, - tx=None, # type: ignore - opt_state={} - ) + state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state + def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create( - apply_fn=apply_fn, - params=params, - tx=tx - ) + state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state + def init_initial_state(model, tx, config, is_training, key): """ We pass in "static" objects like model, tx, config as JAX compares them by @@ -353,34 +385,37 @@ def init_initial_state(model, tx, config, is_training, key): Args: model, tx, config, is_training, key """ - input_shape = ( - config.global_batch_size_to_load, - config.max_target_length + input_shape = (config.global_batch_size_to_load, config.max_target_length) + model_vars = model.init( + {"params": key, "dropout": key, "aqt": key}, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(input_shape, dtype=jnp.int32), ) - model_vars = model.init({'params': key, 'dropout': key, 'aqt': key}, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(input_shape, dtype=jnp.int32)) if is_training: return init_training_state(model.apply, model_vars, tx) return init_decode_state(model.apply, model_vars) + def load_decode_model_vars(model, config, rng, mesh): state, _ = setup_decode_state(model, config, rng, mesh, None) return state.params + def setup_decode_state(model, config, rng, mesh, checkpoint_manager): is_training = False - state, state_mesh_annotations, _ = setup_initial_state(model, None, None, config, - rng, mesh, checkpoint_manager, - is_training) + state, state_mesh_annotations, _ = setup_initial_state( + model, None, None, config, rng, mesh, checkpoint_manager, is_training + ) return state, state_mesh_annotations + def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): is_training = True return setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training) + def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training=True): - """ We initialize the model and optimizer state, and optionally load from a + """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: @@ -397,33 +432,31 @@ def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_ state_mesh_annotations: the mesh annotations for the train state """ - unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(model, tx, config, - rng, mesh, is_training) + unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( + model, tx, config, rng, mesh, is_training + ) # Initialization with nn_partitioning.axis_rules(config.logical_axis_rules): - restored, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, - data_iterator, - config.load_parameters_path, - config.load_full_state_path, - unboxed_abstract_state, - config.enable_single_replica_ckpt_restoring, - config.dataset_type, - ) + restored, raw_params = checkpointing.load_state_if_possible( + checkpoint_manager, + data_iterator, + config.load_parameters_path, + config.load_full_state_path, + unboxed_abstract_state, + config.enable_single_replica_ckpt_restoring, + config.dataset_type, + ) if restored: - if 'iter' in restored and restored['iter'] is not None: - data_iterator.local_iterator = restored['iter'] - state = restored['items'] + if "iter" in restored and restored["iter"] is not None: + data_iterator.local_iterator = restored["iter"] + state = restored["items"] else: init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings - )(rng) - if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params = raw_params) + state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings)(rng) + if raw_params: # If we loaded a partial state, we need to merge it. + state = state.replace(params=raw_params) state = unbox_logicallypartioned(state) return state, state_mesh_annotations, data_iterator @@ -432,6 +465,7 @@ def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_ # Learning Rate Schedule # ----------------------------------------------------------------------------- + def create_learning_rate_schedule(config): """Creates a warmup and cosine decay learning rate schedule: We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 @@ -441,12 +475,14 @@ def create_learning_rate_schedule(config): 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. The zero learning rate section can be used to more accurately measure the fully trained model's performance. """ + def make_cos_schedule(init_lr, final_lr, len_steps): def schedule(step): pct = (step) / len_steps - a = 0.5 * (jnp.cos(jnp.pi*pct) + 1) + a = 0.5 * (jnp.cos(jnp.pi * pct) + 1) lr = init_lr * a + final_lr * (1 - a) return lr + return schedule lr = config.learning_rate @@ -456,19 +492,15 @@ def schedule(step): cos_steps = config.learning_rate_schedule_steps - warmup_steps constant_zero_steps = config.steps - config.learning_rate_schedule_steps - warmup_schedule = optax.linear_schedule( - init_value=0.0, - end_value=lr, - transition_steps=warmup_steps - ) + warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) constant_schedule = optax.constant_schedule(0.0) pieces = [warmup_schedule, cos_schedule] - boundaries=[ - warmup_steps, - warmup_steps + cos_steps, - ] + boundaries = [ + warmup_steps, + warmup_steps + cos_steps, + ] if constant_zero_steps > 0: pieces.append(constant_schedule) @@ -480,8 +512,7 @@ def schedule(step): # Cross entropy implementation is taken from original T5X codebase: # https://github.com/google-research/t5x/blob/ace831eea1e2742b4299cd1a9af7e4f302038351/t5x/losses.py#L25-L101 @jax.custom_vjp -def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, - z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: +def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: """Computes cross entropy loss with stable custom gradient. Computes a stabilized-gradient version of: -jnp.sum(targets * nn.log_softmax(logits), axis=-1) @@ -511,12 +542,11 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, def _cross_entropy_with_logits_fwd( - logits: jnp.ndarray, - targets: jnp.ndarray, - z_loss: float = 0.0 -) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], - Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, - jnp.ndarray, jnp.ndarray, jnp.ndarray]]: + logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0 +) -> Tuple[ + Tuple[jnp.ndarray, jnp.ndarray], + Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], +]: """Forward-mode of `cross_entropy_with_logits`.""" max_logit = logits.max(axis=-1, keepdims=True) shifted = logits - max_logit @@ -528,32 +558,40 @@ def _cross_entropy_with_logits_fwd( log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss - return (loss, total_z_loss), (logits, targets, z_loss, exp_shifted, sum_exp, #pytype: disable=bad-return-type #jax-ndarray - log_softmax, log_z) + return (loss, total_z_loss), ( + logits, + targets, + z_loss, + exp_shifted, + sum_exp, # pytype: disable=bad-return-type #jax-ndarray + log_softmax, + log_z, + ) def _cross_entropy_with_logits_bwd( - res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, - jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray] + res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], + g: Tuple[jnp.ndarray, jnp.ndarray], ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Backward-mode of `cross_entropy_with_logits`.""" g = g[0] # Ignore z_loss component as that is only used for logging. logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res # z-loss term adds the (2 * z_loss * log_z) factor. - deriv = ( - jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - - targets) + deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets g_logits = jnp.expand_dims(g, axis=-1) * deriv g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax - return (jnp.asarray(g_logits, - logits.dtype), jnp.asarray(g_targets, targets.dtype), - jnp.array(0.0)) # sets z-loss coeff gradient to 0 + return ( + jnp.asarray(g_logits, logits.dtype), + jnp.asarray(g_targets, targets.dtype), + jnp.array(0.0), + ) # sets z-loss coeff gradient to 0 + + +cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) -cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, - _cross_entropy_with_logits_bwd) def get_abstract_state(model, tx, config, rng, mesh, is_training=True): - """ Get a shaped abstraction of the state (including optimizer)""" + """Get a shaped abstraction of the state (including optimizer)""" init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -561,14 +599,9 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): state_logical_annotations = nn.get_partition_spec(abstract_state) - state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, - config.logical_axis_rules) + state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) - abstract_sharded_state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings - ).eval_shape(rng) + abstract_sharded_state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings).eval_shape(rng) unboxed_abstract_sharded_state = unbox_logicallypartioned(abstract_sharded_state) # Initialization @@ -576,24 +609,23 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) return unboxed_abstract_sharded_state, state_mesh_annotations, state_mesh_shardings + def get_kv_cache_annotations(model, config, rng, mesh): - """ Get a shaped abstraction of the state (including optimizer)""" + """Get a shaped abstraction of the state (including optimizer)""" def init_kv_cache(model, config): - input_shape = ( - config.global_batch_size_to_load, - config.max_prefill_predict_length - ) + input_shape = (config.global_batch_size_to_load, config.max_prefill_predict_length) - model_vars = model.init({'params': rng, 'dropout': rng, 'aqt': rng}, - jnp.ones(input_shape), - jnp.ones(input_shape), - model_mode=common_types.MODEL_MODE_PREFILL) - return model_vars['cache'] + model_vars = model.init( + {"params": rng, "dropout": rng, "aqt": rng}, + jnp.ones(input_shape), + jnp.ones(input_shape), + model_mode=common_types.MODEL_MODE_PREFILL, + ) + return model_vars["cache"] with nn_partitioning.axis_rules(config.logical_axis_rules): - init_kv_cache_partial = functools.partial(init_kv_cache, model, - config) + init_kv_cache_partial = functools.partial(init_kv_cache, model, config) abstract_state = jax.eval_shape(init_kv_cache_partial) state_logical_annotations = nn.get_partition_spec(abstract_state) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): @@ -604,17 +636,19 @@ def init_kv_cache(model, config): def print_pytree_shape(print_str, ptree): print("\n") print(print_str) - print(jax.tree_util.tree_map(lambda x : x.shape, ptree)) + print(jax.tree_util.tree_map(lambda x: x.shape, ptree)) + def print_model_vars(print_str, model_vars): for k in model_vars: - print(f'{print_str} key{k}:') - print(f'\t {model_vars[k]}') + print(f"{print_str} key{k}:") + print(f"\t {model_vars[k]}") + def get_project(): completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) - project_outputs = completed_command.stdout.decode().strip().split('\n') - if len(project_outputs) < 1 or project_outputs[-1]=='': + project_outputs = completed_command.stdout.decode().strip().split("\n") + if len(project_outputs) < 1 or project_outputs[-1] == "": max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project '") return None return project_outputs[-1] diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index 275ffac47..ed38929e2 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -''' Implementation of Engine API for MaxText ''' +"""Implementation of Engine API for MaxText""" import functools from typing import Any, Optional, Tuple @@ -39,10 +39,10 @@ Params = Any - @struct.dataclass class DecodeState: """The inputs into a generation step.""" + prefill_cache: jax.Array generate_cache: jax.Array generate_cache_index: int @@ -66,7 +66,7 @@ def __init__(self, config): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - self.model = models.Transformer(config, mesh = self._mesh, quant=quant) + self.model = models.Transformer(config, mesh=self._mesh, quant=quant) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -76,47 +76,48 @@ def __init__(self, config): self.state_mesh_annotations = None def load_params(self, *args, **kwargs) -> Params: - ''' Load Parameters, typically from GCS ''' + """Load Parameters, typically from GCS""" # pylint: disable=unused-argument - state, self.state_mesh_annotations = max_utils.setup_decode_state( - self.model, self.config, self.rng, self._mesh, None + state, self.state_mesh_annotations = max_utils.setup_decode_state(self.model, self.config, self.rng, self._mesh, None) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params ) - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - state.params) self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh) - self.kv_cache_shardings = jax.tree_map(lambda x : jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) + self.kv_cache_shardings = jax.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) if not self.model.quant: - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - state.params) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params + ) return state.params else: - self.model.quant.quant_mode = quantizations.get_quant_mode('convert') + self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @jax.jit def model_apply(_p, _rng): return self.model.apply( - _p | {"aqt": {}}, - jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32), - jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32), - decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32), - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': _rng}, - mutable=True + _p | {"aqt": {}}, + jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": _rng}, + mutable=True, ) _, new_vars = model_apply(state.params, self.rng) params = {} - params['aqt'] = new_vars['aqt'] + params["aqt"] = new_vars["aqt"] # Remove param values which have corresponding qtensors in aqt to save memory. - params['params'] = quantizations.remove_quantized_params(state.params['params'], new_vars['aqt']) + params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"]) - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - params) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params + ) - self.model.quant.quant_mode = quantizations.get_quant_mode('serve') + self.model.quant.quant_mode = quantizations.get_quant_mode("serve") return params @functools.partial(jax.jit, static_argnums=(0,)) @@ -143,7 +144,7 @@ def prefill( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") - input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) zero_to_n = jnp.arange(0, padded_tokens.shape[0]) @@ -153,45 +154,52 @@ def prefill( with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): flat_logits, new_vars = self.model.apply( - params, - input_tokens, - positions, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': self.rng}, - mutable=["cache"] + params, + input_tokens, + positions, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": self.rng}, + mutable=["cache"], ) - next_pos = jnp.full((1,1), true_length, dtype = jnp.int32) - generated_tokens = jnp.zeros((1,1), dtype = jnp.int32) - selected_logits = jax.lax.dynamic_slice(flat_logits, (0, true_length-1,0), - (flat_logits.shape[0], 1, flat_logits.shape[2])) + next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32) + generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32) + selected_logits = jax.lax.dynamic_slice( + flat_logits, (0, true_length - 1, 0), (flat_logits.shape[0], 1, flat_logits.shape[2]) + ) selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding) - return {"logits" : selected_logits, "cache" : new_vars['cache'], - "next_pos" : next_pos, "generated_tokens" : generated_tokens} + return { + "logits": selected_logits, + "cache": new_vars["cache"], + "next_pos": next_pos, + "generated_tokens": generated_tokens, + } @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,)) - def generate( - self, params: Params, decode_state: DecodeState - ) -> Tuple[DecodeState, engine_api.ResultTokens]: - '''Run one generate step''' - previous_logits = decode_state['logits'] - - new_token = inference_utils.sampling(previous_logits, self.rng, self.config.decode_sampling_strategy, - topk=self.config.decode_sampling_top_k, - nucleus_topp=self.config.decode_sampling_nucleus_p, - temperature=self.config.decode_sampling_temperature) + def generate(self, params: Params, decode_state: DecodeState) -> Tuple[DecodeState, engine_api.ResultTokens]: + """Run one generate step""" + previous_logits = decode_state["logits"] + + new_token = inference_utils.sampling( + previous_logits, + self.rng, + self.config.decode_sampling_strategy, + topk=self.config.decode_sampling_top_k, + nucleus_topp=self.config.decode_sampling_nucleus_p, + temperature=self.config.decode_sampling_temperature, + ) with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): out_logits, new_vars = self.model.apply( - params | { 'cache': decode_state['cache']}, - new_token, - decode_state['next_pos'], - enable_dropout=False, - model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'params': self.rng}, - mutable=['cache'] + params | {"cache": decode_state["cache"]}, + new_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": self.rng}, + mutable=["cache"], ) all_valid = jnp.ones(new_token.shape, dtype=jnp.int8) @@ -212,35 +220,46 @@ def generate( out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) - return {"logits" : out_logits, "cache" : new_cache, - "next_pos" : decode_state["next_pos"]+1, "generated_tokens" : decode_state["generated_tokens"]+1}, result - - @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(1, 2,)) + return { + "logits": out_logits, + "cache": new_cache, + "next_pos": decode_state["next_pos"] + 1, + "generated_tokens": decode_state["generated_tokens"] + 1, + }, result + + @functools.partial( + jax.jit, + static_argnums=(0,), + donate_argnums=( + 1, + 2, + ), + ) def insert( self, prefix: Prefix, decode_state: DecodeState, slot: int, ) -> DecodeState: - ''' Insert into KV cache ''' + """Insert into KV cache""" unboxed_prefix = max_utils.unbox_logicallypartioned(prefix) def copy(path, partial_cache, full_cache, annotations): path_key = path[-1].key - if path_key in ['cache_ar_index', 'cached_ar_key', 'cached_ar_value', 'cached_ar_key_scale', 'cached_ar_value_scale']: - return full_cache # we don't even zero these out because we can mask them out. + if path_key in ["cache_ar_index", "cached_ar_key", "cached_ar_value", "cached_ar_key_scale", "cached_ar_value_scale"]: + return full_cache # we don't even zero these out because we can mask them out. batch_idx = annotations.index("cache_batch") if "cache_batch" in annotations else -1 if batch_idx < 0: raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}") - if path_key == 'cache_ar_segment_id': + if path_key == "cache_ar_segment_id": ### goal: zero this out in case there is existing data s = list(full_cache.shape) s[batch_idx] = 1 zeros = jnp.zeros(tuple(s), dtype=jnp.int32) return jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx) - elif path_key == 'cache_prefill_segment_id': + elif path_key == "cache_prefill_segment_id": s = list(full_cache.shape) s[batch_idx] = 1 zeros = jnp.zeros(tuple(s), dtype=jnp.int32) @@ -249,31 +268,39 @@ def copy(path, partial_cache, full_cache, annotations): ## copy prefill cachce full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) return full_cache - elif path_key in ['cached_prefill_key', 'cached_prefill_value', - 'cached_prefill_key_scale', 'cached_prefill_value_scale']: + elif path_key in [ + "cached_prefill_key", + "cached_prefill_value", + "cached_prefill_key_scale", + "cached_prefill_value_scale", + ]: return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) else: raise ValueError(f"We don't have a strategy for inserting {path_key}") - inserted_cache = jax.tree_util.tree_map_with_path(copy, unboxed_prefix['cache'], decode_state['cache'], - self.kv_cache_annotations_named) - inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state['logits'], unboxed_prefix['logits'], slot, 0) - inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state['next_pos'], unboxed_prefix['next_pos'], slot, 0) - inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(decode_state['generated_tokens'], - unboxed_prefix['generated_tokens'], slot, 0) + inserted_cache = jax.tree_util.tree_map_with_path( + copy, unboxed_prefix["cache"], decode_state["cache"], self.kv_cache_annotations_named + ) + inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0) + inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0) + inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( + decode_state["generated_tokens"], unboxed_prefix["generated_tokens"], slot, 0 + ) inserted_logits = jax.lax.with_sharding_constraint(inserted_logits, self.replicated_sharding) inserted_generated_tokens = jax.lax.with_sharding_constraint(inserted_generated_tokens, self.replicated_sharding) inserted_next_pos = jax.lax.with_sharding_constraint(inserted_next_pos, self.replicated_sharding) inserted_cache = jax.lax.with_sharding_constraint(inserted_cache, self.kv_cache_shardings) - return {'logits' : inserted_logits, 'cache' : inserted_cache, - 'next_pos' : inserted_next_pos, 'generated_tokens' : inserted_generated_tokens } + return { + "logits": inserted_logits, + "cache": inserted_cache, + "next_pos": inserted_next_pos, + "generated_tokens": inserted_generated_tokens, + } def get_prefix_destination_sharding(self) -> Any: - return jax.sharding.NamedSharding( - mesh=self.mesh, spec=jax.sharding.PartitionSpec() - ) + return jax.sharding.NamedSharding(mesh=self.mesh, spec=jax.sharding.PartitionSpec()) def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: """Return a protobuf of tokenizer info, callable from Py or C++.""" @@ -281,28 +308,32 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: def init_decode_state(self, *args, **kwargs) -> DecodeState: """Initialises any state which a generation step transforms.""" + # pylint: disable=unused-argument def init(abstract_params): - x = jnp.ones( (int(self.config.per_device_batch_size * jax.device_count()), self.config.max_prefill_predict_length), - dtype=jnp.int32) + x = jnp.ones( + (int(self.config.per_device_batch_size * jax.device_count()), self.config.max_prefill_predict_length), + dtype=jnp.int32, + ) _, cache = self.model.apply( - abstract_params, - x, - x, - decoder_segment_ids=jnp.zeros(x.shape, dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR, - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': self.rng}, - mutable=["cache"] + abstract_params, + x, + x, + decoder_segment_ids=jnp.zeros(x.shape, dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": self.rng}, + mutable=["cache"], ) next_pos = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32) generated_tokens = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32) - return {"logits" : jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1, self.config.vocab_size)), - "cache" : cache["cache"], - "next_pos" : next_pos, - "generated_tokens" : generated_tokens - } + return { + "logits": jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1, self.config.vocab_size)), + "cache": cache["cache"], + "next_pos": next_pos, + "generated_tokens": generated_tokens, + } with nn_partitioning.axis_rules(self.config.logical_axis_rules): abstract_outputs = jax.eval_shape(init, self.abstract_params) @@ -311,18 +342,20 @@ def init(abstract_params): with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): mesh_annotations = nn.logical_to_mesh(logical_annotations) - shardings = jax.tree_map(lambda mesh_annotation : jax.sharding.NamedSharding(self._mesh, mesh_annotation), - mesh_annotations) + shardings = jax.tree_map( + lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation), mesh_annotations + ) - @functools.partial(jax.jit, out_shardings = shardings) + @functools.partial(jax.jit, out_shardings=shardings) def initialize(): - return jax.tree_map( lambda x : jnp.zeros(x.shape, x.dtype), abstract_outputs) + return jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs) - cache = initialize()['cache'] + cache = initialize()["cache"] def is_lp(k): return isinstance(k, flax.linen.spmd.LogicallyPartitioned) - self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x : tuple(x.names), cache, is_leaf=is_lp) + + self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x: tuple(x.names), cache, is_leaf=is_lp) del cache zeroed = max_utils.unbox_logicallypartioned(initialize()) return zeroed diff --git a/MaxText/maxengine_config.py b/MaxText/maxengine_config.py index da96c06f9..967b0eb52 100644 --- a/MaxText/maxengine_config.py +++ b/MaxText/maxengine_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -'''Configure MaxText For JetStream''' +"""Configure MaxText For JetStream""" import functools import jax @@ -22,26 +22,23 @@ from jetstream.engine import engine_api import maxengine + def create_maxengine(devices: config_lib.Devices, config: Any) -> engine_api.Engine: del devices return maxengine.MaxEngine(config) def get_server_config(config_str: str, config: Any) -> Type[config_lib.ServerConfig]: - ''' Gets the Server Config Required by JetStream ''' + """Gets the Server Config Required by JetStream""" match config_str: - case 'MaxtextInterleavedServer': + case "MaxtextInterleavedServer": server_config = config_lib.ServerConfig( - prefill_slices = (), - generate_slices = (), - interleaved_slices = ('tpu='+str(jax.device_count()),), - prefill_engine_create_fns = (), - generate_engine_create_fns = (), - interleaved_engine_create_fns = (functools.partial( - create_maxengine, - config=config - ), - ) + prefill_slices=(), + generate_slices=(), + interleaved_slices=("tpu=" + str(jax.device_count()),), + prefill_engine_create_fns=(), + generate_engine_create_fns=(), + interleaved_engine_create_fns=(functools.partial(create_maxengine, config=config),), ) case _: raise NotImplementedError diff --git a/MaxText/maxengine_server.py b/MaxText/maxengine_server.py index 3cf46c8a5..39fcde33d 100644 --- a/MaxText/maxengine_server.py +++ b/MaxText/maxengine_server.py @@ -19,7 +19,7 @@ import sys import pyconfig -import maxengine_config +import maxengine_config from jetstream.core import server_lib # _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') @@ -36,7 +36,7 @@ def main(config): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() - server_config = maxengine_config.get_server_config('MaxtextInterleavedServer', config) + server_config = maxengine_config.get_server_config("MaxtextInterleavedServer", config) # We separate credential from run so that we can unit test it with # local credentials. # TODO: Add grpc credentials for OSS. @@ -49,8 +49,8 @@ def main(config): jetstream_server.wait_for_termination() -if __name__ == '__main__': - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') +if __name__ == "__main__": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(sys.argv) cfg = pyconfig.config diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 6b7a81b03..b41085a18 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting to MaxText. """ @@ -28,45 +28,45 @@ from input_pipeline import input_pipeline_interface - def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config): - """ Get the shardings (both state and data) for train_step """ + """Get the shardings (both state and data) for train_step""" functional_train = get_functional_train_step(train_step, model, config) functional_train.__name__ = "train_step" data_pspec = P(*config.data_sharding) - state_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) - data_sharding = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng - out_shardings = (state_mesh_shardings, None) # State, metrics - static_argnums = () # We partial out the static argnums of model and config - donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. + state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shardings = (state_mesh_shardings, None) # State, metrics + static_argnums = () # We partial out the static argnums of model and config + donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums + def get_functional_train_step(train_step, model, config): return functools.partial(train_step, model, config) + def get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config): - """ Get the shardings (both state and data) for eval_step """ + """Get the shardings (both state and data) for eval_step""" functional_eval = get_functional_eval_step(eval_step, model, config) functional_eval.__name__ = "eval_step" data_pspec = P(*config.data_sharding) - state_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) - data_sharding = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng - out_shardings = None # metrics - static_argnums = () # We partial out the static argnums of model, config - donate_argnums = () # state will be kept instead of being donated in eval_step + state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shardings = None # metrics + static_argnums = () # We partial out the static argnums of model, config + donate_argnums = () # state will be kept instead of being donated in eval_step return functional_eval, in_shardings, out_shardings, static_argnums, donate_argnums + def get_functional_eval_step(eval_step, model, config): return functools.partial(eval_step, model, config) + def load_compiled(config, partial_train, state): - """ # Loading a serialized compiled train step function.""" + """# Loading a serialized compiled train step function.""" + # Currently partial_train and state are needed to reconstruct # input/output shapes to construct the in_trees and out_trees for load API # Parker is working on a serializing these @@ -90,40 +90,58 @@ def get_train_input_output_trees(func, input_args, input_kwargs): p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree) return p_train_step + # https://arxiv.org/pdf/2204.02311.pdf Appendix B def calculate_tflops_training_per_device(num_model_parameters, config, log=True): - """ Calculate training TFLOP""" - learnable_weight_tflops = 6 * num_model_parameters * config.max_target_length * config.per_device_batch_size \ - / 10**12 - noncasual_attention_flops = 12 * config.num_query_heads * config.num_decoder_layers * config.head_dim \ - * config.max_target_length**2 * config.per_device_batch_size / 10**12 - causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention + """Calculate training TFLOP""" + learnable_weight_tflops = 6 * num_model_parameters * config.max_target_length * config.per_device_batch_size / 10**12 + noncasual_attention_flops = ( + 12 + * config.num_query_heads + * config.num_decoder_layers + * config.head_dim + * config.max_target_length**2 + * config.per_device_batch_size + / 10**12 + ) + causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention total_tflops = learnable_weight_tflops + causal_attention_tflops if log: - print('Per train step:\n', - f'Total TFLOPs: {total_tflops:.2f} \n', - f'split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops', - f'and {100 * causal_attention_tflops/total_tflops:.2f}% attention flops') + print( + "Per train step:\n", + f"Total TFLOPs: {total_tflops:.2f} \n", + f"split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops", + f"and {100 * causal_attention_tflops/total_tflops:.2f}% attention flops", + ) return total_tflops, learnable_weight_tflops, causal_attention_tflops + # https://arxiv.org/pdf/2204.02311.pdf Appendix B def calculate_tflops_prefill(num_model_parameters, prefill_length, config, log=True): - """ Calculate training TFLOP""" - learnable_weight_tflops = 2 * num_model_parameters * prefill_length \ - / 10**12 - noncasual_attention_flops = 4 * config.num_query_heads * config.num_decoder_layers * config.head_dim \ - * prefill_length**2 * config.per_device_batch_size / 10**12 - causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention + """Calculate training TFLOP""" + learnable_weight_tflops = 2 * num_model_parameters * prefill_length / 10**12 + noncasual_attention_flops = ( + 4 + * config.num_query_heads + * config.num_decoder_layers + * config.head_dim + * prefill_length**2 + * config.per_device_batch_size + / 10**12 + ) + causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention total_tflops = learnable_weight_tflops + causal_attention_tflops if log: - print('Per prefill step: \n', - f'\tTotal TFLOPs: {total_tflops:.2f} \n', - f'\t\tLearnable weight TFLOPs: {learnable_weight_tflops} ', - f'({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n', - f'\t\tCausal attention TFLOPs: {causal_attention_tflops} ', - f'({100 * causal_attention_tflops/total_tflops:.2f})% of Total') + print( + "Per prefill step: \n", + f"\tTotal TFLOPs: {total_tflops:.2f} \n", + f"\t\tLearnable weight TFLOPs: {learnable_weight_tflops} ", + f"({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n", + f"\t\tCausal attention TFLOPs: {causal_attention_tflops} ", + f"({100 * causal_attention_tflops/total_tflops:.2f})% of Total", + ) return total_tflops, learnable_weight_tflops, causal_attention_tflops @@ -144,22 +162,15 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01): bool: True if the majority of parameters are sufficiently sharded """ total_num_params = max_utils.calculate_num_params_from_pytree(params) - product_num_devices_for_weight_sharding = 1 - for axis in ['fsdp', 'fsdp_transpose', 'sequence', 'tensor']: + product_num_devices_for_weight_sharding = 1 + for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] - total_num_params_per_chip = ( - max_utils.calculate_total_params_per_chip( - params) + total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) + perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding + assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( + "Number of parameters per chip must not be less than in the ideal sharded " + "scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes." ) - perfectly_sharded_params_per_chip = ( - total_num_params / product_num_devices_for_weight_sharding + assert total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 < tolerance, ( + f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " "of total parameters." ) - assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( - 'Number of parameters per chip must not be less than in the ideal sharded ' - 'scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes.' - ) - assert ( - total_num_params_per_chip/perfectly_sharded_params_per_chip - 1 < tolerance - ), (f'Number of unsharded parameters exceeds tolerance {tolerance * 100}% ' - 'of total parameters.') - diff --git a/MaxText/multihost_dataloading.py b/MaxText/multihost_dataloading.py index 8c9088961..fca337691 100644 --- a/MaxText/multihost_dataloading.py +++ b/MaxText/multihost_dataloading.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=unused-import """SPMD Multihost Dataloading Utilities. @@ -36,6 +36,7 @@ import max_logging + def _build_global_shape_and_sharding( local_shape: tuple[int, ...], global_mesh: Mesh ) -> tuple[tuple[int, ...], NamedSharding]: @@ -47,27 +48,23 @@ def _build_global_shape_and_sharding( def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: - """ Put local sharded array into local devices - """ + """Put local sharded array into local devices""" global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh) try: local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0) except ValueError as array_split_error: raise ValueError( - f"Unable to put to devices shape {array.shape} with " - f"local device count {len(global_mesh.local_devices)} " - f"at {jtu.keystr(path)}" + f"Unable to put to devices shape {array.shape} with " + f"local device count {len(global_mesh.local_devices)} " + f"at {jtu.keystr(path)}" ) from array_split_error local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices) return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) - -def get_next_batch_sharded( - local_iterator: Iterator, global_mesh: Mesh -) -> jax.Array: +def get_next_batch_sharded(local_iterator: Iterator, global_mesh: Mesh) -> jax.Array: """Splits the host loaded data equally over all devices.""" SLEEP_TIME = 10 @@ -88,13 +85,14 @@ def get_next_batch_sharded( if not loaded_data_success: local_data = next(local_iterator) - input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh = global_mesh), local_data) + input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh), local_data) return input_gdas class MultiHostDataLoadIterator: """fold get_next_batch_sharded into a iterator class""" + def __init__(self, dataloader: Union[tf.data.Dataset, grain.DataLoader], global_mesh: Mesh): self.global_mesh = global_mesh self.dataloader = dataloader diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index e2a23abda..63fcc42b1 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator, ungrouped-imports """Utils that are only interesting to MaxText. """ @@ -29,25 +29,26 @@ def get_optimizer(config, learning_rate_schedule): if config.opt_type == "adamw": # Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 return optax.adamw( - learning_rate_schedule, - b1=config.adam_b1, - b2=config.adam_b2, - eps=config.adam_eps, - eps_root=config.adam_eps_root, - weight_decay=config.adam_weight_decay, + learning_rate_schedule, + b1=config.adam_b1, + b2=config.adam_b2, + eps=config.adam_eps, + eps_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, ) elif config.opt_type == "adam_pax": return adam_pax( - learning_rate_schedule, - beta1=config.adam_b1, - beta2=config.adam_b2, - epsilon=config.adam_eps, - epsilon_root=config.adam_eps_root, - weight_decay=config.adam_weight_decay, + learning_rate_schedule, + beta1=config.adam_b1, + beta2=config.adam_b2, + epsilon=config.adam_eps, + epsilon_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, ) else: raise ValueError(f"{config.opt_type=} is not a supported.") + def adam_pax( learning_rate_fn: optax.Schedule, beta1: float, @@ -55,7 +56,7 @@ def adam_pax( epsilon: float, epsilon_root: float, weight_decay: float, - ) -> optax.GradientTransformation: +) -> optax.GradientTransformation: """Standard Adam optimizer that supports weight decay. Follows the implemenation in pax/praxis sharded_adam @@ -77,8 +78,7 @@ def adam_pax( """ def init_fn(params): - mu = jax.tree_util.tree_map( # First moment - jnp.zeros_like, params) + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment return optax.ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) @@ -102,8 +102,8 @@ def bias_corrected_decay(step: jnp.int32, decay: float): Returns: Bias corrected decay. """ - t = step.astype(jnp.float32) + 1. - return decay * (1. - jnp.power(decay, t - 1.)) / (1. - jnp.power(decay, t)) + t = step.astype(jnp.float32) + 1.0 + return decay * (1.0 - jnp.power(decay, t - 1.0)) / (1.0 - jnp.power(decay, t)) def update_fn(updates, state, params=None): # Sanitize updates just in case. @@ -112,6 +112,7 @@ def update_fn(updates, state, params=None): count = state.count class _slot_opt_state: + def __init__(self, mu, nu): self.mu = mu self.nu = nu @@ -133,8 +134,7 @@ def _update_momentum(update, mu, nu): mu = jax.tree_map(lambda x: x.mu, updated_moments) nu = jax.tree_map(lambda x: x.nu, updated_moments) - updates = jax.tree_map( - lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) + updates = jax.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) if weight_decay > 0: updates = jax.tree_map(lambda x, v: x + weight_decay * v, updates, params) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index e19af5456..204b32463 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, bare-except, consider-using-generator from collections import OrderedDict @@ -35,9 +35,11 @@ # YAML attribute to specify inheritance. _BASE_CONFIG_ATTR = "base_config" + def yaml_key_to_env_key(s: str) -> str: return _MAX_PREFIX + s.upper() + def string_to_bool(s: str) -> bool: if s.lower() == "true": return True @@ -45,60 +47,80 @@ def string_to_bool(s: str) -> bool: return False raise ValueError(f"Can't convert {s} to bool") -_yaml_types_to_parser = {str : str, int : int, float : float, bool : string_to_bool} + +_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool} + def validate_attention_type(s: str) -> None: - valid_attention_types = ('autoselected', 'dot_product', 'flash', 'cudnn_flash_te') - if s not in valid_attention_types: # currently supported attention - raise ValueError( - "Invalid attention type was passed. Valid options ", valid_attention_types - ) + valid_attention_types = ("autoselected", "dot_product", "flash", "cudnn_flash_te") + if s not in valid_attention_types: # currently supported attention + raise ValueError("Invalid attention type was passed. Valid options ", valid_attention_types) + def validate_keys(keys): - validate_attention_type(keys['attention']) + validate_attention_type(keys["attention"]) + + assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[ + "enable_checkpointing" + ], "You must set enable_checkpointing to load a checkpoint" + assert ( + keys["load_parameters_path"] == "" or keys["load_full_state_path"] == "" + ), "At most one of `load_parameters_path` or `load_full_state_path` should be set" - assert ((keys["load_parameters_path"]=="" and keys["load_full_state_path"]=="") or - keys["enable_checkpointing"]), "You must set enable_checkpointing to load a checkpoint" - assert keys["load_parameters_path"]=="" or keys["load_full_state_path"]=="",\ - "At most one of `load_parameters_path` or `load_full_state_path` should be set" def validate_model_name(s: str) -> bool: + """Validate provided model name.""" # currently supported models - valid_model_names = ('default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'mistral-7b', - 'mixtral-8x7b', 'gemma-7b','gemma-2b', - 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k') + valid_model_names = ( + "default", + "llama2-7b", + "llama2-13b", + "llama2-70b", + "mistral-7b", + "mixtral-8x7b", + "gemma-7b", + "gemma-2b", + "gpt3-175b", + "gpt3-22b", + "gpt3-6b", + "gpt3-52k", + ) if s not in valid_model_names: - raise ValueError( - "Invalid model name was passed. Valid options ", valid_model_names - ) + raise ValueError("Invalid model name was passed. Valid options ", valid_model_names) + def validate_no_keys_overwritten_twice(keys1: list[str], keys2: list[str]): overwritten_keys = [k for k in keys1 if k in keys2] if overwritten_keys: raise ValueError( f"Keys {overwritten_keys} are overwritten from both the model" - " and the environment/command line. This isn't allowed.") + " and the environment/command line. This isn't allowed." + ) + _config = None config = None + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") max_logging.log(f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}") -def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any],list[Any]]: + +def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any], list[Any]]: return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l -class _HyperParameters(): + +class _HyperParameters: # pylint: disable=missing-class-docstring def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]): for environment_var in os.environ: - if environment_var[:len(_MAX_PREFIX)] == _MAX_PREFIX: - proposed_key = environment_var[len(_MAX_PREFIX):].lower() + if environment_var[: len(_MAX_PREFIX)] == _MAX_PREFIX: + proposed_key = environment_var[len(_MAX_PREFIX) :].lower() if proposed_key not in raw_data_from_yaml: raise ValueError(f"We received env `{environment_var}` but it doesn't match a key, so it is assumed a mistake.") - if not environment_var[len(_MAX_PREFIX):].isupper(): + if not environment_var[len(_MAX_PREFIX) :].isupper(): raise ValueError(f"We received env `{environment_var}` but it isn't all uppercase.") def _load_kwargs(self, argv: list[str], **kwargs): @@ -107,21 +129,17 @@ def _load_kwargs(self, argv: list[str], **kwargs): return args_dict def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]: - ''' Update model config from environment and command line - ''' + """Update model config from environment and command line""" raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs) updated_keys = [] for k in raw_data_from_cmd_line: if k not in raw_data_from_yaml: - raise ValueError( - f"Key {k} was passed at the command line but isn't in config." - ) + raise ValueError(f"Key {k} was passed at the command line but isn't in config.") for k in raw_data_from_yaml: if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ: - raise ValueError( - f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.") + raise ValueError(f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.") if not k in raw_data_from_cmd_line and not yaml_key_to_env_key(k) in os.environ: raw_keys[k] = raw_data_from_yaml[k] @@ -133,8 +151,9 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, else: new_proposal = os.environ.get(yaml_key_to_env_key(k)) - if (not isinstance(new_proposal, type(raw_data_from_yaml[k]))) and \ - (type(raw_data_from_yaml[k]) not in _yaml_types_to_parser): + if (not isinstance(new_proposal, type(raw_data_from_yaml[k]))) and ( + type(raw_data_from_yaml[k]) not in _yaml_types_to_parser + ): raise ValueError( f"For key '{k}', type {type(raw_data_from_yaml[k])} not in {_yaml_types_to_parser.keys()}, can't pass" " at the CLI or ENV" @@ -148,8 +167,7 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, new_proposal ) # take the command line value, but type it like the config value. except ValueError as e: - raise ValueError( - f"Couldn't parse value from CLI or ENV '{new_proposal}' for key '{k}'") from e + raise ValueError(f"Couldn't parse value from CLI or ENV '{new_proposal}' for key '{k}'") from e return updated_keys @@ -163,14 +181,10 @@ def _load_config(self, config_name: str) -> dict[str, Any]: if _BASE_CONFIG_ATTR in raw_data_from_yaml: parent_config_filename = raw_data_from_yaml[_BASE_CONFIG_ATTR] if not os.path.isabs(parent_config_filename): - loaded_parent_config_filename = os.path.join( - os.path.dirname(config_name), parent_config_filename - ) + loaded_parent_config_filename = os.path.join(os.path.dirname(config_name), parent_config_filename) if not os.path.isfile(loaded_parent_config_filename): dir_path = os.path.dirname(os.path.realpath(__file__)) - loaded_parent_config_filename = os.path.join( - dir_path, f"configs/{parent_config_filename}" - ) + loaded_parent_config_filename = os.path.join(dir_path, f"configs/{parent_config_filename}") else: loaded_parent_config_filename = parent_config_filename @@ -181,15 +195,14 @@ def _load_config(self, config_name: str) -> dict[str, Any]: return raw_data_from_yaml def __init__(self, argv: list[str], **kwargs): - config_name : str = argv[1] + config_name: str = argv[1] raw_data_from_yaml = self._load_config(config_name) self._validate_env_variables(raw_data_from_yaml) raw_keys = OrderedDict() keys_from_env_and_command_line = self._update_from_env_and_command_line(raw_keys, raw_data_from_yaml, argv, **kwargs) - max_logging.log( - f"Updating keys from env and command line: {keys_from_env_and_command_line}") + max_logging.log(f"Updating keys from env and command line: {keys_from_env_and_command_line}") keys_from_model = _HyperParameters.update_model_vars(argv[1], raw_keys, config_name) max_logging.log(f"Updating keys from model: {keys_from_model}") validate_no_keys_overwritten_twice(keys_from_env_and_command_line, keys_from_model) @@ -197,24 +210,24 @@ def __init__(self, argv: list[str], **kwargs): # We initialize the jax distributed system here because it must be done before device backend is initialized. max_utils.maybe_initialize_jax_distributed_system(raw_keys) - if raw_keys['jax_cache_dir']: - compilation_cache.set_cache_dir(os.path.expanduser(raw_keys['jax_cache_dir'])) + if raw_keys["jax_cache_dir"]: + compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"])) - if raw_keys['model_name'] == "gpt3-175b": + if raw_keys["model_name"] == "gpt3-175b": _HyperParameters.configure_gpt3_task(raw_keys) _HyperParameters.user_init(raw_keys) self.keys = raw_keys - keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension + keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension keys.sort() for k in keys: max_logging.log(f"Config param {k}: {raw_keys[k]}") @staticmethod def user_init(raw_keys): - '''Transformations between the config data and configs used at runtime''' + """Transformations between the config data and configs used at runtime""" if raw_keys["run_name"] == "": - raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default + raw_keys["run_name"] = os.environ.get("JOBSET_NAME") # using XPK default run_name = raw_keys["run_name"] base_output_directory = raw_keys["base_output_directory"] if run_name: @@ -222,22 +235,21 @@ def user_init(raw_keys): raw_keys["checkpoint_dir"] = os.path.join(base_output_directory, run_name, "checkpoints", "") raw_keys["metrics_dir"] = os.path.join(base_output_directory, run_name, "metrics", "") - if raw_keys["learning_rate_schedule_steps"]==-1: + if raw_keys["learning_rate_schedule_steps"] == -1: raw_keys["learning_rate_schedule_steps"] = raw_keys["steps"] - if raw_keys["steps"]==-1: + if raw_keys["steps"] == -1: raw_keys["steps"] = raw_keys["learning_rate_schedule_steps"] - emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys['global_parameter_scale']) - raw_keys['emb_dim'] = 2**emb_scale * raw_keys['base_emb_dim'] - raw_keys['num_query_heads'] = 2**num_head_scale * raw_keys['base_num_query_heads'] - raw_keys['num_kv_heads'] = 2**num_head_scale * raw_keys['base_num_kv_heads'] - raw_keys['mlp_dim'] = 2**mlp_dim_scale * raw_keys['base_mlp_dim'] - raw_keys['num_decoder_layers'] = 2**layer_scale * raw_keys['base_num_decoder_layers'] + emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys["global_parameter_scale"]) + raw_keys["emb_dim"] = 2**emb_scale * raw_keys["base_emb_dim"] + raw_keys["num_query_heads"] = 2**num_head_scale * raw_keys["base_num_query_heads"] + raw_keys["num_kv_heads"] = 2**num_head_scale * raw_keys["base_num_kv_heads"] + raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"] + raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"] - raw_keys['global_batch_size_to_load'], raw_keys['global_batch_size_to_train_on'] = \ - calculate_global_batch_sizes(raw_keys) - raw_keys['num_slices'] = get_num_slices(raw_keys) - raw_keys['quantization_local_shard_count'] = get_quantization_local_shard_count(raw_keys) + raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys) + raw_keys["num_slices"] = get_num_slices(raw_keys) + raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) print_system_information() @@ -253,73 +265,72 @@ def user_init(raw_keys): @staticmethod def configure_gpt3_task(raw_keys): - '''dynamically configure gpt3 task based on training rules''' + """dynamically configure gpt3 task based on training rules""" # follow https://github.com/google/paxml/blob/19db52eed85ae0d2365339b83a97cd0b873bbf73/paxml/tasks/lm/params/c4.py#L280 # according to training_rules of mlperf gpt3 training global_batch_size = calculate_global_batch_sizes(raw_keys)[1] if global_batch_size <= 3584: - raw_keys['learning_rate'] = 2e-5 + raw_keys["learning_rate"] = 2e-5 else: - raw_keys['learning_rate'] = 3e-5 + raw_keys["learning_rate"] = 3e-5 warmup_steps = math.ceil(265.0 * 1536 / global_batch_size - 1e-6) decay_end_step = math.ceil(108600.0 * 1536 / global_batch_size - 1e-6) - raw_keys['learning_rate_schedule_steps'] = decay_end_step - raw_keys['warmup_steps_fraction'] = warmup_steps / decay_end_step + raw_keys["learning_rate_schedule_steps"] = decay_end_step + raw_keys["warmup_steps_fraction"] = warmup_steps / decay_end_step global_batch_size_to_train_on = calculate_global_batch_sizes(raw_keys)[1] - raw_keys['eval_interval'] = math.ceil(24567 / global_batch_size_to_train_on) + raw_keys["eval_interval"] = math.ceil(24567 / global_batch_size_to_train_on) @staticmethod - def update_model_vars(base_config_path, raw_keys, config_name : str): - ''' Update model config variables - ''' - validate_model_name(raw_keys['model_name']) + def update_model_vars(base_config_path, raw_keys, config_name: str): + """Update model config variables""" + validate_model_name(raw_keys["model_name"]) max_logging.log(f"Running Model: {raw_keys['model_name']}") updated_keys = [] - if raw_keys['model_name'] != 'default': - model_name = raw_keys['model_name'] + if raw_keys["model_name"] != "default": + model_name = raw_keys["model_name"] # First look at the model configs next to the base_config_path, and # fallback to the python codebase if the config cannot be found. - file_path = os.path.join( - os.path.dirname(base_config_path), f"models/{model_name}.yml" - ) + file_path = os.path.join(os.path.dirname(base_config_path), f"models/{model_name}.yml") if not os.path.isfile(file_path): dir_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.join(dir_path, f"configs/models/{model_name}.yml") - with open(file_path, 'r', encoding="utf-8") as file: + with open(file_path, "r", encoding="utf-8") as file: model_vars = yaml.safe_load(file) updated_keys = list(model_vars.keys()) raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name) return updated_keys -def validate_and_update_keys(raw_keys, model_keys, config_name : str): - ''' Validate and update model specific config keys - ''' + +def validate_and_update_keys(raw_keys, model_keys, config_name: str): + """Validate and update model specific config keys""" max_logging.log("Updating following parameters in config\n") for k in model_keys: max_logging.log(f"{k}: {model_keys[k]}") if k not in raw_keys: - raise ValueError(f'Key {k} does not exist in config {config_name}.') + raise ValueError(f"Key {k} does not exist in config {config_name}.") elif not isinstance(raw_keys[k], type(model_keys[k])): - raise ValueError(f'Type of key:{k} does not match with {type(model_keys[k])}') + raise ValueError(f"Type of key:{k} does not match with {type(model_keys[k])}") else: raw_keys[k] = model_keys[k] return raw_keys + def get_individual_scales(scale): - '''Choose appropriate scales for individual dimensions based on global scale + """Choose appropriate scales for individual dimensions based on global scale We choose to rotate between doubling: num_head and mlp_dim embed_dim num_layers Any one of these steps is not a perfect doubling, although going through a cycle - of three is a near perfect 8x scaling except for the linear -> softmax -> output step''' - + of three is a near perfect 8x scaling except for the linear -> softmax -> output step""" log_2_scale = math.floor((math.log2(scale))) if 2**log_2_scale != scale: - raise ValueError("Global parameter scale should be a power of 2. If you want finer grained control of the model sizes " - "then you can explicitly set base_embed_dim, base_num_heads, base_mlp_dim, base_num_decoder_layers and/or head_dim.") + raise ValueError( + "Global parameter scale should be a power of 2. If you want finer grained control of the model sizes " + "then you can explicitly set base_embed_dim, base_num_heads, base_mlp_dim, base_num_decoder_layers and/or head_dim." + ) base_scale, rem = divmod(log_2_scale, 3) num_head_scale = base_scale + int(rem > 0) mlp_dim_scale = num_head_scale @@ -327,10 +338,11 @@ def get_individual_scales(scale): layer_scale = base_scale return emb_scale, num_head_scale, mlp_dim_scale, layer_scale + def calculate_global_batch_sizes(raw_keys): - """ Calculates target global batch size from target devices and per_device_batch""" - per_device_batch_size = raw_keys['per_device_batch_size'] - expansion_factor_real_data = raw_keys['expansion_factor_real_data'] + """Calculates target global batch size from target devices and per_device_batch""" + per_device_batch_size = raw_keys["per_device_batch_size"] + expansion_factor_real_data = raw_keys["expansion_factor_real_data"] num_devices = get_num_target_devices(raw_keys) if per_device_batch_size < 1.0: # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 @@ -347,17 +359,19 @@ def calculate_global_batch_sizes(raw_keys): global_batch_size_to_train_on = int(num_devices * per_device_batch_size) return global_batch_size_to_load, global_batch_size_to_train_on + def get_num_target_devices(raw_keys): - compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys.get('compile_topology', "")) + compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys.get("compile_topology", "")) if compile_topology is not None: devices_per_slice = compile_topology.devices_per_slice - return int(devices_per_slice * raw_keys['compile_topology_num_slices']) + return int(devices_per_slice * raw_keys["compile_topology_num_slices"]) else: return len(jax.devices()) + def get_num_slices(raw_keys): - if int(raw_keys['compile_topology_num_slices']) > 0: - return raw_keys['compile_topology_num_slices'] + if int(raw_keys["compile_topology_num_slices"]) > 0: + return raw_keys["compile_topology_num_slices"] else: devices = jax.devices() try: @@ -365,13 +379,16 @@ def get_num_slices(raw_keys): except: return 1 + def get_quantization_local_shard_count(raw_keys): - if raw_keys['quantization_local_shard_count'] == -1: - return raw_keys['num_slices'] + if raw_keys["quantization_local_shard_count"] == -1: + return raw_keys["num_slices"] else: - return raw_keys['quantization_local_shard_count'] + return raw_keys["quantization_local_shard_count"] + + +class HyperParameters: # pylint: disable=missing-class-docstring -class HyperParameters(): # pylint: disable=missing-class-docstring def __init__(self): pass @@ -386,11 +403,13 @@ def __setattr__(self, attr, value): def get_keys(self): return _config.keys + def initialize(argv, **kwargs): global _config, config _config = _HyperParameters(argv, **kwargs) config = HyperParameters() + if __name__ == "__main__": initialize(sys.argv) print(config.steps) diff --git a/MaxText/sequence_packing.py b/MaxText/sequence_packing.py index 36bccde0c..d8ba28082 100644 --- a/MaxText/sequence_packing.py +++ b/MaxText/sequence_packing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Packed Sequence Op.""" @@ -23,9 +23,9 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, key2length: Union[int, Dict[str, int]], keys: Optional[List[str]] = None +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate @@ -66,29 +66,28 @@ def pack_dataset(dataset: tf.data.Dataset, keys = list(shapes.keys()) for k in keys: if k not in shapes: - raise ValueError(f"""Key {k} not found in dataset. Available keys are - {shapes.keys()}""") + raise ValueError( + f"""Key {k} not found in dataset. Available keys are + {shapes.keys()}""" + ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): - raise ValueError('Tensors to be packed must be one-dimensional.') + raise ValueError("Tensors to be packed must be one-dimensional.") # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" if isinstance(key2length, int): key2length = {k: key2length for k in keys} for k in keys: - for suffix in ['_segmentation', '_position']: + for suffix in ["_segmentation", "_position"]: key2length[k + suffix] = key2length[k] # trim to length - dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + dataset = dataset.map(lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) # We pad with a negative value instead of the default 0 because 0 is a # valid token for some tokenizers for e.g., representing unknown value - dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1) + dataset = dataset.padded_batch(batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -98,8 +97,7 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int]) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: @@ -112,16 +110,14 @@ def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) - empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) + empty_example[k + "_position"] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: - new_outputs[k] = outputs[k].write( - outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + new_outputs[k] = outputs[k].write(outputs[k].size(), tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) return new_partial, new_outputs def map_fn(x): @@ -138,10 +134,8 @@ def map_fn(x): dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: - outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) - outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + outputs[k] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + outputs[k + "_position"] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -157,13 +151,10 @@ def body_fn(i, partial, outputs): for k in keys: val = tf.cast(x[k][i], tf.int32) # We consider only the valid tokens i.e., token_id != -1 - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, -1), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, -1), tf.int32))] one_example[k] = val for k in keys: - can_append = tf.logical_and( - can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + can_append = tf.logical_and(can_append, tf.less_equal(tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) def false_fn(): return write_packed_example(partial, outputs) @@ -174,12 +165,10 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) - new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], - tf.range(new_seq_len)], 0) + new_partial[k + "_position"] = tf.concat([partial[k + "_position"], tf.range(new_seq_len)], 0) partial = new_partial return i + 1, partial, outputs @@ -193,14 +182,14 @@ def true_fn(): {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ), - maximum_iterations=dynamic_batch_size) + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + "_segmentation"] = tf.cumsum(tf.cast(tf.equal(packed[k + "_position"], 0), tf.int32), axis=1) * tf.cast( + tf.not_equal(packed[k], 0), tf.int32 + ) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) diff --git a/MaxText/standalone_checkpointer.py b/MaxText/standalone_checkpointer.py index 7e2bb2c1c..fcfb9631d 100644 --- a/MaxText/standalone_checkpointer.py +++ b/MaxText/standalone_checkpointer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Standalone checkpointer - only saves and restores checkpoints at regular intervals, accesses storage needs.""" @@ -41,6 +41,7 @@ Transformer = models.Transformer + def checkpoint_loop(config, state=None): """Main Checkpointing loop. Saves checkpoints. @@ -50,33 +51,29 @@ def checkpoint_loop(config, state=None): ckpt_path: Returns: """ - init_rng, _ , checkpoint_manager, mesh, model, _, tx = setup_mesh_and_model(config) + init_rng, _, checkpoint_manager, mesh, model, _, tx = setup_mesh_and_model(config) - unboxed_abstract_state, _, _ = max_utils.get_abstract_state(model, tx, - config, init_rng, mesh, is_training=True) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) # A barrier to sync all hosts before starting to restore checkpoint jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() with nn_partitioning.axis_rules(config.logical_axis_rules): - state, _ = checkpointing.load_state_if_possible(checkpoint_manager, - None, - config.load_parameters_path, - config.load_full_state_path, - unboxed_abstract_state) + state, _ = checkpointing.load_state_if_possible( + checkpoint_manager, None, config.load_parameters_path, config.load_full_state_path, unboxed_abstract_state + ) if state: - state = state['items'] + state = state["items"] jax.block_until_ready(state) checkpoint_load_end = datetime.datetime.now() - if state is not None: # Checkpoint was available for restore + if state is not None: # Checkpoint was available for restore if jax.process_index() == 0: max_logging.log(f"STANDALONE CHECKPOINTER : Checkpoint restored in : {checkpoint_load_end - checkpoint_load_start}") - else: # Checkpoint was unavailable, state needs to be initialized - state, _, _ = max_utils.setup_training_state(model, None, - tx, config, init_rng, mesh, checkpoint_manager) + else: # Checkpoint was unavailable, state needs to be initialized + state, _, _ = max_utils.setup_training_state(model, None, tx, config, init_rng, mesh, checkpoint_manager) state = add_entropy_to_checkpoint(state) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training for step in np.arange(start_step, config.steps): if checkpoint_manager is not None: start_time = datetime.datetime.now() @@ -90,6 +87,7 @@ def checkpoint_loop(config, state=None): return state + def add_entropy_to_checkpoint(state): """Introduce randomness in checkpoints. This is useful to simulate real checkpoints, without training. Args: @@ -98,14 +96,17 @@ def add_entropy_to_checkpoint(state): state: Returns state with entropy added to the optimizer state. """ opt_0 = state.opt_state[0] - opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda x: - jax.random.normal(create_random_keys(x), shape=x.shape), state.params)) - opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda x: - jax.random.normal(create_random_keys(x), shape=x.shape), state.params)) + opt_0 = opt_0._replace( + mu=jax.tree_util.tree_map(lambda x: jax.random.normal(create_random_keys(x), shape=x.shape), state.params) + ) + opt_0 = opt_0._replace( + nu=jax.tree_util.tree_map(lambda x: jax.random.normal(create_random_keys(x), shape=x.shape), state.params) + ) new_opt = [opt_0] + list(state.opt_state[1:]) state = state.replace(opt_state=new_opt) return state + def create_random_keys(x): """Create random keys to help alter the checkpoint state. Args: @@ -115,8 +116,9 @@ def create_random_keys(x): """ return random.PRNGKey(int(jnp.sum(jnp.abs(x)))) + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_cpu_enable_gloo_collectives', True) + jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(argv) config = pyconfig.config diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index 4bb13e4dc..8d484d19c 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """ Standalone data loader - only loads data for each training step, accesses storage needs.""" @@ -32,7 +32,7 @@ def data_load_loop(config, state=None): """Main data loader loop. - Loads batches of data for each training step. + Loads batches of data for each training step. """ _, _, _, _, _, _, _, data_iterator, _, state = setup_train_loop(config) @@ -43,14 +43,14 @@ def data_load_loop(config, state=None): example_batch = load_next_batch(data_iterator, example_batch, config) jax.block_until_ready(example_batch) first_end = datetime.datetime.now() - time_to_load_first_batch = first_end-start + time_to_load_first_batch = first_end - start if jax.process_index() == 0: max_logging.log(f"STANDALONE DATALOADER : First step completed in {time_to_load_first_batch} seconds, on host 0") - for _ in np.arange(start_step+1, config.steps): + for _ in np.arange(start_step + 1, config.steps): example_batch = load_next_batch(data_iterator, example_batch, config) - jax.block_until_ready(example_batch) # wait until the last batch is read + jax.block_until_ready(example_batch) # wait until the last batch is read end = datetime.datetime.now() if jax.process_index() == 0: max_logging.log(f"STANDALONE DATALOADER : {config.steps} batches loaded in {end-start} seconds, on host 0") @@ -58,7 +58,7 @@ def data_load_loop(config, state=None): def main(argv: Sequence[str]) -> None: - jax.config.update('jax_cpu_enable_gloo_collectives', True) + jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(argv) config = pyconfig.config @@ -70,6 +70,5 @@ def main(argv: Sequence[str]) -> None: data_load_loop(config) - if __name__ == "__main__": app.run(main) diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index e69c2306d..2a4e2ab97 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -36,10 +36,18 @@ class AttentionTest(unittest.TestCase): - """Test for the Attention """ + """Test for the Attention""" + def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size = 1.0, run_name='test', enable_checkpointing=False, max_target_length=128, max_prefill_predict_length=16 ) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + ) self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(0) @@ -62,20 +70,17 @@ def setUp(self): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "dot_product", + attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) self._attention_as_mha_generic_variable = self._attention_as_mha_generic.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), ) def get_data(self, dtype): @@ -86,7 +91,9 @@ def get_data(self, dtype): ) decoder_segment_ids = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, 4) - decoder_positions = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, self.max_target_length) + decoder_positions = jax.random.randint( + self.rng, (self.global_batch_size, self.max_target_length), 0, self.max_target_length + ) return lnx, decoder_segment_ids, decoder_positions @@ -97,23 +104,22 @@ def get_structured_data(self, dtype): dtype=dtype, ) - decoder_positions = jnp.stack([ - jnp.arange(self.max_target_length, dtype=jnp.int32) - for _ in range(self.global_batch_size) - ]) + decoder_positions = jnp.stack( + [jnp.arange(self.max_target_length, dtype=jnp.int32) for _ in range(self.global_batch_size)] + ) - decoder_segment_ids = jax.numpy.zeros((self.global_batch_size, self.max_target_length))\ - + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + decoder_segment_ids = ( + jax.numpy.zeros((self.global_batch_size, self.max_target_length)) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + ) return lnx, decoder_segment_ids, decoder_positions - + @pytest.mark.tpu def test_autoregression(self): prefill_length = self.cfg.max_prefill_predict_length decode_total_length = self.cfg.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - self.dtype) - + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + mha_full = self._attention_as_mha_generic.apply( self._attention_as_mha_generic_variable, lnx, @@ -122,13 +128,13 @@ def test_autoregression(self): inputs_positions=decoder_positions, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) - + lnx_prefill = lnx[:, 0:prefill_length, :] decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - + mha_prefill, output_cache = self._attention_as_mha_generic.apply( self._attention_as_mha_generic_variable, lnx_prefill, @@ -137,40 +143,32 @@ def test_autoregression(self): inputs_positions=decoder_positions_prefill, deterministic=True, model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'aqt': self.rng}, - mutable=["cache"] + rngs={"aqt": self.rng}, + mutable=["cache"], ) self.assertTrue( - jax.numpy.allclose( - mha_prefill, mha_full[:,:prefill_length,:], rtol=1e-02, atol=1e-02, equal_nan=False - ) + jax.numpy.allclose(mha_prefill, mha_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) ) for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx:idx+1, :] - decoder_positions_idx = decoder_positions[:, idx:idx+1] + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] self._attention_as_mha_generic_variable.update(output_cache) mha_idx, output_cache = self._attention_as_mha_generic.apply( - self._attention_as_mha_generic_variable, - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'aqt': self.rng}, - mutable=["cache"] + self._attention_as_mha_generic_variable, + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], ) - mha_full_this_idx = mha_full[:,idx:idx+1,:] - self.assertTrue( - mha_full_this_idx.shape == mha_idx.shape - ) - self.assertTrue( - jax.numpy.allclose( - mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False - ) - ) + mha_full_this_idx = mha_full[:, idx : idx + 1, :] + self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) + self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) @pytest.mark.tpu def test_tpu_kernel_attention_mha(self): @@ -187,8 +185,7 @@ def test_tpu_kernel_attention_mqa(self): def tpu_kernel_attention_helper(self, num_kv_heads): """Test equalvant between dot_product and TPU accelerated""" - lnx, decoder_segment_ids, decoder_positions = self.get_data( - self.dtype) + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) attention_as_mha_generic = Attention( config=self.cfg, @@ -198,20 +195,17 @@ def tpu_kernel_attention_helper(self, num_kv_heads): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "dot_product", + attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) attention_as_mha_generic_variable = attention_as_mha_generic.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), ) mha_generic_output = attention_as_mha_generic.apply( @@ -222,7 +216,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): inputs_positions=decoder_segment_ids, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) attention_as_mha_flash = Attention( @@ -233,20 +227,17 @@ def tpu_kernel_attention_helper(self, num_kv_heads): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "flash", + attention_kernel="flash", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) attention_as_mha_flash_variable = attention_as_mha_flash.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), ) mha_generic_flash_output = attention_as_mha_flash.apply( @@ -257,14 +248,13 @@ def tpu_kernel_attention_helper(self, num_kv_heads): inputs_positions=decoder_segment_ids, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) self.assertTrue( - jax.numpy.allclose( - mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False - ) + jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False) ) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index 4650824bf..7dc4246d7 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -1,19 +1,18 @@ - """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for GPT3 """ import sys @@ -38,11 +37,12 @@ def init_random_model_vars(model, rng, example_batch): """initialze random model vars.""" model_vars = model.init( - {'params': rng, 'aqt': rng}, - example_batch['inputs'], - example_batch['inputs_position'], + {"params": rng, "aqt": rng}, + example_batch["inputs"], + example_batch["inputs_position"], enable_dropout=False, ) + def _replace_initialization(key, value): keystr = jax.tree_util.keystr(key) # replace zero initializer to ensure strong test cases @@ -57,14 +57,15 @@ def _replace_initialization(key, value): class GPT3(unittest.TestCase): """numerical tests for GPT3.""" + def setUp(self): super().setUp() pyconfig.initialize( - [sys.argv[0], 'configs/base.yml'], - run_name='test', - enable_checkpointing=False, - model_name='gpt3-52k', - dtype='float32', + [sys.argv[0], "configs/base.yml"], + run_name="test", + enable_checkpointing=False, + model_name="gpt3-52k", + dtype="float32", ) self.cfg = pyconfig.config @@ -73,14 +74,14 @@ def setUp(self): devices_array = max_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) quant = quantizations.configure_quantization(self.cfg) - self.model = models.Transformer(config = self.cfg, mesh = mesh, quant = quant) + self.model = models.Transformer(config=self.cfg, mesh=mesh, quant=quant) self.example_batch = { - 'inputs': jnp.array([[11, 12, 13, 14, 15]], dtype=jnp.int32), - 'inputs_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), - 'inputs_segmentation': jnp.array([[1, 1, 1, 1, 1]], dtype=jnp.int32), - 'targets': jnp.array([[12, 13, 14, 15, 1]], dtype=jnp.int32), - 'targets_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), - 'targets_segmentation': jnp.array([[1, 1, 1, 1, 0]], dtype=jnp.int32), + "inputs": jnp.array([[11, 12, 13, 14, 15]], dtype=jnp.int32), + "inputs_position": jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + "inputs_segmentation": jnp.array([[1, 1, 1, 1, 1]], dtype=jnp.int32), + "targets": jnp.array([[12, 13, 14, 15, 1]], dtype=jnp.int32), + "targets_position": jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + "targets_segmentation": jnp.array([[1, 1, 1, 1, 0]], dtype=jnp.int32), } self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch) @@ -91,21 +92,20 @@ def test_logits_numerically(self): # paxml applies padding in mlp layer # while maxtext implementaiton applies padding in attention mask instead # the two implementation are equivalent in valid non-padding tokens - per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.]], dtype=jnp.float32) - logits, _ = self.model.apply(self.model_vars, - self.example_batch['inputs'], - self.example_batch['inputs_position'], - decoder_segment_ids=self.example_batch['inputs_segmentation'], - enable_dropout=self.cfg.enable_dropout, - rngs={'dropout': self.rng, 'aqt': self.rng}, mutable='intermediates') - - one_hot_targets = jax.nn.one_hot(self.example_batch['targets'], self.cfg.vocab_size) + per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.0]], dtype=jnp.float32) + logits, _ = self.model.apply( + self.model_vars, + self.example_batch["inputs"], + self.example_batch["inputs_position"], + decoder_segment_ids=self.example_batch["inputs_segmentation"], + enable_dropout=self.cfg.enable_dropout, + rngs={"dropout": self.rng, "aqt": self.rng}, + mutable="intermediates", + ) + + one_hot_targets = jax.nn.one_hot(self.example_batch["targets"], self.cfg.vocab_size) per_example_xent = -jnp.sum(jax.nn.log_softmax(logits) * one_hot_targets, axis=-1, dtype=jnp.float32) # Mask out paddings at the end of each example. - per_example_xent = per_example_xent * (self.example_batch['targets_segmentation'] != 0) + per_example_xent = per_example_xent * (self.example_batch["targets_segmentation"] != 0) - self.assertTrue( - jax.numpy.allclose( - per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03 - ) - ) + self.assertTrue(jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03)) diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 8af271f53..36167f67b 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import subprocess @@ -27,117 +27,130 @@ from input_pipeline import _grain_data_processing from input_pipeline import input_pipeline_interface + class GrainDataProcessingTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - exit_code = subprocess.call(['bash','../setup_gcsfuse.sh', - 'DATASET_GCS_BUCKET=maxtext-dataset', - 'MOUNT_PATH=/tmp/gcsfuse']) - if exit_code != 0: - raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") - - def setUp(self): - super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], - per_device_batch_size=1, - run_name='test', - mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "/tmp/gcsfuse", - tokenizer_path = "../assets/tokenizer", - enable_checkpointing=False, - dataset_type="c4-array_record", - dataset_name='array-record/c4/en/3.0.1', - eval_dataset_name='array-record/c4/en/3.0.1') - self.config = pyconfig.config - self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) - self.train_ds, self.eval_ds = self._get_datasets() - self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() - - def _get_datasets(self): - print("Sharding dataset in ", jax.process_count(), " shards") - train_ds, eval_ds = _grain_data_processing.get_datasets( - config=self.config) - return train_ds, eval_ds - - def _get_preprocessed_datasets(self): - process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) - train_iter, eval_iter, test_iter = _grain_data_processing.preprocess_dataset( - self.config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - global_mesh = self.mesh, - train_ds = self.train_ds, eval_ds = self.eval_ds, - vocab_path=self.config.tokenizer_path) - return train_iter, eval_iter, test_iter - - def test_train_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - # For training we pack multiple short examples in one example. - # *_position and *_segmentation indicate the boundaries. - batch = next(self.train_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - def test_eval_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - batch = next(self.eval_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - - def test_predict_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - batch = next(self.predict_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - def test_batch_determinism(self): - batch1 = next(self.train_iter) - self.train_ds, _ = self._get_datasets() - train_iter, _, _= self._get_preprocessed_datasets() - batch2 = next(train_iter) - self.assertTrue((batch1['inputs']==batch2['inputs']).all()) - self.assertTrue((batch1['targets']==batch2['targets']).all()) - self.assertTrue((batch1['inputs_segmentation']==batch2['inputs_segmentation']).all()) - self.assertTrue((batch1['targets_segmentation']==batch2['targets_segmentation']).all()) - self.assertTrue((batch1['inputs_position']==batch2['inputs_position']).all()) - self.assertTrue((batch1['targets_position']==batch2['targets_position']).all()) - - def test_for_loop_repeatable(self): - def get_first_batch(iterator): - batch = None - for batch in iterator: - break - return batch - - eval_batch1 = get_first_batch(self.eval_iter) - eval_batch2 = get_first_batch(self.eval_iter) - self.assertTrue((eval_batch1['inputs']==eval_batch2['inputs']).all()) - self.assertTrue((eval_batch1['targets']==eval_batch2['targets']).all()) - - -if __name__ == '__main__': + + @classmethod + def setUpClass(cls): + super().setUpClass() + exit_code = subprocess.call( + ["bash", "../setup_gcsfuse.sh", "DATASET_GCS_BUCKET=maxtext-dataset", "MOUNT_PATH=/tmp/gcsfuse"] + ) + if exit_code != 0: + raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") + + def setUp(self): + super().setUp() + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="/tmp/gcsfuse", + tokenizer_path="../assets/tokenizer", + enable_checkpointing=False, + dataset_type="c4-array_record", + dataset_name="array-record/c4/en/3.0.1", + eval_dataset_name="array-record/c4/en/3.0.1", + ) + self.config = pyconfig.config + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.train_ds, self.eval_ds = self._get_datasets() + self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() + + def _get_datasets(self): + print("Sharding dataset in ", jax.process_count(), " shards") + train_ds, eval_ds = _grain_data_processing.get_datasets(config=self.config) + return train_ds, eval_ds + + def _get_preprocessed_datasets(self): + process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) + train_iter, eval_iter, test_iter = _grain_data_processing.preprocess_dataset( + self.config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=self.mesh, + train_ds=self.train_ds, + eval_ds=self.eval_ds, + vocab_path=self.config.tokenizer_path, + ) + return train_iter, eval_iter, test_iter + + def test_train_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + # For training we pack multiple short examples in one example. + # *_position and *_segmentation indicate the boundaries. + batch = next(self.train_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_eval_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.eval_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_predict_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.predict_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_batch_determinism(self): + batch1 = next(self.train_iter) + self.train_ds, _ = self._get_datasets() + train_iter, _, _ = self._get_preprocessed_datasets() + batch2 = next(train_iter) + self.assertTrue((batch1["inputs"] == batch2["inputs"]).all()) + self.assertTrue((batch1["targets"] == batch2["targets"]).all()) + self.assertTrue((batch1["inputs_segmentation"] == batch2["inputs_segmentation"]).all()) + self.assertTrue((batch1["targets_segmentation"] == batch2["targets_segmentation"]).all()) + self.assertTrue((batch1["inputs_position"] == batch2["inputs_position"]).all()) + self.assertTrue((batch1["targets_position"] == batch2["targets_position"]).all()) + + def test_for_loop_repeatable(self): + def get_first_batch(iterator): + batch = None + for batch in iterator: + break + return batch + + eval_batch1 = get_first_batch(self.eval_iter) + eval_batch2 = get_first_batch(self.eval_iter) + self.assertTrue((eval_batch1["inputs"] == eval_batch2["inputs"]).all()) + self.assertTrue((eval_batch1["targets"] == eval_batch2["targets"]).all()) + + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index 274739a7f..b305562c2 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -22,18 +22,20 @@ class Inference_Microbenchmark(unittest.TestCase): + @pytest.mark.tpu def test(self): - pyconfig.initialize([None, - "configs/tpu_smoke_test.yml", - "tokenizer_path=../assets/tokenizer.llama2", - "ici_autoregressive_parallelism=-1", - "ici_fsdp_parallelism=1", - "max_prefill_predict_length=1024", - "max_target_length=2048", - "scan_layers=false", - "weight_dtype=bfloat16", - ]) + pyconfig.initialize([ + None, + "configs/tpu_smoke_test.yml", + "tokenizer_path=../assets/tokenizer.llama2", + "ici_autoregressive_parallelism=-1", + "ici_fsdp_parallelism=1", + "max_prefill_predict_length=1024", + "max_target_length=2048", + "scan_layers=false", + "weight_dtype=bfloat16", + ]) inference_microbenchmark_main(pyconfig.config) diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index 38e1a7c62..6d7b7827c 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for Llama """ import jax @@ -31,13 +31,10 @@ """ -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32 -) -> jnp.ndarray: + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32) -> jnp.ndarray: """Calculate the frequencies""" - freqs = 1.0 / ( - theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim) - ) + freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) t = np.arange(end) # type: ignore freqs = np.outer(t, freqs).astype(dtype) # type: ignore sin, cos = np.sin(freqs), np.cos(freqs) @@ -45,14 +42,13 @@ def precompute_freqs_cis( return jnp.asarray(freqs_cis) - def apply_rotary_emb( - xq: jnp.ndarray, - xk: jnp.ndarray, - freqs_cis: jnp.ndarray, - dtype: jnp.dtype = jnp.bfloat16, + xq: jnp.ndarray, + xk: jnp.ndarray, + freqs_cis: jnp.ndarray, + dtype: jnp.dtype = jnp.bfloat16, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ Apply the computed Rotary Postional Embedding""" + """Apply the computed Rotary Postional Embedding""" reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) @@ -60,29 +56,26 @@ def apply_rotary_emb( xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) # add head dim - freqs_cis = jnp.reshape( - freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]) - ) + freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) xq_out = xq_ * freqs_cis - xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape( - *xq_out.shape[:-1], -1 - ) + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) xk_out = xk_ * freqs_cis - xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape( - *xk_out.shape[:-1], -1 - ) + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) return xq_out.astype(dtype), xk_out.astype(dtype) + def permute_to_match_maxtext_rope(arr): evens = arr[..., ::2] odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) + class RoPETest(unittest.TestCase): - """Test for the RoPE implementation """ + """Test for the RoPE implementation""" + def test_rope(self): dim_per_head = 128 seq_len = 8 @@ -93,24 +86,28 @@ def test_rope(self): # Calculate RoPE embeddings from Sea-Snell implementation freqs_cis = precompute_freqs_cis(dim_per_head, seq_len * 2) - freqs_cis = jnp.take( - freqs_cis, jnp.arange(seq_len, dtype=np.int32)[None, :], axis=0 - ) + freqs_cis = jnp.take(freqs_cis, jnp.arange(seq_len, dtype=np.int32)[None, :], axis=0) - llama_output = apply_rotary_emb( - jnp.asarray(x_q), jnp.asarray(x_k), freqs_cis - ) + llama_output = apply_rotary_emb(jnp.asarray(x_q), jnp.asarray(x_k), freqs_cis) seq_length = x_q.shape[1] position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings from MaxText implementation - query_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_q), position = position) - key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position) + query_proj = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)( + permute_to_match_maxtext_rope(x_q), position=position + ) + key_proj = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)(permute_to_match_maxtext_rope(x_k), position=position) # Compare results - self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) - self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False + ) + ) + self.assertTrue( + jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False) + ) def test_scaling_rope(self): dim_per_head = 128 @@ -121,15 +118,15 @@ def test_scaling_rope(self): position = jnp.arange(seq_len, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings and then scale - query_proj_1 = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(x_q, position = position) - query_proj_1 = query_proj_1 * (dim_per_head ** -0.5) + query_proj_1 = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)(x_q, position=position) + query_proj_1 = query_proj_1 * (dim_per_head**-0.5) # scale first and then apply RoPE - query_proj_2 = x_q * (dim_per_head ** -0.5) - query_proj_2 = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(query_proj_2, position=position) + query_proj_2 = x_q * (dim_per_head**-0.5) + query_proj_2 = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)(query_proj_2, position=position) self.assertTrue(jax.numpy.allclose(query_proj_2, query_proj_1, rtol=1e-01, atol=1e-04, equal_nan=False)) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/max_utils_test.py b/MaxText/tests/max_utils_test.py index c4c1480dd..061446024 100644 --- a/MaxText/tests/max_utils_test.py +++ b/MaxText/tests/max_utils_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the common Max Utils """ import jax @@ -30,55 +30,53 @@ Transformer = models.Transformer + class MaxUtilsSummaryStats(unittest.TestCase): """Tests for the summary stats functions in max_utils.py""" + def test_l2norm_pytree(self): - x = {'a': jax.numpy.array([0, 2, 0]), 'b': jax.numpy.array([0, 3, 6])} + x = {"a": jax.numpy.array([0, 2, 0]), "b": jax.numpy.array([0, 3, 6])} pytree_l2_norm = max_utils.l2norm_pytree(x) self.assertTrue(jax.numpy.allclose(pytree_l2_norm, 7, rtol=1e-05, atol=1e-08, equal_nan=False)) + class MaxUtilsInitState(unittest.TestCase): """Tests initialization of training and decode states in max_utils.py""" + def setUp(self): self.model = nn.Dense(features=5) self.key1, self.key2 = random.split(random.key(0)) - self.input = random.normal(self.key1, (10,)) # Dummy input data + self.input = random.normal(self.key1, (10,)) # Dummy input data self.params = self.model.init(self.key2, self.input) self.output = self.model.apply(self.params, self.input) self.tx = optax.adam(learning_rate=0.001) def test_calculate_num_params_from_pytree(self): example_tree = [ - [1, 'a', object()], - (1, (2, 3), ()), - [1, {'k1': 2, 'k2': (3, 4)}, 5], - {'a': 2, 'b': (2, 3)}, - jnp.array([1, 2, 3]), - ] + [1, "a", object()], + (1, (2, 3), ()), + [1, {"k1": 2, "k2": (3, 4)}, 5], + {"a": 2, "b": (2, 3)}, + jnp.array([1, 2, 3]), + ] self.assertEqual(max_utils.calculate_num_params_from_pytree(example_tree), 17) # Model params self.assertEqual(max_utils.calculate_num_params_from_pytree(self.params), 55) def test_init_train_state(self): state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.params, - tx=None, # type: ignore - opt_state={} + step=0, apply_fn=self.model.apply, params=self.params, tx=None, opt_state={} # type: ignore ) self.assertEqual(state.tx, None) self.assertEqual(state.step, 0) self.assertEqual(state.opt_state, {}) self.assertEqual(state.apply_fn, self.model.apply) - self.assertEqual(max_utils.calculate_num_params_from_pytree(state.params), - max_utils.calculate_num_params_from_pytree(self.params)) - + self.assertEqual( + max_utils.calculate_num_params_from_pytree(state.params), max_utils.calculate_num_params_from_pytree(self.params) + ) def test_init_decode_state(self): - decode_state = max_utils.init_decode_state( - self.model.apply, self.params - ) + decode_state = max_utils.init_decode_state(self.model.apply, self.params) self.assertEqual(decode_state.apply_fn, self.model.apply) output = decode_state.apply_fn(self.params, self.input) self.assertEqual(output.tolist(), self.output.tolist()) @@ -86,8 +84,8 @@ def test_init_decode_state(self): self.assertEqual(decode_state.opt_state, {}) self.assertEqual(decode_state.step, 0) self.assertEqual( - max_utils.calculate_num_params_from_pytree(decode_state.params), - max_utils.calculate_num_params_from_pytree(self.params) + max_utils.calculate_num_params_from_pytree(decode_state.params), + max_utils.calculate_num_params_from_pytree(self.params), ) def test_init_training_state(self): @@ -96,24 +94,24 @@ def test_init_training_state(self): self.assertEqual(state.tx, self.tx) self.assertNotEqual(state.opt_state, {}) self.assertEqual( - max_utils.calculate_num_params_from_pytree(state.params), - max_utils.calculate_num_params_from_pytree(self.params) + max_utils.calculate_num_params_from_pytree(state.params), max_utils.calculate_num_params_from_pytree(self.params) ) + class ModelWithMultipleCollections(nn.Module): - """ - A simple model that has variables in multiple collections - "params" and "special_variables" - """ - def setup(self): - self.dense = nn.Dense(4) - self.kernel = self.variable( - "special_variables", "my_first_kernel", lambda: jnp.ones((4, 5)) - ) - - def __call__(self, x, y): - x = self.dense(x) - x = x @ self.kernel.value - return x + """ + A simple model that has variables in multiple collections - "params" and "special_variables" + """ + + def setup(self): + self.dense = nn.Dense(4) + self.kernel = self.variable("special_variables", "my_first_kernel", lambda: jnp.ones((4, 5))) + + def __call__(self, x, y): + x = self.dense(x) + x = x @ self.kernel.value + return x + class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase): @@ -122,8 +120,7 @@ def setUp(self): self.config = pyconfig.config self.model = ModelWithMultipleCollections() self.key1, self.key2, self.key3 = random.split(random.key(0), num=3) - self.input = random.normal(self.key1, - (self.config.global_batch_size_to_load, self.config.max_target_length)) + self.input = random.normal(self.key1, (self.config.global_batch_size_to_load, self.config.max_target_length)) self.params = self.model.init(self.key2, self.input, self.input) self.tx = optax.adam(learning_rate=0.001) @@ -137,19 +134,16 @@ def _test_init_initial_state_driver(self, is_training): self.assertIsNone(state_under_test.tx) self.assertEqual(state_under_test.opt_state, {}) self.assertEqual( - max_utils.calculate_num_params_from_pytree(state_under_test.params), - max_utils.calculate_num_params_from_pytree(self.params) - ) - self.assertEqual( - len(self.params), - len(state_under_test.params) + max_utils.calculate_num_params_from_pytree(state_under_test.params), + max_utils.calculate_num_params_from_pytree(self.params), ) + self.assertEqual(len(self.params), len(state_under_test.params)) self.assertIn("special_variables", state_under_test.params) self.assertIn("params", state_under_test.params) - + def test_initial_train_state(self): self._test_init_initial_state_driver(True) - + def test_initial_decode_state(self): self._test_init_initial_state_driver(False) @@ -167,28 +161,26 @@ def setUp(self): def test_setup_decode_state(self): rng = random.PRNGKey(0) - state, _ = max_utils.setup_decode_state( - self.model, self.config, rng, self.mesh, None) + state, _ = max_utils.setup_decode_state(self.model, self.config, rng, self.mesh, None) self.assertEqual(state.tx, None) self.assertEqual(state.opt_state, {}) def test_setup_initial_state(self): rng = random.PRNGKey(0) tx = optax.adam(learning_rate=0.001) - state, _, _ = max_utils.setup_initial_state( - self.model, None, tx, self.config, rng, self.mesh, None) + state, _, _ = max_utils.setup_initial_state(self.model, None, tx, self.config, rng, self.mesh, None) self.assertEqual(state.tx, tx) self.assertNotEqual(state.opt_state, {}) + class MaxUtilsT5XCrossEntropy(unittest.TestCase): """Tests for the cross entropy functions in max_utils.py""" + def test_t5x_cross_entropy(self): # Generate random targets and logits key = jax.random.PRNGKey(0) - targets = jax.random.randint(key, shape=(48, 2048), - dtype=jax.numpy.int32, minval=1, maxval=10) - logits = jax.random.uniform(key, shape=(48, 2048, 4096), - dtype=jax.numpy.float32) + targets = jax.random.randint(key, shape=(48, 2048), dtype=jax.numpy.int32, minval=1, maxval=10) + logits = jax.random.uniform(key, shape=(48, 2048, 4096), dtype=jax.numpy.float32) # Calculate xent from optax implementation optax_xent = optax.softmax_cross_entropy_with_integer_labels(logits, targets) @@ -196,10 +188,11 @@ def test_t5x_cross_entropy(self): # Calculate xent from custom T5X implementation one_hot_targets = jax.nn.one_hot(targets, 4096) t5x_xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) - t5x_xent = nn.with_logical_constraint(t5x_xent, ('activation_batch', 'activation_length')) + t5x_xent = nn.with_logical_constraint(t5x_xent, ("activation_batch", "activation_length")) # Compare results self.assertTrue(jax.numpy.allclose(optax_xent, t5x_xent, rtol=1e-05, atol=1e-08, equal_nan=False)) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index 5ab89406e..8b3bcc318 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -30,35 +30,41 @@ from layers import quantizations Mesh = jax.sharding.Mesh -MAX_PREFILL_PREDICT_LENGTH = 4 +MAX_PREFILL_PREDICT_LENGTH = 4 + class TestModel(unittest.TestCase): - """Test the Whole Model """ + """Test the Whole Model""" + def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size = 1.0, run_name='test', - enable_checkpointing=False, base_num_decoder_layers=2, attention="dot_product", - max_target_length=16, base_emb_dim=256, base_num_query_heads=2, base_num_kv_heads=2, max_prefill_predict_length=4) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + base_num_decoder_layers=2, + attention="dot_product", + max_target_length=16, + base_emb_dim=256, + base_num_query_heads=2, + base_num_kv_heads=2, + max_prefill_predict_length=4, + ) self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(0) def get_data(self): s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length) - ids = jax.random.randint( - self.rng, - s, - 0, - self.cfg.vocab_size - ) + ids = jax.random.randint(self.rng, s, 0, self.cfg.vocab_size) decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR - decoder_positions = jnp.stack([ - jnp.arange(self.cfg.max_target_length, dtype=jnp.int32) - for _ in range(self.cfg.global_batch_size_to_train_on) - ]) + decoder_positions = jnp.stack( + [jnp.arange(self.cfg.max_target_length, dtype=jnp.int32) for _ in range(self.cfg.global_batch_size_to_train_on)] + ) return ids, decoder_segment_ids, decoder_positions - + @pytest.mark.tpu def test_train_vs_prefill_and_autoregress(self): PREFILL_RANGE = MAX_PREFILL_PREDICT_LENGTH @@ -66,68 +72,59 @@ def test_train_vs_prefill_and_autoregress(self): devices_array = max_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) quant = quantizations.configure_quantization(self.cfg) - model = models.Transformer(config = self.cfg, mesh = mesh, quant=quant) + model = models.Transformer(config=self.cfg, mesh=mesh, quant=quant) ids, decoder_segment_ids, decoder_positions = self.get_data() transformer_vars = model.init( - {'params': self.rng, 'aqt': self.rng}, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False + {"params": self.rng, "aqt": self.rng}, ids, decoder_positions, decoder_segment_ids, enable_dropout=False ) full_train_logits = model.apply( - transformer_vars, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - model_mode = common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng} + transformer_vars, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_TRAIN, + rngs={"aqt": self.rng}, ) partial_prefill_logits, partial_cache = model.apply( - transformer_vars, - ids[:, :PREFILL_RANGE], - decoder_positions[:, :PREFILL_RANGE], - decoder_segment_ids=decoder_segment_ids[:, :PREFILL_RANGE], - enable_dropout=False, - model_mode = common_types.MODEL_MODE_PREFILL, - rngs={'aqt': self.rng}, - mutable=["cache"], + transformer_vars, + ids[:, :PREFILL_RANGE], + decoder_positions[:, :PREFILL_RANGE], + decoder_segment_ids=decoder_segment_ids[:, :PREFILL_RANGE], + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"aqt": self.rng}, + mutable=["cache"], ) self.assertTrue( jax.numpy.allclose( - full_train_logits[:,:PREFILL_RANGE,:], partial_prefill_logits, rtol=1e-01, atol=1e-01, equal_nan=False + full_train_logits[:, :PREFILL_RANGE, :], partial_prefill_logits, rtol=1e-01, atol=1e-01, equal_nan=False ) ) for idx in range(PREFILL_RANGE, self.cfg.max_target_length): - ids_idx = ids[:, idx:idx+1] - decoder_positions_idx = decoder_positions[:, idx:idx+1] + ids_idx = ids[:, idx : idx + 1] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] transformer_vars.update(partial_cache) ar_logits, partial_cache = model.apply( - transformer_vars, - ids_idx, - decoder_positions_idx, - enable_dropout=False, - model_mode = common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'aqt': self.rng}, - mutable=["cache"], + transformer_vars, + ids_idx, + decoder_positions_idx, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], ) - full_train_logits_idx = full_train_logits[:,idx:idx+1,:] - self.assertTrue( - full_train_logits_idx.shape == ar_logits.shape - ) - self.assertTrue( - jax.numpy.allclose( - full_train_logits_idx, ar_logits, rtol=1e-01, atol=1e-01, equal_nan=False - ) - ) + full_train_logits_idx = full_train_logits[:, idx : idx + 1, :] + self.assertTrue(full_train_logits_idx.shape == ar_logits.shape) + self.assertTrue(jax.numpy.allclose(full_train_logits_idx, ar_logits, rtol=1e-01, atol=1e-01, equal_nan=False)) + -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index e620c2e1e..ba289c040 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, missing-function-docstring import sys @@ -35,29 +35,32 @@ class MultihostDataloadingTest(unittest.TestCase): def setUp(self): super().setUp() batch_size = 4 - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "gs://maxtext-dataset/", - enable_checkpointing=False) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="gs://maxtext-dataset/", + enable_checkpointing=False, + ) config = pyconfig.config global_data_shape = PartitionSpec(batch_size, config.max_target_length) - data_sharding = ('data',) + data_sharding = ("data",) mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes) - data_axes = PartitionSpec('data',) + data_axes = PartitionSpec( + "data", + ) # creating 2 batches of data - global_data = np.arange(np.prod(global_data_shape)*2).reshape((batch_size * 2, config.max_target_length)) + global_data = np.arange(np.prod(global_data_shape) * 2).reshape((batch_size * 2, config.max_target_length)) dataset = tf.data.Dataset.from_tensor_slices(global_data) dataset = dataset.repeat() dataset = dataset.batch(batch_size) - self.multihost_gen = ( - multihost_dataloading.MultiHostDataLoadIterator( - dataset, self.mesh - ) - ) + self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh) @pytest.mark.tpu def test_batch_sharded_data_pipeline(self): @@ -66,5 +69,5 @@ def test_batch_sharded_data_pipeline(self): self.assertTrue(not np.array_equal(first_batch, sec_batch, equal_nan=True)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/profiler_test.py b/MaxText/tests/profiler_test.py index 095073008..1c2cf3780 100644 --- a/MaxText/tests/profiler_test.py +++ b/MaxText/tests/profiler_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Profiler tests for TPUs.""" import glob @@ -29,9 +29,9 @@ class TpuJAXTest(unittest.TestCase): def _get_session_snapshot(self): """Gets a session snapshot of current session. assume only one session.""" - profile_plugin_root ="tensorboard/plugins/profile" + profile_plugin_root = "tensorboard/plugins/profile" # The session exists under a director whose name is time-dependent. - profile_session_glob = os.path.join(profile_plugin_root, '*', '*.xplane.pb') + profile_session_glob = os.path.join(profile_plugin_root, "*", "*.xplane.pb") return glob.glob(profile_session_glob) def test_xplane_is_present(self): @@ -40,47 +40,42 @@ def test_xplane_is_present(self): def test_overview_page(self): xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, - 'overview_page^', {}) + result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "overview_page^", {}) result = json.loads(result) run_environment = result[2] - self.assertEqual(run_environment['p']['host_count'], '1') - self.assertRegex(run_environment['p']['device_type'], 'TPU.*') + self.assertEqual(run_environment["p"]["host_count"], "1") + self.assertRegex(run_environment["p"]["device_type"], "TPU.*") def test_op_profile(self): xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data( - xspace_filenames, 'op_profile^', {} - ) + result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "op_profile^", {}) result = json.loads(result) - self.assertIn('byCategory', result) - self.assertIn('metrics', result['byCategory']) - overall_metrics = result['byCategory']['metrics'] - self.assertIn('flops', overall_metrics) - self.assertIn('bandwidthUtils', overall_metrics) - self.assertGreater(overall_metrics['flops'], 0) + self.assertIn("byCategory", result) + self.assertIn("metrics", result["byCategory"]) + overall_metrics = result["byCategory"]["metrics"] + self.assertIn("flops", overall_metrics) + self.assertIn("bandwidthUtils", overall_metrics) + self.assertGreater(overall_metrics["flops"], 0) def test_device_trace_contains_threads(self): xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data( - xspace_filenames, 'trace_viewer^', {} - ) + result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "trace_viewer^", {}) result = json.loads(result) thread_names = [] - for event in result['traceEvents']: - if 'name' in event and event['name'] == 'thread_name': - thread_names.append((event['args']['name'])) - expected_threads = [ - 'TensorFlow Name Scope', - 'TensorFlow Ops', - 'XLA Modules', - 'XLA Ops', - 'XLA TraceMe', - 'Steps', - ] + for event in result["traceEvents"]: + if "name" in event and event["name"] == "thread_name": + thread_names.append((event["args"]["name"])) + expected_threads = [ + "TensorFlow Name Scope", + "TensorFlow Ops", + "XLA Modules", + "XLA Ops", + "XLA TraceMe", + "Steps", + ] # Ensure that thread_names contains at least all expected threads. - self.assertEqual(set(expected_threads)-set(thread_names), set()) + self.assertEqual(set(expected_threads) - set(thread_names), set()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/quantizations_test.py b/MaxText/tests/quantizations_test.py index 7f35ffb0f..ef1b54df4 100644 --- a/MaxText/tests/quantizations_test.py +++ b/MaxText/tests/quantizations_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the quantizations """ from jax import numpy as jnp @@ -23,8 +23,10 @@ from layers import quantizations import unittest + class QuantTestModule(nn.Module): """Test module for einsum.""" + quantization: quantizations.AqtQuantization @nn.compact @@ -36,23 +38,25 @@ def __call__(self, inputs): einsum = self.quantization.einsum() dot_general_cls = self.quantization.dot_general_cls() dot_general = dot_general_cls() - res_einsum = einsum('bc,ab->ac', inputs, identity) + res_einsum = einsum("bc,ab->ac", inputs, identity) res_dg = dot_general(inputs, inputs, (((), ()), ((), ())), precision=None) return res_einsum, res_dg -def _configure_quantization(quant_str="", mode_str='train'): + +def _configure_quantization(quant_str="", mode_str="train"): pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False, quantization=quant_str) config = pyconfig.config quant = quantizations.configure_quantization(config, mode_str) return quant + def _apply(quant_str=""): quant = _configure_quantization(quant_str) test_module = QuantTestModule(quant) rng = random.PRNGKey(0) - variables = test_module.init({'params': rng}, jnp.ones((2, 2))) + variables = test_module.init({"params": rng}, jnp.ones((2, 2))) inputs = jnp.ones((2, 2)) - res_einsum, res_dg = test_module.apply(variables, inputs, rngs={'params': random.PRNGKey(0)}) + res_einsum, res_dg = test_module.apply(variables, inputs, rngs={"params": random.PRNGKey(0)}) return inputs, res_einsum, res_dg @@ -60,11 +64,10 @@ class QuantizationTest(unittest.TestCase): """Tests for quantization.""" def test_in_quant_mode(self): - quant = _configure_quantization(quant_str="int8", mode_str='convert') + quant = _configure_quantization(quant_str="int8", mode_str="convert") self.assertTrue(quantizations.in_convert_mode(quant)) self.assertFalse(quantizations.in_serve_mode(quant)) - def test_configure_quantization_is_null(self): for quant_mode in ["train", "serve", "convert"]: quant = _configure_quantization(quant_str="", mode_str=quant_mode) @@ -88,37 +91,52 @@ def test_aqt_quantization(self): self.assertTrue(jnp.greater(jnp.max(inputs), jnp.max(res_einsum))) self.assertEqual(res_einsum.dtype, np.dtype(np.float32)) self.assertTrue(jnp.greater(jnp.max(inputs), jnp.max(res_dg[0][0]))) - #self.assertEqual(res_dg.dtype, np.dtype(np.float32)) + # self.assertEqual(res_dg.dtype, np.dtype(np.float32)) def test_remove_quantized_params(self): _params = { - 'decoder': { - 'decoder_norm': {'scale': 1.0}, - 'layers': { - 'mlp': {'wi_0': {'kernel': 1.0}, 'wi_1': {'kernel': 1.0}, 'wo': {'kernel': 1.0}}, - 'self_attention': {'key': {'kernel': 1.0},}}, - 'logits_dense': {'kernel': 1.0}}, - } - _aqt_vars = { - 'decoder': { - 'layers': { - 'mlp': { - 'wi_0': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}}, - 'wi_1': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}}, - 'wo': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}} + "decoder": { + "decoder_norm": {"scale": 1.0}, + "layers": { + "mlp": {"wi_0": {"kernel": 1.0}, "wi_1": {"kernel": 1.0}, "wo": {"kernel": 1.0}}, + "self_attention": { + "key": {"kernel": 1.0}, + }, }, - 'self_attention': {'key': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0}},}}}} - } + "logits_dense": {"kernel": 1.0}, + }, + } + _aqt_vars = { + "decoder": { + "layers": { + "mlp": { + "wi_0": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}}, + "wi_1": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}}, + "wo": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}}, + }, + "self_attention": { + "key": { + "AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}, + } + }, + } + } + } _expected = { - 'decoder': { - 'decoder_norm': {'scale': 1.0}, - 'layers': { - 'mlp': {'wi_0': {'kernel': {}}, 'wi_1': {'kernel': {}}, 'wo': {'kernel': {}}}, - 'self_attention': {'key': {'kernel': {}},}}, - 'logits_dense': {'kernel': 1.0},} - } + "decoder": { + "decoder_norm": {"scale": 1.0}, + "layers": { + "mlp": {"wi_0": {"kernel": {}}, "wi_1": {"kernel": {}}, "wo": {"kernel": {}}}, + "self_attention": { + "key": {"kernel": {}}, + }, + }, + "logits_dense": {"kernel": 1.0}, + } + } result = quantizations.remove_quantized_params(_params, _aqt_vars) self.assertEqual(_expected, result) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index 8b51246b2..d9befd1e8 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the standalone_checkpointer.py """ import unittest @@ -25,35 +25,67 @@ class Standalone_DL_CKPT(unittest.TestCase): - """Tests for standalone_checkpointer.py, checkpoint and restore. """ + """Tests for standalone_checkpointer.py, checkpoint and restore.""" def _get_random_test_name(self, test_name): now = datetime.now() date_time = now.strftime("_%Y-%m-%d-%H-%M_") - random_string = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(6)) + random_string = "".join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(6)) random_run_name = test_name + date_time + random_string return random_run_name @pytest.mark.tpu def test_standalone_dataloader(self): random_run_name = self._get_random_test_name("standalone_dataloader") - sdl_main((None, "configs/base.yml", "run_name="+random_run_name, "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", "steps=100", "enable_checkpointing=false", - "tokenizer_path=../assets/tokenizer.llama2")) # need to pass relative path to tokenizer + sdl_main(( + None, + "configs/base.yml", + "run_name=" + random_run_name, + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "steps=100", + "enable_checkpointing=false", + "tokenizer_path=../assets/tokenizer.llama2", + )) # need to pass relative path to tokenizer @pytest.mark.tpu def test_standalone_checkpointer(self): random_run_name = self._get_random_test_name("standalone_checkpointer") # checkpoint at 50 - sckpt_main((None, "configs/base.yml", f"run_name={random_run_name}", "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset","base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", - "base_mlp_dim=128", "base_num_decoder_layers=2", "steps=60", "enable_checkpointing=True", - "checkpoint_period=50", "async_checkpointing=False")) + sckpt_main(( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=60", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + )) # restore at 50 and checkpoint at 100 - sckpt_main((None, "configs/base.yml", f"run_name={random_run_name}", "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset","base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", - "base_mlp_dim=128", "base_num_decoder_layers=2", "steps=110", "enable_checkpointing=True", - "checkpoint_period=50", "async_checkpointing=False")) + sckpt_main(( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=110", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + )) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/tfds_data_processing_test.py b/MaxText/tests/tfds_data_processing_test.py index 1c998fe82..098334b15 100644 --- a/MaxText/tests/tfds_data_processing_test.py +++ b/MaxText/tests/tfds_data_processing_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, missing-function-docstring import os @@ -29,26 +29,32 @@ from input_pipeline import _tfds_data_processing from input_pipeline import input_pipeline_interface + class TfdsDataProcessingTest(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "gs://maxtext-dataset/", - tokenizer_path = "../assets/tokenizer", - enable_checkpointing=False) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="gs://maxtext-dataset/", + tokenizer_path="../assets/tokenizer", + enable_checkpointing=False, + ) os.environ["TFDS_DATA_DIR"] = pyconfig.config.dataset_path self.config = pyconfig.config self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) self.read_config = tfds.ReadConfig( - shuffle_seed = self.config.data_shuffle_seed, + shuffle_seed=self.config.data_shuffle_seed, ) self.read_config.add_tfds_id = True - + self.train_ds, self.eval_ds = self._get_datasets() self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() @@ -56,10 +62,11 @@ def _get_datasets(self): print("Sharding dataset in ", jax.process_count(), " shards") process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) train_ds, eval_ds = _tfds_data_processing.get_datasets( - config=self.config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - read_config = self.read_config) + config=self.config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + read_config=self.read_config, + ) return train_ds, eval_ds def _get_preprocessed_datasets(self): @@ -67,9 +74,8 @@ def _get_preprocessed_datasets(self): mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), self.config.mesh_axes) sp_tokenizer = input_pipeline_interface.get_tokenizer(self.config.tokenizer_path) train_iter, eval_iter, test_iter = _tfds_data_processing.preprocess_dataset( - self.config, - mesh, - self.train_ds, self.eval_ds, sp_tokenizer) + self.config, mesh, self.train_ds, self.eval_ds, sp_tokenizer + ) return train_iter, eval_iter, test_iter def test_train_ds(self): @@ -77,33 +83,39 @@ def test_train_ds(self): # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. batch = next(self.train_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) def test_eval_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] batch = next(self.eval_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "targets": expected_shape, + }, + ) def test_predict_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] batch = next(self.predict_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "targets": expected_shape, + }, + ) def test_ds_determinism(self): train_ds1 = self.train_ds.batch(64) @@ -113,20 +125,19 @@ def test_ds_determinism(self): train_ds = train_ds.batch(64) train_ds2 = next(train_ds.as_numpy_iterator()) - self.assertCountEqual(train_ds1['tfds_id'], train_ds2['tfds_id']) - + self.assertCountEqual(train_ds1["tfds_id"], train_ds2["tfds_id"]) def test_batch_determinism(self): batch1 = next(self.train_iter) self.train_ds, _ = self._get_datasets() - train_iter2, _, _= self._get_preprocessed_datasets() + train_iter2, _, _ = self._get_preprocessed_datasets() batch2 = next(train_iter2) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs'], batch2['inputs']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets'], batch2['targets']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs_segmentation'], batch2['inputs_segmentation']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets_segmentation'], batch2['targets_segmentation']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs_position'], batch2['inputs_position']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets_position'], batch2['targets_position']))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs"], batch2["inputs"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["targets"], batch2["targets"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs_segmentation"], batch2["inputs_segmentation"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["targets_segmentation"], batch2["targets_segmentation"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs_position"], batch2["inputs_position"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["targets_position"], batch2["targets_position"]))) def test_for_loop_repeatable(self): def get_first_batch(iterator): @@ -137,10 +148,9 @@ def get_first_batch(iterator): eval_batch1 = get_first_batch(self.eval_iter) eval_batch2 = get_first_batch(self.eval_iter) - self.assertTrue((eval_batch1['inputs']==eval_batch2['inputs']).all()) - self.assertTrue((eval_batch1['targets']==eval_batch2['targets']).all()) + self.assertTrue((eval_batch1["inputs"] == eval_batch2["inputs"]).all()) + self.assertTrue((eval_batch1["targets"] == eval_batch2["targets"]).all()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() - diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index 6797f6ee3..c24cd2786 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -1,17 +1,17 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ Tests for tokenizer @@ -30,41 +30,43 @@ class TokenizerTest(unittest.TestCase): @classmethod def setUpClass(cls): - dataset_name = 'c4/en:3.0.1' - dataset_path = 'gs://maxtext-dataset' + dataset_name = "c4/en:3.0.1" + dataset_path = "gs://maxtext-dataset" cls.vocab_size = 32_768 cls.max_corpus_chars = 10_000_000 - assets_path = 'tests' - vocab_model_name = 'test_tokenizer' + assets_path = "tests" + vocab_model_name = "test_tokenizer" cls.tokenizer_path = os.path.join(assets_path, vocab_model_name) os.environ["TFDS_DATA_DIR"] = dataset_path read_config = tfds.ReadConfig( - shuffle_seed = 0, + shuffle_seed=0, ) train_ds_builder = tfds.builder(dataset_name) - cls.dataset = train_ds_builder.as_dataset(split='train', read_config=read_config, shuffle_files=True) - train_tokenizer.train_tokenizer(cls.dataset, - assets_path=assets_path, - vocab_path=cls.tokenizer_path, - vocab_size=cls.vocab_size, - max_corpus_chars=cls.max_corpus_chars) + cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) + train_tokenizer.train_tokenizer( + cls.dataset, + assets_path=assets_path, + vocab_path=cls.tokenizer_path, + vocab_size=cls.vocab_size, + max_corpus_chars=cls.max_corpus_chars, + ) @classmethod def tearDownClass(cls): os.remove(cls.tokenizer_path) def test_tokenize(self): - source_tokenizer = tokenizer.load_tokenizer('../assets/tokenizer') + source_tokenizer = tokenizer.load_tokenizer("../assets/tokenizer") test_tokenizer = tokenizer.load_tokenizer(self.tokenizer_path) - text = 'This is a test' + text = "This is a test" self.assertTrue((np.asarray(source_tokenizer.tokenize(text)) & np.asarray(test_tokenizer.tokenize(text))).all()) def test_detokenize(self): - source_tokenizer = tokenizer.load_tokenizer('../assets/tokenizer') + source_tokenizer = tokenizer.load_tokenizer("../assets/tokenizer") test_tokenizer = tokenizer.load_tokenizer(self.tokenizer_path) - tokens = [66,12,10,698,2] - self.assertEqual(np.asarray(source_tokenizer.detokenize(tokens)),np.asarray(test_tokenizer.detokenize(tokens))) + tokens = [66, 12, 10, 698, 2] + self.assertEqual(np.asarray(source_tokenizer.detokenize(tokens)), np.asarray(test_tokenizer.detokenize(tokens))) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 3bf46e753..cc64aea91 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the common Max Utils """ import unittest @@ -26,73 +26,158 @@ class TrainCompile(unittest.TestCase): @pytest.mark.tpu def test_save_compiled_v4(self): - compiled_trainstep_file='/tmp/test_compiled_v4.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v4-8", "compile_topology_num_slices=1", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v4.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v4-8", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_save_compiled_v5e(self): - compiled_trainstep_file='/tmp/test_compiled_v5e.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-16", "compile_topology_num_slices=1", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_minimal_offloaded_v5e(self): - compiled_trainstep_file='/tmp/test_compiled_v5e_offload.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=1", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=minimal_offloaded", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_compiled_v5e_offload.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=minimal_offloaded", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_save_compiled_v5p_two_slices(self): - compiled_trainstep_file='/tmp/test_compiled_v5p_two_slices.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5p-8", "compile_topology_num_slices=2", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v5p_two_slices.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_sequence_parallelism(self): - compiled_trainstep_file='/tmp/test_compiled.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "use_iota_embed=true", "compile_topology_num_slices=1", - "ici_sequence_parallelism=16", "global_parameter_scale=32", "per_device_batch_size=0.0625", "max_target_length=65536")) + compiled_trainstep_file = "/tmp/test_compiled.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "ici_sequence_parallelism=16", + "global_parameter_scale=32", + "per_device_batch_size=0.0625", + "max_target_length=65536", + )) @pytest.mark.tpu def test_remat_save_dot_except_mlpwi(self): - compiled_trainstep_file='/tmp/test_remat_save_dot_except_mlpwi.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.125", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_dot_except_mlpwi", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlpwi.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.125", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlpwi", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_save_dot_except_mlp(self): - compiled_trainstep_file='/tmp/test_remat_save_dot_except_mlp.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.25", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_dot_except_mlp", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlp.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.25", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlp", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_save_qkv_proj(self): - compiled_trainstep_file='/tmp/test_remat_save_qkv_proj.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.375", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_qkv_proj", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_qkv_proj.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.375", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_qkv_proj", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_full(self): - compiled_trainstep_file='/tmp/test_remat_full.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=1", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=full", - "use_iota_embed=true", "global_parameter_scale=128")) \ No newline at end of file + compiled_trainstep_file = "/tmp/test_remat_full.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=full", + "use_iota_embed=true", + "global_parameter_scale=128", + )) diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 5efc5ebeb..a05e0fb6e 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test for int8""" import os @@ -20,18 +20,32 @@ from train import main as train_main from absl.testing import absltest + class Train(unittest.TestCase): """Smoke test for int8 G3 only""" def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([None, "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", "dataset_type=synthetic", "steps=10", - "enable_checkpointing=False", "quantization=int8"]) - -if __name__ == '__main__': + train_main([ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + "quantization=int8", + ]) + + +if __name__ == "__main__": absltest.main() diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index 8cd41fb33..b3046fd35 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test """ import os @@ -20,18 +20,31 @@ from train import main as train_main from absl.testing import absltest + class Train(unittest.TestCase): """Smoke test G3 only""" def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([None, "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", "dataset_type=synthetic", "steps=10", - "enable_checkpointing=False"]) - -if __name__ == '__main__': + train_main([ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + ]) + + +if __name__ == "__main__": absltest.main() diff --git a/MaxText/tests/weight_dtypes_test.py b/MaxText/tests/weight_dtypes_test.py index 49829d43c..579e23310 100644 --- a/MaxText/tests/weight_dtypes_test.py +++ b/MaxText/tests/weight_dtypes_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Test that all weights are expected dtype (default float32) """ import unittest @@ -30,37 +30,36 @@ Transformer = models.Transformer -class WeightDtypes(unittest.TestCase): - """Test that all weights are expected dtype (default float32) """ - - def get_weights(self, argv): - """ Gets model weights """ - - # Setup necessary inputs to build a model state - pyconfig.initialize(argv) - config = pyconfig.config - quant = quantizations.configure_quantization(config) - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - model = Transformer(config, mesh, quant=quant) - learning_rate_schedule = max_utils.create_learning_rate_schedule(config) - tx = optimizers.get_optimizer(config, learning_rate_schedule) - _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - - abstract_state, _ , _ = max_utils.get_abstract_state(model, tx, config, example_rng, mesh) - return abstract_state.params - - def assert_weights_are_dtype(self, weights, expected_dtype): - jax.tree_util.tree_map_with_path(lambda x,y: self.assertEqual(y.dtype, expected_dtype), weights) - - def test_default_float32(self): - argv = [None, "configs/base.yml", "enable_checkpointing=False"] - weights = self.get_weights(argv) - self.assert_weights_are_dtype(weights, jnp.float32) - - def test_set_bf16(self): - argv = [None, "configs/base.yml", "enable_checkpointing=False", "weight_dtype=bfloat16"] - weights = self.get_weights(argv) - self.assert_weights_are_dtype(weights, jnp.bfloat16) - +class WeightDtypes(unittest.TestCase): + """Test that all weights are expected dtype (default float32)""" + + def get_weights(self, argv): + """Gets model weights""" + + # Setup necessary inputs to build a model state + pyconfig.initialize(argv) + config = pyconfig.config + quant = quantizations.configure_quantization(config) + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + model = Transformer(config, mesh, quant=quant) + learning_rate_schedule = max_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) + _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) + + abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, example_rng, mesh) + return abstract_state.params + + def assert_weights_are_dtype(self, weights, expected_dtype): + jax.tree_util.tree_map_with_path(lambda x, y: self.assertEqual(y.dtype, expected_dtype), weights) + + def test_default_float32(self): + argv = [None, "configs/base.yml", "enable_checkpointing=False"] + weights = self.get_weights(argv) + self.assert_weights_are_dtype(weights, jnp.float32) + + def test_set_bf16(self): + argv = [None, "configs/base.yml", "enable_checkpointing=False", "weight_dtype=bfloat16"] + weights = self.get_weights(argv) + self.assert_weights_are_dtype(weights, jnp.bfloat16) diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index 9e29b1e68..0e7c374ad 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Provides op for tokenizing a dataset.""" @@ -29,34 +29,29 @@ Features = Dict[str, tf.Tensor] - -def _load_sentencepiece_tokenizer(tokenizer_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer(tokenizer_path: str, add_bos: bool = False, add_eos: bool = True, reverse: bool = False): """Load a tf-text SentencePiece tokenizer from given model filepath.""" max_logging.log(f"Tokenizer path: {tokenizer_path}") - with tf.io.gfile.GFile(tokenizer_path, 'rb') as model_fp: + with tf.io.gfile.GFile(tokenizer_path, "rb") as model_fp: sp_model = model_fp.read() - sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) return sp_tokenizer + def load_tokenizer(tokenizer_path: str, add_bos=False, add_eos=True): """Loads the tokenizer at `tokenizer_path` or trains a one from `dataset`.""" try: sp_tokenizer = _load_sentencepiece_tokenizer(tokenizer_path, add_bos, add_eos) return sp_tokenizer except (tf.errors.NotFoundError, tf.errors.InvalidArgumentError): - logging.info('SentencePiece vocab not found, Run train_tokenizer.py') + logging.info("SentencePiece vocab not found, Run train_tokenizer.py") return None @dataclasses.dataclass class TokenizeOp: - sp_tokenizer: Any - data_keys: Iterable[str] = ('inputs', 'targets') + data_keys: Iterable[str] = ("inputs", "targets") def __call__(self, features: Features) -> Features: for k in self.data_keys: diff --git a/MaxText/train.py b/MaxText/train.py index 4a1966945..be67419cc 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Training loop and Decoding of the model.""" @@ -64,56 +64,52 @@ Transformer = models.Transformer EPS = 1e-8 + def validate_train_config(config): - """ Validates the configuration is set correctly for train.py""" + """Validates the configuration is set correctly for train.py""" assert config.run_name, "Erroring out, need a real run_name" - if not config.dataset_path.startswith('gs://'): + if not config.dataset_path.startswith("gs://"): max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") - if not config.base_output_directory.startswith('gs://'): + if not config.base_output_directory.startswith("gs://"): max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive interger." - def get_first_step(state): - with jax.spmd_mode('allow_all'): + with jax.spmd_mode("allow_all"): return int(state.step) def load_next_batch(train_iter, example_batch, config): - """Loads the next batch. Can keep reusing the same batch for performance reasons """ + """Loads the next batch. Can keep reusing the same batch for performance reasons""" if config.reuse_example_batch and example_batch is not None: return example_batch else: return next(train_iter) + def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): """Records scalar metrics to be written to tensorboard""" - metrics['scalar'].update({ - 'perf/step_time_seconds': step_time_delta.total_seconds() - }) - metrics['scalar'].update({ - 'perf/per_device_tflops' : per_device_tflops - }) - metrics['scalar'].update({ - 'perf/per_device_tflops_per_sec': - per_device_tflops / - step_time_delta.total_seconds() - }) - metrics['scalar'].update({'learning/current_learning_rate': lr }) + metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()}) + metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops}) + metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()}) + metrics["scalar"].update({"learning/current_learning_rate": lr}) + _buffered_step = None _buffered_metrics = None + + def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config): """Entry point for all metrics writing in Train's Main. - TODO: would be better as a Class in the future (that initialized all state!) + TODO: would be better as a Class in the future (that initialized all state!) - To avoid introducing an unnecessary dependency, we "double buffer" -- we hold - onto the last metrics and step and only publish when we receive a new metrics and step. - The logic is that this ensures that Jax is able to queues train_steps and we - don't block when turning "lazy" Jax arrays into real Python numbers. + To avoid introducing an unnecessary dependency, we "double buffer" -- we hold + onto the last metrics and step and only publish when we receive a new metrics and step. + The logic is that this ensures that Jax is able to queues train_steps and we + don't block when turning "lazy" Jax arrays into real Python numbers. """ global _buffered_step, _buffered_metrics @@ -131,68 +127,72 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step _buffered_step = step _buffered_metrics = metrics + def write_metrics_to_tensorboard(writer, metrics, step, config): - """ Writes metrics to tensorboard""" - with jax.spmd_mode('allow_all'): + """Writes metrics to tensorboard""" + with jax.spmd_mode("allow_all"): if jax.process_index() == 0: - for metric_name in metrics.get("scalar",[]): + for metric_name in metrics.get("scalar", []): writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) - for metric_name in metrics.get("scalars",[]): + for metric_name in metrics.get("scalars", []): writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) full_log = step % config.log_period == 0 - max_logging.log(f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"loss: {metrics['scalar']['learning/loss']:.3f}") + max_logging.log( + f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " + f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " + f"loss: {metrics['scalar']['learning/loss']:.3f}" + ) if full_log and jax.process_index() == 0: - max_logging.log( - f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'" - ) + max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") writer.flush() -def save_checkpoint(checkpoint_manager, step, state, dataset_type='c4', data_iterator=None): + +def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None): """Wrapper for saving checkpoint""" - if dataset_type == 'c4-array_record': + if dataset_type == "c4-array_record": return checkpoint_manager.save( - step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeSave(item=state), - iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator) - ) - ) + step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeSave(item=state), + iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator), + ), + ) else: return checkpoint_manager.save( - step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeSave(item=state) - )) + step, args=orbax.checkpoint.args.Composite(items=orbax.checkpoint.args.PyTreeSave(item=state)) + ) + # ----------------------------------------------------------------------------- # Top-level Functions # ----------------------------------------------------------------------------- + def record_activation_metrics(output_metrics, intermediate_outputs, config): - """ Adds the activation metrics to the metrics dict""" + """Adds the activation metrics to the metrics dict""" if config.scan_layers: - metrics_dict = intermediate_outputs['intermediates']['decoder']['decoder'] + metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"] for layer_num in range(config.num_decoder_layers): - output_metrics['scalar'][f'activ_fraction_zero/layer_{layer_num:03d}'] = \ - metrics_dict["activation_fraction_zero"][0][layer_num] - output_metrics['scalar'][f'activ_mean/layer_{layer_num:03d}'] = metrics_dict["activation_mean"][0][layer_num] - output_metrics['scalar'][f'activ_stdev/layer_{layer_num:03d}'] = metrics_dict["activation_stdev"][0][layer_num] + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = metrics_dict["activation_fraction_zero"][0][ + layer_num + ] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = metrics_dict["activation_mean"][0][layer_num] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = metrics_dict["activation_stdev"][0][layer_num] else: for layer_num in range(config.num_decoder_layers): - layer = intermediate_outputs['intermediates']['decoder'][f'layers_{layer_num}'] - output_metrics['scalar'][f'activ_fraction_zero/layer_{layer_num:03d}'] = layer["activation_fraction_zero"][0] - output_metrics['scalar'][f'activ_mean/layer_{layer_num:03d}'] = layer["activation_mean"][0] - output_metrics['scalar'][f'activ_stdev/layer_{layer_num:03d}'] = layer["activation_stdev"][0] + layer = intermediate_outputs["intermediates"]["decoder"][f"layers_{layer_num}"] + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = layer["activation_fraction_zero"][0] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer["activation_mean"][0] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer["activation_stdev"][0] + def loss_fn(model, config, data, dropout_rng, params, is_train=True): - '''loss_fn for both train and eval. + """loss_fn for both train and eval. Args: model: A nn.Module @@ -205,33 +205,36 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): Returns: loss: average loss aux: a dictionary including intermediate_outputs, total_loss, and total_weights - ''' + """ # inputs, targets, segments, positions = apply_args rng1, aqt_rng = jax.random.split(dropout_rng) # decimate proportion of data when per_device_batch_size<1 if is_train: for k, v in data.items(): - data[k] = v[:config.global_batch_size_to_train_on,:] - - logits, intermediate_outputs = model.apply(params, - data['inputs'], - data['inputs_position'], - decoder_segment_ids=data['inputs_segmentation'], - enable_dropout=config.enable_dropout if is_train else False, - rngs={'dropout': rng1, 'params': aqt_rng}, mutable='intermediates') - one_hot_targets = jax.nn.one_hot(data['targets'], config.vocab_size) + data[k] = v[: config.global_batch_size_to_train_on, :] + + logits, intermediate_outputs = model.apply( + params, + data["inputs"], + data["inputs_position"], + decoder_segment_ids=data["inputs_segmentation"], + enable_dropout=config.enable_dropout if is_train else False, + rngs={"dropout": rng1, "params": aqt_rng}, + mutable="intermediates", + ) + one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) - xent = nn.with_logical_constraint(xent, ('activation_batch', 'activation_length')) + xent = nn.with_logical_constraint(xent, ("activation_batch", "activation_length")) # Mask out paddings at the end of each example. - xent = xent * (data['targets_segmentation'] != 0) + xent = xent * (data["targets_segmentation"] != 0) total_loss = jnp.sum(xent) - total_weights = jnp.sum(data['targets_segmentation'] != 0) + total_weights = jnp.sum(data["targets_segmentation"] != 0) loss = total_loss / (total_weights + EPS) aux = { - 'intermediate_outputs': intermediate_outputs, - 'total_loss': total_loss, - 'total_weights': total_weights, + "intermediate_outputs": intermediate_outputs, + "total_loss": total_loss, + "total_weights": total_weights, } return loss, aux @@ -254,42 +257,50 @@ def train_step(model, config, state, data, dropout_rng): train_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=True) grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True) (loss, aux), raw_grads = grad_fn(state.params) - intermediate_outputs = aux['intermediate_outputs'] + intermediate_outputs = aux["intermediate_outputs"] if config.gradient_clipping_threshold > 0: grads, _ = optax.clip_by_global_norm(config.gradient_clipping_threshold).update(raw_grads, state, None) else: grads = raw_grads new_state = state.apply_gradients(grads=grads) - metrics = {'scalar': {'learning/loss': loss, 'learning/grad_norm': max_utils.l2norm_pytree(grads), - 'learning/raw_grad_norm': max_utils.l2norm_pytree(raw_grads), - 'learning/param_norm': max_utils.l2norm_pytree(new_state.params)}, 'scalars': {}} + metrics = { + "scalar": { + "learning/loss": loss, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + "learning/param_norm": max_utils.l2norm_pytree(new_state.params), + }, + "scalars": {}, + } if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) return new_state, metrics + def eval_step(model, config, state, data, dropout_rng): """eval_step no backprop and new state compared with train_step.""" eval_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=False) loss, aux = eval_loss_fn(state.params) - total_loss = aux['total_loss'] - total_weights = aux['total_weights'] - metrics = {'scalar': - {'evaluation/loss': loss, - 'evaluation/total_loss': total_loss, - 'evaluation/total_weights': total_weights}} + total_loss = aux["total_loss"] + total_weights = aux["total_weights"] + metrics = { + "scalar": {"evaluation/loss": loss, "evaluation/total_loss": total_loss, "evaluation/total_weights": total_weights} + } return metrics + def create_goodput_recorder(config): if config.enable_goodput_recording: - logger_name = f'goodput_{config.run_name}' + logger_name = f"goodput_{config.run_name}" recorder = goodput.GoodputRecorder(config.run_name, logger_name, jax.process_index() == 0) return recorder return None + def record_goodput(recorder, config, step=None, job_start=False, job_end=False): if recorder and config.enable_goodput_recording: if job_start and step is None: @@ -299,8 +310,9 @@ def record_goodput(recorder, config, step=None, job_start=False, job_end=False): if step is not None: recorder.record_step_start_time(step) + def setup_mesh_and_model(config): - """ Set up the mesh and the model for training + """Set up the mesh and the model for training Args: config @@ -336,8 +348,9 @@ def setup_mesh_and_model(config): tx = optimizers.get_optimizer(config, learning_rate_schedule) return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx + def setup_train_loop(config): - """ Set up prerequisites for the training loop - + """Set up prerequisites for the training loop - checkpoint_manager, PRNG keys, Mesh, Model and optimizer. Set up data iterator and tokenizer, initialize the model. @@ -358,13 +371,24 @@ def setup_train_loop(config): init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config) data_iterator, eval_data_iterator, _ = create_data_iterator_with_tokenizer(config, mesh) - state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(model, data_iterator, - tx, config, init_rng, mesh, checkpoint_manager) + state, state_mesh_annotations, data_iterator = max_utils.setup_training_state( + model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + ) maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh) - return ( init_rng, writer, checkpoint_manager, state_mesh_annotations, model, - mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state) + return ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) def train_loop(config, state=None): @@ -379,27 +403,36 @@ def train_loop(config, state=None): recorder = create_goodput_recorder(config) record_goodput(recorder, config, job_start=True) - ( init_rng, writer, checkpoint_manager, state_mesh_annotations, model, - mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state) = setup_train_loop(config) + ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) = setup_train_loop(config) # pylint: disable=line-too-long - functional_train, in_shard_train, out_shard_train, static_argnums_train, donate_argnums_train = maxtext_utils.get_functional_train_with_signature( - train_step, - mesh, - state_mesh_annotations, - model, - config - ) + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config) if eval_data_iterator: # pylint: disable=line-too-long - functional_eval, in_shard_eval, out_shard_eval, static_argnums_eval, donate_argnums_eval = maxtext_utils.get_functional_eval_with_signature( - eval_step, - mesh, - state_mesh_annotations, - model, - config - ) - + ( + functional_eval, + in_shard_eval, + out_shard_eval, + static_argnums_eval, + donate_argnums_eval, + ) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config) num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") @@ -411,31 +444,33 @@ def train_loop(config, state=None): max_utils.add_config_to_summary_writer(config, writer) # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit - if config.compiled_trainstep_file != '': + if config.compiled_trainstep_file != "": print("Loading the compiled function...", flush=True) # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state) print("Loaded compiled function!", flush=True) else: p_train_step = jax.jit( - functional_train, - in_shardings=in_shard_train, - out_shardings=out_shard_train, - static_argnums=static_argnums_train, - donate_argnums=donate_argnums_train) + functional_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + static_argnums=static_argnums_train, + donate_argnums=donate_argnums_train, + ) if eval_data_iterator: p_eval_step = jax.jit( - functional_eval, - in_shardings=in_shard_eval, - out_shardings=out_shard_eval, - static_argnums=static_argnums_eval, - donate_argnums=donate_argnums_eval) + functional_eval, + in_shardings=in_shard_eval, + out_shardings=out_shard_eval, + static_argnums=static_argnums_eval, + donate_argnums=donate_argnums_eval, + ) - local_metrics_file = open(config.metrics_file, 'a', encoding="utf8") if config.metrics_file else None + local_metrics_file = open(config.metrics_file, "a", encoding="utf8") if config.metrics_file else None running_gcs_metrics = [] if config.gcs_metrics else None - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training first_profiling_step = start_step + config.skip_first_n_steps_for_profiler if config.enable_profiler and first_profiling_step >= config.steps: raise ValueError("Profiling requested but initial profiling step set past training final step") @@ -453,12 +488,10 @@ def train_loop(config, state=None): nextrng = jax.jit(jax.random.fold_in)(init_rng, step) record_goodput(recorder, config, step=step) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - state, metrics = p_train_step( - state, example_batch, nextrng - ) + state, metrics = p_train_step(state, example_batch, nextrng) new_time = datetime.datetime.now() - record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) + record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) last_step_completion = new_time if checkpoint_manager is not None: @@ -474,15 +507,13 @@ def train_loop(config, state=None): if config.eval_interval > 0 and step > start_step and step % config.eval_interval == 0: assert eval_data_iterator - cumulative_eval_metrics = {"total_loss": 0., "total_weights": 0.} + cumulative_eval_metrics = {"total_loss": 0.0, "total_weights": 0.0} for eval_batch in eval_data_iterator: with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step( - state, eval_batch, nextrng - ) - cumulative_eval_metrics['total_loss'] += float(eval_metrics['scalar']['evaluation/total_loss']) - cumulative_eval_metrics['total_weights'] += float(eval_metrics['scalar']['evaluation/total_weights']) - eval_loss = cumulative_eval_metrics['total_loss'] / (cumulative_eval_metrics['total_weights'] + EPS) + eval_metrics = p_eval_step(state, eval_batch, nextrng) + cumulative_eval_metrics["total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"]) + cumulative_eval_metrics["total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"]) + eval_loss = cumulative_eval_metrics["total_loss"] / (cumulative_eval_metrics["total_weights"] + EPS) max_logging.log(f"average loss after {step=}: {eval_loss=}, total_weights={cumulative_eval_metrics['total_weights']}") if eval_loss <= config.target_eval_loss: max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") @@ -494,15 +525,16 @@ def train_loop(config, state=None): if checkpoint_manager is not None: checkpoint_manager.wait_until_finished() - write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics max_utils.close_summary_writer(writer) record_goodput(recorder, config, job_end=True) return state + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" pyconfig.initialize(argv) config = pyconfig.config validate_train_config(config) @@ -512,13 +544,16 @@ def main(argv: Sequence[str]) -> None: vertex_tensorboard_manager.configure_vertex_tensorboard(config) debug_config = debug_configuration.DebugConfig( - stack_trace_config = stack_trace_configuration.StackTraceConfig( - collect_stack_trace = config.collect_stack_trace, - stack_trace_to_cloud = config.stack_trace_to_cloud, - stack_trace_interval_seconds = config.stack_trace_interval_seconds)) + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) + ) diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) with diagnostic.diagnose(diagnostic_config): train_loop(config) + if __name__ == "__main__": app.run(main) diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index 55384c13e..43789ea3f 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Save a Cross Ahead of Time Compiled (XAOT) version of train.py's train step @@ -45,14 +45,15 @@ def validate_config(config): - """ Validates the config is is setup correctly to compile, returning a useful error message if not. """ - assert config.compile_topology != '',\ - "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" - assert config.compile_topology_num_slices > 0,\ - "You must set compile_topology_num_slices to a positive integer" + """Validates the config is is setup correctly to compile, returning a useful error message if not.""" + assert ( + config.compile_topology != "" + ), "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" + assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer" + def get_topology_mesh(config): - """ Get the target hardware devices, and create configured mesh with them """ + """Get the target hardware devices, and create configured mesh with them""" target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) topology_devices = get_topology_desc( platform=target_hardware.platform, @@ -65,8 +66,9 @@ def get_topology_mesh(config): topology_mesh = Mesh(topology_device_mesh, config.mesh_axes) return topology_mesh + def get_shaped_inputs(topology_mesh, config): - """ Get shaped abstractions of inputs to train_step: state, batch and rng """ + """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizier to get shaped versions of the state quant = quantizations.configure_quantization(config) model = Transformer(config, topology_mesh, quant=quant) @@ -79,7 +81,7 @@ def get_shaped_inputs(topology_mesh, config): shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) # Shaped state - abstract_state, state_mesh_annotations, _ = max_utils.get_abstract_state(model, tx, config, example_rng, topology_mesh) + abstract_state, state_mesh_annotations, _ = max_utils.get_abstract_state(model, tx, config, example_rng, topology_mesh) # Shaped batch shaped_batch = input_pipeline_interface.get_shaped_batch(config) @@ -89,29 +91,40 @@ def get_shaped_inputs(topology_mesh, config): return shaped_train_args, shaped_train_kwargs, state_mesh_annotations, model -def jit_and_compile(func, func_input_args, func_input_kwargs, mesh, in_shardings, - out_shardings, static_argnums, donate_argnums, logical_axis_rules): - """ Jit, lower, and compile func.""" +def jit_and_compile( + func, + func_input_args, + func_input_kwargs, + mesh, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + logical_axis_rules, +): + """Jit, lower, and compile func.""" with mesh, logical_axis_rules: jitted = jax.jit( - func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - donate_argnums=donate_argnums + func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, ) lowered = jitted.lower(*func_input_args, **func_input_kwargs) compiled = lowered.compile() return compiled + def save_compiled(compiled, save_name): - """ Serialize and save the compiled function. """ + """Serialize and save the compiled function.""" serialized, _, _ = serialize(compiled) with open(save_name, "wb") as f: pickle.dump(serialized, f) + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") print("Starting train_compile.py...", flush=True) # Parse and validate configuration @@ -127,30 +140,26 @@ def main(argv: Sequence[str]) -> None: # Get function to compile and shardings func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = maxtext_utils.get_functional_train_with_signature( - train.train_step, - topology_mesh, - state_mesh_annotations, - model, - config + train.train_step, topology_mesh, state_mesh_annotations, model, config ) # Compile print("Jitting and compiling train step...", flush=True) compiled = jit_and_compile( - func_to_compile, - shaped_train_args, - shaped_train_kwargs, - topology_mesh, - in_shard, - out_shard, - static_argnums, - donate_argnums, - nn_partitioning.axis_rules(config.logical_axis_rules) + func_to_compile, + shaped_train_args, + shaped_train_kwargs, + topology_mesh, + in_shard, + out_shard, + static_argnums, + donate_argnums, + nn_partitioning.axis_rules(config.logical_axis_rules), ) print("Jitting and compilation complete!", flush=True) # Serialize and save the compiled object - if config.compiled_trainstep_file != '': + if config.compiled_trainstep_file != "": print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") diff --git a/MaxText/train_tokenizer.py b/MaxText/train_tokenizer.py index dab4d99be..03c0d5269 100644 --- a/MaxText/train_tokenizer.py +++ b/MaxText/train_tokenizer.py @@ -1,14 +1,14 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ Train tokenizer @@ -29,28 +29,15 @@ from sentencepiece import SentencePieceTrainer -_DATASET_PATH = flags.DEFINE_string( - 'dataset_path', None, 'Path to the dataset', required=True -) -_DATASET_NAME = flags.DEFINE_string( - 'dataset_name', None, 'Name to the dataset', required=True -) -_VOCAB_SIZE = flags.DEFINE_integer('vocab_size', 32_768, 'Vocab size') -_MAX_CORPUS_CHARS = flags.DEFINE_integer( - 'max_corpus_chars', 10_000_000, 'Max corpus chars' -) -_ASSETS_PATH = flags.DEFINE_string( - 'assets_path', 'assets', 'Name to the dataset' -) -_VOCAB_MODEL_NAME = flags.DEFINE_string( - 'vocab_model_name', 'tokenizer', 'Name to the dataset' -) - -def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('text',) -) -> Tuple[str, int]: +_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True) +_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True) +_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size") +_MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars") +_ASSETS_PATH = flags.DEFINE_string("assets_path", "assets", "Name to the dataset") +_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset") + + +def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: dataset: tf.dataset containing string-data. @@ -61,25 +48,27 @@ def _dump_chars_to_textfile( """ char_count = 0 ds_iter = dataset.as_numpy_iterator() - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + with tempfile.NamedTemporaryFile(delete=False, prefix="/tmp/ds_chars") as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: - line = example[k] + b'\n' + line = example[k] + b"\n" char_count += len(line) outfp.write(line) return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - assets_path: str, - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('text',)): + +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + assets_path: str, + model_path: str, + model_type: str = "unigram", + character_coverage: float = 1.0, + data_keys=("text",), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset @@ -94,65 +83,69 @@ def _train_sentencepiece(dataset: tf.data.Dataset, Returns: path to the trained sentencepiece vocabulary model. """ - if model_path.startswith('gs://'): + if model_path.startswith("gs://"): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) abs_assets_path = os.path.abspath(os.path.expanduser(assets_path)) - fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) + with tempfile.NamedTemporaryFile(delete=False, prefix="/tmp/sp_tmp") as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join([ - f'--input={fname}', f'--vocab_size={vocab_size}', - f'--character_coverage={character_coverage}', - f'--model_prefix={model_fp.name}', f'--model_type={model_type}' + argstr = " ".join([ + f"--input={fname}", + f"--vocab_size={vocab_size}", + f"--character_coverage={character_coverage}", + f"--model_prefix={model_fp.name}", + f"--model_type={model_type}", ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. - copy_rename_path = abs_model_path + '.rntmp' - if not model_path.startswith('gs://'): + copy_rename_path = abs_model_path + ".rntmp" + if not model_path.startswith("gs://"): tf.io.gfile.makedirs(abs_assets_path) - tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) + tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) - logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) + logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path) else: while not tf.io.gfile.exists(abs_model_path): time.sleep(1) time.sleep(1) return abs_model_path -def train_tokenizer(dataset: tf.data.Dataset, - *, - assets_path: str, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: Tuple[str] = ('text',)): + +def train_tokenizer( + dataset: tf.data.Dataset, + *, + assets_path: str, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: Tuple[str] = ("text",), +): """tokenizer training function""" - logging.info('SentencePiece vocab not found, building one from data.') + logging.info("SentencePiece vocab not found, building one from data.") vocab_path = _train_sentencepiece( dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, assets_path=assets_path, model_path=vocab_path, - data_keys=data_keys) - logging.info('Model saved at %s', vocab_path) + data_keys=data_keys, + ) + logging.info("Model saved at %s", vocab_path) def main(argv): del argv - os.environ['TFDS_DATA_DIR'] = _DATASET_PATH.value + os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value read_config = tfds.ReadConfig( - shuffle_seed = 0, + shuffle_seed=0, ) train_ds_builder = tfds.builder(_DATASET_NAME.value) - train_ds = train_ds_builder.as_dataset(split='train', read_config=read_config, shuffle_files=True) + train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) train_tokenizer( train_ds, assets_path=_ASSETS_PATH.value, @@ -162,5 +155,5 @@ def main(argv): ) -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/MaxText/vertex_tensorboard.py b/MaxText/vertex_tensorboard.py index 1cb438db5..9a106c32b 100644 --- a/MaxText/vertex_tensorboard.py +++ b/MaxText/vertex_tensorboard.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Utilities for Tensorboard in Vertex AI.""" @@ -40,7 +40,7 @@ def __del__(self): def setup(self): """Creates Tensorboard instance and Experiment in Vertex AI. - + Returns: URL to view Vertex Tensorboard created in Google Cloud Project. """ @@ -54,19 +54,21 @@ def setup(self): # Create Vertex Tensorboard instance vertex_tensorboard_name = os.environ.get("TENSORBOARD_NAME") - instance_id = tensorboard.create_instance(project=vertex_tensorboard_project, - location=vertex_tensorboard_region, - tensorboard_name=vertex_tensorboard_name) + instance_id = tensorboard.create_instance( + project=vertex_tensorboard_project, location=vertex_tensorboard_region, tensorboard_name=vertex_tensorboard_name + ) # Failed to create Vertex Tensorboard instance if instance_id is None: return None # Create Vertex Experiment vertex_experiment_name = os.environ.get("EXPERIMENT_NAME") - _, tensorboard_url = tensorboard.create_experiment(project=vertex_tensorboard_project, - location=vertex_tensorboard_region, - experiment_name=vertex_experiment_name, - tensorboard_name=vertex_tensorboard_name) + _, tensorboard_url = tensorboard.create_experiment( + project=vertex_tensorboard_project, + location=vertex_tensorboard_region, + experiment_name=vertex_experiment_name, + tensorboard_name=vertex_tensorboard_name, + ) return tensorboard_url def upload_data(self, tensorboard_dir): @@ -84,18 +86,22 @@ def upload_data(self, tensorboard_dir): max_logging.log("Vertex Tensorboard configurations are not set. Data will not be uploaded to Vertex AI.") self.uploader_flag = False - max_logging.log(f"Data will be uploaded to Vertex Tensorboard instance: {tensorboard_name} " - f"and Experiment: {experiment_name} in {tensorboard_region}.") - uploader.start_upload_to_tensorboard(project=tensorboard_project, - location=tensorboard_region, - experiment_name=experiment_name, - tensorboard_name=tensorboard_name, - logdir=tensorboard_dir) + max_logging.log( + f"Data will be uploaded to Vertex Tensorboard instance: {tensorboard_name} " + f"and Experiment: {experiment_name} in {tensorboard_region}." + ) + uploader.start_upload_to_tensorboard( + project=tensorboard_project, + location=tensorboard_region, + experiment_name=experiment_name, + tensorboard_name=tensorboard_name, + logdir=tensorboard_dir, + ) self.uploader_flag = True def configure_vertex_tensorboard(self, config): """Creates Vertex Tensorboard and start thread to upload data to Vertex Tensorboard.""" - if jax.process_index()==0: + if jax.process_index() == 0: if not os.environ.get("TENSORBOARD_PROJECT"): if not config.vertex_tensorboard_project: os.environ["TENSORBOARD_PROJECT"] = max_utils.get_project() @@ -112,11 +118,11 @@ def configure_vertex_tensorboard(self, config): if not os.environ.get("EXPERIMENT_NAME"): os.environ["EXPERIMENT_NAME"] = config.run_name - if config.use_vertex_tensorboard: # running MaxText on GCE + if config.use_vertex_tensorboard: # running MaxText on GCE tensorboard_url = self.setup() if tensorboard_url is None: raise ValueError("Unable to create Tensorboard and Experiment in Vertex AI.") max_logging.log(f"View your Vertex AI Tensorboard at: {tensorboard_url}") self.upload_data(config.tensorboard_dir) - elif os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): # running MaxText via XPK + elif os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): # running MaxText via XPK self.upload_data(config.tensorboard_dir) diff --git a/code_style.sh b/code_style.sh new file mode 100644 index 000000000..588ef70fa --- /dev/null +++ b/code_style.sh @@ -0,0 +1,33 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Clean up Python codes using Pylint & Pyink +# Googlers: please run `sudo apt install pipx; pipx install pylint --force; pipx install pyink==23.10.0` in advance + +set -e + +FOLDERS_TO_FORMAT=("MaxText" "pedagogical_examples") +LINE_LENGTH=$(grep -E "^max-line-length=" pylintrc | cut -d '=' -f 2) + +for folder in "${FOLDERS_TO_FORMAT[@]}" +do + pyink "$folder" --pyink-indentation=2 --line-length=${LINE_LENGTH} +done + +for folder in "${FOLDERS_TO_FORMAT[@]}" +do + pylint "./$folder" +done + +echo "Successfully clean up all codes." diff --git a/pedagogical_examples/non_spmd.py b/pedagogical_examples/non_spmd.py index 743bd4138..9918cbec3 100644 --- a/pedagogical_examples/non_spmd.py +++ b/pedagogical_examples/non_spmd.py @@ -1,22 +1,22 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -''' +""" This programs demonstrates embarrassingly parallelizable non-SPMD computations in Jax, in this case by having each process_index run its own computation. The same approach can be extended for non-embarrassingly parallelizable computations. @@ -24,7 +24,7 @@ then using a `host_local_array_to_global_array` to reshard into a new global array. An important limitation of this approach is that we cannot overlap communication and computation between the different kernel calls. -''' +""" import jax @@ -34,23 +34,21 @@ import numpy as np - - # Notice this is jax.local_devices(), not jax.devices(). Hence each process (on TPUVMs, each VM) will run separate programs # on its mesh. mesh = Mesh(np.array(jax.local_devices()), ["data"]) sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None)) idx = jax.process_index() + # Example step depends on idx which is different on each program def example_step(): - return idx * jax.numpy.ones((idx+1)) + return idx * jax.numpy.ones((idx + 1)) + jit_func = jax.jit( - example_step, - out_shardings=sharding, - ) + example_step, + out_shardings=sharding, +) print(f"{idx=} -> {jit_func()=}") - - diff --git a/pedagogical_examples/shardings.py b/pedagogical_examples/shardings.py index 85198c3ee..912266667 100644 --- a/pedagogical_examples/shardings.py +++ b/pedagogical_examples/shardings.py @@ -1,22 +1,22 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -'''This script is used to measure the performance of different sharding schemes on TPU.''' +"""This script is used to measure the performance of different sharding schemes on TPU.""" from absl import app from absl import flags @@ -32,79 +32,67 @@ from typing import Sequence parser = argparse.ArgumentParser( - description="Experiment different sharding techniques with a simple NN.\ + description="Experiment different sharding techniques with a simple NN.\ Ensure 1) The product of dcn dimensions == number of slices \ 2) product of ici dimension = number of devices per slice" - ) +) parser.add_argument( - "--profiler_path", "-p", + "--profiler_path", + "-p", required=False, default="", help="Path to the profiler where the script will write to.", - type=str -) -parser.add_argument( - "--embedding_dimension", "-d", - required=False, - default=2048, - type=int -) -parser.add_argument( - "--batch_size", "-b", - required=False, - default=131072, - type=int -) -parser.add_argument( - "--num_layers", "-n", - required=False, - default=4, - type=int + type=str, ) +parser.add_argument("--embedding_dimension", "-d", required=False, default=2048, type=int) +parser.add_argument("--batch_size", "-b", required=False, default=131072, type=int) +parser.add_argument("--num_layers", "-n", required=False, default=4, type=int) parser.add_argument( - "--dcn_data_parallelism", "-dd", - help="N-way Data Parallelism across slices", - required=False, - default=1, - type=int + "--dcn_data_parallelism", "-dd", help="N-way Data Parallelism across slices", required=False, default=1, type=int ) parser.add_argument( - "--dcn_fsdp_parallelism", "-df", + "--dcn_fsdp_parallelism", + "-df", help="Fsdp parallelism across slices that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--dcn_tensor_parallelism", "-dt", + "--dcn_tensor_parallelism", + "-dt", help="Tensor parallelism across slices that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--ici_data_parallelism", "-id", + "--ici_data_parallelism", + "-id", help="Data parallelism within each slice that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--ici_fsdp_parallelism", "-if", + "--ici_fsdp_parallelism", + "-if", help="Number of shards for Fsdp Parallelism within each slice.", required=False, default=4, - type=int + type=int, ) parser.add_argument( - "--ici_tensor_parallelism", "-it", + "--ici_tensor_parallelism", + "-it", help="Number of shards for Tensor Parallelism within each slice.", required=False, default=1, - type=int + type=int, ) args = parser.parse_args() + def main(_argv: Sequence[str]) -> None: def activate_profiler(profiler_path): if profiler_path: @@ -115,16 +103,16 @@ def deactivate_profiler(profiler_path): if profiler_path: jax.profiler.stop_trace() - def simple_timeit(f, tries = 5, verbose = True): - '''Simple utility to time a function for multiple runs''' + def simple_timeit(f, tries=5, verbose=True): + """Simple utility to time a function for multiple runs""" outcomes = [] - f() #warm it up! + f() # warm it up! for _ in range(tries): s = datetime.datetime.now() f() e = datetime.datetime.now() - outcomes.append((e-s).total_seconds()) - average_time = sum(outcomes)/len(outcomes) + outcomes.append((e - s).total_seconds()) + average_time = sum(outcomes) / len(outcomes) if verbose: print(f"average time: {average_time}, timings (seconds) {outcomes}") return average_time @@ -138,15 +126,18 @@ def simple_timeit(f, tries = 5, verbose = True): assert len(devices) > 1, "You must have at least two devices" # Assert that we have correct inputs of sharding that fit the number of chips - assert np.product(dcn_parallelism) * np.product(ici_parallelism) == num_devices, f"Number of devices {num_devices} \ + assert ( + np.product(dcn_parallelism) * np.product(ici_parallelism) == num_devices + ), f"Number of devices {num_devices} \ does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}" - multi_slice_env = hasattr(jax.devices()[0], 'slice_index') + multi_slice_env = hasattr(jax.devices()[0], "slice_index") # Create device mesh if multi_slice_env: - assert args.dcn_data_parallelism == 1 + max(x.slice_index for x in jax.devices()), \ - f"Number of slices given {args.dcn_data_parallelism} \ + assert args.dcn_data_parallelism == 1 + max( + x.slice_index for x in jax.devices() + ), f"Number of slices given {args.dcn_data_parallelism} \ does not match the number fetched from jax devices {jax.devices()[0]}" devices_array = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism) else: @@ -156,27 +147,27 @@ def simple_timeit(f, tries = 5, verbose = True): mesh = Mesh(devices_array, ["data", "fsdp", "tensor"]) - data_sharding = PartitionSpec(("data", "fsdp"), "tensor") + data_sharding = PartitionSpec(("data", "fsdp"), "tensor") # We assume parameters are stored in a decreasing order of dimension size parameter_sharding = PartitionSpec("tensor", "fsdp") BATCH = len(jax.devices()) * args.batch_size D_EMB = args.embedding_dimension - D_FF = 4 * D_EMB + D_FF = 4 * D_EMB NUM_LAYERS = args.num_layers parameters = 2 * D_FF * D_EMB * NUM_LAYERS parameter_bytes = 2 * parameters - activation_bytes = 2 * ( BATCH * ( D_FF+D_EMB) ) * NUM_LAYERS + activation_bytes = 2 * (BATCH * (D_FF + D_EMB)) * NUM_LAYERS memory_bytes = parameter_bytes + activation_bytes print(f"total {memory_bytes/1e9} GB, parameters {parameter_bytes/1e9} GB, activations {activation_bytes/1e9} GB") def gen_layer(random_key): - keys = jax.random.split(random_key, num = 4) + keys = jax.random.split(random_key, num=4) return { - "EMB2FF" : 1e-4 * jax.random.normal( keys[0], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), - "FF2EMB" : 1e-4 * jax.random.normal( keys[1], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), + "EMB2FF": 1e-4 * jax.random.normal(keys[0], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), + "FF2EMB": 1e-4 * jax.random.normal(keys[1], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), } def gen_layers(random_key): @@ -187,8 +178,7 @@ def gen_layers(random_key): return tuple(layers) def gen_data(random_key): - return jax.random.uniform(random_key, (BATCH, D_EMB), dtype=jax.numpy.bfloat16 ) - + return jax.random.uniform(random_key, (BATCH, D_EMB), dtype=jax.numpy.bfloat16) def multiply_layer(in_act, in_layer): with jax.named_scope("M1"): @@ -210,7 +200,7 @@ def multiply_layers(in_act, in_layers): return x, in_layers def multiply_layers_with_loss(in_act, in_layers): - x, _ = multiply_layers(in_act, in_layers) + x, _ = multiply_layers(in_act, in_layers) return jax.numpy.sum(x) multiply_layers_and_grad = jax.value_and_grad(multiply_layers_with_loss, argnums=[1]) @@ -220,39 +210,27 @@ def training_step(in_act, in_layers): out_layers = jax.tree_map(lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0]) return out_layers - print("finished includes ", flush = True) + print("finished includes ", flush=True) replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding) - parameter_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) - data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) jit_func = jax.jit( - training_step, - in_shardings=(replicated_sharding, parameter_mesh_shardings), - out_shardings=data_pspec_shardings, - ) + training_step, + in_shardings=(replicated_sharding, parameter_mesh_shardings), + out_shardings=data_pspec_shardings, + ) - data_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) + data_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) - jit_gen_data = jax.jit( - gen_data, - in_shardings=None, - out_shardings=data_mesh_shardings - ) + jit_gen_data = jax.jit(gen_data, in_shardings=None, out_shardings=data_mesh_shardings) - parameter_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) - jit_gen_layers = jax.jit( - gen_layers, - in_shardings=None, - out_shardings=parameter_mesh_shardings - ) + jit_gen_layers = jax.jit(gen_layers, in_shardings=None, out_shardings=parameter_mesh_shardings) # starting the profiler outside `with` statement, # will call it right before the computation once b/301309635 is resolved @@ -261,14 +239,16 @@ def training_step(in_act, in_layers): key = jax.random.PRNGKey(0) presharded_X = jax.block_until_ready(jit_gen_data(key)) presharded_layers = jax.block_until_ready(jit_gen_layers(key)) - TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices()) - time = simple_timeit(lambda : jax.block_until_ready(jit_func(presharded_X, presharded_layers))) - print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush = True) + TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices()) + time = simple_timeit(lambda: jax.block_until_ready(jit_func(presharded_X, presharded_layers))) + print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush=True) deactivate_profiler(args.profiler_path) + def parse_flags(argv): return parser.parse_args(argv[1:]) + if __name__ == "__main__": flags.FLAGS.mark_as_parsed() app.run(main, flags_parser=parse_flags) diff --git a/pedagogical_examples/shmap_collective_matmul.py b/pedagogical_examples/shmap_collective_matmul.py index fe8c38be9..de80fcf77 100644 --- a/pedagogical_examples/shmap_collective_matmul.py +++ b/pedagogical_examples/shmap_collective_matmul.py @@ -1,24 +1,25 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -'''This script is an example collective matmul.''' +"""This script is an example collective matmul.""" import os + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" import numpy as np @@ -27,7 +28,6 @@ import jax.numpy as jnp from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P -from jax.experimental import mesh_utils from jax.sharding import Mesh from jax.experimental.shard_map import shard_map @@ -50,51 +50,57 @@ import string import datetime -def simple_timeit(f, *args, tries = 10, trace_base_dir = None, task = None): - '''Simple utility to time a function for multiple runs''' - assert task is not None - trace_name = f"t_{task}_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) +def simple_timeit(f, *args, tries=10, trace_base_dir=None, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + + if trace_base_dir: + trace_dir = f"{trace_base_dir}/{trace_name}" + else: + trace_dir = None - if trace_base_dir: - trace_dir = f"{trace_base_dir}/{trace_name}" - else: - trace_dir = None + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + if trace_dir: + jax.profiler.start_trace(trace_dir) - outcomes_ms = [] - jax.block_until_ready(f(*args)) #warm it up! - if trace_dir: - jax.profiler.start_trace(trace_dir) + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + if trace_dir: + jax.profiler.stop_trace() - for _ in range(tries): - s = datetime.datetime.now() - jax.block_until_ready(f(*args)) - e = datetime.datetime.now() - outcomes_ms.append(1000*(e-s).total_seconds()) - if trace_dir: - jax.profiler.stop_trace() + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}") + return average_time_ms - average_time_ms = sum(outcomes_ms)/len(outcomes_ms) - print(f"{task}: average time milliseconds: {average_time_ms:.2f}") - return average_time_ms # gen data def gen_data_fn(): - key = jax.random.PRNGKey(np.random.randint(0, 256)) - activations = jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16) - weights = jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16) - return activations, weights + key = jax.random.PRNGKey(np.random.randint(0, 256)) + activations = jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16) # pylint: disable=redefined-outer-name + weights = jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16) # pylint: disable=redefined-outer-name + return activations, weights + data_fn = pjit( gen_data_fn, out_shardings=(P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None), P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None)), ) -def matmul(activations, weights): - return jnp.einsum("bsE,Ehd->bshd", activations, weights) + +def matmul(activations, weights): # pylint: disable=redefined-outer-name + return jnp.einsum("bsE,Ehd->bshd", activations, weights) + jit_matmul = pjit(matmul, out_shardings=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None)) + @partial( shard_map, mesh=global_mesh, @@ -105,30 +111,48 @@ def matmul(activations, weights): out_specs=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None), check_rep=False, ) -def collective_matmul(activations, weights): - print(f"sh_map {activations.shape=} {weights.shape=}") - - axis_size = jax.lax.psum(1, axis_name=MESH_TENSOR_AXIS) - axis_index = jax.lax.axis_index(axis_name=MESH_TENSOR_AXIS) - # The current sequence chunk - chunk_size = activations.shape[1] - mid_chunk = chunk_size // 2 - # create accum buffer - accum = jnp.zeros( - ( - activations.shape[0], - activations.shape[1] * axis_size, - weights.shape[-2], - weights.shape[-1], - ), - dtype=activations.dtype, - ) +def collective_matmul(activations, weights): # pylint: disable=redefined-outer-name + """Collective matrix multiply""" + print(f"sh_map {activations.shape=} {weights.shape=}") + + axis_size = jax.lax.psum(1, axis_name=MESH_TENSOR_AXIS) + axis_index = jax.lax.axis_index(axis_name=MESH_TENSOR_AXIS) + # The current sequence chunk + chunk_size = activations.shape[1] + mid_chunk = chunk_size // 2 + # create accum buffer + accum = jnp.zeros( + ( + activations.shape[0], + activations.shape[1] * axis_size, + weights.shape[-2], + weights.shape[-1], + ), + dtype=activations.dtype, + ) + + # compute first chunk + update = jnp.einsum("bsE,Ehd->bshd", activations, weights) + update_index = (0, axis_index * chunk_size, 0, 0) + accum = jax.lax.dynamic_update_slice(accum, update, update_index) + activation_forward, activation_backward = jnp.split(activations, 2, axis=1) + activation_forward = jax.lax.ppermute( + activation_forward, + axis_name=MESH_TENSOR_AXIS, + perm=[(j, (j + 1) % axis_size) for j in range(axis_size)], + ) + activation_backward = jax.lax.ppermute( + activation_backward, + axis_name=MESH_TENSOR_AXIS, + perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], + ) + + # split activations into chunks and send + def scanned_call(i, carrys): + accum, activation_forward, activation_backward = carrys + update_forward = jnp.einsum("bsE,Ehd->bshd", activation_forward, weights) + update_backward = jnp.einsum("bsE,Ehd->bshd", activation_backward, weights) - # compute first chunk - update = jnp.einsum("bsE,Ehd->bshd", activations, weights) - update_index = (0, axis_index * chunk_size, 0, 0) - accum = jax.lax.dynamic_update_slice(accum, update, update_index) - activation_forward, activation_backward = jnp.split(activations, 2, axis=1) activation_forward = jax.lax.ppermute( activation_forward, axis_name=MESH_TENSOR_AXIS, @@ -140,70 +164,46 @@ def collective_matmul(activations, weights): perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], ) - # split activations into chunks and send - def scanned_call(i, carrys): - accum, activation_forward, activation_backward = carrys - update_forward = jnp.einsum("bsE,Ehd->bshd", activation_forward, weights) - update_backward = jnp.einsum("bsE,Ehd->bshd", activation_backward, weights) - - activation_forward = jax.lax.ppermute( - activation_forward, - axis_name=MESH_TENSOR_AXIS, - perm=[(j, (j + 1) % axis_size) for j in range(axis_size)], - ) - activation_backward = jax.lax.ppermute( - activation_backward, - axis_name=MESH_TENSOR_AXIS, - perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], - ) - - forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size - backward_update_index = ((axis_index + i + 1) % axis_size) * chunk_size + mid_chunk - - accum = jax.lax.dynamic_update_slice(accum, update_forward, (0, forward_update_index, 0, 0)) - accum = jax.lax.dynamic_update_slice(accum, update_backward, (0, backward_update_index, 0, 0)) - return (accum, activation_forward, activation_backward) - - print(f"{accum.shape=}") - - accum, _, _ = jax.lax.fori_loop( - 0, (axis_size - 1), scanned_call, (accum, activation_forward, activation_backward) - ) - return accum - -with global_mesh: - activations, weights = data_fn() + forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size + backward_update_index = ((axis_index + i + 1) % axis_size) * chunk_size + mid_chunk - jax.block_until_ready(activations) - jax.block_until_ready(weights) + accum = jax.lax.dynamic_update_slice(accum, update_forward, (0, forward_update_index, 0, 0)) + accum = jax.lax.dynamic_update_slice(accum, update_backward, (0, backward_update_index, 0, 0)) + return (accum, activation_forward, activation_backward) - @jax.jit - def run_naive(_activations, _weights): - with jax.named_scope("naive_matmul"): - outputs = jit_matmul(_activations, _weights) - return outputs + print(f"{accum.shape=}") - @jax.jit - def run_collective(_activations, _weights): - with jax.named_scope("collective_matmul"): - manual_outputs = jax.jit(collective_matmul)(_activations, _weights) - return manual_outputs + accum, _, _ = jax.lax.fori_loop(0, (axis_size - 1), scanned_call, (accum, activation_forward, activation_backward)) + return accum +with global_mesh: + activations, weights = data_fn() - - naive_outputs = run_naive(activations, weights) - collective_outputs = run_collective(activations, weights) + jax.block_until_ready(activations) + jax.block_until_ready(weights) - print(f"input {activations.shape=} {activations.addressable_shards[0].data.shape=}") - print(f"input {weights.shape=} {weights.addressable_shards[0].data.shape=}") - print(f"naive_outputs {naive_outputs.shape=} {naive_outputs.addressable_shards[0].data.shape=}") - print(f"collective_outputs {collective_outputs.shape=} {collective_outputs.addressable_shards[0].data.shape=}") + @jax.jit + def run_naive(_activations, _weights): + with jax.named_scope("naive_matmul"): + outputs = jit_matmul(_activations, _weights) + return outputs + @jax.jit + def run_collective(_activations, _weights): + with jax.named_scope("collective_matmul"): + manual_outputs = jax.jit(collective_matmul)(_activations, _weights) + return manual_outputs + naive_outputs = run_naive(activations, weights) + collective_outputs = run_collective(activations, weights) - assert jnp.allclose(naive_outputs, collective_outputs), "Two algorithms should match but don't" + print(f"input {activations.shape=} {activations.addressable_shards[0].data.shape=}") + print(f"input {weights.shape=} {weights.addressable_shards[0].data.shape=}") + print(f"naive_outputs {naive_outputs.shape=} {naive_outputs.addressable_shards[0].data.shape=}") + print(f"collective_outputs {collective_outputs.shape=} {collective_outputs.addressable_shards[0].data.shape=}") - simple_timeit(run_naive, activations, weights, task = "naive") - simple_timeit(run_collective, activations, weights, task = "collective") + assert jnp.allclose(naive_outputs, collective_outputs), "Two algorithms should match but don't" + simple_timeit(run_naive, activations, weights, task="naive") + simple_timeit(run_collective, activations, weights, task="collective") diff --git a/pylintrc b/pylintrc index 4897238f0..4ebb6438a 100644 --- a/pylintrc +++ b/pylintrc @@ -29,6 +29,8 @@ jobs=4 # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no +# check python modules in the dir recursively +recursive=y [MESSAGES CONTROL]