From 9343eec18cf3943bb7598aa915c3625f23fb597d Mon Sep 17 00:00:00 2001 From: ssusie Date: Tue, 16 Apr 2024 16:59:15 +0000 Subject: [PATCH] adding script to fix the style and adding modified/fixed files with line length 125 --- .github/workflows/CPUTests.yml | 38 + .github/workflows/UnitTests.yml | 25 - MaxText/__init__.py | 22 +- MaxText/accelerator_to_spec_map.py | 479 +++-------- MaxText/checkpointing.py | 140 ++-- MaxText/common_types.py | 16 +- MaxText/convert_gemma_chkpt.py | 198 ++--- MaxText/convert_gpt3_ckpt_from_paxml.py | 203 +++-- MaxText/decode.py | 35 +- MaxText/generate_param_only_checkpoint.py | 63 +- MaxText/inference_microbenchmark.py | 134 +-- MaxText/inference_scratch/analyze_sharegpt.py | 42 +- MaxText/inference_utils.py | 68 +- .../input_pipeline/_grain_data_processing.py | 132 +-- MaxText/input_pipeline/_grain_operations.py | 93 ++- MaxText/input_pipeline/_grain_tokenizer.py | 30 +- .../input_pipeline/_tfds_data_processing.py | 180 ++-- .../_tfds_data_processing_c4_mlperf.py | 134 +-- .../input_pipeline_interface.py | 205 ++--- MaxText/layers/attentions.py | 781 +++++++++--------- MaxText/layers/embeddings.py | 34 +- MaxText/layers/gemma.py | 132 ++- MaxText/layers/gpt3.py | 304 ++++--- MaxText/layers/initializers.py | 12 +- MaxText/layers/linears.py | 117 ++- MaxText/layers/llama2.py | 135 ++- MaxText/layers/mistral.py | 221 +++-- MaxText/layers/models.py | 262 +++--- MaxText/layers/normalizations.py | 3 +- MaxText/layers/quantizations.py | 97 +-- MaxText/llama_or_mistral_ckpt.py | 468 +++++------ MaxText/max_logging.py | 25 +- MaxText/max_utils.py | 374 +++++---- MaxText/maxengine.py | 233 +++--- MaxText/maxengine_config.py | 23 +- MaxText/maxengine_server.py | 8 +- MaxText/maxtext_utils.py | 149 ++-- MaxText/multihost_dataloading.py | 40 +- MaxText/optimizers.py | 60 +- MaxText/pyconfig.py | 237 +++--- MaxText/sequence_packing.py | 83 +- MaxText/standalone_checkpointer.py | 62 +- MaxText/standalone_dataloader.py | 33 +- MaxText/tests/attention_test.py | 140 ++-- MaxText/tests/gpt3_test.py | 84 +- MaxText/tests/grain_data_processing_test.py | 259 +++--- .../inference_microbenchmark_smoke_test.py | 22 +- MaxText/tests/llama_test.py | 97 ++- MaxText/tests/max_utils_test.py | 127 ++- MaxText/tests/model_test.py | 113 ++- MaxText/tests/multihost_dataloading_test.py | 55 +- MaxText/tests/profiler_test.py | 79 +- MaxText/tests/quantizations_test.py | 102 ++- MaxText/tests/standalone_dl_ckpt_test.py | 82 +- MaxText/tests/tfds_data_processing_test.py | 130 +-- MaxText/tests/tokenizer_test.py | 56 +- MaxText/tests/train_compile_test.py | 199 +++-- MaxText/tests/train_int8_smoke_test.py | 54 +- MaxText/tests/train_smoke_test.py | 53 +- MaxText/tests/weight_dtypes_test.py | 87 +- MaxText/tokenizer.py | 39 +- MaxText/train.py | 331 ++++---- MaxText/train_compile.py | 97 ++- MaxText/train_tokenizer.py | 135 ++- MaxText/vertex_tensorboard.py | 64 +- code_style.sh | 33 + pedagogical_examples/non_spmd.py | 40 +- pedagogical_examples/shardings.py | 166 ++-- .../shmap_collective_matmul.py | 238 +++--- pylintrc | 2 + 70 files changed, 4581 insertions(+), 4433 deletions(-) create mode 100644 .github/workflows/CPUTests.yml create mode 100644 code_style.sh diff --git a/.github/workflows/CPUTests.yml b/.github/workflows/CPUTests.yml new file mode 100644 index 000000000..a6cd25a1e --- /dev/null +++ b/.github/workflows/CPUTests.yml @@ -0,0 +1,38 @@ +name: Linter + +on: + push: + branches: + - '**' + +jobs: + cpu: + name: "CPU tests" + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-20.04] + python-version: ['3.10'] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install pylint pyink pytype==2024.2.27 + - name: Typecheck the code with pytype + run: | + pytype --jobs auto --disable import-error MaxText/ + - name: Analysing the code with pylint in Maxtext/ + run: | + pylint MaxText/ && \ + echo 'Maxtext PyLint check successful' || { echo \ + 'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; } + - name: Analysing the code with pylint in pedagogical_examples/ + run: | + pylint pedagogical_examples/ && \ + echo 'PyLint check on pedagogical_examples/ is successful' || { echo \ + 'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; } diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 1efcfb4e7..08ff79a8c 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -27,31 +27,6 @@ on: - cron: '0 */2 * * *' jobs: - cpu: - name: "CPU test" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-20.04] - python-version: ['3.10'] - steps: - - uses: actions/checkout@v3 - - name: setup python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install Dependencies - run: | - pip install pytype==2024.2.27 - pip install pylint - - name: Typecheck the code with pytype - run: | - pytype --jobs auto --disable import-error MaxText/ - - name: Analysing the code with pylint - run: | - pylint MaxText/ - - # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'gpu' job tpu: strategy: diff --git a/MaxText/__init__.py b/MaxText/__init__.py index 83f918e62..c133d2d71 100644 --- a/MaxText/__init__.py +++ b/MaxText/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/MaxText/accelerator_to_spec_map.py b/MaxText/accelerator_to_spec_map.py index fa6b64d0a..255aef965 100644 --- a/MaxText/accelerator_to_spec_map.py +++ b/MaxText/accelerator_to_spec_map.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Static map of TPU names such as v4-8 to properties such as chip layout.""" @@ -22,358 +22,135 @@ from dataclasses import dataclass + @dataclass class SystemCharacteristics: platform: str topology_name: str - chip_config_name: str # 'megacore' or 'default' + chip_config_name: str # 'megacore' or 'default' chips_per_host_bounds: tuple devices_per_slice: int + UserFacingNameToSystemCharacteristics = { # v5e - 'v5e-16': SystemCharacteristics( - 'tpu', 'v5e:4x4', 'default', (2, 2, 1), 16 - ), - 'v5e-32': SystemCharacteristics( - 'tpu', 'v5e:4x8', 'default', (2, 2, 1), 32 - ), - 'v5e-64': SystemCharacteristics( - 'tpu', 'v5e:8x8', 'default', (2, 2, 1), 64 - ), - 'v5e-128': SystemCharacteristics( - 'tpu', 'v5e:8x16', 'default', (2, 2, 1), 128 - ), - 'v5e-256': SystemCharacteristics( - 'tpu', 'v5e:16x16', 'default', (2, 2, 1), 256 - ), + "v5e-16": SystemCharacteristics("tpu", "v5e:4x4", "default", (2, 2, 1), 16), + "v5e-32": SystemCharacteristics("tpu", "v5e:4x8", "default", (2, 2, 1), 32), + "v5e-64": SystemCharacteristics("tpu", "v5e:8x8", "default", (2, 2, 1), 64), + "v5e-128": SystemCharacteristics("tpu", "v5e:8x16", "default", (2, 2, 1), 128), + "v5e-256": SystemCharacteristics("tpu", "v5e:16x16", "default", (2, 2, 1), 256), # v4 - 'v4-8': SystemCharacteristics( - 'tpu', 'v4:2x2x1', 'megacore', (2, 2, 1), 4 - ), - 'v4-16': SystemCharacteristics( - 'tpu', 'v4:2x2x2', 'megacore', (2, 2, 1), 8 - ), - 'v4-32': SystemCharacteristics( - 'tpu', 'v4:2x2x4', 'megacore', (2, 2, 1), 16 - ), - 'v4-64': SystemCharacteristics( - 'tpu', 'v4:2x4x4', 'megacore', (2, 2, 1), 32 - ), - 'v4-128': SystemCharacteristics( - 'tpu', 'v4:4x4x4', 'megacore', (2, 2, 1), 64 - ), - 'v4-256': SystemCharacteristics( - 'tpu', 'v4:4x4x8', 'megacore', (2, 2, 1), 128 - ), - 'v4-384': SystemCharacteristics( - 'tpu', 'v4:4x4x12', 'megacore', (2, 2, 1), 192 - ), - 'v4-512': SystemCharacteristics( - 'tpu', 'v4:4x8x8', 'megacore', (2, 2, 1), 256 - ), - 'v4-1024': SystemCharacteristics( - 'tpu', 'v4:8x8x8', 'megacore', (2, 2, 1), 512 - ), - 'v4-1536': SystemCharacteristics( - 'tpu', 'v4:8x8x12','megacore', (2, 2, 1), 768 - ), - 'v4-2048': SystemCharacteristics( - 'tpu', 'v4:8x8x16','megacore', (2, 2, 1), 1024 - ), - 'v4-4096': SystemCharacteristics( - 'tpu', 'v4:8x16x16', 'megacore', (2, 2, 1), 2048 - ), + "v4-8": SystemCharacteristics("tpu", "v4:2x2x1", "megacore", (2, 2, 1), 4), + "v4-16": SystemCharacteristics("tpu", "v4:2x2x2", "megacore", (2, 2, 1), 8), + "v4-32": SystemCharacteristics("tpu", "v4:2x2x4", "megacore", (2, 2, 1), 16), + "v4-64": SystemCharacteristics("tpu", "v4:2x4x4", "megacore", (2, 2, 1), 32), + "v4-128": SystemCharacteristics("tpu", "v4:4x4x4", "megacore", (2, 2, 1), 64), + "v4-256": SystemCharacteristics("tpu", "v4:4x4x8", "megacore", (2, 2, 1), 128), + "v4-384": SystemCharacteristics("tpu", "v4:4x4x12", "megacore", (2, 2, 1), 192), + "v4-512": SystemCharacteristics("tpu", "v4:4x8x8", "megacore", (2, 2, 1), 256), + "v4-1024": SystemCharacteristics("tpu", "v4:8x8x8", "megacore", (2, 2, 1), 512), + "v4-1536": SystemCharacteristics("tpu", "v4:8x8x12", "megacore", (2, 2, 1), 768), + "v4-2048": SystemCharacteristics("tpu", "v4:8x8x16", "megacore", (2, 2, 1), 1024), + "v4-4096": SystemCharacteristics("tpu", "v4:8x16x16", "megacore", (2, 2, 1), 2048), # v5p - 'v5p-8': SystemCharacteristics( - 'tpu', 'v5:2x2x1', 'megacore', (2, 2, 1), 4 - ), - 'v5p-16': SystemCharacteristics( - 'tpu', 'v5:2x2x2', 'megacore', (2, 2, 1), 8 - ), - 'v5p-32': SystemCharacteristics( - 'tpu', 'v5:2x2x4', 'megacore', (2, 2, 1), 16 - ), - 'v5p-64': SystemCharacteristics( - 'tpu', 'v5:2x4x4', 'megacore', (2, 2, 1), 32 - ), - 'v5p-128': SystemCharacteristics( - 'tpu', 'v5:4x4x4', 'megacore', (2, 2, 1), 64 - ), - 'v5p-256': SystemCharacteristics( - 'tpu', 'v5:4x4x8', 'megacore', (2, 2, 1), 128 - ), - 'v5p-384': SystemCharacteristics( - 'tpu', 'v5:4x4x12', 'megacore', (2, 2, 1), 192 - ), - 'v5p-512': SystemCharacteristics( - 'tpu', 'v5:4x8x8', 'megacore', (2, 2, 1), 256 - ), - 'v5p-640': SystemCharacteristics( - 'tpu', 'v5:4x4x20', 'megacore', (2, 2, 1), 320 - ), - 'v5p-768': SystemCharacteristics( - 'tpu', 'v5:4x8x12', 'megacore', (2, 2, 1), 384 - ), - 'v5p-896': SystemCharacteristics( - 'tpu', 'v5:4x4x28', 'megacore', (2, 2, 1), 448 - ), - 'v5p-1024': SystemCharacteristics( - 'tpu', 'v5:8x8x8', 'megacore', (2, 2, 1), 512 - ), - 'v5p-1152': SystemCharacteristics( - 'tpu', 'v5:4x12x12', 'megacore', (2, 2, 1), 576 - ), - 'v5p-1280': SystemCharacteristics( - 'tpu', 'v5:4x8x20', 'megacore', (2, 2, 1), 640 - ), - 'v5p-1408': SystemCharacteristics( - 'tpu', 'v5:4x4x44', 'megacore', (2, 2, 1), 704 - ), - 'v5p-1536': SystemCharacteristics( - 'tpu', 'v5:8x8x12', 'megacore', (2, 2, 1), 768 - ), - 'v5p-1664': SystemCharacteristics( - 'tpu', 'v5:4x4x52', 'megacore', (2, 2, 1), 832 - ), - 'v5p-1792': SystemCharacteristics( - 'tpu', 'v5:4x8x28', 'megacore', (2, 2, 1), 896 - ), - 'v5p-1920': SystemCharacteristics( - 'tpu', 'v5:4x12x20', 'megacore', (2, 2, 1), 960 - ), - 'v5p-2048': SystemCharacteristics( - 'tpu', 'v5:8x8x16', 'megacore', (2, 2, 1), 1024 - ), - 'v5p-2176': SystemCharacteristics( - 'tpu', 'v5:4x4x68', 'megacore', (2, 2, 1), 1088 - ), - 'v5p-2304': SystemCharacteristics( - 'tpu', 'v5:8x12x12', 'megacore', (2, 2, 1), 1152 - ), - 'v5p-2432': SystemCharacteristics( - 'tpu', 'v5:4x4x76', 'megacore', (2, 2, 1), 1216 - ), - 'v5p-2560': SystemCharacteristics( - 'tpu', 'v5:8x8x20', 'megacore', (2, 2, 1), 1280 - ), - 'v5p-2688': SystemCharacteristics( - 'tpu', 'v5:4x12x28', 'megacore', (2, 2, 1), 1344 - ), - 'v5p-2816': SystemCharacteristics( - 'tpu', 'v5:4x8x44', 'megacore', (2, 2, 1), 1408 - ), - 'v5p-2944': SystemCharacteristics( - 'tpu', 'v5:4x4x92', 'megacore', (2, 2, 1), 1472 - ), - 'v5p-3072': SystemCharacteristics( - 'tpu', 'v5:8x12x16', 'megacore', (2, 2, 1), 1536 - ), - 'v5p-3200': SystemCharacteristics( - 'tpu', 'v5:4x20x20', 'megacore', (2, 2, 1), 1600 - ), - 'v5p-3328': SystemCharacteristics( - 'tpu', 'v5:4x8x52', 'megacore', (2, 2, 1), 1664 - ), - 'v5p-3456': SystemCharacteristics( - 'tpu', 'v5:12x12x12', 'megacore', (2, 2, 1), 1728 - ), - 'v5p-3584': SystemCharacteristics( - 'tpu', 'v5:8x8x28', 'megacore', (2, 2, 1), 1792 - ), - 'v5p-3712': SystemCharacteristics( - 'tpu', 'v5:4x4x116', 'megacore', (2, 2, 1), 1856 - ), - 'v5p-3840': SystemCharacteristics( - 'tpu', 'v5:8x12x20', 'megacore', (2, 2, 1), 1920 - ), - 'v5p-3968': SystemCharacteristics( - 'tpu', 'v5:4x4x124', 'megacore', (2, 2, 1), 1984 - ), - 'v5p-4096': SystemCharacteristics( - 'tpu', 'v5:8x16x16', 'megacore', (2, 2, 1), 2048 - ), - 'v5p-4224': SystemCharacteristics( - 'tpu', 'v5:4x12x44', 'megacore', (2, 2, 1), 2112 - ), - 'v5p-4352': SystemCharacteristics( - 'tpu', 'v5:4x8x68', 'megacore', (2, 2, 1), 2176 - ), - 'v5p-4480': SystemCharacteristics( - 'tpu', 'v5:4x20x28', 'megacore', (2, 2, 1), 2240 - ), - 'v5p-4608': SystemCharacteristics( - 'tpu', 'v5:12x12x16', 'megacore', (2, 2, 1), 2304 - ), - 'v5p-4736': SystemCharacteristics( - 'tpu', 'v5:4x4x148', 'megacore', (2, 2, 1), 2368 - ), - 'v5p-4864': SystemCharacteristics( - 'tpu', 'v5:4x8x76', 'megacore', (2, 2, 1), 2432 - ), - 'v5p-4992': SystemCharacteristics( - 'tpu', 'v5:4x12x52', 'megacore', (2, 2, 1), 2496 - ), - 'v5p-5120': SystemCharacteristics( - 'tpu', 'v5:8x16x20', 'megacore', (2, 2, 1), 2560 - ), - 'v5p-5248': SystemCharacteristics( - 'tpu', 'v5:4x4x164', 'megacore', (2, 2, 1), 2624 - ), - 'v5p-5376': SystemCharacteristics( - 'tpu', 'v5:8x12x28', 'megacore', (2, 2, 1), 2688 - ), - 'v5p-5504': SystemCharacteristics( - 'tpu', 'v5:4x4x172', 'megacore', (2, 2, 1), 2752 - ), - 'v5p-5632': SystemCharacteristics( - 'tpu', 'v5:8x8x44', 'megacore', (2, 2, 1), 2816 - ), - 'v5p-5760': SystemCharacteristics( - 'tpu', 'v5:12x12x20', 'megacore', (2, 2, 1), 2880 - ), - 'v5p-5888': SystemCharacteristics( - 'tpu', 'v5:4x8x92', 'megacore', (2, 2, 1), 2944 - ), - 'v5p-6016': SystemCharacteristics( - 'tpu', 'v5:4x4x188', 'megacore', (2, 2, 1), 3008 - ), - 'v5p-6144': SystemCharacteristics( - 'tpu', 'v5:12x16x16', 'megacore', (2, 2, 1), 3072 - ), - 'v5p-6272': SystemCharacteristics( - 'tpu', 'v5:4x28x28', 'megacore', (2, 2, 1), 3136 - ), - 'v5p-6400': SystemCharacteristics( - 'tpu', 'v5:8x20x20', 'megacore', (2, 2, 1), 3200 - ), - 'v5p-6528': SystemCharacteristics( - 'tpu', 'v5:4x12x68', 'megacore', (2, 2, 1), 3264 - ), - 'v5p-6656': SystemCharacteristics( - 'tpu', 'v5:8x8x52', 'megacore', (2, 2, 1), 3328 - ), - 'v5p-6784': SystemCharacteristics( - 'tpu', 'v5:4x4x212', 'megacore', (2, 2, 1), 3392 - ), - 'v5p-6912': SystemCharacteristics( - 'tpu', 'v5:12x12x24', 'megacore', (2, 2, 1), 3456 - ), - 'v5p-7040': SystemCharacteristics( - 'tpu', 'v5:4x20x44', 'megacore', (2, 2, 1), 3520 - ), - 'v5p-7168': SystemCharacteristics( - 'tpu', 'v5:8x16x28', 'megacore', (2, 2, 1), 3584 - ), - 'v5p-7296': SystemCharacteristics( - 'tpu', 'v5:4x12x76', 'megacore', (2, 2, 1), 3648 - ), - 'v5p-7424': SystemCharacteristics( - 'tpu', 'v5:4x8x116', 'megacore', (2, 2, 1), 3712 - ), - 'v5p-7552': SystemCharacteristics( - 'tpu', 'v5:4x4x236', 'megacore', (2, 2, 1), 3776 - ), - 'v5p-7680': SystemCharacteristics( - 'tpu', 'v5:12x16x20', 'megacore', (2, 2, 1), 3840 - ), - 'v5p-7808': SystemCharacteristics( - 'tpu', 'v5:4x4x244', 'megacore', (2, 2, 1), 3904 - ), - 'v5p-7936': SystemCharacteristics( - 'tpu', 'v5:4x8x124', 'megacore', (2, 2, 1), 3968 - ), - 'v5p-8064': SystemCharacteristics( - 'tpu', 'v5:12x12x28', 'megacore', (2, 2, 1), 4032 - ), - 'v5p-8192': SystemCharacteristics( - 'tpu', 'v5:16x16x16', 'megacore', (2, 2, 1), 4096 - ), - 'v5p-8320': SystemCharacteristics( - 'tpu', 'v5:4x20x52', 'megacore', (2, 2, 1), 4160 - ), - 'v5p-8448': SystemCharacteristics( - 'tpu', 'v5:8x12x44', 'megacore', (2, 2, 1), 4224 - ), - 'v5p-8704': SystemCharacteristics( - 'tpu', 'v5:8x8x68', 'megacore', (2, 2, 1), 4352 - ), - 'v5p-8832': SystemCharacteristics( - 'tpu', 'v5:4x12x92', 'megacore', (2, 2, 1), 4416 - ), - 'v5p-8960': SystemCharacteristics( - 'tpu', 'v5:8x20x28', 'megacore', (2, 2, 1), 4480 - ), - 'v5p-9216': SystemCharacteristics( - 'tpu', 'v5:12x16x24', 'megacore', (2, 2, 1), 4608 - ), - 'v5p-9472': SystemCharacteristics( - 'tpu', 'v5:4x8x148', 'megacore', (2, 2, 1), 4736 - ), - 'v5p-9600': SystemCharacteristics( - 'tpu', 'v5:12x20x20', 'megacore', (2, 2, 1), 4800 - ), - 'v5p-9728': SystemCharacteristics( - 'tpu', 'v5:8x8x76', 'megacore', (2, 2, 1), 4864 - ), - 'v5p-9856': SystemCharacteristics( - 'tpu', 'v5:4x28x44', 'megacore', (2, 2, 1), 4928 - ), - 'v5p-9984': SystemCharacteristics( - 'tpu', 'v5:8x12x52', 'megacore', (2, 2, 1), 4992 - ), - 'v5p-10240': SystemCharacteristics( - 'tpu', 'v5:16x16x20', 'megacore', (2, 2, 1), 5120 - ), - 'v5p-10368': SystemCharacteristics( - 'tpu', 'v5:12x12x36', 'megacore', (2, 2, 1), 5184 - ), - 'v5p-10496': SystemCharacteristics( - 'tpu', 'v5:4x8x164', 'megacore', (2, 2, 1), 5248 - ), - 'v5p-10752': SystemCharacteristics( - 'tpu', 'v5:12x16x28', 'megacore', (2, 2, 1), 5376 - ), - 'v5p-10880': SystemCharacteristics( - 'tpu', 'v5:4x20x68', 'megacore', (2, 2, 1), 5440 - ), - 'v5p-11008': SystemCharacteristics( - 'tpu', 'v5:4x8x172', 'megacore', (2, 2, 1), 5504 - ), - 'v5p-11136': SystemCharacteristics( - 'tpu', 'v5:4x12x116', 'megacore', (2, 2, 1), 5568 - ), - 'v5p-11264': SystemCharacteristics( - 'tpu', 'v5:8x16x44', 'megacore', (2, 2, 1), 5632 - ), - 'v5p-11520': SystemCharacteristics( - 'tpu', 'v5:12x20x24', 'megacore', (2, 2, 1), 5760 - ), - 'v5p-11648': SystemCharacteristics( - 'tpu', 'v5:4x28x52', 'megacore', (2, 2, 1), 5824 - ), - 'v5p-11776': SystemCharacteristics( - 'tpu', 'v5:8x8x92', 'megacore', (2, 2, 1), 5888 - ), - 'v5p-11904': SystemCharacteristics( - 'tpu', 'v5:4x12x124', 'megacore', (2, 2, 1), 5952 - ), - 'v5p-12032': SystemCharacteristics( - 'tpu', 'v5:4x8x188', 'megacore', (2, 2, 1), 6016 - ), - 'v5p-12160': SystemCharacteristics( - 'tpu', 'v5:4x20x76', 'megacore', (2, 2, 1), 6080 - ), - 'v5p-12288': SystemCharacteristics( - 'tpu', 'v5:16x16x24', 'megacore', (2, 2, 1), 6144 - ), - 'v5p-13824': SystemCharacteristics( - 'tpu', 'v5:12x24x24', 'megacore', (2, 2, 1), 6912 - ), - 'v5p-17920': SystemCharacteristics( - 'tpu', 'v5:16x20x28', 'megacore', (2, 2, 1), 8960 - ), + "v5p-8": SystemCharacteristics("tpu", "v5:2x2x1", "megacore", (2, 2, 1), 4), + "v5p-16": SystemCharacteristics("tpu", "v5:2x2x2", "megacore", (2, 2, 1), 8), + "v5p-32": SystemCharacteristics("tpu", "v5:2x2x4", "megacore", (2, 2, 1), 16), + "v5p-64": SystemCharacteristics("tpu", "v5:2x4x4", "megacore", (2, 2, 1), 32), + "v5p-128": SystemCharacteristics("tpu", "v5:4x4x4", "megacore", (2, 2, 1), 64), + "v5p-256": SystemCharacteristics("tpu", "v5:4x4x8", "megacore", (2, 2, 1), 128), + "v5p-384": SystemCharacteristics("tpu", "v5:4x4x12", "megacore", (2, 2, 1), 192), + "v5p-512": SystemCharacteristics("tpu", "v5:4x8x8", "megacore", (2, 2, 1), 256), + "v5p-640": SystemCharacteristics("tpu", "v5:4x4x20", "megacore", (2, 2, 1), 320), + "v5p-768": SystemCharacteristics("tpu", "v5:4x8x12", "megacore", (2, 2, 1), 384), + "v5p-896": SystemCharacteristics("tpu", "v5:4x4x28", "megacore", (2, 2, 1), 448), + "v5p-1024": SystemCharacteristics("tpu", "v5:8x8x8", "megacore", (2, 2, 1), 512), + "v5p-1152": SystemCharacteristics("tpu", "v5:4x12x12", "megacore", (2, 2, 1), 576), + "v5p-1280": SystemCharacteristics("tpu", "v5:4x8x20", "megacore", (2, 2, 1), 640), + "v5p-1408": SystemCharacteristics("tpu", "v5:4x4x44", "megacore", (2, 2, 1), 704), + "v5p-1536": SystemCharacteristics("tpu", "v5:8x8x12", "megacore", (2, 2, 1), 768), + "v5p-1664": SystemCharacteristics("tpu", "v5:4x4x52", "megacore", (2, 2, 1), 832), + "v5p-1792": SystemCharacteristics("tpu", "v5:4x8x28", "megacore", (2, 2, 1), 896), + "v5p-1920": SystemCharacteristics("tpu", "v5:4x12x20", "megacore", (2, 2, 1), 960), + "v5p-2048": SystemCharacteristics("tpu", "v5:8x8x16", "megacore", (2, 2, 1), 1024), + "v5p-2176": SystemCharacteristics("tpu", "v5:4x4x68", "megacore", (2, 2, 1), 1088), + "v5p-2304": SystemCharacteristics("tpu", "v5:8x12x12", "megacore", (2, 2, 1), 1152), + "v5p-2432": SystemCharacteristics("tpu", "v5:4x4x76", "megacore", (2, 2, 1), 1216), + "v5p-2560": SystemCharacteristics("tpu", "v5:8x8x20", "megacore", (2, 2, 1), 1280), + "v5p-2688": SystemCharacteristics("tpu", "v5:4x12x28", "megacore", (2, 2, 1), 1344), + "v5p-2816": SystemCharacteristics("tpu", "v5:4x8x44", "megacore", (2, 2, 1), 1408), + "v5p-2944": SystemCharacteristics("tpu", "v5:4x4x92", "megacore", (2, 2, 1), 1472), + "v5p-3072": SystemCharacteristics("tpu", "v5:8x12x16", "megacore", (2, 2, 1), 1536), + "v5p-3200": SystemCharacteristics("tpu", "v5:4x20x20", "megacore", (2, 2, 1), 1600), + "v5p-3328": SystemCharacteristics("tpu", "v5:4x8x52", "megacore", (2, 2, 1), 1664), + "v5p-3456": SystemCharacteristics("tpu", "v5:12x12x12", "megacore", (2, 2, 1), 1728), + "v5p-3584": SystemCharacteristics("tpu", "v5:8x8x28", "megacore", (2, 2, 1), 1792), + "v5p-3712": SystemCharacteristics("tpu", "v5:4x4x116", "megacore", (2, 2, 1), 1856), + "v5p-3840": SystemCharacteristics("tpu", "v5:8x12x20", "megacore", (2, 2, 1), 1920), + "v5p-3968": SystemCharacteristics("tpu", "v5:4x4x124", "megacore", (2, 2, 1), 1984), + "v5p-4096": SystemCharacteristics("tpu", "v5:8x16x16", "megacore", (2, 2, 1), 2048), + "v5p-4224": SystemCharacteristics("tpu", "v5:4x12x44", "megacore", (2, 2, 1), 2112), + "v5p-4352": SystemCharacteristics("tpu", "v5:4x8x68", "megacore", (2, 2, 1), 2176), + "v5p-4480": SystemCharacteristics("tpu", "v5:4x20x28", "megacore", (2, 2, 1), 2240), + "v5p-4608": SystemCharacteristics("tpu", "v5:12x12x16", "megacore", (2, 2, 1), 2304), + "v5p-4736": SystemCharacteristics("tpu", "v5:4x4x148", "megacore", (2, 2, 1), 2368), + "v5p-4864": SystemCharacteristics("tpu", "v5:4x8x76", "megacore", (2, 2, 1), 2432), + "v5p-4992": SystemCharacteristics("tpu", "v5:4x12x52", "megacore", (2, 2, 1), 2496), + "v5p-5120": SystemCharacteristics("tpu", "v5:8x16x20", "megacore", (2, 2, 1), 2560), + "v5p-5248": SystemCharacteristics("tpu", "v5:4x4x164", "megacore", (2, 2, 1), 2624), + "v5p-5376": SystemCharacteristics("tpu", "v5:8x12x28", "megacore", (2, 2, 1), 2688), + "v5p-5504": SystemCharacteristics("tpu", "v5:4x4x172", "megacore", (2, 2, 1), 2752), + "v5p-5632": SystemCharacteristics("tpu", "v5:8x8x44", "megacore", (2, 2, 1), 2816), + "v5p-5760": SystemCharacteristics("tpu", "v5:12x12x20", "megacore", (2, 2, 1), 2880), + "v5p-5888": SystemCharacteristics("tpu", "v5:4x8x92", "megacore", (2, 2, 1), 2944), + "v5p-6016": SystemCharacteristics("tpu", "v5:4x4x188", "megacore", (2, 2, 1), 3008), + "v5p-6144": SystemCharacteristics("tpu", "v5:12x16x16", "megacore", (2, 2, 1), 3072), + "v5p-6272": SystemCharacteristics("tpu", "v5:4x28x28", "megacore", (2, 2, 1), 3136), + "v5p-6400": SystemCharacteristics("tpu", "v5:8x20x20", "megacore", (2, 2, 1), 3200), + "v5p-6528": SystemCharacteristics("tpu", "v5:4x12x68", "megacore", (2, 2, 1), 3264), + "v5p-6656": SystemCharacteristics("tpu", "v5:8x8x52", "megacore", (2, 2, 1), 3328), + "v5p-6784": SystemCharacteristics("tpu", "v5:4x4x212", "megacore", (2, 2, 1), 3392), + "v5p-6912": SystemCharacteristics("tpu", "v5:12x12x24", "megacore", (2, 2, 1), 3456), + "v5p-7040": SystemCharacteristics("tpu", "v5:4x20x44", "megacore", (2, 2, 1), 3520), + "v5p-7168": SystemCharacteristics("tpu", "v5:8x16x28", "megacore", (2, 2, 1), 3584), + "v5p-7296": SystemCharacteristics("tpu", "v5:4x12x76", "megacore", (2, 2, 1), 3648), + "v5p-7424": SystemCharacteristics("tpu", "v5:4x8x116", "megacore", (2, 2, 1), 3712), + "v5p-7552": SystemCharacteristics("tpu", "v5:4x4x236", "megacore", (2, 2, 1), 3776), + "v5p-7680": SystemCharacteristics("tpu", "v5:12x16x20", "megacore", (2, 2, 1), 3840), + "v5p-7808": SystemCharacteristics("tpu", "v5:4x4x244", "megacore", (2, 2, 1), 3904), + "v5p-7936": SystemCharacteristics("tpu", "v5:4x8x124", "megacore", (2, 2, 1), 3968), + "v5p-8064": SystemCharacteristics("tpu", "v5:12x12x28", "megacore", (2, 2, 1), 4032), + "v5p-8192": SystemCharacteristics("tpu", "v5:16x16x16", "megacore", (2, 2, 1), 4096), + "v5p-8320": SystemCharacteristics("tpu", "v5:4x20x52", "megacore", (2, 2, 1), 4160), + "v5p-8448": SystemCharacteristics("tpu", "v5:8x12x44", "megacore", (2, 2, 1), 4224), + "v5p-8704": SystemCharacteristics("tpu", "v5:8x8x68", "megacore", (2, 2, 1), 4352), + "v5p-8832": SystemCharacteristics("tpu", "v5:4x12x92", "megacore", (2, 2, 1), 4416), + "v5p-8960": SystemCharacteristics("tpu", "v5:8x20x28", "megacore", (2, 2, 1), 4480), + "v5p-9216": SystemCharacteristics("tpu", "v5:12x16x24", "megacore", (2, 2, 1), 4608), + "v5p-9472": SystemCharacteristics("tpu", "v5:4x8x148", "megacore", (2, 2, 1), 4736), + "v5p-9600": SystemCharacteristics("tpu", "v5:12x20x20", "megacore", (2, 2, 1), 4800), + "v5p-9728": SystemCharacteristics("tpu", "v5:8x8x76", "megacore", (2, 2, 1), 4864), + "v5p-9856": SystemCharacteristics("tpu", "v5:4x28x44", "megacore", (2, 2, 1), 4928), + "v5p-9984": SystemCharacteristics("tpu", "v5:8x12x52", "megacore", (2, 2, 1), 4992), + "v5p-10240": SystemCharacteristics("tpu", "v5:16x16x20", "megacore", (2, 2, 1), 5120), + "v5p-10368": SystemCharacteristics("tpu", "v5:12x12x36", "megacore", (2, 2, 1), 5184), + "v5p-10496": SystemCharacteristics("tpu", "v5:4x8x164", "megacore", (2, 2, 1), 5248), + "v5p-10752": SystemCharacteristics("tpu", "v5:12x16x28", "megacore", (2, 2, 1), 5376), + "v5p-10880": SystemCharacteristics("tpu", "v5:4x20x68", "megacore", (2, 2, 1), 5440), + "v5p-11008": SystemCharacteristics("tpu", "v5:4x8x172", "megacore", (2, 2, 1), 5504), + "v5p-11136": SystemCharacteristics("tpu", "v5:4x12x116", "megacore", (2, 2, 1), 5568), + "v5p-11264": SystemCharacteristics("tpu", "v5:8x16x44", "megacore", (2, 2, 1), 5632), + "v5p-11520": SystemCharacteristics("tpu", "v5:12x20x24", "megacore", (2, 2, 1), 5760), + "v5p-11648": SystemCharacteristics("tpu", "v5:4x28x52", "megacore", (2, 2, 1), 5824), + "v5p-11776": SystemCharacteristics("tpu", "v5:8x8x92", "megacore", (2, 2, 1), 5888), + "v5p-11904": SystemCharacteristics("tpu", "v5:4x12x124", "megacore", (2, 2, 1), 5952), + "v5p-12032": SystemCharacteristics("tpu", "v5:4x8x188", "megacore", (2, 2, 1), 6016), + "v5p-12160": SystemCharacteristics("tpu", "v5:4x20x76", "megacore", (2, 2, 1), 6080), + "v5p-12288": SystemCharacteristics("tpu", "v5:16x16x24", "megacore", (2, 2, 1), 6144), + "v5p-13824": SystemCharacteristics("tpu", "v5:12x24x24", "megacore", (2, 2, 1), 6912), + "v5p-17920": SystemCharacteristics("tpu", "v5:16x20x28", "megacore", (2, 2, 1), 8960), } + def get_system_characteristics(user_facing_name): return UserFacingNameToSystemCharacteristics.get(user_facing_name) diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index bd229cc91..072117680 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" @@ -28,12 +28,13 @@ from multihost_dataloading import MultiHostDataLoadIterator from flax.training import train_state + def create_orbax_checkpoint_manager( checkpoint_dir: str, enable_checkpointing: bool, use_async: bool, save_interval_steps: int, - dataset_type: Optional[str] = 'c4' + dataset_type: Optional[str] = "c4", ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -42,19 +43,19 @@ def create_orbax_checkpoint_manager( max_logging.log("Creating checkpoint manager...") p = epath.Path(checkpoint_dir) - if dataset_type=='c4-array_record': - item_names = ('items', 'iter') + if dataset_type == "c4-array_record": + item_names = ("items", "iter") else: - item_names = ('items',) + item_names = ("items",) mngr = CheckpointManager( p, - item_names = item_names, - options = CheckpointManagerOptions( + item_names=item_names, + options=CheckpointManagerOptions( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, - ) + ), ) max_logging.log("Checkpoint manager created!") return mngr @@ -82,19 +83,19 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): devices inside the replica that current host is in """ idx = _find_idx(device_array, replica_axis_idx) - replica_result = np.take(device_array, - idx, - axis=replica_axis_idx) + replica_result = np.take(device_array, idx, axis=replica_axis_idx) return np.expand_dims(replica_result, axis=replica_axis_idx) -def load_state_if_possible(checkpoint_manager: CheckpointManager, - data_iterator: Union[MultiHostDataLoadIterator, None], - load_parameters_from_path: str, - load_full_state_from_path: str, - abstract_unboxed_pre_state: train_state.TrainState, - enable_single_replica_ckpt_restoring: Optional[bool] = False, - dataset_type: Optional[str] = 'c4'): +def load_state_if_possible( + checkpoint_manager: CheckpointManager, + data_iterator: Union[MultiHostDataLoadIterator, None], + load_parameters_from_path: str, + load_full_state_from_path: str, + abstract_unboxed_pre_state: train_state.TrainState, + enable_single_replica_ckpt_restoring: Optional[bool] = False, + dataset_type: Optional[str] = "c4", +): """Loads TrainState as possible from the inputs. Args: @@ -121,57 +122,59 @@ def load_state_if_possible(checkpoint_manager: CheckpointManager, latest_step = checkpoint_manager.latest_step() if latest_step is not None: - max_logging.log(f"restoring from this run's directory latest step \ - {latest_step}") + max_logging.log( + f"restoring from this run's directory latest step \ + {latest_step}" + ) - def map_to_pspec(data): + def map_to_pspec(data): pspec = data.sharding.spec mesh = data.sharding.mesh if not enable_single_replica_ckpt_restoring: return orbax.checkpoint.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec) orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, - orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), - override=True) + jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True + ) orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, - orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), - override=True) + jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True + ) replica_axis_index = 0 # for maxtext data is the first dimension replica_devices = _replica_devices(mesh.devices, replica_axis_index) replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) return orbax.checkpoint.type_handlers.SingleReplicaArrayRestoreArgs( - sharding=jax.sharding.NamedSharding(mesh, pspec), - single_replica_sharding=single_replica_sharding, - replica_axis_index=replica_axis_index, - global_shape=data.shape, - dtype=data.dtype, - ) - - restore_args = jax.tree_util.tree_map(map_to_pspec, - abstract_unboxed_pre_state, - ) - if dataset_type == 'c4-array_record' and data_iterator is not None: - return checkpoint_manager.restore( - latest_step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeRestore( - item=abstract_unboxed_pre_state, - restore_args=restore_args), - iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)) - ), None + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + replica_axis_index=replica_axis_index, + global_shape=data.shape, + dtype=data.dtype, + ) + + restore_args = jax.tree_util.tree_map( + map_to_pspec, + abstract_unboxed_pre_state, + ) + if dataset_type == "c4-array_record" and data_iterator is not None: + return ( + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args), + iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator), + ), + ), + None, + ) else: return ( - checkpoint_manager.restore( - latest_step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeRestore( - item=abstract_unboxed_pre_state, - restore_args=restore_args) + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + ), ), - ), - None) + None, + ) if load_parameters_from_path != "": max_logging.log(f"restoring params from {load_parameters_from_path=}") @@ -182,16 +185,17 @@ def map_to_pspec(data): # memory, we instead specify here that we are just restoring the params field of the checkpoint # (which itself may be a dictionary containing a key named 'params'). restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(abstract_unboxed_pre_state.params) - restored = ckptr.restore(p, item = {'params': abstract_unboxed_pre_state.params}, transforms={}, - restore_args = {'params': restore_args}) - return None, restored['params'] + restored = ckptr.restore( + p, item={"params": abstract_unboxed_pre_state.params}, transforms={}, restore_args={"params": restore_args} + ) + return None, restored["params"] elif load_full_state_from_path != "": max_logging.log(f"restoring full state from {load_full_state_from_path=}") p = epath.Path(load_full_state_from_path) ckptr = orbax.checkpoint.StandardCheckpointer() restored = ckptr.restore(p, args=orbax.checkpoint.args.StandardRestore(abstract_unboxed_pre_state)) - return {'items': restored}, None + return {"items": restored}, None else: max_logging.log("No existing checkpoints found, not restoring checkpoint.") diff --git a/MaxText/common_types.py b/MaxText/common_types.py index a2e0f389b..2961104f3 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -32,13 +32,13 @@ AxisNames = tuple[str, ...] -BATCH = 'activation_batch' -LENGTH = 'activation_length' -HEAD = 'activation_heads' -D_KV = 'activation_kv' - -MODEL_MODE_AUTOREGRESSIVE = 'autoregressive' -MODEL_MODE_PREFILL = 'prefill' -MODEL_MODE_TRAIN = 'train' +BATCH = "activation_batch" +LENGTH = "activation_length" +HEAD = "activation_heads" +D_KV = "activation_kv" + +MODEL_MODE_AUTOREGRESSIVE = "autoregressive" +MODEL_MODE_PREFILL = "prefill" +MODEL_MODE_TRAIN = "train" DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 diff --git a/MaxText/convert_gemma_chkpt.py b/MaxText/convert_gemma_chkpt.py index a3cf1fd39..c690130c2 100644 --- a/MaxText/convert_gemma_chkpt.py +++ b/MaxText/convert_gemma_chkpt.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=line-too-long """ Convert orbax Gemma checkpoint to MaxText compatible checkpoint. @@ -18,7 +18,8 @@ import jax import jax.numpy as jnp import numpy as np -jax.config.update('jax_platform_name', 'cpu') + +jax.config.update("jax_platform_name", "cpu") import argparse import copy from flax.training import train_state @@ -35,44 +36,35 @@ Params = dict[str, Any] + def nest_params(params: Params) -> Params: """Nests params as a dict of dicts rather than a flat dict.""" nested_params = {} for path, param in params.items(): - *path, leaf = path.split('/') + *path, leaf = path.split("/") subdict = nested_params for key in path: subdict = subdict.setdefault(key, {}) subdict[leaf] = param return nested_params + def main(raw_args=None) -> None: parser = argparse.ArgumentParser() - parser.add_argument('--base_model_path', type=str, required=True) - parser.add_argument('--maxtext_model_path', type=str, required=True) - parser.add_argument('--model_size', type=str, required=True) + parser.add_argument("--base_model_path", type=str, required=True) + parser.add_argument("--maxtext_model_path", type=str, required=True) + parser.add_argument("--model_size", type=str, required=True) args = parser.parse_args(raw_args) - if args.model_size not in ('2b','7b'): + if args.model_size not in ("2b", "7b"): raise NotImplementedError print("Loading checkpoint") checkpointer = orbax.checkpoint.PyTreeCheckpointer() params = checkpointer.restore(args.base_model_path) params = nest_params(params) - num_layers = ( - max(( - int(k.split('_')[1]) - for k in params['transformer'].keys() - if 'layer_' in k - )) - + 1 - ) - hidden_dim, embed_dim = ( - params['transformer']['layer_0']['mlp']['linear']['w'].shape - ) - num_heads, head_dim, _ = ( - params['transformer']['layer_0']['attn']['attn_vec_einsum']['w'].shape - ) + num_layers = max((int(k.split("_")[1]) for k in params["transformer"].keys() if "layer_" in k)) + 1 + hidden_dim, embed_dim = params["transformer"]["layer_0"]["mlp"]["linear"]["w"].shape + num_heads, head_dim, _ = params["transformer"]["layer_0"]["attn"]["attn_vec_einsum"]["w"].shape print("Model configurations from checkpoint") print(f"num_layers: {num_layers}") print(f"hidden_dim: {hidden_dim}") @@ -81,109 +73,96 @@ def main(raw_args=None) -> None: print(f"head_dim: {head_dim}") jax_weights = { - 'decoder': { - 'decoder_norm': { - 'scale': params['transformer']['final_norm']['scale'] + 1 - }, + "decoder": { + "decoder_norm": {"scale": params["transformer"]["final_norm"]["scale"] + 1}, }, - 'token_embedder':{ - 'embedding': params['transformer']['embedder']['input_embedding'] * jnp.sqrt(embed_dim) - } - + "token_embedder": {"embedding": params["transformer"]["embedder"]["input_embedding"] * jnp.sqrt(embed_dim)}, } self_attention = dict({ - 'query': { - 'kernel' : [] - }, - 'key': { - 'kernel' : [] - }, - 'value': { - 'kernel' : [] - }, - 'out': { - 'kernel' : [] - }, + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, }) layer_weight = dict({ - 'mlp': { - 'wi_0': { - 'kernel' : [] - }, - 'wi_1': { - 'kernel' : [] - }, - 'wo': { - 'kernel' : [] - }, - }, - 'pre_self_attention_norm': { - 'scale': [] - }, - 'pre_ffw_norm': { - 'scale': [] - }, + "mlp": { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, + }, + "pre_self_attention_norm": {"scale": []}, + "pre_ffw_norm": {"scale": []}, }) for layer_idx in range(num_layers): - in_layer_name = 'layer_' + str(layer_idx) + in_layer_name = "layer_" + str(layer_idx) # attention block - if args.model_size == '2b': # MQA - self_attention['query']['kernel'].append(params['transformer'][in_layer_name]['attn']['q_einsum']['w'].transpose((1, 0, 2)) * head_dim**-0.5) - self_attention['key']['kernel'].append(params['transformer'][in_layer_name]['attn']['kv_einsum']['w'][0].transpose((1, 0, 2))) - self_attention['value']['kernel'].append(params['transformer'][in_layer_name]['attn']['kv_einsum']['w'][1].transpose((1, 0, 2))) + if args.model_size == "2b": # MQA + self_attention["query"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["q_einsum"]["w"].transpose((1, 0, 2)) * head_dim**-0.5 + ) + self_attention["key"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["kv_einsum"]["w"][0].transpose((1, 0, 2)) + ) + self_attention["value"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["kv_einsum"]["w"][1].transpose((1, 0, 2)) + ) else: - self_attention['query']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][0].transpose((1, 0, 2)) * head_dim**-0.5) - self_attention['key']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][1].transpose((1, 0, 2))) - self_attention['value']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][2].transpose((1, 0, 2))) - self_attention['out']['kernel'].append(params['transformer'][in_layer_name]['attn']['attn_vec_einsum']['w']) + self_attention["query"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][0].transpose((1, 0, 2)) * head_dim**-0.5 + ) + self_attention["key"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][1].transpose((1, 0, 2)) + ) + self_attention["value"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][2].transpose((1, 0, 2)) + ) + self_attention["out"]["kernel"].append(params["transformer"][in_layer_name]["attn"]["attn_vec_einsum"]["w"]) # mlp - layer_weight['mlp']['wi_0']['kernel'].append(params['transformer'][in_layer_name]['mlp']['gating_einsum']['w'][0]) - layer_weight['mlp']['wi_1']['kernel'].append(params['transformer'][in_layer_name]['mlp']['gating_einsum']['w'][1]) - layer_weight['mlp']['wo']['kernel'].append(params['transformer'][in_layer_name]['mlp']['linear']['w']) - layer_weight['pre_self_attention_norm']['scale'].append(params['transformer'][in_layer_name]['pre_attention_norm']['scale'] + 1) - layer_weight['pre_ffw_norm']['scale'].append(params['transformer'][in_layer_name]['pre_ffw_norm']['scale'] + 1) - - self_attention['query']['kernel'] = np.array(self_attention['query']['kernel']).transpose((1, 0, 2, 3)) - self_attention['key']['kernel'] = np.array(self_attention['key']['kernel']).transpose((1, 0, 2, 3)) - self_attention['value']['kernel'] = np.array(self_attention['value']['kernel']).transpose((1, 0, 2, 3)) - self_attention['out']['kernel'] = np.array(self_attention['out']['kernel']).transpose((1, 0, 2, 3)) - - layer_weight['mlp']['wi_0']['kernel'] = np.array(layer_weight['mlp']['wi_0']['kernel']).transpose((1, 0, 2)) - layer_weight['mlp']['wi_1']['kernel'] = np.array(layer_weight['mlp']['wi_1']['kernel']).transpose((1, 0, 2)) - layer_weight['mlp']['wo']['kernel'] = np.array(layer_weight['mlp']['wo']['kernel']).transpose((1, 0, 2)) - layer_weight['pre_self_attention_norm']['scale'] = np.array(layer_weight['pre_self_attention_norm']['scale']).transpose((1, 0)) - layer_weight['pre_ffw_norm']['scale'] = np.array(layer_weight['pre_ffw_norm']['scale']).transpose((1, 0)) - - layer_weight['self_attention'] = copy.deepcopy(self_attention) - jax_weights['decoder']['layers'] = copy.deepcopy(layer_weight) + layer_weight["mlp"]["wi_0"]["kernel"].append(params["transformer"][in_layer_name]["mlp"]["gating_einsum"]["w"][0]) + layer_weight["mlp"]["wi_1"]["kernel"].append(params["transformer"][in_layer_name]["mlp"]["gating_einsum"]["w"][1]) + layer_weight["mlp"]["wo"]["kernel"].append(params["transformer"][in_layer_name]["mlp"]["linear"]["w"]) + layer_weight["pre_self_attention_norm"]["scale"].append( + params["transformer"][in_layer_name]["pre_attention_norm"]["scale"] + 1 + ) + layer_weight["pre_ffw_norm"]["scale"].append(params["transformer"][in_layer_name]["pre_ffw_norm"]["scale"] + 1) + + self_attention["query"]["kernel"] = np.array(self_attention["query"]["kernel"]).transpose((1, 0, 2, 3)) + self_attention["key"]["kernel"] = np.array(self_attention["key"]["kernel"]).transpose((1, 0, 2, 3)) + self_attention["value"]["kernel"] = np.array(self_attention["value"]["kernel"]).transpose((1, 0, 2, 3)) + self_attention["out"]["kernel"] = np.array(self_attention["out"]["kernel"]).transpose((1, 0, 2, 3)) + + layer_weight["mlp"]["wi_0"]["kernel"] = np.array(layer_weight["mlp"]["wi_0"]["kernel"]).transpose((1, 0, 2)) + layer_weight["mlp"]["wi_1"]["kernel"] = np.array(layer_weight["mlp"]["wi_1"]["kernel"]).transpose((1, 0, 2)) + layer_weight["mlp"]["wo"]["kernel"] = np.array(layer_weight["mlp"]["wo"]["kernel"]).transpose((1, 0, 2)) + layer_weight["pre_self_attention_norm"]["scale"] = np.array(layer_weight["pre_self_attention_norm"]["scale"]).transpose( + (1, 0) + ) + layer_weight["pre_ffw_norm"]["scale"] = np.array(layer_weight["pre_ffw_norm"]["scale"]).transpose((1, 0)) + + layer_weight["self_attention"] = copy.deepcopy(self_attention) + jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight) jax_weights = jax.tree_map(jnp.array, jax_weights) + def astype_fn(x): if isinstance(x, jnp.ndarray): return x.astype(jnp.bfloat16) else: return x - jax_weights = jax.tree_map(astype_fn, jax_weights) - enable_checkpointing=True - async_checkpointing=False - save_interval_steps=1 + jax_weights = jax.tree_map(astype_fn, jax_weights) + enable_checkpointing = True + async_checkpointing = False + save_interval_steps = 1 checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - args.maxtext_model_path, - enable_checkpointing, - async_checkpointing, - save_interval_steps + args.maxtext_model_path, enable_checkpointing, async_checkpointing, save_interval_steps ) state_new = train_state.TrainState( - step=0, - apply_fn=None, - params={'params': jax_weights}, - tx=None, # type: ignore - opt_state={} + step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore ) if checkpoint_manager is not None: @@ -194,5 +173,6 @@ def astype_fn(x): checkpoint_manager.wait_until_finished() sys.exit() + if __name__ == "__main__": main() diff --git a/MaxText/convert_gpt3_ckpt_from_paxml.py b/MaxText/convert_gpt3_ckpt_from_paxml.py index 78d43c47b..3ec57f8a2 100644 --- a/MaxText/convert_gpt3_ckpt_from_paxml.py +++ b/MaxText/convert_gpt3_ckpt_from_paxml.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=line-too-long """Convert weights from a paxml gpt3 model to a MaxText one. @@ -53,6 +53,7 @@ from train import save_checkpoint import argparse + def fmt_size(num_bytes: int) -> str: assert num_bytes > 0 for unit in ["B", "KiB", "MiB", "GiB"]: @@ -61,14 +62,15 @@ def fmt_size(num_bytes: int) -> str: num_bytes /= 1024.0 return f"{num_bytes:.2f} {unit}" + def check_memory(): """print out cpu/tpu memory.""" cpu_bytes = Process().memory_info().rss max_logging.log(f"cpu memory: {fmt_size(cpu_bytes)}") for d in jax.local_devices(): stats = d.memory_stats() - used = stats['bytes_in_use'] - limit = stats['bytes_limit'] + used = stats["bytes_in_use"] + limit = stats["bytes_limit"] max_logging.log(f"tpu memory: Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}") @@ -76,13 +78,16 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name """convert ckpt.""" base_args = [ - '', 'MaxText/configs/base.yml', # base arg - 'per_device_batch_size=1', - 'ici_fsdp_parallelism=-1', 'ici_tensor_parallelism=1', - f'model_name={maxtext_model_name}', - f'run_name={run_name}', f'base_output_directory={base_output_directory}', - 'checkpoint_period=1', - 'async_checkpointing=false', + "", + "MaxText/configs/base.yml", # base arg + "per_device_batch_size=1", + "ici_fsdp_parallelism=-1", + "ici_tensor_parallelism=1", + f"model_name={maxtext_model_name}", + f"run_name={run_name}", + f"base_output_directory={base_output_directory}", + "checkpoint_period=1", + "async_checkpointing=false", ] pyconfig.initialize(base_args) cfg = pyconfig.config @@ -96,10 +101,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name tx = optimizers.get_optimizer(cfg, learning_rate_schedule) checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - cfg.checkpoint_dir, - cfg.enable_checkpointing, - cfg.async_checkpointing, - cfg.checkpoint_period, + cfg.checkpoint_dir, + cfg.enable_checkpointing, + cfg.async_checkpointing, + cfg.checkpoint_period, ) state, _, _ = max_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager) @@ -108,33 +113,87 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name # maxtext keystr: (paxml keystr, transform_fn) keystr_map = { - "['token_embedder']['embedding']": (".params.lm.softmax.logits_ffn.linear.w", lambda x: x.T), - "['decoder']['position_embedder']['embedding']": (".params.lm.position_emb.emb_var", None), - "['decoder']['layers']['pre_self_attention_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['pre_self_attention_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['query']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['query']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['key']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['key']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['value']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['value']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['qkv_proj']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis])), - "['decoder']['layers']['self_attention']['qkv_proj']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['out']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1])), - "['decoder']['layers']['self_attention']['out']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wi']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wi']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wo']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wo']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['decoder_norm']['scale']": (".params.lm.final_ln.scale", lambda x: x.T), - "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), + "['token_embedder']['embedding']": (".params.lm.softmax.logits_ffn.linear.w", lambda x: x.T), + "['decoder']['position_embedder']['embedding']": (".params.lm.position_emb.emb_var", None), + "['decoder']['layers']['pre_self_attention_norm']['scale']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['pre_self_attention_norm']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['query']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 0], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['query']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 0], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['key']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 1], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['key']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 1], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['value']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 2], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['value']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 2], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['qkv_proj']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis]), + ), + "['decoder']['layers']['self_attention']['qkv_proj']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['out']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", + lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1]), + ), + "['decoder']['layers']['self_attention']['out']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wi']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wi']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wo']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wo']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['decoder_norm']['scale']": (".params.lm.final_ln.scale", lambda x: x.T), + "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), } state_map = { - ".step": ("step", None), - ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), + ".step": ("step", None), + ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), } def get_layer_prefix(keystr_pax): @@ -151,9 +210,15 @@ def get_layer_prefix(keystr_pax): state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) prefix_pax_opt_state = get_layer_prefix(keystr_pax) # first momentum in optimizer state - state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", transform_fn) + state_map[f".opt_state.mu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", + transform_fn, + ) # second momentum in optimizer state - state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", transform_fn) + state_map[f".opt_state.nu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", + transform_fn, + ) def verify_fn(key_path, _): keystr = jax.tree_util.keystr(key_path) @@ -161,19 +226,19 @@ def verify_fn(key_path, _): jax.tree_util.tree_map_with_path(verify_fn, state) - memory_metrics = {'max_cpu_bytes': 0} + memory_metrics = {"max_cpu_bytes": 0} - bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://"):].split('/', 1) + bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://") :].split("/", 1) def map_fn(key_path, value): key_path_str = jax.tree_util.keystr(key_path) file_path, transform_fn = state_map[key_path_str] full_path = os.path.join(paxml_ckpt_prefix, file_path) - spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} - spec['kvstore'] = { - 'bucket': bucket_name, - 'driver': 'gcs', - 'path': full_path, + spec = {"driver": "zarr", "metadata_key": ".zarray", "kvstore": {}} + spec["kvstore"] = { + "bucket": bucket_name, + "driver": "gcs", + "path": full_path, } arr = ts.open(ts.Spec(spec), open=True).result().read().result() @@ -184,10 +249,9 @@ def map_fn(key_path, value): shape = value.shape sharding = value.sharding result = jax.make_array_from_single_device_arrays( - shape, - sharding, - [jax.device_put(np.array(arr[index]), d) - for d, index in sharding.addressable_devices_indices_map(shape).items()], + shape, + sharding, + [jax.device_put(np.array(arr[index]), d) for d, index in sharding.addressable_devices_indices_map(shape).items()], ) # log peak cpu memory @@ -216,15 +280,18 @@ def map_fn(key_path, value): max_logging.log(f"Peak cpu memory in a single process: {fmt_size(memory_metrics['max_cpu_bytes'])}") max_logging.log("checkpoint converted and saved successfully.") + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--paxml-ckpt-path', - type=str, - default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000", - required=True) - parser.add_argument('--maxtext-model-name', choices=['gpt3-175b', 'gpt3-52k'], type=str, required=True) - parser.add_argument('--base-output-directory', type=str, required=True) - parser.add_argument('--run-name', type=str, required=True) + parser.add_argument( + "--paxml-ckpt-path", + type=str, + default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000", + required=True, + ) + parser.add_argument("--maxtext-model-name", choices=["gpt3-175b", "gpt3-52k"], type=str, required=True) + parser.add_argument("--base-output-directory", type=str, required=True) + parser.add_argument("--run-name", type=str, required=True) args = parser.parse_args() if not args.paxml_ckpt_path.startswith("gs://"): diff --git a/MaxText/decode.py b/MaxText/decode.py index 15606b18e..56a92f94f 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -'''CLI Utility for Running Inference on a Single Stream''' +"""CLI Utility for Running Inference on a Single Stream""" import jax @@ -32,26 +32,21 @@ def main(config): metadata = engine.get_tokenizer() vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) tokenizer = vocab.tokenizer - tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, - prefill_lengths=[config.max_prefill_predict_length]) + tokens, true_length = token_utils.tokenize_and_pad( + text, vocab, is_bos=True, prefill_lengths=[config.max_prefill_predict_length] + ) assert tokens.size <= config.max_prefill_predict_length, "can't take too many tokens" assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" - prefill_result = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length - ) - slot=0 + prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + slot = 0 decode_state = engine.init_decode_state() - decode_state = engine.insert( - prefill_result, decode_state, slot=slot - ) + decode_state = engine.insert(prefill_result, decode_state, slot=slot) steps = range(config.max_prefill_predict_length, config.max_target_length) sampled_tokens_list = [] for _ in steps: - decode_state, sampled_tokens = engine.generate( - params, decode_state - ) + decode_state, sampled_tokens = engine.generate(params, decode_state) sampled_tokens_list.append(sampled_tokens) results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list] @@ -59,15 +54,19 @@ def main(config): print(f"Input `{text}` -> `{output}`") if config.autoregressive_decode_assert != "": - assert output==config.autoregressive_decode_assert, \ - f"generated text mismatch {output=} {config.autoregressive_decode_assert=}" + assert ( + output == config.autoregressive_decode_assert + ), f"generated text mismatch {output=} {config.autoregressive_decode_assert=}" + def validate_config(config): - assert config.load_full_state_path == "", "Decode doesn't operate on full states! Convert to parameter checkpoint first."\ - "Using generate_param_only_checkpoint." + assert config.load_full_state_path == "", ( + "Decode doesn't operate on full states! Convert to parameter checkpoint first." "Using generate_param_only_checkpoint." + ) + if __name__ == "__main__": - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(sys.argv) cfg = pyconfig.config diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index cda4273ca..09ea7412b 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Trasforms a "full state" including optimzer state to a bfloat16 "parameter state" without optimizer state. @@ -40,35 +40,38 @@ Transformer = models.Transformer + def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): - """ If input layers are scanned, and force_unroll is set, - return modify training_state and train_state_annotations to be "unrolled". - Otherwise do nothing.""" + """If input layers are scanned, and force_unroll is set, + return modify training_state and train_state_annotations to be "unrolled". + Otherwise do nothing.""" if not config.scan_layers or not config.force_unroll: return - training_state_layers = training_state.params['params']['decoder']['layers'] - training_state_annotations_layers = training_state_annotations.params['params']['decoder']['layers'] + training_state_layers = training_state.params["params"]["decoder"]["layers"] + training_state_annotations_layers = training_state_annotations.params["params"]["decoder"]["layers"] def new_pspec(x): - return jax.sharding.PartitionSpec(*x[0:config.param_scan_axis] + x[config.param_scan_axis+1:]) + return jax.sharding.PartitionSpec(*x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :]) new_per_layer_state_annotation = jax.tree_map(new_pspec, training_state_annotations_layers) - new_per_layer_state_sharding = jax.tree_map(lambda x : jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) + new_per_layer_state_sharding = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) for i in range(config.num_decoder_layers): + def slice_ith(input_layers): - return jax.tree_map(lambda x : jax.numpy.take(x, i, axis = config.param_scan_axis), input_layers) + return jax.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers) + + new_layer = jax.jit(slice_ith, out_shardings=new_per_layer_state_sharding)(training_state_layers) - new_layer = jax.jit(slice_ith, out_shardings = new_per_layer_state_sharding)(training_state_layers) + training_state.params["params"]["decoder"][f"layers_{i}"] = new_layer + training_state_annotations.params["params"]["decoder"][f"layers_{i}"] = new_per_layer_state_annotation - training_state.params['params']['decoder'][f'layers_{i}'] = new_layer - training_state_annotations.params['params']['decoder'][f'layers_{i}'] = new_per_layer_state_annotation + del training_state.params["params"]["decoder"]["layers"] + del training_state_annotations.params["params"]["decoder"]["layers"] - del training_state.params['params']['decoder']['layers'] - del training_state_annotations.params['params']['decoder']['layers'] + jax.tree_map(lambda x: x.delete(), training_state_layers) - jax.tree_map(lambda x : x.delete(), training_state_layers) def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" @@ -78,22 +81,22 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): rng = random.PRNGKey(0) learning_rate_schedule = max_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - state, state_mesh_notations, _ = max_utils.setup_training_state( - model, None, tx, config, rng, mesh, checkpoint_manager - ) + state, state_mesh_notations, _ = max_utils.setup_training_state(model, None, tx, config, rng, mesh, checkpoint_manager) num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") return state, state_mesh_notations + def _save_decode_checkpoint(config, state, checkpoint_manager): """Generate checkpoint for decode from the training_state.""" - with jax.spmd_mode('allow_all'): - decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x : x.astype(jax.numpy.bfloat16), state.params)) + with jax.spmd_mode("allow_all"): + decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params)) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, 0, decode_state): max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}") checkpoint_manager.wait_until_finished() + def generate_decode_checkpoint(config): """ Generate an decode checkpoint from a given training checkpoint. diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index ec46b137f..8fe46b3e2 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Inference microbenchmark for prefill and autoregressive steps.""" import datetime @@ -28,19 +28,21 @@ def summarize_pytree_data(params, name="Params"): - """ Generate basic metrics of a given Pytree. """ + """Generate basic metrics of a given Pytree.""" num_params, total_param_size, avg_param_size = max_utils.summarize_size_from_pytree(params) num_params_in_billions = num_params / 1e9 total_param_size_in_gb = total_param_size / 1e9 - print(f"{name} stats: \n" - f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" - f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" - f"\tAvg size: {avg_param_size:.3f} bytes\n") + print( + f"{name} stats: \n" + f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" + f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n" + ) return num_params, total_param_size, avg_param_size def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""): - """ Inner loop for benchmarking prefill step. """ + """Inner loop for benchmarking prefill step.""" max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() for i in range(iters): @@ -53,9 +55,10 @@ def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_le return (end - start).total_seconds(), decode_state -def prefill_benchmark(config, engine, params, decode_state, tokens, true_length, - iters=100, profile_name="", num_model_params=None): - """ Handles init, warmup, running prefill benchmark, and printing results. """ +def prefill_benchmark( + config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None +): + """Handles init, warmup, running prefill benchmark, and printing results.""" if num_model_params is None: num_model_params, _, _ = summarize_pytree_data(params, name="Params") @@ -69,22 +72,27 @@ def prefill_benchmark(config, engine, params, decode_state, tokens, true_length, print(f"Prefill results for length {tokens.size}:\n") profile_name = f"prefill_{tokens.size}" if profile_name == "" else profile_name - time_in_s, decode_state = prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, - profile_name=profile_name) + time_in_s, decode_state = prefill_benchmark_loop( + config, engine, decode_state, params, tokens, true_length, iters, profile_name=profile_name + ) prefill_average_ms = 1000 * time_in_s / iters total_prefill_tflops, _, _ = maxtext_utils.calculate_tflops_prefill(num_model_params, tokens.size, config) - tflops_per_sec_per_device = total_prefill_tflops / jax.device_count() / prefill_average_ms * 1000. - print(f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n" - f"\tPrefill total TFLOPs: {total_prefill_tflops:.3f}\n" - f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n") - result_dict = {"prefill_time_in_ms": prefill_average_ms, - "prefill_total_tflops": total_prefill_tflops, - "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device} + tflops_per_sec_per_device = total_prefill_tflops / jax.device_count() / prefill_average_ms * 1000.0 + print( + f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n" + f"\tPrefill total TFLOPs: {total_prefill_tflops:.3f}\n" + f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n" + ) + result_dict = { + "prefill_time_in_ms": prefill_average_ms, + "prefill_total_tflops": total_prefill_tflops, + "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device, + } return result_dict, decode_state def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=""): - """ Inner loop for benchmarking ar step. """ + """Inner loop for benchmarking ar step.""" max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() for _ in range(iters): @@ -96,9 +104,9 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name= def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100): - """ Handles init, warmup, running ar benchmark, and printing results. """ + """Handles init, warmup, running ar benchmark, and printing results.""" if cache_size is None: - _, cache_size, _ = summarize_pytree_data(decode_state['cache'], name="Cache") + _, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") if model_size is None: _, model_size, _ = summarize_pytree_data(params, name="Params") global_batch_size = jax.device_count() * config.per_device_batch_size @@ -112,38 +120,41 @@ def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_si profile_name = "autoregress" if profile_name == "" else profile_name time_in_s, decode_state = ar_benchmark_loop(config, engine, decode_state, params, profile_name=profile_name, iters=iters) seconds_per_step = time_in_s / iters - ar_average_ms = seconds_per_step*1000 + ar_average_ms = seconds_per_step * 1000 total_throughput = jax.device_count() * config.per_device_batch_size / seconds_per_step GB_per_step_per_device = (model_size + cache_size) / 1e9 / jax.device_count() - bw_per_device = GB_per_step_per_device/seconds_per_step - print(f"AutoRegressive results:\n" - f"\tAR step average time: {ar_average_ms:.3f}ms\n" - f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n" - f"\tAR global batch size: {global_batch_size}\n" - f"\tAR throughput: {total_throughput:.3f} tokens/second\n" - f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n") - - - result_dict = {"ar_step_in_ms": ar_average_ms, - "ar_step_in_ms_per_seq": ar_average_ms / global_batch_size, - "ar_global_batch_size": global_batch_size, - "ar_total_throughput_tokens_per_second": total_throughput, - "ar_device_bandwidth_GB_per_second": bw_per_device} + bw_per_device = GB_per_step_per_device / seconds_per_step + print( + f"AutoRegressive results:\n" + f"\tAR step average time: {ar_average_ms:.3f}ms\n" + f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n" + f"\tAR global batch size: {global_batch_size}\n" + f"\tAR throughput: {total_throughput:.3f} tokens/second\n" + f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n" + ) + + result_dict = { + "ar_step_in_ms": ar_average_ms, + "ar_step_in_ms_per_seq": ar_average_ms / global_batch_size, + "ar_global_batch_size": global_batch_size, + "ar_total_throughput_tokens_per_second": total_throughput, + "ar_device_bandwidth_GB_per_second": bw_per_device, + } return result_dict, decode_state def collate_results(config, results, model_size, cache_size, num_model_params, incl_config=False): - """ Adds model/cache size info and optionally config info to results. """ + """Adds model/cache size info and optionally config info to results.""" results["sizes"] = { - "Model_size_in_GB": model_size / 1e9, - "cache_size_in_GB": cache_size / 1e9, - "model_params_in_billions": num_model_params / 1e9, + "Model_size_in_GB": model_size / 1e9, + "cache_size_in_GB": cache_size / 1e9, + "model_params_in_billions": num_model_params / 1e9, } if incl_config: results["config"] = {} for k, v in dict(config.get_keys()).items(): - results["config"][k] = str(v) if k == "dtype" else v # json fails with original dtype + results["config"][k] = str(v) if k == "dtype" else v # json fails with original dtype return results @@ -172,18 +183,25 @@ def main(config): vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) decode_state = engine.init_decode_state() - _, cache_size, _ = summarize_pytree_data(decode_state['cache'], name="Cache") + _, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") num_model_params, model_size, _ = summarize_pytree_data(params, name="Model") benchmark_results = {"Prefill": {}} benchmark_results["AutoRegressive"], decode_state = ar_benchmark( - config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size) + config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size + ) for prefill_length in prefill_lengths: - tokens, true_length = token_utils.tokenize_and_pad( - text, vocab, is_bos=True, prefill_lengths=[prefill_length]) + tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, prefill_lengths=[prefill_length]) benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( - config, engine, params, decode_state, tokens, true_length, - iters=benchmark_loop_iters, num_model_params=num_model_params) + config, + engine, + params, + decode_state, + tokens, + true_length, + iters=benchmark_loop_iters, + num_model_params=num_model_params, + ) results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) write_results(results, filename="") diff --git a/MaxText/inference_scratch/analyze_sharegpt.py b/MaxText/inference_scratch/analyze_sharegpt.py index 47d8715cf..47847275d 100644 --- a/MaxText/inference_scratch/analyze_sharegpt.py +++ b/MaxText/inference_scratch/analyze_sharegpt.py @@ -20,13 +20,16 @@ MAX_INPUT_TOKENS = 1024 MAX_OUTPUT_TOKENS = 1024 + def next_power_of_2(x): - return 1 if x == 0 else 2**(x - 1).bit_length() + return 1 if x == 0 else 2 ** (x - 1).bit_length() + def tokens_in_input_str(s): - return_val = int(1.3 * len(s.split())) + return_val = int(1.3 * len(s.split())) return return_val + def get_prefill_and_generate_times(filename=""): if filename == "": return PREFILL_BUCKET_SIZE_TO_MS, SYSTEM_TIME_PER_DECODE_TOKEN_MS @@ -37,17 +40,18 @@ def get_prefill_and_generate_times(filename=""): for k, v in microbenchmark_results["Prefill"].items(): prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) - return prefill_bucket_size_to_ms, microbenchmark_results['AutoRegressive']['ar_step_in_ms_per_seq'] + return prefill_bucket_size_to_ms, microbenchmark_results["AutoRegressive"]["ar_step_in_ms_per_seq"] + def get_conversations_from_file(filename, max_input_tokens, max_output_tokens): convo_token_numbers = [] - with open(filename, 'r') as f: + with open(filename, "r") as f: loaded_share_gpt = json.load(f) for example in loaded_share_gpt: - if len(example['conversations']) < 2: + if len(example["conversations"]) < 2: continue - num_input_tokens = tokens_in_input_str(example['conversations'][0]['value']) - num_output_tokens = tokens_in_input_str(example['conversations'][1]['value']) + num_input_tokens = tokens_in_input_str(example["conversations"][0]["value"]) + num_output_tokens = tokens_in_input_str(example["conversations"][1]["value"]) convo_token_numbers.append((num_input_tokens, num_output_tokens)) num_convos = len(convo_token_numbers) @@ -78,9 +82,11 @@ def compute_times(convos, prefill_bucket_size_to_ms, system_time_per_decode_toke total_generate_time_seconds = total_generate_system_ms / 1000 total_time_s = total_prefill_time_seconds + total_generate_time_seconds - print(f"\nTotal time {total_time_s:.3f} seconds: " - f"\n\tPrefill time: {total_prefill_time_seconds:.3f} seconds" - f"\n\tGenerate time: {total_generate_time_seconds:.3f} seconds") + print( + f"\nTotal time {total_time_s:.3f} seconds: " + f"\n\tPrefill time: {total_prefill_time_seconds:.3f} seconds" + f"\n\tGenerate time: {total_generate_time_seconds:.3f} seconds" + ) return total_time_s, total_prefill_time_seconds, total_generate_time_seconds @@ -92,11 +98,11 @@ def get_num_tokens_in_convos(convos): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('convo_file', type=str, - help='a json file containing conversations') - parser.add_argument('-t', '--mb_timing_file', type=str, default="", - help='a json file containing microbenchmark timing results') - parser.add_argument('-v', '--verbose', action="store_true") + parser.add_argument("convo_file", type=str, help="a json file containing conversations") + parser.add_argument( + "-t", "--mb_timing_file", type=str, default="", help="a json file containing microbenchmark timing results" + ) + parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() convos = get_conversations_from_file(args.convo_file, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS) @@ -104,5 +110,7 @@ def get_num_tokens_in_convos(convos): prefill_time_ms_buckets, generate_time_ms = get_prefill_and_generate_times(filename=args.mb_timing_file) total_time_seconds, _, _ = compute_times(convos, prefill_time_ms_buckets, generate_time_ms, args.verbose) - print(f"Output {total_output_tokens} tokens in {total_time_seconds:.3f} seconds " - f"= {total_output_tokens/total_time_seconds:.3f} out tok/s") + print( + f"Output {total_output_tokens} tokens in {total_time_seconds:.3f} seconds " + f"= {total_output_tokens/total_time_seconds:.3f} out tok/s" + ) diff --git a/MaxText/inference_utils.py b/MaxText/inference_utils.py index 786ecaae7..96c727c0c 100644 --- a/MaxText/inference_utils.py +++ b/MaxText/inference_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax import jax.numpy as jnp @@ -26,14 +26,15 @@ Inspired by an Google-internal implementation, Global Vision Transformer. """ + def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): - """ - logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], before logit - rng: rng key to use - algorithm: string representing supported algorithms - topk: restricting to topk logits before sampling - nucleus_topp: restricting to p probability mass before sampling - temperature: temperature parameter for scaling probability + """ + logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], before logit + rng: rng key to use + algorithm: string representing supported algorithms + topk: restricting to topk logits before sampling + nucleus_topp: restricting to p probability mass before sampling + temperature: temperature parameter for scaling probability """ if algorithm == "greedy": return jnp.argmax(logits, axis=-1) @@ -46,36 +47,29 @@ def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): else: raise ValueError(f"Sampling {algorithm=} not supported!") + def sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng): """Restrict sampling to the top logits with cumulative probability >= nucleus_topp. - - The nucleus sampling method is proposed in the paper `The Curious Case of - Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` - + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + """ if nucleus_topp < 0: raise ValueError("Can't apply nucleus with parameter {nucleus_topp=} less zero") logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending - sorted_cum_probs = jnp.cumsum( - jax.nn.softmax(logits_sorted, axis=-1), axis=-1) # get cumsum probs - cutoff_index = jnp.sum( - sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True) # find cutoff index + sorted_cum_probs = jnp.cumsum(jax.nn.softmax(logits_sorted, axis=-1), axis=-1) # get cumsum probs + cutoff_index = jnp.sum(sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True) # find cutoff index cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) - logits = jnp.where(logits < cutoff_logit, - jnp.full_like(logits, NEG_INF), logits) + logits = jnp.where(logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits) return jax.random.categorical(rng, logits / temperature) + def sample_topk_logits(logits, topk, temperature, rng): - """ Restricting sampling to the best k logits. """ + """Restricting sampling to the best k logits.""" if topk <= 0: raise ValueError("Can't apply algorithm topk with parameter {topk=} less than or equal to zero") topk_logits, topk_idxs = jax.lax.top_k(logits, topk) - topk_token = jnp.expand_dims( - jax.random.categorical(rng, topk_logits/temperature).astype(jnp.int32), - axis=-1) - sampled_tokens = jnp.squeeze( - jnp.take_along_axis(topk_idxs, topk_token, axis=-1), - axis=-1).astype(jnp.int32) + topk_token = jnp.expand_dims(jax.random.categorical(rng, topk_logits / temperature).astype(jnp.int32), axis=-1) + sampled_tokens = jnp.squeeze(jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1).astype(jnp.int32) return sampled_tokens - - diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index 92c90854d..57b8f1e5e 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline using Grain.""" @@ -30,31 +30,33 @@ import multihost_dataloading -def get_datasets( - config: ml_collections.ConfigDict -): + +def get_datasets(config: ml_collections.ConfigDict): """Load dataset from array_record files for using with grain""" data_dir = os.path.join(config.dataset_path, config.dataset_name) - train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)] + train_files = [data_dir + "/" + f for f in os.listdir(data_dir) if re.match(r".*train.*", f)] train_ds = grain.ArrayRecordDataSource(train_files) if config.eval_dataset_name: - eval_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(rf'.*{config.eval_split}.*', f)] + eval_files = [data_dir + "/" + f for f in os.listdir(data_dir) if re.match(rf".*{config.eval_split}.*", f)] eval_ds = grain.ArrayRecordDataSource(eval_files) else: eval_ds = train_ds return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, - global_mesh, - train_ds, eval_ds, - vocab_path: Optional[str] = None, - data_shuffle_seed = 0, - add_bos = True, - add_eos = True - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, + global_mesh, + train_ds, + eval_ds, + vocab_path: Optional[str] = None, + data_shuffle_seed=0, + add_bos=True, + add_eos=True, +): """Use grain to pre-process the dataset and return iterators""" # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load @@ -78,7 +80,8 @@ def preprocess_dataset(config: ml_collections.ConfigDict, num_epochs=1, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) eval_iter = preprocessing_pipeline( eval_ds, @@ -93,7 +96,8 @@ def preprocess_dataset(config: ml_collections.ConfigDict, shuffle=config.enable_data_shuffling, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) predict_iter = preprocessing_pipeline( eval_ds, @@ -108,45 +112,45 @@ def preprocess_dataset(config: ml_collections.ConfigDict, shuffle=config.enable_data_shuffling, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed) + data_shuffle_seed=data_shuffle_seed, + ) return train_iter, eval_iter, predict_iter + def preprocessing_pipeline( - dataset, - vocab_path, - add_bos: bool, - add_eos: bool, - grain_worker_count: int, - batch_size: int, - global_mesh, - dataloading_host_index, - dataloading_host_count, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - max_length: int = 512, - shift: bool = True, - drop_remainder: bool = True, - data_shuffle_seed = 0, + dataset, + vocab_path, + add_bos: bool, + add_eos: bool, + grain_worker_count: int, + batch_size: int, + global_mesh, + dataloading_host_index, + dataloading_host_count, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + max_length: int = 512, + shift: bool = True, + drop_remainder: bool = True, + data_shuffle_seed=0, ): """Apply grain operations to preprocess the given dataset.""" - assert ( - batch_size % global_mesh.size == 0 - ), 'Batch size should be divisible number of global devices.' + assert batch_size % global_mesh.size == 0, "Batch size should be divisible number of global devices." operations = [] operations.append(_grain_operations.ParseFeatures()) operations.append(_grain_operations.NormalizeFeatures()) - operations.append(_grain_tokenizer.TokenizeAndTrim(["inputs","targets"], - max_length, vocab_path, - add_bos, add_eos)) + operations.append(_grain_tokenizer.TokenizeAndTrim(["inputs", "targets"], max_length, vocab_path, add_bos, add_eos)) # Pack and Batch examples. if pack_examples: - operations.append(grain.experimental.PackAndBatchOperation( - batch_size=batch_size // jax.process_count(), - length_struct={'inputs':max_length,'targets':max_length})) + operations.append( + grain.experimental.PackAndBatchOperation( + batch_size=batch_size // jax.process_count(), length_struct={"inputs": max_length, "targets": max_length} + ) + ) operations.append(_grain_operations.ReformatPacking()) else: operations.append(_grain_operations.PadToMaxLength(max_length)) @@ -157,19 +161,19 @@ def preprocessing_pipeline( operations.append(_grain_operations.ShiftData(axis=1)) index_sampler = grain.IndexSampler( - num_records=len(dataset), - num_epochs = num_epochs, - shard_options=grain.ShardOptions( - shard_index = dataloading_host_index, shard_count = dataloading_host_count, drop_remainder = True - ), - shuffle = shuffle, - seed = data_shuffle_seed + num_records=len(dataset), + num_epochs=num_epochs, + shard_options=grain.ShardOptions( + shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=True + ), + shuffle=shuffle, + seed=data_shuffle_seed, ) dataloader = grain.DataLoader( - data_source = dataset, - operations = operations, - sampler = index_sampler, + data_source=dataset, + operations=operations, + sampler=index_sampler, worker_count=grain_worker_count, ) diff --git a/MaxText/input_pipeline/_grain_operations.py b/MaxText/input_pipeline/_grain_operations.py index 25508376b..685165381 100644 --- a/MaxText/input_pipeline/_grain_operations.py +++ b/MaxText/input_pipeline/_grain_operations.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Operations used by Grain""" @@ -21,60 +21,64 @@ import grain.python as grain import numpy as np import tensorflow as tf + Features = Dict[str, tf.Tensor] + @dataclasses.dataclass class ParseFeatures(grain.MapTransform): """Parse serialized example""" + def map(self, features): def _parse(example): - parsed = tf.io.parse_example( - example, { - 'text': tf.io.FixedLenFeature(shape=(), dtype=tf.string) - }) + parsed = tf.io.parse_example(example, {"text": tf.io.FixedLenFeature(shape=(), dtype=tf.string)}) return parsed + return _parse(features) @dataclasses.dataclass class NormalizeFeatures(grain.MapTransform): """Normalize text feature keys.""" + def map(self, features): - return { - 'inputs':features['text'].numpy().decode(), - 'targets': features['text'].numpy().decode() - } + return {"inputs": features["text"].numpy().decode(), "targets": features["text"].numpy().decode()} @dataclasses.dataclass class ReformatPacking(grain.MapTransform): """Reformat packing outputs.""" + def map(self, data): - return{ - 'inputs':data[0]['inputs'], - 'targets':data[0]['targets'], - 'inputs_segmentation':data[1]['inputs'], - 'targets_segmentation':data[1]['targets'], - 'inputs_position':data[2]['inputs'], - 'targets_position':data[2]['targets'], + return { + "inputs": data[0]["inputs"], + "targets": data[0]["targets"], + "inputs_segmentation": data[1]["inputs"], + "targets_segmentation": data[1]["targets"], + "inputs_position": data[2]["inputs"], + "targets_position": data[2]["targets"], } @dataclasses.dataclass class PadToMaxLength(grain.MapTransform): - """Pads each input to the specified length""" + """Pads each input to the specified length""" + def __init__(self, max_length): self.max_length = max_length + def map(self, data): """map to each element""" + def _pad(x, max_length): pad_amount = max(max_length - x.shape[0], 0) pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1) return np.pad(x, pad_amount) - data['inputs_segmentation'] = np.ones(data['inputs'].shape, dtype = np.int32) - data['inputs_position'] = np.arange(data['inputs'].shape[0], dtype = np.int32) - data['targets_segmentation'] = np.ones(data['targets'].shape, dtype = np.int32) - data['targets_position'] = np.arange(data['targets'].shape[0], dtype = np.int32) + + data["inputs_segmentation"] = np.ones(data["inputs"].shape, dtype=np.int32) + data["inputs_position"] = np.arange(data["inputs"].shape[0], dtype=np.int32) + data["targets_segmentation"] = np.ones(data["targets"].shape, dtype=np.int32) + data["targets_position"] = np.arange(data["targets"].shape[0], dtype=np.int32) for key, _ in data.items(): data[key] = _pad(data[key], self.max_length) return data @@ -84,33 +88,34 @@ def shift_right(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - slices = [slice(None),] * len(x.shape) + slices = [ + slice(None), + ] * len(x.shape) slices[axis] = slice(0, -1) - padded = np.pad( - x, - pad_widths, - mode='constant', - constant_values=x.dtype.type(0) - ) + padded = np.pad(x, pad_widths, mode="constant", constant_values=x.dtype.type(0)) return padded[tuple(slices)] + def shift_and_refine(x, axis=1): """Shift inputs, set segmentation to 0 when target element is 0. Replace EOS by 0 for packed inputs.""" - x['inputs'] = shift_right(x['inputs'], axis=axis) - targets_nonzero = x['targets'] != 0 - x['inputs_segmentation'] *= targets_nonzero - x['targets_segmentation'] *= targets_nonzero + x["inputs"] = shift_right(x["inputs"], axis=axis) + targets_nonzero = x["targets"] != 0 + x["inputs_segmentation"] *= targets_nonzero + x["targets_segmentation"] *= targets_nonzero # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. - x['inputs'] *= x['inputs_segmentation'] == shift_right(x['inputs_segmentation'], axis=axis) + x["inputs"] *= x["inputs_segmentation"] == shift_right(x["inputs_segmentation"], axis=axis) return x + @dataclasses.dataclass class ShiftData(grain.MapTransform): """Shift inputs and refine annotations.""" - def __init__(self, axis = 1): + + def __init__(self, axis=1): self.axis = axis + def map(self, data): return shift_and_refine(data, axis=self.axis) diff --git a/MaxText/input_pipeline/_grain_tokenizer.py b/MaxText/input_pipeline/_grain_tokenizer.py index e9436799a..9a79af448 100644 --- a/MaxText/input_pipeline/_grain_tokenizer.py +++ b/MaxText/input_pipeline/_grain_tokenizer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Tokenize Op used by Grain""" @@ -24,9 +24,11 @@ import grain.python as grain import numpy as np + @dataclasses.dataclass class TokenizeAndTrim(grain.MapTransform): """Tokenize and trim features to sequence length.""" + # pylint: disable=attribute-defined-outside-init feature_names: str | Sequence[str] sequence_length: int | Sequence[int] @@ -49,16 +51,14 @@ def map(self, features: dict[str, Any]) -> dict[str, Any]: if self._processor is None: # Ensures only one thread initializes SPP. self._processor = SentencePieceProcessor() self._processor.Load(self.model_path) - for feature_name, sequence_length in zip( - self.feature_names, self.sequence_length, strict=True - ): + for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True): text = features[feature_name] token_ids = self._processor.EncodeAsIds(text) if self.add_bos: token_ids = [self._processor.bos_id()] + token_ids if self.add_eos: - token_ids = token_ids[:sequence_length-1] + token_ids = token_ids[: sequence_length - 1] token_ids = token_ids + [self._processor.eos_id()] else: token_ids = token_ids[:sequence_length] diff --git a/MaxText/input_pipeline/_tfds_data_processing.py b/MaxText/input_pipeline/_tfds_data_processing.py index 506e7bd87..865278831 100644 --- a/MaxText/input_pipeline/_tfds_data_processing.py +++ b/MaxText/input_pipeline/_tfds_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline for a LM1B dataset.""" @@ -34,17 +34,16 @@ # Right-shifting token inputs for teacher-forced training. # ----------------------------------------------------------------------------- + def shift_right_tf(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - slices = [slice(None),] * len(x.shape) + slices = [ + slice(None), + ] * len(x.shape) slices[axis] = slice(0, -1) - padded = tf.pad( - x, - tf.constant(pad_widths), - mode='constant', - constant_values=tf.constant(0, x.dtype)) + padded = tf.pad(x, tf.constant(pad_widths), mode="constant", constant_values=tf.constant(0, x.dtype)) return padded[tuple(slices)] @@ -54,46 +53,45 @@ def shift_inputs_tf(x, segment_ids=None, axis=1): # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: - shifted *= tf.cast( - segment_ids == shift_right_tf(segment_ids, axis=axis), x.dtype - ) + shifted *= tf.cast(segment_ids == shift_right_tf(segment_ids, axis=axis), x.dtype) return shifted + def shift_data(x, axis=0, segmented=True): - segment_ids = x['inputs_segmentation'] if segmented else None - x['inputs'] = shift_inputs_tf(x['inputs'], segment_ids=segment_ids, axis=axis) + segment_ids = x["inputs_segmentation"] if segmented else None + x["inputs"] = shift_inputs_tf(x["inputs"], segment_ids=segment_ids, axis=axis) return x + def shift_data_by_truncation(x): - x['inputs'] = x['inputs'][:-1] - x['targets'] = x['targets'][1:] + x["inputs"] = x["inputs"][:-1] + x["targets"] = x["targets"][1:] return x def normalize_features(ds): """Normalize text feature keys.""" + def _normalize_features(features): - features['inputs'] = features.pop('text') - features['targets'] = features['inputs'] + features["inputs"] = features.pop("text") + features["targets"] = features["inputs"] return features - return ds.map( - _normalize_features, - num_parallel_calls=AUTOTUNE) + return ds.map(_normalize_features, num_parallel_calls=AUTOTUNE) + def length_trim(ds, max_len): - """"Trim to Max length""" + """ "Trim to Max length""" + def _trim_fn(features): - if tf.shape(features['inputs'])[0] > max_len: - features['inputs'] = features['inputs'][:max_len] - if tf.shape(features['targets'])[0] > max_len: - features['targets'] = features['targets'][:max_len] + if tf.shape(features["inputs"])[0] > max_len: + features["inputs"] = features["inputs"][:max_len] + if tf.shape(features["targets"])[0] > max_len: + features["targets"] = features["targets"][:max_len] return features - return ds.map( - _trim_fn, - num_parallel_calls=AUTOTUNE - ) + return ds.map(_trim_fn, num_parallel_calls=AUTOTUNE) + # ----------------------------------------------------------------------------- # Main dataset preparation. @@ -101,52 +99,45 @@ def _trim_fn(features): def preprocessing_pipeline( - dataset, - batch_size: int, - global_mesh, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - shuffle_buffer_size: int = 1024, - max_length: int = 512, - shift: bool = True, - drop_remainder: bool = True, - prefetch_size = tf.data.experimental.AUTOTUNE, - data_shuffle_seed = 0, + dataset, + batch_size: int, + global_mesh, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + shift: bool = True, + drop_remainder: bool = True, + prefetch_size=tf.data.experimental.AUTOTUNE, + data_shuffle_seed=0, ): """Shuffle and batch/pack the given dataset.""" def truncate_to_max_allowable_length(x, max_length): - x['inputs'] = x['inputs'][:max_length] - x['targets'] = x['targets'][:max_length] + x["inputs"] = x["inputs"][:max_length] + x["targets"] = x["targets"][:max_length] return x - if max_length > 0: # We can take upto max_length+1 because there would be truncation by 1 token # for both inputs and targets - dataset = dataset.map(lambda x: truncate_to_max_allowable_length(x, max_length+1)) + dataset = dataset.map(lambda x: truncate_to_max_allowable_length(x, max_length + 1)) # Shuffle and repeat. if shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, seed = data_shuffle_seed) + dataset = dataset.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) dataset = dataset.repeat(num_epochs) - # Shift inputs for teacher-forced training if shift: - dataset = dataset.map( - shift_data_by_truncation, - num_parallel_calls=tf.data.AUTOTUNE, - deterministic=True) + dataset = dataset.map(shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) # Perform greedy sequence packing if pack_examples: dataset = sequence_packing.pack_dataset(dataset, max_length) - assert ( - batch_size % global_mesh.size == 0 - ), 'Batch size should be divisible number of global devices.' + assert batch_size % global_mesh.size == 0, "Batch size should be divisible number of global devices." # Batch examples. if pack_examples: @@ -155,9 +146,10 @@ def truncate_to_max_allowable_length(x, max_length): # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size // jax.process_count(), - padded_shapes={'inputs': max_length, 'targets': max_length}, - padding_values={'inputs': 0, 'targets': 0}, - drop_remainder=drop_remainder) + padded_shapes={"inputs": max_length, "targets": max_length}, + padding_values={"inputs": 0, "targets": 0}, + drop_remainder=drop_remainder, + ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) @@ -169,20 +161,18 @@ def truncate_to_max_allowable_length(x, max_length): def get_datasets( - config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, - read_config = None, + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, + read_config=None, ): """Load and return dataset of batched examples for use during training.""" # Training dataset. train_ds_builder = tfds.builder(config.dataset_name) # train_data = get_raw_dataset(train_ds_builder, 'train') - train_ds = train_ds_builder.as_dataset(split='train', - read_config = read_config, - shuffle_files=config.enable_data_shuffling) + train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=config.enable_data_shuffling) # shard the dataset as soon as it is loaded - train_ds = train_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) + train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) train_ds = normalize_features(train_ds) # Evaluation dataset. @@ -191,25 +181,25 @@ def get_datasets( else: eval_ds_builder = train_ds_builder # eval_data = get_raw_dataset(eval_ds_builder, config.eval_split) - eval_ds = eval_ds_builder.as_dataset(split=config.eval_split, - read_config = read_config, - shuffle_files=False) - eval_ds = eval_ds.shard(num_shards = jax.process_count(), index = jax.process_index()) + eval_ds = eval_ds_builder.as_dataset(split=config.eval_split, read_config=read_config, shuffle_files=False) + eval_ds = eval_ds.shard(num_shards=jax.process_count(), index=jax.process_index()) eval_ds = normalize_features(eval_ds) return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - global_mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed = 0, - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + global_mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed=0, +): """Pre-process the dataset and return iterators""" # Tokenize data. - train_ds = train_ds.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) - eval_ds = eval_ds.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + train_ds = train_ds.map(tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + eval_ds = eval_ds.map(tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load @@ -220,9 +210,10 @@ def preprocess_dataset(config: ml_collections.ConfigDict, eval_batch_size = global_batch_size_to_load def filter_keys(record): - return {'inputs': record['inputs'], 'targets': record['targets']} - train_ds = train_ds.map(filter_keys,num_parallel_calls=tf.data.AUTOTUNE) - eval_ds = eval_ds.map(filter_keys,num_parallel_calls=tf.data.AUTOTUNE) + return {"inputs": record["inputs"], "targets": record["targets"]} + + train_ds = train_ds.map(filter_keys, num_parallel_calls=tf.data.AUTOTUNE) + eval_ds = eval_ds.map(filter_keys, num_parallel_calls=tf.data.AUTOTUNE) train_iter = preprocessing_pipeline( train_ds, @@ -233,7 +224,8 @@ def filter_keys(record): pack_examples=True, max_length=config.max_target_length, shift=True, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) eval_iter = preprocessing_pipeline( eval_ds, @@ -244,7 +236,8 @@ def filter_keys(record): max_length=config.max_target_length, shift=False, drop_remainder=False, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) predict_iter = preprocessing_pipeline( eval_ds, @@ -255,6 +248,7 @@ def filter_keys(record): max_length=config.max_target_length, shift=False, drop_remainder=False, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) return train_iter, eval_iter, predict_iter diff --git a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index c5a6d8d06..0ca44bdb1 100644 --- a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline for gpt3 c4 mlperf dataset.""" @@ -35,6 +35,7 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE + # data processing functions: # _shift_left_and_pad, rekey, reduce_concat_tokens and split_tokens_to_targets_length # Adapted from: @@ -58,8 +59,10 @@ def _shift_left_and_pad(tensor, pad_val): v = v[0] return v + def rekey(ds, key_map=None): """normalization with key mapping""" + def _rekey(x, key_map=None): """Replace the feature keys according to the mapping in `key_map`. For example, if the dataset returns examples of the format: @@ -75,20 +78,17 @@ def _rekey(x, key_map=None): A preprocessed example with the format listed above. """ if key_map: - return { - new_key: x[old_key] - for new_key, old_key in key_map.items() if old_key - } + return {new_key: x[old_key] for new_key, old_key in key_map.items() if old_key} return x - return ds.map( - functools.partial(_rekey, key_map=key_map), - num_parallel_calls=AUTOTUNE) + return ds.map(functools.partial(_rekey, key_map=key_map), num_parallel_calls=AUTOTUNE) + -def reduce_concat_tokens(dataset, - feature_key='targets', - batch_size=128, - ): +def reduce_concat_tokens( + dataset, + feature_key="targets", + batch_size=128, +): """Token-preprocessor to concatenate multiple unrelated documents. If we want to generate examples of exactly the right length, (to avoid wasting space on padding), then we use this function, folowed by @@ -100,9 +100,9 @@ def reduce_concat_tokens(dataset, Returns: a dataset """ - dataset = dataset.map( - lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE) + dataset = dataset.map(lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE) dataset = dataset.padded_batch(batch_size, padded_shapes={feature_key: [-1]}) + def _my_fn(x): tokens = tf.reshape(x[feature_key], [-1]) # strip padding @@ -111,10 +111,12 @@ def _my_fn(x): return dataset.map(_my_fn, num_parallel_calls=AUTOTUNE) -def split_tokens(dataset, - max_tokens_per_segment=128, - feature_key='targets', - ): + +def split_tokens( + dataset, + max_tokens_per_segment=128, + feature_key="targets", +): """Split examples into multiple examples each. The intended use case is to break up long examples for use in unsupervised transfer-learning. @@ -127,6 +129,7 @@ def split_tokens(dataset, Returns: a dataset """ + def _split_tokens(x): """Split one token sequence into multiple multiple.""" tokens = x[feature_key] @@ -135,9 +138,7 @@ def _split_tokens(x): # Pad to a multiple of length, then use tf.reshape to split up the tokens # into num_segments segments each of the given length. - num_segments = tf.cast( - tf.math.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)), - tf.int32) + num_segments = tf.cast(tf.math.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)), tf.int32) padding = num_segments * length - tf.size(tokens) tokens = tf.pad(tokens, [[0, padding]]) return tf.reshape(tokens, [-1, length]) @@ -149,19 +150,25 @@ def _strip_padding(x): dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0)) dataset = dataset.map(_split_tokens, num_parallel_calls=AUTOTUNE) dataset = dataset.unbatch() - return dataset.map( - _strip_padding, num_parallel_calls=AUTOTUNE) + return dataset.map(_strip_padding, num_parallel_calls=AUTOTUNE) + def split_tokens_to_targets_length(dataset, sequence_length): return split_tokens(dataset, max_tokens_per_segment=sequence_length) -def _pad_to_batch_size(ds: tf.data.Dataset, batch_size: int, num_examples: Optional[int] = None,) -> tf.data.Dataset: + +def _pad_to_batch_size( + ds: tf.data.Dataset, + batch_size: int, + num_examples: Optional[int] = None, +) -> tf.data.Dataset: """Pad unevenly distributed eval data in each shard with new entries to multiples of batch size.""" # local_num represents the total number of examples in eval dataset, if num_examples: local_num = num_examples else: + def _get_num_examples(ds: tf.data.Dataset) -> int: # Iterate one-by-one instead of len(list(...)) to reduce peak memory. num_examples = 0 @@ -173,61 +180,66 @@ def _get_num_examples(ds: tf.data.Dataset) -> int: local_num = _get_num_examples(ds) local_num_batches = (local_num + batch_size - 1) // batch_size # Find the max number of batches required across all Jax processes. - num_batches_all = multihost_utils.process_allgather( - jnp.array([local_num_batches]), tiled=False) + num_batches_all = multihost_utils.process_allgather(jnp.array([local_num_batches]), tiled=False) num_batches = np.max(num_batches_all) pad_num = num_batches * batch_size - local_num assert pad_num >= 0 print( - f'Eval data has {local_num} local entries, padding now with ' - f'{pad_num} extra entries to get {num_batches} batches.') + f"Eval data has {local_num} local entries, padding now with " f"{pad_num} extra entries to get {num_batches} batches." + ) + # Repeat a random example to make the last batch full. def _add_pad(x): - x['targets_segmentation'] *= 0 + x["targets_segmentation"] *= 0 return x + pad_ds = ds.take(1).map(_add_pad).repeat(pad_num) return ds.concatenate(pad_ds) + def get_datasets( - config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, ): """Load and return dataset of batched examples for use during training.""" # Training dataset. read_config = tfds.ReadConfig( - shuffle_seed = config.data_shuffle_seed, - ) + shuffle_seed=config.data_shuffle_seed, + ) train_ds_builder = tfds.builder(config.dataset_name) - train_ds = train_ds_builder.as_dataset(split='train2', read_config=read_config, shuffle_files=config.enable_data_shuffling) + train_ds = train_ds_builder.as_dataset(split="train2", read_config=read_config, shuffle_files=config.enable_data_shuffling) eval_ds_builder = tfds.builder(config.eval_dataset_name) - eval_ds = eval_ds_builder.as_dataset(split='validation_tokenized_5662seqs', read_config=read_config, shuffle_files=False) + eval_ds = eval_ds_builder.as_dataset(split="validation_tokenized_5662seqs", read_config=read_config, shuffle_files=False) # shard the dataset as soon as it is loaded - train_ds = train_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) - train_ds = rekey(train_ds, {'inputs': None, 'targets': 'text'}) + train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) + train_ds = rekey(train_ds, {"inputs": None, "targets": "text"}) - eval_ds = eval_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) + eval_ds = eval_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index) # note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and splitted to target_length # mainly to avoid eval sequences change depending on the number of hosts - eval_ds = rekey(eval_ds, {'inputs': None, 'targets': 'ids'}) + eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"}) return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - global_mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed: int = 0, - shuffle_buffer_size: int = 128, - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + global_mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed: int = 0, + shuffle_buffer_size: int = 128, +): """Pre-process the dataset and return iterators for mlperf training.""" # tokenize - train_ds = train_ds.map( - tokenizer.TokenizeOp(sp_tokenizer, data_keys=('targets',)), num_parallel_calls=AUTOTUNE) + train_ds = train_ds.map(tokenizer.TokenizeOp(sp_tokenizer, data_keys=("targets",)), num_parallel_calls=AUTOTUNE) - train_ds = reduce_concat_tokens(train_ds, feature_key='targets', batch_size=4096) + train_ds = reduce_concat_tokens(train_ds, feature_key="targets", batch_size=4096) train_ds = split_tokens_to_targets_length(train_ds, config.max_target_length) train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) @@ -241,8 +253,8 @@ def format_fn(x, eos_id: int = 1, pad_id: int = 0): x["inputs_position"] = x["targets_position"] x["targets"] = _shift_left_and_pad(x["targets"], eos_id) x["inputs_segmentation"] = tf.where( - tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), - x["targets_segmentation"], 0) + tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), x["targets_segmentation"], 0 + ) x["targets_segmentation"] = x["inputs_segmentation"] return x diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 58227feee..37ecd8f4a 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline""" @@ -26,79 +26,80 @@ from input_pipeline import _tfds_data_processing_c4_mlperf import tokenizer + def get_tokenizer(tokenizer_path, add_bos=True, add_eos=True): # Load tokenizer - sp_tokenizer = tokenizer.load_tokenizer(tokenizer_path=tokenizer_path, - add_bos=add_bos, - add_eos=add_eos) + sp_tokenizer = tokenizer.load_tokenizer(tokenizer_path=tokenizer_path, add_bos=add_bos, add_eos=add_eos) return sp_tokenizer + def make_c4_mlperf_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for customized C4 dataset for mlperf gpt3 training.""" + """Make train iterator and tokenizer for customized C4 dataset for mlperf gpt3 training.""" train_ds, eval_ds = _tfds_data_processing_c4_mlperf.get_datasets( - config=config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), + config=config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), ) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, eval_iter = _tfds_data_processing_c4_mlperf.preprocess_dataset( - config, - mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed=config.data_shuffle_seed + config, mesh, train_ds, eval_ds, sp_tokenizer, data_shuffle_seed=config.data_shuffle_seed ) return train_iter, eval_iter, sp_tokenizer + def make_c4_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for C4 dataset""" + """Make train iterator and tokenizer for C4 dataset""" read_config = tfds.ReadConfig( - shuffle_seed = config.data_shuffle_seed, + shuffle_seed=config.data_shuffle_seed, ) train_ds, eval_ds = _tfds_data_processing.get_datasets( - config=config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - read_config = read_config, + config=config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + read_config=read_config, ) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, _, _ = _tfds_data_processing.preprocess_dataset( - config, - mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed = config.data_shuffle_seed, + config, + mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed=config.data_shuffle_seed, ) return train_iter, None, sp_tokenizer + def make_grain_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for C4 dataset""" - train_ds, eval_ds = _grain_data_processing.get_datasets( - config=config - ) + """Make train iterator and tokenizer for C4 dataset""" + train_ds, eval_ds = _grain_data_processing.get_datasets(config=config) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, _, _ = _grain_data_processing.preprocess_dataset( - config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - global_mesh = mesh, - train_ds = train_ds, eval_ds = eval_ds, - vocab_path=config.tokenizer_path, - data_shuffle_seed = config.data_shuffle_seed, - add_bos = add_bos, - add_eos = add_eos + config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=mesh, + train_ds=train_ds, + eval_ds=eval_ds, + vocab_path=config.tokenizer_path, + data_shuffle_seed=config.data_shuffle_seed, + add_bos=add_bos, + add_eos=add_eos, ) return train_iter, None, sp_tokenizer -class SyntheticDataIterator(): + +class SyntheticDataIterator: """Creates a synthetic data iterator for performance testing work""" + def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) - data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - self.data_generator = jax.jit(SyntheticDataIterator.raw_generate_synthetic_data, - out_shardings=data_pspec_shardings, - static_argnums=0) + data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + self.data_generator = jax.jit( + SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0 + ) def __iter__(self): return self @@ -111,31 +112,34 @@ def __next__(self): def raw_generate_synthetic_data(config): """Generates a single batch of syntehtic data""" output = {} - output['inputs'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['inputs_position'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['inputs_segmentation'] = jax.numpy.ones( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets_position'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets_segmentation'] = jax.numpy.ones( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) + output["inputs"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32) + output["inputs_position"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) + output["inputs_segmentation"] = jax.numpy.ones( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) + output["targets"] = jax.numpy.zeros((config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32) + output["targets_position"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) + output["targets_segmentation"] = jax.numpy.ones( + (config.global_batch_size_to_load, config.max_target_length), dtype=jax.numpy.int32 + ) return output -class BadSyntheticDataIterator(): + +class BadSyntheticDataIterator: """Creates a Bad synthetic data iterator for loading on subset of hosts""" + def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) - data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - self.data_generator = jax.jit(BadSyntheticDataIterator.get_bad_synthetic_data, - out_shardings=data_pspec_shardings, - static_argnums=0) + data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + self.data_generator = jax.jit( + BadSyntheticDataIterator.get_bad_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0 + ) def __iter__(self): return self @@ -146,25 +150,31 @@ def __next__(self): @staticmethod def get_bad_synthetic_data(config): - """fill negative value in synthetic data """ + """fill negative value in synthetic data""" output = {} - output['inputs'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['inputs_position'] = jax.numpy.full((config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['inputs_segmentation'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets_position'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets_segmentation'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) + output["inputs"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["inputs_position"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["inputs_segmentation"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["targets"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["targets_position"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) + output["targets_segmentation"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32 + ) return output + def get_process_loading_real_data(config, mesh): - """ Get list of processes loading data from GCS when expansion_factor_real_data != -1 - """ + """Get list of processes loading data from GCS when expansion_factor_real_data != -1""" sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) devices_indices_map = sharding.devices_indices_map((config.global_batch_size_to_load, config.max_target_length)) batch_cutoff = config.global_batch_size_to_train_on @@ -174,10 +184,11 @@ def get_process_loading_real_data(config, mesh): process_loading_real_data.add(p.process_index) return list(process_loading_real_data) + def make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos): process_indices = get_process_loading_real_data(config, mesh) - print(len(process_indices),"hosts out of",jax.process_count(),"are loading real data") - if config.expansion_factor_real_data != -1: # assert number of hosts loading real data + print(len(process_indices), "hosts out of", jax.process_count(), "are loading real data") + if config.expansion_factor_real_data != -1: # assert number of hosts loading real data assert len(process_indices) == jax.process_count() // config.expansion_factor_real_data if jax.process_index() in process_indices: if config.dataset_type == "c4": @@ -186,11 +197,14 @@ def make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos): return make_grain_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices) elif config.dataset_type == "c4_mlperf": print("Overwrite both add_bos and add_eos to False") - return make_c4_mlperf_train_iterator_and_tokenizer(config, mesh, add_bos=False, add_eos=False, process_indices = process_indices) + return make_c4_mlperf_train_iterator_and_tokenizer( + config, mesh, add_bos=False, add_eos=False, process_indices=process_indices + ) else: return BadSyntheticDataIterator(config, mesh), None, get_tokenizer(config.tokenizer_path, add_bos, add_eos) -def create_data_iterator_with_tokenizer(config, mesh, add_bos = True, add_eos = True): + +def create_data_iterator_with_tokenizer(config, mesh, add_bos=True, add_eos=True): if config.dataset_type == "synthetic": return SyntheticDataIterator(config, mesh), None, get_tokenizer(config.tokenizer_path, add_bos, add_eos) elif config.dataset_type in ("c4", "c4-array_record", "c4_mlperf"): @@ -198,15 +212,16 @@ def create_data_iterator_with_tokenizer(config, mesh, add_bos = True, add_eos = else: assert False, "dataset type not implemented" + def get_shaped_batch(config): - """ Return the shape of the batch - this is what eval_shape would return for the + """Return the shape of the batch - this is what eval_shape would return for the output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078.""" batch_shape = (config.global_batch_size_to_load, config.max_target_length) shaped_batch = {} - shaped_batch['inputs'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['inputs_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['inputs_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["targets_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) return shaped_batch diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 8ea5e7bd0..b1c0803d7 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -58,8 +58,7 @@ nd_dense_init = initializers.nd_dense_init shard_map = shard_map.shard_map -dynamic_vector_slice_in_dim = jax.vmap( - lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) +dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error @@ -89,10 +88,12 @@ def apply_mask_to_logits(logits: Array, mask: Array): """ return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) + def _maybe_aqt_einsum(quant: Quant): """Maybe overwrite dot general with aqt_dot_general.""" return jnp.einsum if quant is None else quant.einsum() + class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str @@ -100,44 +101,33 @@ class AttentionOp(nn.Module): num_query_heads: int num_kv_heads: int float32_qk_product: bool = False - max_prefill_predict_length: int = -1 + max_prefill_predict_length: int = -1 float32_logits: bool = False flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) - dropout_rate: float = 0. + dropout_rate: float = 0.0 dtype: DType = jnp.float32 quant: Optional[Quant] = None quantize_kvcache: bool = False - def check_attention_inputs( - self, - query: Array, - key: Array, - value: Array) -> None: + def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: """Check attention inputs.""" - assert key.ndim == value.ndim, 'k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert key.shape[-2] == value.shape[-2], ('k, v num_kv_heads must match.') - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + assert key.ndim == value.ndim, "k, v must have same rank." + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." + assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." + assert key.shape[-3] == value.shape[-3], "k, v lengths must match." + assert query.shape[-1] == key.shape[-1], "q, k depths must match." # Following Pallas MHA Flash Attention Reference. # https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py # This mask models (1) separate sequences (decoder_segment_ids) and (2) causality - def generate_attention_mask( - self, - query, - key, - decoder_segment_ids: Array | None, - model_mode: str - ) -> Array | None: + def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, model_mode: str) -> Array | None: mask = None if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: mask = decoder_segment_ids[:, None, None, None, :] == common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR elif decoder_segment_ids is not None: mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :] - mask = mask[:, None, None,:, :] + mask = mask[:, None, None, :, :] causal_mask = None # We enforce causality except for AUTOREGRESSION @@ -160,53 +150,43 @@ def generate_attention_mask( return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None - def apply_attention(self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None, - model_mode: str): + def apply_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, model_mode: str): self.check_attention_inputs(query, key, value) length = query.shape[-3] - if self.attention_kernel == 'dot_product' or\ - (self.attention_kernel == 'autoselected' and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) or\ - (self.attention_kernel == 'autoselected' and length < 128): + if ( + self.attention_kernel == "dot_product" + or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) + or (self.attention_kernel == "autoselected" and length < 128) + ): return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode) - elif self.attention_kernel == 'flash' or\ - self.attention_kernel == 'autoselected': + elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected": if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) return self.tpu_flash_attention(query, key, value, decoder_segment_ids), None, None - elif self.attention_kernel == 'cudnn_flash_te': + elif self.attention_kernel == "cudnn_flash_te": if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None else: - raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.') - - def tpu_flash_attention( - self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None) -> Array: + raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") + + def tpu_flash_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None) -> Array: """TPU Flash Attention.""" # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) key = jnp.transpose(key, axes=(0, 2, 1, 3)) value = jnp.transpose(value, axes=(0, 2, 1, 3)) - if decoder_segment_ids is not None: - decoder_segment_ids = splash_attention_kernel.SegmentIds( - decoder_segment_ids, decoder_segment_ids - ) + decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids, decoder_segment_ids) axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) - segment_axis_names = nn.logical_to_mesh_axes( - (BATCH, 'activation_length_no_heads') - ) + segment_axis_names = nn.logical_to_mesh_axes((BATCH, "activation_length_no_heads")) @functools.partial( shard_map, @@ -223,76 +203,73 @@ def tpu_flash_attention( def wrap_flash_attention(query, key, value, decoder_segment_ids): if decoder_segment_ids is not None: assert ( - query.shape[2] - == decoder_segment_ids.q.shape[1] - ), 'Sharding along sequence dimension not allowed in tpu kernel attention' + query.shape[2] == decoder_segment_ids.q.shape[1] + ), "Sharding along sequence dimension not allowed in tpu kernel attention" block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(512, query.shape[2]), - block_kv_compute=min(512, key.shape[2]), - block_kv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_kv_dkv=min(512, key.shape[2]), - block_kv_dkv_compute=min(512, query.shape[2]), - block_q_dq=min(512, query.shape[2]), - block_kv_dq=min(512, query.shape[2]), + block_q=min(512, query.shape[2]), + block_kv_compute=min(512, key.shape[2]), + block_kv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_kv_dkv=min(512, key.shape[2]), + block_kv_dkv_compute=min(512, query.shape[2]), + block_q_dq=min(512, query.shape[2]), + block_kv_dq=min(512, query.shape[2]), ) - masks = [splash_attention_mask.CausalMask( shape=(query.shape[2],query.shape[2])) for i in range(query.shape[1])] + masks = [splash_attention_mask.CausalMask(shape=(query.shape[2], query.shape[2])) for i in range(query.shape[1])] multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) - splash_kernel = splash_attention_kernel.make_splash_mha(mask = multi_head_mask, - head_shards = 1, - q_seq_shards = 1, - block_sizes = block_sizes) - - return jax.vmap(splash_kernel)(query,key,value, segment_ids = decoder_segment_ids) - - devices_in_data_fsdp = self.mesh.shape['data'] * self.mesh.shape['fsdp'] + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes + ) + + return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids) + + devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( - 'Batch dimension should be shardable among the devices in data and fsdp' - ' axis' + "Batch dimension should be shardable among the devices in data and fsdp" " axis" ) x = wrap_flash_attention(query, key, value, decoder_segment_ids) x = jnp.transpose(x, axes=(0, 2, 1, 3)) return x - + def cudnn_flash_attention( - self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None, - model_mode: str = common_types.MODEL_MODE_TRAIN, - ) -> Array: + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + model_mode: str = common_types.MODEL_MODE_TRAIN, + ) -> Array: """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA - 2. Supports head_dim till 128; head_dim=256 support will be added soon + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon """ # These imports are only meant to work in a GPU build. - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error _, _, _, head_dim = query.shape # pylint: disable=unused-variable - #generate attn_mask + # generate attn_mask attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - dpa_layer = DotProductAttention(head_dim=head_dim, - num_attention_heads=self.num_query_heads, - num_gqa_groups=self.num_kv_heads, - attn_mask_type='causal', # 'causal' or 'padding' - attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - attention_dropout=self.dropout_rate, - dropout_rng_name='aqt', - dtype=self.dtype, - float32_logits=self.float32_logits, - qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=1.0/math.sqrt(head_dim), - transpose_batch_sequence=False) + dpa_layer = DotProductAttention( + head_dim=head_dim, + num_attention_heads=self.num_query_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type="causal", # 'causal' or 'padding' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0 / math.sqrt(head_dim), + transpose_batch_sequence=False, + ) return dpa_layer(query, key, value, mask=attn_mask) - def compute_local_attention(self, - attn_weights: Array, - value: Array) -> tuple[Array, Array, Array]: - """Computes the attention of a local subset of the kv cache. + def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Array, Array, Array]: + """Computes the attention of a local subset of the kv cache. Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py @@ -314,25 +291,17 @@ def compute_local_attention(self, local_sum = jnp.moveaxis(local_sum, -2, 1) local_max = jnp.moveaxis(local_max, -2, 1) - local_max = jnp.reshape(local_max, - (local_max.shape[0], - local_max.shape[1], - local_max.shape[2] * local_max.shape[3], - 1)) - local_sum = jnp.reshape(local_sum, - (local_sum.shape[0], - local_sum.shape[1], - local_sum.shape[2] * local_sum.shape[3], - 1)) + local_max = jnp.reshape(local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1)) + local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)) local_out = self.wv_product(local_exps, value) return local_out, local_max, local_sum def apply_attention_dot( self, - query: Array, - key: Array, - value: Array, + query: Array, + key: Array, + value: Array, decoder_segment_ids: Array | None, model_mode: str = common_types.MODEL_MODE_TRAIN, ): @@ -364,28 +333,24 @@ def qk_product(self, query: Array, key: Array) -> Array: Returns: results in shape [b, n_kv, n // n_kv, t, s]. """ - b, t, n, d = query.shape + b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) - result = jnp.einsum('btkgd,bskd->bkgts', query, key) - return result # (4, 8, 1, 1, 6) - + result = jnp.einsum("btkgd,bskd->bkgts", query, key) + return result # (4, 8, 1, 1, 6) - def wv_product( - self, - attn_weights: Array, - value: Array) -> Array: + def wv_product(self, attn_weights: Array, value: Array) -> Array: """weighted value product. Args: - attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. + attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. value: Value projection, in shape of [batch_size, v_len, num_kv_heads, kv_dim]. Returns: result in shape [batch_size, q_len, num_kv_heads * group_size, kv_dim] """ - out = jnp.einsum('bkgts,bskd->btkgd', attn_weights, value) + out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) return result @@ -399,8 +364,7 @@ def revert_kvlen_axis(self, kv): Returns: reshaped kv as [b, ..., s, n, d] """ - return jax.numpy.moveaxis(kv, (0,1,2,3), (1,2,0,3)) - + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (1, 2, 0, 3)) def move_kvlen_axis(self, kv): """Move key/value length axis to the end. @@ -411,7 +375,7 @@ def move_kvlen_axis(self, kv): Returns: reshaped kv as [b, ..., n, d, s] """ - return jax.numpy.moveaxis(kv, (0,1,2,3), (2,0,1,3)) + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (2, 0, 1, 3)) def cached_kv_shape(self, kv_shape): """Cached KV shape. @@ -430,26 +394,51 @@ def cached_kv_shape(self, kv_shape): def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 - kv_cache_layout = ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv', ) + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) - cached_key = self.variable('cache', 'cached_prefill_key', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_value = self.variable('cache', 'cached_prefill_value', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_segment_id = self.variable('cache', 'cache_prefill_segment_id', - nn.with_logical_partitioning(jnp.zeros, ('cache_batch', 'cache_sequence')), - (cache_logical_shape[0], self.max_prefill_predict_length), jnp.int32) + cached_key = self.variable( + "cache", + "cached_prefill_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value = self.variable( + "cache", + "cached_prefill_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_segment_id = self.variable( + "cache", + "cache_prefill_segment_id", + nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + (cache_logical_shape[0], self.max_prefill_predict_length), + jnp.int32, + ) if self.quantize_kvcache: cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1) - cached_key_scale_var = self.variable('cache', 'cached_prefill_key_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) - cached_value_scale_var = self.variable('cache', 'cached_prefill_value_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) + cached_key_scale_var = self.variable( + "cache", + "cached_prefill_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_prefill_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) else: cached_key_scale_var = None cached_value_scale_var = None @@ -461,85 +450,112 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 cache_length = self.max_target_length - self.max_prefill_predict_length - kv_cache_layout = ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv', ) + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) cache_logical_shape = (batch, cache_length, heads, kv_head_size) - cached_key = self.variable('cache', 'cached_ar_key', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_value = self.variable('cache', 'cached_ar_value', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_segment_id = self.variable('cache', 'cache_ar_segment_id', - nn.with_logical_partitioning(jnp.zeros, ('cache_batch', 'cache_sequence')), - (cache_logical_shape[0], cache_length), jnp.int32) + cached_key = self.variable( + "cache", + "cached_ar_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value = self.variable( + "cache", + "cached_ar_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_segment_id = self.variable( + "cache", + "cache_ar_segment_id", + nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + (cache_logical_shape[0], cache_length), + jnp.int32, + ) if self.quantize_kvcache: cache_logical_shape_scale = (batch, cache_length, heads, 1) - cached_key_scale_var = self.variable('cache', 'cached_ar_key_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) - cached_value_scale_var = self.variable('cache', 'cached_ar_value_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) + cached_key_scale_var = self.variable( + "cache", + "cached_ar_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_ar_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) else: cached_key_scale_var = None cached_value_scale_var = None - cache_index = self.variable('cache', 'cache_ar_index', - nn.with_logical_partitioning(jnp.zeros, ()), - (1,), jnp.int32) + cache_index = self.variable("cache", "cache_ar_index", nn.with_logical_partitioning(jnp.zeros, ()), (1,), jnp.int32) key_vars = (cached_key, cached_key_scale_var) value_vars = (cached_value, cached_value_scale_var) return key_vars, value_vars, cached_segment_id, cache_index - def kv_cache_prefill(self, - key: Array, - value: Array, - decoder_segment_ids: Array, - ): - """In prefill mode, we zero out the existing cache, run the computation and - prepare the cache as necessary. - - Args: - key: in shape [b, s, n, d]. - value: in shape [b, s, n, d]. - decoder_segment_ids: [b, s] -- marking segment ids for tokens - - Returns: - key, value, decoder_segment_id. - - """ - batch, sequence, heads, kv_head_size = key.shape - assert key.dtype == value.dtype, "Key and Value Dtypes should match." - - cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(batch, heads, kv_head_size, self.quantize_kvcache) - self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now - - key_shaped_for_cache = self.move_kvlen_axis(key) - value_shaped_for_cache = self.move_kvlen_axis(value) - - if self.quantize_kvcache: - key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache) - value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache) - cached_prefill_key_var[1].value = key_scale - cached_prefill_value_var[1].value = value_scale - - cached_prefill_key_var[0].value = key_shaped_for_cache - cached_prefill_value_var[0].value = value_shaped_for_cache + def kv_cache_prefill( + self, + key: Array, + value: Array, + decoder_segment_ids: Array, + ): + """In prefill mode, we zero out the existing cache, run the computation and + prepare the cache as necessary. - if decoder_segment_ids is not None: - cached_prefill_segment_id.value = decoder_segment_ids + Args: + key: in shape [b, s, n, d]. + value: in shape [b, s, n, d]. + decoder_segment_ids: [b, s] -- marking segment ids for tokens + + Returns: + key, value, decoder_segment_id. - return key, value, decoder_segment_ids - + """ + batch, sequence, heads, kv_head_size = key.shape + assert key.dtype == value.dtype, "Key and Value Dtypes should match." + + cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) + self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now + + key_shaped_for_cache = self.move_kvlen_axis(key) + value_shaped_for_cache = self.move_kvlen_axis(value) + + if self.quantize_kvcache: + key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache) + value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache) + cached_prefill_key_var[1].value = key_scale + cached_prefill_value_var[1].value = value_scale + + cached_prefill_key_var[0].value = key_shaped_for_cache + cached_prefill_value_var[0].value = value_shaped_for_cache - def update_ar_key_value(self, - one_token_key: Array, - one_token_value: Array, - cached_key_vars: tuple[nn.Variable, nn.Variable|None], - cached_value_vars: tuple[nn.Variable, nn.Variable|None], - one_hot_indices: Array) -> tuple[Array, Array]: - """Adds a single token's results to the ar kv cache + if decoder_segment_ids is not None: + cached_prefill_segment_id.value = decoder_segment_ids + + return key, value, decoder_segment_ids + + def update_ar_key_value( + self, + one_token_key: Array, + one_token_value: Array, + cached_key_vars: tuple[nn.Variable, nn.Variable | None], + cached_value_vars: tuple[nn.Variable, nn.Variable | None], + one_hot_indices: Array, + ) -> tuple[Array, Array]: + """Adds a single token's results to the ar kv cache Args: one_token_key (Array): Key of one token to add to the cache @@ -566,20 +582,39 @@ def update_ar_key_value(self, one_hot_indices = one_hot_indices.astype(int) - ar_key = cached_key_var.value ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key, jnp.squeeze(one_hot_indices), 0) - ar_key = nn.with_logical_constraint(ar_key, ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv',)) + ar_key = nn.with_logical_constraint( + ar_key, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) cached_key_var.value = ar_key ar_value = cached_value_var.value ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value, jnp.squeeze(one_hot_indices), 0) - ar_value = nn.with_logical_constraint(ar_value, ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv',)) + ar_value = nn.with_logical_constraint( + ar_value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) cached_value_var.value = ar_value if self.quantize_kvcache: - ar_key_scale = jax.lax.dynamic_update_index_in_dim(cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0) - ar_value_scale = jax.lax.dynamic_update_index_in_dim(cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0) + ar_key_scale = jax.lax.dynamic_update_index_in_dim( + cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0 + ) + ar_value_scale = jax.lax.dynamic_update_index_in_dim( + cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0 + ) cached_key_scale_var.value = ar_key_scale cached_value_scale_var.value = ar_value_scale @@ -594,58 +629,61 @@ def prefill_cache_var_model_var(self, cache_var, target_dtype): return self.revert_kvlen_axis(cache_var[0].value) else: raw_cache, quant_scale = cache_var - raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype) + raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype) return self.revert_kvlen_axis(raw_cache_unquantized) - - - def kv_cache_autoregressive(self, - key: Array, - value: Array, - ): - """In autoregressive mode, we update the cache for this entry and - then return the full cache. - - Args: - key: in shape [b, 1, n, d]. - value: in shape [b, 1, n, d]. - decoder_segment_ids: [b, 1] -- marking segment ids for tokens - - Returns: - tuple of (key, value, segment_id) for both prefill and ar cache, - Raises: - ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. - """ - batch, sequence, heads, kv_head_size = key.shape - if sequence != 1: - raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") - is_initialized = self.has_variable('cache', 'cache_ar_index') - if not is_initialized: - raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") - - cached_ar_key_var, cached_ar_value_var, cached_ar_segment_id, cache_ar_index = self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) - - key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) - value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) - - ar_key, ar_value = self.update_ar_key_value(key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value) - active_indicator = jnp.zeros((batch, 1), dtype = jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR - cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim(cached_ar_segment_id.value, active_indicator, jnp.squeeze(cache_ar_index.value), 1) - cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length) - - # Prep and return both prefill and ar caches - cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(self.max_target_length, heads, kv_head_size, self.quantize_kvcache) - - cached_prefill = self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), cached_prefill_segment_id.value - return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) - - def kv_cache( + def kv_cache_autoregressive( self, key: Array, value: Array, - decoder_segment_ids: Array, - model_mode: str - ) -> tuple: + ): + """In autoregressive mode, we update the cache for this entry and + then return the full cache. + + Args: + key: in shape [b, 1, n, d]. + value: in shape [b, 1, n, d]. + decoder_segment_ids: [b, 1] -- marking segment ids for tokens + + Returns: + tuple of (key, value, segment_id) for both prefill and ar cache, + Raises: + ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. + """ + batch, sequence, heads, kv_head_size = key.shape + if sequence != 1: + raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") + is_initialized = self.has_variable("cache", "cache_ar_index") + if not is_initialized: + raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") + + cached_ar_key_var, cached_ar_value_var, cached_ar_segment_id, cache_ar_index = self._get_ar_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) + + key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) + value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) + + ar_key, ar_value = self.update_ar_key_value(key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value) + active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id.value, active_indicator, jnp.squeeze(cache_ar_index.value), 1 + ) + cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length) + + # Prep and return both prefill and ar caches + cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( + self.max_target_length, heads, kv_head_size, self.quantize_kvcache + ) + + cached_prefill = ( + self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), + self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), + cached_prefill_segment_id.value, + ) + return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) + + def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str) -> tuple: """KV cache takes the current state and updates the state accordingly. The key and value have dimension [batch, length, num_heads, head_dim], @@ -665,7 +703,6 @@ def kv_cache( """ if key.shape != value.shape: raise ValueError(f"Can't KV cache with mismatched shapes {key.shape=}, {value.shape=}") - if model_mode == common_types.MODEL_MODE_TRAIN: return (key, value, decoder_segment_ids), None @@ -675,12 +712,8 @@ def kv_cache( return self.kv_cache_autoregressive(key, value) else: raise ValueError(f"Model Mode isn't supported! {model_mode=}") - - - def normalize_attention(self, - local_outs, - local_maxes, - local_sums): + + def normalize_attention(self, local_outs, local_maxes, local_sums): """Normalize across multiple localized attentions Args: @@ -689,14 +722,13 @@ def normalize_attention(self, local_sums (list): List of exponential sum entries for each local attention Returns: - Array: Combined attention that has been normalized + Array: Combined attention that has been normalized """ # Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py global_max = functools.reduce(jnp.maximum, local_maxes) - global_sum = sum([ - jnp.exp(local_max - global_max) * local_sum - for (local_sum, local_max) in zip(local_sums, local_maxes) - ]) + global_sum = sum( + [jnp.exp(local_max - global_max) * local_sum for (local_sum, local_max) in zip(local_sums, local_maxes)] + ) attn_out = 0 for local_max, local_out in zip(local_maxes, local_outs): @@ -704,17 +736,16 @@ def normalize_attention(self, attn_out += local_normalizer * local_out return attn_out - @nn.compact def __call__(self, query, key, value, decoder_segment_ids, model_mode): prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode) prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( - query=query, - key=prefill_kv_cache[0], - value=prefill_kv_cache[1], - decoder_segment_ids=prefill_kv_cache[2], - model_mode=model_mode, + query=query, + key=prefill_kv_cache[0], + value=prefill_kv_cache[1], + decoder_segment_ids=prefill_kv_cache[2], + model_mode=model_mode, ) # Return the "prefill" cache if it actually the combined prefill+ar kv cache @@ -723,12 +754,12 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): return prefill_unnormalized_output / prefill_exponentials_sum return prefill_unnormalized_output - ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( - query=query, - key=ar_kv_cache[0], - value=ar_kv_cache[1], - decoder_segment_ids=ar_kv_cache[2], - model_mode=model_mode, + ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( + query=query, + key=ar_kv_cache[0], + value=ar_kv_cache[1], + decoder_segment_ids=ar_kv_cache[2], + model_mode=model_mode, ) unnormalized_outputs = [prefill_unnormalized_output, ar_unnormalized_output] @@ -738,27 +769,27 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): class Attention(nn.Module): - """ Generic Attention. - - Attributes: - num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - num_kv_heads: number of kv attention heads. - head_dim: dimension of each head. - mesh: Mesh, device mesh - attention_kernel: str, guidance on if we should use an attention kernel - dtype: the dtype of the computation. - weight_dtype: the dtype of the weights. - max_target_length: maximum target length - max_prefill_predict_length: size of the maximum prefill - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid - numerical issues with bfloat16. - float32_logits: bool, if True then cast logits to float32 before softmax to avoid - numerical issues with bfloat16. - quant: Quant, stores quantization parameters, defaults to None implying no quantization. - quantize_kvcache: bool, quantize the kv cache. + """Generic Attention. + + Attributes: + num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + num_kv_heads: number of kv attention heads. + head_dim: dimension of each head. + mesh: Mesh, device mesh + attention_kernel: str, guidance on if we should use an attention kernel + dtype: the dtype of the computation. + weight_dtype: the dtype of the weights. + max_target_length: maximum target length + max_prefill_predict_length: size of the maximum prefill + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + quant: Quant, stores quantization parameters, defaults to None implying no quantization. + quantize_kvcache: bool, quantize the kv cache. """ config: Config @@ -771,14 +802,13 @@ class Attention(nn.Module): dtype: DType = jnp.float32 weight_dtype: DType = jnp.float32 max_prefill_predict_length: int = -1 - dropout_rate: float = 0. - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + dropout_rate: float = 0.0 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") float32_qk_product: bool = False # computes logits in float32 for stability. float32_logits: bool = False # cast logits in float32 for stability. quant: Optional[Quant] = None quantize_kvcache: bool = False - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) @@ -791,19 +821,21 @@ def query_projection(self, inputs_q: Array) -> Array: # 1/sqrt(depth_kq)! This is folded into the initializers of the # linear transformations, which is equivalent under Adafactor. depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + def query_init(*args): - #pylint: disable=no-value-for-parameter + # pylint: disable=no-value-for-parameter return self.kernel_init(*args) / depth_scaling query_proj = DenseGeneral( - features=(self.num_query_heads, self.head_dim), - axis=-1, - kernel_init=query_init, - kernel_axes=('embed', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='query', - quant=self.quant)(inputs_q) + features=(self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=query_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="query", + quant=self.quant, + )(inputs_q) return query_proj def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: @@ -818,66 +850,69 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: Projection of key or value, in shape of `[batch, kv_length, head_dim]`. """ if self.num_kv_heads == -1: - raise ValueError('num_kv_heads is not defined.') + raise ValueError("num_kv_heads is not defined.") if self.num_query_heads % self.num_kv_heads != 0: - raise ValueError('Invaid num_kv_heads for GQA.') + raise ValueError("Invaid num_kv_heads for GQA.") kv_proj = DenseGeneral( features=(self.num_kv_heads, self.head_dim), axis=-1, kernel_init=self.kernel_init, - kernel_axes=('embed', 'heads', 'kv'), + kernel_axes=("embed", "heads", "kv"), dtype=self.dtype, weight_dtype=self.weight_dtype, name=proj_name, - quant=self.quant)(inputs_kv) + quant=self.quant, + )(inputs_kv) return kv_proj def qkv_projection(self, inputs: Array, proj_name: str): - """ Fused QKV projection""" + """Fused QKV projection""" qkv_proj = DenseGeneral( - features=(3, self.num_query_heads, self.head_dim), - axis = -1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'qkv', 'heads', 'kv'), + features=(3, self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), dtype=self.dtype, weight_dtype=self.weight_dtype, name=proj_name, - quant=self.quant)(inputs) - qkv_proj = checkpoint_name(qkv_proj, 'qkv_proj') - query, key, value = qkv_proj[:,:,0,...], qkv_proj[:,:,1,...], qkv_proj[:,:,2,...] + quant=self.quant, + )(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] return query, key, value def out_projection(self, output_dim: int, out: Array) -> Array: out_proj = DenseGeneral( - features=output_dim, - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('heads', 'kv', 'embed'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='out', - quant=self.quant)(out) + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=("heads", "kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="out", + quant=self.quant, + )(out) return out_proj def key_rotary(self, key: Array, inputs_positions: Array): """Apply Rotary Embedding to key.""" - key = RotaryEmbedding( - embedding_dims=self.head_dim, - name='key_rotary')(inputs=key, position=inputs_positions) + key = RotaryEmbedding(embedding_dims=self.head_dim, name="key_rotary")(inputs=key, position=inputs_positions) return key @nn.compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - inputs_positions: Array, - decoder_segment_ids: Array | None = None, - *, - model_mode: str = common_types.MODEL_MODE_TRAIN, - deterministic: bool = False): + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + inputs_positions: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False, + ): """Applies Attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -902,38 +937,38 @@ def __call__(self, """ # apply projection. if self.config.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name='qkv_proj') + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: query = self.query_projection(inputs_q) - key = self.kv_projection(inputs_kv, proj_name='key') - value = self.kv_projection(inputs_kv, proj_name='value') + key = self.kv_projection(inputs_kv, proj_name="key") + value = self.kv_projection(inputs_kv, proj_name="value") # apply ROPE - query = RotaryEmbedding( - embedding_dims=self.head_dim, name='query_rotary' - )(inputs=query, position=inputs_positions) + query = RotaryEmbedding(embedding_dims=self.head_dim, name="query_rotary")(inputs=query, position=inputs_positions) key = self.key_rotary(key, inputs_positions) # annotate with sharding constraint. query = nn.with_logical_constraint(query, self.query_axis_names) - query = checkpoint_name(query, 'query_proj') + query = checkpoint_name(query, "query_proj") key = nn.with_logical_constraint(key, self.key_axis_names) - key = checkpoint_name(key, 'key_proj') + key = checkpoint_name(key, "key_proj") value = nn.with_logical_constraint(value, self.value_axis_names) - value = checkpoint_name(value, 'value_proj') - - attention_op = AttentionOp(mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - quantize_kvcache=self.quantize_kvcache, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - dropout_rate = self.dropout_rate, - dtype=self.dtype) + value = checkpoint_name(value, "value_proj") + + attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + quantize_kvcache=self.quantize_kvcache, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) @@ -941,5 +976,5 @@ def __call__(self, # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) - out = checkpoint_name(out, 'out_proj') + out = checkpoint_name(out, "out_proj") return out diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 6c954941c..9337986a0 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -32,6 +32,7 @@ _MAX_WAVELENGTH = 10_000 + class Embed(nn.Module): """A parameterized function from integers [0, n) to d-dimensional vectors. @@ -53,8 +54,8 @@ class Embed(nn.Module): def setup(self): self.embedding = self.param( - 'embedding', - with_logical_partitioning(self.embedding_init, ('vocab', 'embed')), + "embedding", + with_logical_partitioning(self.embedding_init, ("vocab", "embed")), (self.num_embeddings, self.features), self.config.weight_dtype, ) @@ -73,7 +74,7 @@ def __call__(self, inputs: Array) -> Array: if self.cast_input_dtype: inputs = inputs.astype(self.cast_input_dtype) if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') + raise ValueError("Input type must be an integer or unsigned integer.") if cfg.use_iota_embed: iota = lax.iota(jnp.int32, self.num_embeddings) @@ -81,9 +82,7 @@ def __call__(self, inputs: Array) -> Array: output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) else: output = jnp.asarray(self.embedding, self.dtype)[inputs] - output = nn.with_logical_constraint( - output, ('activation_batch', 'activation_length', 'activation_embed') - ) + output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed")) return output def attend(self, query: Array) -> Array: @@ -122,9 +121,7 @@ class RotaryEmbedding(nn.Module): def setup(self) -> None: if self.embedding_dims % 2: - raise ValueError( - 'Embedding dim for rotary position embedding must be a multiple of 2.' - ) + raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") def __call__( self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks @@ -147,21 +144,14 @@ def __call__( """ assert position is not None if len(inputs.shape) != 4: - raise ValueError( - 'Input is assumed to be a rank 4 tensor of shape' - '[batch, sequence, heads, dims].' - ) + raise ValueError("Input is assumed to be a rank 4 tensor of shape" "[batch, sequence, heads, dims].") if self.embedding_dims != inputs.shape[3]: raise ValueError( - 'The embedding dims of the rotary position embedding' - 'must match the hidden dimension of the inputs.' + "The embedding dims of the rotary position embedding" "must match the hidden dimension of the inputs." ) half_embedding_dim = self.embedding_dims // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims - timescale = ( - self.min_timescale - * (self.max_timescale / self.min_timescale) ** fraction - ) + timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction position = position[:, :, jnp.newaxis, jnp.newaxis] sinusoid_inp = position / timescale sin = jnp.sin(sinusoid_inp).astype(inputs.dtype) @@ -189,13 +179,11 @@ def __call__( log_timescale_increment = jnp.log(float(self.max_wavelength)) / jnp.maximum( jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 ) - inv_timescales = jnp.exp( - jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment - ) + inv_timescales = jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) position = position[:, :, jnp.newaxis] inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :] scaled_time = position * inv_timescales - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = -1) + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1) # signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]]) position_embedding = signal.astype(jnp.float32) return input_embedding + position_embedding diff --git a/MaxText/layers/gemma.py b/MaxText/layers/gemma.py index cbbadf7bc..fb909f985 100644 --- a/MaxText/layers/gemma.py +++ b/MaxText/layers/gemma.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from flax import linen as nn import common_types @@ -52,69 +52,65 @@ # Decoder and Model definitions class GemmaDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - lnx = RMSNorm( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name='pre_self_attention_norm', - kernel_axes=('embed',))(inputs) + lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("embed",))( + inputs + ) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) attention_layer = Attention( - config=cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - float32_qk_product = True, - float32_logits = True, - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + float32_qk_product=True, + float32_logits=True, + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) attention_lnx += inputs residual = attention_lnx - attn_output = RMSNorm( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name='pre_ffw_norm', - kernel_axes=('embed',))(attention_lnx) + attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("embed",))( + attention_lnx + ) # MLP block. mlp_lnx = MlpBlock( @@ -123,32 +119,30 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(attn_output, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) next_layer_addition = mlp_lnx + residual - next_layer_addition_dropped_out = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,) - )(next_layer_addition, deterministic=deterministic) + next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + next_layer_addition, deterministic=deterministic + ) layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py index 518ec2912..853ec43bb 100644 --- a/MaxText/layers/gpt3.py +++ b/MaxText/layers/gpt3.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -55,12 +55,14 @@ Quant = quantizations.AqtQuantization -#----------------------------------------- +# ----------------------------------------- # The Normalization Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3LayerNorm(nn.Module): """GPT3 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -82,10 +84,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: features = x.shape[-1] scale = self.param( - 'scale', - nn.with_logical_partitioning(self.scale_init, self.kernel_axes), - (features,), - self.weight_dtype + "scale", nn.with_logical_partitioning(self.scale_init, self.kernel_axes), (features,), self.weight_dtype ) scale = jnp.asarray(scale, self.dtype) @@ -93,40 +92,41 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: if self.use_bias: bias = self.param( - 'bias', - nn.with_logical_partitioning(initializers.default_bias_init, self.kernel_axes), - (features,), - self.weight_dtype, + "bias", + nn.with_logical_partitioning(initializers.default_bias_init, self.kernel_axes), + (features,), + self.weight_dtype, ) bias = jnp.asarray(bias, self.dtype) output += bias return output -#----------------------------------------- +# ----------------------------------------- # The Attention Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3MultiHeadAttention(nn.Module): """Multi-head attention in gpt3. - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - head_dim: dimension of each head. - max_target_length: maximum length of output - max_prefill_predict_length: size of the maximum prefill - mesh: device mesh - dtype: the dtype of the computation. - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid - numerical issues with bfloat16. - float32_logits: bool, if True then cast logits to float32 before softmax to avoid - numerical issues with bfloat16. - fused_qkv: whether to fuse query, key and value into one projection. - quant: Quant, stores quantization config, defaults to None implying no quantization. - use_bias: whether to add bias in linear transformation. + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + max_target_length: maximum length of output + max_prefill_predict_length: size of the maximum prefill + mesh: device mesh + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + fused_qkv: whether to fuse query, key and value into one projection. + quant: Quant, stores quantization config, defaults to None implying no quantization. + use_bias: whether to add bias in linear transformation. """ config: Config @@ -138,8 +138,8 @@ class Gpt3MultiHeadAttention(nn.Module): attention_kernel: str dtype: DType = jnp.float32 weight_dtype: DType = jnp.float32 - dropout_rate: float = 0. - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + dropout_rate: float = 0.0 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") float32_qk_product: bool = False # computes logits in float32 for stability. float32_logits: bool = True # cast logits in float32 for stability. fused_qkv: bool = True @@ -152,88 +152,92 @@ class Gpt3MultiHeadAttention(nn.Module): out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) def qkv_projection(self, inputs: Array, proj_name: str): - """ Fused QKV projection""" + """Fused QKV projection""" qkv_proj = DenseGeneral( - features=(3, self.num_heads, self.head_dim), - axis = -1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'qkv', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=proj_name, - quant=self.quant, - use_bias=self.use_bias, - )(inputs) - qkv_proj = checkpoint_name(qkv_proj, 'qkv_proj') - query, key, value = qkv_proj[:,:,0,...], qkv_proj[:,:,1,...], qkv_proj[:,:,2,...] + features=(3, self.num_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + use_bias=self.use_bias, + )(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] return query, key, value def projection(self, inputs: Array, proj_name: str) -> Array: """individual projection for one of q, k and v.""" proj = DenseGeneral( - features=(self.num_heads, self.head_dim), - axis=-1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=proj_name, - quant=self.quant, - use_bias=self.use_bias, - )(inputs) + features=(self.num_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + use_bias=self.use_bias, + )(inputs) return proj def out_projection(self, output_dim: int, out: Array) -> Array: """output projection""" out_proj = DenseGeneral( - features=output_dim, - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('heads', 'kv', 'embed'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='out', - quant=self.quant, - use_bias=self.use_bias, - )(out) + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=("heads", "kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="out", + quant=self.quant, + use_bias=self.use_bias, + )(out) return out_proj @nn.compact - def __call__(self, - inputs_q: Array, - decoder_segment_ids: Array | None = None, - *, - model_mode: str = common_types.MODEL_MODE_TRAIN, - deterministic: bool = False): + def __call__( + self, + inputs_q: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False, + ): if self.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name='qkv_proj') + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: - query = self.projection(inputs_q, proj_name='query') - key = self.projection(inputs_q, proj_name='key') - value = self.projection(inputs_q, proj_name='value') + query = self.projection(inputs_q, proj_name="query") + key = self.projection(inputs_q, proj_name="key") + value = self.projection(inputs_q, proj_name="value") depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) query /= depth_scaling # annotate with sharding constraint. query = nn.with_logical_constraint(query, self.query_axis_names) - query = checkpoint_name(query, 'query_proj') + query = checkpoint_name(query, "query_proj") key = nn.with_logical_constraint(key, self.key_axis_names) - key = checkpoint_name(key, 'key_proj') + key = checkpoint_name(key, "key_proj") value = nn.with_logical_constraint(value, self.value_axis_names) - value = checkpoint_name(value, 'value_proj') - - attention_op = AttentionOp(mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - quantize_kvcache=self.config.quantize_kvcache, - num_query_heads=self.num_heads, - num_kv_heads=self.num_heads, - dtype=self.dtype) + value = checkpoint_name(value, "value_proj") + + attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + quantize_kvcache=self.config.quantize_kvcache, + num_query_heads=self.num_heads, + num_kv_heads=self.num_heads, + dtype=self.dtype, + ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) @@ -241,76 +245,74 @@ def __call__(self, # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) - out = checkpoint_name(out, 'out_proj') + out = checkpoint_name(out, "out_proj") return out -#----------------------------------------- +# ----------------------------------------- # The Decoder Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) - + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) lnx_layer_norm = Gpt3LayerNorm( dtype=cfg.dtype, - name='pre_self_attention_norm', - kernel_axes=('embed',), + name="pre_self_attention_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, reductions_in_fp32=False, use_bias=True, - ) + ) lnx = lnx_layer_norm(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) # Self-attention block - assert cfg.num_query_heads == cfg.num_kv_heads, \ - f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" + assert ( + cfg.num_query_heads == cfg.num_kv_heads + ), f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" attention_layer = Gpt3MultiHeadAttention( - config=cfg, - num_heads=cfg.num_query_heads, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dropout_rate=cfg.dropout_rate, - name='self_attention', - fused_qkv=cfg.fused_qkv, - use_bias=True, - quant=self.quant) + config=cfg, + num_heads=cfg.num_query_heads, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dropout_rate=cfg.dropout_rate, + name="self_attention", + fused_qkv=cfg.fused_qkv, + use_bias=True, + quant=self.quant, + ) attention_lnx = attention_layer( - lnx, - decoder_segment_ids=decoder_segment_ids, - model_mode=model_mode, - deterministic=deterministic) - - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + lnx, decoder_segment_ids=decoder_segment_ids, model_mode=model_mode, deterministic=deterministic + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) attention_lnx += inputs # MLP block. @@ -320,33 +322,29 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", use_bias=True, use_pre_norm=True, config=cfg, quant=self.quant, )(attention_lnx, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) layer_output = attention_lnx + mlp_lnx - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/initializers.py b/MaxText/layers/initializers.py index 6f0bb9c23..5916ecb0c 100644 --- a/MaxText/layers/initializers.py +++ b/MaxText/layers/initializers.py @@ -27,13 +27,9 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] InitializerAxis = Union[int, Tuple[int, ...]] -NdInitializer = Callable[ - [PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array -] +NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array] -default_embed_init = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 -) +default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0) default_bias_init = jax.nn.initializers.constant(0.0) @@ -42,9 +38,7 @@ def nd_dense_init(scale, mode, distribution): """Initializer with in_axis, out_axis set at call time.""" def init_fn(key, shape, dtype, in_axis, out_axis): - fn = jax.nn.initializers.variance_scaling( - scale, mode, distribution, in_axis, out_axis - ) + fn = jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis, out_axis) return fn(key, shape, dtype) return init_fn diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 4cf3d1939..3d3f35b9b 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -40,18 +40,20 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization -def _convert_to_activation_function( - fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: + +def _convert_to_activation_function(fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: """Convert a string to an activation function.""" - if fn_or_string == 'linear': + if fn_or_string == "linear": return lambda x: x elif isinstance(fn_or_string, str): return getattr(nn, fn_or_string) elif callable(fn_or_string): return fn_or_string else: - raise ValueError(f"""Don't know how to convert {fn_or_string} - to an activation function""") + raise ValueError( + f"""Don't know how to convert {fn_or_string} + to an activation function""" + ) def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: @@ -83,7 +85,7 @@ class DenseGeneral(nn.Module): axis: Union[Iterable[int], int] = -1 weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") kernel_axes: Tuple[str, ...] = () quant: Optional[Quant] = None use_bias: bool = False @@ -105,8 +107,7 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): if self.quant: dot_general_cls = self.quant.dot_general_cls() dot_general = dot_general_cls() - return dot_general( - inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -123,12 +124,12 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): kernel = jnp.zeros(kernel_shape) else: kernel = self.param( - 'kernel', - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, ) kernel = jnp.asarray(kernel, self.dtype) @@ -136,9 +137,9 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): output = compute_dot_general(inputs, kernel, axis, contract_ind) if self.use_bias: - bias_axes, bias_shape = self.kernel_axes[-len(features):], kernel_shape[-len(features):] + bias_axes, bias_shape = self.kernel_axes[-len(features) :], kernel_shape[-len(features) :] bias = self.param( - 'bias', + "bias", nn.with_logical_partitioning(bias_init, bias_axes), bias_shape, self.weight_dtype, @@ -167,8 +168,8 @@ class MlpBlock(nn.Module): config: Config intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable[..., Any]]] = ('relu',) - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + activations: Sequence[Union[str, Callable[..., Any]]] = ("relu",) + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") intermediate_dropout_rate: float = 0.1 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -181,6 +182,7 @@ def get_norm_layer(self): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=self.use_bias) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @@ -192,39 +194,39 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): if self.use_pre_norm: inputs = self.get_norm_layer()( - name='mlp_layer_norm', - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - kernel_axes=('embed',), - epsilon=cfg.normalization_layer_epsilon, - )(inputs) + name="mlp_layer_norm", + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("embed",), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) # Iterate over specified MLP input activation functions. # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. activations = [] if cfg.fused_mlp: x = DenseGeneral( - (len(self.activations), self.intermediate_dim), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'num_activations', 'mlp'), - name='wi', - quant=self.quant, - use_bias=self.use_bias, + (len(self.activations), self.intermediate_dim), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "num_activations", "mlp"), + name="wi", + quant=self.quant, + use_bias=self.use_bias, )(inputs) for idx, act_fn in enumerate(self.activations): - y = _convert_to_activation_function(act_fn)(x[:,:,idx,...]) + y = _convert_to_activation_function(act_fn)(x[:, :, idx, ...]) activations.append(y) else: for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" x = DenseGeneral( self.intermediate_dim, dtype=self.dtype, weight_dtype=self.weight_dtype, kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), + kernel_axes=("embed", "mlp"), name=dense_name, quant=self.quant, use_bias=self.use_bias, @@ -234,26 +236,24 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): # Take elementwise product of above intermediate activations. x = functools.reduce(operator.mul, activations) - x = checkpoint_name(x, 'mlpwi') + x = checkpoint_name(x, "mlpwi") # Apply dropout and final dense output projection. x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic ) # Broadcast along length. - x = nn.with_logical_constraint( - x, ('activation_batch', 'activation_length', 'activation_mlp') - ) + x = nn.with_logical_constraint(x, ("activation_batch", "activation_length", "activation_mlp")) output = DenseGeneral( inputs.shape[-1], dtype=self.dtype, weight_dtype=self.weight_dtype, kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - name='wo', + kernel_axes=("mlp", "embed"), + name="wo", quant=self.quant, use_bias=self.use_bias, )(x) - output = checkpoint_name(output, 'mlpwo') + output = checkpoint_name(output, "mlpwo") return output @@ -278,38 +278,35 @@ class MoeBlock(nn.Module): @nn.compact def __call__(self, inputs, deterministic: bool = False): gate_logits = DenseGeneral( - self.num_experts, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - name='gate', - quant=self.quant,)(inputs) + self.num_experts, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + name="gate", + quant=self.quant, + )(inputs) weights, selected_experts = lax.top_k(gate_logits, self.num_experts_per_tok) weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) mlp_lnx = jnp.zeros_like(inputs) weights = weights.astype(self.dtype) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) # TODO(ranran): have a better solution to remove the loop here for k in range(self.num_experts): - weights_exp = jnp.sum(jnp.multiply(selected_experts==k, weights), axis=-1) - mlp_lnx_exp = MlpBlock( + weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1) + mlp_lnx_exp = MlpBlock( intermediate_dim=self.config.mlp_dim, activations=self.config.mlp_activations, intermediate_dropout_rate=self.config.dropout_rate, dtype=self.dtype, weight_dtype=self.weight_dtype, - name=f'mlp_{k}', + name=f"mlp_{k}", config=self.config, - )(inputs, deterministic=deterministic) + )(inputs, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint( - mlp_lnx_exp, ('activation_batch', 'activation_length', 'activation_embed') - ) - mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp - mlp_lnx += mlp_lnx_exp + mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp + mlp_lnx += mlp_lnx_exp return mlp_lnx diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index bc0fd1861..7723d6548 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -43,86 +43,82 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization -#----------------------------------------- +# ----------------------------------------- # The Decoder Layer specific for Llama2 -#----------------------------------------- +# ----------------------------------------- class LlamaDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) - + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) lnx_rms = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_layer_norm', - kernel_axes=('embed',), + name="pre_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - ) + ) lnx = lnx_rms(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) # Self-attention block attention_layer = Attention( - config = cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) - - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='post_self_attention_layer_norm', - kernel_axes=('embed',), + name="post_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) - hidden_states = nn.with_logical_constraint( - hidden_states, - ('activation_batch', 'activation_length', 'activation_embed') - ) + )(intermediate_inputs) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) # MLP block. mlp_lnx = linears.MlpBlock( @@ -131,32 +127,27 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) - + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 2992cf5ab..6954c157f 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -51,153 +51,144 @@ class MistralDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) lnx_rms = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_layer_norm', - kernel_axes=('embed',), - epsilon=cfg.normalization_layer_epsilon - ) + name="pre_self_attention_layer_norm", + kernel_axes=("embed",), + epsilon=cfg.normalization_layer_epsilon, + ) lnx = lnx_rms(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) # Self-attention block attention_layer = Attention( - config = cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) - - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='post_self_attention_layer_norm', - kernel_axes=('embed',), + name="post_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, ('activation_batch', 'activation_length', 'activation_embed')) + )(intermediate_inputs) + hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) if cfg.num_experts > 1: - # TODO(ranran): currently, this MoeBlock does not work as expected, and plan to fix it in coming PR. - - # mlp_lnx = linears.MoeBlock( - # config=cfg, - # num_experts=cfg.num_experts, - # num_experts_per_tok=cfg.num_experts_per_tok, - # kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), - # kernel_axes=('embed', 'mlp'), - # dtype=cfg.dtype, - # )(hidden_states, deterministic=deterministic) - - gate_logits = linears.DenseGeneral( - cfg.num_experts, - weight_dtype=cfg.weight_dtype, - dtype=cfg.dtype, - kernel_init=initializers.nd_dense_init( - 1.0, 'fan_in', 'truncated_normal'), - kernel_axes=('embed', 'mlp'), - name="gate", - quant=self.quant, - )(hidden_states) - weights, selected_experts = jax.lax.top_k(gate_logits, cfg.num_experts_per_tok) - weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) - mlp_lnx = jnp.zeros_like(hidden_states) - weights = weights.astype(cfg.dtype) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) - - # TODO(ranran): have a better solution to remove the loop here - for k in range(cfg.num_experts): - weights_exp = jnp.sum(jnp.multiply( - selected_experts == k, weights), axis=-1) - mlp_lnx_exp = linears.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name=f'mlp_{k}', - config=cfg, - )(hidden_states, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint( - mlp_lnx_exp, ('activation_batch', 'activation_length', 'activation_embed') - ) - mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp - mlp_lnx += mlp_lnx_exp - else: - mlp_lnx = linears.MlpBlock( + # TODO(ranran): currently, this MoeBlock does not work as expected, and plan to fix it in coming PR. + + # mlp_lnx = linears.MoeBlock( + # config=cfg, + # num_experts=cfg.num_experts, + # num_experts_per_tok=cfg.num_experts_per_tok, + # kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), + # kernel_axes=('embed', 'mlp'), + # dtype=cfg.dtype, + # )(hidden_states, deterministic=deterministic) + + gate_logits = linears.DenseGeneral( + cfg.num_experts, + weight_dtype=cfg.weight_dtype, + dtype=cfg.dtype, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", "mlp"), + name="gate", + quant=self.quant, + )(hidden_states) + weights, selected_experts = jax.lax.top_k(gate_logits, cfg.num_experts_per_tok) + weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) + mlp_lnx = jnp.zeros_like(hidden_states) + weights = weights.astype(cfg.dtype) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + + # TODO(ranran): have a better solution to remove the loop here + for k in range(cfg.num_experts): + weights_exp = jnp.sum(jnp.multiply(selected_experts == k, weights), axis=-1) + mlp_lnx_exp = linears.MlpBlock( intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name=f"mlp_{k}", config=cfg, )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx_exp = nn.with_logical_constraint(mlp_lnx_exp, ("activation_batch", "activation_length", "activation_embed")) + mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp + mlp_lnx += mlp_lnx_exp + else: + mlp_lnx = linears.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + config=cfg, + )(hidden_states, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( - layer_output, ('activation_batch', 'activation_length', 'activation_embed'), + layer_output, + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 571a77a95..ff3c246b1 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -41,69 +41,69 @@ PositionalEmbedding = embeddings.PositionalEmbedding Quant = quantizations.AqtQuantization -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # The network: Decoder & Transformer Definitions -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ class DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh - inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_norm', + name="pre_self_attention_norm", epsilon=cfg.normalization_layer_epsilon, - kernel_axes=('embed',))(inputs) - lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + kernel_axes=("embed",), + )(inputs) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) attention_layer = Attention( - config = self.config, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) - + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) - attention_lnx = nn.with_logical_constraint( - attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) # MLP block. mlp_lnx = linears.MlpBlock( @@ -112,32 +112,30 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(lnx, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) next_layer_addition = mlp_lnx + attention_lnx - next_layer_addition_dropped_out = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,) - )(next_layer_addition, deterministic=deterministic) + next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + next_layer_addition, deterministic=deterministic + ) layer_output = next_layer_addition_dropped_out + inputs layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) @@ -146,6 +144,7 @@ def __call__(self, class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" + config: Config shared_embedding: nn.Module mesh: Mesh @@ -156,16 +155,20 @@ def get_decoder_layer(self): return DecoderLayer elif self.config.decoder_block == "llama2": from layers import llama2 + return llama2.LlamaDecoderLayer elif self.config.decoder_block == "mistral": # TODO(ranran): update to Mistral with sliding window attention from layers import mistral + return mistral.MistralDecoderLayer elif self.config.decoder_block == "gemma": from layers import gemma + return gemma.GemmaDecoderLayer elif self.config.decoder_block == "gpt3": from layers import gpt3 + return gpt3.Gpt3DecoderLayer else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @@ -175,74 +178,89 @@ def get_norm_layer(self): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=True) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") @nn.compact - def __call__(self, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=common_types.MODEL_MODE_TRAIN, - ): + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=common_types.MODEL_MODE_TRAIN, + ): cfg = self.config mesh = self.mesh assert decoder_input_tokens.ndim == 2 # [batch, len] # [batch, length] -> [batch, length, emb_dim] - y = self.shared_embedding(decoder_input_tokens.astype('int32')) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic) + y = self.shared_embedding(decoder_input_tokens.astype("int32")) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) y = y.astype(cfg.dtype) if cfg.use_untrainable_positional_embedding: - y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) + y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) if cfg.trainable_position_size > 0: y += Embed( - num_embeddings=cfg.trainable_position_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - name='position_embedder', - config=cfg)(decoder_positions) + num_embeddings=cfg.trainable_position_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + name="position_embedder", + config=cfg, + )(decoder_positions) BlockLayer = self.get_decoder_layer() - if cfg.remat_policy != 'none': - if cfg.remat_policy == 'minimal': + if cfg.remat_policy != "none": + if cfg.remat_policy == "minimal": policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - elif cfg.remat_policy == 'save_dot_except_mlpwi': + elif cfg.remat_policy == "save_dot_except_mlpwi": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', 'out_proj', 'mlpwo', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", ) - elif cfg.remat_policy == 'save_dot_except_mlp': + elif cfg.remat_policy == "save_dot_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', 'out_proj', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", ) - elif cfg.remat_policy == 'save_qkv_proj': + elif cfg.remat_policy == "save_qkv_proj": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", ) - elif cfg.remat_policy == 'qkv_proj_offloaded': + elif cfg.remat_policy == "qkv_proj_offloaded": policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=['query_proj', 'value_proj', 'key_proj'], - offload_src="device", offload_dst="pinned_host") - elif cfg.remat_policy == 'minimal_offloaded': + names_which_can_be_saved=[], + names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") - elif cfg.remat_policy == 'minimal_flash': + elif cfg.remat_policy == "minimal_flash": policy = jax.checkpoint_policies.save_from_both_policies( - jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - jax.checkpoint_policies.save_only_these_names('context',), + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + jax.checkpoint_policies.save_only_these_names( + "context", + ), ) else: - assert ( - cfg.remat_policy == 'full' - ), 'Remat policy needs to be on list of remat policies' + assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" policy = None BlockLayer = nn.remat( # pylint: disable=invalid-name BlockLayer, @@ -251,23 +269,21 @@ def __call__(self, static_argnums=(-1, -2, -3, -4, -5), ) if cfg.scan_layers: - initializing = self.is_mutable_collection('params') - params_spec = ( - cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) - ) + initializing = self.is_mutable_collection("params") + params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) cache_spec = 0 y, _ = nn.scan( BlockLayer, variable_axes={ - 'params': params_spec, - 'cache': cache_spec, - 'intermediates': 0, - 'aqt':0, - '_overwrite_with_gradient': 0, + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, }, split_rngs={ - 'params': True, - 'dropout': cfg.enable_dropout, + "params": True, + "dropout": cfg.enable_dropout, }, in_axes=( nn.broadcast, @@ -276,8 +292,8 @@ def __call__(self, nn.broadcast, ), length=cfg.num_decoder_layers, - metadata_params={nn.PARTITION_NAME: 'layers'}, - )(config=cfg, mesh=mesh, name='layers', quant=self.quant)( + metadata_params={nn.PARTITION_NAME: "layers"}, + )(config=cfg, mesh=mesh, name="layers", quant=self.quant)( y, decoder_segment_ids, decoder_positions, @@ -286,8 +302,7 @@ def __call__(self, ) else: for lyr in range(cfg.num_decoder_layers): - y = BlockLayer(config=cfg, mesh=mesh, name=f'layers_{lyr}', - quant=self.quant)( + y = BlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( y, decoder_segment_ids, decoder_positions, @@ -296,15 +311,13 @@ def __call__(self, ) y = self.get_norm_layer()( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name='decoder_norm', - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=('embed',), - )(y) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic - ) + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="decoder_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("embed",), + )(y) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: @@ -318,16 +331,19 @@ def __call__(self, cfg.vocab_size, weight_dtype=cfg.weight_dtype, dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=('embed', 'vocab'), - name='logits_dense')(y) # We do not quantize the logits matmul. - logits = nn.with_logical_constraint( - logits, ('activation_batch', 'activation_length', 'activation_vocab')) + kernel_axes=("embed", "vocab"), + name="logits_dense", + )( + y + ) # We do not quantize the logits matmul. + logits = nn.with_logical_constraint(logits, ("activation_batch", "activation_length", "activation_vocab")) logits = logits.astype(jnp.float32) return logits class Transformer(nn.Module): """An decoder-only Transformer model.""" + # Make new attributes required, so that all Transformer dependencies (train, decode, compile, etc) will error instead of silently use defaults. # pylint: disable=attribute-defined-outside-init config: Config @@ -345,14 +361,11 @@ def setup(self): dtype=cfg.dtype, attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), - name='token_embedder', + name="token_embedder", config=cfg, ) - self.decoder = Decoder( - config=cfg, shared_embedding=self.shared_embedding, - mesh=mesh, quant=self.quant - ) + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) def __call__( self, @@ -360,14 +373,15 @@ def __call__( decoder_positions, decoder_segment_ids=None, enable_dropout=True, - model_mode=common_types.MODEL_MODE_TRAIN + model_mode=common_types.MODEL_MODE_TRAIN, ): """Applies Transformer decoder-branch on encoded-input and target.""" if decoder_segment_ids is not None and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: raise ValueError( - f'During autoregressive decoding we assume the tokens are in the active sequence' - f' which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}.') + f"During autoregressive decoding we assume the tokens are in the active sequence" + f" which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}." + ) logits = self.decoder( decoder_input_tokens=decoder_input_tokens, diff --git a/MaxText/layers/normalizations.py b/MaxText/layers/normalizations.py index 6d451d4fe..862c586c9 100644 --- a/MaxText/layers/normalizations.py +++ b/MaxText/layers/normalizations.py @@ -26,6 +26,7 @@ class RMSNorm(nn.Module): """RMS normalization.""" + epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -40,7 +41,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) scale = self.param( - 'scale', + "scale", nn.with_logical_partitioning(self.scale_init, self.kernel_axes), (features,), self.weight_dtype, diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index c6450ea17..dba7658bd 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -27,117 +27,119 @@ MAX_INT8 = 127.5 + @dataclass class Quantization: - """Base class for quantization configurations""" + """Base class for quantization configurations""" - def dot_general_cls(self): - """ Placeholder for dot_general implementation in subclasses. """ - pass + def dot_general_cls(self): + """Placeholder for dot_general implementation in subclasses.""" + pass @dataclass class AqtQuantization: - """ Configures AQT quantization github.com/google/aqt. """ + """Configures AQT quantization github.com/google/aqt.""" + quant_dg: aqt_config.DotGeneral quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN def dot_general_cls(self): - """ Returns dot_general configured with aqt params. """ - aqt_dg_cls = functools.partial( - aqt_flax.AqtDotGeneral, - self.quant_dg, - rhs_quant_mode=self.quant_mode - ) + """Returns dot_general configured with aqt params.""" + aqt_dg_cls = functools.partial(aqt_flax.AqtDotGeneral, self.quant_dg, rhs_quant_mode=self.quant_mode) return aqt_dg_cls def einsum(self): - """ Returns einsum configured with aqt params """ - aqt_einsum = functools.partial(aqt_flax.AqtEinsum( - cfg=self.quant_dg, - lhs_quant_mode=self.quant_mode - ) - ) + """Returns einsum configured with aqt params""" + aqt_einsum = functools.partial(aqt_flax.AqtEinsum(cfg=self.quant_dg, lhs_quant_mode=self.quant_mode)) return aqt_einsum + @dataclass class Fp8Quantization(Quantization): - """ Configures Fp8 quantization for NVIDIA GPUs""" + """Configures Fp8 quantization for NVIDIA GPUs""" + quant_mode = "train" def dot_general_cls(self): - """ Returns dot_general configured with aqt params. """ + """Returns dot_general configured with aqt params.""" return nn.Fp8DotGeneralOp + def _get_quant_config(config): """Set quantization params based on user configuration.""" - if not config.quantization or config.quantization == '': + if not config.quantization or config.quantization == "": return None elif config.quantization == "int8": if config.quantization_local_shard_count == 0: drhs_bits = None drhs_accumulator_dtype = None - drhs_local_aqt=None + drhs_local_aqt = None else: drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 drhs_local_aqt = aqt_config.LocalAqt(config.quantization_local_shard_count) return aqt_config.config_v3( - fwd_bits=8, - dlhs_bits=8, - drhs_bits=drhs_bits, - rng_type='jax.uniform', - dlhs_local_aqt=None, - drhs_local_aqt=drhs_local_aqt, - fwd_accumulator_dtype=jnp.int32, - dlhs_accumulator_dtype=jnp.int32, - drhs_accumulator_dtype=drhs_accumulator_dtype, + fwd_bits=8, + dlhs_bits=8, + drhs_bits=drhs_bits, + rng_type="jax.uniform", + dlhs_local_aqt=None, + drhs_local_aqt=drhs_local_aqt, + fwd_accumulator_dtype=jnp.int32, + dlhs_accumulator_dtype=jnp.int32, + drhs_accumulator_dtype=drhs_accumulator_dtype, ) elif config.quantization == "fp8": return "fp8" else: - raise ValueError(f'Invalid value configured for quantization {config.quantization}.') + raise ValueError(f"Invalid value configured for quantization {config.quantization}.") + def in_convert_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT) + def in_serve_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE) -def get_quant_mode(quant_mode_str: str = 'train'): - """ Set quant mode.""" - if quant_mode_str == 'train': + +def get_quant_mode(quant_mode_str: str = "train"): + """Set quant mode.""" + if quant_mode_str == "train": return aqt_flax.QuantMode.TRAIN - elif quant_mode_str == 'serve': + elif quant_mode_str == "serve": return aqt_flax.QuantMode.SERVE - elif quant_mode_str == 'convert': + elif quant_mode_str == "convert": return aqt_flax.QuantMode.CONVERT else: - raise ValueError(f'Invalid quantization mode {quant_mode_str}.') + raise ValueError(f"Invalid quantization mode {quant_mode_str}.") return None -def configure_quantization(config: Config, quant_mode_str: str = 'train'): - """ Configure quantization based on user config and quant mode.""" + +def configure_quantization(config: Config, quant_mode_str: str = "train"): + """Configure quantization based on user config and quant mode.""" quant_cfg = _get_quant_config(config) if quant_cfg: if quant_cfg == "fp8": - return Fp8Quantization() + return Fp8Quantization() quant_mode = get_quant_mode(quant_mode_str) return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode) return None + def _get_aqt_key_paths(aqt_vars): - """ Generate a list of paths which have aqt state """ + """Generate a list of paths which have aqt state""" aqt_tree_flat, _ = jax.tree_util.tree_flatten_with_path(aqt_vars) aqt_key_paths = [] for k, _ in aqt_tree_flat: pruned_keys = [] for d in list(k): - if 'AqtDotGeneral' in d.key: - pruned_keys.append(jax.tree_util.DictKey(key='kernel')) + if "AqtDotGeneral" in d.key: + pruned_keys.append(jax.tree_util.DictKey(key="kernel")) break else: - assert 'Aqt' not in d.key, f"Unexpected Aqt op {d.key} in {k}." + assert "Aqt" not in d.key, f"Unexpected Aqt op {d.key} in {k}." pruned_keys.append(d) aqt_key_paths.append(tuple(pruned_keys)) return aqt_key_paths @@ -153,16 +155,19 @@ def remove_quantized_params(params, aqt_vars): tree_flat[i] = v return tree_unflatten(tree_struct, tree_flat) + def configure_kv_quantization(config: Config): - """ Configure kv quantization based on user config.""" + """Configure kv quantization based on user config.""" return False if not config.quantize_kvcache else True + def quantize_kv(kv: Array): """Quantize key/values stored in kvcache.""" scale = jnp.max(jnp.abs(kv), axis=-1, keepdims=True) value = jnp.int8(jnp.rint(kv * (MAX_INT8 / scale))) return value, scale -def unquantize_kv(value: Array, scale:Array, dtype:jnp.dtype): + +def unquantize_kv(value: Array, scale: Array, dtype: jnp.dtype): """Unquantize key/values stored in kvcache.""" return value.astype(dtype) * scale / MAX_INT8 diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index b9a7aefe4..97a456764 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" r"""Convert weights from a Llama or Mistral model to a MaxText one. @@ -42,54 +42,55 @@ import sys import os -jax.config.update('jax_platform_name', 'cpu') +jax.config.update("jax_platform_name", "cpu") + def permute_to_match_maxtext_rope(arr): evens = arr[..., ::2] odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) MODEL_PARAMS_DICT = { - 'llama2-70b': { - 'num_layers': 80, - 'num_heads': 64, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-70b": { + "num_layers": 80, + "num_heads": 64, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, }, - 'llama2-13b': { - 'num_layers': 40, - 'num_heads': 40, - 'num_kv_heads': 40, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-13b": { + "num_layers": 40, + "num_heads": 40, + "num_kv_heads": 40, + "dims_per_head": 128, + "vocab": 32000, }, - 'llama2-7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 32, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 32, + "dims_per_head": 128, + "vocab": 32000, }, - 'mistral-7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, - 'base_emb_dim': 4096, - 'base_mlp_dim': 14336, + "mistral-7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, + "base_emb_dim": 4096, + "base_mlp_dim": 14336, }, - 'mixtral-8x7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, - 'base_emb_dim': 4096, - 'base_mlp_dim': 14336, - 'num_experts': 8, + "mixtral-8x7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, + "base_emb_dim": 4096, + "base_mlp_dim": 14336, + "num_experts": 8, }, } @@ -108,256 +109,213 @@ def convert(base_model_path, maxtext_model_path, model_size): """ """Convert model to maxtext.""" model_params = MODEL_PARAMS_DICT[model_size] - base_num_decoder_layers = model_params['num_layers'] - base_num_query_heads = model_params['num_heads'] - head_dim = model_params['dims_per_head'] - base_num_kv_heads = model_params['num_kv_heads'] - vocab_size = model_params['vocab'] - num_experts = model_params['num_experts'] if 'num_experts' in model_params else None - - print(f'Loading the base model from {base_model_path}') + base_num_decoder_layers = model_params["num_layers"] + base_num_query_heads = model_params["num_heads"] + head_dim = model_params["dims_per_head"] + base_num_kv_heads = model_params["num_kv_heads"] + vocab_size = model_params["vocab"] + num_experts = model_params["num_experts"] if "num_experts" in model_params else None + + print(f"Loading the base model from {base_model_path}") # Skip any hidden files for checkpoints - ckpt_paths = sorted(pathlib.Path(base_model_path).glob('[!.]*.pth')) + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.pth")) pytorch_vars = {} for i, ckpt_path in enumerate(ckpt_paths): - print(f'Loading checkpoint {i+1} of {len(ckpt_paths)} ...') - checkpoint = torch.load(ckpt_path, map_location='cpu') - pytorch_vars[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint + print(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") + checkpoint = torch.load(ckpt_path, map_location="cpu") + pytorch_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint pytorch_vars = [pytorch_vars[i] for i in sorted(list(pytorch_vars.keys()))] - layer_key = 'gate' if num_experts else 'mlp' + layer_key = "gate" if num_experts else "mlp" jax_weights = { - 'decoder': { - 'layers': { + "decoder": { + "layers": { layer_key: {}, - 'pre_self_attention_layer_norm': {}, - 'post_self_attention_layer_norm': {}, - 'self_attention': {}, + "pre_self_attention_layer_norm": {}, + "post_self_attention_layer_norm": {}, + "self_attention": {}, }, - 'decoder_norm': { - 'scale': pytorch_vars[0]['norm.weight'].type(torch.float16).numpy() + "decoder_norm": {"scale": pytorch_vars[0]["norm.weight"].type(torch.float16).numpy()}, + "logits_dense": { + "kernel": np.concatenate( + [var["output.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose()[:, :vocab_size] }, - 'logits_dense': { - 'kernel': np.concatenate([var['output.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose()[:, :vocab_size] - } }, - 'token_embedder': { - 'embedding': np.concatenate([var['tok_embeddings.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1)[:vocab_size, :] - - } - - } - - layer_weight = { - 'pre_self_attention_layer_norm': { - 'scale': [] + "token_embedder": { + "embedding": np.concatenate( + [var["tok_embeddings.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=1 + )[:vocab_size, :] }, - 'post_self_attention_layer_norm': { - 'scale': [] - } } + layer_weight = {"pre_self_attention_layer_norm": {"scale": []}, "post_self_attention_layer_norm": {"scale": []}} + if num_experts is None: - layer_weight['mlp'] = { - 'wi_0': { - 'kernel': [] - }, - 'wi_1': { - 'kernel': [] - }, - 'wo': { - 'kernel': [] - }, + layer_weight["mlp"] = { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, } else: - layer_weight['gate'] = { - 'kernel': [] - } + layer_weight["gate"] = {"kernel": []} for k in range(num_experts): - jax_weights['decoder']['layers'][f'mlp_{k}'] = {} - layer_weight[f'mlp_{k}'] = { - 'wi_0': { - 'kernel': [] - }, - 'wi_1': { - 'kernel': [] - }, - 'wo': { - 'kernel': [] - }, + jax_weights["decoder"]["layers"][f"mlp_{k}"] = {} + layer_weight[f"mlp_{k}"] = { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, } self_attention = { - 'query': { - 'kernel': [] - }, - 'key': { - 'kernel': [] - }, - 'value': { - 'kernel': [] - }, - 'out': { - 'kernel': [] - }, + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, } for layer_idx in range(base_num_decoder_layers): - wq = np.concatenate([var[f'layers.{layer_idx}.attention.wq.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wk = np.concatenate([var[f'layers.{layer_idx}.attention.wk.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wv = np.concatenate([var[f'layers.{layer_idx}.attention.wv.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - - wq = np.reshape(wq, [base_num_query_heads * head_dim, - base_num_query_heads, head_dim]) - wk = np.reshape(wk, [base_num_query_heads * head_dim, - base_num_kv_heads, head_dim]) - wv = np.reshape(wv, [base_num_query_heads * head_dim, - base_num_kv_heads, head_dim]) + wq = np.concatenate( + [var[f"layers.{layer_idx}.attention.wq.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wk = np.concatenate( + [var[f"layers.{layer_idx}.attention.wk.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wv = np.concatenate( + [var[f"layers.{layer_idx}.attention.wv.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + + wq = np.reshape(wq, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) + wk = np.reshape(wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) + wv = np.reshape(wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim]) wq = permute_to_match_maxtext_rope(wq) wk = permute_to_match_maxtext_rope(wk) w_post = np.concatenate( - [ - var[f'layers.{layer_idx}.attention.wo.weight'].type( - torch.float16).numpy() - for var in pytorch_vars - ], + [var[f"layers.{layer_idx}.attention.wo.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=1, ) - w_post = np.reshape( - w_post, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) - - self_attention['query']['kernel'].append(wq) - self_attention['key']['kernel'].append(wk) - self_attention['value']['kernel'].append(wv) - self_attention['out']['kernel'].append(w_post) - pre_self_attention_layernorm = pytorch_vars[0][f'layers.{layer_idx}.attention_norm.weight'].type( - torch.float16).numpy() - post_self_attention_layernorm = pytorch_vars[0][f'layers.{layer_idx}.ffn_norm.weight'].type( - torch.float16).numpy() - layer_weight['pre_self_attention_layer_norm']['scale'].append( - pre_self_attention_layernorm) - layer_weight['post_self_attention_layer_norm']['scale'].append( - post_self_attention_layernorm) + w_post = np.reshape(w_post, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) + + self_attention["query"]["kernel"].append(wq) + self_attention["key"]["kernel"].append(wk) + self_attention["value"]["kernel"].append(wv) + self_attention["out"]["kernel"].append(w_post) + pre_self_attention_layernorm = pytorch_vars[0][f"layers.{layer_idx}.attention_norm.weight"].type(torch.float16).numpy() + post_self_attention_layernorm = pytorch_vars[0][f"layers.{layer_idx}.ffn_norm.weight"].type(torch.float16).numpy() + layer_weight["pre_self_attention_layer_norm"]["scale"].append(pre_self_attention_layernorm) + layer_weight["post_self_attention_layer_norm"]["scale"].append(post_self_attention_layernorm) if num_experts is None: - wi_0 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w1.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wi_1 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w3.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wo = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w2.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1).transpose() - layer_weight['mlp']['wi_0']['kernel'].append(wi_0) - layer_weight['mlp']['wi_1']['kernel'].append(wi_1) - layer_weight['mlp']['wo']['kernel'].append(wo) + wi_0 = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.w1.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wi_1 = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.w3.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + wo = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.w2.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=1 + ).transpose() + layer_weight["mlp"]["wi_0"]["kernel"].append(wi_0) + layer_weight["mlp"]["wi_1"]["kernel"].append(wi_1) + layer_weight["mlp"]["wo"]["kernel"].append(wo) else: - gate = np.concatenate([var[f'layers.{layer_idx}.feed_forward.gate.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - layer_weight['gate']['kernel'].append(gate) + gate = np.concatenate( + [var[f"layers.{layer_idx}.feed_forward.gate.weight"].type(torch.float16).numpy() for var in pytorch_vars], axis=0 + ).transpose() + layer_weight["gate"]["kernel"].append(gate) for k in range(num_experts): - wi_0 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w1.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wi_1 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w3.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wo = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w2.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1).transpose() - layer_weight[f'mlp_{k}']['wi_0']['kernel'].append(wi_0) - layer_weight[f'mlp_{k}']['wi_1']['kernel'].append(wi_1) - layer_weight[f'mlp_{k}']['wo']['kernel'].append(wo) - - self_attention['query']['kernel'] = np.array( - self_attention['query']['kernel']) - self_attention['key']['kernel'] = np.array(self_attention['key']['kernel']) - self_attention['value']['kernel'] = np.array( - self_attention['value']['kernel']) - self_attention['out']['kernel'] = np.array(self_attention['out']['kernel']) - self_attention['query']['kernel'] = np.transpose( - self_attention['query']['kernel'], axes=(1, 0, 2, 3)) - self_attention['key']['kernel'] = np.transpose( - self_attention['key']['kernel'], axes=(1, 0, 2, 3)) - self_attention['value']['kernel'] = np.transpose( - self_attention['value']['kernel'], axes=(1, 0, 2, 3)) + wi_0 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w1.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wi_1 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w3.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wo = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w2.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=1, + ).transpose() + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"].append(wi_0) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"].append(wi_1) + layer_weight[f"mlp_{k}"]["wo"]["kernel"].append(wo) + + self_attention["query"]["kernel"] = np.array(self_attention["query"]["kernel"]) + self_attention["key"]["kernel"] = np.array(self_attention["key"]["kernel"]) + self_attention["value"]["kernel"] = np.array(self_attention["value"]["kernel"]) + self_attention["out"]["kernel"] = np.array(self_attention["out"]["kernel"]) + self_attention["query"]["kernel"] = np.transpose(self_attention["query"]["kernel"], axes=(1, 0, 2, 3)) + self_attention["key"]["kernel"] = np.transpose(self_attention["key"]["kernel"], axes=(1, 0, 2, 3)) + self_attention["value"]["kernel"] = np.transpose(self_attention["value"]["kernel"], axes=(1, 0, 2, 3)) # layers, base_num_query_heads * head_dim, base_num_query_heads, head_dim => # base_num_query_heads, layers,head_dim, base_num_query_heads * head_dim - self_attention['out']['kernel'] = np.transpose( - self_attention['out']['kernel'], axes=(2, 0, 3, 1)) + self_attention["out"]["kernel"] = np.transpose(self_attention["out"]["kernel"], axes=(2, 0, 3, 1)) # scale the query weights - self_attention['query']['kernel'] = self_attention['query']['kernel'] / \ - np.sqrt(head_dim) + self_attention["query"]["kernel"] = self_attention["query"]["kernel"] / np.sqrt(head_dim) - jax_weights['decoder']['layers']['self_attention'] = self_attention + jax_weights["decoder"]["layers"]["self_attention"] = self_attention # self attention layer norm and swap the layer index - layer_weight['pre_self_attention_layer_norm']['scale'] = np.array( - layer_weight['pre_self_attention_layer_norm']['scale']) - layer_weight['post_self_attention_layer_norm']['scale'] = np.array( - layer_weight['post_self_attention_layer_norm']['scale']) - layer_weight['pre_self_attention_layer_norm']['scale'] = np.transpose( - layer_weight['pre_self_attention_layer_norm']['scale'], - axes=(1, 0)) - layer_weight['post_self_attention_layer_norm']['scale'] = np.transpose( - layer_weight['post_self_attention_layer_norm']['scale'], - axes=(1, 0)) - - jax_weights['decoder']['layers']['pre_self_attention_layer_norm'] = layer_weight['pre_self_attention_layer_norm'] - jax_weights['decoder']['layers']['post_self_attention_layer_norm'] = layer_weight['post_self_attention_layer_norm'] + layer_weight["pre_self_attention_layer_norm"]["scale"] = np.array(layer_weight["pre_self_attention_layer_norm"]["scale"]) + layer_weight["post_self_attention_layer_norm"]["scale"] = np.array(layer_weight["post_self_attention_layer_norm"]["scale"]) + layer_weight["pre_self_attention_layer_norm"]["scale"] = np.transpose( + layer_weight["pre_self_attention_layer_norm"]["scale"], axes=(1, 0) + ) + layer_weight["post_self_attention_layer_norm"]["scale"] = np.transpose( + layer_weight["post_self_attention_layer_norm"]["scale"], axes=(1, 0) + ) + + jax_weights["decoder"]["layers"]["pre_self_attention_layer_norm"] = layer_weight["pre_self_attention_layer_norm"] + jax_weights["decoder"]["layers"]["post_self_attention_layer_norm"] = layer_weight["post_self_attention_layer_norm"] if num_experts is None: - layer_weight['mlp']['wi_0']['kernel'] = np.array( - layer_weight['mlp']['wi_0']['kernel']) - layer_weight['mlp']['wi_1']['kernel'] = np.array( - layer_weight['mlp']['wi_1']['kernel']) - layer_weight['mlp']['wo']['kernel'] = np.array( - layer_weight['mlp']['wo']['kernel']) + layer_weight["mlp"]["wi_0"]["kernel"] = np.array(layer_weight["mlp"]["wi_0"]["kernel"]) + layer_weight["mlp"]["wi_1"]["kernel"] = np.array(layer_weight["mlp"]["wi_1"]["kernel"]) + layer_weight["mlp"]["wo"]["kernel"] = np.array(layer_weight["mlp"]["wo"]["kernel"]) # swap the layer index - layer_weight['mlp']['wi_0']['kernel'] = np.transpose( - layer_weight['mlp']['wi_0']['kernel'], axes=(1, 0, 2)) - layer_weight['mlp']['wi_1']['kernel'] = np.transpose( - layer_weight['mlp']['wi_1']['kernel'], axes=(1, 0, 2)) - layer_weight['mlp']['wo']['kernel'] = np.transpose( - layer_weight['mlp']['wo']['kernel'], axes=(1, 0, 2)) - - jax_weights['decoder']['layers']['mlp'] = layer_weight['mlp'] + layer_weight["mlp"]["wi_0"]["kernel"] = np.transpose(layer_weight["mlp"]["wi_0"]["kernel"], axes=(1, 0, 2)) + layer_weight["mlp"]["wi_1"]["kernel"] = np.transpose(layer_weight["mlp"]["wi_1"]["kernel"], axes=(1, 0, 2)) + layer_weight["mlp"]["wo"]["kernel"] = np.transpose(layer_weight["mlp"]["wo"]["kernel"], axes=(1, 0, 2)) + + jax_weights["decoder"]["layers"]["mlp"] = layer_weight["mlp"] else: - layer_weight['gate']['kernel'] = np.array(layer_weight['gate']['kernel']) - layer_weight['gate']['kernel'] = np.transpose( - layer_weight['gate']['kernel'], axes=(1, 0, 2)) - jax_weights['decoder']['layers']['gate'] = layer_weight['gate'] + layer_weight["gate"]["kernel"] = np.array(layer_weight["gate"]["kernel"]) + layer_weight["gate"]["kernel"] = np.transpose(layer_weight["gate"]["kernel"], axes=(1, 0, 2)) + jax_weights["decoder"]["layers"]["gate"] = layer_weight["gate"] for k in range(num_experts): - layer_weight[f'mlp_{k}']['wi_0']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wi_0']['kernel']) - layer_weight[f'mlp_{k}']['wi_1']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wi_1']['kernel']) - layer_weight[f'mlp_{k}']['wo']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wo']['kernel']) + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"]) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"]) + layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.array(layer_weight[f"mlp_{k}"]["wo"]["kernel"]) # swap the layer index - layer_weight[f'mlp_{k}']['wi_0']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wi_0']['kernel'], axes=(1, 0, 2)) - layer_weight[f'mlp_{k}']['wi_1']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wi_1']['kernel'], axes=(1, 0, 2)) - layer_weight[f'mlp_{k}']['wo']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wo']['kernel'], axes=(1, 0, 2)) + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_0"]["kernel"], axes=(1, 0, 2)) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wi_1"]["kernel"], axes=(1, 0, 2)) + layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.transpose(layer_weight[f"mlp_{k}"]["wo"]["kernel"], axes=(1, 0, 2)) - jax_weights['decoder']['layers'][f'mlp_{k}'] = layer_weight[f'mlp_{k}'] + jax_weights["decoder"]["layers"][f"mlp_{k}"] = layer_weight[f"mlp_{k}"] mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis") - s1=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec("checkpoint_sharding_axis")) #shards first axis - s2=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec(None,"checkpoint_sharding_axis")) #shards second axis - s3=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec(None)) #no sharding + s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis")) # shards first axis + s2 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "checkpoint_sharding_axis")) # shards second axis + s3 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # no sharding def checkpoint_device_put(arr): - if arr.shape[0]%SIMULATED_CPU_DEVICES_COUNT==0: + if arr.shape[0] % SIMULATED_CPU_DEVICES_COUNT == 0: print("sharding first axis") return jax.device_put(arr, device=s1) - elif len(arr.shape)>1 and arr.shape[1]%SIMULATED_CPU_DEVICES_COUNT==0: + elif len(arr.shape) > 1 and arr.shape[1] % SIMULATED_CPU_DEVICES_COUNT == 0: print("sharding second axis") return jax.device_put(arr, device=s2) else: @@ -374,41 +332,33 @@ def checkpoint_device_put(arr): save_interval_steps = 1 checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - maxtext_model_path, - enable_checkpointing, - async_checkpointing, - save_interval_steps + maxtext_model_path, enable_checkpointing, async_checkpointing, save_interval_steps ) state_new = train_state.TrainState( - step=0, - apply_fn=None, - params={'params': jax_weights}, - tx=None, # type: ignore - opt_state={} + step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore ) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new): - max_logging.log( - f"saved a checkpoint at step {step_number_to_save_new_ckpt}") + max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}") # Upon preemption, exit when and only when all ongoing saves are complete. if checkpoint_manager.reached_preemption(0): checkpoint_manager.wait_until_finished() sys.exit() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--base-model-path', type=str, required=True) - parser.add_argument('--maxtext-model-path', type=str, required=True) - parser.add_argument('--model-size', type=str, required=True) + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--maxtext-model-path", type=str, required=True) + parser.add_argument("--model-size", type=str, required=True) args = parser.parse_args() if args.model_size not in MODEL_PARAMS_DICT: raise NotImplementedError - os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}' + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}" convert(args.base_model_path, args.maxtext_model_path, args.model_size) diff --git a/MaxText/max_logging.py b/MaxText/max_logging.py index 61d984f7b..23f7cc5cf 100644 --- a/MaxText/max_logging.py +++ b/MaxText/max_logging.py @@ -1,20 +1,21 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Stub for logging utilities. Right now just meant to avoid raw prints""" + def log(user_str): - print(user_str, flush = True) + print(user_str, flush=True) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 92b789bb4..2e218c48a 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Common Max Utils needed by multiple modules""" import checkpointing @@ -44,17 +44,19 @@ from google.cloud import storage + def find_nans_and_infs(pytree): def finder(x): return jnp.any(jnp.isinf(x) | jnp.isnan(x)) + bad_pytree = jax.tree_map(finder, pytree) return jax.tree_util.tree_flatten(bad_pytree) + def l2norm_pytree(x): """L2 norm of a pytree of arrays.""" - return jax.tree_util.tree_reduce( - lambda x, y: x + jax.numpy.sum(jax.numpy.square(y)), x, initializer=0.0 - ) ** 0.5 + return jax.tree_util.tree_reduce(lambda x, y: x + jax.numpy.sum(jax.numpy.square(y)), x, initializer=0.0) ** 0.5 + def calculate_num_params_from_pytree(params): params_sizes = jax.tree_util.tree_map(jax.numpy.size, params) @@ -68,70 +70,78 @@ def calculate_leaf_params_per_chip(arr): shard = arr.addressable_shards[0] return np.prod(shard.data.shape) - params_sizes_per_chip = jax.tree_util.tree_map( - calculate_leaf_params_per_chip, params) - total_parameters_per_chip = jax.tree_util.tree_reduce( - lambda x, y: x + y, params_sizes_per_chip) + params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_per_chip, params) + total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip) return total_parameters_per_chip def calculate_bytes_from_pytree(params): - params_bytes = jax.tree_util.tree_map(lambda x : x.nbytes, params) + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) return total_bytes + def summarize_size_from_pytree(params): num_params = calculate_num_params_from_pytree(params) num_bytes = calculate_bytes_from_pytree(params) - return num_params, num_bytes, num_bytes/num_params + return num_params, num_bytes, num_bytes / num_params + def activate_profiler(config, optional_postfix=""): if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0): output_path = os.path.join(config.tensorboard_dir, optional_postfix) jax.profiler.start_trace(output_path) + def deactivate_profiler(config): if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0): jax.profiler.stop_trace() + def initialize_summary_writer(config): return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None + def close_summary_writer(summary_writer): if jax.process_index() == 0: summary_writer.close() + def _prepare_metrics_for_json(metrics, step, run_name): """Converts metric dictionary into json supported types (e.g. float)""" metrics_dict = {} - for val in metrics['scalar']: - metrics_dict[val] = float(metrics['scalar'][val]) - metrics_dict['step'] = float(step) - metrics_dict['run_name'] = run_name + for val in metrics["scalar"]: + metrics_dict[val] = float(metrics["scalar"][val]) + metrics_dict["step"] = float(step) + metrics_dict["run_name"] = run_name return metrics_dict + def write_metrics_locally(metrics, step, config, file): """Writes metrics locally for testing""" if step == 0: file.truncate(0) metrics_dict = _prepare_metrics_for_json(metrics, step, config.run_name) - file.write(str(json.dumps(metrics_dict))+'\n') + file.write(str(json.dumps(metrics_dict)) + "\n") if step == config.steps - 1: file.close() + def add_config_to_summary_writer(config, summary_writer): """Writes config params to tensorboard""" if jax.process_index() == 0: for key, value in config.get_keys().items(): add_text_to_summary_writer(key, str(value), summary_writer) + def add_text_to_summary_writer(key, value, summary_writer): """Writes given key-value pair to tensorboard as text/summary""" if jax.process_index() == 0: summary_writer.add_text(key, value) + def write_metrics_for_gcs(metrics, step, config, running_metrics): """Writes metrics to gcs""" metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name) @@ -139,18 +149,19 @@ def write_metrics_for_gcs(metrics, step, config, running_metrics): if (step + 1) % config.log_period == 0 or step == config.steps - 1: start_step = (step // config.log_period) * config.log_period metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt" - with open(metrics_filename, 'w', encoding="utf8") as metrics_for_gcs: + with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs: for metrics_step in running_metrics: - metrics_for_gcs.write(str(json.dumps(metrics_step))+'\n') + metrics_for_gcs.write(str(json.dumps(metrics_step)) + "\n") metrics_for_gcs.close() - gcs_filename=os.path.join(config.metrics_dir, metrics_filename) + gcs_filename = os.path.join(config.metrics_dir, metrics_filename) max_logging.log(f"Moving file {metrics_filename} to GCS...") upload_blob(gcs_filename, metrics_filename) max_logging.log(f"File {metrics_filename} moved successfully!") - running_metrics = [] # reset running_metrics to empty list + running_metrics = [] # reset running_metrics to empty list return running_metrics + def write_config_raw_keys_for_gcs(raw_keys): """Writes config raw keys to GCS""" if not raw_keys["save_config_to_gcs"] or jax.process_index() != 0: @@ -159,21 +170,23 @@ def write_config_raw_keys_for_gcs(raw_keys): raw_keys_dict = dict(raw_keys) filename = "config.yml" - with open(filename, 'w', encoding="utf8") as config_for_gcs: + with open(filename, "w", encoding="utf8") as config_for_gcs: yaml.dump(raw_keys_dict, config_for_gcs) config_for_gcs.close() - gcs_filename=os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], filename) + gcs_filename = os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], filename) max_logging.log(f"Moving file {filename} to GCS...") upload_blob(gcs_filename, filename) max_logging.log(f"File {filename} moved successfully!") + def parse_gcs_bucket_and_prefix(destination_gcs_name): path_parts = destination_gcs_name.replace("gs://", "").split("/") bucket = path_parts.pop(0) key = "/".join(path_parts) return bucket, key + def upload_blob(destination_gcs_name, source_file_name): """Uploads a file to a GCS location""" bucket_name, prefix_name = parse_gcs_bucket_and_prefix(destination_gcs_name) @@ -182,16 +195,18 @@ def upload_blob(destination_gcs_name, source_file_name): blob = bucket.blob(prefix_name) blob.upload_from_filename(source_file_name) + def maybe_initialize_jax_distributed_system(raw_keys): - """ The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of - indirection in MaxText to avoid breaking the call sites unnecessarily. + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. - Currently jax.distributed.initialize() fully works as expected! + Currently jax.distributed.initialize() fully works as expected! - For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. """ - if (raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] - and raw_keys["compile_topology_num_slices"]==-1) or raw_keys["hardware"]=='gpu_multiprocess': + if ( + raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1 + ) or raw_keys["hardware"] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") jax.distributed.initialize() max_logging.log("Jax distributed system initialized!") @@ -204,6 +219,7 @@ def maybe_initialize_jax_distributed_system(raw_keys): initialize_jax_for_cpu() max_logging.log("Jax distributed system initialized on CPUs!") + def initialize_jax_for_gpu(): """Jax distributed initialize for GPUs.""" if os.environ.get("JAX_COORDINATOR_IP") is not None: @@ -212,14 +228,15 @@ def initialize_jax_for_gpu(): jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", num_processes=int(os.getenv("NNODES")), - process_id=int(os.getenv("NODE_RANK"))) + process_id=int(os.getenv("NODE_RANK")), + ) max_logging.log(f"JAX global devices: {jax.devices()}") + def initialize_jax_for_cpu(): - """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready. - """ + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" coordinator_ip_address = get_coordinator_ip_address() - coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK # Env variables to be set in XPK or otherwise job_index = int(os.environ.get("JOB_INDEX")) job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) @@ -227,17 +244,20 @@ def initialize_jax_for_cpu(): pid = job_index * processes_in_job + job_completion_index max_logging.log(f" Jax process id is {pid} ") # Explicit initialize is needed only for CPUs - jax.distributed.initialize(coordinator_address=coordinator_address, - process_id=pid, - num_processes=int(os.environ.get("JAX_PROCESS_COUNT"))) + jax.distributed.initialize( + coordinator_address=coordinator_address, process_id=pid, num_processes=int(os.environ.get("JAX_PROCESS_COUNT")) + ) + def is_cpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a CPU backend.""" - return raw_keys["hardware"] == 'cpu' + return raw_keys["hardware"] == "cpu" + def is_gpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a GPU backend.""" - return raw_keys["hardware"] == 'gpu' + return raw_keys["hardware"] == "gpu" + def get_coordinator_ip_address(): """Get coordinator IP Address with retries""" @@ -260,48 +280,66 @@ def get_coordinator_ip_address(): max_logging.log(f"Coordinator IP address: {coordinator_ip_address}") return coordinator_ip_address + def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type): """Evaluates unspecified DCN/ICI parallelism values""" if -1 in parallelism_vals: - assert parallelism_vals.count(-1) == 1, f"Found unspecified values (-1) for more than one {parallelism_type}\ + assert ( + parallelism_vals.count(-1) == 1 + ), f"Found unspecified values (-1) for more than one {parallelism_type}\ parallelism axis. At most one axis can be unspecified." - determined_val = target_product/np.product(parallelism_vals)*-1 + determined_val = target_product / np.product(parallelism_vals) * -1 - assert determined_val >= 1 and determined_val.is_integer, f"Unspecified value unable to be determined with the given\ + assert ( + determined_val >= 1 and determined_val.is_integer + ), f"Unspecified value unable to be determined with the given\ {parallelism_type} parallelism values" parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) - target_type = "slices" if parallelism_type == 'DCN' else "devices per slice" + target_type = "slices" if parallelism_type == "DCN" else "devices per slice" - assert np.product(parallelism_vals) == target_product, f"Number of {target_type} {target_product} does not match\ + assert ( + np.product(parallelism_vals) == target_product + ), f"Number of {target_type} {target_product} does not match\ the product of the {parallelism_type} parallelism {np.product(parallelism_vals)}" return parallelism_vals + def create_device_mesh(config, devices=None): - """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas """ + """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: devices = jax.devices() num_devices = len(devices) num_slices = config.num_slices - num_devices_per_slice = num_devices//num_slices + num_devices_per_slice = num_devices // num_slices multi_slice_env = num_slices > 1 - dcn_parallelism = [config.dcn_data_parallelism, config.dcn_fsdp_parallelism, - config.dcn_fsdp_transpose_parallelism, config.dcn_sequence_parallelism, - config.dcn_tensor_parallelism, config.dcn_autoregressive_parallelism] - ici_parallelism = [config.ici_data_parallelism, config.ici_fsdp_parallelism, - config.ici_fsdp_transpose_parallelism, config.ici_sequence_parallelism, - config.ici_tensor_parallelism, config.ici_autoregressive_parallelism] + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_fsdp_transpose_parallelism, + config.dcn_sequence_parallelism, + config.dcn_tensor_parallelism, + config.dcn_autoregressive_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_fsdp_transpose_parallelism, + config.ici_sequence_parallelism, + config.ici_tensor_parallelism, + config.ici_autoregressive_parallelism, + ] # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, 'ICI') + ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") if multi_slice_env: - dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, 'DCN') + dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) else: mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) @@ -310,41 +348,35 @@ def create_device_mesh(config, devices=None): return mesh -def unbox_logicallypartioned( - boxed_pytree): - """ Unboxes the flax.LogicallyPartitioned pieces - Args: - boxed_pytree: a pytree that includes LogicallyPartitioned - leaves. - Returns: - a pytree where all all LogicallyPartitioned leaves have been unboxed. +def unbox_logicallypartioned(boxed_pytree): + """Unboxes the flax.LogicallyPartitioned pieces + + Args: + boxed_pytree: a pytree that includes LogicallyPartitioned + leaves. + Returns: + a pytree where all all LogicallyPartitioned leaves have been unboxed. """ - return jax.tree_util.tree_map(lambda x: x.unbox() if \ - isinstance(x, flax.linen.spmd.LogicallyPartitioned) \ - else x, boxed_pytree, \ - is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned)) + return jax.tree_util.tree_map( + lambda x: x.unbox() if isinstance(x, flax.linen.spmd.LogicallyPartitioned) else x, + boxed_pytree, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + def init_decode_state(apply_fn, params): """Init train state with null opt state for decode.""" - state = train_state.TrainState( - step=0, - apply_fn=apply_fn, - params=params, - tx=None, # type: ignore - opt_state={} - ) + state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state + def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create( - apply_fn=apply_fn, - params=params, - tx=tx - ) + state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state + def init_initial_state(model, tx, config, is_training, key): """ We pass in "static" objects like model, tx, config as JAX compares them by @@ -353,34 +385,37 @@ def init_initial_state(model, tx, config, is_training, key): Args: model, tx, config, is_training, key """ - input_shape = ( - config.global_batch_size_to_load, - config.max_target_length + input_shape = (config.global_batch_size_to_load, config.max_target_length) + model_vars = model.init( + {"params": key, "dropout": key, "aqt": key}, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(input_shape, dtype=jnp.int32), ) - model_vars = model.init({'params': key, 'dropout': key, 'aqt': key}, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(input_shape, dtype=jnp.int32)) if is_training: return init_training_state(model.apply, model_vars, tx) return init_decode_state(model.apply, model_vars) + def load_decode_model_vars(model, config, rng, mesh): state, _ = setup_decode_state(model, config, rng, mesh, None) return state.params + def setup_decode_state(model, config, rng, mesh, checkpoint_manager): is_training = False - state, state_mesh_annotations, _ = setup_initial_state(model, None, None, config, - rng, mesh, checkpoint_manager, - is_training) + state, state_mesh_annotations, _ = setup_initial_state( + model, None, None, config, rng, mesh, checkpoint_manager, is_training + ) return state, state_mesh_annotations + def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): is_training = True return setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training) + def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training=True): - """ We initialize the model and optimizer state, and optionally load from a + """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: @@ -397,33 +432,31 @@ def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_ state_mesh_annotations: the mesh annotations for the train state """ - unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(model, tx, config, - rng, mesh, is_training) + unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( + model, tx, config, rng, mesh, is_training + ) # Initialization with nn_partitioning.axis_rules(config.logical_axis_rules): - restored, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, - data_iterator, - config.load_parameters_path, - config.load_full_state_path, - unboxed_abstract_state, - config.enable_single_replica_ckpt_restoring, - config.dataset_type, - ) + restored, raw_params = checkpointing.load_state_if_possible( + checkpoint_manager, + data_iterator, + config.load_parameters_path, + config.load_full_state_path, + unboxed_abstract_state, + config.enable_single_replica_ckpt_restoring, + config.dataset_type, + ) if restored: - if 'iter' in restored and restored['iter'] is not None: - data_iterator.local_iterator = restored['iter'] - state = restored['items'] + if "iter" in restored and restored["iter"] is not None: + data_iterator.local_iterator = restored["iter"] + state = restored["items"] else: init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings - )(rng) - if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params = raw_params) + state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings)(rng) + if raw_params: # If we loaded a partial state, we need to merge it. + state = state.replace(params=raw_params) state = unbox_logicallypartioned(state) return state, state_mesh_annotations, data_iterator @@ -432,6 +465,7 @@ def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_ # Learning Rate Schedule # ----------------------------------------------------------------------------- + def create_learning_rate_schedule(config): """Creates a warmup and cosine decay learning rate schedule: We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 @@ -441,12 +475,14 @@ def create_learning_rate_schedule(config): 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. The zero learning rate section can be used to more accurately measure the fully trained model's performance. """ + def make_cos_schedule(init_lr, final_lr, len_steps): def schedule(step): pct = (step) / len_steps - a = 0.5 * (jnp.cos(jnp.pi*pct) + 1) + a = 0.5 * (jnp.cos(jnp.pi * pct) + 1) lr = init_lr * a + final_lr * (1 - a) return lr + return schedule lr = config.learning_rate @@ -456,19 +492,15 @@ def schedule(step): cos_steps = config.learning_rate_schedule_steps - warmup_steps constant_zero_steps = config.steps - config.learning_rate_schedule_steps - warmup_schedule = optax.linear_schedule( - init_value=0.0, - end_value=lr, - transition_steps=warmup_steps - ) + warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) constant_schedule = optax.constant_schedule(0.0) pieces = [warmup_schedule, cos_schedule] - boundaries=[ - warmup_steps, - warmup_steps + cos_steps, - ] + boundaries = [ + warmup_steps, + warmup_steps + cos_steps, + ] if constant_zero_steps > 0: pieces.append(constant_schedule) @@ -480,8 +512,7 @@ def schedule(step): # Cross entropy implementation is taken from original T5X codebase: # https://github.com/google-research/t5x/blob/ace831eea1e2742b4299cd1a9af7e4f302038351/t5x/losses.py#L25-L101 @jax.custom_vjp -def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, - z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: +def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: """Computes cross entropy loss with stable custom gradient. Computes a stabilized-gradient version of: -jnp.sum(targets * nn.log_softmax(logits), axis=-1) @@ -511,12 +542,11 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, def _cross_entropy_with_logits_fwd( - logits: jnp.ndarray, - targets: jnp.ndarray, - z_loss: float = 0.0 -) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], - Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, - jnp.ndarray, jnp.ndarray, jnp.ndarray]]: + logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0 +) -> Tuple[ + Tuple[jnp.ndarray, jnp.ndarray], + Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], +]: """Forward-mode of `cross_entropy_with_logits`.""" max_logit = logits.max(axis=-1, keepdims=True) shifted = logits - max_logit @@ -528,32 +558,40 @@ def _cross_entropy_with_logits_fwd( log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss - return (loss, total_z_loss), (logits, targets, z_loss, exp_shifted, sum_exp, #pytype: disable=bad-return-type #jax-ndarray - log_softmax, log_z) + return (loss, total_z_loss), ( + logits, + targets, + z_loss, + exp_shifted, + sum_exp, # pytype: disable=bad-return-type #jax-ndarray + log_softmax, + log_z, + ) def _cross_entropy_with_logits_bwd( - res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, - jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray] + res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], + g: Tuple[jnp.ndarray, jnp.ndarray], ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Backward-mode of `cross_entropy_with_logits`.""" g = g[0] # Ignore z_loss component as that is only used for logging. logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res # z-loss term adds the (2 * z_loss * log_z) factor. - deriv = ( - jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - - targets) + deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets g_logits = jnp.expand_dims(g, axis=-1) * deriv g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax - return (jnp.asarray(g_logits, - logits.dtype), jnp.asarray(g_targets, targets.dtype), - jnp.array(0.0)) # sets z-loss coeff gradient to 0 + return ( + jnp.asarray(g_logits, logits.dtype), + jnp.asarray(g_targets, targets.dtype), + jnp.array(0.0), + ) # sets z-loss coeff gradient to 0 + + +cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) -cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, - _cross_entropy_with_logits_bwd) def get_abstract_state(model, tx, config, rng, mesh, is_training=True): - """ Get a shaped abstraction of the state (including optimizer)""" + """Get a shaped abstraction of the state (including optimizer)""" init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -561,14 +599,9 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): state_logical_annotations = nn.get_partition_spec(abstract_state) - state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, - config.logical_axis_rules) + state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) - abstract_sharded_state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings - ).eval_shape(rng) + abstract_sharded_state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings).eval_shape(rng) unboxed_abstract_sharded_state = unbox_logicallypartioned(abstract_sharded_state) # Initialization @@ -576,24 +609,23 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) return unboxed_abstract_sharded_state, state_mesh_annotations, state_mesh_shardings + def get_kv_cache_annotations(model, config, rng, mesh): - """ Get a shaped abstraction of the state (including optimizer)""" + """Get a shaped abstraction of the state (including optimizer)""" def init_kv_cache(model, config): - input_shape = ( - config.global_batch_size_to_load, - config.max_prefill_predict_length - ) + input_shape = (config.global_batch_size_to_load, config.max_prefill_predict_length) - model_vars = model.init({'params': rng, 'dropout': rng, 'aqt': rng}, - jnp.ones(input_shape), - jnp.ones(input_shape), - model_mode=common_types.MODEL_MODE_PREFILL) - return model_vars['cache'] + model_vars = model.init( + {"params": rng, "dropout": rng, "aqt": rng}, + jnp.ones(input_shape), + jnp.ones(input_shape), + model_mode=common_types.MODEL_MODE_PREFILL, + ) + return model_vars["cache"] with nn_partitioning.axis_rules(config.logical_axis_rules): - init_kv_cache_partial = functools.partial(init_kv_cache, model, - config) + init_kv_cache_partial = functools.partial(init_kv_cache, model, config) abstract_state = jax.eval_shape(init_kv_cache_partial) state_logical_annotations = nn.get_partition_spec(abstract_state) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): @@ -604,17 +636,19 @@ def init_kv_cache(model, config): def print_pytree_shape(print_str, ptree): print("\n") print(print_str) - print(jax.tree_util.tree_map(lambda x : x.shape, ptree)) + print(jax.tree_util.tree_map(lambda x: x.shape, ptree)) + def print_model_vars(print_str, model_vars): for k in model_vars: - print(f'{print_str} key{k}:') - print(f'\t {model_vars[k]}') + print(f"{print_str} key{k}:") + print(f"\t {model_vars[k]}") + def get_project(): completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) - project_outputs = completed_command.stdout.decode().strip().split('\n') - if len(project_outputs) < 1 or project_outputs[-1]=='': + project_outputs = completed_command.stdout.decode().strip().split("\n") + if len(project_outputs) < 1 or project_outputs[-1] == "": max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project '") return None return project_outputs[-1] diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index 275ffac47..ed38929e2 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -''' Implementation of Engine API for MaxText ''' +"""Implementation of Engine API for MaxText""" import functools from typing import Any, Optional, Tuple @@ -39,10 +39,10 @@ Params = Any - @struct.dataclass class DecodeState: """The inputs into a generation step.""" + prefill_cache: jax.Array generate_cache: jax.Array generate_cache_index: int @@ -66,7 +66,7 @@ def __init__(self, config): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - self.model = models.Transformer(config, mesh = self._mesh, quant=quant) + self.model = models.Transformer(config, mesh=self._mesh, quant=quant) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -76,47 +76,48 @@ def __init__(self, config): self.state_mesh_annotations = None def load_params(self, *args, **kwargs) -> Params: - ''' Load Parameters, typically from GCS ''' + """Load Parameters, typically from GCS""" # pylint: disable=unused-argument - state, self.state_mesh_annotations = max_utils.setup_decode_state( - self.model, self.config, self.rng, self._mesh, None + state, self.state_mesh_annotations = max_utils.setup_decode_state(self.model, self.config, self.rng, self._mesh, None) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params ) - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - state.params) self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh) - self.kv_cache_shardings = jax.tree_map(lambda x : jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) + self.kv_cache_shardings = jax.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) if not self.model.quant: - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - state.params) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params + ) return state.params else: - self.model.quant.quant_mode = quantizations.get_quant_mode('convert') + self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @jax.jit def model_apply(_p, _rng): return self.model.apply( - _p | {"aqt": {}}, - jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32), - jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32), - decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32), - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': _rng}, - mutable=True + _p | {"aqt": {}}, + jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + jnp.ones((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32), + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": _rng}, + mutable=True, ) _, new_vars = model_apply(state.params, self.rng) params = {} - params['aqt'] = new_vars['aqt'] + params["aqt"] = new_vars["aqt"] # Remove param values which have corresponding qtensors in aqt to save memory. - params['params'] = quantizations.remove_quantized_params(state.params['params'], new_vars['aqt']) + params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"]) - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - params) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params + ) - self.model.quant.quant_mode = quantizations.get_quant_mode('serve') + self.model.quant.quant_mode = quantizations.get_quant_mode("serve") return params @functools.partial(jax.jit, static_argnums=(0,)) @@ -143,7 +144,7 @@ def prefill( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") - input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) zero_to_n = jnp.arange(0, padded_tokens.shape[0]) @@ -153,45 +154,52 @@ def prefill( with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): flat_logits, new_vars = self.model.apply( - params, - input_tokens, - positions, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': self.rng}, - mutable=["cache"] + params, + input_tokens, + positions, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": self.rng}, + mutable=["cache"], ) - next_pos = jnp.full((1,1), true_length, dtype = jnp.int32) - generated_tokens = jnp.zeros((1,1), dtype = jnp.int32) - selected_logits = jax.lax.dynamic_slice(flat_logits, (0, true_length-1,0), - (flat_logits.shape[0], 1, flat_logits.shape[2])) + next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32) + generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32) + selected_logits = jax.lax.dynamic_slice( + flat_logits, (0, true_length - 1, 0), (flat_logits.shape[0], 1, flat_logits.shape[2]) + ) selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding) - return {"logits" : selected_logits, "cache" : new_vars['cache'], - "next_pos" : next_pos, "generated_tokens" : generated_tokens} + return { + "logits": selected_logits, + "cache": new_vars["cache"], + "next_pos": next_pos, + "generated_tokens": generated_tokens, + } @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,)) - def generate( - self, params: Params, decode_state: DecodeState - ) -> Tuple[DecodeState, engine_api.ResultTokens]: - '''Run one generate step''' - previous_logits = decode_state['logits'] - - new_token = inference_utils.sampling(previous_logits, self.rng, self.config.decode_sampling_strategy, - topk=self.config.decode_sampling_top_k, - nucleus_topp=self.config.decode_sampling_nucleus_p, - temperature=self.config.decode_sampling_temperature) + def generate(self, params: Params, decode_state: DecodeState) -> Tuple[DecodeState, engine_api.ResultTokens]: + """Run one generate step""" + previous_logits = decode_state["logits"] + + new_token = inference_utils.sampling( + previous_logits, + self.rng, + self.config.decode_sampling_strategy, + topk=self.config.decode_sampling_top_k, + nucleus_topp=self.config.decode_sampling_nucleus_p, + temperature=self.config.decode_sampling_temperature, + ) with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): out_logits, new_vars = self.model.apply( - params | { 'cache': decode_state['cache']}, - new_token, - decode_state['next_pos'], - enable_dropout=False, - model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'params': self.rng}, - mutable=['cache'] + params | {"cache": decode_state["cache"]}, + new_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": self.rng}, + mutable=["cache"], ) all_valid = jnp.ones(new_token.shape, dtype=jnp.int8) @@ -212,35 +220,46 @@ def generate( out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) - return {"logits" : out_logits, "cache" : new_cache, - "next_pos" : decode_state["next_pos"]+1, "generated_tokens" : decode_state["generated_tokens"]+1}, result - - @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(1, 2,)) + return { + "logits": out_logits, + "cache": new_cache, + "next_pos": decode_state["next_pos"] + 1, + "generated_tokens": decode_state["generated_tokens"] + 1, + }, result + + @functools.partial( + jax.jit, + static_argnums=(0,), + donate_argnums=( + 1, + 2, + ), + ) def insert( self, prefix: Prefix, decode_state: DecodeState, slot: int, ) -> DecodeState: - ''' Insert into KV cache ''' + """Insert into KV cache""" unboxed_prefix = max_utils.unbox_logicallypartioned(prefix) def copy(path, partial_cache, full_cache, annotations): path_key = path[-1].key - if path_key in ['cache_ar_index', 'cached_ar_key', 'cached_ar_value', 'cached_ar_key_scale', 'cached_ar_value_scale']: - return full_cache # we don't even zero these out because we can mask them out. + if path_key in ["cache_ar_index", "cached_ar_key", "cached_ar_value", "cached_ar_key_scale", "cached_ar_value_scale"]: + return full_cache # we don't even zero these out because we can mask them out. batch_idx = annotations.index("cache_batch") if "cache_batch" in annotations else -1 if batch_idx < 0: raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}") - if path_key == 'cache_ar_segment_id': + if path_key == "cache_ar_segment_id": ### goal: zero this out in case there is existing data s = list(full_cache.shape) s[batch_idx] = 1 zeros = jnp.zeros(tuple(s), dtype=jnp.int32) return jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx) - elif path_key == 'cache_prefill_segment_id': + elif path_key == "cache_prefill_segment_id": s = list(full_cache.shape) s[batch_idx] = 1 zeros = jnp.zeros(tuple(s), dtype=jnp.int32) @@ -249,31 +268,39 @@ def copy(path, partial_cache, full_cache, annotations): ## copy prefill cachce full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) return full_cache - elif path_key in ['cached_prefill_key', 'cached_prefill_value', - 'cached_prefill_key_scale', 'cached_prefill_value_scale']: + elif path_key in [ + "cached_prefill_key", + "cached_prefill_value", + "cached_prefill_key_scale", + "cached_prefill_value_scale", + ]: return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) else: raise ValueError(f"We don't have a strategy for inserting {path_key}") - inserted_cache = jax.tree_util.tree_map_with_path(copy, unboxed_prefix['cache'], decode_state['cache'], - self.kv_cache_annotations_named) - inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state['logits'], unboxed_prefix['logits'], slot, 0) - inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state['next_pos'], unboxed_prefix['next_pos'], slot, 0) - inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(decode_state['generated_tokens'], - unboxed_prefix['generated_tokens'], slot, 0) + inserted_cache = jax.tree_util.tree_map_with_path( + copy, unboxed_prefix["cache"], decode_state["cache"], self.kv_cache_annotations_named + ) + inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0) + inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0) + inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( + decode_state["generated_tokens"], unboxed_prefix["generated_tokens"], slot, 0 + ) inserted_logits = jax.lax.with_sharding_constraint(inserted_logits, self.replicated_sharding) inserted_generated_tokens = jax.lax.with_sharding_constraint(inserted_generated_tokens, self.replicated_sharding) inserted_next_pos = jax.lax.with_sharding_constraint(inserted_next_pos, self.replicated_sharding) inserted_cache = jax.lax.with_sharding_constraint(inserted_cache, self.kv_cache_shardings) - return {'logits' : inserted_logits, 'cache' : inserted_cache, - 'next_pos' : inserted_next_pos, 'generated_tokens' : inserted_generated_tokens } + return { + "logits": inserted_logits, + "cache": inserted_cache, + "next_pos": inserted_next_pos, + "generated_tokens": inserted_generated_tokens, + } def get_prefix_destination_sharding(self) -> Any: - return jax.sharding.NamedSharding( - mesh=self.mesh, spec=jax.sharding.PartitionSpec() - ) + return jax.sharding.NamedSharding(mesh=self.mesh, spec=jax.sharding.PartitionSpec()) def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: """Return a protobuf of tokenizer info, callable from Py or C++.""" @@ -281,28 +308,32 @@ def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: def init_decode_state(self, *args, **kwargs) -> DecodeState: """Initialises any state which a generation step transforms.""" + # pylint: disable=unused-argument def init(abstract_params): - x = jnp.ones( (int(self.config.per_device_batch_size * jax.device_count()), self.config.max_prefill_predict_length), - dtype=jnp.int32) + x = jnp.ones( + (int(self.config.per_device_batch_size * jax.device_count()), self.config.max_prefill_predict_length), + dtype=jnp.int32, + ) _, cache = self.model.apply( - abstract_params, - x, - x, - decoder_segment_ids=jnp.zeros(x.shape, dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR, - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': self.rng}, - mutable=["cache"] + abstract_params, + x, + x, + decoder_segment_ids=jnp.zeros(x.shape, dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": self.rng}, + mutable=["cache"], ) next_pos = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32) generated_tokens = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32) - return {"logits" : jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1, self.config.vocab_size)), - "cache" : cache["cache"], - "next_pos" : next_pos, - "generated_tokens" : generated_tokens - } + return { + "logits": jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1, self.config.vocab_size)), + "cache": cache["cache"], + "next_pos": next_pos, + "generated_tokens": generated_tokens, + } with nn_partitioning.axis_rules(self.config.logical_axis_rules): abstract_outputs = jax.eval_shape(init, self.abstract_params) @@ -311,18 +342,20 @@ def init(abstract_params): with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): mesh_annotations = nn.logical_to_mesh(logical_annotations) - shardings = jax.tree_map(lambda mesh_annotation : jax.sharding.NamedSharding(self._mesh, mesh_annotation), - mesh_annotations) + shardings = jax.tree_map( + lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation), mesh_annotations + ) - @functools.partial(jax.jit, out_shardings = shardings) + @functools.partial(jax.jit, out_shardings=shardings) def initialize(): - return jax.tree_map( lambda x : jnp.zeros(x.shape, x.dtype), abstract_outputs) + return jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs) - cache = initialize()['cache'] + cache = initialize()["cache"] def is_lp(k): return isinstance(k, flax.linen.spmd.LogicallyPartitioned) - self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x : tuple(x.names), cache, is_leaf=is_lp) + + self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x: tuple(x.names), cache, is_leaf=is_lp) del cache zeroed = max_utils.unbox_logicallypartioned(initialize()) return zeroed diff --git a/MaxText/maxengine_config.py b/MaxText/maxengine_config.py index da96c06f9..967b0eb52 100644 --- a/MaxText/maxengine_config.py +++ b/MaxText/maxengine_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -'''Configure MaxText For JetStream''' +"""Configure MaxText For JetStream""" import functools import jax @@ -22,26 +22,23 @@ from jetstream.engine import engine_api import maxengine + def create_maxengine(devices: config_lib.Devices, config: Any) -> engine_api.Engine: del devices return maxengine.MaxEngine(config) def get_server_config(config_str: str, config: Any) -> Type[config_lib.ServerConfig]: - ''' Gets the Server Config Required by JetStream ''' + """Gets the Server Config Required by JetStream""" match config_str: - case 'MaxtextInterleavedServer': + case "MaxtextInterleavedServer": server_config = config_lib.ServerConfig( - prefill_slices = (), - generate_slices = (), - interleaved_slices = ('tpu='+str(jax.device_count()),), - prefill_engine_create_fns = (), - generate_engine_create_fns = (), - interleaved_engine_create_fns = (functools.partial( - create_maxengine, - config=config - ), - ) + prefill_slices=(), + generate_slices=(), + interleaved_slices=("tpu=" + str(jax.device_count()),), + prefill_engine_create_fns=(), + generate_engine_create_fns=(), + interleaved_engine_create_fns=(functools.partial(create_maxengine, config=config),), ) case _: raise NotImplementedError diff --git a/MaxText/maxengine_server.py b/MaxText/maxengine_server.py index 3cf46c8a5..39fcde33d 100644 --- a/MaxText/maxengine_server.py +++ b/MaxText/maxengine_server.py @@ -19,7 +19,7 @@ import sys import pyconfig -import maxengine_config +import maxengine_config from jetstream.core import server_lib # _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') @@ -36,7 +36,7 @@ def main(config): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() - server_config = maxengine_config.get_server_config('MaxtextInterleavedServer', config) + server_config = maxengine_config.get_server_config("MaxtextInterleavedServer", config) # We separate credential from run so that we can unit test it with # local credentials. # TODO: Add grpc credentials for OSS. @@ -49,8 +49,8 @@ def main(config): jetstream_server.wait_for_termination() -if __name__ == '__main__': - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') +if __name__ == "__main__": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(sys.argv) cfg = pyconfig.config diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 6b7a81b03..b41085a18 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting to MaxText. """ @@ -28,45 +28,45 @@ from input_pipeline import input_pipeline_interface - def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config): - """ Get the shardings (both state and data) for train_step """ + """Get the shardings (both state and data) for train_step""" functional_train = get_functional_train_step(train_step, model, config) functional_train.__name__ = "train_step" data_pspec = P(*config.data_sharding) - state_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) - data_sharding = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng - out_shardings = (state_mesh_shardings, None) # State, metrics - static_argnums = () # We partial out the static argnums of model and config - donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. + state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shardings = (state_mesh_shardings, None) # State, metrics + static_argnums = () # We partial out the static argnums of model and config + donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums + def get_functional_train_step(train_step, model, config): return functools.partial(train_step, model, config) + def get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config): - """ Get the shardings (both state and data) for eval_step """ + """Get the shardings (both state and data) for eval_step""" functional_eval = get_functional_eval_step(eval_step, model, config) functional_eval.__name__ = "eval_step" data_pspec = P(*config.data_sharding) - state_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) - data_sharding = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng - out_shardings = None # metrics - static_argnums = () # We partial out the static argnums of model, config - donate_argnums = () # state will be kept instead of being donated in eval_step + state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shardings = None # metrics + static_argnums = () # We partial out the static argnums of model, config + donate_argnums = () # state will be kept instead of being donated in eval_step return functional_eval, in_shardings, out_shardings, static_argnums, donate_argnums + def get_functional_eval_step(eval_step, model, config): return functools.partial(eval_step, model, config) + def load_compiled(config, partial_train, state): - """ # Loading a serialized compiled train step function.""" + """# Loading a serialized compiled train step function.""" + # Currently partial_train and state are needed to reconstruct # input/output shapes to construct the in_trees and out_trees for load API # Parker is working on a serializing these @@ -90,40 +90,58 @@ def get_train_input_output_trees(func, input_args, input_kwargs): p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree) return p_train_step + # https://arxiv.org/pdf/2204.02311.pdf Appendix B def calculate_tflops_training_per_device(num_model_parameters, config, log=True): - """ Calculate training TFLOP""" - learnable_weight_tflops = 6 * num_model_parameters * config.max_target_length * config.per_device_batch_size \ - / 10**12 - noncasual_attention_flops = 12 * config.num_query_heads * config.num_decoder_layers * config.head_dim \ - * config.max_target_length**2 * config.per_device_batch_size / 10**12 - causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention + """Calculate training TFLOP""" + learnable_weight_tflops = 6 * num_model_parameters * config.max_target_length * config.per_device_batch_size / 10**12 + noncasual_attention_flops = ( + 12 + * config.num_query_heads + * config.num_decoder_layers + * config.head_dim + * config.max_target_length**2 + * config.per_device_batch_size + / 10**12 + ) + causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention total_tflops = learnable_weight_tflops + causal_attention_tflops if log: - print('Per train step:\n', - f'Total TFLOPs: {total_tflops:.2f} \n', - f'split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops', - f'and {100 * causal_attention_tflops/total_tflops:.2f}% attention flops') + print( + "Per train step:\n", + f"Total TFLOPs: {total_tflops:.2f} \n", + f"split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops", + f"and {100 * causal_attention_tflops/total_tflops:.2f}% attention flops", + ) return total_tflops, learnable_weight_tflops, causal_attention_tflops + # https://arxiv.org/pdf/2204.02311.pdf Appendix B def calculate_tflops_prefill(num_model_parameters, prefill_length, config, log=True): - """ Calculate training TFLOP""" - learnable_weight_tflops = 2 * num_model_parameters * prefill_length \ - / 10**12 - noncasual_attention_flops = 4 * config.num_query_heads * config.num_decoder_layers * config.head_dim \ - * prefill_length**2 * config.per_device_batch_size / 10**12 - causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention + """Calculate training TFLOP""" + learnable_weight_tflops = 2 * num_model_parameters * prefill_length / 10**12 + noncasual_attention_flops = ( + 4 + * config.num_query_heads + * config.num_decoder_layers + * config.head_dim + * prefill_length**2 + * config.per_device_batch_size + / 10**12 + ) + causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention total_tflops = learnable_weight_tflops + causal_attention_tflops if log: - print('Per prefill step: \n', - f'\tTotal TFLOPs: {total_tflops:.2f} \n', - f'\t\tLearnable weight TFLOPs: {learnable_weight_tflops} ', - f'({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n', - f'\t\tCausal attention TFLOPs: {causal_attention_tflops} ', - f'({100 * causal_attention_tflops/total_tflops:.2f})% of Total') + print( + "Per prefill step: \n", + f"\tTotal TFLOPs: {total_tflops:.2f} \n", + f"\t\tLearnable weight TFLOPs: {learnable_weight_tflops} ", + f"({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n", + f"\t\tCausal attention TFLOPs: {causal_attention_tflops} ", + f"({100 * causal_attention_tflops/total_tflops:.2f})% of Total", + ) return total_tflops, learnable_weight_tflops, causal_attention_tflops @@ -144,22 +162,15 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01): bool: True if the majority of parameters are sufficiently sharded """ total_num_params = max_utils.calculate_num_params_from_pytree(params) - product_num_devices_for_weight_sharding = 1 - for axis in ['fsdp', 'fsdp_transpose', 'sequence', 'tensor']: + product_num_devices_for_weight_sharding = 1 + for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] - total_num_params_per_chip = ( - max_utils.calculate_total_params_per_chip( - params) + total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) + perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding + assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( + "Number of parameters per chip must not be less than in the ideal sharded " + "scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes." ) - perfectly_sharded_params_per_chip = ( - total_num_params / product_num_devices_for_weight_sharding + assert total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 < tolerance, ( + f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " "of total parameters." ) - assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( - 'Number of parameters per chip must not be less than in the ideal sharded ' - 'scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes.' - ) - assert ( - total_num_params_per_chip/perfectly_sharded_params_per_chip - 1 < tolerance - ), (f'Number of unsharded parameters exceeds tolerance {tolerance * 100}% ' - 'of total parameters.') - diff --git a/MaxText/multihost_dataloading.py b/MaxText/multihost_dataloading.py index 8c9088961..fca337691 100644 --- a/MaxText/multihost_dataloading.py +++ b/MaxText/multihost_dataloading.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=unused-import """SPMD Multihost Dataloading Utilities. @@ -36,6 +36,7 @@ import max_logging + def _build_global_shape_and_sharding( local_shape: tuple[int, ...], global_mesh: Mesh ) -> tuple[tuple[int, ...], NamedSharding]: @@ -47,27 +48,23 @@ def _build_global_shape_and_sharding( def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: - """ Put local sharded array into local devices - """ + """Put local sharded array into local devices""" global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh) try: local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0) except ValueError as array_split_error: raise ValueError( - f"Unable to put to devices shape {array.shape} with " - f"local device count {len(global_mesh.local_devices)} " - f"at {jtu.keystr(path)}" + f"Unable to put to devices shape {array.shape} with " + f"local device count {len(global_mesh.local_devices)} " + f"at {jtu.keystr(path)}" ) from array_split_error local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices) return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) - -def get_next_batch_sharded( - local_iterator: Iterator, global_mesh: Mesh -) -> jax.Array: +def get_next_batch_sharded(local_iterator: Iterator, global_mesh: Mesh) -> jax.Array: """Splits the host loaded data equally over all devices.""" SLEEP_TIME = 10 @@ -88,13 +85,14 @@ def get_next_batch_sharded( if not loaded_data_success: local_data = next(local_iterator) - input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh = global_mesh), local_data) + input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh=global_mesh), local_data) return input_gdas class MultiHostDataLoadIterator: """fold get_next_batch_sharded into a iterator class""" + def __init__(self, dataloader: Union[tf.data.Dataset, grain.DataLoader], global_mesh: Mesh): self.global_mesh = global_mesh self.dataloader = dataloader diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index e2a23abda..63fcc42b1 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator, ungrouped-imports """Utils that are only interesting to MaxText. """ @@ -29,25 +29,26 @@ def get_optimizer(config, learning_rate_schedule): if config.opt_type == "adamw": # Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 return optax.adamw( - learning_rate_schedule, - b1=config.adam_b1, - b2=config.adam_b2, - eps=config.adam_eps, - eps_root=config.adam_eps_root, - weight_decay=config.adam_weight_decay, + learning_rate_schedule, + b1=config.adam_b1, + b2=config.adam_b2, + eps=config.adam_eps, + eps_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, ) elif config.opt_type == "adam_pax": return adam_pax( - learning_rate_schedule, - beta1=config.adam_b1, - beta2=config.adam_b2, - epsilon=config.adam_eps, - epsilon_root=config.adam_eps_root, - weight_decay=config.adam_weight_decay, + learning_rate_schedule, + beta1=config.adam_b1, + beta2=config.adam_b2, + epsilon=config.adam_eps, + epsilon_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, ) else: raise ValueError(f"{config.opt_type=} is not a supported.") + def adam_pax( learning_rate_fn: optax.Schedule, beta1: float, @@ -55,7 +56,7 @@ def adam_pax( epsilon: float, epsilon_root: float, weight_decay: float, - ) -> optax.GradientTransformation: +) -> optax.GradientTransformation: """Standard Adam optimizer that supports weight decay. Follows the implemenation in pax/praxis sharded_adam @@ -77,8 +78,7 @@ def adam_pax( """ def init_fn(params): - mu = jax.tree_util.tree_map( # First moment - jnp.zeros_like, params) + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment return optax.ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) @@ -102,8 +102,8 @@ def bias_corrected_decay(step: jnp.int32, decay: float): Returns: Bias corrected decay. """ - t = step.astype(jnp.float32) + 1. - return decay * (1. - jnp.power(decay, t - 1.)) / (1. - jnp.power(decay, t)) + t = step.astype(jnp.float32) + 1.0 + return decay * (1.0 - jnp.power(decay, t - 1.0)) / (1.0 - jnp.power(decay, t)) def update_fn(updates, state, params=None): # Sanitize updates just in case. @@ -112,6 +112,7 @@ def update_fn(updates, state, params=None): count = state.count class _slot_opt_state: + def __init__(self, mu, nu): self.mu = mu self.nu = nu @@ -133,8 +134,7 @@ def _update_momentum(update, mu, nu): mu = jax.tree_map(lambda x: x.mu, updated_moments) nu = jax.tree_map(lambda x: x.nu, updated_moments) - updates = jax.tree_map( - lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) + updates = jax.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) if weight_decay > 0: updates = jax.tree_map(lambda x, v: x + weight_decay * v, updates, params) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index e19af5456..204b32463 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, bare-except, consider-using-generator from collections import OrderedDict @@ -35,9 +35,11 @@ # YAML attribute to specify inheritance. _BASE_CONFIG_ATTR = "base_config" + def yaml_key_to_env_key(s: str) -> str: return _MAX_PREFIX + s.upper() + def string_to_bool(s: str) -> bool: if s.lower() == "true": return True @@ -45,60 +47,80 @@ def string_to_bool(s: str) -> bool: return False raise ValueError(f"Can't convert {s} to bool") -_yaml_types_to_parser = {str : str, int : int, float : float, bool : string_to_bool} + +_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool} + def validate_attention_type(s: str) -> None: - valid_attention_types = ('autoselected', 'dot_product', 'flash', 'cudnn_flash_te') - if s not in valid_attention_types: # currently supported attention - raise ValueError( - "Invalid attention type was passed. Valid options ", valid_attention_types - ) + valid_attention_types = ("autoselected", "dot_product", "flash", "cudnn_flash_te") + if s not in valid_attention_types: # currently supported attention + raise ValueError("Invalid attention type was passed. Valid options ", valid_attention_types) + def validate_keys(keys): - validate_attention_type(keys['attention']) + validate_attention_type(keys["attention"]) + + assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[ + "enable_checkpointing" + ], "You must set enable_checkpointing to load a checkpoint" + assert ( + keys["load_parameters_path"] == "" or keys["load_full_state_path"] == "" + ), "At most one of `load_parameters_path` or `load_full_state_path` should be set" - assert ((keys["load_parameters_path"]=="" and keys["load_full_state_path"]=="") or - keys["enable_checkpointing"]), "You must set enable_checkpointing to load a checkpoint" - assert keys["load_parameters_path"]=="" or keys["load_full_state_path"]=="",\ - "At most one of `load_parameters_path` or `load_full_state_path` should be set" def validate_model_name(s: str) -> bool: + """Validate provided model name.""" # currently supported models - valid_model_names = ('default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'mistral-7b', - 'mixtral-8x7b', 'gemma-7b','gemma-2b', - 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k') + valid_model_names = ( + "default", + "llama2-7b", + "llama2-13b", + "llama2-70b", + "mistral-7b", + "mixtral-8x7b", + "gemma-7b", + "gemma-2b", + "gpt3-175b", + "gpt3-22b", + "gpt3-6b", + "gpt3-52k", + ) if s not in valid_model_names: - raise ValueError( - "Invalid model name was passed. Valid options ", valid_model_names - ) + raise ValueError("Invalid model name was passed. Valid options ", valid_model_names) + def validate_no_keys_overwritten_twice(keys1: list[str], keys2: list[str]): overwritten_keys = [k for k in keys1 if k in keys2] if overwritten_keys: raise ValueError( f"Keys {overwritten_keys} are overwritten from both the model" - " and the environment/command line. This isn't allowed.") + " and the environment/command line. This isn't allowed." + ) + _config = None config = None + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") max_logging.log(f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}") -def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any],list[Any]]: + +def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any], list[Any]]: return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l -class _HyperParameters(): + +class _HyperParameters: # pylint: disable=missing-class-docstring def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]): for environment_var in os.environ: - if environment_var[:len(_MAX_PREFIX)] == _MAX_PREFIX: - proposed_key = environment_var[len(_MAX_PREFIX):].lower() + if environment_var[: len(_MAX_PREFIX)] == _MAX_PREFIX: + proposed_key = environment_var[len(_MAX_PREFIX) :].lower() if proposed_key not in raw_data_from_yaml: raise ValueError(f"We received env `{environment_var}` but it doesn't match a key, so it is assumed a mistake.") - if not environment_var[len(_MAX_PREFIX):].isupper(): + if not environment_var[len(_MAX_PREFIX) :].isupper(): raise ValueError(f"We received env `{environment_var}` but it isn't all uppercase.") def _load_kwargs(self, argv: list[str], **kwargs): @@ -107,21 +129,17 @@ def _load_kwargs(self, argv: list[str], **kwargs): return args_dict def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]: - ''' Update model config from environment and command line - ''' + """Update model config from environment and command line""" raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs) updated_keys = [] for k in raw_data_from_cmd_line: if k not in raw_data_from_yaml: - raise ValueError( - f"Key {k} was passed at the command line but isn't in config." - ) + raise ValueError(f"Key {k} was passed at the command line but isn't in config.") for k in raw_data_from_yaml: if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ: - raise ValueError( - f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.") + raise ValueError(f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.") if not k in raw_data_from_cmd_line and not yaml_key_to_env_key(k) in os.environ: raw_keys[k] = raw_data_from_yaml[k] @@ -133,8 +151,9 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, else: new_proposal = os.environ.get(yaml_key_to_env_key(k)) - if (not isinstance(new_proposal, type(raw_data_from_yaml[k]))) and \ - (type(raw_data_from_yaml[k]) not in _yaml_types_to_parser): + if (not isinstance(new_proposal, type(raw_data_from_yaml[k]))) and ( + type(raw_data_from_yaml[k]) not in _yaml_types_to_parser + ): raise ValueError( f"For key '{k}', type {type(raw_data_from_yaml[k])} not in {_yaml_types_to_parser.keys()}, can't pass" " at the CLI or ENV" @@ -148,8 +167,7 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, new_proposal ) # take the command line value, but type it like the config value. except ValueError as e: - raise ValueError( - f"Couldn't parse value from CLI or ENV '{new_proposal}' for key '{k}'") from e + raise ValueError(f"Couldn't parse value from CLI or ENV '{new_proposal}' for key '{k}'") from e return updated_keys @@ -163,14 +181,10 @@ def _load_config(self, config_name: str) -> dict[str, Any]: if _BASE_CONFIG_ATTR in raw_data_from_yaml: parent_config_filename = raw_data_from_yaml[_BASE_CONFIG_ATTR] if not os.path.isabs(parent_config_filename): - loaded_parent_config_filename = os.path.join( - os.path.dirname(config_name), parent_config_filename - ) + loaded_parent_config_filename = os.path.join(os.path.dirname(config_name), parent_config_filename) if not os.path.isfile(loaded_parent_config_filename): dir_path = os.path.dirname(os.path.realpath(__file__)) - loaded_parent_config_filename = os.path.join( - dir_path, f"configs/{parent_config_filename}" - ) + loaded_parent_config_filename = os.path.join(dir_path, f"configs/{parent_config_filename}") else: loaded_parent_config_filename = parent_config_filename @@ -181,15 +195,14 @@ def _load_config(self, config_name: str) -> dict[str, Any]: return raw_data_from_yaml def __init__(self, argv: list[str], **kwargs): - config_name : str = argv[1] + config_name: str = argv[1] raw_data_from_yaml = self._load_config(config_name) self._validate_env_variables(raw_data_from_yaml) raw_keys = OrderedDict() keys_from_env_and_command_line = self._update_from_env_and_command_line(raw_keys, raw_data_from_yaml, argv, **kwargs) - max_logging.log( - f"Updating keys from env and command line: {keys_from_env_and_command_line}") + max_logging.log(f"Updating keys from env and command line: {keys_from_env_and_command_line}") keys_from_model = _HyperParameters.update_model_vars(argv[1], raw_keys, config_name) max_logging.log(f"Updating keys from model: {keys_from_model}") validate_no_keys_overwritten_twice(keys_from_env_and_command_line, keys_from_model) @@ -197,24 +210,24 @@ def __init__(self, argv: list[str], **kwargs): # We initialize the jax distributed system here because it must be done before device backend is initialized. max_utils.maybe_initialize_jax_distributed_system(raw_keys) - if raw_keys['jax_cache_dir']: - compilation_cache.set_cache_dir(os.path.expanduser(raw_keys['jax_cache_dir'])) + if raw_keys["jax_cache_dir"]: + compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"])) - if raw_keys['model_name'] == "gpt3-175b": + if raw_keys["model_name"] == "gpt3-175b": _HyperParameters.configure_gpt3_task(raw_keys) _HyperParameters.user_init(raw_keys) self.keys = raw_keys - keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension + keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension keys.sort() for k in keys: max_logging.log(f"Config param {k}: {raw_keys[k]}") @staticmethod def user_init(raw_keys): - '''Transformations between the config data and configs used at runtime''' + """Transformations between the config data and configs used at runtime""" if raw_keys["run_name"] == "": - raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default + raw_keys["run_name"] = os.environ.get("JOBSET_NAME") # using XPK default run_name = raw_keys["run_name"] base_output_directory = raw_keys["base_output_directory"] if run_name: @@ -222,22 +235,21 @@ def user_init(raw_keys): raw_keys["checkpoint_dir"] = os.path.join(base_output_directory, run_name, "checkpoints", "") raw_keys["metrics_dir"] = os.path.join(base_output_directory, run_name, "metrics", "") - if raw_keys["learning_rate_schedule_steps"]==-1: + if raw_keys["learning_rate_schedule_steps"] == -1: raw_keys["learning_rate_schedule_steps"] = raw_keys["steps"] - if raw_keys["steps"]==-1: + if raw_keys["steps"] == -1: raw_keys["steps"] = raw_keys["learning_rate_schedule_steps"] - emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys['global_parameter_scale']) - raw_keys['emb_dim'] = 2**emb_scale * raw_keys['base_emb_dim'] - raw_keys['num_query_heads'] = 2**num_head_scale * raw_keys['base_num_query_heads'] - raw_keys['num_kv_heads'] = 2**num_head_scale * raw_keys['base_num_kv_heads'] - raw_keys['mlp_dim'] = 2**mlp_dim_scale * raw_keys['base_mlp_dim'] - raw_keys['num_decoder_layers'] = 2**layer_scale * raw_keys['base_num_decoder_layers'] + emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys["global_parameter_scale"]) + raw_keys["emb_dim"] = 2**emb_scale * raw_keys["base_emb_dim"] + raw_keys["num_query_heads"] = 2**num_head_scale * raw_keys["base_num_query_heads"] + raw_keys["num_kv_heads"] = 2**num_head_scale * raw_keys["base_num_kv_heads"] + raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"] + raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"] - raw_keys['global_batch_size_to_load'], raw_keys['global_batch_size_to_train_on'] = \ - calculate_global_batch_sizes(raw_keys) - raw_keys['num_slices'] = get_num_slices(raw_keys) - raw_keys['quantization_local_shard_count'] = get_quantization_local_shard_count(raw_keys) + raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = calculate_global_batch_sizes(raw_keys) + raw_keys["num_slices"] = get_num_slices(raw_keys) + raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) print_system_information() @@ -253,73 +265,72 @@ def user_init(raw_keys): @staticmethod def configure_gpt3_task(raw_keys): - '''dynamically configure gpt3 task based on training rules''' + """dynamically configure gpt3 task based on training rules""" # follow https://github.com/google/paxml/blob/19db52eed85ae0d2365339b83a97cd0b873bbf73/paxml/tasks/lm/params/c4.py#L280 # according to training_rules of mlperf gpt3 training global_batch_size = calculate_global_batch_sizes(raw_keys)[1] if global_batch_size <= 3584: - raw_keys['learning_rate'] = 2e-5 + raw_keys["learning_rate"] = 2e-5 else: - raw_keys['learning_rate'] = 3e-5 + raw_keys["learning_rate"] = 3e-5 warmup_steps = math.ceil(265.0 * 1536 / global_batch_size - 1e-6) decay_end_step = math.ceil(108600.0 * 1536 / global_batch_size - 1e-6) - raw_keys['learning_rate_schedule_steps'] = decay_end_step - raw_keys['warmup_steps_fraction'] = warmup_steps / decay_end_step + raw_keys["learning_rate_schedule_steps"] = decay_end_step + raw_keys["warmup_steps_fraction"] = warmup_steps / decay_end_step global_batch_size_to_train_on = calculate_global_batch_sizes(raw_keys)[1] - raw_keys['eval_interval'] = math.ceil(24567 / global_batch_size_to_train_on) + raw_keys["eval_interval"] = math.ceil(24567 / global_batch_size_to_train_on) @staticmethod - def update_model_vars(base_config_path, raw_keys, config_name : str): - ''' Update model config variables - ''' - validate_model_name(raw_keys['model_name']) + def update_model_vars(base_config_path, raw_keys, config_name: str): + """Update model config variables""" + validate_model_name(raw_keys["model_name"]) max_logging.log(f"Running Model: {raw_keys['model_name']}") updated_keys = [] - if raw_keys['model_name'] != 'default': - model_name = raw_keys['model_name'] + if raw_keys["model_name"] != "default": + model_name = raw_keys["model_name"] # First look at the model configs next to the base_config_path, and # fallback to the python codebase if the config cannot be found. - file_path = os.path.join( - os.path.dirname(base_config_path), f"models/{model_name}.yml" - ) + file_path = os.path.join(os.path.dirname(base_config_path), f"models/{model_name}.yml") if not os.path.isfile(file_path): dir_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.join(dir_path, f"configs/models/{model_name}.yml") - with open(file_path, 'r', encoding="utf-8") as file: + with open(file_path, "r", encoding="utf-8") as file: model_vars = yaml.safe_load(file) updated_keys = list(model_vars.keys()) raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name) return updated_keys -def validate_and_update_keys(raw_keys, model_keys, config_name : str): - ''' Validate and update model specific config keys - ''' + +def validate_and_update_keys(raw_keys, model_keys, config_name: str): + """Validate and update model specific config keys""" max_logging.log("Updating following parameters in config\n") for k in model_keys: max_logging.log(f"{k}: {model_keys[k]}") if k not in raw_keys: - raise ValueError(f'Key {k} does not exist in config {config_name}.') + raise ValueError(f"Key {k} does not exist in config {config_name}.") elif not isinstance(raw_keys[k], type(model_keys[k])): - raise ValueError(f'Type of key:{k} does not match with {type(model_keys[k])}') + raise ValueError(f"Type of key:{k} does not match with {type(model_keys[k])}") else: raw_keys[k] = model_keys[k] return raw_keys + def get_individual_scales(scale): - '''Choose appropriate scales for individual dimensions based on global scale + """Choose appropriate scales for individual dimensions based on global scale We choose to rotate between doubling: num_head and mlp_dim embed_dim num_layers Any one of these steps is not a perfect doubling, although going through a cycle - of three is a near perfect 8x scaling except for the linear -> softmax -> output step''' - + of three is a near perfect 8x scaling except for the linear -> softmax -> output step""" log_2_scale = math.floor((math.log2(scale))) if 2**log_2_scale != scale: - raise ValueError("Global parameter scale should be a power of 2. If you want finer grained control of the model sizes " - "then you can explicitly set base_embed_dim, base_num_heads, base_mlp_dim, base_num_decoder_layers and/or head_dim.") + raise ValueError( + "Global parameter scale should be a power of 2. If you want finer grained control of the model sizes " + "then you can explicitly set base_embed_dim, base_num_heads, base_mlp_dim, base_num_decoder_layers and/or head_dim." + ) base_scale, rem = divmod(log_2_scale, 3) num_head_scale = base_scale + int(rem > 0) mlp_dim_scale = num_head_scale @@ -327,10 +338,11 @@ def get_individual_scales(scale): layer_scale = base_scale return emb_scale, num_head_scale, mlp_dim_scale, layer_scale + def calculate_global_batch_sizes(raw_keys): - """ Calculates target global batch size from target devices and per_device_batch""" - per_device_batch_size = raw_keys['per_device_batch_size'] - expansion_factor_real_data = raw_keys['expansion_factor_real_data'] + """Calculates target global batch size from target devices and per_device_batch""" + per_device_batch_size = raw_keys["per_device_batch_size"] + expansion_factor_real_data = raw_keys["expansion_factor_real_data"] num_devices = get_num_target_devices(raw_keys) if per_device_batch_size < 1.0: # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 @@ -347,17 +359,19 @@ def calculate_global_batch_sizes(raw_keys): global_batch_size_to_train_on = int(num_devices * per_device_batch_size) return global_batch_size_to_load, global_batch_size_to_train_on + def get_num_target_devices(raw_keys): - compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys.get('compile_topology', "")) + compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys.get("compile_topology", "")) if compile_topology is not None: devices_per_slice = compile_topology.devices_per_slice - return int(devices_per_slice * raw_keys['compile_topology_num_slices']) + return int(devices_per_slice * raw_keys["compile_topology_num_slices"]) else: return len(jax.devices()) + def get_num_slices(raw_keys): - if int(raw_keys['compile_topology_num_slices']) > 0: - return raw_keys['compile_topology_num_slices'] + if int(raw_keys["compile_topology_num_slices"]) > 0: + return raw_keys["compile_topology_num_slices"] else: devices = jax.devices() try: @@ -365,13 +379,16 @@ def get_num_slices(raw_keys): except: return 1 + def get_quantization_local_shard_count(raw_keys): - if raw_keys['quantization_local_shard_count'] == -1: - return raw_keys['num_slices'] + if raw_keys["quantization_local_shard_count"] == -1: + return raw_keys["num_slices"] else: - return raw_keys['quantization_local_shard_count'] + return raw_keys["quantization_local_shard_count"] + + +class HyperParameters: # pylint: disable=missing-class-docstring -class HyperParameters(): # pylint: disable=missing-class-docstring def __init__(self): pass @@ -386,11 +403,13 @@ def __setattr__(self, attr, value): def get_keys(self): return _config.keys + def initialize(argv, **kwargs): global _config, config _config = _HyperParameters(argv, **kwargs) config = HyperParameters() + if __name__ == "__main__": initialize(sys.argv) print(config.steps) diff --git a/MaxText/sequence_packing.py b/MaxText/sequence_packing.py index 36bccde0c..d8ba28082 100644 --- a/MaxText/sequence_packing.py +++ b/MaxText/sequence_packing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Packed Sequence Op.""" @@ -23,9 +23,9 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, key2length: Union[int, Dict[str, int]], keys: Optional[List[str]] = None +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate @@ -66,29 +66,28 @@ def pack_dataset(dataset: tf.data.Dataset, keys = list(shapes.keys()) for k in keys: if k not in shapes: - raise ValueError(f"""Key {k} not found in dataset. Available keys are - {shapes.keys()}""") + raise ValueError( + f"""Key {k} not found in dataset. Available keys are + {shapes.keys()}""" + ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): - raise ValueError('Tensors to be packed must be one-dimensional.') + raise ValueError("Tensors to be packed must be one-dimensional.") # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" if isinstance(key2length, int): key2length = {k: key2length for k in keys} for k in keys: - for suffix in ['_segmentation', '_position']: + for suffix in ["_segmentation", "_position"]: key2length[k + suffix] = key2length[k] # trim to length - dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + dataset = dataset.map(lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) # We pad with a negative value instead of the default 0 because 0 is a # valid token for some tokenizers for e.g., representing unknown value - dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1) + dataset = dataset.padded_batch(batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -98,8 +97,7 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int]) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: @@ -112,16 +110,14 @@ def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) - empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) + empty_example[k + "_position"] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: - new_outputs[k] = outputs[k].write( - outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + new_outputs[k] = outputs[k].write(outputs[k].size(), tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) return new_partial, new_outputs def map_fn(x): @@ -138,10 +134,8 @@ def map_fn(x): dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: - outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) - outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + outputs[k] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + outputs[k + "_position"] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -157,13 +151,10 @@ def body_fn(i, partial, outputs): for k in keys: val = tf.cast(x[k][i], tf.int32) # We consider only the valid tokens i.e., token_id != -1 - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, -1), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, -1), tf.int32))] one_example[k] = val for k in keys: - can_append = tf.logical_and( - can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + can_append = tf.logical_and(can_append, tf.less_equal(tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) def false_fn(): return write_packed_example(partial, outputs) @@ -174,12 +165,10 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) - new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], - tf.range(new_seq_len)], 0) + new_partial[k + "_position"] = tf.concat([partial[k + "_position"], tf.range(new_seq_len)], 0) partial = new_partial return i + 1, partial, outputs @@ -193,14 +182,14 @@ def true_fn(): {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ), - maximum_iterations=dynamic_batch_size) + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + "_segmentation"] = tf.cumsum(tf.cast(tf.equal(packed[k + "_position"], 0), tf.int32), axis=1) * tf.cast( + tf.not_equal(packed[k], 0), tf.int32 + ) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) diff --git a/MaxText/standalone_checkpointer.py b/MaxText/standalone_checkpointer.py index 7e2bb2c1c..fcfb9631d 100644 --- a/MaxText/standalone_checkpointer.py +++ b/MaxText/standalone_checkpointer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Standalone checkpointer - only saves and restores checkpoints at regular intervals, accesses storage needs.""" @@ -41,6 +41,7 @@ Transformer = models.Transformer + def checkpoint_loop(config, state=None): """Main Checkpointing loop. Saves checkpoints. @@ -50,33 +51,29 @@ def checkpoint_loop(config, state=None): ckpt_path: Returns: """ - init_rng, _ , checkpoint_manager, mesh, model, _, tx = setup_mesh_and_model(config) + init_rng, _, checkpoint_manager, mesh, model, _, tx = setup_mesh_and_model(config) - unboxed_abstract_state, _, _ = max_utils.get_abstract_state(model, tx, - config, init_rng, mesh, is_training=True) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) # A barrier to sync all hosts before starting to restore checkpoint jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() with nn_partitioning.axis_rules(config.logical_axis_rules): - state, _ = checkpointing.load_state_if_possible(checkpoint_manager, - None, - config.load_parameters_path, - config.load_full_state_path, - unboxed_abstract_state) + state, _ = checkpointing.load_state_if_possible( + checkpoint_manager, None, config.load_parameters_path, config.load_full_state_path, unboxed_abstract_state + ) if state: - state = state['items'] + state = state["items"] jax.block_until_ready(state) checkpoint_load_end = datetime.datetime.now() - if state is not None: # Checkpoint was available for restore + if state is not None: # Checkpoint was available for restore if jax.process_index() == 0: max_logging.log(f"STANDALONE CHECKPOINTER : Checkpoint restored in : {checkpoint_load_end - checkpoint_load_start}") - else: # Checkpoint was unavailable, state needs to be initialized - state, _, _ = max_utils.setup_training_state(model, None, - tx, config, init_rng, mesh, checkpoint_manager) + else: # Checkpoint was unavailable, state needs to be initialized + state, _, _ = max_utils.setup_training_state(model, None, tx, config, init_rng, mesh, checkpoint_manager) state = add_entropy_to_checkpoint(state) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training for step in np.arange(start_step, config.steps): if checkpoint_manager is not None: start_time = datetime.datetime.now() @@ -90,6 +87,7 @@ def checkpoint_loop(config, state=None): return state + def add_entropy_to_checkpoint(state): """Introduce randomness in checkpoints. This is useful to simulate real checkpoints, without training. Args: @@ -98,14 +96,17 @@ def add_entropy_to_checkpoint(state): state: Returns state with entropy added to the optimizer state. """ opt_0 = state.opt_state[0] - opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda x: - jax.random.normal(create_random_keys(x), shape=x.shape), state.params)) - opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda x: - jax.random.normal(create_random_keys(x), shape=x.shape), state.params)) + opt_0 = opt_0._replace( + mu=jax.tree_util.tree_map(lambda x: jax.random.normal(create_random_keys(x), shape=x.shape), state.params) + ) + opt_0 = opt_0._replace( + nu=jax.tree_util.tree_map(lambda x: jax.random.normal(create_random_keys(x), shape=x.shape), state.params) + ) new_opt = [opt_0] + list(state.opt_state[1:]) state = state.replace(opt_state=new_opt) return state + def create_random_keys(x): """Create random keys to help alter the checkpoint state. Args: @@ -115,8 +116,9 @@ def create_random_keys(x): """ return random.PRNGKey(int(jnp.sum(jnp.abs(x)))) + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_cpu_enable_gloo_collectives', True) + jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(argv) config = pyconfig.config diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index 4bb13e4dc..8d484d19c 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """ Standalone data loader - only loads data for each training step, accesses storage needs.""" @@ -32,7 +32,7 @@ def data_load_loop(config, state=None): """Main data loader loop. - Loads batches of data for each training step. + Loads batches of data for each training step. """ _, _, _, _, _, _, _, data_iterator, _, state = setup_train_loop(config) @@ -43,14 +43,14 @@ def data_load_loop(config, state=None): example_batch = load_next_batch(data_iterator, example_batch, config) jax.block_until_ready(example_batch) first_end = datetime.datetime.now() - time_to_load_first_batch = first_end-start + time_to_load_first_batch = first_end - start if jax.process_index() == 0: max_logging.log(f"STANDALONE DATALOADER : First step completed in {time_to_load_first_batch} seconds, on host 0") - for _ in np.arange(start_step+1, config.steps): + for _ in np.arange(start_step + 1, config.steps): example_batch = load_next_batch(data_iterator, example_batch, config) - jax.block_until_ready(example_batch) # wait until the last batch is read + jax.block_until_ready(example_batch) # wait until the last batch is read end = datetime.datetime.now() if jax.process_index() == 0: max_logging.log(f"STANDALONE DATALOADER : {config.steps} batches loaded in {end-start} seconds, on host 0") @@ -58,7 +58,7 @@ def data_load_loop(config, state=None): def main(argv: Sequence[str]) -> None: - jax.config.update('jax_cpu_enable_gloo_collectives', True) + jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(argv) config = pyconfig.config @@ -70,6 +70,5 @@ def main(argv: Sequence[str]) -> None: data_load_loop(config) - if __name__ == "__main__": app.run(main) diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index e69c2306d..2a4e2ab97 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -36,10 +36,18 @@ class AttentionTest(unittest.TestCase): - """Test for the Attention """ + """Test for the Attention""" + def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size = 1.0, run_name='test', enable_checkpointing=False, max_target_length=128, max_prefill_predict_length=16 ) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + ) self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(0) @@ -62,20 +70,17 @@ def setUp(self): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "dot_product", + attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) self._attention_as_mha_generic_variable = self._attention_as_mha_generic.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), ) def get_data(self, dtype): @@ -86,7 +91,9 @@ def get_data(self, dtype): ) decoder_segment_ids = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, 4) - decoder_positions = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, self.max_target_length) + decoder_positions = jax.random.randint( + self.rng, (self.global_batch_size, self.max_target_length), 0, self.max_target_length + ) return lnx, decoder_segment_ids, decoder_positions @@ -97,23 +104,22 @@ def get_structured_data(self, dtype): dtype=dtype, ) - decoder_positions = jnp.stack([ - jnp.arange(self.max_target_length, dtype=jnp.int32) - for _ in range(self.global_batch_size) - ]) + decoder_positions = jnp.stack( + [jnp.arange(self.max_target_length, dtype=jnp.int32) for _ in range(self.global_batch_size)] + ) - decoder_segment_ids = jax.numpy.zeros((self.global_batch_size, self.max_target_length))\ - + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + decoder_segment_ids = ( + jax.numpy.zeros((self.global_batch_size, self.max_target_length)) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + ) return lnx, decoder_segment_ids, decoder_positions - + @pytest.mark.tpu def test_autoregression(self): prefill_length = self.cfg.max_prefill_predict_length decode_total_length = self.cfg.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - self.dtype) - + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + mha_full = self._attention_as_mha_generic.apply( self._attention_as_mha_generic_variable, lnx, @@ -122,13 +128,13 @@ def test_autoregression(self): inputs_positions=decoder_positions, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) - + lnx_prefill = lnx[:, 0:prefill_length, :] decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - + mha_prefill, output_cache = self._attention_as_mha_generic.apply( self._attention_as_mha_generic_variable, lnx_prefill, @@ -137,40 +143,32 @@ def test_autoregression(self): inputs_positions=decoder_positions_prefill, deterministic=True, model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'aqt': self.rng}, - mutable=["cache"] + rngs={"aqt": self.rng}, + mutable=["cache"], ) self.assertTrue( - jax.numpy.allclose( - mha_prefill, mha_full[:,:prefill_length,:], rtol=1e-02, atol=1e-02, equal_nan=False - ) + jax.numpy.allclose(mha_prefill, mha_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) ) for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx:idx+1, :] - decoder_positions_idx = decoder_positions[:, idx:idx+1] + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] self._attention_as_mha_generic_variable.update(output_cache) mha_idx, output_cache = self._attention_as_mha_generic.apply( - self._attention_as_mha_generic_variable, - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'aqt': self.rng}, - mutable=["cache"] + self._attention_as_mha_generic_variable, + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], ) - mha_full_this_idx = mha_full[:,idx:idx+1,:] - self.assertTrue( - mha_full_this_idx.shape == mha_idx.shape - ) - self.assertTrue( - jax.numpy.allclose( - mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False - ) - ) + mha_full_this_idx = mha_full[:, idx : idx + 1, :] + self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) + self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) @pytest.mark.tpu def test_tpu_kernel_attention_mha(self): @@ -187,8 +185,7 @@ def test_tpu_kernel_attention_mqa(self): def tpu_kernel_attention_helper(self, num_kv_heads): """Test equalvant between dot_product and TPU accelerated""" - lnx, decoder_segment_ids, decoder_positions = self.get_data( - self.dtype) + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) attention_as_mha_generic = Attention( config=self.cfg, @@ -198,20 +195,17 @@ def tpu_kernel_attention_helper(self, num_kv_heads): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "dot_product", + attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) attention_as_mha_generic_variable = attention_as_mha_generic.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), ) mha_generic_output = attention_as_mha_generic.apply( @@ -222,7 +216,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): inputs_positions=decoder_segment_ids, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) attention_as_mha_flash = Attention( @@ -233,20 +227,17 @@ def tpu_kernel_attention_helper(self, num_kv_heads): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "flash", + attention_kernel="flash", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) attention_as_mha_flash_variable = attention_as_mha_flash.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), ) mha_generic_flash_output = attention_as_mha_flash.apply( @@ -257,14 +248,13 @@ def tpu_kernel_attention_helper(self, num_kv_heads): inputs_positions=decoder_segment_ids, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) self.assertTrue( - jax.numpy.allclose( - mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False - ) + jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False) ) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index 4650824bf..7dc4246d7 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -1,19 +1,18 @@ - """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for GPT3 """ import sys @@ -38,11 +37,12 @@ def init_random_model_vars(model, rng, example_batch): """initialze random model vars.""" model_vars = model.init( - {'params': rng, 'aqt': rng}, - example_batch['inputs'], - example_batch['inputs_position'], + {"params": rng, "aqt": rng}, + example_batch["inputs"], + example_batch["inputs_position"], enable_dropout=False, ) + def _replace_initialization(key, value): keystr = jax.tree_util.keystr(key) # replace zero initializer to ensure strong test cases @@ -57,14 +57,15 @@ def _replace_initialization(key, value): class GPT3(unittest.TestCase): """numerical tests for GPT3.""" + def setUp(self): super().setUp() pyconfig.initialize( - [sys.argv[0], 'configs/base.yml'], - run_name='test', - enable_checkpointing=False, - model_name='gpt3-52k', - dtype='float32', + [sys.argv[0], "configs/base.yml"], + run_name="test", + enable_checkpointing=False, + model_name="gpt3-52k", + dtype="float32", ) self.cfg = pyconfig.config @@ -73,14 +74,14 @@ def setUp(self): devices_array = max_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) quant = quantizations.configure_quantization(self.cfg) - self.model = models.Transformer(config = self.cfg, mesh = mesh, quant = quant) + self.model = models.Transformer(config=self.cfg, mesh=mesh, quant=quant) self.example_batch = { - 'inputs': jnp.array([[11, 12, 13, 14, 15]], dtype=jnp.int32), - 'inputs_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), - 'inputs_segmentation': jnp.array([[1, 1, 1, 1, 1]], dtype=jnp.int32), - 'targets': jnp.array([[12, 13, 14, 15, 1]], dtype=jnp.int32), - 'targets_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), - 'targets_segmentation': jnp.array([[1, 1, 1, 1, 0]], dtype=jnp.int32), + "inputs": jnp.array([[11, 12, 13, 14, 15]], dtype=jnp.int32), + "inputs_position": jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + "inputs_segmentation": jnp.array([[1, 1, 1, 1, 1]], dtype=jnp.int32), + "targets": jnp.array([[12, 13, 14, 15, 1]], dtype=jnp.int32), + "targets_position": jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + "targets_segmentation": jnp.array([[1, 1, 1, 1, 0]], dtype=jnp.int32), } self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch) @@ -91,21 +92,20 @@ def test_logits_numerically(self): # paxml applies padding in mlp layer # while maxtext implementaiton applies padding in attention mask instead # the two implementation are equivalent in valid non-padding tokens - per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.]], dtype=jnp.float32) - logits, _ = self.model.apply(self.model_vars, - self.example_batch['inputs'], - self.example_batch['inputs_position'], - decoder_segment_ids=self.example_batch['inputs_segmentation'], - enable_dropout=self.cfg.enable_dropout, - rngs={'dropout': self.rng, 'aqt': self.rng}, mutable='intermediates') - - one_hot_targets = jax.nn.one_hot(self.example_batch['targets'], self.cfg.vocab_size) + per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.0]], dtype=jnp.float32) + logits, _ = self.model.apply( + self.model_vars, + self.example_batch["inputs"], + self.example_batch["inputs_position"], + decoder_segment_ids=self.example_batch["inputs_segmentation"], + enable_dropout=self.cfg.enable_dropout, + rngs={"dropout": self.rng, "aqt": self.rng}, + mutable="intermediates", + ) + + one_hot_targets = jax.nn.one_hot(self.example_batch["targets"], self.cfg.vocab_size) per_example_xent = -jnp.sum(jax.nn.log_softmax(logits) * one_hot_targets, axis=-1, dtype=jnp.float32) # Mask out paddings at the end of each example. - per_example_xent = per_example_xent * (self.example_batch['targets_segmentation'] != 0) + per_example_xent = per_example_xent * (self.example_batch["targets_segmentation"] != 0) - self.assertTrue( - jax.numpy.allclose( - per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03 - ) - ) + self.assertTrue(jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03)) diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 8af271f53..36167f67b 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import subprocess @@ -27,117 +27,130 @@ from input_pipeline import _grain_data_processing from input_pipeline import input_pipeline_interface + class GrainDataProcessingTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - exit_code = subprocess.call(['bash','../setup_gcsfuse.sh', - 'DATASET_GCS_BUCKET=maxtext-dataset', - 'MOUNT_PATH=/tmp/gcsfuse']) - if exit_code != 0: - raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") - - def setUp(self): - super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], - per_device_batch_size=1, - run_name='test', - mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "/tmp/gcsfuse", - tokenizer_path = "../assets/tokenizer", - enable_checkpointing=False, - dataset_type="c4-array_record", - dataset_name='array-record/c4/en/3.0.1', - eval_dataset_name='array-record/c4/en/3.0.1') - self.config = pyconfig.config - self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) - self.train_ds, self.eval_ds = self._get_datasets() - self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() - - def _get_datasets(self): - print("Sharding dataset in ", jax.process_count(), " shards") - train_ds, eval_ds = _grain_data_processing.get_datasets( - config=self.config) - return train_ds, eval_ds - - def _get_preprocessed_datasets(self): - process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) - train_iter, eval_iter, test_iter = _grain_data_processing.preprocess_dataset( - self.config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - global_mesh = self.mesh, - train_ds = self.train_ds, eval_ds = self.eval_ds, - vocab_path=self.config.tokenizer_path) - return train_iter, eval_iter, test_iter - - def test_train_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - # For training we pack multiple short examples in one example. - # *_position and *_segmentation indicate the boundaries. - batch = next(self.train_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - def test_eval_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - batch = next(self.eval_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - - def test_predict_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - batch = next(self.predict_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - def test_batch_determinism(self): - batch1 = next(self.train_iter) - self.train_ds, _ = self._get_datasets() - train_iter, _, _= self._get_preprocessed_datasets() - batch2 = next(train_iter) - self.assertTrue((batch1['inputs']==batch2['inputs']).all()) - self.assertTrue((batch1['targets']==batch2['targets']).all()) - self.assertTrue((batch1['inputs_segmentation']==batch2['inputs_segmentation']).all()) - self.assertTrue((batch1['targets_segmentation']==batch2['targets_segmentation']).all()) - self.assertTrue((batch1['inputs_position']==batch2['inputs_position']).all()) - self.assertTrue((batch1['targets_position']==batch2['targets_position']).all()) - - def test_for_loop_repeatable(self): - def get_first_batch(iterator): - batch = None - for batch in iterator: - break - return batch - - eval_batch1 = get_first_batch(self.eval_iter) - eval_batch2 = get_first_batch(self.eval_iter) - self.assertTrue((eval_batch1['inputs']==eval_batch2['inputs']).all()) - self.assertTrue((eval_batch1['targets']==eval_batch2['targets']).all()) - - -if __name__ == '__main__': + + @classmethod + def setUpClass(cls): + super().setUpClass() + exit_code = subprocess.call( + ["bash", "../setup_gcsfuse.sh", "DATASET_GCS_BUCKET=maxtext-dataset", "MOUNT_PATH=/tmp/gcsfuse"] + ) + if exit_code != 0: + raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") + + def setUp(self): + super().setUp() + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="/tmp/gcsfuse", + tokenizer_path="../assets/tokenizer", + enable_checkpointing=False, + dataset_type="c4-array_record", + dataset_name="array-record/c4/en/3.0.1", + eval_dataset_name="array-record/c4/en/3.0.1", + ) + self.config = pyconfig.config + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.train_ds, self.eval_ds = self._get_datasets() + self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() + + def _get_datasets(self): + print("Sharding dataset in ", jax.process_count(), " shards") + train_ds, eval_ds = _grain_data_processing.get_datasets(config=self.config) + return train_ds, eval_ds + + def _get_preprocessed_datasets(self): + process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) + train_iter, eval_iter, test_iter = _grain_data_processing.preprocess_dataset( + self.config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=self.mesh, + train_ds=self.train_ds, + eval_ds=self.eval_ds, + vocab_path=self.config.tokenizer_path, + ) + return train_iter, eval_iter, test_iter + + def test_train_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + # For training we pack multiple short examples in one example. + # *_position and *_segmentation indicate the boundaries. + batch = next(self.train_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_eval_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.eval_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_predict_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.predict_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_batch_determinism(self): + batch1 = next(self.train_iter) + self.train_ds, _ = self._get_datasets() + train_iter, _, _ = self._get_preprocessed_datasets() + batch2 = next(train_iter) + self.assertTrue((batch1["inputs"] == batch2["inputs"]).all()) + self.assertTrue((batch1["targets"] == batch2["targets"]).all()) + self.assertTrue((batch1["inputs_segmentation"] == batch2["inputs_segmentation"]).all()) + self.assertTrue((batch1["targets_segmentation"] == batch2["targets_segmentation"]).all()) + self.assertTrue((batch1["inputs_position"] == batch2["inputs_position"]).all()) + self.assertTrue((batch1["targets_position"] == batch2["targets_position"]).all()) + + def test_for_loop_repeatable(self): + def get_first_batch(iterator): + batch = None + for batch in iterator: + break + return batch + + eval_batch1 = get_first_batch(self.eval_iter) + eval_batch2 = get_first_batch(self.eval_iter) + self.assertTrue((eval_batch1["inputs"] == eval_batch2["inputs"]).all()) + self.assertTrue((eval_batch1["targets"] == eval_batch2["targets"]).all()) + + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index 274739a7f..b305562c2 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -22,18 +22,20 @@ class Inference_Microbenchmark(unittest.TestCase): + @pytest.mark.tpu def test(self): - pyconfig.initialize([None, - "configs/tpu_smoke_test.yml", - "tokenizer_path=../assets/tokenizer.llama2", - "ici_autoregressive_parallelism=-1", - "ici_fsdp_parallelism=1", - "max_prefill_predict_length=1024", - "max_target_length=2048", - "scan_layers=false", - "weight_dtype=bfloat16", - ]) + pyconfig.initialize([ + None, + "configs/tpu_smoke_test.yml", + "tokenizer_path=../assets/tokenizer.llama2", + "ici_autoregressive_parallelism=-1", + "ici_fsdp_parallelism=1", + "max_prefill_predict_length=1024", + "max_target_length=2048", + "scan_layers=false", + "weight_dtype=bfloat16", + ]) inference_microbenchmark_main(pyconfig.config) diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index 38e1a7c62..6d7b7827c 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for Llama """ import jax @@ -31,13 +31,10 @@ """ -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32 -) -> jnp.ndarray: + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32) -> jnp.ndarray: """Calculate the frequencies""" - freqs = 1.0 / ( - theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim) - ) + freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) t = np.arange(end) # type: ignore freqs = np.outer(t, freqs).astype(dtype) # type: ignore sin, cos = np.sin(freqs), np.cos(freqs) @@ -45,14 +42,13 @@ def precompute_freqs_cis( return jnp.asarray(freqs_cis) - def apply_rotary_emb( - xq: jnp.ndarray, - xk: jnp.ndarray, - freqs_cis: jnp.ndarray, - dtype: jnp.dtype = jnp.bfloat16, + xq: jnp.ndarray, + xk: jnp.ndarray, + freqs_cis: jnp.ndarray, + dtype: jnp.dtype = jnp.bfloat16, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ Apply the computed Rotary Postional Embedding""" + """Apply the computed Rotary Postional Embedding""" reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) @@ -60,29 +56,26 @@ def apply_rotary_emb( xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) # add head dim - freqs_cis = jnp.reshape( - freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]) - ) + freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) xq_out = xq_ * freqs_cis - xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape( - *xq_out.shape[:-1], -1 - ) + xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) xk_out = xk_ * freqs_cis - xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape( - *xk_out.shape[:-1], -1 - ) + xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) return xq_out.astype(dtype), xk_out.astype(dtype) + def permute_to_match_maxtext_rope(arr): evens = arr[..., ::2] odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) + class RoPETest(unittest.TestCase): - """Test for the RoPE implementation """ + """Test for the RoPE implementation""" + def test_rope(self): dim_per_head = 128 seq_len = 8 @@ -93,24 +86,28 @@ def test_rope(self): # Calculate RoPE embeddings from Sea-Snell implementation freqs_cis = precompute_freqs_cis(dim_per_head, seq_len * 2) - freqs_cis = jnp.take( - freqs_cis, jnp.arange(seq_len, dtype=np.int32)[None, :], axis=0 - ) + freqs_cis = jnp.take(freqs_cis, jnp.arange(seq_len, dtype=np.int32)[None, :], axis=0) - llama_output = apply_rotary_emb( - jnp.asarray(x_q), jnp.asarray(x_k), freqs_cis - ) + llama_output = apply_rotary_emb(jnp.asarray(x_q), jnp.asarray(x_k), freqs_cis) seq_length = x_q.shape[1] position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings from MaxText implementation - query_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_q), position = position) - key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position) + query_proj = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)( + permute_to_match_maxtext_rope(x_q), position=position + ) + key_proj = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)(permute_to_match_maxtext_rope(x_k), position=position) # Compare results - self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) - self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False + ) + ) + self.assertTrue( + jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False) + ) def test_scaling_rope(self): dim_per_head = 128 @@ -121,15 +118,15 @@ def test_scaling_rope(self): position = jnp.arange(seq_len, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings and then scale - query_proj_1 = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(x_q, position = position) - query_proj_1 = query_proj_1 * (dim_per_head ** -0.5) + query_proj_1 = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)(x_q, position=position) + query_proj_1 = query_proj_1 * (dim_per_head**-0.5) # scale first and then apply RoPE - query_proj_2 = x_q * (dim_per_head ** -0.5) - query_proj_2 = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(query_proj_2, position=position) + query_proj_2 = x_q * (dim_per_head**-0.5) + query_proj_2 = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)(query_proj_2, position=position) self.assertTrue(jax.numpy.allclose(query_proj_2, query_proj_1, rtol=1e-01, atol=1e-04, equal_nan=False)) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/max_utils_test.py b/MaxText/tests/max_utils_test.py index c4c1480dd..061446024 100644 --- a/MaxText/tests/max_utils_test.py +++ b/MaxText/tests/max_utils_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the common Max Utils """ import jax @@ -30,55 +30,53 @@ Transformer = models.Transformer + class MaxUtilsSummaryStats(unittest.TestCase): """Tests for the summary stats functions in max_utils.py""" + def test_l2norm_pytree(self): - x = {'a': jax.numpy.array([0, 2, 0]), 'b': jax.numpy.array([0, 3, 6])} + x = {"a": jax.numpy.array([0, 2, 0]), "b": jax.numpy.array([0, 3, 6])} pytree_l2_norm = max_utils.l2norm_pytree(x) self.assertTrue(jax.numpy.allclose(pytree_l2_norm, 7, rtol=1e-05, atol=1e-08, equal_nan=False)) + class MaxUtilsInitState(unittest.TestCase): """Tests initialization of training and decode states in max_utils.py""" + def setUp(self): self.model = nn.Dense(features=5) self.key1, self.key2 = random.split(random.key(0)) - self.input = random.normal(self.key1, (10,)) # Dummy input data + self.input = random.normal(self.key1, (10,)) # Dummy input data self.params = self.model.init(self.key2, self.input) self.output = self.model.apply(self.params, self.input) self.tx = optax.adam(learning_rate=0.001) def test_calculate_num_params_from_pytree(self): example_tree = [ - [1, 'a', object()], - (1, (2, 3), ()), - [1, {'k1': 2, 'k2': (3, 4)}, 5], - {'a': 2, 'b': (2, 3)}, - jnp.array([1, 2, 3]), - ] + [1, "a", object()], + (1, (2, 3), ()), + [1, {"k1": 2, "k2": (3, 4)}, 5], + {"a": 2, "b": (2, 3)}, + jnp.array([1, 2, 3]), + ] self.assertEqual(max_utils.calculate_num_params_from_pytree(example_tree), 17) # Model params self.assertEqual(max_utils.calculate_num_params_from_pytree(self.params), 55) def test_init_train_state(self): state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.params, - tx=None, # type: ignore - opt_state={} + step=0, apply_fn=self.model.apply, params=self.params, tx=None, opt_state={} # type: ignore ) self.assertEqual(state.tx, None) self.assertEqual(state.step, 0) self.assertEqual(state.opt_state, {}) self.assertEqual(state.apply_fn, self.model.apply) - self.assertEqual(max_utils.calculate_num_params_from_pytree(state.params), - max_utils.calculate_num_params_from_pytree(self.params)) - + self.assertEqual( + max_utils.calculate_num_params_from_pytree(state.params), max_utils.calculate_num_params_from_pytree(self.params) + ) def test_init_decode_state(self): - decode_state = max_utils.init_decode_state( - self.model.apply, self.params - ) + decode_state = max_utils.init_decode_state(self.model.apply, self.params) self.assertEqual(decode_state.apply_fn, self.model.apply) output = decode_state.apply_fn(self.params, self.input) self.assertEqual(output.tolist(), self.output.tolist()) @@ -86,8 +84,8 @@ def test_init_decode_state(self): self.assertEqual(decode_state.opt_state, {}) self.assertEqual(decode_state.step, 0) self.assertEqual( - max_utils.calculate_num_params_from_pytree(decode_state.params), - max_utils.calculate_num_params_from_pytree(self.params) + max_utils.calculate_num_params_from_pytree(decode_state.params), + max_utils.calculate_num_params_from_pytree(self.params), ) def test_init_training_state(self): @@ -96,24 +94,24 @@ def test_init_training_state(self): self.assertEqual(state.tx, self.tx) self.assertNotEqual(state.opt_state, {}) self.assertEqual( - max_utils.calculate_num_params_from_pytree(state.params), - max_utils.calculate_num_params_from_pytree(self.params) + max_utils.calculate_num_params_from_pytree(state.params), max_utils.calculate_num_params_from_pytree(self.params) ) + class ModelWithMultipleCollections(nn.Module): - """ - A simple model that has variables in multiple collections - "params" and "special_variables" - """ - def setup(self): - self.dense = nn.Dense(4) - self.kernel = self.variable( - "special_variables", "my_first_kernel", lambda: jnp.ones((4, 5)) - ) - - def __call__(self, x, y): - x = self.dense(x) - x = x @ self.kernel.value - return x + """ + A simple model that has variables in multiple collections - "params" and "special_variables" + """ + + def setup(self): + self.dense = nn.Dense(4) + self.kernel = self.variable("special_variables", "my_first_kernel", lambda: jnp.ones((4, 5))) + + def __call__(self, x, y): + x = self.dense(x) + x = x @ self.kernel.value + return x + class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase): @@ -122,8 +120,7 @@ def setUp(self): self.config = pyconfig.config self.model = ModelWithMultipleCollections() self.key1, self.key2, self.key3 = random.split(random.key(0), num=3) - self.input = random.normal(self.key1, - (self.config.global_batch_size_to_load, self.config.max_target_length)) + self.input = random.normal(self.key1, (self.config.global_batch_size_to_load, self.config.max_target_length)) self.params = self.model.init(self.key2, self.input, self.input) self.tx = optax.adam(learning_rate=0.001) @@ -137,19 +134,16 @@ def _test_init_initial_state_driver(self, is_training): self.assertIsNone(state_under_test.tx) self.assertEqual(state_under_test.opt_state, {}) self.assertEqual( - max_utils.calculate_num_params_from_pytree(state_under_test.params), - max_utils.calculate_num_params_from_pytree(self.params) - ) - self.assertEqual( - len(self.params), - len(state_under_test.params) + max_utils.calculate_num_params_from_pytree(state_under_test.params), + max_utils.calculate_num_params_from_pytree(self.params), ) + self.assertEqual(len(self.params), len(state_under_test.params)) self.assertIn("special_variables", state_under_test.params) self.assertIn("params", state_under_test.params) - + def test_initial_train_state(self): self._test_init_initial_state_driver(True) - + def test_initial_decode_state(self): self._test_init_initial_state_driver(False) @@ -167,28 +161,26 @@ def setUp(self): def test_setup_decode_state(self): rng = random.PRNGKey(0) - state, _ = max_utils.setup_decode_state( - self.model, self.config, rng, self.mesh, None) + state, _ = max_utils.setup_decode_state(self.model, self.config, rng, self.mesh, None) self.assertEqual(state.tx, None) self.assertEqual(state.opt_state, {}) def test_setup_initial_state(self): rng = random.PRNGKey(0) tx = optax.adam(learning_rate=0.001) - state, _, _ = max_utils.setup_initial_state( - self.model, None, tx, self.config, rng, self.mesh, None) + state, _, _ = max_utils.setup_initial_state(self.model, None, tx, self.config, rng, self.mesh, None) self.assertEqual(state.tx, tx) self.assertNotEqual(state.opt_state, {}) + class MaxUtilsT5XCrossEntropy(unittest.TestCase): """Tests for the cross entropy functions in max_utils.py""" + def test_t5x_cross_entropy(self): # Generate random targets and logits key = jax.random.PRNGKey(0) - targets = jax.random.randint(key, shape=(48, 2048), - dtype=jax.numpy.int32, minval=1, maxval=10) - logits = jax.random.uniform(key, shape=(48, 2048, 4096), - dtype=jax.numpy.float32) + targets = jax.random.randint(key, shape=(48, 2048), dtype=jax.numpy.int32, minval=1, maxval=10) + logits = jax.random.uniform(key, shape=(48, 2048, 4096), dtype=jax.numpy.float32) # Calculate xent from optax implementation optax_xent = optax.softmax_cross_entropy_with_integer_labels(logits, targets) @@ -196,10 +188,11 @@ def test_t5x_cross_entropy(self): # Calculate xent from custom T5X implementation one_hot_targets = jax.nn.one_hot(targets, 4096) t5x_xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) - t5x_xent = nn.with_logical_constraint(t5x_xent, ('activation_batch', 'activation_length')) + t5x_xent = nn.with_logical_constraint(t5x_xent, ("activation_batch", "activation_length")) # Compare results self.assertTrue(jax.numpy.allclose(optax_xent, t5x_xent, rtol=1e-05, atol=1e-08, equal_nan=False)) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index 5ab89406e..8b3bcc318 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -30,35 +30,41 @@ from layers import quantizations Mesh = jax.sharding.Mesh -MAX_PREFILL_PREDICT_LENGTH = 4 +MAX_PREFILL_PREDICT_LENGTH = 4 + class TestModel(unittest.TestCase): - """Test the Whole Model """ + """Test the Whole Model""" + def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size = 1.0, run_name='test', - enable_checkpointing=False, base_num_decoder_layers=2, attention="dot_product", - max_target_length=16, base_emb_dim=256, base_num_query_heads=2, base_num_kv_heads=2, max_prefill_predict_length=4) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + base_num_decoder_layers=2, + attention="dot_product", + max_target_length=16, + base_emb_dim=256, + base_num_query_heads=2, + base_num_kv_heads=2, + max_prefill_predict_length=4, + ) self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(0) def get_data(self): s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length) - ids = jax.random.randint( - self.rng, - s, - 0, - self.cfg.vocab_size - ) + ids = jax.random.randint(self.rng, s, 0, self.cfg.vocab_size) decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR - decoder_positions = jnp.stack([ - jnp.arange(self.cfg.max_target_length, dtype=jnp.int32) - for _ in range(self.cfg.global_batch_size_to_train_on) - ]) + decoder_positions = jnp.stack( + [jnp.arange(self.cfg.max_target_length, dtype=jnp.int32) for _ in range(self.cfg.global_batch_size_to_train_on)] + ) return ids, decoder_segment_ids, decoder_positions - + @pytest.mark.tpu def test_train_vs_prefill_and_autoregress(self): PREFILL_RANGE = MAX_PREFILL_PREDICT_LENGTH @@ -66,68 +72,59 @@ def test_train_vs_prefill_and_autoregress(self): devices_array = max_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) quant = quantizations.configure_quantization(self.cfg) - model = models.Transformer(config = self.cfg, mesh = mesh, quant=quant) + model = models.Transformer(config=self.cfg, mesh=mesh, quant=quant) ids, decoder_segment_ids, decoder_positions = self.get_data() transformer_vars = model.init( - {'params': self.rng, 'aqt': self.rng}, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False + {"params": self.rng, "aqt": self.rng}, ids, decoder_positions, decoder_segment_ids, enable_dropout=False ) full_train_logits = model.apply( - transformer_vars, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - model_mode = common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng} + transformer_vars, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_TRAIN, + rngs={"aqt": self.rng}, ) partial_prefill_logits, partial_cache = model.apply( - transformer_vars, - ids[:, :PREFILL_RANGE], - decoder_positions[:, :PREFILL_RANGE], - decoder_segment_ids=decoder_segment_ids[:, :PREFILL_RANGE], - enable_dropout=False, - model_mode = common_types.MODEL_MODE_PREFILL, - rngs={'aqt': self.rng}, - mutable=["cache"], + transformer_vars, + ids[:, :PREFILL_RANGE], + decoder_positions[:, :PREFILL_RANGE], + decoder_segment_ids=decoder_segment_ids[:, :PREFILL_RANGE], + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"aqt": self.rng}, + mutable=["cache"], ) self.assertTrue( jax.numpy.allclose( - full_train_logits[:,:PREFILL_RANGE,:], partial_prefill_logits, rtol=1e-01, atol=1e-01, equal_nan=False + full_train_logits[:, :PREFILL_RANGE, :], partial_prefill_logits, rtol=1e-01, atol=1e-01, equal_nan=False ) ) for idx in range(PREFILL_RANGE, self.cfg.max_target_length): - ids_idx = ids[:, idx:idx+1] - decoder_positions_idx = decoder_positions[:, idx:idx+1] + ids_idx = ids[:, idx : idx + 1] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] transformer_vars.update(partial_cache) ar_logits, partial_cache = model.apply( - transformer_vars, - ids_idx, - decoder_positions_idx, - enable_dropout=False, - model_mode = common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'aqt': self.rng}, - mutable=["cache"], + transformer_vars, + ids_idx, + decoder_positions_idx, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], ) - full_train_logits_idx = full_train_logits[:,idx:idx+1,:] - self.assertTrue( - full_train_logits_idx.shape == ar_logits.shape - ) - self.assertTrue( - jax.numpy.allclose( - full_train_logits_idx, ar_logits, rtol=1e-01, atol=1e-01, equal_nan=False - ) - ) + full_train_logits_idx = full_train_logits[:, idx : idx + 1, :] + self.assertTrue(full_train_logits_idx.shape == ar_logits.shape) + self.assertTrue(jax.numpy.allclose(full_train_logits_idx, ar_logits, rtol=1e-01, atol=1e-01, equal_nan=False)) + -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index e620c2e1e..ba289c040 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, missing-function-docstring import sys @@ -35,29 +35,32 @@ class MultihostDataloadingTest(unittest.TestCase): def setUp(self): super().setUp() batch_size = 4 - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "gs://maxtext-dataset/", - enable_checkpointing=False) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="gs://maxtext-dataset/", + enable_checkpointing=False, + ) config = pyconfig.config global_data_shape = PartitionSpec(batch_size, config.max_target_length) - data_sharding = ('data',) + data_sharding = ("data",) mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes) - data_axes = PartitionSpec('data',) + data_axes = PartitionSpec( + "data", + ) # creating 2 batches of data - global_data = np.arange(np.prod(global_data_shape)*2).reshape((batch_size * 2, config.max_target_length)) + global_data = np.arange(np.prod(global_data_shape) * 2).reshape((batch_size * 2, config.max_target_length)) dataset = tf.data.Dataset.from_tensor_slices(global_data) dataset = dataset.repeat() dataset = dataset.batch(batch_size) - self.multihost_gen = ( - multihost_dataloading.MultiHostDataLoadIterator( - dataset, self.mesh - ) - ) + self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh) @pytest.mark.tpu def test_batch_sharded_data_pipeline(self): @@ -66,5 +69,5 @@ def test_batch_sharded_data_pipeline(self): self.assertTrue(not np.array_equal(first_batch, sec_batch, equal_nan=True)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/profiler_test.py b/MaxText/tests/profiler_test.py index 095073008..1c2cf3780 100644 --- a/MaxText/tests/profiler_test.py +++ b/MaxText/tests/profiler_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Profiler tests for TPUs.""" import glob @@ -29,9 +29,9 @@ class TpuJAXTest(unittest.TestCase): def _get_session_snapshot(self): """Gets a session snapshot of current session. assume only one session.""" - profile_plugin_root ="tensorboard/plugins/profile" + profile_plugin_root = "tensorboard/plugins/profile" # The session exists under a director whose name is time-dependent. - profile_session_glob = os.path.join(profile_plugin_root, '*', '*.xplane.pb') + profile_session_glob = os.path.join(profile_plugin_root, "*", "*.xplane.pb") return glob.glob(profile_session_glob) def test_xplane_is_present(self): @@ -40,47 +40,42 @@ def test_xplane_is_present(self): def test_overview_page(self): xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, - 'overview_page^', {}) + result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "overview_page^", {}) result = json.loads(result) run_environment = result[2] - self.assertEqual(run_environment['p']['host_count'], '1') - self.assertRegex(run_environment['p']['device_type'], 'TPU.*') + self.assertEqual(run_environment["p"]["host_count"], "1") + self.assertRegex(run_environment["p"]["device_type"], "TPU.*") def test_op_profile(self): xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data( - xspace_filenames, 'op_profile^', {} - ) + result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "op_profile^", {}) result = json.loads(result) - self.assertIn('byCategory', result) - self.assertIn('metrics', result['byCategory']) - overall_metrics = result['byCategory']['metrics'] - self.assertIn('flops', overall_metrics) - self.assertIn('bandwidthUtils', overall_metrics) - self.assertGreater(overall_metrics['flops'], 0) + self.assertIn("byCategory", result) + self.assertIn("metrics", result["byCategory"]) + overall_metrics = result["byCategory"]["metrics"] + self.assertIn("flops", overall_metrics) + self.assertIn("bandwidthUtils", overall_metrics) + self.assertGreater(overall_metrics["flops"], 0) def test_device_trace_contains_threads(self): xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data( - xspace_filenames, 'trace_viewer^', {} - ) + result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "trace_viewer^", {}) result = json.loads(result) thread_names = [] - for event in result['traceEvents']: - if 'name' in event and event['name'] == 'thread_name': - thread_names.append((event['args']['name'])) - expected_threads = [ - 'TensorFlow Name Scope', - 'TensorFlow Ops', - 'XLA Modules', - 'XLA Ops', - 'XLA TraceMe', - 'Steps', - ] + for event in result["traceEvents"]: + if "name" in event and event["name"] == "thread_name": + thread_names.append((event["args"]["name"])) + expected_threads = [ + "TensorFlow Name Scope", + "TensorFlow Ops", + "XLA Modules", + "XLA Ops", + "XLA TraceMe", + "Steps", + ] # Ensure that thread_names contains at least all expected threads. - self.assertEqual(set(expected_threads)-set(thread_names), set()) + self.assertEqual(set(expected_threads) - set(thread_names), set()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/quantizations_test.py b/MaxText/tests/quantizations_test.py index 7f35ffb0f..ef1b54df4 100644 --- a/MaxText/tests/quantizations_test.py +++ b/MaxText/tests/quantizations_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the quantizations """ from jax import numpy as jnp @@ -23,8 +23,10 @@ from layers import quantizations import unittest + class QuantTestModule(nn.Module): """Test module for einsum.""" + quantization: quantizations.AqtQuantization @nn.compact @@ -36,23 +38,25 @@ def __call__(self, inputs): einsum = self.quantization.einsum() dot_general_cls = self.quantization.dot_general_cls() dot_general = dot_general_cls() - res_einsum = einsum('bc,ab->ac', inputs, identity) + res_einsum = einsum("bc,ab->ac", inputs, identity) res_dg = dot_general(inputs, inputs, (((), ()), ((), ())), precision=None) return res_einsum, res_dg -def _configure_quantization(quant_str="", mode_str='train'): + +def _configure_quantization(quant_str="", mode_str="train"): pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False, quantization=quant_str) config = pyconfig.config quant = quantizations.configure_quantization(config, mode_str) return quant + def _apply(quant_str=""): quant = _configure_quantization(quant_str) test_module = QuantTestModule(quant) rng = random.PRNGKey(0) - variables = test_module.init({'params': rng}, jnp.ones((2, 2))) + variables = test_module.init({"params": rng}, jnp.ones((2, 2))) inputs = jnp.ones((2, 2)) - res_einsum, res_dg = test_module.apply(variables, inputs, rngs={'params': random.PRNGKey(0)}) + res_einsum, res_dg = test_module.apply(variables, inputs, rngs={"params": random.PRNGKey(0)}) return inputs, res_einsum, res_dg @@ -60,11 +64,10 @@ class QuantizationTest(unittest.TestCase): """Tests for quantization.""" def test_in_quant_mode(self): - quant = _configure_quantization(quant_str="int8", mode_str='convert') + quant = _configure_quantization(quant_str="int8", mode_str="convert") self.assertTrue(quantizations.in_convert_mode(quant)) self.assertFalse(quantizations.in_serve_mode(quant)) - def test_configure_quantization_is_null(self): for quant_mode in ["train", "serve", "convert"]: quant = _configure_quantization(quant_str="", mode_str=quant_mode) @@ -88,37 +91,52 @@ def test_aqt_quantization(self): self.assertTrue(jnp.greater(jnp.max(inputs), jnp.max(res_einsum))) self.assertEqual(res_einsum.dtype, np.dtype(np.float32)) self.assertTrue(jnp.greater(jnp.max(inputs), jnp.max(res_dg[0][0]))) - #self.assertEqual(res_dg.dtype, np.dtype(np.float32)) + # self.assertEqual(res_dg.dtype, np.dtype(np.float32)) def test_remove_quantized_params(self): _params = { - 'decoder': { - 'decoder_norm': {'scale': 1.0}, - 'layers': { - 'mlp': {'wi_0': {'kernel': 1.0}, 'wi_1': {'kernel': 1.0}, 'wo': {'kernel': 1.0}}, - 'self_attention': {'key': {'kernel': 1.0},}}, - 'logits_dense': {'kernel': 1.0}}, - } - _aqt_vars = { - 'decoder': { - 'layers': { - 'mlp': { - 'wi_0': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}}, - 'wi_1': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}}, - 'wo': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}} + "decoder": { + "decoder_norm": {"scale": 1.0}, + "layers": { + "mlp": {"wi_0": {"kernel": 1.0}, "wi_1": {"kernel": 1.0}, "wo": {"kernel": 1.0}}, + "self_attention": { + "key": {"kernel": 1.0}, + }, }, - 'self_attention': {'key': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0}},}}}} - } + "logits_dense": {"kernel": 1.0}, + }, + } + _aqt_vars = { + "decoder": { + "layers": { + "mlp": { + "wi_0": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}}, + "wi_1": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}}, + "wo": {"AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}}, + }, + "self_attention": { + "key": { + "AqtDotGeneral_0": {"qrhs": {"scale": 1.0, "_value": 1.0}}, + } + }, + } + } + } _expected = { - 'decoder': { - 'decoder_norm': {'scale': 1.0}, - 'layers': { - 'mlp': {'wi_0': {'kernel': {}}, 'wi_1': {'kernel': {}}, 'wo': {'kernel': {}}}, - 'self_attention': {'key': {'kernel': {}},}}, - 'logits_dense': {'kernel': 1.0},} - } + "decoder": { + "decoder_norm": {"scale": 1.0}, + "layers": { + "mlp": {"wi_0": {"kernel": {}}, "wi_1": {"kernel": {}}, "wo": {"kernel": {}}}, + "self_attention": { + "key": {"kernel": {}}, + }, + }, + "logits_dense": {"kernel": 1.0}, + } + } result = quantizations.remove_quantized_params(_params, _aqt_vars) self.assertEqual(_expected, result) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index 8b51246b2..d9befd1e8 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the standalone_checkpointer.py """ import unittest @@ -25,35 +25,67 @@ class Standalone_DL_CKPT(unittest.TestCase): - """Tests for standalone_checkpointer.py, checkpoint and restore. """ + """Tests for standalone_checkpointer.py, checkpoint and restore.""" def _get_random_test_name(self, test_name): now = datetime.now() date_time = now.strftime("_%Y-%m-%d-%H-%M_") - random_string = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(6)) + random_string = "".join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(6)) random_run_name = test_name + date_time + random_string return random_run_name @pytest.mark.tpu def test_standalone_dataloader(self): random_run_name = self._get_random_test_name("standalone_dataloader") - sdl_main((None, "configs/base.yml", "run_name="+random_run_name, "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", "steps=100", "enable_checkpointing=false", - "tokenizer_path=../assets/tokenizer.llama2")) # need to pass relative path to tokenizer + sdl_main(( + None, + "configs/base.yml", + "run_name=" + random_run_name, + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "steps=100", + "enable_checkpointing=false", + "tokenizer_path=../assets/tokenizer.llama2", + )) # need to pass relative path to tokenizer @pytest.mark.tpu def test_standalone_checkpointer(self): random_run_name = self._get_random_test_name("standalone_checkpointer") # checkpoint at 50 - sckpt_main((None, "configs/base.yml", f"run_name={random_run_name}", "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset","base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", - "base_mlp_dim=128", "base_num_decoder_layers=2", "steps=60", "enable_checkpointing=True", - "checkpoint_period=50", "async_checkpointing=False")) + sckpt_main(( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=60", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + )) # restore at 50 and checkpoint at 100 - sckpt_main((None, "configs/base.yml", f"run_name={random_run_name}", "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset","base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", - "base_mlp_dim=128", "base_num_decoder_layers=2", "steps=110", "enable_checkpointing=True", - "checkpoint_period=50", "async_checkpointing=False")) + sckpt_main(( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=110", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + )) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/tfds_data_processing_test.py b/MaxText/tests/tfds_data_processing_test.py index 1c998fe82..098334b15 100644 --- a/MaxText/tests/tfds_data_processing_test.py +++ b/MaxText/tests/tfds_data_processing_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, missing-function-docstring import os @@ -29,26 +29,32 @@ from input_pipeline import _tfds_data_processing from input_pipeline import input_pipeline_interface + class TfdsDataProcessingTest(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "gs://maxtext-dataset/", - tokenizer_path = "../assets/tokenizer", - enable_checkpointing=False) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="gs://maxtext-dataset/", + tokenizer_path="../assets/tokenizer", + enable_checkpointing=False, + ) os.environ["TFDS_DATA_DIR"] = pyconfig.config.dataset_path self.config = pyconfig.config self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) self.read_config = tfds.ReadConfig( - shuffle_seed = self.config.data_shuffle_seed, + shuffle_seed=self.config.data_shuffle_seed, ) self.read_config.add_tfds_id = True - + self.train_ds, self.eval_ds = self._get_datasets() self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() @@ -56,10 +62,11 @@ def _get_datasets(self): print("Sharding dataset in ", jax.process_count(), " shards") process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) train_ds, eval_ds = _tfds_data_processing.get_datasets( - config=self.config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - read_config = self.read_config) + config=self.config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + read_config=self.read_config, + ) return train_ds, eval_ds def _get_preprocessed_datasets(self): @@ -67,9 +74,8 @@ def _get_preprocessed_datasets(self): mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), self.config.mesh_axes) sp_tokenizer = input_pipeline_interface.get_tokenizer(self.config.tokenizer_path) train_iter, eval_iter, test_iter = _tfds_data_processing.preprocess_dataset( - self.config, - mesh, - self.train_ds, self.eval_ds, sp_tokenizer) + self.config, mesh, self.train_ds, self.eval_ds, sp_tokenizer + ) return train_iter, eval_iter, test_iter def test_train_ds(self): @@ -77,33 +83,39 @@ def test_train_ds(self): # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. batch = next(self.train_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) def test_eval_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] batch = next(self.eval_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "targets": expected_shape, + }, + ) def test_predict_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] batch = next(self.predict_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "targets": expected_shape, + }, + ) def test_ds_determinism(self): train_ds1 = self.train_ds.batch(64) @@ -113,20 +125,19 @@ def test_ds_determinism(self): train_ds = train_ds.batch(64) train_ds2 = next(train_ds.as_numpy_iterator()) - self.assertCountEqual(train_ds1['tfds_id'], train_ds2['tfds_id']) - + self.assertCountEqual(train_ds1["tfds_id"], train_ds2["tfds_id"]) def test_batch_determinism(self): batch1 = next(self.train_iter) self.train_ds, _ = self._get_datasets() - train_iter2, _, _= self._get_preprocessed_datasets() + train_iter2, _, _ = self._get_preprocessed_datasets() batch2 = next(train_iter2) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs'], batch2['inputs']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets'], batch2['targets']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs_segmentation'], batch2['inputs_segmentation']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets_segmentation'], batch2['targets_segmentation']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs_position'], batch2['inputs_position']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets_position'], batch2['targets_position']))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs"], batch2["inputs"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["targets"], batch2["targets"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs_segmentation"], batch2["inputs_segmentation"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["targets_segmentation"], batch2["targets_segmentation"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs_position"], batch2["inputs_position"]))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["targets_position"], batch2["targets_position"]))) def test_for_loop_repeatable(self): def get_first_batch(iterator): @@ -137,10 +148,9 @@ def get_first_batch(iterator): eval_batch1 = get_first_batch(self.eval_iter) eval_batch2 = get_first_batch(self.eval_iter) - self.assertTrue((eval_batch1['inputs']==eval_batch2['inputs']).all()) - self.assertTrue((eval_batch1['targets']==eval_batch2['targets']).all()) + self.assertTrue((eval_batch1["inputs"] == eval_batch2["inputs"]).all()) + self.assertTrue((eval_batch1["targets"] == eval_batch2["targets"]).all()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() - diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index 6797f6ee3..c24cd2786 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -1,17 +1,17 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ Tests for tokenizer @@ -30,41 +30,43 @@ class TokenizerTest(unittest.TestCase): @classmethod def setUpClass(cls): - dataset_name = 'c4/en:3.0.1' - dataset_path = 'gs://maxtext-dataset' + dataset_name = "c4/en:3.0.1" + dataset_path = "gs://maxtext-dataset" cls.vocab_size = 32_768 cls.max_corpus_chars = 10_000_000 - assets_path = 'tests' - vocab_model_name = 'test_tokenizer' + assets_path = "tests" + vocab_model_name = "test_tokenizer" cls.tokenizer_path = os.path.join(assets_path, vocab_model_name) os.environ["TFDS_DATA_DIR"] = dataset_path read_config = tfds.ReadConfig( - shuffle_seed = 0, + shuffle_seed=0, ) train_ds_builder = tfds.builder(dataset_name) - cls.dataset = train_ds_builder.as_dataset(split='train', read_config=read_config, shuffle_files=True) - train_tokenizer.train_tokenizer(cls.dataset, - assets_path=assets_path, - vocab_path=cls.tokenizer_path, - vocab_size=cls.vocab_size, - max_corpus_chars=cls.max_corpus_chars) + cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) + train_tokenizer.train_tokenizer( + cls.dataset, + assets_path=assets_path, + vocab_path=cls.tokenizer_path, + vocab_size=cls.vocab_size, + max_corpus_chars=cls.max_corpus_chars, + ) @classmethod def tearDownClass(cls): os.remove(cls.tokenizer_path) def test_tokenize(self): - source_tokenizer = tokenizer.load_tokenizer('../assets/tokenizer') + source_tokenizer = tokenizer.load_tokenizer("../assets/tokenizer") test_tokenizer = tokenizer.load_tokenizer(self.tokenizer_path) - text = 'This is a test' + text = "This is a test" self.assertTrue((np.asarray(source_tokenizer.tokenize(text)) & np.asarray(test_tokenizer.tokenize(text))).all()) def test_detokenize(self): - source_tokenizer = tokenizer.load_tokenizer('../assets/tokenizer') + source_tokenizer = tokenizer.load_tokenizer("../assets/tokenizer") test_tokenizer = tokenizer.load_tokenizer(self.tokenizer_path) - tokens = [66,12,10,698,2] - self.assertEqual(np.asarray(source_tokenizer.detokenize(tokens)),np.asarray(test_tokenizer.detokenize(tokens))) + tokens = [66, 12, 10, 698, 2] + self.assertEqual(np.asarray(source_tokenizer.detokenize(tokens)), np.asarray(test_tokenizer.detokenize(tokens))) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 3bf46e753..cc64aea91 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the common Max Utils """ import unittest @@ -26,73 +26,158 @@ class TrainCompile(unittest.TestCase): @pytest.mark.tpu def test_save_compiled_v4(self): - compiled_trainstep_file='/tmp/test_compiled_v4.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v4-8", "compile_topology_num_slices=1", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v4.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v4-8", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_save_compiled_v5e(self): - compiled_trainstep_file='/tmp/test_compiled_v5e.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-16", "compile_topology_num_slices=1", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_minimal_offloaded_v5e(self): - compiled_trainstep_file='/tmp/test_compiled_v5e_offload.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=1", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=minimal_offloaded", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_compiled_v5e_offload.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=minimal_offloaded", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_save_compiled_v5p_two_slices(self): - compiled_trainstep_file='/tmp/test_compiled_v5p_two_slices.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5p-8", "compile_topology_num_slices=2", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v5p_two_slices.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_sequence_parallelism(self): - compiled_trainstep_file='/tmp/test_compiled.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "use_iota_embed=true", "compile_topology_num_slices=1", - "ici_sequence_parallelism=16", "global_parameter_scale=32", "per_device_batch_size=0.0625", "max_target_length=65536")) + compiled_trainstep_file = "/tmp/test_compiled.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "ici_sequence_parallelism=16", + "global_parameter_scale=32", + "per_device_batch_size=0.0625", + "max_target_length=65536", + )) @pytest.mark.tpu def test_remat_save_dot_except_mlpwi(self): - compiled_trainstep_file='/tmp/test_remat_save_dot_except_mlpwi.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.125", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_dot_except_mlpwi", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlpwi.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.125", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlpwi", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_save_dot_except_mlp(self): - compiled_trainstep_file='/tmp/test_remat_save_dot_except_mlp.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.25", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_dot_except_mlp", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlp.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.25", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlp", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_save_qkv_proj(self): - compiled_trainstep_file='/tmp/test_remat_save_qkv_proj.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.375", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_qkv_proj", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_qkv_proj.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.375", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_qkv_proj", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_full(self): - compiled_trainstep_file='/tmp/test_remat_full.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=1", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=full", - "use_iota_embed=true", "global_parameter_scale=128")) \ No newline at end of file + compiled_trainstep_file = "/tmp/test_remat_full.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=full", + "use_iota_embed=true", + "global_parameter_scale=128", + )) diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 5efc5ebeb..a05e0fb6e 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test for int8""" import os @@ -20,18 +20,32 @@ from train import main as train_main from absl.testing import absltest + class Train(unittest.TestCase): """Smoke test for int8 G3 only""" def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([None, "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", "dataset_type=synthetic", "steps=10", - "enable_checkpointing=False", "quantization=int8"]) - -if __name__ == '__main__': + train_main([ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + "quantization=int8", + ]) + + +if __name__ == "__main__": absltest.main() diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index 8cd41fb33..b3046fd35 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test """ import os @@ -20,18 +20,31 @@ from train import main as train_main from absl.testing import absltest + class Train(unittest.TestCase): """Smoke test G3 only""" def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([None, "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", "dataset_type=synthetic", "steps=10", - "enable_checkpointing=False"]) - -if __name__ == '__main__': + train_main([ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + ]) + + +if __name__ == "__main__": absltest.main() diff --git a/MaxText/tests/weight_dtypes_test.py b/MaxText/tests/weight_dtypes_test.py index 49829d43c..579e23310 100644 --- a/MaxText/tests/weight_dtypes_test.py +++ b/MaxText/tests/weight_dtypes_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Test that all weights are expected dtype (default float32) """ import unittest @@ -30,37 +30,36 @@ Transformer = models.Transformer -class WeightDtypes(unittest.TestCase): - """Test that all weights are expected dtype (default float32) """ - - def get_weights(self, argv): - """ Gets model weights """ - - # Setup necessary inputs to build a model state - pyconfig.initialize(argv) - config = pyconfig.config - quant = quantizations.configure_quantization(config) - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - model = Transformer(config, mesh, quant=quant) - learning_rate_schedule = max_utils.create_learning_rate_schedule(config) - tx = optimizers.get_optimizer(config, learning_rate_schedule) - _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - - abstract_state, _ , _ = max_utils.get_abstract_state(model, tx, config, example_rng, mesh) - return abstract_state.params - - def assert_weights_are_dtype(self, weights, expected_dtype): - jax.tree_util.tree_map_with_path(lambda x,y: self.assertEqual(y.dtype, expected_dtype), weights) - - def test_default_float32(self): - argv = [None, "configs/base.yml", "enable_checkpointing=False"] - weights = self.get_weights(argv) - self.assert_weights_are_dtype(weights, jnp.float32) - - def test_set_bf16(self): - argv = [None, "configs/base.yml", "enable_checkpointing=False", "weight_dtype=bfloat16"] - weights = self.get_weights(argv) - self.assert_weights_are_dtype(weights, jnp.bfloat16) - +class WeightDtypes(unittest.TestCase): + """Test that all weights are expected dtype (default float32)""" + + def get_weights(self, argv): + """Gets model weights""" + + # Setup necessary inputs to build a model state + pyconfig.initialize(argv) + config = pyconfig.config + quant = quantizations.configure_quantization(config) + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + model = Transformer(config, mesh, quant=quant) + learning_rate_schedule = max_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) + _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) + + abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, example_rng, mesh) + return abstract_state.params + + def assert_weights_are_dtype(self, weights, expected_dtype): + jax.tree_util.tree_map_with_path(lambda x, y: self.assertEqual(y.dtype, expected_dtype), weights) + + def test_default_float32(self): + argv = [None, "configs/base.yml", "enable_checkpointing=False"] + weights = self.get_weights(argv) + self.assert_weights_are_dtype(weights, jnp.float32) + + def test_set_bf16(self): + argv = [None, "configs/base.yml", "enable_checkpointing=False", "weight_dtype=bfloat16"] + weights = self.get_weights(argv) + self.assert_weights_are_dtype(weights, jnp.bfloat16) diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index 9e29b1e68..0e7c374ad 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Provides op for tokenizing a dataset.""" @@ -29,34 +29,29 @@ Features = Dict[str, tf.Tensor] - -def _load_sentencepiece_tokenizer(tokenizer_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer(tokenizer_path: str, add_bos: bool = False, add_eos: bool = True, reverse: bool = False): """Load a tf-text SentencePiece tokenizer from given model filepath.""" max_logging.log(f"Tokenizer path: {tokenizer_path}") - with tf.io.gfile.GFile(tokenizer_path, 'rb') as model_fp: + with tf.io.gfile.GFile(tokenizer_path, "rb") as model_fp: sp_model = model_fp.read() - sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) return sp_tokenizer + def load_tokenizer(tokenizer_path: str, add_bos=False, add_eos=True): """Loads the tokenizer at `tokenizer_path` or trains a one from `dataset`.""" try: sp_tokenizer = _load_sentencepiece_tokenizer(tokenizer_path, add_bos, add_eos) return sp_tokenizer except (tf.errors.NotFoundError, tf.errors.InvalidArgumentError): - logging.info('SentencePiece vocab not found, Run train_tokenizer.py') + logging.info("SentencePiece vocab not found, Run train_tokenizer.py") return None @dataclasses.dataclass class TokenizeOp: - sp_tokenizer: Any - data_keys: Iterable[str] = ('inputs', 'targets') + data_keys: Iterable[str] = ("inputs", "targets") def __call__(self, features: Features) -> Features: for k in self.data_keys: diff --git a/MaxText/train.py b/MaxText/train.py index 4a1966945..be67419cc 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Training loop and Decoding of the model.""" @@ -64,56 +64,52 @@ Transformer = models.Transformer EPS = 1e-8 + def validate_train_config(config): - """ Validates the configuration is set correctly for train.py""" + """Validates the configuration is set correctly for train.py""" assert config.run_name, "Erroring out, need a real run_name" - if not config.dataset_path.startswith('gs://'): + if not config.dataset_path.startswith("gs://"): max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") - if not config.base_output_directory.startswith('gs://'): + if not config.base_output_directory.startswith("gs://"): max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive interger." - def get_first_step(state): - with jax.spmd_mode('allow_all'): + with jax.spmd_mode("allow_all"): return int(state.step) def load_next_batch(train_iter, example_batch, config): - """Loads the next batch. Can keep reusing the same batch for performance reasons """ + """Loads the next batch. Can keep reusing the same batch for performance reasons""" if config.reuse_example_batch and example_batch is not None: return example_batch else: return next(train_iter) + def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): """Records scalar metrics to be written to tensorboard""" - metrics['scalar'].update({ - 'perf/step_time_seconds': step_time_delta.total_seconds() - }) - metrics['scalar'].update({ - 'perf/per_device_tflops' : per_device_tflops - }) - metrics['scalar'].update({ - 'perf/per_device_tflops_per_sec': - per_device_tflops / - step_time_delta.total_seconds() - }) - metrics['scalar'].update({'learning/current_learning_rate': lr }) + metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()}) + metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops}) + metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()}) + metrics["scalar"].update({"learning/current_learning_rate": lr}) + _buffered_step = None _buffered_metrics = None + + def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config): """Entry point for all metrics writing in Train's Main. - TODO: would be better as a Class in the future (that initialized all state!) + TODO: would be better as a Class in the future (that initialized all state!) - To avoid introducing an unnecessary dependency, we "double buffer" -- we hold - onto the last metrics and step and only publish when we receive a new metrics and step. - The logic is that this ensures that Jax is able to queues train_steps and we - don't block when turning "lazy" Jax arrays into real Python numbers. + To avoid introducing an unnecessary dependency, we "double buffer" -- we hold + onto the last metrics and step and only publish when we receive a new metrics and step. + The logic is that this ensures that Jax is able to queues train_steps and we + don't block when turning "lazy" Jax arrays into real Python numbers. """ global _buffered_step, _buffered_metrics @@ -131,68 +127,72 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step _buffered_step = step _buffered_metrics = metrics + def write_metrics_to_tensorboard(writer, metrics, step, config): - """ Writes metrics to tensorboard""" - with jax.spmd_mode('allow_all'): + """Writes metrics to tensorboard""" + with jax.spmd_mode("allow_all"): if jax.process_index() == 0: - for metric_name in metrics.get("scalar",[]): + for metric_name in metrics.get("scalar", []): writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) - for metric_name in metrics.get("scalars",[]): + for metric_name in metrics.get("scalars", []): writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) full_log = step % config.log_period == 0 - max_logging.log(f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"loss: {metrics['scalar']['learning/loss']:.3f}") + max_logging.log( + f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " + f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " + f"loss: {metrics['scalar']['learning/loss']:.3f}" + ) if full_log and jax.process_index() == 0: - max_logging.log( - f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'" - ) + max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") writer.flush() -def save_checkpoint(checkpoint_manager, step, state, dataset_type='c4', data_iterator=None): + +def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None): """Wrapper for saving checkpoint""" - if dataset_type == 'c4-array_record': + if dataset_type == "c4-array_record": return checkpoint_manager.save( - step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeSave(item=state), - iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator) - ) - ) + step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeSave(item=state), + iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator), + ), + ) else: return checkpoint_manager.save( - step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeSave(item=state) - )) + step, args=orbax.checkpoint.args.Composite(items=orbax.checkpoint.args.PyTreeSave(item=state)) + ) + # ----------------------------------------------------------------------------- # Top-level Functions # ----------------------------------------------------------------------------- + def record_activation_metrics(output_metrics, intermediate_outputs, config): - """ Adds the activation metrics to the metrics dict""" + """Adds the activation metrics to the metrics dict""" if config.scan_layers: - metrics_dict = intermediate_outputs['intermediates']['decoder']['decoder'] + metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"] for layer_num in range(config.num_decoder_layers): - output_metrics['scalar'][f'activ_fraction_zero/layer_{layer_num:03d}'] = \ - metrics_dict["activation_fraction_zero"][0][layer_num] - output_metrics['scalar'][f'activ_mean/layer_{layer_num:03d}'] = metrics_dict["activation_mean"][0][layer_num] - output_metrics['scalar'][f'activ_stdev/layer_{layer_num:03d}'] = metrics_dict["activation_stdev"][0][layer_num] + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = metrics_dict["activation_fraction_zero"][0][ + layer_num + ] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = metrics_dict["activation_mean"][0][layer_num] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = metrics_dict["activation_stdev"][0][layer_num] else: for layer_num in range(config.num_decoder_layers): - layer = intermediate_outputs['intermediates']['decoder'][f'layers_{layer_num}'] - output_metrics['scalar'][f'activ_fraction_zero/layer_{layer_num:03d}'] = layer["activation_fraction_zero"][0] - output_metrics['scalar'][f'activ_mean/layer_{layer_num:03d}'] = layer["activation_mean"][0] - output_metrics['scalar'][f'activ_stdev/layer_{layer_num:03d}'] = layer["activation_stdev"][0] + layer = intermediate_outputs["intermediates"]["decoder"][f"layers_{layer_num}"] + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = layer["activation_fraction_zero"][0] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer["activation_mean"][0] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer["activation_stdev"][0] + def loss_fn(model, config, data, dropout_rng, params, is_train=True): - '''loss_fn for both train and eval. + """loss_fn for both train and eval. Args: model: A nn.Module @@ -205,33 +205,36 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): Returns: loss: average loss aux: a dictionary including intermediate_outputs, total_loss, and total_weights - ''' + """ # inputs, targets, segments, positions = apply_args rng1, aqt_rng = jax.random.split(dropout_rng) # decimate proportion of data when per_device_batch_size<1 if is_train: for k, v in data.items(): - data[k] = v[:config.global_batch_size_to_train_on,:] - - logits, intermediate_outputs = model.apply(params, - data['inputs'], - data['inputs_position'], - decoder_segment_ids=data['inputs_segmentation'], - enable_dropout=config.enable_dropout if is_train else False, - rngs={'dropout': rng1, 'params': aqt_rng}, mutable='intermediates') - one_hot_targets = jax.nn.one_hot(data['targets'], config.vocab_size) + data[k] = v[: config.global_batch_size_to_train_on, :] + + logits, intermediate_outputs = model.apply( + params, + data["inputs"], + data["inputs_position"], + decoder_segment_ids=data["inputs_segmentation"], + enable_dropout=config.enable_dropout if is_train else False, + rngs={"dropout": rng1, "params": aqt_rng}, + mutable="intermediates", + ) + one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) - xent = nn.with_logical_constraint(xent, ('activation_batch', 'activation_length')) + xent = nn.with_logical_constraint(xent, ("activation_batch", "activation_length")) # Mask out paddings at the end of each example. - xent = xent * (data['targets_segmentation'] != 0) + xent = xent * (data["targets_segmentation"] != 0) total_loss = jnp.sum(xent) - total_weights = jnp.sum(data['targets_segmentation'] != 0) + total_weights = jnp.sum(data["targets_segmentation"] != 0) loss = total_loss / (total_weights + EPS) aux = { - 'intermediate_outputs': intermediate_outputs, - 'total_loss': total_loss, - 'total_weights': total_weights, + "intermediate_outputs": intermediate_outputs, + "total_loss": total_loss, + "total_weights": total_weights, } return loss, aux @@ -254,42 +257,50 @@ def train_step(model, config, state, data, dropout_rng): train_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=True) grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True) (loss, aux), raw_grads = grad_fn(state.params) - intermediate_outputs = aux['intermediate_outputs'] + intermediate_outputs = aux["intermediate_outputs"] if config.gradient_clipping_threshold > 0: grads, _ = optax.clip_by_global_norm(config.gradient_clipping_threshold).update(raw_grads, state, None) else: grads = raw_grads new_state = state.apply_gradients(grads=grads) - metrics = {'scalar': {'learning/loss': loss, 'learning/grad_norm': max_utils.l2norm_pytree(grads), - 'learning/raw_grad_norm': max_utils.l2norm_pytree(raw_grads), - 'learning/param_norm': max_utils.l2norm_pytree(new_state.params)}, 'scalars': {}} + metrics = { + "scalar": { + "learning/loss": loss, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + "learning/param_norm": max_utils.l2norm_pytree(new_state.params), + }, + "scalars": {}, + } if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) return new_state, metrics + def eval_step(model, config, state, data, dropout_rng): """eval_step no backprop and new state compared with train_step.""" eval_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=False) loss, aux = eval_loss_fn(state.params) - total_loss = aux['total_loss'] - total_weights = aux['total_weights'] - metrics = {'scalar': - {'evaluation/loss': loss, - 'evaluation/total_loss': total_loss, - 'evaluation/total_weights': total_weights}} + total_loss = aux["total_loss"] + total_weights = aux["total_weights"] + metrics = { + "scalar": {"evaluation/loss": loss, "evaluation/total_loss": total_loss, "evaluation/total_weights": total_weights} + } return metrics + def create_goodput_recorder(config): if config.enable_goodput_recording: - logger_name = f'goodput_{config.run_name}' + logger_name = f"goodput_{config.run_name}" recorder = goodput.GoodputRecorder(config.run_name, logger_name, jax.process_index() == 0) return recorder return None + def record_goodput(recorder, config, step=None, job_start=False, job_end=False): if recorder and config.enable_goodput_recording: if job_start and step is None: @@ -299,8 +310,9 @@ def record_goodput(recorder, config, step=None, job_start=False, job_end=False): if step is not None: recorder.record_step_start_time(step) + def setup_mesh_and_model(config): - """ Set up the mesh and the model for training + """Set up the mesh and the model for training Args: config @@ -336,8 +348,9 @@ def setup_mesh_and_model(config): tx = optimizers.get_optimizer(config, learning_rate_schedule) return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx + def setup_train_loop(config): - """ Set up prerequisites for the training loop - + """Set up prerequisites for the training loop - checkpoint_manager, PRNG keys, Mesh, Model and optimizer. Set up data iterator and tokenizer, initialize the model. @@ -358,13 +371,24 @@ def setup_train_loop(config): init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config) data_iterator, eval_data_iterator, _ = create_data_iterator_with_tokenizer(config, mesh) - state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(model, data_iterator, - tx, config, init_rng, mesh, checkpoint_manager) + state, state_mesh_annotations, data_iterator = max_utils.setup_training_state( + model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + ) maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh) - return ( init_rng, writer, checkpoint_manager, state_mesh_annotations, model, - mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state) + return ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) def train_loop(config, state=None): @@ -379,27 +403,36 @@ def train_loop(config, state=None): recorder = create_goodput_recorder(config) record_goodput(recorder, config, job_start=True) - ( init_rng, writer, checkpoint_manager, state_mesh_annotations, model, - mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state) = setup_train_loop(config) + ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) = setup_train_loop(config) # pylint: disable=line-too-long - functional_train, in_shard_train, out_shard_train, static_argnums_train, donate_argnums_train = maxtext_utils.get_functional_train_with_signature( - train_step, - mesh, - state_mesh_annotations, - model, - config - ) + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config) if eval_data_iterator: # pylint: disable=line-too-long - functional_eval, in_shard_eval, out_shard_eval, static_argnums_eval, donate_argnums_eval = maxtext_utils.get_functional_eval_with_signature( - eval_step, - mesh, - state_mesh_annotations, - model, - config - ) - + ( + functional_eval, + in_shard_eval, + out_shard_eval, + static_argnums_eval, + donate_argnums_eval, + ) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config) num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") @@ -411,31 +444,33 @@ def train_loop(config, state=None): max_utils.add_config_to_summary_writer(config, writer) # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit - if config.compiled_trainstep_file != '': + if config.compiled_trainstep_file != "": print("Loading the compiled function...", flush=True) # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state) print("Loaded compiled function!", flush=True) else: p_train_step = jax.jit( - functional_train, - in_shardings=in_shard_train, - out_shardings=out_shard_train, - static_argnums=static_argnums_train, - donate_argnums=donate_argnums_train) + functional_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + static_argnums=static_argnums_train, + donate_argnums=donate_argnums_train, + ) if eval_data_iterator: p_eval_step = jax.jit( - functional_eval, - in_shardings=in_shard_eval, - out_shardings=out_shard_eval, - static_argnums=static_argnums_eval, - donate_argnums=donate_argnums_eval) + functional_eval, + in_shardings=in_shard_eval, + out_shardings=out_shard_eval, + static_argnums=static_argnums_eval, + donate_argnums=donate_argnums_eval, + ) - local_metrics_file = open(config.metrics_file, 'a', encoding="utf8") if config.metrics_file else None + local_metrics_file = open(config.metrics_file, "a", encoding="utf8") if config.metrics_file else None running_gcs_metrics = [] if config.gcs_metrics else None - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training first_profiling_step = start_step + config.skip_first_n_steps_for_profiler if config.enable_profiler and first_profiling_step >= config.steps: raise ValueError("Profiling requested but initial profiling step set past training final step") @@ -453,12 +488,10 @@ def train_loop(config, state=None): nextrng = jax.jit(jax.random.fold_in)(init_rng, step) record_goodput(recorder, config, step=step) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - state, metrics = p_train_step( - state, example_batch, nextrng - ) + state, metrics = p_train_step(state, example_batch, nextrng) new_time = datetime.datetime.now() - record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) + record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) last_step_completion = new_time if checkpoint_manager is not None: @@ -474,15 +507,13 @@ def train_loop(config, state=None): if config.eval_interval > 0 and step > start_step and step % config.eval_interval == 0: assert eval_data_iterator - cumulative_eval_metrics = {"total_loss": 0., "total_weights": 0.} + cumulative_eval_metrics = {"total_loss": 0.0, "total_weights": 0.0} for eval_batch in eval_data_iterator: with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step( - state, eval_batch, nextrng - ) - cumulative_eval_metrics['total_loss'] += float(eval_metrics['scalar']['evaluation/total_loss']) - cumulative_eval_metrics['total_weights'] += float(eval_metrics['scalar']['evaluation/total_weights']) - eval_loss = cumulative_eval_metrics['total_loss'] / (cumulative_eval_metrics['total_weights'] + EPS) + eval_metrics = p_eval_step(state, eval_batch, nextrng) + cumulative_eval_metrics["total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"]) + cumulative_eval_metrics["total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"]) + eval_loss = cumulative_eval_metrics["total_loss"] / (cumulative_eval_metrics["total_weights"] + EPS) max_logging.log(f"average loss after {step=}: {eval_loss=}, total_weights={cumulative_eval_metrics['total_weights']}") if eval_loss <= config.target_eval_loss: max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") @@ -494,15 +525,16 @@ def train_loop(config, state=None): if checkpoint_manager is not None: checkpoint_manager.wait_until_finished() - write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics max_utils.close_summary_writer(writer) record_goodput(recorder, config, job_end=True) return state + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" pyconfig.initialize(argv) config = pyconfig.config validate_train_config(config) @@ -512,13 +544,16 @@ def main(argv: Sequence[str]) -> None: vertex_tensorboard_manager.configure_vertex_tensorboard(config) debug_config = debug_configuration.DebugConfig( - stack_trace_config = stack_trace_configuration.StackTraceConfig( - collect_stack_trace = config.collect_stack_trace, - stack_trace_to_cloud = config.stack_trace_to_cloud, - stack_trace_interval_seconds = config.stack_trace_interval_seconds)) + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) + ) diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) with diagnostic.diagnose(diagnostic_config): train_loop(config) + if __name__ == "__main__": app.run(main) diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index 55384c13e..43789ea3f 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Save a Cross Ahead of Time Compiled (XAOT) version of train.py's train step @@ -45,14 +45,15 @@ def validate_config(config): - """ Validates the config is is setup correctly to compile, returning a useful error message if not. """ - assert config.compile_topology != '',\ - "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" - assert config.compile_topology_num_slices > 0,\ - "You must set compile_topology_num_slices to a positive integer" + """Validates the config is is setup correctly to compile, returning a useful error message if not.""" + assert ( + config.compile_topology != "" + ), "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" + assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer" + def get_topology_mesh(config): - """ Get the target hardware devices, and create configured mesh with them """ + """Get the target hardware devices, and create configured mesh with them""" target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) topology_devices = get_topology_desc( platform=target_hardware.platform, @@ -65,8 +66,9 @@ def get_topology_mesh(config): topology_mesh = Mesh(topology_device_mesh, config.mesh_axes) return topology_mesh + def get_shaped_inputs(topology_mesh, config): - """ Get shaped abstractions of inputs to train_step: state, batch and rng """ + """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizier to get shaped versions of the state quant = quantizations.configure_quantization(config) model = Transformer(config, topology_mesh, quant=quant) @@ -79,7 +81,7 @@ def get_shaped_inputs(topology_mesh, config): shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) # Shaped state - abstract_state, state_mesh_annotations, _ = max_utils.get_abstract_state(model, tx, config, example_rng, topology_mesh) + abstract_state, state_mesh_annotations, _ = max_utils.get_abstract_state(model, tx, config, example_rng, topology_mesh) # Shaped batch shaped_batch = input_pipeline_interface.get_shaped_batch(config) @@ -89,29 +91,40 @@ def get_shaped_inputs(topology_mesh, config): return shaped_train_args, shaped_train_kwargs, state_mesh_annotations, model -def jit_and_compile(func, func_input_args, func_input_kwargs, mesh, in_shardings, - out_shardings, static_argnums, donate_argnums, logical_axis_rules): - """ Jit, lower, and compile func.""" +def jit_and_compile( + func, + func_input_args, + func_input_kwargs, + mesh, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + logical_axis_rules, +): + """Jit, lower, and compile func.""" with mesh, logical_axis_rules: jitted = jax.jit( - func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - donate_argnums=donate_argnums + func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, ) lowered = jitted.lower(*func_input_args, **func_input_kwargs) compiled = lowered.compile() return compiled + def save_compiled(compiled, save_name): - """ Serialize and save the compiled function. """ + """Serialize and save the compiled function.""" serialized, _, _ = serialize(compiled) with open(save_name, "wb") as f: pickle.dump(serialized, f) + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") print("Starting train_compile.py...", flush=True) # Parse and validate configuration @@ -127,30 +140,26 @@ def main(argv: Sequence[str]) -> None: # Get function to compile and shardings func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = maxtext_utils.get_functional_train_with_signature( - train.train_step, - topology_mesh, - state_mesh_annotations, - model, - config + train.train_step, topology_mesh, state_mesh_annotations, model, config ) # Compile print("Jitting and compiling train step...", flush=True) compiled = jit_and_compile( - func_to_compile, - shaped_train_args, - shaped_train_kwargs, - topology_mesh, - in_shard, - out_shard, - static_argnums, - donate_argnums, - nn_partitioning.axis_rules(config.logical_axis_rules) + func_to_compile, + shaped_train_args, + shaped_train_kwargs, + topology_mesh, + in_shard, + out_shard, + static_argnums, + donate_argnums, + nn_partitioning.axis_rules(config.logical_axis_rules), ) print("Jitting and compilation complete!", flush=True) # Serialize and save the compiled object - if config.compiled_trainstep_file != '': + if config.compiled_trainstep_file != "": print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") diff --git a/MaxText/train_tokenizer.py b/MaxText/train_tokenizer.py index dab4d99be..03c0d5269 100644 --- a/MaxText/train_tokenizer.py +++ b/MaxText/train_tokenizer.py @@ -1,14 +1,14 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ Train tokenizer @@ -29,28 +29,15 @@ from sentencepiece import SentencePieceTrainer -_DATASET_PATH = flags.DEFINE_string( - 'dataset_path', None, 'Path to the dataset', required=True -) -_DATASET_NAME = flags.DEFINE_string( - 'dataset_name', None, 'Name to the dataset', required=True -) -_VOCAB_SIZE = flags.DEFINE_integer('vocab_size', 32_768, 'Vocab size') -_MAX_CORPUS_CHARS = flags.DEFINE_integer( - 'max_corpus_chars', 10_000_000, 'Max corpus chars' -) -_ASSETS_PATH = flags.DEFINE_string( - 'assets_path', 'assets', 'Name to the dataset' -) -_VOCAB_MODEL_NAME = flags.DEFINE_string( - 'vocab_model_name', 'tokenizer', 'Name to the dataset' -) - -def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('text',) -) -> Tuple[str, int]: +_DATASET_PATH = flags.DEFINE_string("dataset_path", None, "Path to the dataset", required=True) +_DATASET_NAME = flags.DEFINE_string("dataset_name", None, "Name to the dataset", required=True) +_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size") +_MAX_CORPUS_CHARS = flags.DEFINE_integer("max_corpus_chars", 10_000_000, "Max corpus chars") +_ASSETS_PATH = flags.DEFINE_string("assets_path", "assets", "Name to the dataset") +_VOCAB_MODEL_NAME = flags.DEFINE_string("vocab_model_name", "tokenizer", "Name to the dataset") + + +def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: dataset: tf.dataset containing string-data. @@ -61,25 +48,27 @@ def _dump_chars_to_textfile( """ char_count = 0 ds_iter = dataset.as_numpy_iterator() - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + with tempfile.NamedTemporaryFile(delete=False, prefix="/tmp/ds_chars") as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: - line = example[k] + b'\n' + line = example[k] + b"\n" char_count += len(line) outfp.write(line) return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - assets_path: str, - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('text',)): + +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + assets_path: str, + model_path: str, + model_type: str = "unigram", + character_coverage: float = 1.0, + data_keys=("text",), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset @@ -94,65 +83,69 @@ def _train_sentencepiece(dataset: tf.data.Dataset, Returns: path to the trained sentencepiece vocabulary model. """ - if model_path.startswith('gs://'): + if model_path.startswith("gs://"): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) abs_assets_path = os.path.abspath(os.path.expanduser(assets_path)) - fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) - with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + fname, _ = _dump_chars_to_textfile(dataset, maxchars=maxchars, data_keys=data_keys) + with tempfile.NamedTemporaryFile(delete=False, prefix="/tmp/sp_tmp") as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join([ - f'--input={fname}', f'--vocab_size={vocab_size}', - f'--character_coverage={character_coverage}', - f'--model_prefix={model_fp.name}', f'--model_type={model_type}' + argstr = " ".join([ + f"--input={fname}", + f"--vocab_size={vocab_size}", + f"--character_coverage={character_coverage}", + f"--model_prefix={model_fp.name}", + f"--model_type={model_type}", ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. - copy_rename_path = abs_model_path + '.rntmp' - if not model_path.startswith('gs://'): + copy_rename_path = abs_model_path + ".rntmp" + if not model_path.startswith("gs://"): tf.io.gfile.makedirs(abs_assets_path) - tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) + tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) - logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) + logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path) else: while not tf.io.gfile.exists(abs_model_path): time.sleep(1) time.sleep(1) return abs_model_path -def train_tokenizer(dataset: tf.data.Dataset, - *, - assets_path: str, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: Tuple[str] = ('text',)): + +def train_tokenizer( + dataset: tf.data.Dataset, + *, + assets_path: str, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: Tuple[str] = ("text",), +): """tokenizer training function""" - logging.info('SentencePiece vocab not found, building one from data.') + logging.info("SentencePiece vocab not found, building one from data.") vocab_path = _train_sentencepiece( dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, assets_path=assets_path, model_path=vocab_path, - data_keys=data_keys) - logging.info('Model saved at %s', vocab_path) + data_keys=data_keys, + ) + logging.info("Model saved at %s", vocab_path) def main(argv): del argv - os.environ['TFDS_DATA_DIR'] = _DATASET_PATH.value + os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value read_config = tfds.ReadConfig( - shuffle_seed = 0, + shuffle_seed=0, ) train_ds_builder = tfds.builder(_DATASET_NAME.value) - train_ds = train_ds_builder.as_dataset(split='train', read_config=read_config, shuffle_files=True) + train_ds = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) train_tokenizer( train_ds, assets_path=_ASSETS_PATH.value, @@ -162,5 +155,5 @@ def main(argv): ) -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/MaxText/vertex_tensorboard.py b/MaxText/vertex_tensorboard.py index 1cb438db5..9a106c32b 100644 --- a/MaxText/vertex_tensorboard.py +++ b/MaxText/vertex_tensorboard.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Utilities for Tensorboard in Vertex AI.""" @@ -40,7 +40,7 @@ def __del__(self): def setup(self): """Creates Tensorboard instance and Experiment in Vertex AI. - + Returns: URL to view Vertex Tensorboard created in Google Cloud Project. """ @@ -54,19 +54,21 @@ def setup(self): # Create Vertex Tensorboard instance vertex_tensorboard_name = os.environ.get("TENSORBOARD_NAME") - instance_id = tensorboard.create_instance(project=vertex_tensorboard_project, - location=vertex_tensorboard_region, - tensorboard_name=vertex_tensorboard_name) + instance_id = tensorboard.create_instance( + project=vertex_tensorboard_project, location=vertex_tensorboard_region, tensorboard_name=vertex_tensorboard_name + ) # Failed to create Vertex Tensorboard instance if instance_id is None: return None # Create Vertex Experiment vertex_experiment_name = os.environ.get("EXPERIMENT_NAME") - _, tensorboard_url = tensorboard.create_experiment(project=vertex_tensorboard_project, - location=vertex_tensorboard_region, - experiment_name=vertex_experiment_name, - tensorboard_name=vertex_tensorboard_name) + _, tensorboard_url = tensorboard.create_experiment( + project=vertex_tensorboard_project, + location=vertex_tensorboard_region, + experiment_name=vertex_experiment_name, + tensorboard_name=vertex_tensorboard_name, + ) return tensorboard_url def upload_data(self, tensorboard_dir): @@ -84,18 +86,22 @@ def upload_data(self, tensorboard_dir): max_logging.log("Vertex Tensorboard configurations are not set. Data will not be uploaded to Vertex AI.") self.uploader_flag = False - max_logging.log(f"Data will be uploaded to Vertex Tensorboard instance: {tensorboard_name} " - f"and Experiment: {experiment_name} in {tensorboard_region}.") - uploader.start_upload_to_tensorboard(project=tensorboard_project, - location=tensorboard_region, - experiment_name=experiment_name, - tensorboard_name=tensorboard_name, - logdir=tensorboard_dir) + max_logging.log( + f"Data will be uploaded to Vertex Tensorboard instance: {tensorboard_name} " + f"and Experiment: {experiment_name} in {tensorboard_region}." + ) + uploader.start_upload_to_tensorboard( + project=tensorboard_project, + location=tensorboard_region, + experiment_name=experiment_name, + tensorboard_name=tensorboard_name, + logdir=tensorboard_dir, + ) self.uploader_flag = True def configure_vertex_tensorboard(self, config): """Creates Vertex Tensorboard and start thread to upload data to Vertex Tensorboard.""" - if jax.process_index()==0: + if jax.process_index() == 0: if not os.environ.get("TENSORBOARD_PROJECT"): if not config.vertex_tensorboard_project: os.environ["TENSORBOARD_PROJECT"] = max_utils.get_project() @@ -112,11 +118,11 @@ def configure_vertex_tensorboard(self, config): if not os.environ.get("EXPERIMENT_NAME"): os.environ["EXPERIMENT_NAME"] = config.run_name - if config.use_vertex_tensorboard: # running MaxText on GCE + if config.use_vertex_tensorboard: # running MaxText on GCE tensorboard_url = self.setup() if tensorboard_url is None: raise ValueError("Unable to create Tensorboard and Experiment in Vertex AI.") max_logging.log(f"View your Vertex AI Tensorboard at: {tensorboard_url}") self.upload_data(config.tensorboard_dir) - elif os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): # running MaxText via XPK + elif os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): # running MaxText via XPK self.upload_data(config.tensorboard_dir) diff --git a/code_style.sh b/code_style.sh new file mode 100644 index 000000000..588ef70fa --- /dev/null +++ b/code_style.sh @@ -0,0 +1,33 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Clean up Python codes using Pylint & Pyink +# Googlers: please run `sudo apt install pipx; pipx install pylint --force; pipx install pyink==23.10.0` in advance + +set -e + +FOLDERS_TO_FORMAT=("MaxText" "pedagogical_examples") +LINE_LENGTH=$(grep -E "^max-line-length=" pylintrc | cut -d '=' -f 2) + +for folder in "${FOLDERS_TO_FORMAT[@]}" +do + pyink "$folder" --pyink-indentation=2 --line-length=${LINE_LENGTH} +done + +for folder in "${FOLDERS_TO_FORMAT[@]}" +do + pylint "./$folder" +done + +echo "Successfully clean up all codes." diff --git a/pedagogical_examples/non_spmd.py b/pedagogical_examples/non_spmd.py index 743bd4138..9918cbec3 100644 --- a/pedagogical_examples/non_spmd.py +++ b/pedagogical_examples/non_spmd.py @@ -1,22 +1,22 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -''' +""" This programs demonstrates embarrassingly parallelizable non-SPMD computations in Jax, in this case by having each process_index run its own computation. The same approach can be extended for non-embarrassingly parallelizable computations. @@ -24,7 +24,7 @@ then using a `host_local_array_to_global_array` to reshard into a new global array. An important limitation of this approach is that we cannot overlap communication and computation between the different kernel calls. -''' +""" import jax @@ -34,23 +34,21 @@ import numpy as np - - # Notice this is jax.local_devices(), not jax.devices(). Hence each process (on TPUVMs, each VM) will run separate programs # on its mesh. mesh = Mesh(np.array(jax.local_devices()), ["data"]) sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None)) idx = jax.process_index() + # Example step depends on idx which is different on each program def example_step(): - return idx * jax.numpy.ones((idx+1)) + return idx * jax.numpy.ones((idx + 1)) + jit_func = jax.jit( - example_step, - out_shardings=sharding, - ) + example_step, + out_shardings=sharding, +) print(f"{idx=} -> {jit_func()=}") - - diff --git a/pedagogical_examples/shardings.py b/pedagogical_examples/shardings.py index 85198c3ee..912266667 100644 --- a/pedagogical_examples/shardings.py +++ b/pedagogical_examples/shardings.py @@ -1,22 +1,22 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -'''This script is used to measure the performance of different sharding schemes on TPU.''' +"""This script is used to measure the performance of different sharding schemes on TPU.""" from absl import app from absl import flags @@ -32,79 +32,67 @@ from typing import Sequence parser = argparse.ArgumentParser( - description="Experiment different sharding techniques with a simple NN.\ + description="Experiment different sharding techniques with a simple NN.\ Ensure 1) The product of dcn dimensions == number of slices \ 2) product of ici dimension = number of devices per slice" - ) +) parser.add_argument( - "--profiler_path", "-p", + "--profiler_path", + "-p", required=False, default="", help="Path to the profiler where the script will write to.", - type=str -) -parser.add_argument( - "--embedding_dimension", "-d", - required=False, - default=2048, - type=int -) -parser.add_argument( - "--batch_size", "-b", - required=False, - default=131072, - type=int -) -parser.add_argument( - "--num_layers", "-n", - required=False, - default=4, - type=int + type=str, ) +parser.add_argument("--embedding_dimension", "-d", required=False, default=2048, type=int) +parser.add_argument("--batch_size", "-b", required=False, default=131072, type=int) +parser.add_argument("--num_layers", "-n", required=False, default=4, type=int) parser.add_argument( - "--dcn_data_parallelism", "-dd", - help="N-way Data Parallelism across slices", - required=False, - default=1, - type=int + "--dcn_data_parallelism", "-dd", help="N-way Data Parallelism across slices", required=False, default=1, type=int ) parser.add_argument( - "--dcn_fsdp_parallelism", "-df", + "--dcn_fsdp_parallelism", + "-df", help="Fsdp parallelism across slices that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--dcn_tensor_parallelism", "-dt", + "--dcn_tensor_parallelism", + "-dt", help="Tensor parallelism across slices that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--ici_data_parallelism", "-id", + "--ici_data_parallelism", + "-id", help="Data parallelism within each slice that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--ici_fsdp_parallelism", "-if", + "--ici_fsdp_parallelism", + "-if", help="Number of shards for Fsdp Parallelism within each slice.", required=False, default=4, - type=int + type=int, ) parser.add_argument( - "--ici_tensor_parallelism", "-it", + "--ici_tensor_parallelism", + "-it", help="Number of shards for Tensor Parallelism within each slice.", required=False, default=1, - type=int + type=int, ) args = parser.parse_args() + def main(_argv: Sequence[str]) -> None: def activate_profiler(profiler_path): if profiler_path: @@ -115,16 +103,16 @@ def deactivate_profiler(profiler_path): if profiler_path: jax.profiler.stop_trace() - def simple_timeit(f, tries = 5, verbose = True): - '''Simple utility to time a function for multiple runs''' + def simple_timeit(f, tries=5, verbose=True): + """Simple utility to time a function for multiple runs""" outcomes = [] - f() #warm it up! + f() # warm it up! for _ in range(tries): s = datetime.datetime.now() f() e = datetime.datetime.now() - outcomes.append((e-s).total_seconds()) - average_time = sum(outcomes)/len(outcomes) + outcomes.append((e - s).total_seconds()) + average_time = sum(outcomes) / len(outcomes) if verbose: print(f"average time: {average_time}, timings (seconds) {outcomes}") return average_time @@ -138,15 +126,18 @@ def simple_timeit(f, tries = 5, verbose = True): assert len(devices) > 1, "You must have at least two devices" # Assert that we have correct inputs of sharding that fit the number of chips - assert np.product(dcn_parallelism) * np.product(ici_parallelism) == num_devices, f"Number of devices {num_devices} \ + assert ( + np.product(dcn_parallelism) * np.product(ici_parallelism) == num_devices + ), f"Number of devices {num_devices} \ does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}" - multi_slice_env = hasattr(jax.devices()[0], 'slice_index') + multi_slice_env = hasattr(jax.devices()[0], "slice_index") # Create device mesh if multi_slice_env: - assert args.dcn_data_parallelism == 1 + max(x.slice_index for x in jax.devices()), \ - f"Number of slices given {args.dcn_data_parallelism} \ + assert args.dcn_data_parallelism == 1 + max( + x.slice_index for x in jax.devices() + ), f"Number of slices given {args.dcn_data_parallelism} \ does not match the number fetched from jax devices {jax.devices()[0]}" devices_array = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism) else: @@ -156,27 +147,27 @@ def simple_timeit(f, tries = 5, verbose = True): mesh = Mesh(devices_array, ["data", "fsdp", "tensor"]) - data_sharding = PartitionSpec(("data", "fsdp"), "tensor") + data_sharding = PartitionSpec(("data", "fsdp"), "tensor") # We assume parameters are stored in a decreasing order of dimension size parameter_sharding = PartitionSpec("tensor", "fsdp") BATCH = len(jax.devices()) * args.batch_size D_EMB = args.embedding_dimension - D_FF = 4 * D_EMB + D_FF = 4 * D_EMB NUM_LAYERS = args.num_layers parameters = 2 * D_FF * D_EMB * NUM_LAYERS parameter_bytes = 2 * parameters - activation_bytes = 2 * ( BATCH * ( D_FF+D_EMB) ) * NUM_LAYERS + activation_bytes = 2 * (BATCH * (D_FF + D_EMB)) * NUM_LAYERS memory_bytes = parameter_bytes + activation_bytes print(f"total {memory_bytes/1e9} GB, parameters {parameter_bytes/1e9} GB, activations {activation_bytes/1e9} GB") def gen_layer(random_key): - keys = jax.random.split(random_key, num = 4) + keys = jax.random.split(random_key, num=4) return { - "EMB2FF" : 1e-4 * jax.random.normal( keys[0], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), - "FF2EMB" : 1e-4 * jax.random.normal( keys[1], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), + "EMB2FF": 1e-4 * jax.random.normal(keys[0], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), + "FF2EMB": 1e-4 * jax.random.normal(keys[1], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), } def gen_layers(random_key): @@ -187,8 +178,7 @@ def gen_layers(random_key): return tuple(layers) def gen_data(random_key): - return jax.random.uniform(random_key, (BATCH, D_EMB), dtype=jax.numpy.bfloat16 ) - + return jax.random.uniform(random_key, (BATCH, D_EMB), dtype=jax.numpy.bfloat16) def multiply_layer(in_act, in_layer): with jax.named_scope("M1"): @@ -210,7 +200,7 @@ def multiply_layers(in_act, in_layers): return x, in_layers def multiply_layers_with_loss(in_act, in_layers): - x, _ = multiply_layers(in_act, in_layers) + x, _ = multiply_layers(in_act, in_layers) return jax.numpy.sum(x) multiply_layers_and_grad = jax.value_and_grad(multiply_layers_with_loss, argnums=[1]) @@ -220,39 +210,27 @@ def training_step(in_act, in_layers): out_layers = jax.tree_map(lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0]) return out_layers - print("finished includes ", flush = True) + print("finished includes ", flush=True) replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding) - parameter_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) - data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) jit_func = jax.jit( - training_step, - in_shardings=(replicated_sharding, parameter_mesh_shardings), - out_shardings=data_pspec_shardings, - ) + training_step, + in_shardings=(replicated_sharding, parameter_mesh_shardings), + out_shardings=data_pspec_shardings, + ) - data_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) + data_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) - jit_gen_data = jax.jit( - gen_data, - in_shardings=None, - out_shardings=data_mesh_shardings - ) + jit_gen_data = jax.jit(gen_data, in_shardings=None, out_shardings=data_mesh_shardings) - parameter_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) - jit_gen_layers = jax.jit( - gen_layers, - in_shardings=None, - out_shardings=parameter_mesh_shardings - ) + jit_gen_layers = jax.jit(gen_layers, in_shardings=None, out_shardings=parameter_mesh_shardings) # starting the profiler outside `with` statement, # will call it right before the computation once b/301309635 is resolved @@ -261,14 +239,16 @@ def training_step(in_act, in_layers): key = jax.random.PRNGKey(0) presharded_X = jax.block_until_ready(jit_gen_data(key)) presharded_layers = jax.block_until_ready(jit_gen_layers(key)) - TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices()) - time = simple_timeit(lambda : jax.block_until_ready(jit_func(presharded_X, presharded_layers))) - print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush = True) + TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices()) + time = simple_timeit(lambda: jax.block_until_ready(jit_func(presharded_X, presharded_layers))) + print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush=True) deactivate_profiler(args.profiler_path) + def parse_flags(argv): return parser.parse_args(argv[1:]) + if __name__ == "__main__": flags.FLAGS.mark_as_parsed() app.run(main, flags_parser=parse_flags) diff --git a/pedagogical_examples/shmap_collective_matmul.py b/pedagogical_examples/shmap_collective_matmul.py index fe8c38be9..de80fcf77 100644 --- a/pedagogical_examples/shmap_collective_matmul.py +++ b/pedagogical_examples/shmap_collective_matmul.py @@ -1,24 +1,25 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -'''This script is an example collective matmul.''' +"""This script is an example collective matmul.""" import os + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" import numpy as np @@ -27,7 +28,6 @@ import jax.numpy as jnp from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P -from jax.experimental import mesh_utils from jax.sharding import Mesh from jax.experimental.shard_map import shard_map @@ -50,51 +50,57 @@ import string import datetime -def simple_timeit(f, *args, tries = 10, trace_base_dir = None, task = None): - '''Simple utility to time a function for multiple runs''' - assert task is not None - trace_name = f"t_{task}_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) +def simple_timeit(f, *args, tries=10, trace_base_dir=None, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + + if trace_base_dir: + trace_dir = f"{trace_base_dir}/{trace_name}" + else: + trace_dir = None - if trace_base_dir: - trace_dir = f"{trace_base_dir}/{trace_name}" - else: - trace_dir = None + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + if trace_dir: + jax.profiler.start_trace(trace_dir) - outcomes_ms = [] - jax.block_until_ready(f(*args)) #warm it up! - if trace_dir: - jax.profiler.start_trace(trace_dir) + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + if trace_dir: + jax.profiler.stop_trace() - for _ in range(tries): - s = datetime.datetime.now() - jax.block_until_ready(f(*args)) - e = datetime.datetime.now() - outcomes_ms.append(1000*(e-s).total_seconds()) - if trace_dir: - jax.profiler.stop_trace() + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}") + return average_time_ms - average_time_ms = sum(outcomes_ms)/len(outcomes_ms) - print(f"{task}: average time milliseconds: {average_time_ms:.2f}") - return average_time_ms # gen data def gen_data_fn(): - key = jax.random.PRNGKey(np.random.randint(0, 256)) - activations = jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16) - weights = jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16) - return activations, weights + key = jax.random.PRNGKey(np.random.randint(0, 256)) + activations = jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16) # pylint: disable=redefined-outer-name + weights = jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16) # pylint: disable=redefined-outer-name + return activations, weights + data_fn = pjit( gen_data_fn, out_shardings=(P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None), P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None)), ) -def matmul(activations, weights): - return jnp.einsum("bsE,Ehd->bshd", activations, weights) + +def matmul(activations, weights): # pylint: disable=redefined-outer-name + return jnp.einsum("bsE,Ehd->bshd", activations, weights) + jit_matmul = pjit(matmul, out_shardings=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None)) + @partial( shard_map, mesh=global_mesh, @@ -105,30 +111,48 @@ def matmul(activations, weights): out_specs=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None), check_rep=False, ) -def collective_matmul(activations, weights): - print(f"sh_map {activations.shape=} {weights.shape=}") - - axis_size = jax.lax.psum(1, axis_name=MESH_TENSOR_AXIS) - axis_index = jax.lax.axis_index(axis_name=MESH_TENSOR_AXIS) - # The current sequence chunk - chunk_size = activations.shape[1] - mid_chunk = chunk_size // 2 - # create accum buffer - accum = jnp.zeros( - ( - activations.shape[0], - activations.shape[1] * axis_size, - weights.shape[-2], - weights.shape[-1], - ), - dtype=activations.dtype, - ) +def collective_matmul(activations, weights): # pylint: disable=redefined-outer-name + """Collective matrix multiply""" + print(f"sh_map {activations.shape=} {weights.shape=}") + + axis_size = jax.lax.psum(1, axis_name=MESH_TENSOR_AXIS) + axis_index = jax.lax.axis_index(axis_name=MESH_TENSOR_AXIS) + # The current sequence chunk + chunk_size = activations.shape[1] + mid_chunk = chunk_size // 2 + # create accum buffer + accum = jnp.zeros( + ( + activations.shape[0], + activations.shape[1] * axis_size, + weights.shape[-2], + weights.shape[-1], + ), + dtype=activations.dtype, + ) + + # compute first chunk + update = jnp.einsum("bsE,Ehd->bshd", activations, weights) + update_index = (0, axis_index * chunk_size, 0, 0) + accum = jax.lax.dynamic_update_slice(accum, update, update_index) + activation_forward, activation_backward = jnp.split(activations, 2, axis=1) + activation_forward = jax.lax.ppermute( + activation_forward, + axis_name=MESH_TENSOR_AXIS, + perm=[(j, (j + 1) % axis_size) for j in range(axis_size)], + ) + activation_backward = jax.lax.ppermute( + activation_backward, + axis_name=MESH_TENSOR_AXIS, + perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], + ) + + # split activations into chunks and send + def scanned_call(i, carrys): + accum, activation_forward, activation_backward = carrys + update_forward = jnp.einsum("bsE,Ehd->bshd", activation_forward, weights) + update_backward = jnp.einsum("bsE,Ehd->bshd", activation_backward, weights) - # compute first chunk - update = jnp.einsum("bsE,Ehd->bshd", activations, weights) - update_index = (0, axis_index * chunk_size, 0, 0) - accum = jax.lax.dynamic_update_slice(accum, update, update_index) - activation_forward, activation_backward = jnp.split(activations, 2, axis=1) activation_forward = jax.lax.ppermute( activation_forward, axis_name=MESH_TENSOR_AXIS, @@ -140,70 +164,46 @@ def collective_matmul(activations, weights): perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], ) - # split activations into chunks and send - def scanned_call(i, carrys): - accum, activation_forward, activation_backward = carrys - update_forward = jnp.einsum("bsE,Ehd->bshd", activation_forward, weights) - update_backward = jnp.einsum("bsE,Ehd->bshd", activation_backward, weights) - - activation_forward = jax.lax.ppermute( - activation_forward, - axis_name=MESH_TENSOR_AXIS, - perm=[(j, (j + 1) % axis_size) for j in range(axis_size)], - ) - activation_backward = jax.lax.ppermute( - activation_backward, - axis_name=MESH_TENSOR_AXIS, - perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], - ) - - forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size - backward_update_index = ((axis_index + i + 1) % axis_size) * chunk_size + mid_chunk - - accum = jax.lax.dynamic_update_slice(accum, update_forward, (0, forward_update_index, 0, 0)) - accum = jax.lax.dynamic_update_slice(accum, update_backward, (0, backward_update_index, 0, 0)) - return (accum, activation_forward, activation_backward) - - print(f"{accum.shape=}") - - accum, _, _ = jax.lax.fori_loop( - 0, (axis_size - 1), scanned_call, (accum, activation_forward, activation_backward) - ) - return accum - -with global_mesh: - activations, weights = data_fn() + forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size + backward_update_index = ((axis_index + i + 1) % axis_size) * chunk_size + mid_chunk - jax.block_until_ready(activations) - jax.block_until_ready(weights) + accum = jax.lax.dynamic_update_slice(accum, update_forward, (0, forward_update_index, 0, 0)) + accum = jax.lax.dynamic_update_slice(accum, update_backward, (0, backward_update_index, 0, 0)) + return (accum, activation_forward, activation_backward) - @jax.jit - def run_naive(_activations, _weights): - with jax.named_scope("naive_matmul"): - outputs = jit_matmul(_activations, _weights) - return outputs + print(f"{accum.shape=}") - @jax.jit - def run_collective(_activations, _weights): - with jax.named_scope("collective_matmul"): - manual_outputs = jax.jit(collective_matmul)(_activations, _weights) - return manual_outputs + accum, _, _ = jax.lax.fori_loop(0, (axis_size - 1), scanned_call, (accum, activation_forward, activation_backward)) + return accum +with global_mesh: + activations, weights = data_fn() - - naive_outputs = run_naive(activations, weights) - collective_outputs = run_collective(activations, weights) + jax.block_until_ready(activations) + jax.block_until_ready(weights) - print(f"input {activations.shape=} {activations.addressable_shards[0].data.shape=}") - print(f"input {weights.shape=} {weights.addressable_shards[0].data.shape=}") - print(f"naive_outputs {naive_outputs.shape=} {naive_outputs.addressable_shards[0].data.shape=}") - print(f"collective_outputs {collective_outputs.shape=} {collective_outputs.addressable_shards[0].data.shape=}") + @jax.jit + def run_naive(_activations, _weights): + with jax.named_scope("naive_matmul"): + outputs = jit_matmul(_activations, _weights) + return outputs + @jax.jit + def run_collective(_activations, _weights): + with jax.named_scope("collective_matmul"): + manual_outputs = jax.jit(collective_matmul)(_activations, _weights) + return manual_outputs + naive_outputs = run_naive(activations, weights) + collective_outputs = run_collective(activations, weights) - assert jnp.allclose(naive_outputs, collective_outputs), "Two algorithms should match but don't" + print(f"input {activations.shape=} {activations.addressable_shards[0].data.shape=}") + print(f"input {weights.shape=} {weights.addressable_shards[0].data.shape=}") + print(f"naive_outputs {naive_outputs.shape=} {naive_outputs.addressable_shards[0].data.shape=}") + print(f"collective_outputs {collective_outputs.shape=} {collective_outputs.addressable_shards[0].data.shape=}") - simple_timeit(run_naive, activations, weights, task = "naive") - simple_timeit(run_collective, activations, weights, task = "collective") + assert jnp.allclose(naive_outputs, collective_outputs), "Two algorithms should match but don't" + simple_timeit(run_naive, activations, weights, task="naive") + simple_timeit(run_collective, activations, weights, task="collective") diff --git a/pylintrc b/pylintrc index 4897238f0..4ebb6438a 100644 --- a/pylintrc +++ b/pylintrc @@ -29,6 +29,8 @@ jobs=4 # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no +# check python modules in the dir recursively +recursive=y [MESSAGES CONTROL]