Skip to content

Commit

Permalink
sync to head (#771)
Browse files Browse the repository at this point in the history
* Move tpu end-to-end test scripts to tpu folder

* unify WORKDIR to /deps

* Share GCS path between Gemma-7b tests

* Add README for llama2-7B

* adding script to fix the style and adding modified/fixed files with line length 125

* Move apt install from `rto_setup.sh` to `setup.sh`

* Update instructions for installing snap.

* Removes batch size from prefill attention calculation.

* Fixes for inf testing.

* Revert "Fixes for inf testing."

This reverts commit b15b1d5.

* Fixes

* Fix subset of hosts dataloading

* inference microbenchmark

  - allow run specified stages
  - allow run specific prefill length(s)
  - delete prefill result
  - printout prefill result

added funcs in max_utils

* Update Run_MaxText_via_xpk.md

Fixing typo.

* inference_microbenchmark:

  - time prefill only
  - benchmark prefill and insert

* Mark nvidia devtools repo as trusted

This is a stopgaps measure to circumvent the nvidia repo's gpg signature issue

* Explicitly set AQT Freezer mode in MaxText.

PiperOrigin-RevId: 627250589

* Move aqtp pin up

* Pre-commit config

* Update 128B config on v5e to use qkv_proj_offloaded remat_policy

* [MaxText] Rename llama2_7b_single_host_gpu.yml to make it clear that it can be used for any number of host.

PiperOrigin-RevId: 627804089

* Split Mixtral test into two scripts

* Update jax.tree_map to jax.tree_util.tree_map

* change norm sharding

fix lint

Revert "fix lint"

This reverts commit d8dc450.

fix lint

* Change l2norm to use jnp.sqrt

* Fix test_tokenize

* Streamlined setup.sh to have fewer apt install calls

* loosen tolerance in assert_params_sufficiently_sharded

* Enable entropy on multihost CPUs.

* Add tests to GPU runner

* Replace deprecated np.product with np.prod

* fix norm sharding

* Add Llama2-70b test

* Internal change only.

PiperOrigin-RevId: 630446330

* Add more tests for Mixtral

* Make some AQT dataclasses to use keyword-only fields (1/N)

This cl introduces an temporary decorator that will be temporarily used during this migration. The eventual goal is to enforce kw_only=True in all dataclasses unless it's not feasible, aiming to make AQT less error-prune and improve readability.

PiperOrigin-RevId: 631132072

* Reverts e8b53e5

PiperOrigin-RevId: 631465526

* Update tflops calculation

* fix sharding on generate cache in prefill results.

* Remove async XLA_FLAGS from A3 configs.

XLA PR openxla/xla#11422 removed some XLA flags relating to async collectives. This caused the A3 configs to fail to run, so this change removes such flags from the A3 configs. The flags removed are:

--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_async_all_reduce=true

Such flags had no impact before the XLA PR as the async collectives were already enabled by default.

* Update llama2_7b_gpu.yml

PiperOrigin-RevId: 631752008

* Add forward pass logit check test for Llama2-7b

* Eval the command string from XPK for GPU script

* Remove cases where the deprecated --xla_gpu_simplify_all_fp_conversions is set to its default value.

PiperOrigin-RevId: 633645462

* streamline CI test structure

* fix pylint

fix pylint: Using variable 'p_eval_step' before assignment (#651)

* Remove async XLA_FLAGS from A3 configs

* Add llama-70b gpu config.

PiperOrigin-RevId: 634267313

* Support data input from HuggingFace

* Update the NCCL flags for A3+.

* add gemma logit test

* Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

* fix hf input pipeline

* Fix prefill assertion

* Remove decode asserts from Gemma test files

* add single controller flag

* fix OOM issue running inference microbenchmark with llama13b on v5e4

* Add Llama2 13B Tests

* Don't clip fp8 stats

* Integrate nsys profiler

Remove 'enable_profiler' config and add 'profiler' config instead

* Add MoE matmul implementation

* fix OUTPUT_PATH in v5e/128b.sh

* squash

* Update flops calculation to active experts in moe

* Enable kv cache layout control

* Fix Gemma Readme link

* Internal change only.

Reverts a28f518

PiperOrigin-RevId: 639890999

* Upgrade Pinned Base Image for GPU

* Metrics bug: server_lib should be config_lib

* Fix MoE matmul scale issue

* Removed unused Pallas import from layers/attentions.py

PiperOrigin-RevId: 640481280

* Change norm sharding for llama2-7b to fsdp.

PiperOrigin-RevId: 640498890

* Copybara import of the project:

--
d7d694f by RissyRan <[email protected]>:

Fix forward test for Mixtral

COPYBARA_INTEGRATE_REVIEW=#679 from google:ranran_fix_forward_test d7d694f
PiperOrigin-RevId: 640537456

* Set additional flags for a3 and a3plus

* Use run_id instead of sha for docker tag

* refactor data input pipeline and add perf data

* Add gpt3 175b on v5e config

* Pipeline parallelism support (linear only)

* Turn on layer scanning for llama2-7b on GPU.

This better utilizes recent optimizations to collective approximation in the XLA latency hiding scheduler.

PiperOrigin-RevId: 642559284

* reshape q

* Add profiler flags to JetStream server

Add jetstream config

backward compatible

* fix tfds instruction

* Add vanilla megablox to MoE

* Add llama2 70b training config for v5e

* base.yml changes

circular changes to pipeline.py

pyconfig circ changes

pipeline parallel tests circular style

tree map, half passed tests

Total iterations circularized

improved iteration comment

run all tests

test both circular and non-circular

circ storage comment

circ storage pushing index comment

* Account for new mesh axes for llama2-7b, and llama2-70b on GPUs.

PiperOrigin-RevId: 643999933

* Sharding the llama2 70b on v5e-16 more efficiently.

https://arxiv.org/pdf/2211.05102
https://arxiv.org/pdf/1909.08053

* add compute_axis_order

* Add maxengine_server configs to base.yml

* Add FSDP + Megablox

* Llama3-8b model config

* MaxText package

* fix data loading from HF hub

* Fix llama2-{7,70}b sharding on GPU.

PiperOrigin-RevId: 645365795

* Move stage to second axis in mesh

Move stage to second axis in mesh

* Copybara import of the project:

--
1718b89 by RissyRan <[email protected]>:

Refactor permute and unpermute operations

COPYBARA_INTEGRATE_REVIEW=#714 from google:refactor_mega b101cbc
PiperOrigin-RevId: 645591567

* Fix Mesh setup for multiprocess CPUs.

* add kv_quant_axis

* Add a directory check for the . If it fails, attempt to check a path relative to the base config, similar to what is done for model configurations.

Minor update

Remove the raised exception

* Add mistral tokenizer to maxtext/assets

* Update the dependencies to prepare for integration of emergency checkpointing

Withhold some package versions

Update version of typing_extensions

* Make broadcasting from one replica to all more memory efficient

PiperOrigin-RevId: 646526020

* Inference Microbenchmark Sweep

* Fix mesh_axes and data_sharding for LLaMA 2 GPU configs.

PiperOrigin-RevId: 646795068

* Allow owners to have any approver

Fix AddLabel syntax

Fix punctuation

* Enable saving using Orbax's emergency checkpoint manager

fix data loading from HF hub

Add explanation to the emergency checkpoint feature

Fix pylint issues

Minor changes to the config file

resolve conflicts

Inference Microbenchmark Sweep

Fix mesh_axes and data_sharding for LLaMA 2 GPU configs.

PiperOrigin-RevId: 646795068

* Add Llama2 7B, 13B high performance training configs

* Load/Save Aqt quantized checkpoint.

* modify prefill to return first token

* Fix and protect simple_layer

Fix and protect simple_layer

Fix and protect simple_layer

Fix and protect simple_layer

* Adding option for int4 quantization to kvcache.

* support eval dataset and refactor

* Support partial overrides for logical_axis_rules.

* Fix simple test step count

* Clean up MoE brute force implementation

* Preliminary restore with lots of hardcoding and hacking

Refactor the code and remove the hardcoding

More refactoring

Cleanup for pull request

Address linting issues

Preliminary restore with lots of hardcoding and hacking

Refactor the code and remove the hardcoding

More refactoring

Cleanup for pull request

Address linting issues

Small formatting

Fix merging issues

* Add convergence tests on A3 GPU

* Update tile size

* Handle cases where memstats are not available for the device.

Memstats are not guaranteed to be available and can throw an error or return None. This change will handle both `jaxlib.xla_extension.XlaRuntimeError` if the device is not a PjRt addressable device or `KeyError` if the memstats returns None if they are not available.

* Fix validation error for other models

* Fix decode.py to also use first_token from prefill_call

* Add moe perf number

* move num_experts pyconfig assertion to fix tests

* Cast type for inputs before kernel call

* Move sharding overrides to models/ directory.

PiperOrigin-RevId: 650994392

* Enable quantization for MoE Matmul implementation

* Integrate and test Goodput Monitor with MaxText

* Adding Tokens/s/device to the log.

* Adding support for mixed precision quantization configs.

---------

Co-authored-by: maxtext authors <[email protected]>
Co-authored-by: Nina Cai <[email protected]>
Co-authored-by: NinaCai <[email protected]>
Co-authored-by: michelle-yooh <[email protected]>
Co-authored-by: In-Ho Yi <[email protected]>
Co-authored-by: A9isha <[email protected]>
Co-authored-by: In-Ho Yi <[email protected]>
Co-authored-by: ssusie <[email protected]>
Co-authored-by: tonyjohnchen <[email protected]>
Co-authored-by: Roshani Narasimhan <[email protected]>
Co-authored-by: Pate Motter <[email protected]>
Co-authored-by: khatwanimohit <[email protected]>
Co-authored-by: Morgan Du <[email protected]>
Co-authored-by: DongHyun Choi <[email protected]>
Co-authored-by: gobbleturk <[email protected]>
Co-authored-by: Raymond Zou <[email protected]>
Co-authored-by: Bixia Zheng <[email protected]>
Co-authored-by: Ran Ran <[email protected]>
Co-authored-by: Zhiyu Li <[email protected]>
Co-authored-by: Rafi Witten <[email protected]>
Co-authored-by: RissyRan <[email protected]>
Co-authored-by: Greg Olechwierowicz <[email protected]>
Co-authored-by: Junwei Yang <[email protected]>
Co-authored-by: Reed Wanderman-Milne <[email protected]>
Co-authored-by: Dimitar (Mitko) Asenov <[email protected]>
Co-authored-by: aireenmei <[email protected]>
Co-authored-by: yangyuwei <[email protected]>
Co-authored-by: Abhinav Singh <[email protected]>
Co-authored-by: Sadi Kneipp <[email protected]>
Co-authored-by: jwyang-google <[email protected]>
Co-authored-by: Anfal Siddiqui <[email protected]>
Co-authored-by: Brendan Slabe <[email protected]>
Co-authored-by: Sergei Lebedev <[email protected]>
Co-authored-by: Jon Bolin <[email protected]>
Co-authored-by: Zijun Zhou <[email protected]>
Co-authored-by: Zhihao Shan <[email protected]>
Co-authored-by: Adam O'Brien <[email protected]>
Co-authored-by: Vipan Nalla <[email protected]>
Co-authored-by: Vipan Nalla <[email protected]>
Co-authored-by: Xuefeng Gu <[email protected]>
Co-authored-by: Andy Ye <[email protected]>
Co-authored-by: Mitali Singh <[email protected]>
Co-authored-by: xuefgu <[email protected]>
Co-authored-by: Luke Baumann <[email protected]>
Co-authored-by: Dipannita Shaw <[email protected]>
  • Loading branch information
Show file tree
Hide file tree
Showing 243 changed files with 147,764 additions and 4,714 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Changes in this file should match with requiredReviewers in .github/workflows/AddLabel.yml
* @rwitten
* @gobbleturk
31 changes: 22 additions & 9 deletions .github/workflows/AddLabel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
name: Add Label

on:
workflow_call:
workflow_run:
# workflows: [Unit Test, CodeQL]
workflows: [CodeQL]
workflows: [Unit Test, CodeQL]
types:
- completed
pull_request_review:
Expand Down Expand Up @@ -54,23 +52,33 @@ jobs:
}
// This list should match with CODEOWNERS
let requiredReviewers = { rwitten: "" }
let requiredReviewers = { gobbleturk: "" }
const reviews = await github.rest.pulls.listReviews({
owner,
repo,
pull_number,
})
const pullRequest = await github.rest.pulls.get({
owner,
repo,
pull_number,
});
const pullRequester = pullRequest.data.user.login;
if (reviews.data.length === 0) {
console.log("Not adding pull ready because the PR is not approved yet")
console.log("Not adding pull ready because the PR is not approved yet.")
process.exit()
}
let is_approved=false
for (const review of reviews.data) {
if (review.state === "APPROVED") {
delete requiredReviewers[review.user.login]
if (review.state === "APPROVED" && (review.user.login in requiredReviewers || pullRequester in requiredReviewers)) {
is_approved=true
break;
}
}
if (Object.keys(requiredReviewers).length !== 0) {
console.log("Not adding pull ready because the PR is not approved yet")
if (!is_approved) {
console.log("Not adding pull ready because the PR is not approved yet by a code owner.")
process.exit()
}
Expand All @@ -80,6 +88,11 @@ jobs:
pull_number,
per_page: 100,
})
// Check that the number of commits in the PR is 1.
if (commits.data.length !== 1) {
console.log("Not adding pull ready because the PR has more than one commit. Please squash your commits.")
process.exit(1)
}
const ref = commits.data.slice(-1)[0].sha
const checkRuns = await github.rest.checks.listForRef({
owner,
Expand Down
38 changes: 38 additions & 0 deletions .github/workflows/CPUTests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Linter

on:
push:
branches:
- '**'

jobs:
cpu:
name: "CPU tests"
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ['3.10']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install pylint pyink pytype==2024.2.27
- name: Typecheck the code with pytype
run: |
pytype --jobs auto --disable import-error MaxText/
- name: Analysing the code with pylint in Maxtext/
run: |
pylint MaxText/ && \
echo 'Maxtext PyLint check successful' || { echo \
'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; }
- name: Analysing the code with pylint in pedagogical_examples/
run: |
pylint pedagogical_examples/ && \
echo 'PyLint check on pedagogical_examples/ is successful' || { echo \
'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; }
170 changes: 137 additions & 33 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,151 @@ on:
branches: [ "main" ]
workflow_dispatch:
schedule:
# Run the job every 60 mins
- cron: '*/60 * * * *'
# Run the job every 2 hours
- cron: '0 */2 * * *'

jobs:
build:
build_and_upload_image:
strategy:
fail-fast: false
matrix:
tpu-type: ["v4-8"]
name: "TPU test (${{ matrix.tpu-type }})"
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
fail-fast: false
matrix:
device:
- type: tpu
name: v4-8
mode: stable
- type: gpu
name: a100-40gb-4
mode: pinned
name: Build and upload image (${{ matrix.device.name }})
runs-on: ["self-hosted", "${{ matrix.device.type }}", "${{ matrix.device.name }}"]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Cleanup old docker images
run: docker system prune --all --force
- name: Build an image
run: |
docker system prune --all --force
- name: Install dependencies
bash docker_build_dependency_image.sh MODE=${{ matrix.device.mode }} DEVICE=${{ matrix.device.type }}
- name: Tag the image
run: |
bash docker_build_dependency_image.sh
- name: Analysing the code with pylint
docker tag maxtext_base_image gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }}
- name: Upload the image
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c "pylint MaxText/"
docker push gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }}
common:
needs: build_and_upload_image
strategy:
fail-fast: False
matrix:
device:
- type: tpu
name: v4-8
attention: autoselected
pytest_marker: ''
container_env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75
TF_FORCE_GPU_ALLOW_GROWTH: false
container_resource_option: "--privileged"
- type: gpu
name: a100-40gb-4
image_suffix: gpu_jax_pinned
attention: dot_product
pytest_marker: -m 'not tpu'
container_env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65
TF_FORCE_GPU_ALLOW_GROWTH: true
container_resource_option: "--shm-size 2g --runtime=nvidia --gpus all --privileged"
name: Common test (${{ matrix.device.name }})
runs-on: ["self-hosted", "${{ matrix.device.type }}", "${{ matrix.device.name }}"]
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }}
volumes:
- /home/runner/actions-runner/_work/maxtext/maxtext:/deps
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ matrix.device.container_env.XLA_PYTHON_CLIENT_MEM_FRACTION }}
TF_FORCE_GPU_ALLOW_GROWTH: ${{ matrix.device.container_env.TF_FORCE_GPU_ALLOW_GROWTH }}
options: ${{ matrix.device.container_resource_option }}
steps:
- uses: actions/checkout@v4
- name: Test gsutil installation
run: which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}
- name: Test with pytest
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest'
- name: Test train.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2'
run: cd MaxText;python3 -m pytest ${{ matrix.device.pytest_marker }}
- name: Test train.py with TFDS c4
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test train.py with HF c4
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs hf_train_files=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet hf_path=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large attention=${{ matrix.device.attention }} enable_checkpointing=false
- name: Test train.py with synthetic data
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} dataset_type=synthetic
- name: Test train.py with per_device_batch_size < 1
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test decode.py
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4'
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1
- name: Test decode.py with per_device_batch_size < 1
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=.25
- name: Test int8_training
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset int8_training=true steps=2'
add_pull_ready:
if: github.ref != 'refs/heads/main'
permissions:
checks: read
pull-requests: write
needs: build
uses: ./.github/workflows/AddLabel.yml
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test fp8_training
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=fp8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test generate_param_only_checkpoint
run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a ${{ matrix.device.attention }}
- name: Test generate_param_only_checkpoint with int8 quantization
run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8 -a ${{ matrix.device.attention }}
- name: Test grain checkpoint determinism
run: bash end_to_end/test_checkpointing.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset False grain ${{ matrix.device.attention }}
- name: Test checkpoint compatibility
run: bash end_to_end/test_checkpoint_compatibility.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset ${{ matrix.device.attention }}

tpu:
needs: build_and_upload_image
strategy:
fail-fast: false
matrix:
device-type: ["v4-8"]
name: "TPU test (${{ matrix.device-type }})"
runs-on: ["self-hosted", "tpu", "${{ matrix.device-type }}"]
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu
volumes:
- /home/runner/actions-runner/_work/maxtext/maxtext:/deps
options: "--privileged"
steps:
- uses: actions/checkout@v4
- name: Validate Pedagogical Example, Shmap_collective_matmul
run: python3 pedagogical_examples/shmap_collective_matmul.py

gpu:
needs: build_and_upload_image
strategy:
fail-fast: false
matrix:
device-type: ["a100-40gb-4"]
build-mode: ["pinned"]
name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})"
runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"]
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu
volumes:
- /home/runner/actions-runner/_work/maxtext/maxtext:/deps
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65
TF_FORCE_GPU_ALLOW_GROWTH: true
options: "--shm-size 2g --runtime=nvidia --gpus all --privileged"
steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Test train.py with flash attention
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=cudnn_flash_te

clean_up:
if: ${{ always() }}
needs: [common, gpu, tpu]
name: "Clean up"
runs-on: ["self-hosted"]
steps:
- name: Delete GPU image
run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu --force-delete-tags --quiet
- name: Delete TPU image
run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu --force-delete-tags --quiet

56 changes: 56 additions & 0 deletions .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Build Images

on:
schedule:
# Run the job daily at 12AM UTC
- cron: '0 0 * * *'

jobs:
tpu:
strategy:
fail-fast: false
matrix:
device-type: ["v4-8"]
runs-on: ["self-hosted", "tpu", "${{ matrix.device-type }}"]
steps:
- uses: actions/checkout@v3
- name: build jax stable image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_stable MODE=stable DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_stable
- name: build jax nightly image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_jax_nightly MODE=nightly DEVICE=tpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_jax_nightly
gpu:
strategy:
fail-fast: false
matrix:
device-type: ["a100-40gb-4"]
runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"]
steps:
- uses: actions/checkout@v3
- name: build jax stable image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable MODE=stable DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_stable
- name: build jax nightly image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_nightly MODE=nightly DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_nightly
- name: build jax pinned image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_pinned MODE=pinned DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_pinned
Loading

0 comments on commit aac50a8

Please sign in to comment.