diff --git a/.github/workflows/CPUTests.yaml b/.github/workflows/CPUTests.yaml new file mode 100644 index 000000000..a6cd25a1e --- /dev/null +++ b/.github/workflows/CPUTests.yaml @@ -0,0 +1,38 @@ +name: Linter + +on: + push: + branches: + - '**' + +jobs: + cpu: + name: "CPU tests" + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-20.04] + python-version: ['3.10'] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install pylint pyink pytype==2024.2.27 + - name: Typecheck the code with pytype + run: | + pytype --jobs auto --disable import-error MaxText/ + - name: Analysing the code with pylint in Maxtext/ + run: | + pylint MaxText/ && \ + echo 'Maxtext PyLint check successful' || { echo \ + 'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; } + - name: Analysing the code with pylint in pedagogical_examples/ + run: | + pylint pedagogical_examples/ && \ + echo 'PyLint check on pedagogical_examples/ is successful' || { echo \ + 'PyLint check has failed. Please run bash code_style.sh to fix issues'; exit 20; } diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index a2feaf336..52156c130 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -27,31 +27,6 @@ on: - cron: '0 */2 * * *' jobs: - cpu: - name: "CPU test" - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-20.04] - python-version: ['3.10'] - steps: - - uses: actions/checkout@v3 - - name: setup python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install Dependencies - run: | - pip install pytype==2024.2.27 - pip install pylint - - name: Typecheck the code with pytype - run: | - pytype --jobs auto --disable import-error MaxText/ - - name: Analysing the code with pylint - run: | - pylint MaxText/ - - # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'gpu' job tpu: strategy: @@ -98,7 +73,7 @@ jobs: - name: Test int8_training run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ - 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false' + 'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false' - name: Test fp8_training run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ @@ -123,7 +98,7 @@ jobs: run: | docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ 'python3 pedagogical_examples/shmap_collective_matmul.py' - + # IF YOU MODIFY THIS, YOU SHOULD ALSO ADD CORRESPONDING MODICATIONS TO 'tpu' job gpu: strategy: diff --git a/MaxText/__init__.py b/MaxText/__init__.py index 83f918e62..c133d2d71 100644 --- a/MaxText/__init__.py +++ b/MaxText/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/MaxText/accelerator_to_spec_map.py b/MaxText/accelerator_to_spec_map.py index fa6b64d0a..22d3da5dd 100644 --- a/MaxText/accelerator_to_spec_map.py +++ b/MaxText/accelerator_to_spec_map.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Static map of TPU names such as v4-8 to properties such as chip layout.""" @@ -22,358 +22,349 @@ from dataclasses import dataclass + @dataclass class SystemCharacteristics: platform: str topology_name: str - chip_config_name: str # 'megacore' or 'default' + chip_config_name: str # 'megacore' or 'default' chips_per_host_bounds: tuple devices_per_slice: int + UserFacingNameToSystemCharacteristics = { # v5e - 'v5e-16': SystemCharacteristics( - 'tpu', 'v5e:4x4', 'default', (2, 2, 1), 16 - ), - 'v5e-32': SystemCharacteristics( - 'tpu', 'v5e:4x8', 'default', (2, 2, 1), 32 - ), - 'v5e-64': SystemCharacteristics( - 'tpu', 'v5e:8x8', 'default', (2, 2, 1), 64 + "v5e-16": SystemCharacteristics("tpu", "v5e:4x4", "default", (2, 2, 1), 16), + "v5e-32": SystemCharacteristics("tpu", "v5e:4x8", "default", (2, 2, 1), 32), + "v5e-64": SystemCharacteristics("tpu", "v5e:8x8", "default", (2, 2, 1), 64), + "v5e-128": SystemCharacteristics( + "tpu", "v5e:8x16", "default", (2, 2, 1), 128 ), - 'v5e-128': SystemCharacteristics( - 'tpu', 'v5e:8x16', 'default', (2, 2, 1), 128 - ), - 'v5e-256': SystemCharacteristics( - 'tpu', 'v5e:16x16', 'default', (2, 2, 1), 256 + "v5e-256": SystemCharacteristics( + "tpu", "v5e:16x16", "default", (2, 2, 1), 256 ), # v4 - 'v4-8': SystemCharacteristics( - 'tpu', 'v4:2x2x1', 'megacore', (2, 2, 1), 4 - ), - 'v4-16': SystemCharacteristics( - 'tpu', 'v4:2x2x2', 'megacore', (2, 2, 1), 8 - ), - 'v4-32': SystemCharacteristics( - 'tpu', 'v4:2x2x4', 'megacore', (2, 2, 1), 16 + "v4-8": SystemCharacteristics("tpu", "v4:2x2x1", "megacore", (2, 2, 1), 4), + "v4-16": SystemCharacteristics("tpu", "v4:2x2x2", "megacore", (2, 2, 1), 8), + "v4-32": SystemCharacteristics( + "tpu", "v4:2x2x4", "megacore", (2, 2, 1), 16 ), - 'v4-64': SystemCharacteristics( - 'tpu', 'v4:2x4x4', 'megacore', (2, 2, 1), 32 + "v4-64": SystemCharacteristics( + "tpu", "v4:2x4x4", "megacore", (2, 2, 1), 32 ), - 'v4-128': SystemCharacteristics( - 'tpu', 'v4:4x4x4', 'megacore', (2, 2, 1), 64 + "v4-128": SystemCharacteristics( + "tpu", "v4:4x4x4", "megacore", (2, 2, 1), 64 ), - 'v4-256': SystemCharacteristics( - 'tpu', 'v4:4x4x8', 'megacore', (2, 2, 1), 128 + "v4-256": SystemCharacteristics( + "tpu", "v4:4x4x8", "megacore", (2, 2, 1), 128 ), - 'v4-384': SystemCharacteristics( - 'tpu', 'v4:4x4x12', 'megacore', (2, 2, 1), 192 + "v4-384": SystemCharacteristics( + "tpu", "v4:4x4x12", "megacore", (2, 2, 1), 192 ), - 'v4-512': SystemCharacteristics( - 'tpu', 'v4:4x8x8', 'megacore', (2, 2, 1), 256 + "v4-512": SystemCharacteristics( + "tpu", "v4:4x8x8", "megacore", (2, 2, 1), 256 ), - 'v4-1024': SystemCharacteristics( - 'tpu', 'v4:8x8x8', 'megacore', (2, 2, 1), 512 + "v4-1024": SystemCharacteristics( + "tpu", "v4:8x8x8", "megacore", (2, 2, 1), 512 ), - 'v4-1536': SystemCharacteristics( - 'tpu', 'v4:8x8x12','megacore', (2, 2, 1), 768 + "v4-1536": SystemCharacteristics( + "tpu", "v4:8x8x12", "megacore", (2, 2, 1), 768 ), - 'v4-2048': SystemCharacteristics( - 'tpu', 'v4:8x8x16','megacore', (2, 2, 1), 1024 + "v4-2048": SystemCharacteristics( + "tpu", "v4:8x8x16", "megacore", (2, 2, 1), 1024 ), - 'v4-4096': SystemCharacteristics( - 'tpu', 'v4:8x16x16', 'megacore', (2, 2, 1), 2048 + "v4-4096": SystemCharacteristics( + "tpu", "v4:8x16x16", "megacore", (2, 2, 1), 2048 ), # v5p - 'v5p-8': SystemCharacteristics( - 'tpu', 'v5:2x2x1', 'megacore', (2, 2, 1), 4 + "v5p-8": SystemCharacteristics("tpu", "v5:2x2x1", "megacore", (2, 2, 1), 4), + "v5p-16": SystemCharacteristics( + "tpu", "v5:2x2x2", "megacore", (2, 2, 1), 8 ), - 'v5p-16': SystemCharacteristics( - 'tpu', 'v5:2x2x2', 'megacore', (2, 2, 1), 8 + "v5p-32": SystemCharacteristics( + "tpu", "v5:2x2x4", "megacore", (2, 2, 1), 16 ), - 'v5p-32': SystemCharacteristics( - 'tpu', 'v5:2x2x4', 'megacore', (2, 2, 1), 16 + "v5p-64": SystemCharacteristics( + "tpu", "v5:2x4x4", "megacore", (2, 2, 1), 32 ), - 'v5p-64': SystemCharacteristics( - 'tpu', 'v5:2x4x4', 'megacore', (2, 2, 1), 32 + "v5p-128": SystemCharacteristics( + "tpu", "v5:4x4x4", "megacore", (2, 2, 1), 64 ), - 'v5p-128': SystemCharacteristics( - 'tpu', 'v5:4x4x4', 'megacore', (2, 2, 1), 64 + "v5p-256": SystemCharacteristics( + "tpu", "v5:4x4x8", "megacore", (2, 2, 1), 128 ), - 'v5p-256': SystemCharacteristics( - 'tpu', 'v5:4x4x8', 'megacore', (2, 2, 1), 128 + "v5p-384": SystemCharacteristics( + "tpu", "v5:4x4x12", "megacore", (2, 2, 1), 192 ), - 'v5p-384': SystemCharacteristics( - 'tpu', 'v5:4x4x12', 'megacore', (2, 2, 1), 192 + "v5p-512": SystemCharacteristics( + "tpu", "v5:4x8x8", "megacore", (2, 2, 1), 256 ), - 'v5p-512': SystemCharacteristics( - 'tpu', 'v5:4x8x8', 'megacore', (2, 2, 1), 256 + "v5p-640": SystemCharacteristics( + "tpu", "v5:4x4x20", "megacore", (2, 2, 1), 320 ), - 'v5p-640': SystemCharacteristics( - 'tpu', 'v5:4x4x20', 'megacore', (2, 2, 1), 320 + "v5p-768": SystemCharacteristics( + "tpu", "v5:4x8x12", "megacore", (2, 2, 1), 384 ), - 'v5p-768': SystemCharacteristics( - 'tpu', 'v5:4x8x12', 'megacore', (2, 2, 1), 384 + "v5p-896": SystemCharacteristics( + "tpu", "v5:4x4x28", "megacore", (2, 2, 1), 448 ), - 'v5p-896': SystemCharacteristics( - 'tpu', 'v5:4x4x28', 'megacore', (2, 2, 1), 448 + "v5p-1024": SystemCharacteristics( + "tpu", "v5:8x8x8", "megacore", (2, 2, 1), 512 ), - 'v5p-1024': SystemCharacteristics( - 'tpu', 'v5:8x8x8', 'megacore', (2, 2, 1), 512 + "v5p-1152": SystemCharacteristics( + "tpu", "v5:4x12x12", "megacore", (2, 2, 1), 576 ), - 'v5p-1152': SystemCharacteristics( - 'tpu', 'v5:4x12x12', 'megacore', (2, 2, 1), 576 + "v5p-1280": SystemCharacteristics( + "tpu", "v5:4x8x20", "megacore", (2, 2, 1), 640 ), - 'v5p-1280': SystemCharacteristics( - 'tpu', 'v5:4x8x20', 'megacore', (2, 2, 1), 640 + "v5p-1408": SystemCharacteristics( + "tpu", "v5:4x4x44", "megacore", (2, 2, 1), 704 ), - 'v5p-1408': SystemCharacteristics( - 'tpu', 'v5:4x4x44', 'megacore', (2, 2, 1), 704 + "v5p-1536": SystemCharacteristics( + "tpu", "v5:8x8x12", "megacore", (2, 2, 1), 768 ), - 'v5p-1536': SystemCharacteristics( - 'tpu', 'v5:8x8x12', 'megacore', (2, 2, 1), 768 + "v5p-1664": SystemCharacteristics( + "tpu", "v5:4x4x52", "megacore", (2, 2, 1), 832 ), - 'v5p-1664': SystemCharacteristics( - 'tpu', 'v5:4x4x52', 'megacore', (2, 2, 1), 832 + "v5p-1792": SystemCharacteristics( + "tpu", "v5:4x8x28", "megacore", (2, 2, 1), 896 ), - 'v5p-1792': SystemCharacteristics( - 'tpu', 'v5:4x8x28', 'megacore', (2, 2, 1), 896 + "v5p-1920": SystemCharacteristics( + "tpu", "v5:4x12x20", "megacore", (2, 2, 1), 960 ), - 'v5p-1920': SystemCharacteristics( - 'tpu', 'v5:4x12x20', 'megacore', (2, 2, 1), 960 + "v5p-2048": SystemCharacteristics( + "tpu", "v5:8x8x16", "megacore", (2, 2, 1), 1024 ), - 'v5p-2048': SystemCharacteristics( - 'tpu', 'v5:8x8x16', 'megacore', (2, 2, 1), 1024 + "v5p-2176": SystemCharacteristics( + "tpu", "v5:4x4x68", "megacore", (2, 2, 1), 1088 ), - 'v5p-2176': SystemCharacteristics( - 'tpu', 'v5:4x4x68', 'megacore', (2, 2, 1), 1088 + "v5p-2304": SystemCharacteristics( + "tpu", "v5:8x12x12", "megacore", (2, 2, 1), 1152 ), - 'v5p-2304': SystemCharacteristics( - 'tpu', 'v5:8x12x12', 'megacore', (2, 2, 1), 1152 + "v5p-2432": SystemCharacteristics( + "tpu", "v5:4x4x76", "megacore", (2, 2, 1), 1216 ), - 'v5p-2432': SystemCharacteristics( - 'tpu', 'v5:4x4x76', 'megacore', (2, 2, 1), 1216 + "v5p-2560": SystemCharacteristics( + "tpu", "v5:8x8x20", "megacore", (2, 2, 1), 1280 ), - 'v5p-2560': SystemCharacteristics( - 'tpu', 'v5:8x8x20', 'megacore', (2, 2, 1), 1280 + "v5p-2688": SystemCharacteristics( + "tpu", "v5:4x12x28", "megacore", (2, 2, 1), 1344 ), - 'v5p-2688': SystemCharacteristics( - 'tpu', 'v5:4x12x28', 'megacore', (2, 2, 1), 1344 + "v5p-2816": SystemCharacteristics( + "tpu", "v5:4x8x44", "megacore", (2, 2, 1), 1408 ), - 'v5p-2816': SystemCharacteristics( - 'tpu', 'v5:4x8x44', 'megacore', (2, 2, 1), 1408 + "v5p-2944": SystemCharacteristics( + "tpu", "v5:4x4x92", "megacore", (2, 2, 1), 1472 ), - 'v5p-2944': SystemCharacteristics( - 'tpu', 'v5:4x4x92', 'megacore', (2, 2, 1), 1472 + "v5p-3072": SystemCharacteristics( + "tpu", "v5:8x12x16", "megacore", (2, 2, 1), 1536 ), - 'v5p-3072': SystemCharacteristics( - 'tpu', 'v5:8x12x16', 'megacore', (2, 2, 1), 1536 + "v5p-3200": SystemCharacteristics( + "tpu", "v5:4x20x20", "megacore", (2, 2, 1), 1600 ), - 'v5p-3200': SystemCharacteristics( - 'tpu', 'v5:4x20x20', 'megacore', (2, 2, 1), 1600 + "v5p-3328": SystemCharacteristics( + "tpu", "v5:4x8x52", "megacore", (2, 2, 1), 1664 ), - 'v5p-3328': SystemCharacteristics( - 'tpu', 'v5:4x8x52', 'megacore', (2, 2, 1), 1664 + "v5p-3456": SystemCharacteristics( + "tpu", "v5:12x12x12", "megacore", (2, 2, 1), 1728 ), - 'v5p-3456': SystemCharacteristics( - 'tpu', 'v5:12x12x12', 'megacore', (2, 2, 1), 1728 + "v5p-3584": SystemCharacteristics( + "tpu", "v5:8x8x28", "megacore", (2, 2, 1), 1792 ), - 'v5p-3584': SystemCharacteristics( - 'tpu', 'v5:8x8x28', 'megacore', (2, 2, 1), 1792 + "v5p-3712": SystemCharacteristics( + "tpu", "v5:4x4x116", "megacore", (2, 2, 1), 1856 ), - 'v5p-3712': SystemCharacteristics( - 'tpu', 'v5:4x4x116', 'megacore', (2, 2, 1), 1856 + "v5p-3840": SystemCharacteristics( + "tpu", "v5:8x12x20", "megacore", (2, 2, 1), 1920 ), - 'v5p-3840': SystemCharacteristics( - 'tpu', 'v5:8x12x20', 'megacore', (2, 2, 1), 1920 + "v5p-3968": SystemCharacteristics( + "tpu", "v5:4x4x124", "megacore", (2, 2, 1), 1984 ), - 'v5p-3968': SystemCharacteristics( - 'tpu', 'v5:4x4x124', 'megacore', (2, 2, 1), 1984 + "v5p-4096": SystemCharacteristics( + "tpu", "v5:8x16x16", "megacore", (2, 2, 1), 2048 ), - 'v5p-4096': SystemCharacteristics( - 'tpu', 'v5:8x16x16', 'megacore', (2, 2, 1), 2048 + "v5p-4224": SystemCharacteristics( + "tpu", "v5:4x12x44", "megacore", (2, 2, 1), 2112 ), - 'v5p-4224': SystemCharacteristics( - 'tpu', 'v5:4x12x44', 'megacore', (2, 2, 1), 2112 + "v5p-4352": SystemCharacteristics( + "tpu", "v5:4x8x68", "megacore", (2, 2, 1), 2176 ), - 'v5p-4352': SystemCharacteristics( - 'tpu', 'v5:4x8x68', 'megacore', (2, 2, 1), 2176 + "v5p-4480": SystemCharacteristics( + "tpu", "v5:4x20x28", "megacore", (2, 2, 1), 2240 ), - 'v5p-4480': SystemCharacteristics( - 'tpu', 'v5:4x20x28', 'megacore', (2, 2, 1), 2240 + "v5p-4608": SystemCharacteristics( + "tpu", "v5:12x12x16", "megacore", (2, 2, 1), 2304 ), - 'v5p-4608': SystemCharacteristics( - 'tpu', 'v5:12x12x16', 'megacore', (2, 2, 1), 2304 + "v5p-4736": SystemCharacteristics( + "tpu", "v5:4x4x148", "megacore", (2, 2, 1), 2368 ), - 'v5p-4736': SystemCharacteristics( - 'tpu', 'v5:4x4x148', 'megacore', (2, 2, 1), 2368 + "v5p-4864": SystemCharacteristics( + "tpu", "v5:4x8x76", "megacore", (2, 2, 1), 2432 ), - 'v5p-4864': SystemCharacteristics( - 'tpu', 'v5:4x8x76', 'megacore', (2, 2, 1), 2432 + "v5p-4992": SystemCharacteristics( + "tpu", "v5:4x12x52", "megacore", (2, 2, 1), 2496 ), - 'v5p-4992': SystemCharacteristics( - 'tpu', 'v5:4x12x52', 'megacore', (2, 2, 1), 2496 + "v5p-5120": SystemCharacteristics( + "tpu", "v5:8x16x20", "megacore", (2, 2, 1), 2560 ), - 'v5p-5120': SystemCharacteristics( - 'tpu', 'v5:8x16x20', 'megacore', (2, 2, 1), 2560 + "v5p-5248": SystemCharacteristics( + "tpu", "v5:4x4x164", "megacore", (2, 2, 1), 2624 ), - 'v5p-5248': SystemCharacteristics( - 'tpu', 'v5:4x4x164', 'megacore', (2, 2, 1), 2624 + "v5p-5376": SystemCharacteristics( + "tpu", "v5:8x12x28", "megacore", (2, 2, 1), 2688 ), - 'v5p-5376': SystemCharacteristics( - 'tpu', 'v5:8x12x28', 'megacore', (2, 2, 1), 2688 + "v5p-5504": SystemCharacteristics( + "tpu", "v5:4x4x172", "megacore", (2, 2, 1), 2752 ), - 'v5p-5504': SystemCharacteristics( - 'tpu', 'v5:4x4x172', 'megacore', (2, 2, 1), 2752 + "v5p-5632": SystemCharacteristics( + "tpu", "v5:8x8x44", "megacore", (2, 2, 1), 2816 ), - 'v5p-5632': SystemCharacteristics( - 'tpu', 'v5:8x8x44', 'megacore', (2, 2, 1), 2816 + "v5p-5760": SystemCharacteristics( + "tpu", "v5:12x12x20", "megacore", (2, 2, 1), 2880 ), - 'v5p-5760': SystemCharacteristics( - 'tpu', 'v5:12x12x20', 'megacore', (2, 2, 1), 2880 + "v5p-5888": SystemCharacteristics( + "tpu", "v5:4x8x92", "megacore", (2, 2, 1), 2944 ), - 'v5p-5888': SystemCharacteristics( - 'tpu', 'v5:4x8x92', 'megacore', (2, 2, 1), 2944 + "v5p-6016": SystemCharacteristics( + "tpu", "v5:4x4x188", "megacore", (2, 2, 1), 3008 ), - 'v5p-6016': SystemCharacteristics( - 'tpu', 'v5:4x4x188', 'megacore', (2, 2, 1), 3008 + "v5p-6144": SystemCharacteristics( + "tpu", "v5:12x16x16", "megacore", (2, 2, 1), 3072 ), - 'v5p-6144': SystemCharacteristics( - 'tpu', 'v5:12x16x16', 'megacore', (2, 2, 1), 3072 + "v5p-6272": SystemCharacteristics( + "tpu", "v5:4x28x28", "megacore", (2, 2, 1), 3136 ), - 'v5p-6272': SystemCharacteristics( - 'tpu', 'v5:4x28x28', 'megacore', (2, 2, 1), 3136 + "v5p-6400": SystemCharacteristics( + "tpu", "v5:8x20x20", "megacore", (2, 2, 1), 3200 ), - 'v5p-6400': SystemCharacteristics( - 'tpu', 'v5:8x20x20', 'megacore', (2, 2, 1), 3200 + "v5p-6528": SystemCharacteristics( + "tpu", "v5:4x12x68", "megacore", (2, 2, 1), 3264 ), - 'v5p-6528': SystemCharacteristics( - 'tpu', 'v5:4x12x68', 'megacore', (2, 2, 1), 3264 + "v5p-6656": SystemCharacteristics( + "tpu", "v5:8x8x52", "megacore", (2, 2, 1), 3328 ), - 'v5p-6656': SystemCharacteristics( - 'tpu', 'v5:8x8x52', 'megacore', (2, 2, 1), 3328 + "v5p-6784": SystemCharacteristics( + "tpu", "v5:4x4x212", "megacore", (2, 2, 1), 3392 ), - 'v5p-6784': SystemCharacteristics( - 'tpu', 'v5:4x4x212', 'megacore', (2, 2, 1), 3392 + "v5p-6912": SystemCharacteristics( + "tpu", "v5:12x12x24", "megacore", (2, 2, 1), 3456 ), - 'v5p-6912': SystemCharacteristics( - 'tpu', 'v5:12x12x24', 'megacore', (2, 2, 1), 3456 + "v5p-7040": SystemCharacteristics( + "tpu", "v5:4x20x44", "megacore", (2, 2, 1), 3520 ), - 'v5p-7040': SystemCharacteristics( - 'tpu', 'v5:4x20x44', 'megacore', (2, 2, 1), 3520 + "v5p-7168": SystemCharacteristics( + "tpu", "v5:8x16x28", "megacore", (2, 2, 1), 3584 ), - 'v5p-7168': SystemCharacteristics( - 'tpu', 'v5:8x16x28', 'megacore', (2, 2, 1), 3584 + "v5p-7296": SystemCharacteristics( + "tpu", "v5:4x12x76", "megacore", (2, 2, 1), 3648 ), - 'v5p-7296': SystemCharacteristics( - 'tpu', 'v5:4x12x76', 'megacore', (2, 2, 1), 3648 + "v5p-7424": SystemCharacteristics( + "tpu", "v5:4x8x116", "megacore", (2, 2, 1), 3712 ), - 'v5p-7424': SystemCharacteristics( - 'tpu', 'v5:4x8x116', 'megacore', (2, 2, 1), 3712 + "v5p-7552": SystemCharacteristics( + "tpu", "v5:4x4x236", "megacore", (2, 2, 1), 3776 ), - 'v5p-7552': SystemCharacteristics( - 'tpu', 'v5:4x4x236', 'megacore', (2, 2, 1), 3776 + "v5p-7680": SystemCharacteristics( + "tpu", "v5:12x16x20", "megacore", (2, 2, 1), 3840 ), - 'v5p-7680': SystemCharacteristics( - 'tpu', 'v5:12x16x20', 'megacore', (2, 2, 1), 3840 + "v5p-7808": SystemCharacteristics( + "tpu", "v5:4x4x244", "megacore", (2, 2, 1), 3904 ), - 'v5p-7808': SystemCharacteristics( - 'tpu', 'v5:4x4x244', 'megacore', (2, 2, 1), 3904 + "v5p-7936": SystemCharacteristics( + "tpu", "v5:4x8x124", "megacore", (2, 2, 1), 3968 ), - 'v5p-7936': SystemCharacteristics( - 'tpu', 'v5:4x8x124', 'megacore', (2, 2, 1), 3968 + "v5p-8064": SystemCharacteristics( + "tpu", "v5:12x12x28", "megacore", (2, 2, 1), 4032 ), - 'v5p-8064': SystemCharacteristics( - 'tpu', 'v5:12x12x28', 'megacore', (2, 2, 1), 4032 + "v5p-8192": SystemCharacteristics( + "tpu", "v5:16x16x16", "megacore", (2, 2, 1), 4096 ), - 'v5p-8192': SystemCharacteristics( - 'tpu', 'v5:16x16x16', 'megacore', (2, 2, 1), 4096 + "v5p-8320": SystemCharacteristics( + "tpu", "v5:4x20x52", "megacore", (2, 2, 1), 4160 ), - 'v5p-8320': SystemCharacteristics( - 'tpu', 'v5:4x20x52', 'megacore', (2, 2, 1), 4160 + "v5p-8448": SystemCharacteristics( + "tpu", "v5:8x12x44", "megacore", (2, 2, 1), 4224 ), - 'v5p-8448': SystemCharacteristics( - 'tpu', 'v5:8x12x44', 'megacore', (2, 2, 1), 4224 + "v5p-8704": SystemCharacteristics( + "tpu", "v5:8x8x68", "megacore", (2, 2, 1), 4352 ), - 'v5p-8704': SystemCharacteristics( - 'tpu', 'v5:8x8x68', 'megacore', (2, 2, 1), 4352 + "v5p-8832": SystemCharacteristics( + "tpu", "v5:4x12x92", "megacore", (2, 2, 1), 4416 ), - 'v5p-8832': SystemCharacteristics( - 'tpu', 'v5:4x12x92', 'megacore', (2, 2, 1), 4416 + "v5p-8960": SystemCharacteristics( + "tpu", "v5:8x20x28", "megacore", (2, 2, 1), 4480 ), - 'v5p-8960': SystemCharacteristics( - 'tpu', 'v5:8x20x28', 'megacore', (2, 2, 1), 4480 + "v5p-9216": SystemCharacteristics( + "tpu", "v5:12x16x24", "megacore", (2, 2, 1), 4608 ), - 'v5p-9216': SystemCharacteristics( - 'tpu', 'v5:12x16x24', 'megacore', (2, 2, 1), 4608 + "v5p-9472": SystemCharacteristics( + "tpu", "v5:4x8x148", "megacore", (2, 2, 1), 4736 ), - 'v5p-9472': SystemCharacteristics( - 'tpu', 'v5:4x8x148', 'megacore', (2, 2, 1), 4736 + "v5p-9600": SystemCharacteristics( + "tpu", "v5:12x20x20", "megacore", (2, 2, 1), 4800 ), - 'v5p-9600': SystemCharacteristics( - 'tpu', 'v5:12x20x20', 'megacore', (2, 2, 1), 4800 + "v5p-9728": SystemCharacteristics( + "tpu", "v5:8x8x76", "megacore", (2, 2, 1), 4864 ), - 'v5p-9728': SystemCharacteristics( - 'tpu', 'v5:8x8x76', 'megacore', (2, 2, 1), 4864 + "v5p-9856": SystemCharacteristics( + "tpu", "v5:4x28x44", "megacore", (2, 2, 1), 4928 ), - 'v5p-9856': SystemCharacteristics( - 'tpu', 'v5:4x28x44', 'megacore', (2, 2, 1), 4928 + "v5p-9984": SystemCharacteristics( + "tpu", "v5:8x12x52", "megacore", (2, 2, 1), 4992 ), - 'v5p-9984': SystemCharacteristics( - 'tpu', 'v5:8x12x52', 'megacore', (2, 2, 1), 4992 + "v5p-10240": SystemCharacteristics( + "tpu", "v5:16x16x20", "megacore", (2, 2, 1), 5120 ), - 'v5p-10240': SystemCharacteristics( - 'tpu', 'v5:16x16x20', 'megacore', (2, 2, 1), 5120 + "v5p-10368": SystemCharacteristics( + "tpu", "v5:12x12x36", "megacore", (2, 2, 1), 5184 ), - 'v5p-10368': SystemCharacteristics( - 'tpu', 'v5:12x12x36', 'megacore', (2, 2, 1), 5184 + "v5p-10496": SystemCharacteristics( + "tpu", "v5:4x8x164", "megacore", (2, 2, 1), 5248 ), - 'v5p-10496': SystemCharacteristics( - 'tpu', 'v5:4x8x164', 'megacore', (2, 2, 1), 5248 + "v5p-10752": SystemCharacteristics( + "tpu", "v5:12x16x28", "megacore", (2, 2, 1), 5376 ), - 'v5p-10752': SystemCharacteristics( - 'tpu', 'v5:12x16x28', 'megacore', (2, 2, 1), 5376 + "v5p-10880": SystemCharacteristics( + "tpu", "v5:4x20x68", "megacore", (2, 2, 1), 5440 ), - 'v5p-10880': SystemCharacteristics( - 'tpu', 'v5:4x20x68', 'megacore', (2, 2, 1), 5440 + "v5p-11008": SystemCharacteristics( + "tpu", "v5:4x8x172", "megacore", (2, 2, 1), 5504 ), - 'v5p-11008': SystemCharacteristics( - 'tpu', 'v5:4x8x172', 'megacore', (2, 2, 1), 5504 + "v5p-11136": SystemCharacteristics( + "tpu", "v5:4x12x116", "megacore", (2, 2, 1), 5568 ), - 'v5p-11136': SystemCharacteristics( - 'tpu', 'v5:4x12x116', 'megacore', (2, 2, 1), 5568 + "v5p-11264": SystemCharacteristics( + "tpu", "v5:8x16x44", "megacore", (2, 2, 1), 5632 ), - 'v5p-11264': SystemCharacteristics( - 'tpu', 'v5:8x16x44', 'megacore', (2, 2, 1), 5632 + "v5p-11520": SystemCharacteristics( + "tpu", "v5:12x20x24", "megacore", (2, 2, 1), 5760 ), - 'v5p-11520': SystemCharacteristics( - 'tpu', 'v5:12x20x24', 'megacore', (2, 2, 1), 5760 + "v5p-11648": SystemCharacteristics( + "tpu", "v5:4x28x52", "megacore", (2, 2, 1), 5824 ), - 'v5p-11648': SystemCharacteristics( - 'tpu', 'v5:4x28x52', 'megacore', (2, 2, 1), 5824 + "v5p-11776": SystemCharacteristics( + "tpu", "v5:8x8x92", "megacore", (2, 2, 1), 5888 ), - 'v5p-11776': SystemCharacteristics( - 'tpu', 'v5:8x8x92', 'megacore', (2, 2, 1), 5888 + "v5p-11904": SystemCharacteristics( + "tpu", "v5:4x12x124", "megacore", (2, 2, 1), 5952 ), - 'v5p-11904': SystemCharacteristics( - 'tpu', 'v5:4x12x124', 'megacore', (2, 2, 1), 5952 + "v5p-12032": SystemCharacteristics( + "tpu", "v5:4x8x188", "megacore", (2, 2, 1), 6016 ), - 'v5p-12032': SystemCharacteristics( - 'tpu', 'v5:4x8x188', 'megacore', (2, 2, 1), 6016 + "v5p-12160": SystemCharacteristics( + "tpu", "v5:4x20x76", "megacore", (2, 2, 1), 6080 ), - 'v5p-12160': SystemCharacteristics( - 'tpu', 'v5:4x20x76', 'megacore', (2, 2, 1), 6080 + "v5p-12288": SystemCharacteristics( + "tpu", "v5:16x16x24", "megacore", (2, 2, 1), 6144 ), - 'v5p-12288': SystemCharacteristics( - 'tpu', 'v5:16x16x24', 'megacore', (2, 2, 1), 6144 + "v5p-13824": SystemCharacteristics( + "tpu", "v5:12x24x24", "megacore", (2, 2, 1), 6912 ), - 'v5p-13824': SystemCharacteristics( - 'tpu', 'v5:12x24x24', 'megacore', (2, 2, 1), 6912 - ), - 'v5p-17920': SystemCharacteristics( - 'tpu', 'v5:16x20x28', 'megacore', (2, 2, 1), 8960 + "v5p-17920": SystemCharacteristics( + "tpu", "v5:16x20x28", "megacore", (2, 2, 1), 8960 ), } + def get_system_characteristics(user_facing_name): return UserFacingNameToSystemCharacteristics.get(user_facing_name) diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index bd229cc91..5edb63430 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" @@ -28,12 +28,13 @@ from multihost_dataloading import MultiHostDataLoadIterator from flax.training import train_state + def create_orbax_checkpoint_manager( checkpoint_dir: str, enable_checkpointing: bool, use_async: bool, save_interval_steps: int, - dataset_type: Optional[str] = 'c4' + dataset_type: Optional[str] = "c4", ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -42,19 +43,19 @@ def create_orbax_checkpoint_manager( max_logging.log("Creating checkpoint manager...") p = epath.Path(checkpoint_dir) - if dataset_type=='c4-array_record': - item_names = ('items', 'iter') + if dataset_type == "c4-array_record": + item_names = ("items", "iter") else: - item_names = ('items',) + item_names = ("items",) mngr = CheckpointManager( p, - item_names = item_names, - options = CheckpointManagerOptions( + item_names=item_names, + options=CheckpointManagerOptions( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, - ) + ), ) max_logging.log("Checkpoint manager created!") return mngr @@ -82,19 +83,19 @@ def _replica_devices(device_array: np.ndarray, replica_axis_idx: int): devices inside the replica that current host is in """ idx = _find_idx(device_array, replica_axis_idx) - replica_result = np.take(device_array, - idx, - axis=replica_axis_idx) + replica_result = np.take(device_array, idx, axis=replica_axis_idx) return np.expand_dims(replica_result, axis=replica_axis_idx) -def load_state_if_possible(checkpoint_manager: CheckpointManager, - data_iterator: Union[MultiHostDataLoadIterator, None], - load_parameters_from_path: str, - load_full_state_from_path: str, - abstract_unboxed_pre_state: train_state.TrainState, - enable_single_replica_ckpt_restoring: Optional[bool] = False, - dataset_type: Optional[str] = 'c4'): +def load_state_if_possible( + checkpoint_manager: CheckpointManager, + data_iterator: Union[MultiHostDataLoadIterator, None], + load_parameters_from_path: str, + load_full_state_from_path: str, + abstract_unboxed_pre_state: train_state.TrainState, + enable_single_replica_ckpt_restoring: Optional[bool] = False, + dataset_type: Optional[str] = "c4", +): """Loads TrainState as possible from the inputs. Args: @@ -117,61 +118,81 @@ def load_state_if_possible(checkpoint_manager: CheckpointManager, """ if checkpoint_manager is not None: - max_logging.log("checkpoint manager exists so trying to load this run's existing checkpoint") + max_logging.log( + "checkpoint manager exists so trying to load this run's existing checkpoint" + ) latest_step = checkpoint_manager.latest_step() if latest_step is not None: - max_logging.log(f"restoring from this run's directory latest step \ - {latest_step}") + max_logging.log( + f"restoring from this run's directory latest step \ + {latest_step}" + ) - def map_to_pspec(data): + def map_to_pspec(data): pspec = data.sharding.spec mesh = data.sharding.mesh if not enable_single_replica_ckpt_restoring: - return orbax.checkpoint.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec) + return orbax.checkpoint.type_handlers.ArrayRestoreArgs( + mesh=mesh, mesh_axes=pspec + ) orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, - orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), - override=True) + jax.Array, + orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), + override=True, + ) orbax.checkpoint.type_handlers.register_type_handler( - jax.Array, - orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), - override=True) + jax.Array, + orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), + override=True, + ) replica_axis_index = 0 # for maxtext data is the first dimension replica_devices = _replica_devices(mesh.devices, replica_axis_index) replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names) - single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec) + single_replica_sharding = jax.sharding.NamedSharding( + replica_mesh, pspec + ) return orbax.checkpoint.type_handlers.SingleReplicaArrayRestoreArgs( - sharding=jax.sharding.NamedSharding(mesh, pspec), - single_replica_sharding=single_replica_sharding, - replica_axis_index=replica_axis_index, - global_shape=data.shape, - dtype=data.dtype, - ) - - restore_args = jax.tree_util.tree_map(map_to_pspec, - abstract_unboxed_pre_state, - ) - if dataset_type == 'c4-array_record' and data_iterator is not None: - return checkpoint_manager.restore( - latest_step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeRestore( - item=abstract_unboxed_pre_state, - restore_args=restore_args), - iter=grain.PyGrainCheckpointRestore(data_iterator.local_iterator)) - ), None + sharding=jax.sharding.NamedSharding(mesh, pspec), + single_replica_sharding=single_replica_sharding, + replica_axis_index=replica_axis_index, + global_shape=data.shape, + dtype=data.dtype, + ) + + restore_args = jax.tree_util.tree_map( + map_to_pspec, + abstract_unboxed_pre_state, + ) + if dataset_type == "c4-array_record" and data_iterator is not None: + return ( + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeRestore( + item=abstract_unboxed_pre_state, + restore_args=restore_args, + ), + iter=grain.PyGrainCheckpointRestore( + data_iterator.local_iterator + ), + ), + ), + None, + ) else: return ( - checkpoint_manager.restore( - latest_step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeRestore( - item=abstract_unboxed_pre_state, - restore_args=restore_args) + checkpoint_manager.restore( + latest_step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeRestore( + item=abstract_unboxed_pre_state, + restore_args=restore_args, + ) + ), ), - ), - None) + None, + ) if load_parameters_from_path != "": max_logging.log(f"restoring params from {load_parameters_from_path=}") @@ -181,17 +202,26 @@ def map_to_pspec(data): # Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste # memory, we instead specify here that we are just restoring the params field of the checkpoint # (which itself may be a dictionary containing a key named 'params'). - restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args(abstract_unboxed_pre_state.params) - restored = ckptr.restore(p, item = {'params': abstract_unboxed_pre_state.params}, transforms={}, - restore_args = {'params': restore_args}) - return None, restored['params'] + restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args( + abstract_unboxed_pre_state.params + ) + restored = ckptr.restore( + p, + item={"params": abstract_unboxed_pre_state.params}, + transforms={}, + restore_args={"params": restore_args}, + ) + return None, restored["params"] elif load_full_state_from_path != "": max_logging.log(f"restoring full state from {load_full_state_from_path=}") p = epath.Path(load_full_state_from_path) ckptr = orbax.checkpoint.StandardCheckpointer() - restored = ckptr.restore(p, args=orbax.checkpoint.args.StandardRestore(abstract_unboxed_pre_state)) - return {'items': restored}, None + restored = ckptr.restore( + p, + args=orbax.checkpoint.args.StandardRestore(abstract_unboxed_pre_state), + ) + return {"items": restored}, None else: max_logging.log("No existing checkpoints found, not restoring checkpoint.") diff --git a/MaxText/common_types.py b/MaxText/common_types.py index a2e0f389b..2961104f3 100644 --- a/MaxText/common_types.py +++ b/MaxText/common_types.py @@ -32,13 +32,13 @@ AxisNames = tuple[str, ...] -BATCH = 'activation_batch' -LENGTH = 'activation_length' -HEAD = 'activation_heads' -D_KV = 'activation_kv' - -MODEL_MODE_AUTOREGRESSIVE = 'autoregressive' -MODEL_MODE_PREFILL = 'prefill' -MODEL_MODE_TRAIN = 'train' +BATCH = "activation_batch" +LENGTH = "activation_length" +HEAD = "activation_heads" +D_KV = "activation_kv" + +MODEL_MODE_AUTOREGRESSIVE = "autoregressive" +MODEL_MODE_PREFILL = "prefill" +MODEL_MODE_TRAIN = "train" DECODING_ACTIVE_SEQUENCE_INDICATOR = 1 diff --git a/MaxText/convert_gemma_chkpt.py b/MaxText/convert_gemma_chkpt.py index a3cf1fd39..3787baac4 100644 --- a/MaxText/convert_gemma_chkpt.py +++ b/MaxText/convert_gemma_chkpt.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=line-too-long """ Convert orbax Gemma checkpoint to MaxText compatible checkpoint. @@ -18,7 +18,8 @@ import jax import jax.numpy as jnp import numpy as np -jax.config.update('jax_platform_name', 'cpu') + +jax.config.update("jax_platform_name", "cpu") import argparse import copy from flax.training import train_state @@ -35,24 +36,26 @@ Params = dict[str, Any] + def nest_params(params: Params) -> Params: """Nests params as a dict of dicts rather than a flat dict.""" nested_params = {} for path, param in params.items(): - *path, leaf = path.split('/') + *path, leaf = path.split("/") subdict = nested_params for key in path: subdict = subdict.setdefault(key, {}) subdict[leaf] = param return nested_params + def main(raw_args=None) -> None: parser = argparse.ArgumentParser() - parser.add_argument('--base_model_path', type=str, required=True) - parser.add_argument('--maxtext_model_path', type=str, required=True) - parser.add_argument('--model_size', type=str, required=True) + parser.add_argument("--base_model_path", type=str, required=True) + parser.add_argument("--maxtext_model_path", type=str, required=True) + parser.add_argument("--model_size", type=str, required=True) args = parser.parse_args(raw_args) - if args.model_size not in ('2b','7b'): + if args.model_size not in ("2b", "7b"): raise NotImplementedError print("Loading checkpoint") @@ -60,19 +63,21 @@ def main(raw_args=None) -> None: params = checkpointer.restore(args.base_model_path) params = nest_params(params) num_layers = ( - max(( - int(k.split('_')[1]) - for k in params['transformer'].keys() - if 'layer_' in k - )) + max( + ( + int(k.split("_")[1]) + for k in params["transformer"].keys() + if "layer_" in k + ) + ) + 1 ) - hidden_dim, embed_dim = ( - params['transformer']['layer_0']['mlp']['linear']['w'].shape - ) - num_heads, head_dim, _ = ( - params['transformer']['layer_0']['attn']['attn_vec_einsum']['w'].shape - ) + hidden_dim, embed_dim = params["transformer"]["layer_0"]["mlp"]["linear"][ + "w" + ].shape + num_heads, head_dim, _ = params["transformer"]["layer_0"]["attn"][ + "attn_vec_einsum" + ]["w"].shape print("Model configurations from checkpoint") print(f"num_layers: {num_layers}") print(f"hidden_dim: {hidden_dim}") @@ -81,109 +86,148 @@ def main(raw_args=None) -> None: print(f"head_dim: {head_dim}") jax_weights = { - 'decoder': { - 'decoder_norm': { - 'scale': params['transformer']['final_norm']['scale'] + 1 - }, + "decoder": { + "decoder_norm": { + "scale": params["transformer"]["final_norm"]["scale"] + 1 + }, + }, + "token_embedder": { + "embedding": params["transformer"]["embedder"]["input_embedding"] + * jnp.sqrt(embed_dim) }, - 'token_embedder':{ - 'embedding': params['transformer']['embedder']['input_embedding'] * jnp.sqrt(embed_dim) - } - } self_attention = dict({ - 'query': { - 'kernel' : [] - }, - 'key': { - 'kernel' : [] - }, - 'value': { - 'kernel' : [] - }, - 'out': { - 'kernel' : [] - }, + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, }) layer_weight = dict({ - 'mlp': { - 'wi_0': { - 'kernel' : [] - }, - 'wi_1': { - 'kernel' : [] - }, - 'wo': { - 'kernel' : [] - }, - }, - 'pre_self_attention_norm': { - 'scale': [] - }, - 'pre_ffw_norm': { - 'scale': [] - }, + "mlp": { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, + }, + "pre_self_attention_norm": {"scale": []}, + "pre_ffw_norm": {"scale": []}, }) for layer_idx in range(num_layers): - in_layer_name = 'layer_' + str(layer_idx) + in_layer_name = "layer_" + str(layer_idx) # attention block - if args.model_size == '2b': # MQA - self_attention['query']['kernel'].append(params['transformer'][in_layer_name]['attn']['q_einsum']['w'].transpose((1, 0, 2)) * head_dim**-0.5) - self_attention['key']['kernel'].append(params['transformer'][in_layer_name]['attn']['kv_einsum']['w'][0].transpose((1, 0, 2))) - self_attention['value']['kernel'].append(params['transformer'][in_layer_name]['attn']['kv_einsum']['w'][1].transpose((1, 0, 2))) + if args.model_size == "2b": # MQA + self_attention["query"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["q_einsum"][ + "w" + ].transpose((1, 0, 2)) + * head_dim**-0.5 + ) + self_attention["key"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["kv_einsum"]["w"][ + 0 + ].transpose((1, 0, 2)) + ) + self_attention["value"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["kv_einsum"]["w"][ + 1 + ].transpose((1, 0, 2)) + ) else: - self_attention['query']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][0].transpose((1, 0, 2)) * head_dim**-0.5) - self_attention['key']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][1].transpose((1, 0, 2))) - self_attention['value']['kernel'].append(params['transformer'][in_layer_name]['attn']['qkv_einsum']['w'][2].transpose((1, 0, 2))) - self_attention['out']['kernel'].append(params['transformer'][in_layer_name]['attn']['attn_vec_einsum']['w']) + self_attention["query"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][ + 0 + ].transpose((1, 0, 2)) + * head_dim**-0.5 + ) + self_attention["key"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][ + 1 + ].transpose((1, 0, 2)) + ) + self_attention["value"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["qkv_einsum"]["w"][ + 2 + ].transpose((1, 0, 2)) + ) + self_attention["out"]["kernel"].append( + params["transformer"][in_layer_name]["attn"]["attn_vec_einsum"]["w"] + ) # mlp - layer_weight['mlp']['wi_0']['kernel'].append(params['transformer'][in_layer_name]['mlp']['gating_einsum']['w'][0]) - layer_weight['mlp']['wi_1']['kernel'].append(params['transformer'][in_layer_name]['mlp']['gating_einsum']['w'][1]) - layer_weight['mlp']['wo']['kernel'].append(params['transformer'][in_layer_name]['mlp']['linear']['w']) - layer_weight['pre_self_attention_norm']['scale'].append(params['transformer'][in_layer_name]['pre_attention_norm']['scale'] + 1) - layer_weight['pre_ffw_norm']['scale'].append(params['transformer'][in_layer_name]['pre_ffw_norm']['scale'] + 1) - - self_attention['query']['kernel'] = np.array(self_attention['query']['kernel']).transpose((1, 0, 2, 3)) - self_attention['key']['kernel'] = np.array(self_attention['key']['kernel']).transpose((1, 0, 2, 3)) - self_attention['value']['kernel'] = np.array(self_attention['value']['kernel']).transpose((1, 0, 2, 3)) - self_attention['out']['kernel'] = np.array(self_attention['out']['kernel']).transpose((1, 0, 2, 3)) - - layer_weight['mlp']['wi_0']['kernel'] = np.array(layer_weight['mlp']['wi_0']['kernel']).transpose((1, 0, 2)) - layer_weight['mlp']['wi_1']['kernel'] = np.array(layer_weight['mlp']['wi_1']['kernel']).transpose((1, 0, 2)) - layer_weight['mlp']['wo']['kernel'] = np.array(layer_weight['mlp']['wo']['kernel']).transpose((1, 0, 2)) - layer_weight['pre_self_attention_norm']['scale'] = np.array(layer_weight['pre_self_attention_norm']['scale']).transpose((1, 0)) - layer_weight['pre_ffw_norm']['scale'] = np.array(layer_weight['pre_ffw_norm']['scale']).transpose((1, 0)) - - layer_weight['self_attention'] = copy.deepcopy(self_attention) - jax_weights['decoder']['layers'] = copy.deepcopy(layer_weight) + layer_weight["mlp"]["wi_0"]["kernel"].append( + params["transformer"][in_layer_name]["mlp"]["gating_einsum"]["w"][0] + ) + layer_weight["mlp"]["wi_1"]["kernel"].append( + params["transformer"][in_layer_name]["mlp"]["gating_einsum"]["w"][1] + ) + layer_weight["mlp"]["wo"]["kernel"].append( + params["transformer"][in_layer_name]["mlp"]["linear"]["w"] + ) + layer_weight["pre_self_attention_norm"]["scale"].append( + params["transformer"][in_layer_name]["pre_attention_norm"]["scale"] + 1 + ) + layer_weight["pre_ffw_norm"]["scale"].append( + params["transformer"][in_layer_name]["pre_ffw_norm"]["scale"] + 1 + ) + + self_attention["query"]["kernel"] = np.array( + self_attention["query"]["kernel"] + ).transpose((1, 0, 2, 3)) + self_attention["key"]["kernel"] = np.array( + self_attention["key"]["kernel"] + ).transpose((1, 0, 2, 3)) + self_attention["value"]["kernel"] = np.array( + self_attention["value"]["kernel"] + ).transpose((1, 0, 2, 3)) + self_attention["out"]["kernel"] = np.array( + self_attention["out"]["kernel"] + ).transpose((1, 0, 2, 3)) + + layer_weight["mlp"]["wi_0"]["kernel"] = np.array( + layer_weight["mlp"]["wi_0"]["kernel"] + ).transpose((1, 0, 2)) + layer_weight["mlp"]["wi_1"]["kernel"] = np.array( + layer_weight["mlp"]["wi_1"]["kernel"] + ).transpose((1, 0, 2)) + layer_weight["mlp"]["wo"]["kernel"] = np.array( + layer_weight["mlp"]["wo"]["kernel"] + ).transpose((1, 0, 2)) + layer_weight["pre_self_attention_norm"]["scale"] = np.array( + layer_weight["pre_self_attention_norm"]["scale"] + ).transpose((1, 0)) + layer_weight["pre_ffw_norm"]["scale"] = np.array( + layer_weight["pre_ffw_norm"]["scale"] + ).transpose((1, 0)) + + layer_weight["self_attention"] = copy.deepcopy(self_attention) + jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight) jax_weights = jax.tree_map(jnp.array, jax_weights) + def astype_fn(x): if isinstance(x, jnp.ndarray): return x.astype(jnp.bfloat16) else: return x - jax_weights = jax.tree_map(astype_fn, jax_weights) - enable_checkpointing=True - async_checkpointing=False - save_interval_steps=1 + jax_weights = jax.tree_map(astype_fn, jax_weights) + enable_checkpointing = True + async_checkpointing = False + save_interval_steps = 1 checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( args.maxtext_model_path, enable_checkpointing, async_checkpointing, - save_interval_steps + save_interval_steps, ) state_new = train_state.TrainState( - step=0, - apply_fn=None, - params={'params': jax_weights}, - tx=None, # type: ignore - opt_state={} + step=0, + apply_fn=None, + params={"params": jax_weights}, + tx=None, # type: ignore + opt_state={}, ) if checkpoint_manager is not None: @@ -194,5 +238,6 @@ def astype_fn(x): checkpoint_manager.wait_until_finished() sys.exit() + if __name__ == "__main__": main() diff --git a/MaxText/convert_gpt3_ckpt_from_paxml.py b/MaxText/convert_gpt3_ckpt_from_paxml.py index 78d43c47b..26f10339a 100644 --- a/MaxText/convert_gpt3_ckpt_from_paxml.py +++ b/MaxText/convert_gpt3_ckpt_from_paxml.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=line-too-long """Convert weights from a paxml gpt3 model to a MaxText one. @@ -53,6 +53,7 @@ from train import save_checkpoint import argparse + def fmt_size(num_bytes: int) -> str: assert num_bytes > 0 for unit in ["B", "KiB", "MiB", "GiB"]: @@ -61,28 +62,36 @@ def fmt_size(num_bytes: int) -> str: num_bytes /= 1024.0 return f"{num_bytes:.2f} {unit}" + def check_memory(): """print out cpu/tpu memory.""" cpu_bytes = Process().memory_info().rss max_logging.log(f"cpu memory: {fmt_size(cpu_bytes)}") for d in jax.local_devices(): stats = d.memory_stats() - used = stats['bytes_in_use'] - limit = stats['bytes_limit'] - max_logging.log(f"tpu memory: Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}") + used = stats["bytes_in_use"] + limit = stats["bytes_limit"] + max_logging.log( + f"tpu memory: Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}" + ) -def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name): +def convert( + paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name +): """convert ckpt.""" base_args = [ - '', 'MaxText/configs/base.yml', # base arg - 'per_device_batch_size=1', - 'ici_fsdp_parallelism=-1', 'ici_tensor_parallelism=1', - f'model_name={maxtext_model_name}', - f'run_name={run_name}', f'base_output_directory={base_output_directory}', - 'checkpoint_period=1', - 'async_checkpointing=false', + "", + "MaxText/configs/base.yml", # base arg + "per_device_batch_size=1", + "ici_fsdp_parallelism=-1", + "ici_tensor_parallelism=1", + f"model_name={maxtext_model_name}", + f"run_name={run_name}", + f"base_output_directory={base_output_directory}", + "checkpoint_period=1", + "async_checkpointing=false", ] pyconfig.initialize(base_args) cfg = pyconfig.config @@ -96,45 +105,110 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name tx = optimizers.get_optimizer(cfg, learning_rate_schedule) checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - cfg.checkpoint_dir, - cfg.enable_checkpointing, - cfg.async_checkpointing, - cfg.checkpoint_period, + cfg.checkpoint_dir, + cfg.enable_checkpointing, + cfg.async_checkpointing, + cfg.checkpoint_period, ) - state, _, _ = max_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager) + state, _, _ = max_utils.setup_training_state( + model, None, tx, cfg, init_rng, mesh, checkpoint_manager + ) max_logging.log("start") check_memory() # maxtext keystr: (paxml keystr, transform_fn) keystr_map = { - "['token_embedder']['embedding']": (".params.lm.softmax.logits_ffn.linear.w", lambda x: x.T), - "['decoder']['position_embedder']['embedding']": (".params.lm.position_emb.emb_var", None), - "['decoder']['layers']['pre_self_attention_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['pre_self_attention_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['query']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['query']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,0], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['key']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['key']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,1], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['value']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['value']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x[:,2], 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['qkv_proj']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis])), - "['decoder']['layers']['self_attention']['qkv_proj']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['self_attention']['out']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1])), - "['decoder']['layers']['self_attention']['out']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wi']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wi']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wo']['kernel']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['layers']['mlp']['wo']['bias']": (".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", lambda x: np.moveaxis(x, 0, cfg.param_scan_axis)), - "['decoder']['decoder_norm']['scale']": (".params.lm.final_ln.scale", lambda x: x.T), - "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), + "['token_embedder']['embedding']": ( + ".params.lm.softmax.logits_ffn.linear.w", + lambda x: x.T, + ), + "['decoder']['position_embedder']['embedding']": ( + ".params.lm.position_emb.emb_var", + None, + ), + "['decoder']['layers']['pre_self_attention_norm']['scale']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.scale", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['pre_self_attention_norm']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.layer_norm.bias", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['query']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 0], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['query']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 0], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['key']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 1], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['key']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 1], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['value']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x[:, 2], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['value']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x[:, 2], 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['qkv_proj']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.w", + lambda x: np.moveaxis(x, [2, 0], [0, cfg.param_scan_axis]), + ), + "['decoder']['layers']['self_attention']['qkv_proj']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.combined_qkv.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['self_attention']['out']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.w", + lambda x: np.moveaxis(x, [0, 1], [cfg.param_scan_axis, -1]), + ), + "['decoder']['layers']['self_attention']['out']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.self_attention.post.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['mlp_layer_norm']['scale']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.scale", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['mlp_layer_norm']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.layer_norm.bias", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wi']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.linear.w", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wi']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer1.bias.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wo']['kernel']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.linear.w", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['layers']['mlp']['wo']['bias']": ( + ".params.lm.transformer.repeat.sub.x_layers_0.ff_layer.ffn_layer2.bias.b", + lambda x: np.moveaxis(x, 0, cfg.param_scan_axis), + ), + "['decoder']['decoder_norm']['scale']": ( + ".params.lm.final_ln.scale", + lambda x: x.T, + ), + "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), } state_map = { - ".step": ("step", None), - ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), + ".step": ("step", None), + ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), } def get_layer_prefix(keystr_pax): @@ -148,12 +222,21 @@ def get_layer_prefix(keystr_pax): for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items(): # model variable - state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) + state_map[f".params['params']{keystr_maxtext}"] = ( + f"mdl_vars{keystr_pax}", + transform_fn, + ) prefix_pax_opt_state = get_layer_prefix(keystr_pax) # first momentum in optimizer state - state_map[f".opt_state.mu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", transform_fn) + state_map[f".opt_state.mu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", + transform_fn, + ) # second momentum in optimizer state - state_map[f".opt_state.nu['params']{keystr_maxtext}"] = (f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", transform_fn) + state_map[f".opt_state.nu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", + transform_fn, + ) def verify_fn(key_path, _): keystr = jax.tree_util.keystr(key_path) @@ -161,19 +244,19 @@ def verify_fn(key_path, _): jax.tree_util.tree_map_with_path(verify_fn, state) - memory_metrics = {'max_cpu_bytes': 0} + memory_metrics = {"max_cpu_bytes": 0} - bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://"):].split('/', 1) + bucket_name, paxml_ckpt_prefix = paxml_ckpt_path[len("gs://") :].split("/", 1) def map_fn(key_path, value): key_path_str = jax.tree_util.keystr(key_path) file_path, transform_fn = state_map[key_path_str] full_path = os.path.join(paxml_ckpt_prefix, file_path) - spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}} - spec['kvstore'] = { - 'bucket': bucket_name, - 'driver': 'gcs', - 'path': full_path, + spec = {"driver": "zarr", "metadata_key": ".zarray", "kvstore": {}} + spec["kvstore"] = { + "bucket": bucket_name, + "driver": "gcs", + "path": full_path, } arr = ts.open(ts.Spec(spec), open=True).result().read().result() @@ -184,15 +267,21 @@ def map_fn(key_path, value): shape = value.shape sharding = value.sharding result = jax.make_array_from_single_device_arrays( - shape, - sharding, - [jax.device_put(np.array(arr[index]), d) - for d, index in sharding.addressable_devices_indices_map(shape).items()], + shape, + sharding, + [ + jax.device_put(np.array(arr[index]), d) + for d, index in sharding.addressable_devices_indices_map( + shape + ).items() + ], ) # log peak cpu memory cpu_bytes = Process().memory_info().rss - memory_metrics["max_cpu_bytes"] = max(cpu_bytes, memory_metrics["max_cpu_bytes"]) + memory_metrics["max_cpu_bytes"] = max( + cpu_bytes, memory_metrics["max_cpu_bytes"] + ) # collect cpu memory back asap arr = None @@ -213,21 +302,38 @@ def map_fn(key_path, value): checkpoint_manager.wait_until_finished() sys.exit() - max_logging.log(f"Peak cpu memory in a single process: {fmt_size(memory_metrics['max_cpu_bytes'])}") + max_logging.log( + f"Peak cpu memory in a single process: {fmt_size(memory_metrics['max_cpu_bytes'])}" + ) max_logging.log("checkpoint converted and saved successfully.") + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--paxml-ckpt-path', - type=str, - default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000", - required=True) - parser.add_argument('--maxtext-model-name', choices=['gpt3-175b', 'gpt3-52k'], type=str, required=True) - parser.add_argument('--base-output-directory', type=str, required=True) - parser.add_argument('--run-name', type=str, required=True) + parser.add_argument( + "--paxml-ckpt-path", + type=str, + default="gs://mlperf-llm-public2/gpt3_spmd1x64x24_tpuv4-3072_v84_20221101/checkpoints/checkpoint_00004000", + required=True, + ) + parser.add_argument( + "--maxtext-model-name", + choices=["gpt3-175b", "gpt3-52k"], + type=str, + required=True, + ) + parser.add_argument("--base-output-directory", type=str, required=True) + parser.add_argument("--run-name", type=str, required=True) args = parser.parse_args() if not args.paxml_ckpt_path.startswith("gs://"): - raise ValueError("--paxml-ckpt-path should be a gcs path starting with gs://") + raise ValueError( + "--paxml-ckpt-path should be a gcs path starting with gs://" + ) - convert(args.paxml_ckpt_path, args.maxtext_model_name, args.base_output_directory, args.run_name) + convert( + args.paxml_ckpt_path, + args.maxtext_model_name, + args.base_output_directory, + args.run_name, + ) diff --git a/MaxText/decode.py b/MaxText/decode.py index 15606b18e..e3a5d62ce 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -'''CLI Utility for Running Inference on a Single Stream''' +"""CLI Utility for Running Inference on a Single Stream""" import jax @@ -32,42 +32,54 @@ def main(config): metadata = engine.get_tokenizer() vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) tokenizer = vocab.tokenizer - tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, - prefill_lengths=[config.max_prefill_predict_length]) - assert tokens.size <= config.max_prefill_predict_length, "can't take too many tokens" - assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet" + tokens, true_length = token_utils.tokenize_and_pad( + text, + vocab, + is_bos=True, + prefill_lengths=[config.max_prefill_predict_length], + ) + assert ( + tokens.size <= config.max_prefill_predict_length + ), "can't take too many tokens" + assert ( + config.quantization != "fp8" + ), "fp8 on NVIDIA GPUs is not supported in decode.py yet" prefill_result = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) - slot=0 + slot = 0 decode_state = engine.init_decode_state() - decode_state = engine.insert( - prefill_result, decode_state, slot=slot - ) + decode_state = engine.insert(prefill_result, decode_state, slot=slot) steps = range(config.max_prefill_predict_length, config.max_target_length) sampled_tokens_list = [] for _ in steps: - decode_state, sampled_tokens = engine.generate( - params, decode_state - ) + decode_state, sampled_tokens = engine.generate(params, decode_state) sampled_tokens_list.append(sampled_tokens) - results = [sampled_tokens.get_result_at_slot(slot).tokens.item() for sampled_tokens in sampled_tokens_list] + results = [ + sampled_tokens.get_result_at_slot(slot).tokens.item() + for sampled_tokens in sampled_tokens_list + ] output = tokenizer.detokenize(results) print(f"Input `{text}` -> `{output}`") if config.autoregressive_decode_assert != "": - assert output==config.autoregressive_decode_assert, \ - f"generated text mismatch {output=} {config.autoregressive_decode_assert=}" + assert ( + output == config.autoregressive_decode_assert + ), f"generated text mismatch {output=} {config.autoregressive_decode_assert=}" + def validate_config(config): - assert config.load_full_state_path == "", "Decode doesn't operate on full states! Convert to parameter checkpoint first."\ - "Using generate_param_only_checkpoint." + assert config.load_full_state_path == "", ( + "Decode doesn't operate on full states! Convert to parameter checkpoint first." + "Using generate_param_only_checkpoint." + ) + if __name__ == "__main__": - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(sys.argv) cfg = pyconfig.config diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index cda4273ca..5a08089bd 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Trasforms a "full state" including optimzer state to a bfloat16 "parameter state" without optimizer state. @@ -40,35 +40,56 @@ Transformer = models.Transformer -def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): - """ If input layers are scanned, and force_unroll is set, - return modify training_state and train_state_annotations to be "unrolled". - Otherwise do nothing.""" + +def _possibly_unroll_params( + config, training_state, training_state_annotations, mesh +): + """If input layers are scanned, and force_unroll is set, + return modify training_state and train_state_annotations to be "unrolled". + Otherwise do nothing.""" if not config.scan_layers or not config.force_unroll: return - training_state_layers = training_state.params['params']['decoder']['layers'] - training_state_annotations_layers = training_state_annotations.params['params']['decoder']['layers'] + training_state_layers = training_state.params["params"]["decoder"]["layers"] + training_state_annotations_layers = training_state_annotations.params[ + "params" + ]["decoder"]["layers"] def new_pspec(x): - return jax.sharding.PartitionSpec(*x[0:config.param_scan_axis] + x[config.param_scan_axis+1:]) + return jax.sharding.PartitionSpec( + *x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :] + ) - new_per_layer_state_annotation = jax.tree_map(new_pspec, training_state_annotations_layers) - new_per_layer_state_sharding = jax.tree_map(lambda x : jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation) + new_per_layer_state_annotation = jax.tree_map( + new_pspec, training_state_annotations_layers + ) + new_per_layer_state_sharding = jax.tree_map( + lambda x: jax.sharding.NamedSharding(mesh, x), + new_per_layer_state_annotation, + ) for i in range(config.num_decoder_layers): + def slice_ith(input_layers): - return jax.tree_map(lambda x : jax.numpy.take(x, i, axis = config.param_scan_axis), input_layers) + return jax.tree_map( + lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), + input_layers, + ) + + new_layer = jax.jit(slice_ith, out_shardings=new_per_layer_state_sharding)( + training_state_layers + ) - new_layer = jax.jit(slice_ith, out_shardings = new_per_layer_state_sharding)(training_state_layers) + training_state.params["params"]["decoder"][f"layers_{i}"] = new_layer + training_state_annotations.params["params"]["decoder"][ + f"layers_{i}" + ] = new_per_layer_state_annotation - training_state.params['params']['decoder'][f'layers_{i}'] = new_layer - training_state_annotations.params['params']['decoder'][f'layers_{i}'] = new_per_layer_state_annotation + del training_state.params["params"]["decoder"]["layers"] + del training_state_annotations.params["params"]["decoder"]["layers"] - del training_state.params['params']['decoder']['layers'] - del training_state_annotations.params['params']['decoder']['layers'] + jax.tree_map(lambda x: x.delete(), training_state_layers) - jax.tree_map(lambda x : x.delete(), training_state_layers) def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" @@ -79,21 +100,27 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): learning_rate_schedule = max_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) state, state_mesh_notations, _ = max_utils.setup_training_state( - model, None, tx, config, rng, mesh, checkpoint_manager + model, None, tx, config, rng, mesh, checkpoint_manager ) num_params = max_utils.calculate_num_params_from_pytree(state.params) - max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") + max_logging.log( + f"In input checkpoint Number of model params={num_params/1e9:.3f} billion" + ) return state, state_mesh_notations + def _save_decode_checkpoint(config, state, checkpoint_manager): """Generate checkpoint for decode from the training_state.""" - with jax.spmd_mode('allow_all'): - decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x : x.astype(jax.numpy.bfloat16), state.params)) + with jax.spmd_mode("allow_all"): + decode_state = max_utils.init_decode_state( + None, jax.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params) + ) if checkpoint_manager is not None: if save_checkpoint(checkpoint_manager, 0, decode_state): max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}") checkpoint_manager.wait_until_finished() + def generate_decode_checkpoint(config): """ Generate an decode checkpoint from a given training checkpoint. @@ -119,16 +146,26 @@ def generate_decode_checkpoint(config): config.checkpoint_period, ) # Read training state from config.load_paramaters_path - max_logging.log(f"Read training checkpoint from: {config.load_full_state_path}") - training_state, training_state_annotations = _read_train_checkpoint(config, checkpoint_manager, mesh) - assert training_state.opt_state != {}, "missing opt_state in training checkpoint" + max_logging.log( + f"Read training checkpoint from: {config.load_full_state_path}" + ) + training_state, training_state_annotations = _read_train_checkpoint( + config, checkpoint_manager, mesh + ) + assert ( + training_state.opt_state != {} + ), "missing opt_state in training checkpoint" - _possibly_unroll_params(config, training_state, training_state_annotations, mesh) + _possibly_unroll_params( + config, training_state, training_state_annotations, mesh + ) # Save decode state to config's checkpoint directory at step 0 max_logging.log(f"Save decode checkpoint at: {config.checkpoint_dir}") _save_decode_checkpoint(config, training_state, checkpoint_manager) - max_logging.log(f"Successfully generated decode checkpoint at: {config.checkpoint_dir}0/items") + max_logging.log( + f"Successfully generated decode checkpoint at: {config.checkpoint_dir}0/items" + ) return True diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index ec46b137f..d2b64d782 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Inference microbenchmark for prefill and autoregressive steps.""" import datetime @@ -28,24 +28,41 @@ def summarize_pytree_data(params, name="Params"): - """ Generate basic metrics of a given Pytree. """ - num_params, total_param_size, avg_param_size = max_utils.summarize_size_from_pytree(params) + """Generate basic metrics of a given Pytree.""" + ( + num_params, + total_param_size, + avg_param_size, + ) = max_utils.summarize_size_from_pytree(params) num_params_in_billions = num_params / 1e9 total_param_size_in_gb = total_param_size / 1e9 - print(f"{name} stats: \n" - f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" - f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" - f"\tAvg size: {avg_param_size:.3f} bytes\n") + print( + f"{name} stats: \n" + f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" + f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n" + ) return num_params, total_param_size, avg_param_size -def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""): - """ Inner loop for benchmarking prefill step. """ +def prefill_benchmark_loop( + config, + engine, + decode_state, + params, + tokens, + true_length, + iters, + profile_name="", +): + """Inner loop for benchmarking prefill step.""" max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() for i in range(iters): slot = int(i % (jax.device_count() * config.per_device_batch_size)) - prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + prefill_result = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) decode_state = engine.insert(prefill_result, decode_state, slot=slot) jax.block_until_ready(decode_state) end = datetime.datetime.now() @@ -53,38 +70,71 @@ def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_le return (end - start).total_seconds(), decode_state -def prefill_benchmark(config, engine, params, decode_state, tokens, true_length, - iters=100, profile_name="", num_model_params=None): - """ Handles init, warmup, running prefill benchmark, and printing results. """ +def prefill_benchmark( + config, + engine, + params, + decode_state, + tokens, + true_length, + iters=100, + profile_name="", + num_model_params=None, +): + """Handles init, warmup, running prefill benchmark, and printing results.""" if num_model_params is None: num_model_params, _, _ = summarize_pytree_data(params, name="Params") - prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + prefill_result = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) decode_state = engine.insert(prefill_result, decode_state, slot=0) jax.block_until_ready(decode_state) - prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) + prefill_result = engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) decode_state = engine.insert(prefill_result, decode_state, slot=0) jax.block_until_ready(decode_state) print(f"Prefill results for length {tokens.size}:\n") - profile_name = f"prefill_{tokens.size}" if profile_name == "" else profile_name - time_in_s, decode_state = prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, - profile_name=profile_name) + profile_name = ( + f"prefill_{tokens.size}" if profile_name == "" else profile_name + ) + time_in_s, decode_state = prefill_benchmark_loop( + config, + engine, + decode_state, + params, + tokens, + true_length, + iters, + profile_name=profile_name, + ) prefill_average_ms = 1000 * time_in_s / iters - total_prefill_tflops, _, _ = maxtext_utils.calculate_tflops_prefill(num_model_params, tokens.size, config) - tflops_per_sec_per_device = total_prefill_tflops / jax.device_count() / prefill_average_ms * 1000. - print(f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n" - f"\tPrefill total TFLOPs: {total_prefill_tflops:.3f}\n" - f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n") - result_dict = {"prefill_time_in_ms": prefill_average_ms, - "prefill_total_tflops": total_prefill_tflops, - "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device} + total_prefill_tflops, _, _ = maxtext_utils.calculate_tflops_prefill( + num_model_params, tokens.size, config + ) + tflops_per_sec_per_device = ( + total_prefill_tflops / jax.device_count() / prefill_average_ms * 1000.0 + ) + print( + f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n" + f"\tPrefill total TFLOPs: {total_prefill_tflops:.3f}\n" + f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n" + ) + result_dict = { + "prefill_time_in_ms": prefill_average_ms, + "prefill_total_tflops": total_prefill_tflops, + "prefill_tflops_per_sec_per_device": tflops_per_sec_per_device, + } return result_dict, decode_state -def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=""): - """ Inner loop for benchmarking ar step. """ +def ar_benchmark_loop( + config, engine, decode_state, params, iters, profile_name="" +): + """Inner loop for benchmarking ar step.""" max_utils.activate_profiler(config, profile_name) start = datetime.datetime.now() for _ in range(iters): @@ -95,10 +145,21 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name= return (end - start).total_seconds(), decode_state -def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100): - """ Handles init, warmup, running ar benchmark, and printing results. """ +def ar_benchmark( + config, + engine, + params, + decode_state, + cache_size=None, + model_size=None, + profile_name="", + iters=100, +): + """Handles init, warmup, running ar benchmark, and printing results.""" if cache_size is None: - _, cache_size, _ = summarize_pytree_data(decode_state['cache'], name="Cache") + _, cache_size, _ = summarize_pytree_data( + decode_state["cache"], name="Cache" + ) if model_size is None: _, model_size, _ = summarize_pytree_data(params, name="Params") global_batch_size = jax.device_count() * config.per_device_batch_size @@ -110,40 +171,56 @@ def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_si jax.block_until_ready(decode_state) profile_name = "autoregress" if profile_name == "" else profile_name - time_in_s, decode_state = ar_benchmark_loop(config, engine, decode_state, params, profile_name=profile_name, iters=iters) + time_in_s, decode_state = ar_benchmark_loop( + config, + engine, + decode_state, + params, + profile_name=profile_name, + iters=iters, + ) seconds_per_step = time_in_s / iters - ar_average_ms = seconds_per_step*1000 - total_throughput = jax.device_count() * config.per_device_batch_size / seconds_per_step + ar_average_ms = seconds_per_step * 1000 + total_throughput = ( + jax.device_count() * config.per_device_batch_size / seconds_per_step + ) GB_per_step_per_device = (model_size + cache_size) / 1e9 / jax.device_count() - bw_per_device = GB_per_step_per_device/seconds_per_step - print(f"AutoRegressive results:\n" - f"\tAR step average time: {ar_average_ms:.3f}ms\n" - f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n" - f"\tAR global batch size: {global_batch_size}\n" - f"\tAR throughput: {total_throughput:.3f} tokens/second\n" - f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n") - - - result_dict = {"ar_step_in_ms": ar_average_ms, - "ar_step_in_ms_per_seq": ar_average_ms / global_batch_size, - "ar_global_batch_size": global_batch_size, - "ar_total_throughput_tokens_per_second": total_throughput, - "ar_device_bandwidth_GB_per_second": bw_per_device} + bw_per_device = GB_per_step_per_device / seconds_per_step + print( + f"AutoRegressive results:\n" + f"\tAR step average time: {ar_average_ms:.3f}ms\n" + f"\tAR step average time per seq: {ar_average_ms/global_batch_size:.3f}ms\n" + f"\tAR global batch size: {global_batch_size}\n" + f"\tAR throughput: {total_throughput:.3f} tokens/second\n" + f"\tAR memory bandwidth per device: {bw_per_device:.3f} GB/s\n\n\n" + ) + + result_dict = { + "ar_step_in_ms": ar_average_ms, + "ar_step_in_ms_per_seq": ar_average_ms / global_batch_size, + "ar_global_batch_size": global_batch_size, + "ar_total_throughput_tokens_per_second": total_throughput, + "ar_device_bandwidth_GB_per_second": bw_per_device, + } return result_dict, decode_state -def collate_results(config, results, model_size, cache_size, num_model_params, incl_config=False): - """ Adds model/cache size info and optionally config info to results. """ +def collate_results( + config, results, model_size, cache_size, num_model_params, incl_config=False +): + """Adds model/cache size info and optionally config info to results.""" results["sizes"] = { - "Model_size_in_GB": model_size / 1e9, - "cache_size_in_GB": cache_size / 1e9, - "model_params_in_billions": num_model_params / 1e9, + "Model_size_in_GB": model_size / 1e9, + "cache_size_in_GB": cache_size / 1e9, + "model_params_in_billions": num_model_params / 1e9, } if incl_config: results["config"] = {} for k, v in dict(config.get_keys()).items(): - results["config"][k] = str(v) if k == "dtype" else v # json fails with original dtype + results["config"][k] = ( + str(v) if k == "dtype" else v + ) # json fails with original dtype return results @@ -159,7 +236,9 @@ def print_results_for_analyze(results): prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) print("\nFor usage in analyze_sharegpt.py :") print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}") - print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}") + print( + f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}" + ) def main(config): @@ -172,20 +251,40 @@ def main(config): vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) decode_state = engine.init_decode_state() - _, cache_size, _ = summarize_pytree_data(decode_state['cache'], name="Cache") + _, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") num_model_params, model_size, _ = summarize_pytree_data(params, name="Model") benchmark_results = {"Prefill": {}} benchmark_results["AutoRegressive"], decode_state = ar_benchmark( - config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size) + config, + engine, + params, + decode_state, + iters=benchmark_loop_iters, + cache_size=cache_size, + model_size=model_size, + ) for prefill_length in prefill_lengths: tokens, true_length = token_utils.tokenize_and_pad( - text, vocab, is_bos=True, prefill_lengths=[prefill_length]) - benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( - config, engine, params, decode_state, tokens, true_length, - iters=benchmark_loop_iters, num_model_params=num_model_params) - - results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) + text, vocab, is_bos=True, prefill_lengths=[prefill_length] + ) + ( + benchmark_results["Prefill"][prefill_length], + decode_state, + ) = prefill_benchmark( + config, + engine, + params, + decode_state, + tokens, + true_length, + iters=benchmark_loop_iters, + num_model_params=num_model_params, + ) + + results = collate_results( + config, benchmark_results, model_size, cache_size, num_model_params + ) write_results(results, filename="") print_results_for_analyze(results) diff --git a/MaxText/inference_scratch/analyze_sharegpt.py b/MaxText/inference_scratch/analyze_sharegpt.py index 47d8715cf..efc95009e 100644 --- a/MaxText/inference_scratch/analyze_sharegpt.py +++ b/MaxText/inference_scratch/analyze_sharegpt.py @@ -15,18 +15,27 @@ import argparse import json -PREFILL_BUCKET_SIZE_TO_MS = {64: 9.174, 128: 11.087, 256: 18.468, 512: 29.128, 1024: 58.386} +PREFILL_BUCKET_SIZE_TO_MS = { + 64: 9.174, + 128: 11.087, + 256: 18.468, + 512: 29.128, + 1024: 58.386, +} SYSTEM_TIME_PER_DECODE_TOKEN_MS = 0.32591875 MAX_INPUT_TOKENS = 1024 MAX_OUTPUT_TOKENS = 1024 + def next_power_of_2(x): - return 1 if x == 0 else 2**(x - 1).bit_length() + return 1 if x == 0 else 2 ** (x - 1).bit_length() + def tokens_in_input_str(s): - return_val = int(1.3 * len(s.split())) + return_val = int(1.3 * len(s.split())) return return_val + def get_prefill_and_generate_times(filename=""): if filename == "": return PREFILL_BUCKET_SIZE_TO_MS, SYSTEM_TIME_PER_DECODE_TOKEN_MS @@ -37,31 +46,50 @@ def get_prefill_and_generate_times(filename=""): for k, v in microbenchmark_results["Prefill"].items(): prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) - return prefill_bucket_size_to_ms, microbenchmark_results['AutoRegressive']['ar_step_in_ms_per_seq'] + return ( + prefill_bucket_size_to_ms, + microbenchmark_results["AutoRegressive"]["ar_step_in_ms_per_seq"], + ) + def get_conversations_from_file(filename, max_input_tokens, max_output_tokens): convo_token_numbers = [] - with open(filename, 'r') as f: + with open(filename, "r") as f: loaded_share_gpt = json.load(f) for example in loaded_share_gpt: - if len(example['conversations']) < 2: + if len(example["conversations"]) < 2: continue - num_input_tokens = tokens_in_input_str(example['conversations'][0]['value']) - num_output_tokens = tokens_in_input_str(example['conversations'][1]['value']) + num_input_tokens = tokens_in_input_str(example["conversations"][0]["value"]) + num_output_tokens = tokens_in_input_str( + example["conversations"][1]["value"] + ) convo_token_numbers.append((num_input_tokens, num_output_tokens)) num_convos = len(convo_token_numbers) - kept_convos = [c for c in convo_token_numbers if c[0] <= max_input_tokens and c[1] <= max_output_tokens] + kept_convos = [ + c + for c in convo_token_numbers + if c[0] <= max_input_tokens and c[1] <= max_output_tokens + ] mean_input = sum(c[0] for c in kept_convos) / len(kept_convos) mean_output = sum(c[1] for c in kept_convos) / len(kept_convos) - print(f"Kept {len(kept_convos)} of {num_convos} total convos. {len(100*kept_convos)/num_convos:.3f}%") - print(f"Out of kept convos, mean input tokens: {mean_input:.3f}, mean output tokens: {mean_output:.3f}") + print( + f"Kept {len(kept_convos)} of {num_convos} total convos. {len(100*kept_convos)/num_convos:.3f}%" + ) + print( + f"Out of kept convos, mean input tokens: {mean_input:.3f}, mean output tokens: {mean_output:.3f}" + ) return kept_convos -def compute_times(convos, prefill_bucket_size_to_ms, system_time_per_decode_token_ms, verbose=False): +def compute_times( + convos, + prefill_bucket_size_to_ms, + system_time_per_decode_token_ms, + verbose=False, +): total_prefill_system_ms = 0 total_generate_system_ms = 0 for convo in convos: @@ -72,15 +100,19 @@ def compute_times(convos, prefill_bucket_size_to_ms, system_time_per_decode_toke total_prefill_system_ms += prefill_system_ms total_generate_system_ms += generate_system_ms if verbose: - print(f"{convo} {bucket}, {prefill_system_ms:.2f}, {generate_system_ms:.2f}") + print( + f"{convo} {bucket}, {prefill_system_ms:.2f}, {generate_system_ms:.2f}" + ) total_prefill_time_seconds = total_prefill_system_ms / 1000 total_generate_time_seconds = total_generate_system_ms / 1000 total_time_s = total_prefill_time_seconds + total_generate_time_seconds - print(f"\nTotal time {total_time_s:.3f} seconds: " - f"\n\tPrefill time: {total_prefill_time_seconds:.3f} seconds" - f"\n\tGenerate time: {total_generate_time_seconds:.3f} seconds") + print( + f"\nTotal time {total_time_s:.3f} seconds: " + f"\n\tPrefill time: {total_prefill_time_seconds:.3f} seconds" + f"\n\tGenerate time: {total_generate_time_seconds:.3f} seconds" + ) return total_time_s, total_prefill_time_seconds, total_generate_time_seconds @@ -92,17 +124,31 @@ def get_num_tokens_in_convos(convos): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('convo_file', type=str, - help='a json file containing conversations') - parser.add_argument('-t', '--mb_timing_file', type=str, default="", - help='a json file containing microbenchmark timing results') - parser.add_argument('-v', '--verbose', action="store_true") + parser.add_argument( + "convo_file", type=str, help="a json file containing conversations" + ) + parser.add_argument( + "-t", + "--mb_timing_file", + type=str, + default="", + help="a json file containing microbenchmark timing results", + ) + parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() - convos = get_conversations_from_file(args.convo_file, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS) + convos = get_conversations_from_file( + args.convo_file, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS + ) total_input_tokens, total_output_tokens = get_num_tokens_in_convos(convos) - prefill_time_ms_buckets, generate_time_ms = get_prefill_and_generate_times(filename=args.mb_timing_file) - total_time_seconds, _, _ = compute_times(convos, prefill_time_ms_buckets, generate_time_ms, args.verbose) - - print(f"Output {total_output_tokens} tokens in {total_time_seconds:.3f} seconds " - f"= {total_output_tokens/total_time_seconds:.3f} out tok/s") + prefill_time_ms_buckets, generate_time_ms = get_prefill_and_generate_times( + filename=args.mb_timing_file + ) + total_time_seconds, _, _ = compute_times( + convos, prefill_time_ms_buckets, generate_time_ms, args.verbose + ) + + print( + f"Output {total_output_tokens} tokens in {total_time_seconds:.3f} seconds " + f"= {total_output_tokens/total_time_seconds:.3f} out tok/s" + ) diff --git a/MaxText/inference_utils.py b/MaxText/inference_utils.py index 786ecaae7..4d0116422 100644 --- a/MaxText/inference_utils.py +++ b/MaxText/inference_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax import jax.numpy as jnp @@ -26,14 +26,15 @@ Inspired by an Google-internal implementation, Global Vision Transformer. """ + def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): - """ - logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], before logit - rng: rng key to use - algorithm: string representing supported algorithms - topk: restricting to topk logits before sampling - nucleus_topp: restricting to p probability mass before sampling - temperature: temperature parameter for scaling probability + """ + logits: unnormalized logits to sample, shaped [YOUR_LEADING_DIMS, Vocab], before logit + rng: rng key to use + algorithm: string representing supported algorithms + topk: restricting to topk logits before sampling + nucleus_topp: restricting to p probability mass before sampling + temperature: temperature parameter for scaling probability """ if algorithm == "greedy": return jnp.argmax(logits, axis=-1) @@ -46,36 +47,44 @@ def sampling(logits, rng, algorithm, topk=0, nucleus_topp=0, temperature=1.0): else: raise ValueError(f"Sampling {algorithm=} not supported!") + def sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng): """Restrict sampling to the top logits with cumulative probability >= nucleus_topp. - - The nucleus sampling method is proposed in the paper `The Curious Case of - Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` - + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + """ if nucleus_topp < 0: - raise ValueError("Can't apply nucleus with parameter {nucleus_topp=} less zero") + raise ValueError( + "Can't apply nucleus with parameter {nucleus_topp=} less zero" + ) logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending sorted_cum_probs = jnp.cumsum( - jax.nn.softmax(logits_sorted, axis=-1), axis=-1) # get cumsum probs + jax.nn.softmax(logits_sorted, axis=-1), axis=-1 + ) # get cumsum probs cutoff_index = jnp.sum( - sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True) # find cutoff index + sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True + ) # find cutoff index cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) - logits = jnp.where(logits < cutoff_logit, - jnp.full_like(logits, NEG_INF), logits) + logits = jnp.where( + logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits + ) return jax.random.categorical(rng, logits / temperature) + def sample_topk_logits(logits, topk, temperature, rng): - """ Restricting sampling to the best k logits. """ + """Restricting sampling to the best k logits.""" if topk <= 0: - raise ValueError("Can't apply algorithm topk with parameter {topk=} less than or equal to zero") + raise ValueError( + "Can't apply algorithm topk with parameter {topk=} less than or equal to zero" + ) topk_logits, topk_idxs = jax.lax.top_k(logits, topk) topk_token = jnp.expand_dims( - jax.random.categorical(rng, topk_logits/temperature).astype(jnp.int32), - axis=-1) + jax.random.categorical(rng, topk_logits / temperature).astype(jnp.int32), + axis=-1, + ) sampled_tokens = jnp.squeeze( - jnp.take_along_axis(topk_idxs, topk_token, axis=-1), - axis=-1).astype(jnp.int32) + jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1 + ).astype(jnp.int32) return sampled_tokens - - diff --git a/MaxText/input_pipeline/_grain_data_processing.py b/MaxText/input_pipeline/_grain_data_processing.py index 92c90854d..a9d04c9f7 100644 --- a/MaxText/input_pipeline/_grain_data_processing.py +++ b/MaxText/input_pipeline/_grain_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline using Grain.""" @@ -30,31 +30,41 @@ import multihost_dataloading -def get_datasets( - config: ml_collections.ConfigDict -): + +def get_datasets(config: ml_collections.ConfigDict): """Load dataset from array_record files for using with grain""" data_dir = os.path.join(config.dataset_path, config.dataset_name) - train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)] + train_files = [ + data_dir + "/" + f + for f in os.listdir(data_dir) + if re.match(r".*train.*", f) + ] train_ds = grain.ArrayRecordDataSource(train_files) if config.eval_dataset_name: - eval_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(rf'.*{config.eval_split}.*', f)] + eval_files = [ + data_dir + "/" + f + for f in os.listdir(data_dir) + if re.match(rf".*{config.eval_split}.*", f) + ] eval_ds = grain.ArrayRecordDataSource(eval_files) else: eval_ds = train_ds return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, - global_mesh, - train_ds, eval_ds, - vocab_path: Optional[str] = None, - data_shuffle_seed = 0, - add_bos = True, - add_eos = True - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, + global_mesh, + train_ds, + eval_ds, + vocab_path: Optional[str] = None, + data_shuffle_seed=0, + add_bos=True, + add_eos=True, +): """Use grain to pre-process the dataset and return iterators""" # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load @@ -78,7 +88,8 @@ def preprocess_dataset(config: ml_collections.ConfigDict, num_epochs=1, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) eval_iter = preprocessing_pipeline( eval_ds, @@ -93,7 +104,8 @@ def preprocess_dataset(config: ml_collections.ConfigDict, shuffle=config.enable_data_shuffling, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) predict_iter = preprocessing_pipeline( eval_ds, @@ -108,72 +120,88 @@ def preprocess_dataset(config: ml_collections.ConfigDict, shuffle=config.enable_data_shuffling, pack_examples=True, max_length=config.max_target_length, - data_shuffle_seed=data_shuffle_seed) + data_shuffle_seed=data_shuffle_seed, + ) return train_iter, eval_iter, predict_iter + def preprocessing_pipeline( - dataset, - vocab_path, - add_bos: bool, - add_eos: bool, - grain_worker_count: int, - batch_size: int, - global_mesh, - dataloading_host_index, - dataloading_host_count, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - max_length: int = 512, - shift: bool = True, - drop_remainder: bool = True, - data_shuffle_seed = 0, + dataset, + vocab_path, + add_bos: bool, + add_eos: bool, + grain_worker_count: int, + batch_size: int, + global_mesh, + dataloading_host_index, + dataloading_host_count, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + max_length: int = 512, + shift: bool = True, + drop_remainder: bool = True, + data_shuffle_seed=0, ): """Apply grain operations to preprocess the given dataset.""" assert ( - batch_size % global_mesh.size == 0 - ), 'Batch size should be divisible number of global devices.' + batch_size % global_mesh.size == 0 + ), "Batch size should be divisible number of global devices." operations = [] operations.append(_grain_operations.ParseFeatures()) operations.append(_grain_operations.NormalizeFeatures()) - operations.append(_grain_tokenizer.TokenizeAndTrim(["inputs","targets"], - max_length, vocab_path, - add_bos, add_eos)) + operations.append( + _grain_tokenizer.TokenizeAndTrim( + ["inputs", "targets"], max_length, vocab_path, add_bos, add_eos + ) + ) # Pack and Batch examples. if pack_examples: - operations.append(grain.experimental.PackAndBatchOperation( - batch_size=batch_size // jax.process_count(), - length_struct={'inputs':max_length,'targets':max_length})) + operations.append( + grain.experimental.PackAndBatchOperation( + batch_size=batch_size // jax.process_count(), + length_struct={"inputs": max_length, "targets": max_length}, + ) + ) operations.append(_grain_operations.ReformatPacking()) else: operations.append(_grain_operations.PadToMaxLength(max_length)) - operations.append(grain.Batch(batch_size=batch_size // jax.process_count(), drop_remainder=drop_remainder)) + operations.append( + grain.Batch( + batch_size=batch_size // jax.process_count(), + drop_remainder=drop_remainder, + ) + ) # Shift inputs for teacher-forced training if shift: operations.append(_grain_operations.ShiftData(axis=1)) index_sampler = grain.IndexSampler( - num_records=len(dataset), - num_epochs = num_epochs, - shard_options=grain.ShardOptions( - shard_index = dataloading_host_index, shard_count = dataloading_host_count, drop_remainder = True - ), - shuffle = shuffle, - seed = data_shuffle_seed + num_records=len(dataset), + num_epochs=num_epochs, + shard_options=grain.ShardOptions( + shard_index=dataloading_host_index, + shard_count=dataloading_host_count, + drop_remainder=True, + ), + shuffle=shuffle, + seed=data_shuffle_seed, ) dataloader = grain.DataLoader( - data_source = dataset, - operations = operations, - sampler = index_sampler, + data_source=dataset, + operations=operations, + sampler=index_sampler, worker_count=grain_worker_count, ) - multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataloader, global_mesh) + multihost_gen = multihost_dataloading.MultiHostDataLoadIterator( + dataloader, global_mesh + ) # Return multi-host jax.Array prep iterator return multihost_gen diff --git a/MaxText/input_pipeline/_grain_operations.py b/MaxText/input_pipeline/_grain_operations.py index 25508376b..3d48f7765 100644 --- a/MaxText/input_pipeline/_grain_operations.py +++ b/MaxText/input_pipeline/_grain_operations.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Operations used by Grain""" @@ -21,60 +21,73 @@ import grain.python as grain import numpy as np import tensorflow as tf + Features = Dict[str, tf.Tensor] + @dataclasses.dataclass class ParseFeatures(grain.MapTransform): """Parse serialized example""" + def map(self, features): def _parse(example): parsed = tf.io.parse_example( - example, { - 'text': tf.io.FixedLenFeature(shape=(), dtype=tf.string) - }) + example, {"text": tf.io.FixedLenFeature(shape=(), dtype=tf.string)} + ) return parsed + return _parse(features) @dataclasses.dataclass class NormalizeFeatures(grain.MapTransform): """Normalize text feature keys.""" + def map(self, features): return { - 'inputs':features['text'].numpy().decode(), - 'targets': features['text'].numpy().decode() + "inputs": features["text"].numpy().decode(), + "targets": features["text"].numpy().decode(), } @dataclasses.dataclass class ReformatPacking(grain.MapTransform): """Reformat packing outputs.""" + def map(self, data): - return{ - 'inputs':data[0]['inputs'], - 'targets':data[0]['targets'], - 'inputs_segmentation':data[1]['inputs'], - 'targets_segmentation':data[1]['targets'], - 'inputs_position':data[2]['inputs'], - 'targets_position':data[2]['targets'], + return { + "inputs": data[0]["inputs"], + "targets": data[0]["targets"], + "inputs_segmentation": data[1]["inputs"], + "targets_segmentation": data[1]["targets"], + "inputs_position": data[2]["inputs"], + "targets_position": data[2]["targets"], } @dataclasses.dataclass class PadToMaxLength(grain.MapTransform): - """Pads each input to the specified length""" + """Pads each input to the specified length""" + def __init__(self, max_length): self.max_length = max_length + def map(self, data): """map to each element""" + def _pad(x, max_length): pad_amount = max(max_length - x.shape[0], 0) pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1) return np.pad(x, pad_amount) - data['inputs_segmentation'] = np.ones(data['inputs'].shape, dtype = np.int32) - data['inputs_position'] = np.arange(data['inputs'].shape[0], dtype = np.int32) - data['targets_segmentation'] = np.ones(data['targets'].shape, dtype = np.int32) - data['targets_position'] = np.arange(data['targets'].shape[0], dtype = np.int32) + + data["inputs_segmentation"] = np.ones(data["inputs"].shape, dtype=np.int32) + data["inputs_position"] = np.arange(data["inputs"].shape[0], dtype=np.int32) + data["targets_segmentation"] = np.ones( + data["targets"].shape, dtype=np.int32 + ) + data["targets_position"] = np.arange( + data["targets"].shape[0], dtype=np.int32 + ) for key, _ in data.items(): data[key] = _pad(data[key], self.max_length) return data @@ -84,33 +97,38 @@ def shift_right(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - slices = [slice(None),] * len(x.shape) + slices = [ + slice(None), + ] * len(x.shape) slices[axis] = slice(0, -1) padded = np.pad( - x, - pad_widths, - mode='constant', - constant_values=x.dtype.type(0) - ) + x, pad_widths, mode="constant", constant_values=x.dtype.type(0) + ) return padded[tuple(slices)] + def shift_and_refine(x, axis=1): """Shift inputs, set segmentation to 0 when target element is 0. Replace EOS by 0 for packed inputs.""" - x['inputs'] = shift_right(x['inputs'], axis=axis) - targets_nonzero = x['targets'] != 0 - x['inputs_segmentation'] *= targets_nonzero - x['targets_segmentation'] *= targets_nonzero + x["inputs"] = shift_right(x["inputs"], axis=axis) + targets_nonzero = x["targets"] != 0 + x["inputs_segmentation"] *= targets_nonzero + x["targets_segmentation"] *= targets_nonzero # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. - x['inputs'] *= x['inputs_segmentation'] == shift_right(x['inputs_segmentation'], axis=axis) + x["inputs"] *= x["inputs_segmentation"] == shift_right( + x["inputs_segmentation"], axis=axis + ) return x + @dataclasses.dataclass class ShiftData(grain.MapTransform): """Shift inputs and refine annotations.""" - def __init__(self, axis = 1): + + def __init__(self, axis=1): self.axis = axis + def map(self, data): return shift_and_refine(data, axis=self.axis) diff --git a/MaxText/input_pipeline/_grain_tokenizer.py b/MaxText/input_pipeline/_grain_tokenizer.py index e9436799a..94a2fe22f 100644 --- a/MaxText/input_pipeline/_grain_tokenizer.py +++ b/MaxText/input_pipeline/_grain_tokenizer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Tokenize Op used by Grain""" @@ -24,9 +24,11 @@ import grain.python as grain import numpy as np + @dataclasses.dataclass class TokenizeAndTrim(grain.MapTransform): """Tokenize and trim features to sequence length.""" + # pylint: disable=attribute-defined-outside-init feature_names: str | Sequence[str] sequence_length: int | Sequence[int] @@ -58,7 +60,7 @@ def map(self, features: dict[str, Any]) -> dict[str, Any]: token_ids = [self._processor.bos_id()] + token_ids if self.add_eos: - token_ids = token_ids[:sequence_length-1] + token_ids = token_ids[: sequence_length - 1] token_ids = token_ids + [self._processor.eos_id()] else: token_ids = token_ids[:sequence_length] diff --git a/MaxText/input_pipeline/_tfds_data_processing.py b/MaxText/input_pipeline/_tfds_data_processing.py index 506e7bd87..ec64eb54a 100644 --- a/MaxText/input_pipeline/_tfds_data_processing.py +++ b/MaxText/input_pipeline/_tfds_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline for a LM1B dataset.""" @@ -34,17 +34,21 @@ # Right-shifting token inputs for teacher-forced training. # ----------------------------------------------------------------------------- + def shift_right_tf(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) - slices = [slice(None),] * len(x.shape) + slices = [ + slice(None), + ] * len(x.shape) slices[axis] = slice(0, -1) padded = tf.pad( x, tf.constant(pad_widths), - mode='constant', - constant_values=tf.constant(0, x.dtype)) + mode="constant", + constant_values=tf.constant(0, x.dtype), + ) return padded[tuple(slices)] @@ -59,41 +63,42 @@ def shift_inputs_tf(x, segment_ids=None, axis=1): ) return shifted + def shift_data(x, axis=0, segmented=True): - segment_ids = x['inputs_segmentation'] if segmented else None - x['inputs'] = shift_inputs_tf(x['inputs'], segment_ids=segment_ids, axis=axis) + segment_ids = x["inputs_segmentation"] if segmented else None + x["inputs"] = shift_inputs_tf(x["inputs"], segment_ids=segment_ids, axis=axis) return x + def shift_data_by_truncation(x): - x['inputs'] = x['inputs'][:-1] - x['targets'] = x['targets'][1:] + x["inputs"] = x["inputs"][:-1] + x["targets"] = x["targets"][1:] return x def normalize_features(ds): """Normalize text feature keys.""" + def _normalize_features(features): - features['inputs'] = features.pop('text') - features['targets'] = features['inputs'] + features["inputs"] = features.pop("text") + features["targets"] = features["inputs"] return features - return ds.map( - _normalize_features, - num_parallel_calls=AUTOTUNE) + return ds.map(_normalize_features, num_parallel_calls=AUTOTUNE) + def length_trim(ds, max_len): - """"Trim to Max length""" + """ "Trim to Max length""" + def _trim_fn(features): - if tf.shape(features['inputs'])[0] > max_len: - features['inputs'] = features['inputs'][:max_len] - if tf.shape(features['targets'])[0] > max_len: - features['targets'] = features['targets'][:max_len] + if tf.shape(features["inputs"])[0] > max_len: + features["inputs"] = features["inputs"][:max_len] + if tf.shape(features["targets"])[0] > max_len: + features["targets"] = features["targets"][:max_len] return features - return ds.map( - _trim_fn, - num_parallel_calls=AUTOTUNE - ) + return ds.map(_trim_fn, num_parallel_calls=AUTOTUNE) + # ----------------------------------------------------------------------------- # Main dataset preparation. @@ -101,88 +106,98 @@ def _trim_fn(features): def preprocessing_pipeline( - dataset, - batch_size: int, - global_mesh, - shuffle: bool, - num_epochs: Optional[int] = 1, - pack_examples: bool = True, - shuffle_buffer_size: int = 1024, - max_length: int = 512, - shift: bool = True, - drop_remainder: bool = True, - prefetch_size = tf.data.experimental.AUTOTUNE, - data_shuffle_seed = 0, + dataset, + batch_size: int, + global_mesh, + shuffle: bool, + num_epochs: Optional[int] = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + shift: bool = True, + drop_remainder: bool = True, + prefetch_size=tf.data.experimental.AUTOTUNE, + data_shuffle_seed=0, ): """Shuffle and batch/pack the given dataset.""" def truncate_to_max_allowable_length(x, max_length): - x['inputs'] = x['inputs'][:max_length] - x['targets'] = x['targets'][:max_length] + x["inputs"] = x["inputs"][:max_length] + x["targets"] = x["targets"][:max_length] return x - if max_length > 0: # We can take upto max_length+1 because there would be truncation by 1 token # for both inputs and targets - dataset = dataset.map(lambda x: truncate_to_max_allowable_length(x, max_length+1)) + dataset = dataset.map( + lambda x: truncate_to_max_allowable_length(x, max_length + 1) + ) # Shuffle and repeat. if shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, seed = data_shuffle_seed) + dataset = dataset.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) dataset = dataset.repeat(num_epochs) - # Shift inputs for teacher-forced training if shift: dataset = dataset.map( - shift_data_by_truncation, - num_parallel_calls=tf.data.AUTOTUNE, - deterministic=True) + shift_data_by_truncation, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=True, + ) # Perform greedy sequence packing if pack_examples: dataset = sequence_packing.pack_dataset(dataset, max_length) assert ( - batch_size % global_mesh.size == 0 - ), 'Batch size should be divisible number of global devices.' + batch_size % global_mesh.size == 0 + ), "Batch size should be divisible number of global devices." # Batch examples. if pack_examples: - dataset = dataset.batch(batch_size // jax.process_count(), drop_remainder=drop_remainder) + dataset = dataset.batch( + batch_size // jax.process_count(), drop_remainder=drop_remainder + ) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size // jax.process_count(), - padded_shapes={'inputs': max_length, 'targets': max_length}, - padding_values={'inputs': 0, 'targets': 0}, - drop_remainder=drop_remainder) + padded_shapes={"inputs": max_length, "targets": max_length}, + padding_values={"inputs": 0, "targets": 0}, + drop_remainder=drop_remainder, + ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) - multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, global_mesh) + multihost_gen = multihost_dataloading.MultiHostDataLoadIterator( + dataset, global_mesh + ) # Return multi-host jax.Array prep iterator return multihost_gen def get_datasets( - config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, - read_config = None, + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, + read_config=None, ): """Load and return dataset of batched examples for use during training.""" # Training dataset. train_ds_builder = tfds.builder(config.dataset_name) # train_data = get_raw_dataset(train_ds_builder, 'train') - train_ds = train_ds_builder.as_dataset(split='train', - read_config = read_config, - shuffle_files=config.enable_data_shuffling) + train_ds = train_ds_builder.as_dataset( + split="train", + read_config=read_config, + shuffle_files=config.enable_data_shuffling, + ) # shard the dataset as soon as it is loaded - train_ds = train_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) + train_ds = train_ds.shard( + num_shards=dataloading_host_count, index=dataloading_host_index + ) train_ds = normalize_features(train_ds) # Evaluation dataset. @@ -191,25 +206,33 @@ def get_datasets( else: eval_ds_builder = train_ds_builder # eval_data = get_raw_dataset(eval_ds_builder, config.eval_split) - eval_ds = eval_ds_builder.as_dataset(split=config.eval_split, - read_config = read_config, - shuffle_files=False) - eval_ds = eval_ds.shard(num_shards = jax.process_count(), index = jax.process_index()) + eval_ds = eval_ds_builder.as_dataset( + split=config.eval_split, read_config=read_config, shuffle_files=False + ) + eval_ds = eval_ds.shard( + num_shards=jax.process_count(), index=jax.process_index() + ) eval_ds = normalize_features(eval_ds) return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - global_mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed = 0, - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + global_mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed=0, +): """Pre-process the dataset and return iterators""" # Tokenize data. train_ds = train_ds.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) eval_ds = eval_ds.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) # Set global batch size. global_batch_size_to_load = config.global_batch_size_to_load @@ -220,9 +243,10 @@ def preprocess_dataset(config: ml_collections.ConfigDict, eval_batch_size = global_batch_size_to_load def filter_keys(record): - return {'inputs': record['inputs'], 'targets': record['targets']} - train_ds = train_ds.map(filter_keys,num_parallel_calls=tf.data.AUTOTUNE) - eval_ds = eval_ds.map(filter_keys,num_parallel_calls=tf.data.AUTOTUNE) + return {"inputs": record["inputs"], "targets": record["targets"]} + + train_ds = train_ds.map(filter_keys, num_parallel_calls=tf.data.AUTOTUNE) + eval_ds = eval_ds.map(filter_keys, num_parallel_calls=tf.data.AUTOTUNE) train_iter = preprocessing_pipeline( train_ds, @@ -233,7 +257,8 @@ def filter_keys(record): pack_examples=True, max_length=config.max_target_length, shift=True, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) eval_iter = preprocessing_pipeline( eval_ds, @@ -244,7 +269,8 @@ def filter_keys(record): max_length=config.max_target_length, shift=False, drop_remainder=False, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) predict_iter = preprocessing_pipeline( eval_ds, @@ -255,6 +281,7 @@ def filter_keys(record): max_length=config.max_target_length, shift=False, drop_remainder=False, - data_shuffle_seed = data_shuffle_seed,) + data_shuffle_seed=data_shuffle_seed, + ) return train_iter, eval_iter, predict_iter diff --git a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index c5a6d8d06..5b9769aad 100644 --- a/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline for gpt3 c4 mlperf dataset.""" @@ -35,6 +35,7 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE + # data processing functions: # _shift_left_and_pad, rekey, reduce_concat_tokens and split_tokens_to_targets_length # Adapted from: @@ -58,8 +59,10 @@ def _shift_left_and_pad(tensor, pad_val): v = v[0] return v + def rekey(ds, key_map=None): """normalization with key mapping""" + def _rekey(x, key_map=None): """Replace the feature keys according to the mapping in `key_map`. For example, if the dataset returns examples of the format: @@ -76,19 +79,20 @@ def _rekey(x, key_map=None): """ if key_map: return { - new_key: x[old_key] - for new_key, old_key in key_map.items() if old_key + new_key: x[old_key] for new_key, old_key in key_map.items() if old_key } return x return ds.map( - functools.partial(_rekey, key_map=key_map), - num_parallel_calls=AUTOTUNE) + functools.partial(_rekey, key_map=key_map), num_parallel_calls=AUTOTUNE + ) + -def reduce_concat_tokens(dataset, - feature_key='targets', - batch_size=128, - ): +def reduce_concat_tokens( + dataset, + feature_key="targets", + batch_size=128, +): """Token-preprocessor to concatenate multiple unrelated documents. If we want to generate examples of exactly the right length, (to avoid wasting space on padding), then we use this function, folowed by @@ -101,8 +105,10 @@ def reduce_concat_tokens(dataset, a dataset """ dataset = dataset.map( - lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE) + lambda x: {feature_key: x[feature_key]}, num_parallel_calls=AUTOTUNE + ) dataset = dataset.padded_batch(batch_size, padded_shapes={feature_key: [-1]}) + def _my_fn(x): tokens = tf.reshape(x[feature_key], [-1]) # strip padding @@ -111,10 +117,12 @@ def _my_fn(x): return dataset.map(_my_fn, num_parallel_calls=AUTOTUNE) -def split_tokens(dataset, - max_tokens_per_segment=128, - feature_key='targets', - ): + +def split_tokens( + dataset, + max_tokens_per_segment=128, + feature_key="targets", +): """Split examples into multiple examples each. The intended use case is to break up long examples for use in unsupervised transfer-learning. @@ -127,6 +135,7 @@ def split_tokens(dataset, Returns: a dataset """ + def _split_tokens(x): """Split one token sequence into multiple multiple.""" tokens = x[feature_key] @@ -136,8 +145,11 @@ def _split_tokens(x): # Pad to a multiple of length, then use tf.reshape to split up the tokens # into num_segments segments each of the given length. num_segments = tf.cast( - tf.math.ceil(tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32)), - tf.int32) + tf.math.ceil( + tf.cast(n_tokens, tf.float32) / tf.cast(length, tf.float32) + ), + tf.int32, + ) padding = num_segments * length - tf.size(tokens) tokens = tf.pad(tokens, [[0, padding]]) return tf.reshape(tokens, [-1, length]) @@ -149,19 +161,25 @@ def _strip_padding(x): dataset = dataset.filter(lambda x: tf.not_equal(tf.size(x[feature_key]), 0)) dataset = dataset.map(_split_tokens, num_parallel_calls=AUTOTUNE) dataset = dataset.unbatch() - return dataset.map( - _strip_padding, num_parallel_calls=AUTOTUNE) + return dataset.map(_strip_padding, num_parallel_calls=AUTOTUNE) + def split_tokens_to_targets_length(dataset, sequence_length): return split_tokens(dataset, max_tokens_per_segment=sequence_length) -def _pad_to_batch_size(ds: tf.data.Dataset, batch_size: int, num_examples: Optional[int] = None,) -> tf.data.Dataset: + +def _pad_to_batch_size( + ds: tf.data.Dataset, + batch_size: int, + num_examples: Optional[int] = None, +) -> tf.data.Dataset: """Pad unevenly distributed eval data in each shard with new entries to multiples of batch size.""" # local_num represents the total number of examples in eval dataset, if num_examples: local_num = num_examples else: + def _get_num_examples(ds: tf.data.Dataset) -> int: # Iterate one-by-one instead of len(list(...)) to reduce peak memory. num_examples = 0 @@ -174,60 +192,85 @@ def _get_num_examples(ds: tf.data.Dataset) -> int: local_num_batches = (local_num + batch_size - 1) // batch_size # Find the max number of batches required across all Jax processes. num_batches_all = multihost_utils.process_allgather( - jnp.array([local_num_batches]), tiled=False) + jnp.array([local_num_batches]), tiled=False + ) num_batches = np.max(num_batches_all) pad_num = num_batches * batch_size - local_num assert pad_num >= 0 print( - f'Eval data has {local_num} local entries, padding now with ' - f'{pad_num} extra entries to get {num_batches} batches.') + f"Eval data has {local_num} local entries, padding now with " + f"{pad_num} extra entries to get {num_batches} batches." + ) + # Repeat a random example to make the last batch full. def _add_pad(x): - x['targets_segmentation'] *= 0 + x["targets_segmentation"] *= 0 return x + pad_ds = ds.take(1).map(_add_pad).repeat(pad_num) return ds.concatenate(pad_ds) + def get_datasets( - config: ml_collections.ConfigDict, - dataloading_host_index, - dataloading_host_count, + config: ml_collections.ConfigDict, + dataloading_host_index, + dataloading_host_count, ): """Load and return dataset of batched examples for use during training.""" # Training dataset. read_config = tfds.ReadConfig( - shuffle_seed = config.data_shuffle_seed, - ) + shuffle_seed=config.data_shuffle_seed, + ) train_ds_builder = tfds.builder(config.dataset_name) - train_ds = train_ds_builder.as_dataset(split='train2', read_config=read_config, shuffle_files=config.enable_data_shuffling) + train_ds = train_ds_builder.as_dataset( + split="train2", + read_config=read_config, + shuffle_files=config.enable_data_shuffling, + ) eval_ds_builder = tfds.builder(config.eval_dataset_name) - eval_ds = eval_ds_builder.as_dataset(split='validation_tokenized_5662seqs', read_config=read_config, shuffle_files=False) + eval_ds = eval_ds_builder.as_dataset( + split="validation_tokenized_5662seqs", + read_config=read_config, + shuffle_files=False, + ) # shard the dataset as soon as it is loaded - train_ds = train_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) - train_ds = rekey(train_ds, {'inputs': None, 'targets': 'text'}) - - eval_ds = eval_ds.shard(num_shards = dataloading_host_count, index = dataloading_host_index) + train_ds = train_ds.shard( + num_shards=dataloading_host_count, index=dataloading_host_index + ) + train_ds = rekey(train_ds, {"inputs": None, "targets": "text"}) + + eval_ds = eval_ds.shard( + num_shards=dataloading_host_count, index=dataloading_host_index + ) # note validation_tokenized_5662seqs split is pre tokenized, reduce_concated and splitted to target_length # mainly to avoid eval sequences change depending on the number of hosts - eval_ds = rekey(eval_ds, {'inputs': None, 'targets': 'ids'}) + eval_ds = rekey(eval_ds, {"inputs": None, "targets": "ids"}) return train_ds, eval_ds -def preprocess_dataset(config: ml_collections.ConfigDict, - global_mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed: int = 0, - shuffle_buffer_size: int = 128, - ): + +def preprocess_dataset( + config: ml_collections.ConfigDict, + global_mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed: int = 0, + shuffle_buffer_size: int = 128, +): """Pre-process the dataset and return iterators for mlperf training.""" # tokenize train_ds = train_ds.map( - tokenizer.TokenizeOp(sp_tokenizer, data_keys=('targets',)), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer, data_keys=("targets",)), + num_parallel_calls=AUTOTUNE, + ) - train_ds = reduce_concat_tokens(train_ds, feature_key='targets', batch_size=4096) + train_ds = reduce_concat_tokens( + train_ds, feature_key="targets", batch_size=4096 + ) train_ds = split_tokens_to_targets_length(train_ds, config.max_target_length) train_ds = train_ds.shuffle(shuffle_buffer_size, seed=data_shuffle_seed) @@ -241,8 +284,10 @@ def format_fn(x, eos_id: int = 1, pad_id: int = 0): x["inputs_position"] = x["targets_position"] x["targets"] = _shift_left_and_pad(x["targets"], eos_id) x["inputs_segmentation"] = tf.where( - tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), - x["targets_segmentation"], 0) + tf.logical_and(x["targets"] != eos_id, x["targets"] != pad_id), + x["targets_segmentation"], + 0, + ) x["targets_segmentation"] = x["inputs_segmentation"] return x @@ -257,12 +302,16 @@ def format_fn(x, eos_id: int = 1, pad_id: int = 0): else: eval_batch_size = global_batch_size_to_load - train_ds = train_ds.batch(global_batch_size_to_load // jax.process_count(), drop_remainder=True) + train_ds = train_ds.batch( + global_batch_size_to_load // jax.process_count(), drop_remainder=True + ) # ensure array split in an equal division for each device # pad zeros up to the same batch_size among all processes eval_ds = _pad_to_batch_size(eval_ds, eval_batch_size // jax.process_count()) - eval_ds = eval_ds.batch(eval_batch_size // jax.process_count(), drop_remainder=False) + eval_ds = eval_ds.batch( + eval_batch_size // jax.process_count(), drop_remainder=False + ) # We are running eval over exactly one epoch. # We explicitly cache the entire epoch (in memory) to ensure that it is the # same across different iterations. @@ -271,8 +320,12 @@ def format_fn(x, eos_id: int = 1, pad_id: int = 0): train_ds = train_ds.prefetch(AUTOTUNE) eval_ds = eval_ds.prefetch(AUTOTUNE) - train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(train_ds, global_mesh) - eval_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(eval_ds, global_mesh) + train_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator( + train_ds, global_mesh + ) + eval_multihost_gen = multihost_dataloading.MultiHostDataLoadIterator( + eval_ds, global_mesh + ) # Return multi-host jax.Array prep iterator return train_multihost_gen, eval_multihost_gen diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 58227feee..f88d1c2bc 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Input pipeline""" @@ -26,79 +26,97 @@ from input_pipeline import _tfds_data_processing_c4_mlperf import tokenizer + def get_tokenizer(tokenizer_path, add_bos=True, add_eos=True): # Load tokenizer - sp_tokenizer = tokenizer.load_tokenizer(tokenizer_path=tokenizer_path, - add_bos=add_bos, - add_eos=add_eos) + sp_tokenizer = tokenizer.load_tokenizer( + tokenizer_path=tokenizer_path, add_bos=add_bos, add_eos=add_eos + ) return sp_tokenizer -def make_c4_mlperf_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for customized C4 dataset for mlperf gpt3 training.""" + +def make_c4_mlperf_train_iterator_and_tokenizer( + config, mesh, add_bos, add_eos, process_indices +): + """Make train iterator and tokenizer for customized C4 dataset for mlperf gpt3 training.""" train_ds, eval_ds = _tfds_data_processing_c4_mlperf.get_datasets( - config=config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), + config=config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), ) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, eval_iter = _tfds_data_processing_c4_mlperf.preprocess_dataset( - config, - mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed=config.data_shuffle_seed + config, + mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed=config.data_shuffle_seed, ) return train_iter, eval_iter, sp_tokenizer -def make_c4_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for C4 dataset""" + +def make_c4_train_iterator_and_tokenizer( + config, mesh, add_bos, add_eos, process_indices +): + """Make train iterator and tokenizer for C4 dataset""" read_config = tfds.ReadConfig( - shuffle_seed = config.data_shuffle_seed, + shuffle_seed=config.data_shuffle_seed, ) train_ds, eval_ds = _tfds_data_processing.get_datasets( - config=config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - read_config = read_config, + config=config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + read_config=read_config, ) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, _, _ = _tfds_data_processing.preprocess_dataset( - config, - mesh, - train_ds, eval_ds, sp_tokenizer, - data_shuffle_seed = config.data_shuffle_seed, + config, + mesh, + train_ds, + eval_ds, + sp_tokenizer, + data_shuffle_seed=config.data_shuffle_seed, ) return train_iter, None, sp_tokenizer -def make_grain_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices): - """ Make train iterator and tokenizer for C4 dataset""" - train_ds, eval_ds = _grain_data_processing.get_datasets( - config=config - ) + +def make_grain_train_iterator_and_tokenizer( + config, mesh, add_bos, add_eos, process_indices +): + """Make train iterator and tokenizer for C4 dataset""" + train_ds, eval_ds = _grain_data_processing.get_datasets(config=config) sp_tokenizer = get_tokenizer(config.tokenizer_path, add_bos, add_eos) train_iter, _, _ = _grain_data_processing.preprocess_dataset( - config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - global_mesh = mesh, - train_ds = train_ds, eval_ds = eval_ds, - vocab_path=config.tokenizer_path, - data_shuffle_seed = config.data_shuffle_seed, - add_bos = add_bos, - add_eos = add_eos + config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=mesh, + train_ds=train_ds, + eval_ds=eval_ds, + vocab_path=config.tokenizer_path, + data_shuffle_seed=config.data_shuffle_seed, + add_bos=add_bos, + add_eos=add_eos, ) return train_iter, None, sp_tokenizer -class SyntheticDataIterator(): + +class SyntheticDataIterator: """Creates a synthetic data iterator for performance testing work""" + def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - self.data_generator = jax.jit(SyntheticDataIterator.raw_generate_synthetic_data, + lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec + ) + self.data_generator = jax.jit( + SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, - static_argnums=0) + static_argnums=0, + ) def __iter__(self): return self @@ -111,31 +129,48 @@ def __next__(self): def raw_generate_synthetic_data(config): """Generates a single batch of syntehtic data""" output = {} - output['inputs'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['inputs_position'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['inputs_segmentation'] = jax.numpy.ones( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets_position'] = jax.numpy.zeros( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) - output['targets_segmentation'] = jax.numpy.ones( (config.global_batch_size_to_load, config.max_target_length), - dtype=jax.numpy.int32) + output["inputs"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), + dtype=jax.numpy.int32, + ) + output["inputs_position"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), + dtype=jax.numpy.int32, + ) + output["inputs_segmentation"] = jax.numpy.ones( + (config.global_batch_size_to_load, config.max_target_length), + dtype=jax.numpy.int32, + ) + output["targets"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), + dtype=jax.numpy.int32, + ) + output["targets_position"] = jax.numpy.zeros( + (config.global_batch_size_to_load, config.max_target_length), + dtype=jax.numpy.int32, + ) + output["targets_segmentation"] = jax.numpy.ones( + (config.global_batch_size_to_load, config.max_target_length), + dtype=jax.numpy.int32, + ) return output -class BadSyntheticDataIterator(): + +class BadSyntheticDataIterator: """Creates a Bad synthetic data iterator for loading on subset of hosts""" + def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - self.data_generator = jax.jit(BadSyntheticDataIterator.get_bad_synthetic_data, + lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec + ) + self.data_generator = jax.jit( + BadSyntheticDataIterator.get_bad_synthetic_data, out_shardings=data_pspec_shardings, - static_argnums=0) + static_argnums=0, + ) def __iter__(self): return self @@ -146,27 +181,47 @@ def __next__(self): @staticmethod def get_bad_synthetic_data(config): - """fill negative value in synthetic data """ + """fill negative value in synthetic data""" output = {} - output['inputs'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['inputs_position'] = jax.numpy.full((config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['inputs_segmentation'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets_position'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) - output['targets_segmentation'] = jax.numpy.full( (config.global_batch_size_to_load, - config.max_target_length), -1, dtype=jax.numpy.int32) + output["inputs"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), + -1, + dtype=jax.numpy.int32, + ) + output["inputs_position"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), + -1, + dtype=jax.numpy.int32, + ) + output["inputs_segmentation"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), + -1, + dtype=jax.numpy.int32, + ) + output["targets"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), + -1, + dtype=jax.numpy.int32, + ) + output["targets_position"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), + -1, + dtype=jax.numpy.int32, + ) + output["targets_segmentation"] = jax.numpy.full( + (config.global_batch_size_to_load, config.max_target_length), + -1, + dtype=jax.numpy.int32, + ) return output + def get_process_loading_real_data(config, mesh): - """ Get list of processes loading data from GCS when expansion_factor_real_data != -1 - """ + """Get list of processes loading data from GCS when expansion_factor_real_data != -1""" sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) - devices_indices_map = sharding.devices_indices_map((config.global_batch_size_to_load, config.max_target_length)) + devices_indices_map = sharding.devices_indices_map( + (config.global_batch_size_to_load, config.max_target_length) + ) batch_cutoff = config.global_batch_size_to_train_on process_loading_real_data = set() for p, indices in devices_indices_map.items(): @@ -174,39 +229,81 @@ def get_process_loading_real_data(config, mesh): process_loading_real_data.add(p.process_index) return list(process_loading_real_data) + def make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos): process_indices = get_process_loading_real_data(config, mesh) - print(len(process_indices),"hosts out of",jax.process_count(),"are loading real data") - if config.expansion_factor_real_data != -1: # assert number of hosts loading real data - assert len(process_indices) == jax.process_count() // config.expansion_factor_real_data + print( + len(process_indices), + "hosts out of", + jax.process_count(), + "are loading real data", + ) + if ( + config.expansion_factor_real_data != -1 + ): # assert number of hosts loading real data + assert ( + len(process_indices) + == jax.process_count() // config.expansion_factor_real_data + ) if jax.process_index() in process_indices: if config.dataset_type == "c4": - return make_c4_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices) + return make_c4_train_iterator_and_tokenizer( + config, mesh, add_bos, add_eos, process_indices + ) elif config.dataset_type == "c4-array_record": - return make_grain_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos, process_indices) + return make_grain_train_iterator_and_tokenizer( + config, mesh, add_bos, add_eos, process_indices + ) elif config.dataset_type == "c4_mlperf": print("Overwrite both add_bos and add_eos to False") - return make_c4_mlperf_train_iterator_and_tokenizer(config, mesh, add_bos=False, add_eos=False, process_indices = process_indices) + return make_c4_mlperf_train_iterator_and_tokenizer( + config, + mesh, + add_bos=False, + add_eos=False, + process_indices=process_indices, + ) else: - return BadSyntheticDataIterator(config, mesh), None, get_tokenizer(config.tokenizer_path, add_bos, add_eos) + return ( + BadSyntheticDataIterator(config, mesh), + None, + get_tokenizer(config.tokenizer_path, add_bos, add_eos), + ) + -def create_data_iterator_with_tokenizer(config, mesh, add_bos = True, add_eos = True): +def create_data_iterator_with_tokenizer( + config, mesh, add_bos=True, add_eos=True +): if config.dataset_type == "synthetic": - return SyntheticDataIterator(config, mesh), None, get_tokenizer(config.tokenizer_path, add_bos, add_eos) + return ( + SyntheticDataIterator(config, mesh), + None, + get_tokenizer(config.tokenizer_path, add_bos, add_eos), + ) elif config.dataset_type in ("c4", "c4-array_record", "c4_mlperf"): - return make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos) + return make_mixed_train_iterator_and_tokenizer( + config, mesh, add_bos, add_eos + ) else: assert False, "dataset type not implemented" + def get_shaped_batch(config): - """ Return the shape of the batch - this is what eval_shape would return for the - output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078.""" + """Return the shape of the batch - this is what eval_shape would return for the + output of create_data_iterator_with_tokenizer, but eval_shape doesn't work, see b/306901078. + """ batch_shape = (config.global_batch_size_to_load, config.max_target_length) shaped_batch = {} - shaped_batch['inputs'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['inputs_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['inputs_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets_position'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) - shaped_batch['targets_segmentation'] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs_position"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["inputs_segmentation"] = jax.ShapeDtypeStruct( + batch_shape, jnp.int32 + ) + shaped_batch["targets"] = jax.ShapeDtypeStruct(batch_shape, jnp.int32) + shaped_batch["targets_position"] = jax.ShapeDtypeStruct( + batch_shape, jnp.int32 + ) + shaped_batch["targets_segmentation"] = jax.ShapeDtypeStruct( + batch_shape, jnp.int32 + ) return shaped_batch diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 8ea5e7bd0..9d200d66c 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -59,7 +59,8 @@ shard_map = shard_map.shard_map dynamic_vector_slice_in_dim = jax.vmap( - lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) + lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None) +) # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error @@ -87,12 +88,16 @@ def apply_mask_to_logits(logits: Array, mask: Array): Returns: Masked logits. """ - return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) + return jnp.where( + (mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE + ) + def _maybe_aqt_einsum(quant: Quant): """Maybe overwrite dot general with aqt_dot_general.""" return jnp.einsum if quant is None else quant.einsum() + class AttentionOp(nn.Module): mesh: Mesh attention_kernel: str @@ -100,44 +105,42 @@ class AttentionOp(nn.Module): num_query_heads: int num_kv_heads: int float32_qk_product: bool = False - max_prefill_predict_length: int = -1 + max_prefill_predict_length: int = -1 float32_logits: bool = False flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) - dropout_rate: float = 0. + dropout_rate: float = 0.0 dtype: DType = jnp.float32 quant: Optional[Quant] = None quantize_kvcache: bool = False def check_attention_inputs( - self, - query: Array, - key: Array, - value: Array) -> None: + self, query: Array, key: Array, value: Array + ) -> None: """Check attention inputs.""" - assert key.ndim == value.ndim, 'k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert key.shape[-2] == value.shape[-2], ('k, v num_kv_heads must match.') - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + assert key.ndim == value.ndim, "k, v must have same rank." + assert ( + query.shape[:-3] == key.shape[:-3] == value.shape[:-3] + ), "q, k, v batch dims must match." + assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." + assert key.shape[-3] == value.shape[-3], "k, v lengths must match." + assert query.shape[-1] == key.shape[-1], "q, k depths must match." # Following Pallas MHA Flash Attention Reference. # https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py # This mask models (1) separate sequences (decoder_segment_ids) and (2) causality def generate_attention_mask( - self, - query, - key, - decoder_segment_ids: Array | None, - model_mode: str + self, query, key, decoder_segment_ids: Array | None, model_mode: str ) -> Array | None: mask = None if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - mask = decoder_segment_ids[:, None, None, None, :] == common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + mask = ( + decoder_segment_ids[:, None, None, None, :] + == common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + ) elif decoder_segment_ids is not None: mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :] - mask = mask[:, None, None,:, :] + mask = mask[:, None, None, :, :] causal_mask = None # We enforce causality except for AUTOREGRESSION @@ -158,54 +161,83 @@ def generate_attention_mask( else: output_mask = None - return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None + return ( + jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) + if output_mask is not None + else None + ) - def apply_attention(self, + def apply_attention( + self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, - model_mode: str): + model_mode: str, + ): self.check_attention_inputs(query, key, value) length = query.shape[-3] - if self.attention_kernel == 'dot_product' or\ - (self.attention_kernel == 'autoselected' and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) or\ - (self.attention_kernel == 'autoselected' and length < 128): - return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode) - elif self.attention_kernel == 'flash' or\ - self.attention_kernel == 'autoselected': + if ( + self.attention_kernel == "dot_product" + or ( + self.attention_kernel == "autoselected" + and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE + ) + or (self.attention_kernel == "autoselected" and length < 128) + ): + return self.apply_attention_dot( + query, key, value, decoder_segment_ids, model_mode + ) + elif ( + self.attention_kernel == "flash" + or self.attention_kernel == "autoselected" + ): if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") - return self.tpu_flash_attention(query, key, value, decoder_segment_ids), None, None - elif self.attention_kernel == 'cudnn_flash_te': + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) + return ( + self.tpu_flash_attention(query, key, value, decoder_segment_ids), + None, + None, + ) + elif self.attention_kernel == "cudnn_flash_te": if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: - raise ValueError("""Decode not supported with flash attention. - Use `dot_product` instead.""") - return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) + return ( + self.cudnn_flash_attention( + query, key, value, decoder_segment_ids, model_mode + ), + None, + None, + ) else: - raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.') + raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") def tpu_flash_attention( - self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None) -> Array: + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + ) -> Array: """TPU Flash Attention.""" # Transpose to ('batch', 'heads', 'length', 'kv') query = jnp.transpose(query, axes=(0, 2, 1, 3)) key = jnp.transpose(key, axes=(0, 2, 1, 3)) value = jnp.transpose(value, axes=(0, 2, 1, 3)) - if decoder_segment_ids is not None: decoder_segment_ids = splash_attention_kernel.SegmentIds( decoder_segment_ids, decoder_segment_ids ) axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) segment_axis_names = nn.logical_to_mesh_axes( - (BATCH, 'activation_length_no_heads') + (BATCH, "activation_length_no_heads") ) @functools.partial( @@ -223,76 +255,88 @@ def tpu_flash_attention( def wrap_flash_attention(query, key, value, decoder_segment_ids): if decoder_segment_ids is not None: assert ( - query.shape[2] - == decoder_segment_ids.q.shape[1] - ), 'Sharding along sequence dimension not allowed in tpu kernel attention' + query.shape[2] == decoder_segment_ids.q.shape[1] + ), "Sharding along sequence dimension not allowed in tpu kernel attention" block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(512, query.shape[2]), - block_kv_compute=min(512, key.shape[2]), - block_kv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_kv_dkv=min(512, key.shape[2]), - block_kv_dkv_compute=min(512, query.shape[2]), - block_q_dq=min(512, query.shape[2]), - block_kv_dq=min(512, query.shape[2]), + block_q=min(512, query.shape[2]), + block_kv_compute=min(512, key.shape[2]), + block_kv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_kv_dkv=min(512, key.shape[2]), + block_kv_dkv_compute=min(512, query.shape[2]), + block_q_dq=min(512, query.shape[2]), + block_kv_dq=min(512, query.shape[2]), ) - masks = [splash_attention_mask.CausalMask( shape=(query.shape[2],query.shape[2])) for i in range(query.shape[1])] + masks = [ + splash_attention_mask.CausalMask( + shape=(query.shape[2], query.shape[2]) + ) + for i in range(query.shape[1]) + ] multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) - splash_kernel = splash_attention_kernel.make_splash_mha(mask = multi_head_mask, - head_shards = 1, - q_seq_shards = 1, - block_sizes = block_sizes) - - return jax.vmap(splash_kernel)(query,key,value, segment_ids = decoder_segment_ids) - - devices_in_data_fsdp = self.mesh.shape['data'] * self.mesh.shape['fsdp'] + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, + q_seq_shards=1, + block_sizes=block_sizes, + ) + + return jax.vmap(splash_kernel)( + query, key, value, segment_ids=decoder_segment_ids + ) + + devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( - 'Batch dimension should be shardable among the devices in data and fsdp' - ' axis' + "Batch dimension should be shardable among the devices in data and fsdp" + " axis" ) x = wrap_flash_attention(query, key, value, decoder_segment_ids) x = jnp.transpose(x, axes=(0, 2, 1, 3)) return x - + def cudnn_flash_attention( - self, - query: Array, - key: Array, - value: Array, - decoder_segment_ids: Array | None, - model_mode: str = common_types.MODEL_MODE_TRAIN, - ) -> Array: + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + model_mode: str = common_types.MODEL_MODE_TRAIN, + ) -> Array: """CUDNN Flash Attention with Transformer Engine. - 1. Stable API, supports GQA - 2. Supports head_dim till 128; head_dim=256 support will be added soon + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon """ # These imports are only meant to work in a GPU build. - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error _, _, _, head_dim = query.shape # pylint: disable=unused-variable - #generate attn_mask - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - - dpa_layer = DotProductAttention(head_dim=head_dim, - num_attention_heads=self.num_query_heads, - num_gqa_groups=self.num_kv_heads, - attn_mask_type='causal', # 'causal' or 'padding' - attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - attention_dropout=self.dropout_rate, - dropout_rng_name='aqt', - dtype=self.dtype, - float32_logits=self.float32_logits, - qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=1.0/math.sqrt(head_dim), - transpose_batch_sequence=False) + # generate attn_mask + attn_mask = self.generate_attention_mask( + query, key, decoder_segment_ids, model_mode + ) + + dpa_layer = DotProductAttention( + head_dim=head_dim, + num_attention_heads=self.num_query_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type="causal", # 'causal' or 'padding' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0 / math.sqrt(head_dim), + transpose_batch_sequence=False, + ) return dpa_layer(query, key, value, mask=attn_mask) - def compute_local_attention(self, - attn_weights: Array, - value: Array) -> tuple[Array, Array, Array]: - """Computes the attention of a local subset of the kv cache. + def compute_local_attention( + self, attn_weights: Array, value: Array + ) -> tuple[Array, Array, Array]: + """Computes the attention of a local subset of the kv cache. Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py @@ -314,25 +358,33 @@ def compute_local_attention(self, local_sum = jnp.moveaxis(local_sum, -2, 1) local_max = jnp.moveaxis(local_max, -2, 1) - local_max = jnp.reshape(local_max, - (local_max.shape[0], - local_max.shape[1], - local_max.shape[2] * local_max.shape[3], - 1)) - local_sum = jnp.reshape(local_sum, - (local_sum.shape[0], - local_sum.shape[1], - local_sum.shape[2] * local_sum.shape[3], - 1)) + local_max = jnp.reshape( + local_max, + ( + local_max.shape[0], + local_max.shape[1], + local_max.shape[2] * local_max.shape[3], + 1, + ), + ) + local_sum = jnp.reshape( + local_sum, + ( + local_sum.shape[0], + local_sum.shape[1], + local_sum.shape[2] * local_sum.shape[3], + 1, + ), + ) local_out = self.wv_product(local_exps, value) return local_out, local_max, local_sum def apply_attention_dot( self, - query: Array, - key: Array, - value: Array, + query: Array, + key: Array, + value: Array, decoder_segment_ids: Array | None, model_mode: str = common_types.MODEL_MODE_TRAIN, ): @@ -347,7 +399,9 @@ def apply_attention_dot( # Casting softmaxt computation for float32 for model stability. if self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = self.generate_attention_mask( + query, key, decoder_segment_ids, model_mode + ) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) return self.compute_local_attention(attn_weights, value) @@ -364,28 +418,24 @@ def qk_product(self, query: Array, key: Array) -> Array: Returns: results in shape [b, n_kv, n // n_kv, t, s]. """ - b, t, n, d = query.shape + b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) - result = jnp.einsum('btkgd,bskd->bkgts', query, key) - return result # (4, 8, 1, 1, 6) - + result = jnp.einsum("btkgd,bskd->bkgts", query, key) + return result # (4, 8, 1, 1, 6) - def wv_product( - self, - attn_weights: Array, - value: Array) -> Array: + def wv_product(self, attn_weights: Array, value: Array) -> Array: """weighted value product. Args: - attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. + attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. value: Value projection, in shape of [batch_size, v_len, num_kv_heads, kv_dim]. Returns: result in shape [batch_size, q_len, num_kv_heads * group_size, kv_dim] """ - out = jnp.einsum('bkgts,bskd->btkgd', attn_weights, value) + out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) return result @@ -399,8 +449,7 @@ def revert_kvlen_axis(self, kv): Returns: reshaped kv as [b, ..., s, n, d] """ - return jax.numpy.moveaxis(kv, (0,1,2,3), (1,2,0,3)) - + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (1, 2, 0, 3)) def move_kvlen_axis(self, kv): """Move key/value length axis to the end. @@ -411,7 +460,7 @@ def move_kvlen_axis(self, kv): Returns: reshaped kv as [b, ..., n, d, s] """ - return jax.numpy.moveaxis(kv, (0,1,2,3), (2,0,1,3)) + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (2, 0, 1, 3)) def cached_kv_shape(self, kv_shape): """Cached KV shape. @@ -430,26 +479,63 @@ def cached_kv_shape(self, kv_shape): def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 - kv_cache_layout = ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv', ) - cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) - cached_key = self.variable('cache', 'cached_prefill_key', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_value = self.variable('cache', 'cached_prefill_value', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_segment_id = self.variable('cache', 'cache_prefill_segment_id', - nn.with_logical_partitioning(jnp.zeros, ('cache_batch', 'cache_sequence')), - (cache_logical_shape[0], self.max_prefill_predict_length), jnp.int32) + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) + cache_logical_shape = ( + batch, + self.max_prefill_predict_length, + heads, + kv_head_size, + ) + cached_key = self.variable( + "cache", + "cached_prefill_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value = self.variable( + "cache", + "cached_prefill_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_segment_id = self.variable( + "cache", + "cache_prefill_segment_id", + nn.with_logical_partitioning( + jnp.zeros, ("cache_batch", "cache_sequence") + ), + (cache_logical_shape[0], self.max_prefill_predict_length), + jnp.int32, + ) if self.quantize_kvcache: - cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1) - cached_key_scale_var = self.variable('cache', 'cached_prefill_key_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) - cached_value_scale_var = self.variable('cache', 'cached_prefill_value_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) + cache_logical_shape_scale = ( + batch, + self.max_prefill_predict_length, + heads, + 1, + ) + cached_key_scale_var = self.variable( + "cache", + "cached_prefill_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_prefill_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) else: cached_key_scale_var = None cached_value_scale_var = None @@ -461,85 +547,130 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 cache_length = self.max_target_length - self.max_prefill_predict_length - kv_cache_layout = ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv', ) + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) cache_logical_shape = (batch, cache_length, heads, kv_head_size) - cached_key = self.variable('cache', 'cached_ar_key', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_value = self.variable('cache', 'cached_ar_value', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape), dtype) - cached_segment_id = self.variable('cache', 'cache_ar_segment_id', - nn.with_logical_partitioning(jnp.zeros, ('cache_batch', 'cache_sequence')), - (cache_logical_shape[0], cache_length), jnp.int32) + cached_key = self.variable( + "cache", + "cached_ar_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value = self.variable( + "cache", + "cached_ar_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_segment_id = self.variable( + "cache", + "cache_ar_segment_id", + nn.with_logical_partitioning( + jnp.zeros, ("cache_batch", "cache_sequence") + ), + (cache_logical_shape[0], cache_length), + jnp.int32, + ) if self.quantize_kvcache: cache_logical_shape_scale = (batch, cache_length, heads, 1) - cached_key_scale_var = self.variable('cache', 'cached_ar_key_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) - cached_value_scale_var = self.variable('cache', 'cached_ar_value_scale', - nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), - self.cached_kv_shape(cache_logical_shape_scale), jnp.bfloat16) + cached_key_scale_var = self.variable( + "cache", + "cached_ar_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_ar_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) else: cached_key_scale_var = None cached_value_scale_var = None - cache_index = self.variable('cache', 'cache_ar_index', - nn.with_logical_partitioning(jnp.zeros, ()), - (1,), jnp.int32) + cache_index = self.variable( + "cache", + "cache_ar_index", + nn.with_logical_partitioning(jnp.zeros, ()), + (1,), + jnp.int32, + ) key_vars = (cached_key, cached_key_scale_var) value_vars = (cached_value, cached_value_scale_var) return key_vars, value_vars, cached_segment_id, cache_index - def kv_cache_prefill(self, - key: Array, - value: Array, - decoder_segment_ids: Array, - ): - """In prefill mode, we zero out the existing cache, run the computation and - prepare the cache as necessary. - - Args: - key: in shape [b, s, n, d]. - value: in shape [b, s, n, d]. - decoder_segment_ids: [b, s] -- marking segment ids for tokens - - Returns: - key, value, decoder_segment_id. - - """ - batch, sequence, heads, kv_head_size = key.shape - assert key.dtype == value.dtype, "Key and Value Dtypes should match." - - cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(batch, heads, kv_head_size, self.quantize_kvcache) - self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now - - key_shaped_for_cache = self.move_kvlen_axis(key) - value_shaped_for_cache = self.move_kvlen_axis(value) - - if self.quantize_kvcache: - key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache) - value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache) - cached_prefill_key_var[1].value = key_scale - cached_prefill_value_var[1].value = value_scale - - cached_prefill_key_var[0].value = key_shaped_for_cache - cached_prefill_value_var[0].value = value_shaped_for_cache + def kv_cache_prefill( + self, + key: Array, + value: Array, + decoder_segment_ids: Array, + ): + """In prefill mode, we zero out the existing cache, run the computation and + prepare the cache as necessary. - if decoder_segment_ids is not None: - cached_prefill_segment_id.value = decoder_segment_ids + Args: + key: in shape [b, s, n, d]. + value: in shape [b, s, n, d]. + decoder_segment_ids: [b, s] -- marking segment ids for tokens - return key, value, decoder_segment_ids - + Returns: + key, value, decoder_segment_id. + + """ + batch, sequence, heads, kv_head_size = key.shape + assert key.dtype == value.dtype, "Key and Value Dtypes should match." + + ( + cached_prefill_key_var, + cached_prefill_value_var, + cached_prefill_segment_id, + ) = self._get_prefill_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) + self._get_ar_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) # initialize it now + + key_shaped_for_cache = self.move_kvlen_axis(key) + value_shaped_for_cache = self.move_kvlen_axis(value) + + if self.quantize_kvcache: + key_shaped_for_cache, key_scale = quantizations.quantize_kv( + key_shaped_for_cache + ) + value_shaped_for_cache, value_scale = quantizations.quantize_kv( + value_shaped_for_cache + ) + cached_prefill_key_var[1].value = key_scale + cached_prefill_value_var[1].value = value_scale - def update_ar_key_value(self, - one_token_key: Array, - one_token_value: Array, - cached_key_vars: tuple[nn.Variable, nn.Variable|None], - cached_value_vars: tuple[nn.Variable, nn.Variable|None], - one_hot_indices: Array) -> tuple[Array, Array]: - """Adds a single token's results to the ar kv cache + cached_prefill_key_var[0].value = key_shaped_for_cache + cached_prefill_value_var[0].value = value_shaped_for_cache + + if decoder_segment_ids is not None: + cached_prefill_segment_id.value = decoder_segment_ids + + return key, value, decoder_segment_ids + + def update_ar_key_value( + self, + one_token_key: Array, + one_token_value: Array, + cached_key_vars: tuple[nn.Variable, nn.Variable | None], + cached_value_vars: tuple[nn.Variable, nn.Variable | None], + one_hot_indices: Array, + ) -> tuple[Array, Array]: + """Adds a single token's results to the ar kv cache Args: one_token_key (Array): Key of one token to add to the cache @@ -561,30 +692,69 @@ def update_ar_key_value(self, one_token_value = self.move_kvlen_axis(one_token_value) if self.quantize_kvcache: - one_token_key, one_token_key_scale = quantizations.quantize_kv(one_token_key) - one_token_value, one_token_value_scale = quantizations.quantize_kv(one_token_value) + one_token_key, one_token_key_scale = quantizations.quantize_kv( + one_token_key + ) + one_token_value, one_token_value_scale = quantizations.quantize_kv( + one_token_value + ) one_hot_indices = one_hot_indices.astype(int) - ar_key = cached_key_var.value - ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key, jnp.squeeze(one_hot_indices), 0) - ar_key = nn.with_logical_constraint(ar_key, ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv',)) + ar_key = jax.lax.dynamic_update_index_in_dim( + ar_key, one_token_key, jnp.squeeze(one_hot_indices), 0 + ) + ar_key = nn.with_logical_constraint( + ar_key, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) cached_key_var.value = ar_key ar_value = cached_value_var.value - ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value, jnp.squeeze(one_hot_indices), 0) - ar_value = nn.with_logical_constraint(ar_value, ('cache_sequence', 'cache_heads', 'cache_batch', 'cache_kv',)) + ar_value = jax.lax.dynamic_update_index_in_dim( + ar_value, one_token_value, jnp.squeeze(one_hot_indices), 0 + ) + ar_value = nn.with_logical_constraint( + ar_value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) cached_value_var.value = ar_value if self.quantize_kvcache: - ar_key_scale = jax.lax.dynamic_update_index_in_dim(cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0) - ar_value_scale = jax.lax.dynamic_update_index_in_dim(cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0) + ar_key_scale = jax.lax.dynamic_update_index_in_dim( + cached_key_scale_var.value, + one_token_key_scale, + jnp.squeeze(one_hot_indices), + 0, + ) + ar_value_scale = jax.lax.dynamic_update_index_in_dim( + cached_value_scale_var.value, + one_token_value_scale, + jnp.squeeze(one_hot_indices), + 0, + ) cached_key_scale_var.value = ar_key_scale cached_value_scale_var.value = ar_value_scale - ar_key = quantizations.unquantize_kv(cached_key_var.value, cached_key_scale_var.value, one_token_key.dtype) - ar_value = quantizations.unquantize_kv(cached_value_var.value, cached_value_scale_var.value, one_token_value.dtype) + ar_key = quantizations.unquantize_kv( + cached_key_var.value, cached_key_scale_var.value, one_token_key.dtype + ) + ar_value = quantizations.unquantize_kv( + cached_value_var.value, + cached_value_scale_var.value, + one_token_value.dtype, + ) # Move the keys and values back to their original shapes. return self.revert_kvlen_axis(ar_key), self.revert_kvlen_axis(ar_value) @@ -594,57 +764,90 @@ def prefill_cache_var_model_var(self, cache_var, target_dtype): return self.revert_kvlen_axis(cache_var[0].value) else: raw_cache, quant_scale = cache_var - raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype) + raw_cache_unquantized = quantizations.unquantize_kv( + raw_cache.value, quant_scale.value, target_dtype + ) return self.revert_kvlen_axis(raw_cache_unquantized) - - - def kv_cache_autoregressive(self, - key: Array, - value: Array, - ): - """In autoregressive mode, we update the cache for this entry and - then return the full cache. + def kv_cache_autoregressive( + self, + key: Array, + value: Array, + ): + """In autoregressive mode, we update the cache for this entry and + then return the full cache. - Args: - key: in shape [b, 1, n, d]. - value: in shape [b, 1, n, d]. - decoder_segment_ids: [b, 1] -- marking segment ids for tokens + Args: + key: in shape [b, 1, n, d]. + value: in shape [b, 1, n, d]. + decoder_segment_ids: [b, 1] -- marking segment ids for tokens - Returns: - tuple of (key, value, segment_id) for both prefill and ar cache, - Raises: - ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. - """ - batch, sequence, heads, kv_head_size = key.shape - if sequence != 1: - raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") - is_initialized = self.has_variable('cache', 'cache_ar_index') - if not is_initialized: - raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") + Returns: + tuple of (key, value, segment_id) for both prefill and ar cache, + Raises: + ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. + """ + batch, sequence, heads, kv_head_size = key.shape + if sequence != 1: + raise ValueError( + f"Sequence length should be 1 during autoregression, got {sequence=}" + ) + is_initialized = self.has_variable("cache", "cache_ar_index") + if not is_initialized: + raise ValueError( + "Error, we can't do autoregression if we haven't seeded the KV Cache." + ) - cached_ar_key_var, cached_ar_value_var, cached_ar_segment_id, cache_ar_index = self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) + ( + cached_ar_key_var, + cached_ar_value_var, + cached_ar_segment_id, + cache_ar_index, + ) = self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) - key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) - value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) + key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) + value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) - ar_key, ar_value = self.update_ar_key_value(key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value) - active_indicator = jnp.zeros((batch, 1), dtype = jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR - cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim(cached_ar_segment_id.value, active_indicator, jnp.squeeze(cache_ar_index.value), 1) - cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length) + ar_key, ar_value = self.update_ar_key_value( + key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value + ) + active_indicator = ( + jnp.zeros((batch, 1), dtype=jnp.int32) + + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + ) + cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id.value, + active_indicator, + jnp.squeeze(cache_ar_index.value), + 1, + ) + cache_ar_index.value = jnp.mod( + cache_ar_index.value + 1, + self.max_target_length - self.max_prefill_predict_length, + ) - # Prep and return both prefill and ar caches - cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(self.max_target_length, heads, kv_head_size, self.quantize_kvcache) + # Prep and return both prefill and ar caches + ( + cached_prefill_key_var, + cached_prefill_value_var, + cached_prefill_segment_id, + ) = self._get_prefill_cache( + self.max_target_length, heads, kv_head_size, self.quantize_kvcache + ) - cached_prefill = self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), cached_prefill_segment_id.value - return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) + cached_prefill = ( + self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), + self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), + cached_prefill_segment_id.value, + ) + return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) def kv_cache( self, key: Array, value: Array, decoder_segment_ids: Array, - model_mode: str + model_mode: str, ) -> tuple: """KV cache takes the current state and updates the state accordingly. @@ -664,8 +867,9 @@ def kv_cache( """ if key.shape != value.shape: - raise ValueError(f"Can't KV cache with mismatched shapes {key.shape=}, {value.shape=}") - + raise ValueError( + f"Can't KV cache with mismatched shapes {key.shape=}, {value.shape=}" + ) if model_mode == common_types.MODEL_MODE_TRAIN: return (key, value, decoder_segment_ids), None @@ -675,12 +879,8 @@ def kv_cache( return self.kv_cache_autoregressive(key, value) else: raise ValueError(f"Model Mode isn't supported! {model_mode=}") - - - def normalize_attention(self, - local_outs, - local_maxes, - local_sums): + + def normalize_attention(self, local_outs, local_maxes, local_sums): """Normalize across multiple localized attentions Args: @@ -689,14 +889,16 @@ def normalize_attention(self, local_sums (list): List of exponential sum entries for each local attention Returns: - Array: Combined attention that has been normalized + Array: Combined attention that has been normalized """ # Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py global_max = functools.reduce(jnp.maximum, local_maxes) - global_sum = sum([ - jnp.exp(local_max - global_max) * local_sum - for (local_sum, local_max) in zip(local_sums, local_maxes) - ]) + global_sum = sum( + [ + jnp.exp(local_max - global_max) * local_sum + for (local_sum, local_max) in zip(local_sums, local_maxes) + ] + ) attn_out = 0 for local_max, local_out in zip(local_maxes, local_outs): @@ -704,17 +906,22 @@ def normalize_attention(self, attn_out += local_normalizer * local_out return attn_out - @nn.compact def __call__(self, query, key, value, decoder_segment_ids, model_mode): - prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode) + prefill_kv_cache, ar_kv_cache = self.kv_cache( + key, value, decoder_segment_ids, model_mode + ) - prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( - query=query, - key=prefill_kv_cache[0], - value=prefill_kv_cache[1], - decoder_segment_ids=prefill_kv_cache[2], - model_mode=model_mode, + ( + prefill_unnormalized_output, + prefill_exponentials_max, + prefill_exponentials_sum, + ) = self.apply_attention( + query=query, + key=prefill_kv_cache[0], + value=prefill_kv_cache[1], + decoder_segment_ids=prefill_kv_cache[2], + model_mode=model_mode, ) # Return the "prefill" cache if it actually the combined prefill+ar kv cache @@ -723,42 +930,48 @@ def __call__(self, query, key, value, decoder_segment_ids, model_mode): return prefill_unnormalized_output / prefill_exponentials_sum return prefill_unnormalized_output - ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( - query=query, - key=ar_kv_cache[0], - value=ar_kv_cache[1], - decoder_segment_ids=ar_kv_cache[2], - model_mode=model_mode, + ( + ar_unnormalized_output, + ar_exponentials_max, + ar_exponentials_sum, + ) = self.apply_attention( + query=query, + key=ar_kv_cache[0], + value=ar_kv_cache[1], + decoder_segment_ids=ar_kv_cache[2], + model_mode=model_mode, ) unnormalized_outputs = [prefill_unnormalized_output, ar_unnormalized_output] exponentials_maxes = [prefill_exponentials_max, ar_exponentials_max] exponentials_sums = [prefill_exponentials_sum, ar_exponentials_sum] - return self.normalize_attention(unnormalized_outputs, exponentials_maxes, exponentials_sums) + return self.normalize_attention( + unnormalized_outputs, exponentials_maxes, exponentials_sums + ) class Attention(nn.Module): - """ Generic Attention. - - Attributes: - num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - num_kv_heads: number of kv attention heads. - head_dim: dimension of each head. - mesh: Mesh, device mesh - attention_kernel: str, guidance on if we should use an attention kernel - dtype: the dtype of the computation. - weight_dtype: the dtype of the weights. - max_target_length: maximum target length - max_prefill_predict_length: size of the maximum prefill - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid - numerical issues with bfloat16. - float32_logits: bool, if True then cast logits to float32 before softmax to avoid - numerical issues with bfloat16. - quant: Quant, stores quantization parameters, defaults to None implying no quantization. - quantize_kvcache: bool, quantize the kv cache. + """Generic Attention. + + Attributes: + num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + num_kv_heads: number of kv attention heads. + head_dim: dimension of each head. + mesh: Mesh, device mesh + attention_kernel: str, guidance on if we should use an attention kernel + dtype: the dtype of the computation. + weight_dtype: the dtype of the weights. + max_target_length: maximum target length + max_prefill_predict_length: size of the maximum prefill + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + quant: Quant, stores quantization parameters, defaults to None implying no quantization. + quantize_kvcache: bool, quantize the kv cache. """ config: Config @@ -771,14 +984,13 @@ class Attention(nn.Module): dtype: DType = jnp.float32 weight_dtype: DType = jnp.float32 max_prefill_predict_length: int = -1 - dropout_rate: float = 0. - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + dropout_rate: float = 0.0 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") float32_qk_product: bool = False # computes logits in float32 for stability. float32_logits: bool = False # cast logits in float32 for stability. quant: Optional[Quant] = None quantize_kvcache: bool = False - query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) @@ -791,19 +1003,21 @@ def query_projection(self, inputs_q: Array) -> Array: # 1/sqrt(depth_kq)! This is folded into the initializers of the # linear transformations, which is equivalent under Adafactor. depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + def query_init(*args): - #pylint: disable=no-value-for-parameter + # pylint: disable=no-value-for-parameter return self.kernel_init(*args) / depth_scaling query_proj = DenseGeneral( - features=(self.num_query_heads, self.head_dim), - axis=-1, - kernel_init=query_init, - kernel_axes=('embed', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='query', - quant=self.quant)(inputs_q) + features=(self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=query_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="query", + quant=self.quant, + )(inputs_q) return query_proj def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: @@ -818,66 +1032,75 @@ def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: Projection of key or value, in shape of `[batch, kv_length, head_dim]`. """ if self.num_kv_heads == -1: - raise ValueError('num_kv_heads is not defined.') + raise ValueError("num_kv_heads is not defined.") if self.num_query_heads % self.num_kv_heads != 0: - raise ValueError('Invaid num_kv_heads for GQA.') + raise ValueError("Invaid num_kv_heads for GQA.") kv_proj = DenseGeneral( features=(self.num_kv_heads, self.head_dim), axis=-1, kernel_init=self.kernel_init, - kernel_axes=('embed', 'heads', 'kv'), + kernel_axes=("embed", "heads", "kv"), dtype=self.dtype, weight_dtype=self.weight_dtype, name=proj_name, - quant=self.quant)(inputs_kv) + quant=self.quant, + )(inputs_kv) return kv_proj def qkv_projection(self, inputs: Array, proj_name: str): - """ Fused QKV projection""" + """Fused QKV projection""" qkv_proj = DenseGeneral( - features=(3, self.num_query_heads, self.head_dim), - axis = -1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'qkv', 'heads', 'kv'), + features=(3, self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), dtype=self.dtype, weight_dtype=self.weight_dtype, name=proj_name, - quant=self.quant)(inputs) - qkv_proj = checkpoint_name(qkv_proj, 'qkv_proj') - query, key, value = qkv_proj[:,:,0,...], qkv_proj[:,:,1,...], qkv_proj[:,:,2,...] + quant=self.quant, + )(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = ( + qkv_proj[:, :, 0, ...], + qkv_proj[:, :, 1, ...], + qkv_proj[:, :, 2, ...], + ) return query, key, value def out_projection(self, output_dim: int, out: Array) -> Array: out_proj = DenseGeneral( - features=output_dim, - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('heads', 'kv', 'embed'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='out', - quant=self.quant)(out) + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=("heads", "kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="out", + quant=self.quant, + )(out) return out_proj def key_rotary(self, key: Array, inputs_positions: Array): """Apply Rotary Embedding to key.""" - key = RotaryEmbedding( - embedding_dims=self.head_dim, - name='key_rotary')(inputs=key, position=inputs_positions) + key = RotaryEmbedding(embedding_dims=self.head_dim, name="key_rotary")( + inputs=key, position=inputs_positions + ) return key @nn.compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - inputs_positions: Array, - decoder_segment_ids: Array | None = None, - *, - model_mode: str = common_types.MODEL_MODE_TRAIN, - deterministic: bool = False): + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + inputs_positions: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False, + ): """Applies Attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, @@ -902,38 +1125,40 @@ def __call__(self, """ # apply projection. if self.config.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name='qkv_proj') + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: query = self.query_projection(inputs_q) - key = self.kv_projection(inputs_kv, proj_name='key') - value = self.kv_projection(inputs_kv, proj_name='value') + key = self.kv_projection(inputs_kv, proj_name="key") + value = self.kv_projection(inputs_kv, proj_name="value") # apply ROPE - query = RotaryEmbedding( - embedding_dims=self.head_dim, name='query_rotary' - )(inputs=query, position=inputs_positions) + query = RotaryEmbedding(embedding_dims=self.head_dim, name="query_rotary")( + inputs=query, position=inputs_positions + ) key = self.key_rotary(key, inputs_positions) # annotate with sharding constraint. query = nn.with_logical_constraint(query, self.query_axis_names) - query = checkpoint_name(query, 'query_proj') + query = checkpoint_name(query, "query_proj") key = nn.with_logical_constraint(key, self.key_axis_names) - key = checkpoint_name(key, 'key_proj') + key = checkpoint_name(key, "key_proj") value = nn.with_logical_constraint(value, self.value_axis_names) - value = checkpoint_name(value, 'value_proj') - - attention_op = AttentionOp(mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - quantize_kvcache=self.quantize_kvcache, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - dropout_rate = self.dropout_rate, - dtype=self.dtype) + value = checkpoint_name(value, "value_proj") + + attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + quantize_kvcache=self.quantize_kvcache, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) @@ -941,5 +1166,5 @@ def __call__(self, # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) - out = checkpoint_name(out, 'out_proj') + out = checkpoint_name(out, "out_proj") return out diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 6c954941c..8e605231e 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -32,6 +32,7 @@ _MAX_WAVELENGTH = 10_000 + class Embed(nn.Module): """A parameterized function from integers [0, n) to d-dimensional vectors. @@ -53,8 +54,8 @@ class Embed(nn.Module): def setup(self): self.embedding = self.param( - 'embedding', - with_logical_partitioning(self.embedding_init, ('vocab', 'embed')), + "embedding", + with_logical_partitioning(self.embedding_init, ("vocab", "embed")), (self.num_embeddings, self.features), self.config.weight_dtype, ) @@ -73,7 +74,7 @@ def __call__(self, inputs: Array) -> Array: if self.cast_input_dtype: inputs = inputs.astype(self.cast_input_dtype) if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') + raise ValueError("Input type must be an integer or unsigned integer.") if cfg.use_iota_embed: iota = lax.iota(jnp.int32, self.num_embeddings) @@ -82,7 +83,7 @@ def __call__(self, inputs: Array) -> Array: else: output = jnp.asarray(self.embedding, self.dtype)[inputs] output = nn.with_logical_constraint( - output, ('activation_batch', 'activation_length', 'activation_embed') + output, ("activation_batch", "activation_length", "activation_embed") ) return output @@ -123,7 +124,7 @@ class RotaryEmbedding(nn.Module): def setup(self) -> None: if self.embedding_dims % 2: raise ValueError( - 'Embedding dim for rotary position embedding must be a multiple of 2.' + "Embedding dim for rotary position embedding must be a multiple of 2." ) def __call__( @@ -148,13 +149,13 @@ def __call__( assert position is not None if len(inputs.shape) != 4: raise ValueError( - 'Input is assumed to be a rank 4 tensor of shape' - '[batch, sequence, heads, dims].' + "Input is assumed to be a rank 4 tensor of shape" + "[batch, sequence, heads, dims]." ) if self.embedding_dims != inputs.shape[3]: raise ValueError( - 'The embedding dims of the rotary position embedding' - 'must match the hidden dimension of the inputs.' + "The embedding dims of the rotary position embedding" + "must match the hidden dimension of the inputs." ) half_embedding_dim = self.embedding_dims // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims @@ -195,7 +196,9 @@ def __call__( position = position[:, :, jnp.newaxis] inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :] scaled_time = position * inv_timescales - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = -1) + signal = jnp.concatenate( + [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1 + ) # signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]]) position_embedding = signal.astype(jnp.float32) return input_embedding + position_embedding diff --git a/MaxText/layers/gemma.py b/MaxText/layers/gemma.py index cbbadf7bc..37ded1908 100644 --- a/MaxText/layers/gemma.py +++ b/MaxText/layers/gemma.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from flax import linen as nn import common_types @@ -52,69 +52,78 @@ # Decoder and Model definitions class GemmaDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs, ("activation_batch", "activation_length", "activation_embed") + ) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_norm', - kernel_axes=('embed',))(inputs) + name="pre_self_attention_norm", + kernel_axes=("embed",), + )(inputs) lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx, ("activation_batch", "activation_length", "activation_embed") + ) attention_layer = Attention( - config=cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - float32_qk_product = True, - float32_logits = True, - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + float32_qk_product=True, + float32_logits=True, + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) attention_lnx = nn.with_logical_constraint( attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + ("activation_batch", "activation_length", "activation_embed"), + ) attention_lnx += inputs residual = attention_lnx attn_output = RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_ffw_norm', - kernel_axes=('embed',))(attention_lnx) + name="pre_ffw_norm", + kernel_axes=("embed",), + )(attention_lnx) # MLP block. mlp_lnx = MlpBlock( @@ -123,12 +132,12 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(attn_output, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") ) next_layer_addition = mlp_lnx + residual @@ -140,15 +149,15 @@ def __call__(self, layer_output = next_layer_addition_dropped_out layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/gpt3.py b/MaxText/layers/gpt3.py index 518ec2912..ae72cb190 100644 --- a/MaxText/layers/gpt3.py +++ b/MaxText/layers/gpt3.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -55,12 +55,14 @@ Quant = quantizations.AqtQuantization -#----------------------------------------- +# ----------------------------------------- # The Normalization Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3LayerNorm(nn.Module): """GPT3 Layer normalization operating on the last axis of the input data.""" + epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -82,10 +84,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: features = x.shape[-1] scale = self.param( - 'scale', + "scale", nn.with_logical_partitioning(self.scale_init, self.kernel_axes), (features,), - self.weight_dtype + self.weight_dtype, ) scale = jnp.asarray(scale, self.dtype) @@ -93,40 +95,43 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: if self.use_bias: bias = self.param( - 'bias', - nn.with_logical_partitioning(initializers.default_bias_init, self.kernel_axes), - (features,), - self.weight_dtype, + "bias", + nn.with_logical_partitioning( + initializers.default_bias_init, self.kernel_axes + ), + (features,), + self.weight_dtype, ) bias = jnp.asarray(bias, self.dtype) output += bias return output -#----------------------------------------- +# ----------------------------------------- # The Attention Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3MultiHeadAttention(nn.Module): """Multi-head attention in gpt3. - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - head_dim: dimension of each head. - max_target_length: maximum length of output - max_prefill_predict_length: size of the maximum prefill - mesh: device mesh - dtype: the dtype of the computation. - dropout_rate: dropout rate - kernel_init: initializer for the kernel of the Dense layers. - float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid - numerical issues with bfloat16. - float32_logits: bool, if True then cast logits to float32 before softmax to avoid - numerical issues with bfloat16. - fused_qkv: whether to fuse query, key and value into one projection. - quant: Quant, stores quantization config, defaults to None implying no quantization. - use_bias: whether to add bias in linear transformation. + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + max_target_length: maximum length of output + max_prefill_predict_length: size of the maximum prefill + mesh: device mesh + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + fused_qkv: whether to fuse query, key and value into one projection. + quant: Quant, stores quantization config, defaults to None implying no quantization. + use_bias: whether to add bias in linear transformation. """ config: Config @@ -138,8 +143,8 @@ class Gpt3MultiHeadAttention(nn.Module): attention_kernel: str dtype: DType = jnp.float32 weight_dtype: DType = jnp.float32 - dropout_rate: float = 0. - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'normal') + dropout_rate: float = 0.0 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") float32_qk_product: bool = False # computes logits in float32 for stability. float32_logits: bool = True # cast logits in float32 for stability. fused_qkv: bool = True @@ -152,88 +157,96 @@ class Gpt3MultiHeadAttention(nn.Module): out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) def qkv_projection(self, inputs: Array, proj_name: str): - """ Fused QKV projection""" + """Fused QKV projection""" qkv_proj = DenseGeneral( - features=(3, self.num_heads, self.head_dim), - axis = -1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'qkv', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=proj_name, - quant=self.quant, - use_bias=self.use_bias, - )(inputs) - qkv_proj = checkpoint_name(qkv_proj, 'qkv_proj') - query, key, value = qkv_proj[:,:,0,...], qkv_proj[:,:,1,...], qkv_proj[:,:,2,...] + features=(3, self.num_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + use_bias=self.use_bias, + )(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = ( + qkv_proj[:, :, 0, ...], + qkv_proj[:, :, 1, ...], + qkv_proj[:, :, 2, ...], + ) return query, key, value def projection(self, inputs: Array, proj_name: str) -> Array: """individual projection for one of q, k and v.""" proj = DenseGeneral( - features=(self.num_heads, self.head_dim), - axis=-1, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'heads', 'kv'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name=proj_name, - quant=self.quant, - use_bias=self.use_bias, - )(inputs) + features=(self.num_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + use_bias=self.use_bias, + )(inputs) return proj def out_projection(self, output_dim: int, out: Array) -> Array: """output projection""" out_proj = DenseGeneral( - features=output_dim, - axis=(-2, -1), - kernel_init=self.kernel_init, - kernel_axes=('heads', 'kv', 'embed'), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - name='out', - quant=self.quant, - use_bias=self.use_bias, - )(out) + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=("heads", "kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="out", + quant=self.quant, + use_bias=self.use_bias, + )(out) return out_proj @nn.compact - def __call__(self, - inputs_q: Array, - decoder_segment_ids: Array | None = None, - *, - model_mode: str = common_types.MODEL_MODE_TRAIN, - deterministic: bool = False): + def __call__( + self, + inputs_q: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False, + ): if self.fused_qkv: - query, key, value = self.qkv_projection(inputs_q, proj_name='qkv_proj') + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") else: - query = self.projection(inputs_q, proj_name='query') - key = self.projection(inputs_q, proj_name='key') - value = self.projection(inputs_q, proj_name='value') + query = self.projection(inputs_q, proj_name="query") + key = self.projection(inputs_q, proj_name="key") + value = self.projection(inputs_q, proj_name="value") depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) query /= depth_scaling # annotate with sharding constraint. query = nn.with_logical_constraint(query, self.query_axis_names) - query = checkpoint_name(query, 'query_proj') + query = checkpoint_name(query, "query_proj") key = nn.with_logical_constraint(key, self.key_axis_names) - key = checkpoint_name(key, 'key_proj') + key = checkpoint_name(key, "key_proj") value = nn.with_logical_constraint(value, self.value_axis_names) - value = checkpoint_name(value, 'value_proj') - - attention_op = AttentionOp(mesh=self.mesh, - attention_kernel=self.attention_kernel, - max_target_length=self.max_target_length, - float32_qk_product=self.float32_qk_product, - float32_logits=self.float32_logits, - quant=self.quant, - quantize_kvcache=self.config.quantize_kvcache, - num_query_heads=self.num_heads, - num_kv_heads=self.num_heads, - dtype=self.dtype) + value = checkpoint_name(value, "value_proj") + + attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + quantize_kvcache=self.config.quantize_kvcache, + num_query_heads=self.num_heads, + num_kv_heads=self.num_heads, + dtype=self.dtype, + ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) @@ -241,76 +254,84 @@ def __call__(self, # apply output projection, output dim is set to the input dim. out = self.out_projection(inputs_q.shape[-1], out) - out = checkpoint_name(out, 'out_proj') + out = checkpoint_name(out, "out_proj") return out -#----------------------------------------- +# ----------------------------------------- # The Decoder Layer specific for GPT3 -#----------------------------------------- +# ----------------------------------------- + class Gpt3DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) - + inputs, ("activation_batch", "activation_length", "activation_embed") + ) lnx_layer_norm = Gpt3LayerNorm( dtype=cfg.dtype, - name='pre_self_attention_norm', - kernel_axes=('embed',), + name="pre_self_attention_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, reductions_in_fp32=False, use_bias=True, - ) + ) lnx = lnx_layer_norm(inputs) lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx, ("activation_batch", "activation_length", "activation_embed") + ) # Self-attention block - assert cfg.num_query_heads == cfg.num_kv_heads, \ - f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" + assert ( + cfg.num_query_heads == cfg.num_kv_heads + ), f"{cfg.num_query_heads=} should be the same as {cfg.num_kv_heads=} in gpt3" attention_layer = Gpt3MultiHeadAttention( - config=cfg, - num_heads=cfg.num_query_heads, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dropout_rate=cfg.dropout_rate, - name='self_attention', - fused_qkv=cfg.fused_qkv, - use_bias=True, - quant=self.quant) + config=cfg, + num_heads=cfg.num_query_heads, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dropout_rate=cfg.dropout_rate, + name="self_attention", + fused_qkv=cfg.fused_qkv, + use_bias=True, + quant=self.quant, + ) attention_lnx = attention_layer( - lnx, - decoder_segment_ids=decoder_segment_ids, - model_mode=model_mode, - deterministic=deterministic) + lnx, + decoder_segment_ids=decoder_segment_ids, + model_mode=model_mode, + deterministic=deterministic, + ) attention_lnx = nn.with_logical_constraint( attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + ("activation_batch", "activation_length", "activation_embed"), + ) attention_lnx += inputs # MLP block. @@ -320,33 +341,33 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", use_bias=True, use_pre_norm=True, config=cfg, quant=self.quant, )(attention_lnx, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") ) layer_output = attention_lnx + mlp_lnx - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + layer_output, deterministic=deterministic + ) layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/initializers.py b/MaxText/layers/initializers.py index 6f0bb9c23..83fd1ff36 100644 --- a/MaxText/layers/initializers.py +++ b/MaxText/layers/initializers.py @@ -32,7 +32,7 @@ ] default_embed_init = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 + 1.0, "fan_in", "normal", out_axis=0 ) default_bias_init = jax.nn.initializers.constant(0.0) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 4cf3d1939..8b4ed401b 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -40,18 +40,22 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization + def _convert_to_activation_function( - fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: + fn_or_string: Union[str, Callable[..., Any]] +) -> Callable[..., Any]: """Convert a string to an activation function.""" - if fn_or_string == 'linear': + if fn_or_string == "linear": return lambda x: x elif isinstance(fn_or_string, str): return getattr(nn, fn_or_string) elif callable(fn_or_string): return fn_or_string else: - raise ValueError(f"""Don't know how to convert {fn_or_string} - to an activation function""") + raise ValueError( + f"""Don't know how to convert {fn_or_string} + to an activation function""" + ) def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: @@ -83,7 +87,7 @@ class DenseGeneral(nn.Module): axis: Union[Iterable[int], int] = -1 weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") kernel_axes: Tuple[str, ...] = () quant: Optional[Quant] = None use_bias: bool = False @@ -106,7 +110,8 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): dot_general_cls = self.quant.dot_general_cls() dot_general = dot_general_cls() return dot_general( - inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + inputs, kernel, ((axis, contract_ind), ((), ())), precision=None + ) features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -123,12 +128,12 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): kernel = jnp.zeros(kernel_shape) else: kernel = self.param( - 'kernel', - nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), - kernel_shape, - self.weight_dtype, - kernel_in_axis, - kernel_out_axis, + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, ) kernel = jnp.asarray(kernel, self.dtype) @@ -136,9 +141,12 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): output = compute_dot_general(inputs, kernel, axis, contract_ind) if self.use_bias: - bias_axes, bias_shape = self.kernel_axes[-len(features):], kernel_shape[-len(features):] + bias_axes, bias_shape = ( + self.kernel_axes[-len(features) :], + kernel_shape[-len(features) :], + ) bias = self.param( - 'bias', + "bias", nn.with_logical_partitioning(bias_init, bias_axes), bias_shape, self.weight_dtype, @@ -167,8 +175,8 @@ class MlpBlock(nn.Module): config: Config intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable[..., Any]]] = ('relu',) - kernel_init: NdInitializer = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + activations: Sequence[Union[str, Callable[..., Any]]] = ("relu",) + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") intermediate_dropout_rate: float = 0.1 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -181,9 +189,14 @@ def get_norm_layer(self): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 - return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=self.use_bias) + + return functools.partial( + gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=self.use_bias + ) else: - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + raise ValueError( + f"Incorrect decoder_block name {self.config.decoder_block=}" + ) @nn.compact def __call__(self, inputs, decode: bool = False, deterministic: bool = False): @@ -192,39 +205,39 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): if self.use_pre_norm: inputs = self.get_norm_layer()( - name='mlp_layer_norm', - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - kernel_axes=('embed',), - epsilon=cfg.normalization_layer_epsilon, - )(inputs) + name="mlp_layer_norm", + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("embed",), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) # Iterate over specified MLP input activation functions. # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. activations = [] if cfg.fused_mlp: x = DenseGeneral( - (len(self.activations), self.intermediate_dim), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - kernel_init=self.kernel_init, - kernel_axes=('embed', 'num_activations', 'mlp'), - name='wi', - quant=self.quant, - use_bias=self.use_bias, + (len(self.activations), self.intermediate_dim), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "num_activations", "mlp"), + name="wi", + quant=self.quant, + use_bias=self.use_bias, )(inputs) for idx, act_fn in enumerate(self.activations): - y = _convert_to_activation_function(act_fn)(x[:,:,idx,...]) + y = _convert_to_activation_function(act_fn)(x[:, :, idx, ...]) activations.append(y) else: for idx, act_fn in enumerate(self.activations): - dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" x = DenseGeneral( self.intermediate_dim, dtype=self.dtype, weight_dtype=self.weight_dtype, kernel_init=self.kernel_init, - kernel_axes=('embed', 'mlp'), + kernel_axes=("embed", "mlp"), name=dense_name, quant=self.quant, use_bias=self.use_bias, @@ -234,26 +247,26 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False): # Take elementwise product of above intermediate activations. x = functools.reduce(operator.mul, activations) - x = checkpoint_name(x, 'mlpwi') + x = checkpoint_name(x, "mlpwi") # Apply dropout and final dense output projection. x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic ) # Broadcast along length. x = nn.with_logical_constraint( - x, ('activation_batch', 'activation_length', 'activation_mlp') + x, ("activation_batch", "activation_length", "activation_mlp") ) output = DenseGeneral( inputs.shape[-1], dtype=self.dtype, weight_dtype=self.weight_dtype, kernel_init=self.kernel_init, - kernel_axes=('mlp', 'embed'), - name='wo', + kernel_axes=("mlp", "embed"), + name="wo", quant=self.quant, use_bias=self.use_bias, )(x) - output = checkpoint_name(output, 'mlpwo') + output = checkpoint_name(output, "mlpwo") return output @@ -278,38 +291,42 @@ class MoeBlock(nn.Module): @nn.compact def __call__(self, inputs, deterministic: bool = False): gate_logits = DenseGeneral( - self.num_experts, - dtype=self.dtype, - kernel_init=self.kernel_init, - kernel_axes=self.kernel_axes, - name='gate', - quant=self.quant,)(inputs) + self.num_experts, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + name="gate", + quant=self.quant, + )(inputs) weights, selected_experts = lax.top_k(gate_logits, self.num_experts_per_tok) weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) mlp_lnx = jnp.zeros_like(inputs) weights = weights.astype(self.dtype) mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") + ) # TODO(ranran): have a better solution to remove the loop here for k in range(self.num_experts): - weights_exp = jnp.sum(jnp.multiply(selected_experts==k, weights), axis=-1) - mlp_lnx_exp = MlpBlock( + weights_exp = jnp.sum( + jnp.multiply(selected_experts == k, weights), axis=-1 + ) + mlp_lnx_exp = MlpBlock( intermediate_dim=self.config.mlp_dim, activations=self.config.mlp_activations, intermediate_dropout_rate=self.config.dropout_rate, dtype=self.dtype, weight_dtype=self.weight_dtype, - name=f'mlp_{k}', + name=f"mlp_{k}", config=self.config, - )(inputs, deterministic=deterministic) + )(inputs, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint( - mlp_lnx_exp, ('activation_batch', 'activation_length', 'activation_embed') - ) - mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp - mlp_lnx += mlp_lnx_exp + mlp_lnx_exp = nn.with_logical_constraint( + mlp_lnx_exp, + ("activation_batch", "activation_length", "activation_embed"), + ) + mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp + mlp_lnx += mlp_lnx_exp return mlp_lnx diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index bc0fd1861..1f9134f34 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -43,86 +43,92 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization -#----------------------------------------- +# ----------------------------------------- # The Decoder Layer specific for Llama2 -#----------------------------------------- +# ----------------------------------------- class LlamaDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) - + inputs, ("activation_batch", "activation_length", "activation_embed") + ) lnx_rms = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_layer_norm', - kernel_axes=('embed',), + name="pre_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - ) + ) lnx = lnx_rms(inputs) lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx, ("activation_batch", "activation_length", "activation_embed") + ) # Self-attention block attention_layer = Attention( - config = cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) attention_lnx = nn.with_logical_constraint( attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + ("activation_batch", "activation_length", "activation_embed"), + ) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='post_self_attention_layer_norm', - kernel_axes=('embed',), + name="post_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) + )(intermediate_inputs) hidden_states = nn.with_logical_constraint( - hidden_states, - ('activation_batch', 'activation_length', 'activation_embed') - ) + hidden_states, + ("activation_batch", "activation_length", "activation_embed"), + ) # MLP block. mlp_lnx = linears.MlpBlock( @@ -131,32 +137,31 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(hidden_states, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") ) - layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + layer_output, deterministic=deterministic + ) layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 2992cf5ab..e61a79f8b 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Transformer model definition.""" # pylint: disable=arguments-differ @@ -51,153 +51,169 @@ class MistralDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: models.Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs, ("activation_batch", "activation_length", "activation_embed") + ) lnx_rms = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_layer_norm', - kernel_axes=('embed',), - epsilon=cfg.normalization_layer_epsilon - ) + name="pre_self_attention_layer_norm", + kernel_axes=("embed",), + epsilon=cfg.normalization_layer_epsilon, + ) lnx = lnx_rms(inputs) lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx, ("activation_batch", "activation_length", "activation_embed") + ) # Self-attention block attention_layer = Attention( - config = cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) attention_lnx = nn.with_logical_constraint( attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + ("activation_batch", "activation_length", "activation_embed"), + ) intermediate_inputs = inputs + attention_lnx # Fully Connected hidden_states = models.RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='post_self_attention_layer_norm', - kernel_axes=('embed',), + name="post_self_attention_layer_norm", + kernel_axes=("embed",), epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, ('activation_batch', 'activation_length', 'activation_embed')) + )(intermediate_inputs) + hidden_states = nn.with_logical_constraint( + hidden_states, + ("activation_batch", "activation_length", "activation_embed"), + ) if cfg.num_experts > 1: - # TODO(ranran): currently, this MoeBlock does not work as expected, and plan to fix it in coming PR. - - # mlp_lnx = linears.MoeBlock( - # config=cfg, - # num_experts=cfg.num_experts, - # num_experts_per_tok=cfg.num_experts_per_tok, - # kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), - # kernel_axes=('embed', 'mlp'), - # dtype=cfg.dtype, - # )(hidden_states, deterministic=deterministic) - - gate_logits = linears.DenseGeneral( - cfg.num_experts, - weight_dtype=cfg.weight_dtype, - dtype=cfg.dtype, - kernel_init=initializers.nd_dense_init( - 1.0, 'fan_in', 'truncated_normal'), - kernel_axes=('embed', 'mlp'), - name="gate", - quant=self.quant, - )(hidden_states) - weights, selected_experts = jax.lax.top_k(gate_logits, cfg.num_experts_per_tok) - weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) - mlp_lnx = jnp.zeros_like(hidden_states) - weights = weights.astype(cfg.dtype) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') - ) + # TODO(ranran): currently, this MoeBlock does not work as expected, and plan to fix it in coming PR. + + # mlp_lnx = linears.MoeBlock( + # config=cfg, + # num_experts=cfg.num_experts, + # num_experts_per_tok=cfg.num_experts_per_tok, + # kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), + # kernel_axes=('embed', 'mlp'), + # dtype=cfg.dtype, + # )(hidden_states, deterministic=deterministic) + + gate_logits = linears.DenseGeneral( + cfg.num_experts, + weight_dtype=cfg.weight_dtype, + dtype=cfg.dtype, + kernel_init=initializers.nd_dense_init( + 1.0, "fan_in", "truncated_normal" + ), + kernel_axes=("embed", "mlp"), + name="gate", + quant=self.quant, + )(hidden_states) + weights, selected_experts = jax.lax.top_k( + gate_logits, cfg.num_experts_per_tok + ) + weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1) + mlp_lnx = jnp.zeros_like(hidden_states) + weights = weights.astype(cfg.dtype) + mlp_lnx = nn.with_logical_constraint( + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") + ) - # TODO(ranran): have a better solution to remove the loop here - for k in range(cfg.num_experts): - weights_exp = jnp.sum(jnp.multiply( - selected_experts == k, weights), axis=-1) - mlp_lnx_exp = linears.MlpBlock( - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name=f'mlp_{k}', - config=cfg, - )(hidden_states, deterministic=deterministic) - mlp_lnx_exp = nn.with_logical_constraint( - mlp_lnx_exp, ('activation_batch', 'activation_length', 'activation_embed') - ) - mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp - mlp_lnx += mlp_lnx_exp - else: - mlp_lnx = linears.MlpBlock( + # TODO(ranran): have a better solution to remove the loop here + for k in range(cfg.num_experts): + weights_exp = jnp.sum( + jnp.multiply(selected_experts == k, weights), axis=-1 + ) + mlp_lnx_exp = linears.MlpBlock( intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name=f"mlp_{k}", config=cfg, )(hidden_states, deterministic=deterministic) - mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + mlp_lnx_exp = nn.with_logical_constraint( + mlp_lnx_exp, + ("activation_batch", "activation_length", "activation_embed"), ) + mlp_lnx_exp = weights_exp[:, :, None] * mlp_lnx_exp + mlp_lnx += mlp_lnx_exp + else: + mlp_lnx = linears.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + config=cfg, + )(hidden_states, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint( + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") + ) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - layer_output, deterministic=deterministic) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + layer_output, deterministic=deterministic + ) layer_output = nn.with_logical_constraint( - layer_output, ('activation_batch', 'activation_length', 'activation_embed'), + layer_output, + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 571a77a95..ca1f92309 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -41,69 +41,76 @@ PositionalEmbedding = embeddings.PositionalEmbedding Quant = quantizations.AqtQuantization -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ # The network: Decoder & Transformer Definitions -#------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ class DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + config: Config mesh: Mesh quant: Optional[Quant] = None @nn.compact - def __call__(self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ): + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): cfg = self.config mesh = self.mesh inputs = nn.with_logical_constraint( - inputs, ('activation_batch', 'activation_length', 'activation_embed')) + inputs, ("activation_batch", "activation_length", "activation_embed") + ) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] lnx = RMSNorm( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='pre_self_attention_norm', + name="pre_self_attention_norm", epsilon=cfg.normalization_layer_epsilon, - kernel_axes=('embed',))(inputs) + kernel_axes=("embed",), + )(inputs) lnx = nn.with_logical_constraint( - lnx, ('activation_batch', 'activation_length', 'activation_embed')) + lnx, ("activation_batch", "activation_length", "activation_embed") + ) attention_layer = Attention( - config = self.config, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name='self_attention', - quant=self.quant, - quantize_kvcache=cfg.quantize_kvcache) - + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) attention_lnx = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode) + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) attention_lnx = nn.with_logical_constraint( attention_lnx, - ('activation_batch', 'activation_length', 'activation_embed')) + ("activation_batch", "activation_length", "activation_embed"), + ) # MLP block. mlp_lnx = linears.MlpBlock( @@ -112,12 +119,12 @@ def __call__(self, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name='mlp', + name="mlp", config=cfg, quant=self.quant, )(lnx, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint( - mlp_lnx, ('activation_batch', 'activation_length', 'activation_embed') + mlp_lnx, ("activation_batch", "activation_length", "activation_embed") ) next_layer_addition = mlp_lnx + attention_lnx @@ -129,15 +136,15 @@ def __call__(self, layer_output = next_layer_addition_dropped_out + inputs layer_output = nn.with_logical_constraint( layer_output, - ('activation_batch', 'activation_length', 'activation_embed'), + ("activation_batch", "activation_length", "activation_embed"), ) if cfg.record_internal_nn_metrics: - self.sow('intermediates', 'activation_mean', jnp.mean(layer_output)) - self.sow('intermediates', 'activation_stdev', jnp.std(layer_output)) + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( - 'intermediates', - 'activation_fraction_zero', + "intermediates", + "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) @@ -146,6 +153,7 @@ def __call__(self, class Decoder(nn.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" + config: Config shared_embedding: nn.Module mesh: Mesh @@ -156,93 +164,128 @@ def get_decoder_layer(self): return DecoderLayer elif self.config.decoder_block == "llama2": from layers import llama2 + return llama2.LlamaDecoderLayer elif self.config.decoder_block == "mistral": # TODO(ranran): update to Mistral with sliding window attention from layers import mistral + return mistral.MistralDecoderLayer elif self.config.decoder_block == "gemma": from layers import gemma + return gemma.GemmaDecoderLayer elif self.config.decoder_block == "gpt3": from layers import gpt3 + return gpt3.Gpt3DecoderLayer else: - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + raise ValueError( + f"Incorrect decoder_block name {self.config.decoder_block=}" + ) def get_norm_layer(self): if self.config.decoder_block in ("default", "llama2", "mistral", "gemma"): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 - return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=True) + + return functools.partial( + gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=True + ) else: - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + raise ValueError( + f"Incorrect decoder_block name {self.config.decoder_block=}" + ) @nn.compact - def __call__(self, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=common_types.MODEL_MODE_TRAIN, - ): + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=common_types.MODEL_MODE_TRAIN, + ): cfg = self.config mesh = self.mesh assert decoder_input_tokens.ndim == 2 # [batch, len] # [batch, length] -> [batch, length, emb_dim] - y = self.shared_embedding(decoder_input_tokens.astype('int32')) - y = nn.Dropout( - rate=cfg.dropout_rate, broadcast_dims=(-2,))( - y, deterministic=deterministic) + y = self.shared_embedding(decoder_input_tokens.astype("int32")) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + y, deterministic=deterministic + ) y = y.astype(cfg.dtype) if cfg.use_untrainable_positional_embedding: - y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) + y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) if cfg.trainable_position_size > 0: y += Embed( - num_embeddings=cfg.trainable_position_size, - features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - name='position_embedder', - config=cfg)(decoder_positions) + num_embeddings=cfg.trainable_position_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + name="position_embedder", + config=cfg, + )(decoder_positions) BlockLayer = self.get_decoder_layer() - if cfg.remat_policy != 'none': - if cfg.remat_policy == 'minimal': + if cfg.remat_policy != "none": + if cfg.remat_policy == "minimal": policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - elif cfg.remat_policy == 'save_dot_except_mlpwi': + elif cfg.remat_policy == "save_dot_except_mlpwi": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', 'out_proj', 'mlpwo', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", ) - elif cfg.remat_policy == 'save_dot_except_mlp': + elif cfg.remat_policy == "save_dot_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', 'out_proj', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", ) - elif cfg.remat_policy == 'save_qkv_proj': + elif cfg.remat_policy == "save_qkv_proj": policy = jax.checkpoint_policies.save_only_these_names( - 'query_proj', 'value_proj', 'key_proj', 'qkv_proj', + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", ) - elif cfg.remat_policy == 'qkv_proj_offloaded': + elif cfg.remat_policy == "qkv_proj_offloaded": policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=['query_proj', 'value_proj', 'key_proj'], - offload_src="device", offload_dst="pinned_host") - elif cfg.remat_policy == 'minimal_offloaded': - policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") - elif cfg.remat_policy == 'minimal_flash': + names_which_can_be_saved=[], + names_which_can_be_offloaded=[ + "query_proj", + "value_proj", + "key_proj", + ], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": + policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( + offload_src="device", offload_dst="pinned_host" + ) + elif cfg.remat_policy == "minimal_flash": policy = jax.checkpoint_policies.save_from_both_policies( - jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - jax.checkpoint_policies.save_only_these_names('context',), + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + jax.checkpoint_policies.save_only_these_names( + "context", + ), ) else: assert ( - cfg.remat_policy == 'full' - ), 'Remat policy needs to be on list of remat policies' + cfg.remat_policy == "full" + ), "Remat policy needs to be on list of remat policies" policy = None BlockLayer = nn.remat( # pylint: disable=invalid-name BlockLayer, @@ -251,7 +294,7 @@ def __call__(self, static_argnums=(-1, -2, -3, -4, -5), ) if cfg.scan_layers: - initializing = self.is_mutable_collection('params') + initializing = self.is_mutable_collection("params") params_spec = ( cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) ) @@ -259,15 +302,15 @@ def __call__(self, y, _ = nn.scan( BlockLayer, variable_axes={ - 'params': params_spec, - 'cache': cache_spec, - 'intermediates': 0, - 'aqt':0, - '_overwrite_with_gradient': 0, + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, }, split_rngs={ - 'params': True, - 'dropout': cfg.enable_dropout, + "params": True, + "dropout": cfg.enable_dropout, }, in_axes=( nn.broadcast, @@ -276,8 +319,8 @@ def __call__(self, nn.broadcast, ), length=cfg.num_decoder_layers, - metadata_params={nn.PARTITION_NAME: 'layers'}, - )(config=cfg, mesh=mesh, name='layers', quant=self.quant)( + metadata_params={nn.PARTITION_NAME: "layers"}, + )(config=cfg, mesh=mesh, name="layers", quant=self.quant)( y, decoder_segment_ids, decoder_positions, @@ -286,8 +329,9 @@ def __call__(self, ) else: for lyr in range(cfg.num_decoder_layers): - y = BlockLayer(config=cfg, mesh=mesh, name=f'layers_{lyr}', - quant=self.quant)( + y = BlockLayer( + config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant + )( y, decoder_segment_ids, decoder_positions, @@ -296,12 +340,12 @@ def __call__(self, ) y = self.get_norm_layer()( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name='decoder_norm', - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=('embed',), - )(y) + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="decoder_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("embed",), + )(y) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic ) @@ -317,17 +361,24 @@ def __call__(self, logits = linears.DenseGeneral( cfg.vocab_size, weight_dtype=cfg.weight_dtype, - dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=('embed', 'vocab'), - name='logits_dense')(y) # We do not quantize the logits matmul. + dtype=jnp.float32 + if cfg.logits_dot_in_fp32 + else cfg.dtype, # for logit training stability + kernel_axes=("embed", "vocab"), + name="logits_dense", + )( + y + ) # We do not quantize the logits matmul. logits = nn.with_logical_constraint( - logits, ('activation_batch', 'activation_length', 'activation_vocab')) + logits, ("activation_batch", "activation_length", "activation_vocab") + ) logits = logits.astype(jnp.float32) return logits class Transformer(nn.Module): """An decoder-only Transformer model.""" + # Make new attributes required, so that all Transformer dependencies (train, decode, compile, etc) will error instead of silently use defaults. # pylint: disable=attribute-defined-outside-init config: Config @@ -343,15 +394,19 @@ def setup(self): num_embeddings=cfg.vocab_size, features=cfg.emb_dim, dtype=cfg.dtype, - attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability + attend_dtype=jnp.float32 + if cfg.logits_dot_in_fp32 + else cfg.dtype, # for logit training stability embedding_init=nn.initializers.normal(stddev=1.0), - name='token_embedder', + name="token_embedder", config=cfg, ) self.decoder = Decoder( - config=cfg, shared_embedding=self.shared_embedding, - mesh=mesh, quant=self.quant + config=cfg, + shared_embedding=self.shared_embedding, + mesh=mesh, + quant=self.quant, ) def __call__( @@ -360,14 +415,18 @@ def __call__( decoder_positions, decoder_segment_ids=None, enable_dropout=True, - model_mode=common_types.MODEL_MODE_TRAIN + model_mode=common_types.MODEL_MODE_TRAIN, ): """Applies Transformer decoder-branch on encoded-input and target.""" - if decoder_segment_ids is not None and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + if ( + decoder_segment_ids is not None + and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE + ): raise ValueError( - f'During autoregressive decoding we assume the tokens are in the active sequence' - f' which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}.') + f"During autoregressive decoding we assume the tokens are in the active sequence" + f" which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}." + ) logits = self.decoder( decoder_input_tokens=decoder_input_tokens, diff --git a/MaxText/layers/normalizations.py b/MaxText/layers/normalizations.py index 6d451d4fe..862c586c9 100644 --- a/MaxText/layers/normalizations.py +++ b/MaxText/layers/normalizations.py @@ -26,6 +26,7 @@ class RMSNorm(nn.Module): """RMS normalization.""" + epsilon: float = 1e-6 dtype: Any = jnp.float32 weight_dtype: Any = jnp.float32 @@ -40,7 +41,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) scale = self.param( - 'scale', + "scale", nn.with_logical_partitioning(self.scale_init, self.kernel_axes), (features,), self.weight_dtype, diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index c6450ea17..60e079fdd 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -27,117 +27,127 @@ MAX_INT8 = 127.5 + @dataclass class Quantization: - """Base class for quantization configurations""" + """Base class for quantization configurations""" - def dot_general_cls(self): - """ Placeholder for dot_general implementation in subclasses. """ - pass + def dot_general_cls(self): + """Placeholder for dot_general implementation in subclasses.""" + pass @dataclass class AqtQuantization: - """ Configures AQT quantization github.com/google/aqt. """ + """Configures AQT quantization github.com/google/aqt.""" + quant_dg: aqt_config.DotGeneral quant_mode: aqt_flax.QuantMode = aqt_flax.QuantMode.TRAIN def dot_general_cls(self): - """ Returns dot_general configured with aqt params. """ + """Returns dot_general configured with aqt params.""" aqt_dg_cls = functools.partial( - aqt_flax.AqtDotGeneral, - self.quant_dg, - rhs_quant_mode=self.quant_mode - ) + aqt_flax.AqtDotGeneral, self.quant_dg, rhs_quant_mode=self.quant_mode + ) return aqt_dg_cls def einsum(self): - """ Returns einsum configured with aqt params """ - aqt_einsum = functools.partial(aqt_flax.AqtEinsum( - cfg=self.quant_dg, - lhs_quant_mode=self.quant_mode - ) + """Returns einsum configured with aqt params""" + aqt_einsum = functools.partial( + aqt_flax.AqtEinsum(cfg=self.quant_dg, lhs_quant_mode=self.quant_mode) ) return aqt_einsum + @dataclass class Fp8Quantization(Quantization): - """ Configures Fp8 quantization for NVIDIA GPUs""" + """Configures Fp8 quantization for NVIDIA GPUs""" + quant_mode = "train" def dot_general_cls(self): - """ Returns dot_general configured with aqt params. """ + """Returns dot_general configured with aqt params.""" return nn.Fp8DotGeneralOp + def _get_quant_config(config): """Set quantization params based on user configuration.""" - if not config.quantization or config.quantization == '': + if not config.quantization or config.quantization == "": return None elif config.quantization == "int8": if config.quantization_local_shard_count == 0: drhs_bits = None drhs_accumulator_dtype = None - drhs_local_aqt=None + drhs_local_aqt = None else: drhs_bits = 8 drhs_accumulator_dtype = jnp.int32 - drhs_local_aqt = aqt_config.LocalAqt(config.quantization_local_shard_count) + drhs_local_aqt = aqt_config.LocalAqt( + config.quantization_local_shard_count + ) return aqt_config.config_v3( - fwd_bits=8, - dlhs_bits=8, - drhs_bits=drhs_bits, - rng_type='jax.uniform', - dlhs_local_aqt=None, - drhs_local_aqt=drhs_local_aqt, - fwd_accumulator_dtype=jnp.int32, - dlhs_accumulator_dtype=jnp.int32, - drhs_accumulator_dtype=drhs_accumulator_dtype, + fwd_bits=8, + dlhs_bits=8, + drhs_bits=drhs_bits, + rng_type="jax.uniform", + dlhs_local_aqt=None, + drhs_local_aqt=drhs_local_aqt, + fwd_accumulator_dtype=jnp.int32, + dlhs_accumulator_dtype=jnp.int32, + drhs_accumulator_dtype=drhs_accumulator_dtype, ) elif config.quantization == "fp8": return "fp8" else: - raise ValueError(f'Invalid value configured for quantization {config.quantization}.') + raise ValueError( + f"Invalid value configured for quantization {config.quantization}." + ) + def in_convert_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.CONVERT) + def in_serve_mode(quant): return quant and (quant.quant_mode == aqt_flax.QuantMode.SERVE) -def get_quant_mode(quant_mode_str: str = 'train'): - """ Set quant mode.""" - if quant_mode_str == 'train': + +def get_quant_mode(quant_mode_str: str = "train"): + """Set quant mode.""" + if quant_mode_str == "train": return aqt_flax.QuantMode.TRAIN - elif quant_mode_str == 'serve': + elif quant_mode_str == "serve": return aqt_flax.QuantMode.SERVE - elif quant_mode_str == 'convert': + elif quant_mode_str == "convert": return aqt_flax.QuantMode.CONVERT else: - raise ValueError(f'Invalid quantization mode {quant_mode_str}.') + raise ValueError(f"Invalid quantization mode {quant_mode_str}.") return None -def configure_quantization(config: Config, quant_mode_str: str = 'train'): - """ Configure quantization based on user config and quant mode.""" + +def configure_quantization(config: Config, quant_mode_str: str = "train"): + """Configure quantization based on user config and quant mode.""" quant_cfg = _get_quant_config(config) if quant_cfg: if quant_cfg == "fp8": - return Fp8Quantization() + return Fp8Quantization() quant_mode = get_quant_mode(quant_mode_str) return AqtQuantization(quant_dg=quant_cfg, quant_mode=quant_mode) return None + def _get_aqt_key_paths(aqt_vars): - """ Generate a list of paths which have aqt state """ + """Generate a list of paths which have aqt state""" aqt_tree_flat, _ = jax.tree_util.tree_flatten_with_path(aqt_vars) aqt_key_paths = [] for k, _ in aqt_tree_flat: pruned_keys = [] for d in list(k): - if 'AqtDotGeneral' in d.key: - pruned_keys.append(jax.tree_util.DictKey(key='kernel')) + if "AqtDotGeneral" in d.key: + pruned_keys.append(jax.tree_util.DictKey(key="kernel")) break else: - assert 'Aqt' not in d.key, f"Unexpected Aqt op {d.key} in {k}." + assert "Aqt" not in d.key, f"Unexpected Aqt op {d.key} in {k}." pruned_keys.append(d) aqt_key_paths.append(tuple(pruned_keys)) return aqt_key_paths @@ -153,16 +163,19 @@ def remove_quantized_params(params, aqt_vars): tree_flat[i] = v return tree_unflatten(tree_struct, tree_flat) + def configure_kv_quantization(config: Config): - """ Configure kv quantization based on user config.""" + """Configure kv quantization based on user config.""" return False if not config.quantize_kvcache else True + def quantize_kv(kv: Array): """Quantize key/values stored in kvcache.""" scale = jnp.max(jnp.abs(kv), axis=-1, keepdims=True) value = jnp.int8(jnp.rint(kv * (MAX_INT8 / scale))) return value, scale -def unquantize_kv(value: Array, scale:Array, dtype:jnp.dtype): + +def unquantize_kv(value: Array, scale: Array, dtype: jnp.dtype): """Unquantize key/values stored in kvcache.""" return value.astype(dtype) * scale / MAX_INT8 diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index b9a7aefe4..b053ae464 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" r"""Convert weights from a Llama or Mistral model to a MaxText one. @@ -42,54 +42,55 @@ import sys import os -jax.config.update('jax_platform_name', 'cpu') +jax.config.update("jax_platform_name", "cpu") + def permute_to_match_maxtext_rope(arr): evens = arr[..., ::2] odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) MODEL_PARAMS_DICT = { - 'llama2-70b': { - 'num_layers': 80, - 'num_heads': 64, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-70b": { + "num_layers": 80, + "num_heads": 64, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, }, - 'llama2-13b': { - 'num_layers': 40, - 'num_heads': 40, - 'num_kv_heads': 40, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-13b": { + "num_layers": 40, + "num_heads": 40, + "num_kv_heads": 40, + "dims_per_head": 128, + "vocab": 32000, }, - 'llama2-7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 32, - 'dims_per_head': 128, - 'vocab': 32000, + "llama2-7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 32, + "dims_per_head": 128, + "vocab": 32000, }, - 'mistral-7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, - 'base_emb_dim': 4096, - 'base_mlp_dim': 14336, + "mistral-7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, + "base_emb_dim": 4096, + "base_mlp_dim": 14336, }, - 'mixtral-8x7b': { - 'num_layers': 32, - 'num_heads': 32, - 'num_kv_heads': 8, - 'dims_per_head': 128, - 'vocab': 32000, - 'base_emb_dim': 4096, - 'base_mlp_dim': 14336, - 'num_experts': 8, + "mixtral-8x7b": { + "num_layers": 32, + "num_heads": 32, + "num_kv_heads": 8, + "dims_per_head": 128, + "vocab": 32000, + "base_emb_dim": 4096, + "base_mlp_dim": 14336, + "num_experts": 8, }, } @@ -108,256 +109,358 @@ def convert(base_model_path, maxtext_model_path, model_size): """ """Convert model to maxtext.""" model_params = MODEL_PARAMS_DICT[model_size] - base_num_decoder_layers = model_params['num_layers'] - base_num_query_heads = model_params['num_heads'] - head_dim = model_params['dims_per_head'] - base_num_kv_heads = model_params['num_kv_heads'] - vocab_size = model_params['vocab'] - num_experts = model_params['num_experts'] if 'num_experts' in model_params else None - - print(f'Loading the base model from {base_model_path}') + base_num_decoder_layers = model_params["num_layers"] + base_num_query_heads = model_params["num_heads"] + head_dim = model_params["dims_per_head"] + base_num_kv_heads = model_params["num_kv_heads"] + vocab_size = model_params["vocab"] + num_experts = ( + model_params["num_experts"] if "num_experts" in model_params else None + ) + + print(f"Loading the base model from {base_model_path}") # Skip any hidden files for checkpoints - ckpt_paths = sorted(pathlib.Path(base_model_path).glob('[!.]*.pth')) + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.pth")) pytorch_vars = {} for i, ckpt_path in enumerate(ckpt_paths): - print(f'Loading checkpoint {i+1} of {len(ckpt_paths)} ...') - checkpoint = torch.load(ckpt_path, map_location='cpu') - pytorch_vars[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint + print(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") + checkpoint = torch.load(ckpt_path, map_location="cpu") + pytorch_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint pytorch_vars = [pytorch_vars[i] for i in sorted(list(pytorch_vars.keys()))] - layer_key = 'gate' if num_experts else 'mlp' + layer_key = "gate" if num_experts else "mlp" jax_weights = { - 'decoder': { - 'layers': { + "decoder": { + "layers": { layer_key: {}, - 'pre_self_attention_layer_norm': {}, - 'post_self_attention_layer_norm': {}, - 'self_attention': {}, + "pre_self_attention_layer_norm": {}, + "post_self_attention_layer_norm": {}, + "self_attention": {}, }, - 'decoder_norm': { - 'scale': pytorch_vars[0]['norm.weight'].type(torch.float16).numpy() + "decoder_norm": { + "scale": pytorch_vars[0]["norm.weight"] + .type(torch.float16) + .numpy() + }, + "logits_dense": { + "kernel": np.concatenate( + [ + var["output.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose()[:, :vocab_size] }, - 'logits_dense': { - 'kernel': np.concatenate([var['output.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose()[:, :vocab_size] - } }, - 'token_embedder': { - 'embedding': np.concatenate([var['tok_embeddings.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1)[:vocab_size, :] - - } - + "token_embedder": { + "embedding": np.concatenate( + [ + var["tok_embeddings.weight"].type(torch.float16).numpy() + for var in pytorch_vars + ], + axis=1, + )[:vocab_size, :] + }, } layer_weight = { - 'pre_self_attention_layer_norm': { - 'scale': [] - }, - 'post_self_attention_layer_norm': { - 'scale': [] - } + "pre_self_attention_layer_norm": {"scale": []}, + "post_self_attention_layer_norm": {"scale": []}, } if num_experts is None: - layer_weight['mlp'] = { - 'wi_0': { - 'kernel': [] - }, - 'wi_1': { - 'kernel': [] - }, - 'wo': { - 'kernel': [] - }, + layer_weight["mlp"] = { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, } else: - layer_weight['gate'] = { - 'kernel': [] - } + layer_weight["gate"] = {"kernel": []} for k in range(num_experts): - jax_weights['decoder']['layers'][f'mlp_{k}'] = {} - layer_weight[f'mlp_{k}'] = { - 'wi_0': { - 'kernel': [] - }, - 'wi_1': { - 'kernel': [] - }, - 'wo': { - 'kernel': [] - }, + jax_weights["decoder"]["layers"][f"mlp_{k}"] = {} + layer_weight[f"mlp_{k}"] = { + "wi_0": {"kernel": []}, + "wi_1": {"kernel": []}, + "wo": {"kernel": []}, } self_attention = { - 'query': { - 'kernel': [] - }, - 'key': { - 'kernel': [] - }, - 'value': { - 'kernel': [] - }, - 'out': { - 'kernel': [] - }, + "query": {"kernel": []}, + "key": {"kernel": []}, + "value": {"kernel": []}, + "out": {"kernel": []}, } for layer_idx in range(base_num_decoder_layers): - wq = np.concatenate([var[f'layers.{layer_idx}.attention.wq.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wk = np.concatenate([var[f'layers.{layer_idx}.attention.wk.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wv = np.concatenate([var[f'layers.{layer_idx}.attention.wv.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - - wq = np.reshape(wq, [base_num_query_heads * head_dim, - base_num_query_heads, head_dim]) - wk = np.reshape(wk, [base_num_query_heads * head_dim, - base_num_kv_heads, head_dim]) - wv = np.reshape(wv, [base_num_query_heads * head_dim, - base_num_kv_heads, head_dim]) + wq = np.concatenate( + [ + var[f"layers.{layer_idx}.attention.wq.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wk = np.concatenate( + [ + var[f"layers.{layer_idx}.attention.wk.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wv = np.concatenate( + [ + var[f"layers.{layer_idx}.attention.wv.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + + wq = np.reshape( + wq, [base_num_query_heads * head_dim, base_num_query_heads, head_dim] + ) + wk = np.reshape( + wk, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim] + ) + wv = np.reshape( + wv, [base_num_query_heads * head_dim, base_num_kv_heads, head_dim] + ) wq = permute_to_match_maxtext_rope(wq) wk = permute_to_match_maxtext_rope(wk) w_post = np.concatenate( [ - var[f'layers.{layer_idx}.attention.wo.weight'].type( - torch.float16).numpy() + var[f"layers.{layer_idx}.attention.wo.weight"] + .type(torch.float16) + .numpy() for var in pytorch_vars ], axis=1, ) w_post = np.reshape( - w_post, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) - - self_attention['query']['kernel'].append(wq) - self_attention['key']['kernel'].append(wk) - self_attention['value']['kernel'].append(wv) - self_attention['out']['kernel'].append(w_post) - pre_self_attention_layernorm = pytorch_vars[0][f'layers.{layer_idx}.attention_norm.weight'].type( - torch.float16).numpy() - post_self_attention_layernorm = pytorch_vars[0][f'layers.{layer_idx}.ffn_norm.weight'].type( - torch.float16).numpy() - layer_weight['pre_self_attention_layer_norm']['scale'].append( - pre_self_attention_layernorm) - layer_weight['post_self_attention_layer_norm']['scale'].append( - post_self_attention_layernorm) + w_post, + [base_num_query_heads * head_dim, base_num_query_heads, head_dim], + ) + + self_attention["query"]["kernel"].append(wq) + self_attention["key"]["kernel"].append(wk) + self_attention["value"]["kernel"].append(wv) + self_attention["out"]["kernel"].append(w_post) + pre_self_attention_layernorm = ( + pytorch_vars[0][f"layers.{layer_idx}.attention_norm.weight"] + .type(torch.float16) + .numpy() + ) + post_self_attention_layernorm = ( + pytorch_vars[0][f"layers.{layer_idx}.ffn_norm.weight"] + .type(torch.float16) + .numpy() + ) + layer_weight["pre_self_attention_layer_norm"]["scale"].append( + pre_self_attention_layernorm + ) + layer_weight["post_self_attention_layer_norm"]["scale"].append( + post_self_attention_layernorm + ) if num_experts is None: - wi_0 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w1.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wi_1 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w3.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wo = np.concatenate([var[f'layers.{layer_idx}.feed_forward.w2.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1).transpose() - layer_weight['mlp']['wi_0']['kernel'].append(wi_0) - layer_weight['mlp']['wi_1']['kernel'].append(wi_1) - layer_weight['mlp']['wo']['kernel'].append(wo) + wi_0 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.w1.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wi_1 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.w3.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wo = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.w2.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=1, + ).transpose() + layer_weight["mlp"]["wi_0"]["kernel"].append(wi_0) + layer_weight["mlp"]["wi_1"]["kernel"].append(wi_1) + layer_weight["mlp"]["wo"]["kernel"].append(wo) else: - gate = np.concatenate([var[f'layers.{layer_idx}.feed_forward.gate.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - layer_weight['gate']['kernel'].append(gate) + gate = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.gate.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + layer_weight["gate"]["kernel"].append(gate) for k in range(num_experts): - wi_0 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w1.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wi_1 = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w3.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=0).transpose() - wo = np.concatenate([var[f'layers.{layer_idx}.feed_forward.experts.{k}.w2.weight'].type(torch.float16).numpy() - for var in pytorch_vars], axis=1).transpose() - layer_weight[f'mlp_{k}']['wi_0']['kernel'].append(wi_0) - layer_weight[f'mlp_{k}']['wi_1']['kernel'].append(wi_1) - layer_weight[f'mlp_{k}']['wo']['kernel'].append(wo) - - self_attention['query']['kernel'] = np.array( - self_attention['query']['kernel']) - self_attention['key']['kernel'] = np.array(self_attention['key']['kernel']) - self_attention['value']['kernel'] = np.array( - self_attention['value']['kernel']) - self_attention['out']['kernel'] = np.array(self_attention['out']['kernel']) - self_attention['query']['kernel'] = np.transpose( - self_attention['query']['kernel'], axes=(1, 0, 2, 3)) - self_attention['key']['kernel'] = np.transpose( - self_attention['key']['kernel'], axes=(1, 0, 2, 3)) - self_attention['value']['kernel'] = np.transpose( - self_attention['value']['kernel'], axes=(1, 0, 2, 3)) + wi_0 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w1.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wi_1 = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w3.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=0, + ).transpose() + wo = np.concatenate( + [ + var[f"layers.{layer_idx}.feed_forward.experts.{k}.w2.weight"] + .type(torch.float16) + .numpy() + for var in pytorch_vars + ], + axis=1, + ).transpose() + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"].append(wi_0) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"].append(wi_1) + layer_weight[f"mlp_{k}"]["wo"]["kernel"].append(wo) + + self_attention["query"]["kernel"] = np.array( + self_attention["query"]["kernel"] + ) + self_attention["key"]["kernel"] = np.array(self_attention["key"]["kernel"]) + self_attention["value"]["kernel"] = np.array( + self_attention["value"]["kernel"] + ) + self_attention["out"]["kernel"] = np.array(self_attention["out"]["kernel"]) + self_attention["query"]["kernel"] = np.transpose( + self_attention["query"]["kernel"], axes=(1, 0, 2, 3) + ) + self_attention["key"]["kernel"] = np.transpose( + self_attention["key"]["kernel"], axes=(1, 0, 2, 3) + ) + self_attention["value"]["kernel"] = np.transpose( + self_attention["value"]["kernel"], axes=(1, 0, 2, 3) + ) # layers, base_num_query_heads * head_dim, base_num_query_heads, head_dim => # base_num_query_heads, layers,head_dim, base_num_query_heads * head_dim - self_attention['out']['kernel'] = np.transpose( - self_attention['out']['kernel'], axes=(2, 0, 3, 1)) + self_attention["out"]["kernel"] = np.transpose( + self_attention["out"]["kernel"], axes=(2, 0, 3, 1) + ) # scale the query weights - self_attention['query']['kernel'] = self_attention['query']['kernel'] / \ - np.sqrt(head_dim) + self_attention["query"]["kernel"] = self_attention["query"][ + "kernel" + ] / np.sqrt(head_dim) - jax_weights['decoder']['layers']['self_attention'] = self_attention + jax_weights["decoder"]["layers"]["self_attention"] = self_attention # self attention layer norm and swap the layer index - layer_weight['pre_self_attention_layer_norm']['scale'] = np.array( - layer_weight['pre_self_attention_layer_norm']['scale']) - layer_weight['post_self_attention_layer_norm']['scale'] = np.array( - layer_weight['post_self_attention_layer_norm']['scale']) - layer_weight['pre_self_attention_layer_norm']['scale'] = np.transpose( - layer_weight['pre_self_attention_layer_norm']['scale'], - axes=(1, 0)) - layer_weight['post_self_attention_layer_norm']['scale'] = np.transpose( - layer_weight['post_self_attention_layer_norm']['scale'], - axes=(1, 0)) - - jax_weights['decoder']['layers']['pre_self_attention_layer_norm'] = layer_weight['pre_self_attention_layer_norm'] - jax_weights['decoder']['layers']['post_self_attention_layer_norm'] = layer_weight['post_self_attention_layer_norm'] + layer_weight["pre_self_attention_layer_norm"]["scale"] = np.array( + layer_weight["pre_self_attention_layer_norm"]["scale"] + ) + layer_weight["post_self_attention_layer_norm"]["scale"] = np.array( + layer_weight["post_self_attention_layer_norm"]["scale"] + ) + layer_weight["pre_self_attention_layer_norm"]["scale"] = np.transpose( + layer_weight["pre_self_attention_layer_norm"]["scale"], axes=(1, 0) + ) + layer_weight["post_self_attention_layer_norm"]["scale"] = np.transpose( + layer_weight["post_self_attention_layer_norm"]["scale"], axes=(1, 0) + ) + + jax_weights["decoder"]["layers"][ + "pre_self_attention_layer_norm" + ] = layer_weight["pre_self_attention_layer_norm"] + jax_weights["decoder"]["layers"][ + "post_self_attention_layer_norm" + ] = layer_weight["post_self_attention_layer_norm"] if num_experts is None: - layer_weight['mlp']['wi_0']['kernel'] = np.array( - layer_weight['mlp']['wi_0']['kernel']) - layer_weight['mlp']['wi_1']['kernel'] = np.array( - layer_weight['mlp']['wi_1']['kernel']) - layer_weight['mlp']['wo']['kernel'] = np.array( - layer_weight['mlp']['wo']['kernel']) + layer_weight["mlp"]["wi_0"]["kernel"] = np.array( + layer_weight["mlp"]["wi_0"]["kernel"] + ) + layer_weight["mlp"]["wi_1"]["kernel"] = np.array( + layer_weight["mlp"]["wi_1"]["kernel"] + ) + layer_weight["mlp"]["wo"]["kernel"] = np.array( + layer_weight["mlp"]["wo"]["kernel"] + ) # swap the layer index - layer_weight['mlp']['wi_0']['kernel'] = np.transpose( - layer_weight['mlp']['wi_0']['kernel'], axes=(1, 0, 2)) - layer_weight['mlp']['wi_1']['kernel'] = np.transpose( - layer_weight['mlp']['wi_1']['kernel'], axes=(1, 0, 2)) - layer_weight['mlp']['wo']['kernel'] = np.transpose( - layer_weight['mlp']['wo']['kernel'], axes=(1, 0, 2)) - - jax_weights['decoder']['layers']['mlp'] = layer_weight['mlp'] + layer_weight["mlp"]["wi_0"]["kernel"] = np.transpose( + layer_weight["mlp"]["wi_0"]["kernel"], axes=(1, 0, 2) + ) + layer_weight["mlp"]["wi_1"]["kernel"] = np.transpose( + layer_weight["mlp"]["wi_1"]["kernel"], axes=(1, 0, 2) + ) + layer_weight["mlp"]["wo"]["kernel"] = np.transpose( + layer_weight["mlp"]["wo"]["kernel"], axes=(1, 0, 2) + ) + + jax_weights["decoder"]["layers"]["mlp"] = layer_weight["mlp"] else: - layer_weight['gate']['kernel'] = np.array(layer_weight['gate']['kernel']) - layer_weight['gate']['kernel'] = np.transpose( - layer_weight['gate']['kernel'], axes=(1, 0, 2)) - jax_weights['decoder']['layers']['gate'] = layer_weight['gate'] + layer_weight["gate"]["kernel"] = np.array(layer_weight["gate"]["kernel"]) + layer_weight["gate"]["kernel"] = np.transpose( + layer_weight["gate"]["kernel"], axes=(1, 0, 2) + ) + jax_weights["decoder"]["layers"]["gate"] = layer_weight["gate"] for k in range(num_experts): - layer_weight[f'mlp_{k}']['wi_0']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wi_0']['kernel']) - layer_weight[f'mlp_{k}']['wi_1']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wi_1']['kernel']) - layer_weight[f'mlp_{k}']['wo']['kernel'] = np.array( - layer_weight[f'mlp_{k}']['wo']['kernel']) + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.array( + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] + ) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.array( + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] + ) + layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.array( + layer_weight[f"mlp_{k}"]["wo"]["kernel"] + ) # swap the layer index - layer_weight[f'mlp_{k}']['wi_0']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wi_0']['kernel'], axes=(1, 0, 2)) - layer_weight[f'mlp_{k}']['wi_1']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wi_1']['kernel'], axes=(1, 0, 2)) - layer_weight[f'mlp_{k}']['wo']['kernel'] = np.transpose( - layer_weight[f'mlp_{k}']['wo']['kernel'], axes=(1, 0, 2)) - - jax_weights['decoder']['layers'][f'mlp_{k}'] = layer_weight[f'mlp_{k}'] + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"] = np.transpose( + layer_weight[f"mlp_{k}"]["wi_0"]["kernel"], axes=(1, 0, 2) + ) + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"] = np.transpose( + layer_weight[f"mlp_{k}"]["wi_1"]["kernel"], axes=(1, 0, 2) + ) + layer_weight[f"mlp_{k}"]["wo"]["kernel"] = np.transpose( + layer_weight[f"mlp_{k}"]["wo"]["kernel"], axes=(1, 0, 2) + ) + + jax_weights["decoder"]["layers"][f"mlp_{k}"] = layer_weight[f"mlp_{k}"] mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis") - s1=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec("checkpoint_sharding_axis")) #shards first axis - s2=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec(None,"checkpoint_sharding_axis")) #shards second axis - s3=jax.sharding.NamedSharding(mesh,jax.sharding.PartitionSpec(None)) #no sharding + s1 = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis") + ) # shards first axis + s2 = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(None, "checkpoint_sharding_axis") + ) # shards second axis + s3 = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(None) + ) # no sharding def checkpoint_device_put(arr): - if arr.shape[0]%SIMULATED_CPU_DEVICES_COUNT==0: + if arr.shape[0] % SIMULATED_CPU_DEVICES_COUNT == 0: print("sharding first axis") return jax.device_put(arr, device=s1) - elif len(arr.shape)>1 and arr.shape[1]%SIMULATED_CPU_DEVICES_COUNT==0: + elif len(arr.shape) > 1 and arr.shape[1] % SIMULATED_CPU_DEVICES_COUNT == 0: print("sharding second axis") return jax.device_put(arr, device=s2) else: @@ -377,38 +480,43 @@ def checkpoint_device_put(arr): maxtext_model_path, enable_checkpointing, async_checkpointing, - save_interval_steps + save_interval_steps, ) state_new = train_state.TrainState( step=0, apply_fn=None, - params={'params': jax_weights}, + params={"params": jax_weights}, tx=None, # type: ignore - opt_state={} + opt_state={}, ) if checkpoint_manager is not None: - if save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new): + if save_checkpoint( + checkpoint_manager, step_number_to_save_new_ckpt, state_new + ): max_logging.log( - f"saved a checkpoint at step {step_number_to_save_new_ckpt}") + f"saved a checkpoint at step {step_number_to_save_new_ckpt}" + ) # Upon preemption, exit when and only when all ongoing saves are complete. if checkpoint_manager.reached_preemption(0): checkpoint_manager.wait_until_finished() sys.exit() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--base-model-path', type=str, required=True) - parser.add_argument('--maxtext-model-path', type=str, required=True) - parser.add_argument('--model-size', type=str, required=True) + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--maxtext-model-path", type=str, required=True) + parser.add_argument("--model-size", type=str, required=True) args = parser.parse_args() if args.model_size not in MODEL_PARAMS_DICT: raise NotImplementedError - os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}' + os.environ[ + "XLA_FLAGS" + ] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}" convert(args.base_model_path, args.maxtext_model_path, args.model_size) diff --git a/MaxText/max_logging.py b/MaxText/max_logging.py index 61d984f7b..23f7cc5cf 100644 --- a/MaxText/max_logging.py +++ b/MaxText/max_logging.py @@ -1,20 +1,21 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Stub for logging utilities. Right now just meant to avoid raw prints""" + def log(user_str): - print(user_str, flush = True) + print(user_str, flush=True) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 92b789bb4..7854e8c65 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Common Max Utils needed by multiple modules""" import checkpointing @@ -44,17 +44,26 @@ from google.cloud import storage + def find_nans_and_infs(pytree): def finder(x): return jnp.any(jnp.isinf(x) | jnp.isnan(x)) + bad_pytree = jax.tree_map(finder, pytree) return jax.tree_util.tree_flatten(bad_pytree) + def l2norm_pytree(x): """L2 norm of a pytree of arrays.""" - return jax.tree_util.tree_reduce( - lambda x, y: x + jax.numpy.sum(jax.numpy.square(y)), x, initializer=0.0 - ) ** 0.5 + return ( + jax.tree_util.tree_reduce( + lambda x, y: x + jax.numpy.sum(jax.numpy.square(y)), + x, + initializer=0.0, + ) + ** 0.5 + ) + def calculate_num_params_from_pytree(params): params_sizes = jax.tree_util.tree_map(jax.numpy.size, params) @@ -69,69 +78,89 @@ def calculate_leaf_params_per_chip(arr): return np.prod(shard.data.shape) params_sizes_per_chip = jax.tree_util.tree_map( - calculate_leaf_params_per_chip, params) + calculate_leaf_params_per_chip, params + ) total_parameters_per_chip = jax.tree_util.tree_reduce( - lambda x, y: x + y, params_sizes_per_chip) + lambda x, y: x + y, params_sizes_per_chip + ) return total_parameters_per_chip def calculate_bytes_from_pytree(params): - params_bytes = jax.tree_util.tree_map(lambda x : x.nbytes, params) + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) return total_bytes + def summarize_size_from_pytree(params): num_params = calculate_num_params_from_pytree(params) num_bytes = calculate_bytes_from_pytree(params) - return num_params, num_bytes, num_bytes/num_params + return num_params, num_bytes, num_bytes / num_params + def activate_profiler(config, optional_postfix=""): - if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0): + if config.enable_profiler and ( + config.upload_all_profiler_results or jax.process_index() == 0 + ): output_path = os.path.join(config.tensorboard_dir, optional_postfix) jax.profiler.start_trace(output_path) + def deactivate_profiler(config): - if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0): + if config.enable_profiler and ( + config.upload_all_profiler_results or jax.process_index() == 0 + ): jax.profiler.stop_trace() + def initialize_summary_writer(config): - return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None + return ( + writer.SummaryWriter(config.tensorboard_dir) + if jax.process_index() == 0 + else None + ) + def close_summary_writer(summary_writer): if jax.process_index() == 0: summary_writer.close() + def _prepare_metrics_for_json(metrics, step, run_name): """Converts metric dictionary into json supported types (e.g. float)""" metrics_dict = {} - for val in metrics['scalar']: - metrics_dict[val] = float(metrics['scalar'][val]) - metrics_dict['step'] = float(step) - metrics_dict['run_name'] = run_name + for val in metrics["scalar"]: + metrics_dict[val] = float(metrics["scalar"][val]) + metrics_dict["step"] = float(step) + metrics_dict["run_name"] = run_name return metrics_dict + def write_metrics_locally(metrics, step, config, file): """Writes metrics locally for testing""" if step == 0: file.truncate(0) metrics_dict = _prepare_metrics_for_json(metrics, step, config.run_name) - file.write(str(json.dumps(metrics_dict))+'\n') + file.write(str(json.dumps(metrics_dict)) + "\n") if step == config.steps - 1: file.close() + def add_config_to_summary_writer(config, summary_writer): """Writes config params to tensorboard""" if jax.process_index() == 0: for key, value in config.get_keys().items(): add_text_to_summary_writer(key, str(value), summary_writer) + def add_text_to_summary_writer(key, value, summary_writer): """Writes given key-value pair to tensorboard as text/summary""" if jax.process_index() == 0: summary_writer.add_text(key, value) + def write_metrics_for_gcs(metrics, step, config, running_metrics): """Writes metrics to gcs""" metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name) @@ -139,18 +168,19 @@ def write_metrics_for_gcs(metrics, step, config, running_metrics): if (step + 1) % config.log_period == 0 or step == config.steps - 1: start_step = (step // config.log_period) * config.log_period metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt" - with open(metrics_filename, 'w', encoding="utf8") as metrics_for_gcs: + with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs: for metrics_step in running_metrics: - metrics_for_gcs.write(str(json.dumps(metrics_step))+'\n') + metrics_for_gcs.write(str(json.dumps(metrics_step)) + "\n") metrics_for_gcs.close() - gcs_filename=os.path.join(config.metrics_dir, metrics_filename) + gcs_filename = os.path.join(config.metrics_dir, metrics_filename) max_logging.log(f"Moving file {metrics_filename} to GCS...") upload_blob(gcs_filename, metrics_filename) max_logging.log(f"File {metrics_filename} moved successfully!") - running_metrics = [] # reset running_metrics to empty list + running_metrics = [] # reset running_metrics to empty list return running_metrics + def write_config_raw_keys_for_gcs(raw_keys): """Writes config raw keys to GCS""" if not raw_keys["save_config_to_gcs"] or jax.process_index() != 0: @@ -159,21 +189,25 @@ def write_config_raw_keys_for_gcs(raw_keys): raw_keys_dict = dict(raw_keys) filename = "config.yml" - with open(filename, 'w', encoding="utf8") as config_for_gcs: + with open(filename, "w", encoding="utf8") as config_for_gcs: yaml.dump(raw_keys_dict, config_for_gcs) config_for_gcs.close() - gcs_filename=os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], filename) + gcs_filename = os.path.join( + raw_keys["base_output_directory"], raw_keys["run_name"], filename + ) max_logging.log(f"Moving file {filename} to GCS...") upload_blob(gcs_filename, filename) max_logging.log(f"File {filename} moved successfully!") + def parse_gcs_bucket_and_prefix(destination_gcs_name): path_parts = destination_gcs_name.replace("gs://", "").split("/") bucket = path_parts.pop(0) key = "/".join(path_parts) return bucket, key + def upload_blob(destination_gcs_name, source_file_name): """Uploads a file to a GCS location""" bucket_name, prefix_name = parse_gcs_bucket_and_prefix(destination_gcs_name) @@ -182,28 +216,37 @@ def upload_blob(destination_gcs_name, source_file_name): blob = bucket.blob(prefix_name) blob.upload_from_filename(source_file_name) + def maybe_initialize_jax_distributed_system(raw_keys): - """ The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of - indirection in MaxText to avoid breaking the call sites unnecessarily. + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. - Currently jax.distributed.initialize() fully works as expected! + Currently jax.distributed.initialize() fully works as expected! - For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. """ - if (raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] - and raw_keys["compile_topology_num_slices"]==-1) or raw_keys["hardware"]=='gpu_multiprocess': + if ( + raw_keys["enable_checkpointing"] + and raw_keys["async_checkpointing"] + and raw_keys["compile_topology_num_slices"] == -1 + ) or raw_keys["hardware"] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") jax.distributed.initialize() max_logging.log("Jax distributed system initialized!") elif is_gpu_backend(raw_keys): - max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") + max_logging.log( + "Attempting to initialize the jax distributed system for GPU backend..." + ) initialize_jax_for_gpu() max_logging.log("Jax distributed system initialized on GPU!") elif is_cpu_backend(raw_keys): - max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") + max_logging.log( + "Attempting to initialize the jax distributed system for CPU backend..." + ) initialize_jax_for_cpu() max_logging.log("Jax distributed system initialized on CPUs!") + def initialize_jax_for_gpu(): """Jax distributed initialize for GPUs.""" if os.environ.get("JAX_COORDINATOR_IP") is not None: @@ -212,14 +255,17 @@ def initialize_jax_for_gpu(): jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", num_processes=int(os.getenv("NNODES")), - process_id=int(os.getenv("NODE_RANK"))) + process_id=int(os.getenv("NODE_RANK")), + ) max_logging.log(f"JAX global devices: {jax.devices()}") + def initialize_jax_for_cpu(): - """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready. - """ + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" coordinator_ip_address = get_coordinator_ip_address() - coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + coordinator_address = ( + coordinator_ip_address + ":1234" + ) # JAX coordinator port used in XPK # Env variables to be set in XPK or otherwise job_index = int(os.environ.get("JOB_INDEX")) job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) @@ -227,17 +273,22 @@ def initialize_jax_for_cpu(): pid = job_index * processes_in_job + job_completion_index max_logging.log(f" Jax process id is {pid} ") # Explicit initialize is needed only for CPUs - jax.distributed.initialize(coordinator_address=coordinator_address, - process_id=pid, - num_processes=int(os.environ.get("JAX_PROCESS_COUNT"))) + jax.distributed.initialize( + coordinator_address=coordinator_address, + process_id=pid, + num_processes=int(os.environ.get("JAX_PROCESS_COUNT")), + ) + def is_cpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a CPU backend.""" - return raw_keys["hardware"] == 'cpu' + return raw_keys["hardware"] == "cpu" + def is_gpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a GPU backend.""" - return raw_keys["hardware"] == 'gpu' + return raw_keys["hardware"] == "gpu" + def get_coordinator_ip_address(): """Get coordinator IP Address with retries""" @@ -260,49 +311,75 @@ def get_coordinator_ip_address(): max_logging.log(f"Coordinator IP address: {coordinator_ip_address}") return coordinator_ip_address -def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type): + +def fill_unspecified_mesh_axes( + parallelism_vals, target_product, parallelism_type +): """Evaluates unspecified DCN/ICI parallelism values""" if -1 in parallelism_vals: - assert parallelism_vals.count(-1) == 1, f"Found unspecified values (-1) for more than one {parallelism_type}\ + assert ( + parallelism_vals.count(-1) == 1 + ), f"Found unspecified values (-1) for more than one {parallelism_type}\ parallelism axis. At most one axis can be unspecified." - determined_val = target_product/np.product(parallelism_vals)*-1 + determined_val = target_product / np.product(parallelism_vals) * -1 - assert determined_val >= 1 and determined_val.is_integer, f"Unspecified value unable to be determined with the given\ + assert ( + determined_val >= 1 and determined_val.is_integer + ), f"Unspecified value unable to be determined with the given\ {parallelism_type} parallelism values" parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) - target_type = "slices" if parallelism_type == 'DCN' else "devices per slice" + target_type = "slices" if parallelism_type == "DCN" else "devices per slice" - assert np.product(parallelism_vals) == target_product, f"Number of {target_type} {target_product} does not match\ + assert ( + np.product(parallelism_vals) == target_product + ), f"Number of {target_type} {target_product} does not match\ the product of the {parallelism_type} parallelism {np.product(parallelism_vals)}" return parallelism_vals + def create_device_mesh(config, devices=None): - """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas """ + """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" if devices is None: devices = jax.devices() num_devices = len(devices) num_slices = config.num_slices - num_devices_per_slice = num_devices//num_slices + num_devices_per_slice = num_devices // num_slices multi_slice_env = num_slices > 1 - dcn_parallelism = [config.dcn_data_parallelism, config.dcn_fsdp_parallelism, - config.dcn_fsdp_transpose_parallelism, config.dcn_sequence_parallelism, - config.dcn_tensor_parallelism, config.dcn_autoregressive_parallelism] - ici_parallelism = [config.ici_data_parallelism, config.ici_fsdp_parallelism, - config.ici_fsdp_transpose_parallelism, config.ici_sequence_parallelism, - config.ici_tensor_parallelism, config.ici_autoregressive_parallelism] + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_fsdp_transpose_parallelism, + config.dcn_sequence_parallelism, + config.dcn_tensor_parallelism, + config.dcn_autoregressive_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_fsdp_transpose_parallelism, + config.ici_sequence_parallelism, + config.ici_tensor_parallelism, + config.ici_autoregressive_parallelism, + ] # Find possible unspecified parallelisms - ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, 'ICI') + ici_parallelism = fill_unspecified_mesh_axes( + ici_parallelism, num_devices_per_slice, "ICI" + ) if multi_slice_env: - dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, 'DCN') - mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) + dcn_parallelism = fill_unspecified_mesh_axes( + dcn_parallelism, num_slices, "DCN" + ) + mesh = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, dcn_parallelism, devices + ) else: mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) @@ -310,41 +387,43 @@ def create_device_mesh(config, devices=None): return mesh -def unbox_logicallypartioned( - boxed_pytree): - """ Unboxes the flax.LogicallyPartitioned pieces - Args: - boxed_pytree: a pytree that includes LogicallyPartitioned - leaves. - Returns: - a pytree where all all LogicallyPartitioned leaves have been unboxed. +def unbox_logicallypartioned(boxed_pytree): + """Unboxes the flax.LogicallyPartitioned pieces + + Args: + boxed_pytree: a pytree that includes LogicallyPartitioned + leaves. + Returns: + a pytree where all all LogicallyPartitioned leaves have been unboxed. """ - return jax.tree_util.tree_map(lambda x: x.unbox() if \ - isinstance(x, flax.linen.spmd.LogicallyPartitioned) \ - else x, boxed_pytree, \ - is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned)) + return jax.tree_util.tree_map( + lambda x: x.unbox() + if isinstance(x, flax.linen.spmd.LogicallyPartitioned) + else x, + boxed_pytree, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + def init_decode_state(apply_fn, params): """Init train state with null opt state for decode.""" state = train_state.TrainState( - step=0, - apply_fn=apply_fn, - params=params, - tx=None, # type: ignore - opt_state={} - ) + step=0, + apply_fn=apply_fn, + params=params, + tx=None, # type: ignore + opt_state={}, + ) return state + def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create( - apply_fn=apply_fn, - params=params, - tx=tx - ) + state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state + def init_initial_state(model, tx, config, is_training, key): """ We pass in "static" objects like model, tx, config as JAX compares them by @@ -353,34 +432,57 @@ def init_initial_state(model, tx, config, is_training, key): Args: model, tx, config, is_training, key """ - input_shape = ( - config.global_batch_size_to_load, - config.max_target_length + input_shape = (config.global_batch_size_to_load, config.max_target_length) + model_vars = model.init( + {"params": key, "dropout": key, "aqt": key}, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(input_shape, dtype=jnp.int32), ) - model_vars = model.init({'params': key, 'dropout': key, 'aqt': key}, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(input_shape, dtype=jnp.int32)) if is_training: return init_training_state(model.apply, model_vars, tx) return init_decode_state(model.apply, model_vars) + def load_decode_model_vars(model, config, rng, mesh): state, _ = setup_decode_state(model, config, rng, mesh, None) return state.params + def setup_decode_state(model, config, rng, mesh, checkpoint_manager): is_training = False - state, state_mesh_annotations, _ = setup_initial_state(model, None, None, config, - rng, mesh, checkpoint_manager, - is_training) + state, state_mesh_annotations, _ = setup_initial_state( + model, None, None, config, rng, mesh, checkpoint_manager, is_training + ) return state, state_mesh_annotations -def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): + +def setup_training_state( + model, data_iterator, tx, config, rng, mesh, checkpoint_manager +): is_training = True - return setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training) + return setup_initial_state( + model, + data_iterator, + tx, + config, + rng, + mesh, + checkpoint_manager, + is_training, + ) -def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training=True): - """ We initialize the model and optimizer state, and optionally load from a + +def setup_initial_state( + model, + data_iterator, + tx, + config, + rng, + mesh, + checkpoint_manager, + is_training=True, +): + """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: @@ -397,33 +499,39 @@ def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_ state_mesh_annotations: the mesh annotations for the train state """ - unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state(model, tx, config, - rng, mesh, is_training) + ( + unboxed_abstract_state, + state_mesh_annotations, + state_mesh_shardings, + ) = get_abstract_state(model, tx, config, rng, mesh, is_training) # Initialization with nn_partitioning.axis_rules(config.logical_axis_rules): - restored, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, - data_iterator, - config.load_parameters_path, - config.load_full_state_path, - unboxed_abstract_state, - config.enable_single_replica_ckpt_restoring, - config.dataset_type, - ) + restored, raw_params = checkpointing.load_state_if_possible( + checkpoint_manager, + data_iterator, + config.load_parameters_path, + config.load_full_state_path, + unboxed_abstract_state, + config.enable_single_replica_ckpt_restoring, + config.dataset_type, + ) if restored: - if 'iter' in restored and restored['iter'] is not None: - data_iterator.local_iterator = restored['iter'] - state = restored['items'] + if "iter" in restored and restored["iter"] is not None: + data_iterator.local_iterator = restored["iter"] + state = restored["items"] else: - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) + init_state_partial = functools.partial( + init_initial_state, model, tx, config, is_training + ) state = jax.jit( init_state_partial, in_shardings=None, - out_shardings=state_mesh_shardings + out_shardings=state_mesh_shardings, )(rng) - if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params = raw_params) + if raw_params: # If we loaded a partial state, we need to merge it. + state = state.replace(params=raw_params) state = unbox_logicallypartioned(state) return state, state_mesh_annotations, data_iterator @@ -432,6 +540,7 @@ def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_ # Learning Rate Schedule # ----------------------------------------------------------------------------- + def create_learning_rate_schedule(config): """Creates a warmup and cosine decay learning rate schedule: We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 @@ -441,34 +550,36 @@ def create_learning_rate_schedule(config): 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. The zero learning rate section can be used to more accurately measure the fully trained model's performance. """ + def make_cos_schedule(init_lr, final_lr, len_steps): def schedule(step): pct = (step) / len_steps - a = 0.5 * (jnp.cos(jnp.pi*pct) + 1) + a = 0.5 * (jnp.cos(jnp.pi * pct) + 1) lr = init_lr * a + final_lr * (1 - a) return lr + return schedule lr = config.learning_rate cos_final_lr = lr * config.cosine_learning_rate_final_fraction - warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction) + warmup_steps = int( + config.learning_rate_schedule_steps * config.warmup_steps_fraction + ) cos_steps = config.learning_rate_schedule_steps - warmup_steps constant_zero_steps = config.steps - config.learning_rate_schedule_steps warmup_schedule = optax.linear_schedule( - init_value=0.0, - end_value=lr, - transition_steps=warmup_steps + init_value=0.0, end_value=lr, transition_steps=warmup_steps ) cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) constant_schedule = optax.constant_schedule(0.0) pieces = [warmup_schedule, cos_schedule] - boundaries=[ - warmup_steps, - warmup_steps + cos_steps, - ] + boundaries = [ + warmup_steps, + warmup_steps + cos_steps, + ] if constant_zero_steps > 0: pieces.append(constant_schedule) @@ -480,8 +591,9 @@ def schedule(step): # Cross entropy implementation is taken from original T5X codebase: # https://github.com/google-research/t5x/blob/ace831eea1e2742b4299cd1a9af7e4f302038351/t5x/losses.py#L25-L101 @jax.custom_vjp -def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, - z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: +def cross_entropy_with_logits( + logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Computes cross entropy loss with stable custom gradient. Computes a stabilized-gradient version of: -jnp.sum(targets * nn.log_softmax(logits), axis=-1) @@ -511,12 +623,19 @@ def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, def _cross_entropy_with_logits_fwd( - logits: jnp.ndarray, - targets: jnp.ndarray, - z_loss: float = 0.0 -) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], - Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, - jnp.ndarray, jnp.ndarray, jnp.ndarray]]: + logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0 +) -> Tuple[ + Tuple[jnp.ndarray, jnp.ndarray], + Tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + ], +]: """Forward-mode of `cross_entropy_with_logits`.""" max_logit = logits.max(axis=-1, keepdims=True) shifted = logits - max_logit @@ -528,72 +647,102 @@ def _cross_entropy_with_logits_fwd( log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss - return (loss, total_z_loss), (logits, targets, z_loss, exp_shifted, sum_exp, #pytype: disable=bad-return-type #jax-ndarray - log_softmax, log_z) + return (loss, total_z_loss), ( + logits, + targets, + z_loss, + exp_shifted, + sum_exp, # pytype: disable=bad-return-type #jax-ndarray + log_softmax, + log_z, + ) def _cross_entropy_with_logits_bwd( - res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, - jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray] + res: Tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + ], + g: Tuple[jnp.ndarray, jnp.ndarray], ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Backward-mode of `cross_entropy_with_logits`.""" g = g[0] # Ignore z_loss component as that is only used for logging. logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res # z-loss term adds the (2 * z_loss * log_z) factor. deriv = ( - jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - - targets) + jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp + - targets + ) g_logits = jnp.expand_dims(g, axis=-1) * deriv g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax - return (jnp.asarray(g_logits, - logits.dtype), jnp.asarray(g_targets, targets.dtype), - jnp.array(0.0)) # sets z-loss coeff gradient to 0 + return ( + jnp.asarray(g_logits, logits.dtype), + jnp.asarray(g_targets, targets.dtype), + jnp.array(0.0), + ) # sets z-loss coeff gradient to 0 + + +cross_entropy_with_logits.defvjp( + _cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd +) -cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, - _cross_entropy_with_logits_bwd) def get_abstract_state(model, tx, config, rng, mesh, is_training=True): - """ Get a shaped abstraction of the state (including optimizer)""" - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) + """Get a shaped abstraction of the state (including optimizer)""" + init_state_partial = functools.partial( + init_initial_state, model, tx, config, is_training + ) with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial, rng) state_logical_annotations = nn.get_partition_spec(abstract_state) - state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, - config.logical_axis_rules) + state_mesh_shardings = nn.logical_to_mesh_sharding( + state_logical_annotations, mesh, config.logical_axis_rules + ) abstract_sharded_state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings + init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings ).eval_shape(rng) - unboxed_abstract_sharded_state = unbox_logicallypartioned(abstract_sharded_state) + unboxed_abstract_sharded_state = unbox_logicallypartioned( + abstract_sharded_state + ) # Initialization with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) - return unboxed_abstract_sharded_state, state_mesh_annotations, state_mesh_shardings + return ( + unboxed_abstract_sharded_state, + state_mesh_annotations, + state_mesh_shardings, + ) + def get_kv_cache_annotations(model, config, rng, mesh): - """ Get a shaped abstraction of the state (including optimizer)""" + """Get a shaped abstraction of the state (including optimizer)""" def init_kv_cache(model, config): input_shape = ( - config.global_batch_size_to_load, - config.max_prefill_predict_length + config.global_batch_size_to_load, + config.max_prefill_predict_length, ) - model_vars = model.init({'params': rng, 'dropout': rng, 'aqt': rng}, - jnp.ones(input_shape), - jnp.ones(input_shape), - model_mode=common_types.MODEL_MODE_PREFILL) - return model_vars['cache'] + model_vars = model.init( + {"params": rng, "dropout": rng, "aqt": rng}, + jnp.ones(input_shape), + jnp.ones(input_shape), + model_mode=common_types.MODEL_MODE_PREFILL, + ) + return model_vars["cache"] with nn_partitioning.axis_rules(config.logical_axis_rules): - init_kv_cache_partial = functools.partial(init_kv_cache, model, - config) + init_kv_cache_partial = functools.partial(init_kv_cache, model, config) abstract_state = jax.eval_shape(init_kv_cache_partial) state_logical_annotations = nn.get_partition_spec(abstract_state) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): @@ -604,17 +753,23 @@ def init_kv_cache(model, config): def print_pytree_shape(print_str, ptree): print("\n") print(print_str) - print(jax.tree_util.tree_map(lambda x : x.shape, ptree)) + print(jax.tree_util.tree_map(lambda x: x.shape, ptree)) + def print_model_vars(print_str, model_vars): for k in model_vars: - print(f'{print_str} key{k}:') - print(f'\t {model_vars[k]}') + print(f"{print_str} key{k}:") + print(f"\t {model_vars[k]}") + def get_project(): - completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) - project_outputs = completed_command.stdout.decode().strip().split('\n') - if len(project_outputs) < 1 or project_outputs[-1]=='': - max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project '") + completed_command = subprocess.run( + ["gcloud", "config", "get", "project"], check=True, capture_output=True + ) + project_outputs = completed_command.stdout.decode().strip().split("\n") + if len(project_outputs) < 1 or project_outputs[-1] == "": + max_logging.log( + "You must specify config.vertex_tensorboard_project or set 'gcloud config set project '" + ) return None return project_outputs[-1] diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index 275ffac47..d8120e577 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -''' Implementation of Engine API for MaxText ''' +"""Implementation of Engine API for MaxText""" import functools from typing import Any, Optional, Tuple @@ -39,10 +39,10 @@ Params = Any - @struct.dataclass class DecodeState: """The inputs into a generation step.""" + prefill_cache: jax.Array generate_cache: jax.Array generate_cache_index: int @@ -66,7 +66,7 @@ def __init__(self, config): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - self.model = models.Transformer(config, mesh = self._mesh, quant=quant) + self.model = models.Transformer(config, mesh=self._mesh, quant=quant) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -76,47 +76,72 @@ def __init__(self, config): self.state_mesh_annotations = None def load_params(self, *args, **kwargs) -> Params: - ''' Load Parameters, typically from GCS ''' + """Load Parameters, typically from GCS""" # pylint: disable=unused-argument state, self.state_mesh_annotations = max_utils.setup_decode_state( - self.model, self.config, self.rng, self._mesh, None + self.model, self.config, self.rng, self._mesh, None + ) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct( + shape=x.shape, dtype=x.dtype, sharding=x.sharding + ), + state.params, + ) + self.kv_cache_annotations = max_utils.get_kv_cache_annotations( + self.model, self.config, self.rng, self._mesh + ) + self.kv_cache_shardings = jax.tree_map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.kv_cache_annotations, ) - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - state.params) - self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh) - self.kv_cache_shardings = jax.tree_map(lambda x : jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations) if not self.model.quant: - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - state.params) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct( + shape=x.shape, dtype=x.dtype, sharding=x.sharding + ), + state.params, + ) return state.params else: - self.model.quant.quant_mode = quantizations.get_quant_mode('convert') + self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @jax.jit def model_apply(_p, _rng): return self.model.apply( - _p | {"aqt": {}}, - jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32), - jnp.ones( (1, self.config.max_prefill_predict_length), dtype=jnp.int32), - decoder_segment_ids=jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32), - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': _rng}, - mutable=True + _p | {"aqt": {}}, + jnp.ones( + (1, self.config.max_prefill_predict_length), dtype=jnp.int32 + ), + jnp.ones( + (1, self.config.max_prefill_predict_length), dtype=jnp.int32 + ), + decoder_segment_ids=jnp.zeros( + (1, self.config.max_prefill_predict_length), dtype=jnp.int32 + ), + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": _rng}, + mutable=True, ) _, new_vars = model_apply(state.params, self.rng) params = {} - params['aqt'] = new_vars['aqt'] + params["aqt"] = new_vars["aqt"] # Remove param values which have corresponding qtensors in aqt to save memory. - params['params'] = quantizations.remove_quantized_params(state.params['params'], new_vars['aqt']) + params["params"] = quantizations.remove_quantized_params( + state.params["params"], new_vars["aqt"] + ) - self.abstract_params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), - params) + self.abstract_params = jax.tree_map( + lambda x: jax.ShapeDtypeStruct( + shape=x.shape, dtype=x.dtype, sharding=x.sharding + ), + params, + ) - self.model.quant.quant_mode = quantizations.get_quant_mode('serve') + self.model.quant.quant_mode = quantizations.get_quant_mode("serve") return params @functools.partial(jax.jit, static_argnums=(0,)) @@ -143,61 +168,78 @@ def prefill( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") - input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) zero_to_n = jnp.arange(0, padded_tokens.shape[0]) ones_to_keep = zero_to_n < true_length - one_d_output = ones_to_keep * common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + one_d_output = ( + ones_to_keep * common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + ) sequence_indicator = jnp.expand_dims(one_d_output, 0) with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): flat_logits, new_vars = self.model.apply( - params, - input_tokens, - positions, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': self.rng}, - mutable=["cache"] + params, + input_tokens, + positions, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": self.rng}, + mutable=["cache"], ) - next_pos = jnp.full((1,1), true_length, dtype = jnp.int32) - generated_tokens = jnp.zeros((1,1), dtype = jnp.int32) - selected_logits = jax.lax.dynamic_slice(flat_logits, (0, true_length-1,0), - (flat_logits.shape[0], 1, flat_logits.shape[2])) - selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding) - return {"logits" : selected_logits, "cache" : new_vars['cache'], - "next_pos" : next_pos, "generated_tokens" : generated_tokens} + next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32) + generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32) + selected_logits = jax.lax.dynamic_slice( + flat_logits, + (0, true_length - 1, 0), + (flat_logits.shape[0], 1, flat_logits.shape[2]), + ) + selected_logits = jax.lax.with_sharding_constraint( + selected_logits, self.replicated_sharding + ) + return { + "logits": selected_logits, + "cache": new_vars["cache"], + "next_pos": next_pos, + "generated_tokens": generated_tokens, + } @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,)) def generate( self, params: Params, decode_state: DecodeState ) -> Tuple[DecodeState, engine_api.ResultTokens]: - '''Run one generate step''' - previous_logits = decode_state['logits'] - - new_token = inference_utils.sampling(previous_logits, self.rng, self.config.decode_sampling_strategy, - topk=self.config.decode_sampling_top_k, - nucleus_topp=self.config.decode_sampling_nucleus_p, - temperature=self.config.decode_sampling_temperature) + """Run one generate step""" + previous_logits = decode_state["logits"] + + new_token = inference_utils.sampling( + previous_logits, + self.rng, + self.config.decode_sampling_strategy, + topk=self.config.decode_sampling_top_k, + nucleus_topp=self.config.decode_sampling_nucleus_p, + temperature=self.config.decode_sampling_temperature, + ) with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): out_logits, new_vars = self.model.apply( - params | { 'cache': decode_state['cache']}, - new_token, - decode_state['next_pos'], - enable_dropout=False, - model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'params': self.rng}, - mutable=['cache'] + params | {"cache": decode_state["cache"]}, + new_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": self.rng}, + mutable=["cache"], ) all_valid = jnp.ones(new_token.shape, dtype=jnp.int8) result = engine_api.ResultTokens( - data=jnp.concatenate((new_token, all_valid, decode_state["generated_tokens"]), axis=1), + data=jnp.concatenate( + (new_token, all_valid, decode_state["generated_tokens"]), axis=1 + ), # Tokens are shape [batch, speculations], so when we concatenate # tokens, validity and length along their index 1 dimension then they # occupy 0:speculations. @@ -209,66 +251,129 @@ def generate( samples_per_slot=1, ) - out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) - new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) - - return {"logits" : out_logits, "cache" : new_cache, - "next_pos" : decode_state["next_pos"]+1, "generated_tokens" : decode_state["generated_tokens"]+1}, result + out_logits = jax.lax.with_sharding_constraint( + out_logits, self.replicated_sharding + ) + new_cache = jax.lax.with_sharding_constraint( + new_vars["cache"], self.kv_cache_shardings + ) - @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(1, 2,)) + return { + "logits": out_logits, + "cache": new_cache, + "next_pos": decode_state["next_pos"] + 1, + "generated_tokens": decode_state["generated_tokens"] + 1, + }, result + + @functools.partial( + jax.jit, + static_argnums=(0,), + donate_argnums=( + 1, + 2, + ), + ) def insert( self, prefix: Prefix, decode_state: DecodeState, slot: int, ) -> DecodeState: - ''' Insert into KV cache ''' + """Insert into KV cache""" unboxed_prefix = max_utils.unbox_logicallypartioned(prefix) def copy(path, partial_cache, full_cache, annotations): path_key = path[-1].key - if path_key in ['cache_ar_index', 'cached_ar_key', 'cached_ar_value', 'cached_ar_key_scale', 'cached_ar_value_scale']: - return full_cache # we don't even zero these out because we can mask them out. - - batch_idx = annotations.index("cache_batch") if "cache_batch" in annotations else -1 + if path_key in [ + "cache_ar_index", + "cached_ar_key", + "cached_ar_value", + "cached_ar_key_scale", + "cached_ar_value_scale", + ]: + return full_cache # we don't even zero these out because we can mask them out. + + batch_idx = ( + annotations.index("cache_batch") + if "cache_batch" in annotations + else -1 + ) if batch_idx < 0: - raise ValueError(f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}") + raise ValueError( + f"Batch index {batch_idx=} shouldn't be less than zero for {path_key}, got {annotations=}" + ) - if path_key == 'cache_ar_segment_id': + if path_key == "cache_ar_segment_id": ### goal: zero this out in case there is existing data s = list(full_cache.shape) s[batch_idx] = 1 zeros = jnp.zeros(tuple(s), dtype=jnp.int32) - return jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx) - elif path_key == 'cache_prefill_segment_id': + return jax.lax.dynamic_update_index_in_dim( + full_cache, zeros, slot, batch_idx + ) + elif path_key == "cache_prefill_segment_id": s = list(full_cache.shape) s[batch_idx] = 1 zeros = jnp.zeros(tuple(s), dtype=jnp.int32) ## zero out in case prefill cache is too small to cover - full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx) + full_cache = jax.lax.dynamic_update_index_in_dim( + full_cache, zeros, slot, batch_idx + ) ## copy prefill cachce - full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) + full_cache = jax.lax.dynamic_update_index_in_dim( + full_cache, partial_cache, slot, batch_idx + ) return full_cache - elif path_key in ['cached_prefill_key', 'cached_prefill_value', - 'cached_prefill_key_scale', 'cached_prefill_value_scale']: - return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) + elif path_key in [ + "cached_prefill_key", + "cached_prefill_value", + "cached_prefill_key_scale", + "cached_prefill_value_scale", + ]: + return jax.lax.dynamic_update_index_in_dim( + full_cache, partial_cache, slot, batch_idx + ) else: raise ValueError(f"We don't have a strategy for inserting {path_key}") - inserted_cache = jax.tree_util.tree_map_with_path(copy, unboxed_prefix['cache'], decode_state['cache'], - self.kv_cache_annotations_named) - inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state['logits'], unboxed_prefix['logits'], slot, 0) - inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state['next_pos'], unboxed_prefix['next_pos'], slot, 0) - inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(decode_state['generated_tokens'], - unboxed_prefix['generated_tokens'], slot, 0) + inserted_cache = jax.tree_util.tree_map_with_path( + copy, + unboxed_prefix["cache"], + decode_state["cache"], + self.kv_cache_annotations_named, + ) + inserted_logits = jax.lax.dynamic_update_index_in_dim( + decode_state["logits"], unboxed_prefix["logits"], slot, 0 + ) + inserted_next_pos = jax.lax.dynamic_update_index_in_dim( + decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0 + ) + inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( + decode_state["generated_tokens"], + unboxed_prefix["generated_tokens"], + slot, + 0, + ) - inserted_logits = jax.lax.with_sharding_constraint(inserted_logits, self.replicated_sharding) - inserted_generated_tokens = jax.lax.with_sharding_constraint(inserted_generated_tokens, self.replicated_sharding) - inserted_next_pos = jax.lax.with_sharding_constraint(inserted_next_pos, self.replicated_sharding) - inserted_cache = jax.lax.with_sharding_constraint(inserted_cache, self.kv_cache_shardings) + inserted_logits = jax.lax.with_sharding_constraint( + inserted_logits, self.replicated_sharding + ) + inserted_generated_tokens = jax.lax.with_sharding_constraint( + inserted_generated_tokens, self.replicated_sharding + ) + inserted_next_pos = jax.lax.with_sharding_constraint( + inserted_next_pos, self.replicated_sharding + ) + inserted_cache = jax.lax.with_sharding_constraint( + inserted_cache, self.kv_cache_shardings + ) - return {'logits' : inserted_logits, 'cache' : inserted_cache, - 'next_pos' : inserted_next_pos, 'generated_tokens' : inserted_generated_tokens } + return { + "logits": inserted_logits, + "cache": inserted_cache, + "next_pos": inserted_next_pos, + "generated_tokens": inserted_generated_tokens, + } def get_prefix_destination_sharding(self) -> Any: return jax.sharding.NamedSharding( @@ -277,32 +382,52 @@ def get_prefix_destination_sharding(self) -> Any: def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: """Return a protobuf of tokenizer info, callable from Py or C++.""" - return tokenizer_pb2.TokenizerParameters(path=self.config.tokenizer_path, extra_ids=0) + return tokenizer_pb2.TokenizerParameters( + path=self.config.tokenizer_path, extra_ids=0 + ) def init_decode_state(self, *args, **kwargs) -> DecodeState: """Initialises any state which a generation step transforms.""" + # pylint: disable=unused-argument def init(abstract_params): - x = jnp.ones( (int(self.config.per_device_batch_size * jax.device_count()), self.config.max_prefill_predict_length), - dtype=jnp.int32) + x = jnp.ones( + ( + int(self.config.per_device_batch_size * jax.device_count()), + self.config.max_prefill_predict_length, + ), + dtype=jnp.int32, + ) _, cache = self.model.apply( - abstract_params, - x, - x, - decoder_segment_ids=jnp.zeros(x.shape, dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR, - enable_dropout=False, - model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'params': self.rng}, - mutable=["cache"] + abstract_params, + x, + x, + decoder_segment_ids=jnp.zeros(x.shape, dtype=jnp.int32) + + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"params": self.rng}, + mutable=["cache"], ) - next_pos = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32) - generated_tokens = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32) - return {"logits" : jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1, self.config.vocab_size)), - "cache" : cache["cache"], - "next_pos" : next_pos, - "generated_tokens" : generated_tokens - } + next_pos = jnp.zeros( + (int(self.config.per_device_batch_size * jax.device_count()), 1), + dtype=jnp.int32, + ) + generated_tokens = jnp.zeros( + (int(self.config.per_device_batch_size * jax.device_count()), 1), + dtype=jnp.int32, + ) + return { + "logits": jnp.zeros(( + int(self.config.per_device_batch_size * jax.device_count()), + 1, + self.config.vocab_size, + )), + "cache": cache["cache"], + "next_pos": next_pos, + "generated_tokens": generated_tokens, + } with nn_partitioning.axis_rules(self.config.logical_axis_rules): abstract_outputs = jax.eval_shape(init, self.abstract_params) @@ -311,18 +436,27 @@ def init(abstract_params): with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): mesh_annotations = nn.logical_to_mesh(logical_annotations) - shardings = jax.tree_map(lambda mesh_annotation : jax.sharding.NamedSharding(self._mesh, mesh_annotation), - mesh_annotations) + shardings = jax.tree_map( + lambda mesh_annotation: jax.sharding.NamedSharding( + self._mesh, mesh_annotation + ), + mesh_annotations, + ) - @functools.partial(jax.jit, out_shardings = shardings) + @functools.partial(jax.jit, out_shardings=shardings) def initialize(): - return jax.tree_map( lambda x : jnp.zeros(x.shape, x.dtype), abstract_outputs) + return jax.tree_map( + lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs + ) - cache = initialize()['cache'] + cache = initialize()["cache"] def is_lp(k): return isinstance(k, flax.linen.spmd.LogicallyPartitioned) - self.kv_cache_annotations_named = jax.tree_util.tree_map(lambda x : tuple(x.names), cache, is_leaf=is_lp) + + self.kv_cache_annotations_named = jax.tree_util.tree_map( + lambda x: tuple(x.names), cache, is_leaf=is_lp + ) del cache zeroed = max_utils.unbox_logicallypartioned(initialize()) return zeroed diff --git a/MaxText/maxengine_config.py b/MaxText/maxengine_config.py index da96c06f9..b2862fd89 100644 --- a/MaxText/maxengine_config.py +++ b/MaxText/maxengine_config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -'''Configure MaxText For JetStream''' +"""Configure MaxText For JetStream""" import functools import jax @@ -22,26 +22,29 @@ from jetstream.engine import engine_api import maxengine -def create_maxengine(devices: config_lib.Devices, config: Any) -> engine_api.Engine: + +def create_maxengine( + devices: config_lib.Devices, config: Any +) -> engine_api.Engine: del devices return maxengine.MaxEngine(config) -def get_server_config(config_str: str, config: Any) -> Type[config_lib.ServerConfig]: - ''' Gets the Server Config Required by JetStream ''' +def get_server_config( + config_str: str, config: Any +) -> Type[config_lib.ServerConfig]: + """Gets the Server Config Required by JetStream""" match config_str: - case 'MaxtextInterleavedServer': + case "MaxtextInterleavedServer": server_config = config_lib.ServerConfig( - prefill_slices = (), - generate_slices = (), - interleaved_slices = ('tpu='+str(jax.device_count()),), - prefill_engine_create_fns = (), - generate_engine_create_fns = (), - interleaved_engine_create_fns = (functools.partial( - create_maxengine, - config=config + prefill_slices=(), + generate_slices=(), + interleaved_slices=("tpu=" + str(jax.device_count()),), + prefill_engine_create_fns=(), + generate_engine_create_fns=(), + interleaved_engine_create_fns=( + functools.partial(create_maxengine, config=config), ), - ) ) case _: raise NotImplementedError diff --git a/MaxText/maxengine_server.py b/MaxText/maxengine_server.py index 3cf46c8a5..0972714f3 100644 --- a/MaxText/maxengine_server.py +++ b/MaxText/maxengine_server.py @@ -19,7 +19,7 @@ import sys import pyconfig -import maxengine_config +import maxengine_config from jetstream.core import server_lib # _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') @@ -36,7 +36,9 @@ def main(config): # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() - server_config = maxengine_config.get_server_config('MaxtextInterleavedServer', config) + server_config = maxengine_config.get_server_config( + "MaxtextInterleavedServer", config + ) # We separate credential from run so that we can unit test it with # local credentials. # TODO: Add grpc credentials for OSS. @@ -49,8 +51,8 @@ def main(config): jetstream_server.wait_for_termination() -if __name__ == '__main__': - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') +if __name__ == "__main__": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(sys.argv) cfg = pyconfig.config diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 6b7a81b03..f72f0c311 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting to MaxText. """ @@ -28,45 +28,77 @@ from input_pipeline import input_pipeline_interface - -def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config): - """ Get the shardings (both state and data) for train_step """ +def get_functional_train_with_signature( + train_step, mesh, state_mesh_annotations, model, config +): + """Get the shardings (both state and data) for train_step""" functional_train = get_functional_train_step(train_step, model, config) functional_train.__name__ = "train_step" data_pspec = P(*config.data_sharding) state_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations + ) data_sharding = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng - out_shardings = (state_mesh_shardings, None) # State, metrics - static_argnums = () # We partial out the static argnums of model and config - donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. - return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums + lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec + ) + in_shardings = ( + state_mesh_shardings, + data_sharding, + None, + ) # State, batch, rng + out_shardings = (state_mesh_shardings, None) # State, metrics + static_argnums = () # We partial out the static argnums of model and config + donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. + return ( + functional_train, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + ) + def get_functional_train_step(train_step, model, config): return functools.partial(train_step, model, config) -def get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config): - """ Get the shardings (both state and data) for eval_step """ + +def get_functional_eval_with_signature( + eval_step, mesh, state_mesh_annotations, model, config +): + """Get the shardings (both state and data) for eval_step""" functional_eval = get_functional_eval_step(eval_step, model, config) functional_eval.__name__ = "eval_step" data_pspec = P(*config.data_sharding) state_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations + ) data_sharding = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng - out_shardings = None # metrics - static_argnums = () # We partial out the static argnums of model, config - donate_argnums = () # state will be kept instead of being donated in eval_step - return functional_eval, in_shardings, out_shardings, static_argnums, donate_argnums + lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec + ) + in_shardings = ( + state_mesh_shardings, + data_sharding, + None, + ) # State, batch, rng + out_shardings = None # metrics + static_argnums = () # We partial out the static argnums of model, config + donate_argnums = () # state will be kept instead of being donated in eval_step + return ( + functional_eval, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + ) + def get_functional_eval_step(eval_step, model, config): return functools.partial(eval_step, model, config) + def load_compiled(config, partial_train, state): - """ # Loading a serialized compiled train step function.""" + """# Loading a serialized compiled train step function.""" + # Currently partial_train and state are needed to reconstruct # input/output shapes to construct the in_trees and out_trees for load API # Parker is working on a serializing these @@ -76,7 +108,9 @@ def load_serialized_compiled(save_name): return serialized_compiled def get_train_input_output_trees(func, input_args, input_kwargs): - _, in_tree_recreated = jax.tree_util.tree_flatten((input_args, input_kwargs)) + _, in_tree_recreated = jax.tree_util.tree_flatten( + (input_args, input_kwargs) + ) out_shaped = jax.eval_shape(func, *input_args, **input_kwargs) _, out_tree_recreated = jax.tree_util.tree_flatten(out_shaped) return in_tree_recreated, out_tree_recreated @@ -86,44 +120,78 @@ def get_train_input_output_trees(func, input_args, input_kwargs): example_rng = jax.random.PRNGKey(0) shaped_input_args = (state, shaped_batch, example_rng) shaped_input_kwargs = {} - in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs) + in_tree, out_tree = get_train_input_output_trees( + partial_train, shaped_input_args, shaped_input_kwargs + ) p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree) return p_train_step + # https://arxiv.org/pdf/2204.02311.pdf Appendix B -def calculate_tflops_training_per_device(num_model_parameters, config, log=True): - """ Calculate training TFLOP""" - learnable_weight_tflops = 6 * num_model_parameters * config.max_target_length * config.per_device_batch_size \ - / 10**12 - noncasual_attention_flops = 12 * config.num_query_heads * config.num_decoder_layers * config.head_dim \ - * config.max_target_length**2 * config.per_device_batch_size / 10**12 - causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention +def calculate_tflops_training_per_device( + num_model_parameters, config, log=True +): + """Calculate training TFLOP""" + learnable_weight_tflops = ( + 6 + * num_model_parameters + * config.max_target_length + * config.per_device_batch_size + / 10**12 + ) + noncasual_attention_flops = ( + 12 + * config.num_query_heads + * config.num_decoder_layers + * config.head_dim + * config.max_target_length**2 + * config.per_device_batch_size + / 10**12 + ) + causal_attention_tflops = ( + noncasual_attention_flops / 2 + ) # due to causality in attention total_tflops = learnable_weight_tflops + causal_attention_tflops if log: - print('Per train step:\n', - f'Total TFLOPs: {total_tflops:.2f} \n', - f'split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops', - f'and {100 * causal_attention_tflops/total_tflops:.2f}% attention flops') + print( + "Per train step:\n", + f"Total TFLOPs: {total_tflops:.2f} \n", + f"split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops", + f"and {100 * causal_attention_tflops/total_tflops:.2f}% attention flops", + ) return total_tflops, learnable_weight_tflops, causal_attention_tflops + # https://arxiv.org/pdf/2204.02311.pdf Appendix B -def calculate_tflops_prefill(num_model_parameters, prefill_length, config, log=True): - """ Calculate training TFLOP""" - learnable_weight_tflops = 2 * num_model_parameters * prefill_length \ - / 10**12 - noncasual_attention_flops = 4 * config.num_query_heads * config.num_decoder_layers * config.head_dim \ - * prefill_length**2 * config.per_device_batch_size / 10**12 - causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention +def calculate_tflops_prefill( + num_model_parameters, prefill_length, config, log=True +): + """Calculate training TFLOP""" + learnable_weight_tflops = 2 * num_model_parameters * prefill_length / 10**12 + noncasual_attention_flops = ( + 4 + * config.num_query_heads + * config.num_decoder_layers + * config.head_dim + * prefill_length**2 + * config.per_device_batch_size + / 10**12 + ) + causal_attention_tflops = ( + noncasual_attention_flops / 2 + ) # due to causality in attention total_tflops = learnable_weight_tflops + causal_attention_tflops if log: - print('Per prefill step: \n', - f'\tTotal TFLOPs: {total_tflops:.2f} \n', - f'\t\tLearnable weight TFLOPs: {learnable_weight_tflops} ', - f'({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n', - f'\t\tCausal attention TFLOPs: {causal_attention_tflops} ', - f'({100 * causal_attention_tflops/total_tflops:.2f})% of Total') + print( + "Per prefill step: \n", + f"\tTotal TFLOPs: {total_tflops:.2f} \n", + f"\t\tLearnable weight TFLOPs: {learnable_weight_tflops} ", + f"({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n", + f"\t\tCausal attention TFLOPs: {causal_attention_tflops} ", + f"({100 * causal_attention_tflops/total_tflops:.2f})% of Total", + ) return total_tflops, learnable_weight_tflops, causal_attention_tflops @@ -144,22 +212,21 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01): bool: True if the majority of parameters are sufficiently sharded """ total_num_params = max_utils.calculate_num_params_from_pytree(params) - product_num_devices_for_weight_sharding = 1 - for axis in ['fsdp', 'fsdp_transpose', 'sequence', 'tensor']: + product_num_devices_for_weight_sharding = 1 + for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor"]: product_num_devices_for_weight_sharding *= mesh.shape[axis] - total_num_params_per_chip = ( - max_utils.calculate_total_params_per_chip( - params) - ) + total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) perfectly_sharded_params_per_chip = ( - total_num_params / product_num_devices_for_weight_sharding + total_num_params / product_num_devices_for_weight_sharding ) assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( - 'Number of parameters per chip must not be less than in the ideal sharded ' - 'scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes.' - ) + "Number of parameters per chip must not be less than in the ideal sharded " + "scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes." + ) assert ( - total_num_params_per_chip/perfectly_sharded_params_per_chip - 1 < tolerance - ), (f'Number of unsharded parameters exceeds tolerance {tolerance * 100}% ' - 'of total parameters.') - + total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 + < tolerance + ), ( + f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " + "of total parameters." + ) diff --git a/MaxText/multihost_dataloading.py b/MaxText/multihost_dataloading.py index 8c9088961..9fbf0c645 100644 --- a/MaxText/multihost_dataloading.py +++ b/MaxText/multihost_dataloading.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=unused-import """SPMD Multihost Dataloading Utilities. @@ -36,6 +36,7 @@ import max_logging + def _build_global_shape_and_sharding( local_shape: tuple[int, ...], global_mesh: Mesh ) -> tuple[tuple[int, ...], NamedSharding]: @@ -47,26 +48,32 @@ def _build_global_shape_and_sharding( def _form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array: - """ Put local sharded array into local devices - """ - global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh) + """Put local sharded array into local devices""" + global_shape, sharding = _build_global_shape_and_sharding( + np.shape(array), global_mesh + ) try: - local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0) + local_device_arrays = np.split( + array, len(global_mesh.local_devices), axis=0 + ) except ValueError as array_split_error: raise ValueError( - f"Unable to put to devices shape {array.shape} with " - f"local device count {len(global_mesh.local_devices)} " - f"at {jtu.keystr(path)}" + f"Unable to put to devices shape {array.shape} with " + f"local device count {len(global_mesh.local_devices)} " + f"at {jtu.keystr(path)}" ) from array_split_error - local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices) - return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers) - + local_device_buffers = jax.device_put( + local_device_arrays, global_mesh.local_devices + ) + return jax.make_array_from_single_device_arrays( + global_shape, sharding, local_device_buffers + ) def get_next_batch_sharded( - local_iterator: Iterator, global_mesh: Mesh + local_iterator: Iterator, global_mesh: Mesh ) -> jax.Array: """Splits the host loaded data equally over all devices.""" @@ -88,14 +95,21 @@ def get_next_batch_sharded( if not loaded_data_success: local_data = next(local_iterator) - input_gdas = jtu.tree_map_with_path(partial(_form_global_array, global_mesh = global_mesh), local_data) + input_gdas = jtu.tree_map_with_path( + partial(_form_global_array, global_mesh=global_mesh), local_data + ) return input_gdas class MultiHostDataLoadIterator: """fold get_next_batch_sharded into a iterator class""" - def __init__(self, dataloader: Union[tf.data.Dataset, grain.DataLoader], global_mesh: Mesh): + + def __init__( + self, + dataloader: Union[tf.data.Dataset, grain.DataLoader], + global_mesh: Mesh, + ): self.global_mesh = global_mesh self.dataloader = dataloader if isinstance(self.dataloader, tf.data.Dataset): @@ -103,7 +117,9 @@ def __init__(self, dataloader: Union[tf.data.Dataset, grain.DataLoader], global_ elif isinstance(self.dataloader, grain.DataLoader): self.local_iterator = iter(self.dataloader) else: - raise ValueError("Type error: dataloader should be either tf.data.Dataset or grain.DataLoader.") + raise ValueError( + "Type error: dataloader should be either tf.data.Dataset or grain.DataLoader." + ) def reset(self): if isinstance(self.dataloader, tf.data.Dataset): @@ -111,7 +127,9 @@ def reset(self): elif isinstance(self.dataloader, grain.DataLoader): self.local_iterator = iter(self.dataloader) else: - raise ValueError("Type error: dataloader should be either tf.data.Dataset or grain.DataLoader.") + raise ValueError( + "Type error: dataloader should be either tf.data.Dataset or grain.DataLoader." + ) def __iter__(self): self.reset() diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index e2a23abda..9b8300dfd 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator, ungrouped-imports """Utils that are only interesting to MaxText. """ @@ -29,25 +29,26 @@ def get_optimizer(config, learning_rate_schedule): if config.opt_type == "adamw": # Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 return optax.adamw( - learning_rate_schedule, - b1=config.adam_b1, - b2=config.adam_b2, - eps=config.adam_eps, - eps_root=config.adam_eps_root, - weight_decay=config.adam_weight_decay, + learning_rate_schedule, + b1=config.adam_b1, + b2=config.adam_b2, + eps=config.adam_eps, + eps_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, ) elif config.opt_type == "adam_pax": return adam_pax( - learning_rate_schedule, - beta1=config.adam_b1, - beta2=config.adam_b2, - epsilon=config.adam_eps, - epsilon_root=config.adam_eps_root, - weight_decay=config.adam_weight_decay, + learning_rate_schedule, + beta1=config.adam_b1, + beta2=config.adam_b2, + epsilon=config.adam_eps, + epsilon_root=config.adam_eps_root, + weight_decay=config.adam_weight_decay, ) else: raise ValueError(f"{config.opt_type=} is not a supported.") + def adam_pax( learning_rate_fn: optax.Schedule, beta1: float, @@ -55,7 +56,7 @@ def adam_pax( epsilon: float, epsilon_root: float, weight_decay: float, - ) -> optax.GradientTransformation: +) -> optax.GradientTransformation: """Standard Adam optimizer that supports weight decay. Follows the implemenation in pax/praxis sharded_adam @@ -77,8 +78,7 @@ def adam_pax( """ def init_fn(params): - mu = jax.tree_util.tree_map( # First moment - jnp.zeros_like, params) + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment return optax.ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) @@ -102,8 +102,10 @@ def bias_corrected_decay(step: jnp.int32, decay: float): Returns: Bias corrected decay. """ - t = step.astype(jnp.float32) + 1. - return decay * (1. - jnp.power(decay, t - 1.)) / (1. - jnp.power(decay, t)) + t = step.astype(jnp.float32) + 1.0 + return ( + decay * (1.0 - jnp.power(decay, t - 1.0)) / (1.0 - jnp.power(decay, t)) + ) def update_fn(updates, state, params=None): # Sanitize updates just in case. @@ -112,6 +114,7 @@ def update_fn(updates, state, params=None): count = state.count class _slot_opt_state: + def __init__(self, mu, nu): self.mu = mu self.nu = nu @@ -128,13 +131,16 @@ def _update_momentum(update, mu, nu): nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu return _slot_opt_state(mu=mu, nu=nu) - updated_moments = jax.tree_map(_update_momentum, updates, state.mu, state.nu) + updated_moments = jax.tree_map( + _update_momentum, updates, state.mu, state.nu + ) mu = jax.tree_map(lambda x: x.mu, updated_moments) nu = jax.tree_map(lambda x: x.nu, updated_moments) updates = jax.tree_map( - lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu) + lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu + ) if weight_decay > 0: updates = jax.tree_map(lambda x, v: x + weight_decay * v, updates, params) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index e19af5456..dc532b337 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, bare-except, consider-using-generator from collections import OrderedDict @@ -35,9 +35,11 @@ # YAML attribute to specify inheritance. _BASE_CONFIG_ATTR = "base_config" + def yaml_key_to_env_key(s: str) -> str: return _MAX_PREFIX + s.upper() + def string_to_bool(s: str) -> bool: if s.lower() == "true": return True @@ -45,70 +47,108 @@ def string_to_bool(s: str) -> bool: return False raise ValueError(f"Can't convert {s} to bool") -_yaml_types_to_parser = {str : str, int : int, float : float, bool : string_to_bool} + +_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool} + def validate_attention_type(s: str) -> None: - valid_attention_types = ('autoselected', 'dot_product', 'flash', 'cudnn_flash_te') - if s not in valid_attention_types: # currently supported attention + valid_attention_types = ( + "autoselected", + "dot_product", + "flash", + "cudnn_flash_te", + ) + if s not in valid_attention_types: # currently supported attention raise ValueError( - "Invalid attention type was passed. Valid options ", valid_attention_types + "Invalid attention type was passed. Valid options ", + valid_attention_types, ) + def validate_keys(keys): - validate_attention_type(keys['attention']) + validate_attention_type(keys["attention"]) + + assert ( + keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "" + ) or keys[ + "enable_checkpointing" + ], "You must set enable_checkpointing to load a checkpoint" + assert ( + keys["load_parameters_path"] == "" or keys["load_full_state_path"] == "" + ), "At most one of `load_parameters_path` or `load_full_state_path` should be set" - assert ((keys["load_parameters_path"]=="" and keys["load_full_state_path"]=="") or - keys["enable_checkpointing"]), "You must set enable_checkpointing to load a checkpoint" - assert keys["load_parameters_path"]=="" or keys["load_full_state_path"]=="",\ - "At most one of `load_parameters_path` or `load_full_state_path` should be set" def validate_model_name(s: str) -> bool: # currently supported models - valid_model_names = ('default', 'llama2-7b', 'llama2-13b', 'llama2-70b', 'mistral-7b', - 'mixtral-8x7b', 'gemma-7b','gemma-2b', - 'gpt3-175b', 'gpt3-22b', 'gpt3-6b', 'gpt3-52k') + valid_model_names = ( + "default", + "llama2-7b", + "llama2-13b", + "llama2-70b", + "mistral-7b", + "mixtral-8x7b", + "gemma-7b", + "gemma-2b", + "gpt3-175b", + "gpt3-22b", + "gpt3-6b", + "gpt3-52k", + ) if s not in valid_model_names: raise ValueError( - "Invalid model name was passed. Valid options ", valid_model_names + "Invalid model name was passed. Valid options ", valid_model_names ) + def validate_no_keys_overwritten_twice(keys1: list[str], keys2: list[str]): overwritten_keys = [k for k in keys1 if k in keys2] if overwritten_keys: raise ValueError( f"Keys {overwritten_keys} are overwritten from both the model" - " and the environment/command line. This isn't allowed.") + " and the environment/command line. This isn't allowed." + ) + _config = None config = None + def print_system_information(): max_logging.log(f"System Information: Jax Version: {jax.__version__}") max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") - max_logging.log(f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}") + max_logging.log( + f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}" + ) + -def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any],list[Any]]: +def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any], list[Any]]: return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l -class _HyperParameters(): + +class _HyperParameters: # pylint: disable=missing-class-docstring def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]): for environment_var in os.environ: - if environment_var[:len(_MAX_PREFIX)] == _MAX_PREFIX: - proposed_key = environment_var[len(_MAX_PREFIX):].lower() + if environment_var[: len(_MAX_PREFIX)] == _MAX_PREFIX: + proposed_key = environment_var[len(_MAX_PREFIX) :].lower() if proposed_key not in raw_data_from_yaml: - raise ValueError(f"We received env `{environment_var}` but it doesn't match a key, so it is assumed a mistake.") - if not environment_var[len(_MAX_PREFIX):].isupper(): - raise ValueError(f"We received env `{environment_var}` but it isn't all uppercase.") + raise ValueError( + f"We received env `{environment_var}` but it doesn't match a key, so it is assumed a mistake." + ) + if not environment_var[len(_MAX_PREFIX) :].isupper(): + raise ValueError( + f"We received env `{environment_var}` but it isn't all uppercase." + ) def _load_kwargs(self, argv: list[str], **kwargs): args_dict = dict(a.split("=", 1) for a in argv[2:]) args_dict.update(kwargs) return args_dict - def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]: - ''' Update model config from environment and command line - ''' + def _update_from_env_and_command_line( + self, raw_keys, raw_data_from_yaml, argv, **kwargs + ) -> list[str]: + """Update model config from environment and command line""" raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs) updated_keys = [] @@ -121,9 +161,13 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, for k in raw_data_from_yaml: if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ: raise ValueError( - f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.") + f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed." + ) - if not k in raw_data_from_cmd_line and not yaml_key_to_env_key(k) in os.environ: + if ( + not k in raw_data_from_cmd_line + and not yaml_key_to_env_key(k) in os.environ + ): raw_keys[k] = raw_data_from_yaml[k] continue @@ -133,8 +177,9 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, else: new_proposal = os.environ.get(yaml_key_to_env_key(k)) - if (not isinstance(new_proposal, type(raw_data_from_yaml[k]))) and \ - (type(raw_data_from_yaml[k]) not in _yaml_types_to_parser): + if (not isinstance(new_proposal, type(raw_data_from_yaml[k]))) and ( + type(raw_data_from_yaml[k]) not in _yaml_types_to_parser + ): raise ValueError( f"For key '{k}', type {type(raw_data_from_yaml[k])} not in {_yaml_types_to_parser.keys()}, can't pass" " at the CLI or ENV" @@ -149,7 +194,8 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, ) # take the command line value, but type it like the config value. except ValueError as e: raise ValueError( - f"Couldn't parse value from CLI or ENV '{new_proposal}' for key '{k}'") from e + f"Couldn't parse value from CLI or ENV '{new_proposal}' for key '{k}'" + ) from e return updated_keys @@ -181,63 +227,93 @@ def _load_config(self, config_name: str) -> dict[str, Any]: return raw_data_from_yaml def __init__(self, argv: list[str], **kwargs): - config_name : str = argv[1] + config_name: str = argv[1] raw_data_from_yaml = self._load_config(config_name) self._validate_env_variables(raw_data_from_yaml) raw_keys = OrderedDict() - keys_from_env_and_command_line = self._update_from_env_and_command_line(raw_keys, raw_data_from_yaml, argv, **kwargs) + keys_from_env_and_command_line = self._update_from_env_and_command_line( + raw_keys, raw_data_from_yaml, argv, **kwargs + ) max_logging.log( - f"Updating keys from env and command line: {keys_from_env_and_command_line}") - keys_from_model = _HyperParameters.update_model_vars(argv[1], raw_keys, config_name) + f"Updating keys from env and command line: {keys_from_env_and_command_line}" + ) + keys_from_model = _HyperParameters.update_model_vars( + argv[1], raw_keys, config_name + ) max_logging.log(f"Updating keys from model: {keys_from_model}") - validate_no_keys_overwritten_twice(keys_from_env_and_command_line, keys_from_model) + validate_no_keys_overwritten_twice( + keys_from_env_and_command_line, keys_from_model + ) # We initialize the jax distributed system here because it must be done before device backend is initialized. max_utils.maybe_initialize_jax_distributed_system(raw_keys) - if raw_keys['jax_cache_dir']: - compilation_cache.set_cache_dir(os.path.expanduser(raw_keys['jax_cache_dir'])) + if raw_keys["jax_cache_dir"]: + compilation_cache.set_cache_dir( + os.path.expanduser(raw_keys["jax_cache_dir"]) + ) - if raw_keys['model_name'] == "gpt3-175b": + if raw_keys["model_name"] == "gpt3-175b": _HyperParameters.configure_gpt3_task(raw_keys) _HyperParameters.user_init(raw_keys) self.keys = raw_keys - keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension + keys = [k for k in raw_keys] # pylint: disable=unnecessary-comprehension keys.sort() for k in keys: max_logging.log(f"Config param {k}: {raw_keys[k]}") @staticmethod def user_init(raw_keys): - '''Transformations between the config data and configs used at runtime''' + """Transformations between the config data and configs used at runtime""" if raw_keys["run_name"] == "": - raw_keys["run_name"] = os.environ.get("JOBSET_NAME") #using XPK default + raw_keys["run_name"] = os.environ.get("JOBSET_NAME") # using XPK default run_name = raw_keys["run_name"] base_output_directory = raw_keys["base_output_directory"] if run_name: - raw_keys["tensorboard_dir"] = os.path.join(base_output_directory, run_name, "tensorboard", "") - raw_keys["checkpoint_dir"] = os.path.join(base_output_directory, run_name, "checkpoints", "") - raw_keys["metrics_dir"] = os.path.join(base_output_directory, run_name, "metrics", "") + raw_keys["tensorboard_dir"] = os.path.join( + base_output_directory, run_name, "tensorboard", "" + ) + raw_keys["checkpoint_dir"] = os.path.join( + base_output_directory, run_name, "checkpoints", "" + ) + raw_keys["metrics_dir"] = os.path.join( + base_output_directory, run_name, "metrics", "" + ) - if raw_keys["learning_rate_schedule_steps"]==-1: + if raw_keys["learning_rate_schedule_steps"] == -1: raw_keys["learning_rate_schedule_steps"] = raw_keys["steps"] - if raw_keys["steps"]==-1: + if raw_keys["steps"] == -1: raw_keys["steps"] = raw_keys["learning_rate_schedule_steps"] - emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(raw_keys['global_parameter_scale']) - raw_keys['emb_dim'] = 2**emb_scale * raw_keys['base_emb_dim'] - raw_keys['num_query_heads'] = 2**num_head_scale * raw_keys['base_num_query_heads'] - raw_keys['num_kv_heads'] = 2**num_head_scale * raw_keys['base_num_kv_heads'] - raw_keys['mlp_dim'] = 2**mlp_dim_scale * raw_keys['base_mlp_dim'] - raw_keys['num_decoder_layers'] = 2**layer_scale * raw_keys['base_num_decoder_layers'] + ( + emb_scale, + num_head_scale, + mlp_dim_scale, + layer_scale, + ) = get_individual_scales(raw_keys["global_parameter_scale"]) + raw_keys["emb_dim"] = 2**emb_scale * raw_keys["base_emb_dim"] + raw_keys["num_query_heads"] = ( + 2**num_head_scale * raw_keys["base_num_query_heads"] + ) + raw_keys["num_kv_heads"] = ( + 2**num_head_scale * raw_keys["base_num_kv_heads"] + ) + raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"] + raw_keys["num_decoder_layers"] = ( + 2**layer_scale * raw_keys["base_num_decoder_layers"] + ) - raw_keys['global_batch_size_to_load'], raw_keys['global_batch_size_to_train_on'] = \ - calculate_global_batch_sizes(raw_keys) - raw_keys['num_slices'] = get_num_slices(raw_keys) - raw_keys['quantization_local_shard_count'] = get_quantization_local_shard_count(raw_keys) + ( + raw_keys["global_batch_size_to_load"], + raw_keys["global_batch_size_to_train_on"], + ) = calculate_global_batch_sizes(raw_keys) + raw_keys["num_slices"] = get_num_slices(raw_keys) + raw_keys[ + "quantization_local_shard_count" + ] = get_quantization_local_shard_count(raw_keys) print_system_information() @@ -246,38 +322,39 @@ def user_init(raw_keys): # Type conversions raw_keys["dtype"] = jax.numpy.dtype(raw_keys["dtype"]) - raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) + raw_keys["logical_axis_rules"] = _lists_to_tuples( + raw_keys["logical_axis_rules"] + ) raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) validate_keys(raw_keys) @staticmethod def configure_gpt3_task(raw_keys): - '''dynamically configure gpt3 task based on training rules''' + """dynamically configure gpt3 task based on training rules""" # follow https://github.com/google/paxml/blob/19db52eed85ae0d2365339b83a97cd0b873bbf73/paxml/tasks/lm/params/c4.py#L280 # according to training_rules of mlperf gpt3 training global_batch_size = calculate_global_batch_sizes(raw_keys)[1] if global_batch_size <= 3584: - raw_keys['learning_rate'] = 2e-5 + raw_keys["learning_rate"] = 2e-5 else: - raw_keys['learning_rate'] = 3e-5 + raw_keys["learning_rate"] = 3e-5 warmup_steps = math.ceil(265.0 * 1536 / global_batch_size - 1e-6) decay_end_step = math.ceil(108600.0 * 1536 / global_batch_size - 1e-6) - raw_keys['learning_rate_schedule_steps'] = decay_end_step - raw_keys['warmup_steps_fraction'] = warmup_steps / decay_end_step + raw_keys["learning_rate_schedule_steps"] = decay_end_step + raw_keys["warmup_steps_fraction"] = warmup_steps / decay_end_step global_batch_size_to_train_on = calculate_global_batch_sizes(raw_keys)[1] - raw_keys['eval_interval'] = math.ceil(24567 / global_batch_size_to_train_on) + raw_keys["eval_interval"] = math.ceil(24567 / global_batch_size_to_train_on) @staticmethod - def update_model_vars(base_config_path, raw_keys, config_name : str): - ''' Update model config variables - ''' - validate_model_name(raw_keys['model_name']) + def update_model_vars(base_config_path, raw_keys, config_name: str): + """Update model config variables""" + validate_model_name(raw_keys["model_name"]) max_logging.log(f"Running Model: {raw_keys['model_name']}") updated_keys = [] - if raw_keys['model_name'] != 'default': - model_name = raw_keys['model_name'] + if raw_keys["model_name"] != "default": + model_name = raw_keys["model_name"] # First look at the model configs next to the base_config_path, and # fallback to the python codebase if the config cannot be found. file_path = os.path.join( @@ -286,40 +363,45 @@ def update_model_vars(base_config_path, raw_keys, config_name : str): if not os.path.isfile(file_path): dir_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.join(dir_path, f"configs/models/{model_name}.yml") - with open(file_path, 'r', encoding="utf-8") as file: + with open(file_path, "r", encoding="utf-8") as file: model_vars = yaml.safe_load(file) updated_keys = list(model_vars.keys()) raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name) return updated_keys -def validate_and_update_keys(raw_keys, model_keys, config_name : str): - ''' Validate and update model specific config keys - ''' + +def validate_and_update_keys(raw_keys, model_keys, config_name: str): + """Validate and update model specific config keys""" max_logging.log("Updating following parameters in config\n") for k in model_keys: max_logging.log(f"{k}: {model_keys[k]}") if k not in raw_keys: - raise ValueError(f'Key {k} does not exist in config {config_name}.') + raise ValueError(f"Key {k} does not exist in config {config_name}.") elif not isinstance(raw_keys[k], type(model_keys[k])): - raise ValueError(f'Type of key:{k} does not match with {type(model_keys[k])}') + raise ValueError( + f"Type of key:{k} does not match with {type(model_keys[k])}" + ) else: raw_keys[k] = model_keys[k] return raw_keys + def get_individual_scales(scale): - '''Choose appropriate scales for individual dimensions based on global scale + """Choose appropriate scales for individual dimensions based on global scale We choose to rotate between doubling: num_head and mlp_dim embed_dim num_layers Any one of these steps is not a perfect doubling, although going through a cycle - of three is a near perfect 8x scaling except for the linear -> softmax -> output step''' - + of three is a near perfect 8x scaling except for the linear -> softmax -> output step + """ log_2_scale = math.floor((math.log2(scale))) if 2**log_2_scale != scale: - raise ValueError("Global parameter scale should be a power of 2. If you want finer grained control of the model sizes " - "then you can explicitly set base_embed_dim, base_num_heads, base_mlp_dim, base_num_decoder_layers and/or head_dim.") + raise ValueError( + "Global parameter scale should be a power of 2. If you want finer grained control of the model sizes " + "then you can explicitly set base_embed_dim, base_num_heads, base_mlp_dim, base_num_decoder_layers and/or head_dim." + ) base_scale, rem = divmod(log_2_scale, 3) num_head_scale = base_scale + int(rem > 0) mlp_dim_scale = num_head_scale @@ -327,10 +409,11 @@ def get_individual_scales(scale): layer_scale = base_scale return emb_scale, num_head_scale, mlp_dim_scale, layer_scale + def calculate_global_batch_sizes(raw_keys): - """ Calculates target global batch size from target devices and per_device_batch""" - per_device_batch_size = raw_keys['per_device_batch_size'] - expansion_factor_real_data = raw_keys['expansion_factor_real_data'] + """Calculates target global batch size from target devices and per_device_batch""" + per_device_batch_size = raw_keys["per_device_batch_size"] + expansion_factor_real_data = raw_keys["expansion_factor_real_data"] num_devices = get_num_target_devices(raw_keys) if per_device_batch_size < 1.0: # For per_device_batch_size<1, we load the data as if per_device_batch_size=1 @@ -340,24 +423,30 @@ def calculate_global_batch_sizes(raw_keys): global_batch_size_to_load = num_devices else: if expansion_factor_real_data != -1: - global_batch_size_to_load = int(num_devices * per_device_batch_size * expansion_factor_real_data) + global_batch_size_to_load = int( + num_devices * per_device_batch_size * expansion_factor_real_data + ) else: global_batch_size_to_load = int(num_devices * per_device_batch_size) global_batch_size_to_train_on = int(num_devices * per_device_batch_size) return global_batch_size_to_load, global_batch_size_to_train_on + def get_num_target_devices(raw_keys): - compile_topology = accelerator_to_spec_map.get_system_characteristics(raw_keys.get('compile_topology', "")) + compile_topology = accelerator_to_spec_map.get_system_characteristics( + raw_keys.get("compile_topology", "") + ) if compile_topology is not None: devices_per_slice = compile_topology.devices_per_slice - return int(devices_per_slice * raw_keys['compile_topology_num_slices']) + return int(devices_per_slice * raw_keys["compile_topology_num_slices"]) else: return len(jax.devices()) + def get_num_slices(raw_keys): - if int(raw_keys['compile_topology_num_slices']) > 0: - return raw_keys['compile_topology_num_slices'] + if int(raw_keys["compile_topology_num_slices"]) > 0: + return raw_keys["compile_topology_num_slices"] else: devices = jax.devices() try: @@ -365,13 +454,16 @@ def get_num_slices(raw_keys): except: return 1 + def get_quantization_local_shard_count(raw_keys): - if raw_keys['quantization_local_shard_count'] == -1: - return raw_keys['num_slices'] + if raw_keys["quantization_local_shard_count"] == -1: + return raw_keys["num_slices"] else: - return raw_keys['quantization_local_shard_count'] + return raw_keys["quantization_local_shard_count"] + + +class HyperParameters: # pylint: disable=missing-class-docstring -class HyperParameters(): # pylint: disable=missing-class-docstring def __init__(self): pass @@ -386,11 +478,13 @@ def __setattr__(self, attr, value): def get_keys(self): return _config.keys + def initialize(argv, **kwargs): global _config, config _config = _HyperParameters(argv, **kwargs) config = HyperParameters() + if __name__ == "__main__": initialize(sys.argv) print(config.steps) diff --git a/MaxText/sequence_packing.py b/MaxText/sequence_packing.py index 36bccde0c..0b54c9d42 100644 --- a/MaxText/sequence_packing.py +++ b/MaxText/sequence_packing.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Packed Sequence Op.""" @@ -23,9 +23,11 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None, +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate @@ -66,29 +68,33 @@ def pack_dataset(dataset: tf.data.Dataset, keys = list(shapes.keys()) for k in keys: if k not in shapes: - raise ValueError(f"""Key {k} not found in dataset. Available keys are - {shapes.keys()}""") + raise ValueError( + f"""Key {k} not found in dataset. Available keys are + {shapes.keys()}""" + ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): - raise ValueError('Tensors to be packed must be one-dimensional.') + raise ValueError("Tensors to be packed must be one-dimensional.") # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" if isinstance(key2length, int): key2length = {k: key2length for k in keys} for k in keys: - for suffix in ['_segmentation', '_position']: + for suffix in ["_segmentation", "_position"]: key2length[k + suffix] = key2length[k] # trim to length dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, + ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) # We pad with a negative value instead of the default 0 because 0 is a # valid token for some tokenizers for e.g., representing unknown value dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1) + batch_size, padded_shapes={k: [-1] for k in keys}, padding_values=-1 + ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -98,8 +104,9 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] +) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: @@ -112,7 +119,7 @@ def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) - empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) + empty_example[k + "_position"] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): @@ -121,7 +128,8 @@ def write_packed_example(partial, outputs): for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) return new_partial, new_outputs def map_fn(x): @@ -139,9 +147,11 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) - outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) + outputs[k + "_position"] = tf.TensorArray( + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -157,13 +167,15 @@ def body_fn(i, partial, outputs): for k in keys: val = tf.cast(x[k][i], tf.int32) # We consider only the valid tokens i.e., token_id != -1 - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, -1), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, -1), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), + ) def false_fn(): return write_packed_example(partial, outputs) @@ -174,12 +186,12 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) - new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], - tf.range(new_seq_len)], 0) + new_partial[k + "_position"] = tf.concat( + [partial[k + "_position"], tf.range(new_seq_len)], 0 + ) partial = new_partial return i + 1, partial, outputs @@ -193,14 +205,14 @@ def true_fn(): {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ), - maximum_iterations=dynamic_batch_size) + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + "_segmentation"] = tf.cumsum( + tf.cast(tf.equal(packed[k + "_position"], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) diff --git a/MaxText/standalone_checkpointer.py b/MaxText/standalone_checkpointer.py index 7e2bb2c1c..8257c776e 100644 --- a/MaxText/standalone_checkpointer.py +++ b/MaxText/standalone_checkpointer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Standalone checkpointer - only saves and restores checkpoints at regular intervals, accesses storage needs.""" @@ -41,6 +41,7 @@ Transformer = models.Transformer + def checkpoint_loop(config, state=None): """Main Checkpointing loop. Saves checkpoints. @@ -50,46 +51,59 @@ def checkpoint_loop(config, state=None): ckpt_path: Returns: """ - init_rng, _ , checkpoint_manager, mesh, model, _, tx = setup_mesh_and_model(config) + init_rng, _, checkpoint_manager, mesh, model, _, tx = setup_mesh_and_model( + config + ) - unboxed_abstract_state, _, _ = max_utils.get_abstract_state(model, tx, - config, init_rng, mesh, is_training=True) + unboxed_abstract_state, _, _ = max_utils.get_abstract_state( + model, tx, config, init_rng, mesh, is_training=True + ) # A barrier to sync all hosts before starting to restore checkpoint jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() with nn_partitioning.axis_rules(config.logical_axis_rules): - state, _ = checkpointing.load_state_if_possible(checkpoint_manager, - None, - config.load_parameters_path, - config.load_full_state_path, - unboxed_abstract_state) + state, _ = checkpointing.load_state_if_possible( + checkpoint_manager, + None, + config.load_parameters_path, + config.load_full_state_path, + unboxed_abstract_state, + ) if state: - state = state['items'] + state = state["items"] jax.block_until_ready(state) checkpoint_load_end = datetime.datetime.now() - if state is not None: # Checkpoint was available for restore + if state is not None: # Checkpoint was available for restore if jax.process_index() == 0: - max_logging.log(f"STANDALONE CHECKPOINTER : Checkpoint restored in : {checkpoint_load_end - checkpoint_load_start}") - else: # Checkpoint was unavailable, state needs to be initialized - state, _, _ = max_utils.setup_training_state(model, None, - tx, config, init_rng, mesh, checkpoint_manager) + max_logging.log( + f"STANDALONE CHECKPOINTER : Checkpoint restored in : {checkpoint_load_end - checkpoint_load_start}" + ) + else: # Checkpoint was unavailable, state needs to be initialized + state, _, _ = max_utils.setup_training_state( + model, None, tx, config, init_rng, mesh, checkpoint_manager + ) state = add_entropy_to_checkpoint(state) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training for step in np.arange(start_step, config.steps): if checkpoint_manager is not None: start_time = datetime.datetime.now() # A barrier to sync all hosts before starting to save checkpoint - jax.experimental.multihost_utils.sync_global_devices("Barrier before save") + jax.experimental.multihost_utils.sync_global_devices( + "Barrier before save" + ) if save_checkpoint(checkpoint_manager, step, state): checkpoint_manager.wait_until_finished() end_time = datetime.datetime.now() if jax.process_index() == 0: - max_logging.log(f"STANDALONE CHECKPOINTER : Checkpoint saved in {end_time - start_time} ,step {step}, on host 0") + max_logging.log( + f"STANDALONE CHECKPOINTER : Checkpoint saved in {end_time - start_time} ,step {step}, on host 0" + ) return state + def add_entropy_to_checkpoint(state): """Introduce randomness in checkpoints. This is useful to simulate real checkpoints, without training. Args: @@ -98,14 +112,23 @@ def add_entropy_to_checkpoint(state): state: Returns state with entropy added to the optimizer state. """ opt_0 = state.opt_state[0] - opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda x: - jax.random.normal(create_random_keys(x), shape=x.shape), state.params)) - opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda x: - jax.random.normal(create_random_keys(x), shape=x.shape), state.params)) + opt_0 = opt_0._replace( + mu=jax.tree_util.tree_map( + lambda x: jax.random.normal(create_random_keys(x), shape=x.shape), + state.params, + ) + ) + opt_0 = opt_0._replace( + nu=jax.tree_util.tree_map( + lambda x: jax.random.normal(create_random_keys(x), shape=x.shape), + state.params, + ) + ) new_opt = [opt_0] + list(state.opt_state[1:]) state = state.replace(opt_state=new_opt) return state + def create_random_keys(x): """Create random keys to help alter the checkpoint state. Args: @@ -115,8 +138,9 @@ def create_random_keys(x): """ return random.PRNGKey(int(jnp.sum(jnp.abs(x)))) + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_cpu_enable_gloo_collectives', True) + jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(argv) config = pyconfig.config diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index 4bb13e4dc..bf6e8b0ba 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -1,15 +1,15 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """ Standalone data loader - only loads data for each training step, accesses storage needs.""" @@ -32,7 +32,7 @@ def data_load_loop(config, state=None): """Main data loader loop. - Loads batches of data for each training step. + Loads batches of data for each training step. """ _, _, _, _, _, _, _, data_iterator, _, state = setup_train_loop(config) @@ -43,22 +43,26 @@ def data_load_loop(config, state=None): example_batch = load_next_batch(data_iterator, example_batch, config) jax.block_until_ready(example_batch) first_end = datetime.datetime.now() - time_to_load_first_batch = first_end-start + time_to_load_first_batch = first_end - start if jax.process_index() == 0: - max_logging.log(f"STANDALONE DATALOADER : First step completed in {time_to_load_first_batch} seconds, on host 0") + max_logging.log( + f"STANDALONE DATALOADER : First step completed in {time_to_load_first_batch} seconds, on host 0" + ) - for _ in np.arange(start_step+1, config.steps): + for _ in np.arange(start_step + 1, config.steps): example_batch = load_next_batch(data_iterator, example_batch, config) - jax.block_until_ready(example_batch) # wait until the last batch is read + jax.block_until_ready(example_batch) # wait until the last batch is read end = datetime.datetime.now() if jax.process_index() == 0: - max_logging.log(f"STANDALONE DATALOADER : {config.steps} batches loaded in {end-start} seconds, on host 0") + max_logging.log( + f"STANDALONE DATALOADER : {config.steps} batches loaded in {end-start} seconds, on host 0" + ) return state def main(argv: Sequence[str]) -> None: - jax.config.update('jax_cpu_enable_gloo_collectives', True) + jax.config.update("jax_cpu_enable_gloo_collectives", True) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" pyconfig.initialize(argv) config = pyconfig.config @@ -70,6 +74,5 @@ def main(argv: Sequence[str]) -> None: data_load_loop(config) - if __name__ == "__main__": app.run(main) diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index e69c2306d..c03bd8fc9 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -36,10 +36,18 @@ class AttentionTest(unittest.TestCase): - """Test for the Attention """ + """Test for the Attention""" + def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size = 1.0, run_name='test', enable_checkpointing=False, max_target_length=128, max_prefill_predict_length=16 ) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + ) self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(0) @@ -62,20 +70,23 @@ def setUp(self): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "dot_product", + attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) - self._attention_as_mha_generic_variable = self._attention_as_mha_generic.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + self._attention_as_mha_generic_variable = ( + self._attention_as_mha_generic.init( + {"params": self.rng, "aqt": self.rng}, + jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ), + jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ), + jnp.ones((self.global_batch_size, self.max_target_length)), + ) ) def get_data(self, dtype): @@ -85,8 +96,15 @@ def get_data(self, dtype): dtype=dtype, ) - decoder_segment_ids = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, 4) - decoder_positions = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, self.max_target_length) + decoder_segment_ids = jax.random.randint( + self.rng, (self.global_batch_size, self.max_target_length), 0, 4 + ) + decoder_positions = jax.random.randint( + self.rng, + (self.global_batch_size, self.max_target_length), + 0, + self.max_target_length, + ) return lnx, decoder_segment_ids, decoder_positions @@ -97,23 +115,28 @@ def get_structured_data(self, dtype): dtype=dtype, ) - decoder_positions = jnp.stack([ - jnp.arange(self.max_target_length, dtype=jnp.int32) - for _ in range(self.global_batch_size) - ]) + decoder_positions = jnp.stack( + [ + jnp.arange(self.max_target_length, dtype=jnp.int32) + for _ in range(self.global_batch_size) + ] + ) - decoder_segment_ids = jax.numpy.zeros((self.global_batch_size, self.max_target_length))\ - + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + decoder_segment_ids = ( + jax.numpy.zeros((self.global_batch_size, self.max_target_length)) + + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + ) return lnx, decoder_segment_ids, decoder_positions - + @pytest.mark.tpu def test_autoregression(self): prefill_length = self.cfg.max_prefill_predict_length decode_total_length = self.cfg.max_target_length lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - self.dtype) - + self.dtype + ) + mha_full = self._attention_as_mha_generic.apply( self._attention_as_mha_generic_variable, lnx, @@ -122,13 +145,13 @@ def test_autoregression(self): inputs_positions=decoder_positions, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) - + lnx_prefill = lnx[:, 0:prefill_length, :] decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - + mha_prefill, output_cache = self._attention_as_mha_generic.apply( self._attention_as_mha_generic_variable, lnx_prefill, @@ -137,39 +160,45 @@ def test_autoregression(self): inputs_positions=decoder_positions_prefill, deterministic=True, model_mode=common_types.MODEL_MODE_PREFILL, - rngs={'aqt': self.rng}, - mutable=["cache"] + rngs={"aqt": self.rng}, + mutable=["cache"], ) self.assertTrue( jax.numpy.allclose( - mha_prefill, mha_full[:,:prefill_length,:], rtol=1e-02, atol=1e-02, equal_nan=False + mha_prefill, + mha_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, ) ) for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx:idx+1, :] - decoder_positions_idx = decoder_positions[:, idx:idx+1] + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] self._attention_as_mha_generic_variable.update(output_cache) mha_idx, output_cache = self._attention_as_mha_generic.apply( - self._attention_as_mha_generic_variable, - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'aqt': self.rng}, - mutable=["cache"] + self._attention_as_mha_generic_variable, + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], ) - mha_full_this_idx = mha_full[:,idx:idx+1,:] - self.assertTrue( - mha_full_this_idx.shape == mha_idx.shape - ) + mha_full_this_idx = mha_full[:, idx : idx + 1, :] + self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) self.assertTrue( - jax.numpy.allclose( - mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False - ) + jax.numpy.allclose( + mha_full_this_idx, + mha_idx, + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) ) @pytest.mark.tpu @@ -187,8 +216,7 @@ def test_tpu_kernel_attention_mqa(self): def tpu_kernel_attention_helper(self, num_kv_heads): """Test equalvant between dot_product and TPU accelerated""" - lnx, decoder_segment_ids, decoder_positions = self.get_data( - self.dtype) + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) attention_as_mha_generic = Attention( config=self.cfg, @@ -198,20 +226,21 @@ def tpu_kernel_attention_helper(self, num_kv_heads): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "dot_product", + attention_kernel="dot_product", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) attention_as_mha_generic_variable = attention_as_mha_generic.init( - {'params': self.rng, 'aqt': self.rng}, + {"params": self.rng, "aqt": self.rng}, jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), + (self.global_batch_size, self.max_target_length, self.embed_dim) + ), jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones( - (self.global_batch_size, self.max_target_length)), + (self.global_batch_size, self.max_target_length, self.embed_dim) + ), + jnp.ones((self.global_batch_size, self.max_target_length)), ) mha_generic_output = attention_as_mha_generic.apply( @@ -222,7 +251,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): inputs_positions=decoder_segment_ids, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) attention_as_mha_flash = Attention( @@ -233,20 +262,21 @@ def tpu_kernel_attention_helper(self, num_kv_heads): max_target_length=self.max_target_length, max_prefill_predict_length=self.cfg.max_prefill_predict_length, mesh=self.mesh, - attention_kernel = "flash", + attention_kernel="flash", dtype=self.dtype, dropout_rate=self.cfg.dropout_rate, - name='self_attention', + name="self_attention", ) attention_as_mha_flash_variable = attention_as_mha_flash.init( - {'params': self.rng, 'aqt': self.rng}, - jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), + {"params": self.rng, "aqt": self.rng}, jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim)), + (self.global_batch_size, self.max_target_length, self.embed_dim) + ), jnp.ones( - (self.global_batch_size, self.max_target_length)), + (self.global_batch_size, self.max_target_length, self.embed_dim) + ), + jnp.ones((self.global_batch_size, self.max_target_length)), ) mha_generic_flash_output = attention_as_mha_flash.apply( @@ -257,14 +287,19 @@ def tpu_kernel_attention_helper(self, num_kv_heads): inputs_positions=decoder_segment_ids, deterministic=True, model_mode=common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng}, + rngs={"aqt": self.rng}, ) self.assertTrue( jax.numpy.allclose( - mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False + mha_generic_output, + mha_generic_flash_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, ) ) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index 4650824bf..d99494db9 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -1,19 +1,18 @@ - """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for GPT3 """ import sys @@ -38,33 +37,39 @@ def init_random_model_vars(model, rng, example_batch): """initialze random model vars.""" model_vars = model.init( - {'params': rng, 'aqt': rng}, - example_batch['inputs'], - example_batch['inputs_position'], + {"params": rng, "aqt": rng}, + example_batch["inputs"], + example_batch["inputs_position"], enable_dropout=False, ) + def _replace_initialization(key, value): keystr = jax.tree_util.keystr(key) # replace zero initializer to ensure strong test cases # including Gpt3LayerNorm scale, Gpt3LayerNorm bias, and DenseGeneral bias if "scale" in keystr or "bias" in keystr: - value = jax.nn.initializers.normal(1.0)(rng, value.shape, dtype=value.dtype) + value = jax.nn.initializers.normal(1.0)( + rng, value.shape, dtype=value.dtype + ) return value - model_vars = jax.tree_util.tree_map_with_path(_replace_initialization, model_vars) + model_vars = jax.tree_util.tree_map_with_path( + _replace_initialization, model_vars + ) return model_vars class GPT3(unittest.TestCase): """numerical tests for GPT3.""" + def setUp(self): super().setUp() pyconfig.initialize( - [sys.argv[0], 'configs/base.yml'], - run_name='test', - enable_checkpointing=False, - model_name='gpt3-52k', - dtype='float32', + [sys.argv[0], "configs/base.yml"], + run_name="test", + enable_checkpointing=False, + model_name="gpt3-52k", + dtype="float32", ) self.cfg = pyconfig.config @@ -73,16 +78,18 @@ def setUp(self): devices_array = max_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) quant = quantizations.configure_quantization(self.cfg) - self.model = models.Transformer(config = self.cfg, mesh = mesh, quant = quant) + self.model = models.Transformer(config=self.cfg, mesh=mesh, quant=quant) self.example_batch = { - 'inputs': jnp.array([[11, 12, 13, 14, 15]], dtype=jnp.int32), - 'inputs_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), - 'inputs_segmentation': jnp.array([[1, 1, 1, 1, 1]], dtype=jnp.int32), - 'targets': jnp.array([[12, 13, 14, 15, 1]], dtype=jnp.int32), - 'targets_position': jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), - 'targets_segmentation': jnp.array([[1, 1, 1, 1, 0]], dtype=jnp.int32), + "inputs": jnp.array([[11, 12, 13, 14, 15]], dtype=jnp.int32), + "inputs_position": jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + "inputs_segmentation": jnp.array([[1, 1, 1, 1, 1]], dtype=jnp.int32), + "targets": jnp.array([[12, 13, 14, 15, 1]], dtype=jnp.int32), + "targets_position": jnp.array([[0, 1, 2, 3, 4]], dtype=jnp.int32), + "targets_segmentation": jnp.array([[1, 1, 1, 1, 0]], dtype=jnp.int32), } - self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch) + self.model_vars = init_random_model_vars( + self.model, self.rng, self.example_batch + ) @pytest.mark.tpu def test_logits_numerically(self): @@ -91,18 +98,29 @@ def test_logits_numerically(self): # paxml applies padding in mlp layer # while maxtext implementaiton applies padding in attention mask instead # the two implementation are equivalent in valid non-padding tokens - per_example_xent_truth = jnp.array([[31.976467, 25.806253, 17.311134, 45.362663, 0.]], dtype=jnp.float32) - logits, _ = self.model.apply(self.model_vars, - self.example_batch['inputs'], - self.example_batch['inputs_position'], - decoder_segment_ids=self.example_batch['inputs_segmentation'], - enable_dropout=self.cfg.enable_dropout, - rngs={'dropout': self.rng, 'aqt': self.rng}, mutable='intermediates') - - one_hot_targets = jax.nn.one_hot(self.example_batch['targets'], self.cfg.vocab_size) - per_example_xent = -jnp.sum(jax.nn.log_softmax(logits) * one_hot_targets, axis=-1, dtype=jnp.float32) + per_example_xent_truth = jnp.array( + [[31.976467, 25.806253, 17.311134, 45.362663, 0.0]], dtype=jnp.float32 + ) + logits, _ = self.model.apply( + self.model_vars, + self.example_batch["inputs"], + self.example_batch["inputs_position"], + decoder_segment_ids=self.example_batch["inputs_segmentation"], + enable_dropout=self.cfg.enable_dropout, + rngs={"dropout": self.rng, "aqt": self.rng}, + mutable="intermediates", + ) + + one_hot_targets = jax.nn.one_hot( + self.example_batch["targets"], self.cfg.vocab_size + ) + per_example_xent = -jnp.sum( + jax.nn.log_softmax(logits) * one_hot_targets, axis=-1, dtype=jnp.float32 + ) # Mask out paddings at the end of each example. - per_example_xent = per_example_xent * (self.example_batch['targets_segmentation'] != 0) + per_example_xent = per_example_xent * ( + self.example_batch["targets_segmentation"] != 0 + ) self.assertTrue( jax.numpy.allclose( diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index 8af271f53..08b35d746 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import subprocess @@ -27,117 +27,155 @@ from input_pipeline import _grain_data_processing from input_pipeline import input_pipeline_interface + class GrainDataProcessingTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - super().setUpClass() - exit_code = subprocess.call(['bash','../setup_gcsfuse.sh', - 'DATASET_GCS_BUCKET=maxtext-dataset', - 'MOUNT_PATH=/tmp/gcsfuse']) - if exit_code != 0: - raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") - - def setUp(self): - super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], - per_device_batch_size=1, - run_name='test', - mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "/tmp/gcsfuse", - tokenizer_path = "../assets/tokenizer", - enable_checkpointing=False, - dataset_type="c4-array_record", - dataset_name='array-record/c4/en/3.0.1', - eval_dataset_name='array-record/c4/en/3.0.1') - self.config = pyconfig.config - self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) - self.train_ds, self.eval_ds = self._get_datasets() - self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() - - def _get_datasets(self): - print("Sharding dataset in ", jax.process_count(), " shards") - train_ds, eval_ds = _grain_data_processing.get_datasets( - config=self.config) - return train_ds, eval_ds - - def _get_preprocessed_datasets(self): - process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) - train_iter, eval_iter, test_iter = _grain_data_processing.preprocess_dataset( - self.config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - global_mesh = self.mesh, - train_ds = self.train_ds, eval_ds = self.eval_ds, - vocab_path=self.config.tokenizer_path) - return train_iter, eval_iter, test_iter - - def test_train_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - # For training we pack multiple short examples in one example. - # *_position and *_segmentation indicate the boundaries. - batch = next(self.train_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - def test_eval_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - batch = next(self.eval_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - - def test_predict_ds(self): - expected_shape = [jax.device_count(), self.config.max_target_length] - batch = next(self.predict_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - - def test_batch_determinism(self): - batch1 = next(self.train_iter) - self.train_ds, _ = self._get_datasets() - train_iter, _, _= self._get_preprocessed_datasets() - batch2 = next(train_iter) - self.assertTrue((batch1['inputs']==batch2['inputs']).all()) - self.assertTrue((batch1['targets']==batch2['targets']).all()) - self.assertTrue((batch1['inputs_segmentation']==batch2['inputs_segmentation']).all()) - self.assertTrue((batch1['targets_segmentation']==batch2['targets_segmentation']).all()) - self.assertTrue((batch1['inputs_position']==batch2['inputs_position']).all()) - self.assertTrue((batch1['targets_position']==batch2['targets_position']).all()) - - def test_for_loop_repeatable(self): - def get_first_batch(iterator): - batch = None - for batch in iterator: - break - return batch - - eval_batch1 = get_first_batch(self.eval_iter) - eval_batch2 = get_first_batch(self.eval_iter) - self.assertTrue((eval_batch1['inputs']==eval_batch2['inputs']).all()) - self.assertTrue((eval_batch1['targets']==eval_batch2['targets']).all()) - - -if __name__ == '__main__': + + @classmethod + def setUpClass(cls): + super().setUpClass() + exit_code = subprocess.call([ + "bash", + "../setup_gcsfuse.sh", + "DATASET_GCS_BUCKET=maxtext-dataset", + "MOUNT_PATH=/tmp/gcsfuse", + ]) + if exit_code != 0: + raise ValueError( + f"Running setup_gcsfuse.sh failed with exit code: {exit_code}" + ) + + def setUp(self): + super().setUp() + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="/tmp/gcsfuse", + tokenizer_path="../assets/tokenizer", + enable_checkpointing=False, + dataset_type="c4-array_record", + dataset_name="array-record/c4/en/3.0.1", + eval_dataset_name="array-record/c4/en/3.0.1", + ) + self.config = pyconfig.config + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh( + mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes + ) + self.train_ds, self.eval_ds = self._get_datasets() + ( + self.train_iter, + self.eval_iter, + self.predict_iter, + ) = self._get_preprocessed_datasets() + + def _get_datasets(self): + print("Sharding dataset in ", jax.process_count(), " shards") + train_ds, eval_ds = _grain_data_processing.get_datasets(config=self.config) + return train_ds, eval_ds + + def _get_preprocessed_datasets(self): + process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config, self.mesh + ) + ( + train_iter, + eval_iter, + test_iter, + ) = _grain_data_processing.preprocess_dataset( + self.config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + global_mesh=self.mesh, + train_ds=self.train_ds, + eval_ds=self.eval_ds, + vocab_path=self.config.tokenizer_path, + ) + return train_iter, eval_iter, test_iter + + def test_train_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + # For training we pack multiple short examples in one example. + # *_position and *_segmentation indicate the boundaries. + batch = next(self.train_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_eval_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.eval_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_predict_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.predict_iter) + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + def test_batch_determinism(self): + batch1 = next(self.train_iter) + self.train_ds, _ = self._get_datasets() + train_iter, _, _ = self._get_preprocessed_datasets() + batch2 = next(train_iter) + self.assertTrue((batch1["inputs"] == batch2["inputs"]).all()) + self.assertTrue((batch1["targets"] == batch2["targets"]).all()) + self.assertTrue( + (batch1["inputs_segmentation"] == batch2["inputs_segmentation"]).all() + ) + self.assertTrue( + (batch1["targets_segmentation"] == batch2["targets_segmentation"]).all() + ) + self.assertTrue( + (batch1["inputs_position"] == batch2["inputs_position"]).all() + ) + self.assertTrue( + (batch1["targets_position"] == batch2["targets_position"]).all() + ) + + def test_for_loop_repeatable(self): + def get_first_batch(iterator): + batch = None + for batch in iterator: + break + return batch + + eval_batch1 = get_first_batch(self.eval_iter) + eval_batch2 = get_first_batch(self.eval_iter) + self.assertTrue((eval_batch1["inputs"] == eval_batch2["inputs"]).all()) + self.assertTrue((eval_batch1["targets"] == eval_batch2["targets"]).all()) + + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index 274739a7f..b305562c2 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -22,18 +22,20 @@ class Inference_Microbenchmark(unittest.TestCase): + @pytest.mark.tpu def test(self): - pyconfig.initialize([None, - "configs/tpu_smoke_test.yml", - "tokenizer_path=../assets/tokenizer.llama2", - "ici_autoregressive_parallelism=-1", - "ici_fsdp_parallelism=1", - "max_prefill_predict_length=1024", - "max_target_length=2048", - "scan_layers=false", - "weight_dtype=bfloat16", - ]) + pyconfig.initialize([ + None, + "configs/tpu_smoke_test.yml", + "tokenizer_path=../assets/tokenizer.llama2", + "ici_autoregressive_parallelism=-1", + "ici_fsdp_parallelism=1", + "max_prefill_predict_length=1024", + "max_target_length=2048", + "scan_layers=false", + "weight_dtype=bfloat16", + ]) inference_microbenchmark_main(pyconfig.config) diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index 38e1a7c62..68388e944 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for Llama """ import jax @@ -31,6 +31,7 @@ """ + def precompute_freqs_cis( dim: int, end: int, theta: float = 10000.0, dtype: jnp.dtype = jnp.float32 ) -> jnp.ndarray: @@ -45,14 +46,13 @@ def precompute_freqs_cis( return jnp.asarray(freqs_cis) - def apply_rotary_emb( - xq: jnp.ndarray, - xk: jnp.ndarray, - freqs_cis: jnp.ndarray, - dtype: jnp.dtype = jnp.bfloat16, + xq: jnp.ndarray, + xk: jnp.ndarray, + freqs_cis: jnp.ndarray, + dtype: jnp.dtype = jnp.bfloat16, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ Apply the computed Rotary Postional Embedding""" + """Apply the computed Rotary Postional Embedding""" reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) @@ -76,13 +76,16 @@ def apply_rotary_emb( return xq_out.astype(dtype), xk_out.astype(dtype) + def permute_to_match_maxtext_rope(arr): evens = arr[..., ::2] odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim-1) + return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) + class RoPETest(unittest.TestCase): - """Test for the RoPE implementation """ + """Test for the RoPE implementation""" + def test_rope(self): dim_per_head = 128 seq_len = 8 @@ -105,12 +108,32 @@ def test_rope(self): position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings from MaxText implementation - query_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_q), position = position) - key_proj = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(permute_to_match_maxtext_rope(x_k), position = position) + query_proj = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)( + permute_to_match_maxtext_rope(x_q), position=position + ) + key_proj = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)( + permute_to_match_maxtext_rope(x_k), position=position + ) # Compare results - self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[0]), query_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) - self.assertTrue(jax.numpy.allclose(permute_to_match_maxtext_rope(llama_output[1]), key_proj, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + permute_to_match_maxtext_rope(llama_output[0]), + query_proj, + rtol=1e-01, + atol=1e-04, + equal_nan=False, + ) + ) + self.assertTrue( + jax.numpy.allclose( + permute_to_match_maxtext_rope(llama_output[1]), + key_proj, + rtol=1e-01, + atol=1e-04, + equal_nan=False, + ) + ) def test_scaling_rope(self): dim_per_head = 128 @@ -121,15 +144,23 @@ def test_scaling_rope(self): position = jnp.arange(seq_len, dtype=jnp.float32)[jnp.newaxis, :] # Calculate RoPE embeddings and then scale - query_proj_1 = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(x_q, position = position) - query_proj_1 = query_proj_1 * (dim_per_head ** -0.5) + query_proj_1 = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)( + x_q, position=position + ) + query_proj_1 = query_proj_1 * (dim_per_head**-0.5) # scale first and then apply RoPE - query_proj_2 = x_q * (dim_per_head ** -0.5) - query_proj_2 = embeddings.RotaryEmbedding(embedding_dims = dim_per_head)(query_proj_2, position=position) + query_proj_2 = x_q * (dim_per_head**-0.5) + query_proj_2 = embeddings.RotaryEmbedding(embedding_dims=dim_per_head)( + query_proj_2, position=position + ) - self.assertTrue(jax.numpy.allclose(query_proj_2, query_proj_1, rtol=1e-01, atol=1e-04, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + query_proj_2, query_proj_1, rtol=1e-01, atol=1e-04, equal_nan=False + ) + ) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/max_utils_test.py b/MaxText/tests/max_utils_test.py index c4c1480dd..5402400dd 100644 --- a/MaxText/tests/max_utils_test.py +++ b/MaxText/tests/max_utils_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the common Max Utils """ import jax @@ -30,55 +30,66 @@ Transformer = models.Transformer + class MaxUtilsSummaryStats(unittest.TestCase): """Tests for the summary stats functions in max_utils.py""" + def test_l2norm_pytree(self): - x = {'a': jax.numpy.array([0, 2, 0]), 'b': jax.numpy.array([0, 3, 6])} + x = {"a": jax.numpy.array([0, 2, 0]), "b": jax.numpy.array([0, 3, 6])} pytree_l2_norm = max_utils.l2norm_pytree(x) - self.assertTrue(jax.numpy.allclose(pytree_l2_norm, 7, rtol=1e-05, atol=1e-08, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + pytree_l2_norm, 7, rtol=1e-05, atol=1e-08, equal_nan=False + ) + ) + class MaxUtilsInitState(unittest.TestCase): """Tests initialization of training and decode states in max_utils.py""" + def setUp(self): self.model = nn.Dense(features=5) self.key1, self.key2 = random.split(random.key(0)) - self.input = random.normal(self.key1, (10,)) # Dummy input data + self.input = random.normal(self.key1, (10,)) # Dummy input data self.params = self.model.init(self.key2, self.input) self.output = self.model.apply(self.params, self.input) self.tx = optax.adam(learning_rate=0.001) def test_calculate_num_params_from_pytree(self): example_tree = [ - [1, 'a', object()], - (1, (2, 3), ()), - [1, {'k1': 2, 'k2': (3, 4)}, 5], - {'a': 2, 'b': (2, 3)}, - jnp.array([1, 2, 3]), - ] - self.assertEqual(max_utils.calculate_num_params_from_pytree(example_tree), 17) + [1, "a", object()], + (1, (2, 3), ()), + [1, {"k1": 2, "k2": (3, 4)}, 5], + {"a": 2, "b": (2, 3)}, + jnp.array([1, 2, 3]), + ] + self.assertEqual( + max_utils.calculate_num_params_from_pytree(example_tree), 17 + ) # Model params - self.assertEqual(max_utils.calculate_num_params_from_pytree(self.params), 55) + self.assertEqual( + max_utils.calculate_num_params_from_pytree(self.params), 55 + ) def test_init_train_state(self): state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.params, - tx=None, # type: ignore - opt_state={} + step=0, + apply_fn=self.model.apply, + params=self.params, + tx=None, # type: ignore + opt_state={}, ) self.assertEqual(state.tx, None) self.assertEqual(state.step, 0) self.assertEqual(state.opt_state, {}) self.assertEqual(state.apply_fn, self.model.apply) - self.assertEqual(max_utils.calculate_num_params_from_pytree(state.params), - max_utils.calculate_num_params_from_pytree(self.params)) - + self.assertEqual( + max_utils.calculate_num_params_from_pytree(state.params), + max_utils.calculate_num_params_from_pytree(self.params), + ) def test_init_decode_state(self): - decode_state = max_utils.init_decode_state( - self.model.apply, self.params - ) + decode_state = max_utils.init_decode_state(self.model.apply, self.params) self.assertEqual(decode_state.apply_fn, self.model.apply) output = decode_state.apply_fn(self.params, self.input) self.assertEqual(output.tolist(), self.output.tolist()) @@ -86,34 +97,39 @@ def test_init_decode_state(self): self.assertEqual(decode_state.opt_state, {}) self.assertEqual(decode_state.step, 0) self.assertEqual( - max_utils.calculate_num_params_from_pytree(decode_state.params), - max_utils.calculate_num_params_from_pytree(self.params) + max_utils.calculate_num_params_from_pytree(decode_state.params), + max_utils.calculate_num_params_from_pytree(self.params), ) def test_init_training_state(self): - state = max_utils.init_training_state(self.model.apply, self.params, self.tx) + state = max_utils.init_training_state( + self.model.apply, self.params, self.tx + ) self.assertEqual(state.apply_fn, self.model.apply) self.assertEqual(state.tx, self.tx) self.assertNotEqual(state.opt_state, {}) self.assertEqual( - max_utils.calculate_num_params_from_pytree(state.params), - max_utils.calculate_num_params_from_pytree(self.params) + max_utils.calculate_num_params_from_pytree(state.params), + max_utils.calculate_num_params_from_pytree(self.params), ) + class ModelWithMultipleCollections(nn.Module): - """ - A simple model that has variables in multiple collections - "params" and "special_variables" - """ - def setup(self): - self.dense = nn.Dense(4) - self.kernel = self.variable( + """ + A simple model that has variables in multiple collections - "params" and "special_variables" + """ + + def setup(self): + self.dense = nn.Dense(4) + self.kernel = self.variable( "special_variables", "my_first_kernel", lambda: jnp.ones((4, 5)) - ) - - def __call__(self, x, y): - x = self.dense(x) - x = x @ self.kernel.value - return x + ) + + def __call__(self, x, y): + x = self.dense(x) + x = x @ self.kernel.value + return x + class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase): @@ -122,13 +138,17 @@ def setUp(self): self.config = pyconfig.config self.model = ModelWithMultipleCollections() self.key1, self.key2, self.key3 = random.split(random.key(0), num=3) - self.input = random.normal(self.key1, - (self.config.global_batch_size_to_load, self.config.max_target_length)) + self.input = random.normal( + self.key1, + (self.config.global_batch_size_to_load, self.config.max_target_length), + ) self.params = self.model.init(self.key2, self.input, self.input) self.tx = optax.adam(learning_rate=0.001) def _test_init_initial_state_driver(self, is_training): - state_under_test = max_utils.init_initial_state(self.model, self.tx, self.config, is_training, self.key3) + state_under_test = max_utils.init_initial_state( + self.model, self.tx, self.config, is_training, self.key3 + ) self.assertEqual(state_under_test.apply_fn, self.model.apply) if is_training: self.assertEqual(state_under_test.tx, self.tx) @@ -137,19 +157,16 @@ def _test_init_initial_state_driver(self, is_training): self.assertIsNone(state_under_test.tx) self.assertEqual(state_under_test.opt_state, {}) self.assertEqual( - max_utils.calculate_num_params_from_pytree(state_under_test.params), - max_utils.calculate_num_params_from_pytree(self.params) - ) - self.assertEqual( - len(self.params), - len(state_under_test.params) + max_utils.calculate_num_params_from_pytree(state_under_test.params), + max_utils.calculate_num_params_from_pytree(self.params), ) + self.assertEqual(len(self.params), len(state_under_test.params)) self.assertIn("special_variables", state_under_test.params) self.assertIn("params", state_under_test.params) - + def test_initial_train_state(self): self._test_init_initial_state_driver(True) - + def test_initial_decode_state(self): self._test_init_initial_state_driver(False) @@ -168,7 +185,8 @@ def setUp(self): def test_setup_decode_state(self): rng = random.PRNGKey(0) state, _ = max_utils.setup_decode_state( - self.model, self.config, rng, self.mesh, None) + self.model, self.config, rng, self.mesh, None + ) self.assertEqual(state.tx, None) self.assertEqual(state.opt_state, {}) @@ -176,30 +194,46 @@ def test_setup_initial_state(self): rng = random.PRNGKey(0) tx = optax.adam(learning_rate=0.001) state, _, _ = max_utils.setup_initial_state( - self.model, None, tx, self.config, rng, self.mesh, None) + self.model, None, tx, self.config, rng, self.mesh, None + ) self.assertEqual(state.tx, tx) self.assertNotEqual(state.opt_state, {}) + class MaxUtilsT5XCrossEntropy(unittest.TestCase): """Tests for the cross entropy functions in max_utils.py""" + def test_t5x_cross_entropy(self): # Generate random targets and logits key = jax.random.PRNGKey(0) - targets = jax.random.randint(key, shape=(48, 2048), - dtype=jax.numpy.int32, minval=1, maxval=10) - logits = jax.random.uniform(key, shape=(48, 2048, 4096), - dtype=jax.numpy.float32) + targets = jax.random.randint( + key, shape=(48, 2048), dtype=jax.numpy.int32, minval=1, maxval=10 + ) + logits = jax.random.uniform( + key, shape=(48, 2048, 4096), dtype=jax.numpy.float32 + ) # Calculate xent from optax implementation - optax_xent = optax.softmax_cross_entropy_with_integer_labels(logits, targets) + optax_xent = optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ) # Calculate xent from custom T5X implementation one_hot_targets = jax.nn.one_hot(targets, 4096) - t5x_xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) - t5x_xent = nn.with_logical_constraint(t5x_xent, ('activation_batch', 'activation_length')) + t5x_xent, _ = max_utils.cross_entropy_with_logits( + logits, one_hot_targets, 0.0 + ) + t5x_xent = nn.with_logical_constraint( + t5x_xent, ("activation_batch", "activation_length") + ) # Compare results - self.assertTrue(jax.numpy.allclose(optax_xent, t5x_xent, rtol=1e-05, atol=1e-08, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + optax_xent, t5x_xent, rtol=1e-05, atol=1e-08, equal_nan=False + ) + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index 5ab89406e..4c53976a0 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -30,35 +30,46 @@ from layers import quantizations Mesh = jax.sharding.Mesh -MAX_PREFILL_PREDICT_LENGTH = 4 +MAX_PREFILL_PREDICT_LENGTH = 4 + class TestModel(unittest.TestCase): - """Test the Whole Model """ + """Test the Whole Model""" + def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size = 1.0, run_name='test', - enable_checkpointing=False, base_num_decoder_layers=2, attention="dot_product", - max_target_length=16, base_emb_dim=256, base_num_query_heads=2, base_num_kv_heads=2, max_prefill_predict_length=4) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + base_num_decoder_layers=2, + attention="dot_product", + max_target_length=16, + base_emb_dim=256, + base_num_query_heads=2, + base_num_kv_heads=2, + max_prefill_predict_length=4, + ) self.cfg = pyconfig.config self.rng = jax.random.PRNGKey(0) def get_data(self): s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length) - ids = jax.random.randint( - self.rng, - s, - 0, - self.cfg.vocab_size - ) + ids = jax.random.randint(self.rng, s, 0, self.cfg.vocab_size) - decoder_segment_ids = jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR - decoder_positions = jnp.stack([ - jnp.arange(self.cfg.max_target_length, dtype=jnp.int32) - for _ in range(self.cfg.global_batch_size_to_train_on) - ]) + decoder_segment_ids = ( + jax.numpy.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + ) + decoder_positions = jnp.stack( + [ + jnp.arange(self.cfg.max_target_length, dtype=jnp.int32) + for _ in range(self.cfg.global_batch_size_to_train_on) + ] + ) return ids, decoder_segment_ids, decoder_positions - + @pytest.mark.tpu def test_train_vs_prefill_and_autoregress(self): PREFILL_RANGE = MAX_PREFILL_PREDICT_LENGTH @@ -66,68 +77,75 @@ def test_train_vs_prefill_and_autoregress(self): devices_array = max_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) quant = quantizations.configure_quantization(self.cfg) - model = models.Transformer(config = self.cfg, mesh = mesh, quant=quant) + model = models.Transformer(config=self.cfg, mesh=mesh, quant=quant) ids, decoder_segment_ids, decoder_positions = self.get_data() transformer_vars = model.init( - {'params': self.rng, 'aqt': self.rng}, + {"params": self.rng, "aqt": self.rng}, ids, decoder_positions, decoder_segment_ids, - enable_dropout=False + enable_dropout=False, ) full_train_logits = model.apply( - transformer_vars, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - model_mode = common_types.MODEL_MODE_TRAIN, - rngs={'aqt': self.rng} + transformer_vars, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_TRAIN, + rngs={"aqt": self.rng}, ) partial_prefill_logits, partial_cache = model.apply( - transformer_vars, - ids[:, :PREFILL_RANGE], - decoder_positions[:, :PREFILL_RANGE], - decoder_segment_ids=decoder_segment_ids[:, :PREFILL_RANGE], - enable_dropout=False, - model_mode = common_types.MODEL_MODE_PREFILL, - rngs={'aqt': self.rng}, - mutable=["cache"], + transformer_vars, + ids[:, :PREFILL_RANGE], + decoder_positions[:, :PREFILL_RANGE], + decoder_segment_ids=decoder_segment_ids[:, :PREFILL_RANGE], + enable_dropout=False, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"aqt": self.rng}, + mutable=["cache"], ) self.assertTrue( jax.numpy.allclose( - full_train_logits[:,:PREFILL_RANGE,:], partial_prefill_logits, rtol=1e-01, atol=1e-01, equal_nan=False + full_train_logits[:, :PREFILL_RANGE, :], + partial_prefill_logits, + rtol=1e-01, + atol=1e-01, + equal_nan=False, ) ) for idx in range(PREFILL_RANGE, self.cfg.max_target_length): - ids_idx = ids[:, idx:idx+1] - decoder_positions_idx = decoder_positions[:, idx:idx+1] + ids_idx = ids[:, idx : idx + 1] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] transformer_vars.update(partial_cache) ar_logits, partial_cache = model.apply( - transformer_vars, - ids_idx, - decoder_positions_idx, - enable_dropout=False, - model_mode = common_types.MODEL_MODE_AUTOREGRESSIVE, - rngs={'aqt': self.rng}, - mutable=["cache"], + transformer_vars, + ids_idx, + decoder_positions_idx, + enable_dropout=False, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], ) - full_train_logits_idx = full_train_logits[:,idx:idx+1,:] - self.assertTrue( - full_train_logits_idx.shape == ar_logits.shape - ) + full_train_logits_idx = full_train_logits[:, idx : idx + 1, :] + self.assertTrue(full_train_logits_idx.shape == ar_logits.shape) self.assertTrue( - jax.numpy.allclose( - full_train_logits_idx, ar_logits, rtol=1e-01, atol=1e-01, equal_nan=False - ) + jax.numpy.allclose( + full_train_logits_idx, + ar_logits, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ) ) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index e620c2e1e..0d3d3a7c2 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, missing-function-docstring import sys @@ -35,28 +35,37 @@ class MultihostDataloadingTest(unittest.TestCase): def setUp(self): super().setUp() batch_size = 4 - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "gs://maxtext-dataset/", - enable_checkpointing=False) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="gs://maxtext-dataset/", + enable_checkpointing=False, + ) config = pyconfig.config global_data_shape = PartitionSpec(batch_size, config.max_target_length) - data_sharding = ('data',) + data_sharding = ("data",) mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes) - data_axes = PartitionSpec('data',) + self.mesh = Mesh( + mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes + ) + data_axes = PartitionSpec( + "data", + ) # creating 2 batches of data - global_data = np.arange(np.prod(global_data_shape)*2).reshape((batch_size * 2, config.max_target_length)) + global_data = np.arange(np.prod(global_data_shape) * 2).reshape( + (batch_size * 2, config.max_target_length) + ) dataset = tf.data.Dataset.from_tensor_slices(global_data) dataset = dataset.repeat() dataset = dataset.batch(batch_size) - self.multihost_gen = ( - multihost_dataloading.MultiHostDataLoadIterator( - dataset, self.mesh - ) + self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator( + dataset, self.mesh ) @pytest.mark.tpu @@ -66,5 +75,5 @@ def test_batch_sharded_data_pipeline(self): self.assertTrue(not np.array_equal(first_batch, sec_batch, equal_nan=True)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/profiler_test.py b/MaxText/tests/profiler_test.py index 095073008..4c76e12fa 100644 --- a/MaxText/tests/profiler_test.py +++ b/MaxText/tests/profiler_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Profiler tests for TPUs.""" import glob @@ -29,9 +29,9 @@ class TpuJAXTest(unittest.TestCase): def _get_session_snapshot(self): """Gets a session snapshot of current session. assume only one session.""" - profile_plugin_root ="tensorboard/plugins/profile" + profile_plugin_root = "tensorboard/plugins/profile" # The session exists under a director whose name is time-dependent. - profile_session_glob = os.path.join(profile_plugin_root, '*', '*.xplane.pb') + profile_session_glob = os.path.join(profile_plugin_root, "*", "*.xplane.pb") return glob.glob(profile_session_glob) def test_xplane_is_present(self): @@ -40,47 +40,48 @@ def test_xplane_is_present(self): def test_overview_page(self): xspace_filenames = self._get_session_snapshot() - result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, - 'overview_page^', {}) + result, _ = raw_to_tool_data.xspace_to_tool_data( + xspace_filenames, "overview_page^", {} + ) result = json.loads(result) run_environment = result[2] - self.assertEqual(run_environment['p']['host_count'], '1') - self.assertRegex(run_environment['p']['device_type'], 'TPU.*') + self.assertEqual(run_environment["p"]["host_count"], "1") + self.assertRegex(run_environment["p"]["device_type"], "TPU.*") def test_op_profile(self): xspace_filenames = self._get_session_snapshot() result, _ = raw_to_tool_data.xspace_to_tool_data( - xspace_filenames, 'op_profile^', {} + xspace_filenames, "op_profile^", {} ) result = json.loads(result) - self.assertIn('byCategory', result) - self.assertIn('metrics', result['byCategory']) - overall_metrics = result['byCategory']['metrics'] - self.assertIn('flops', overall_metrics) - self.assertIn('bandwidthUtils', overall_metrics) - self.assertGreater(overall_metrics['flops'], 0) + self.assertIn("byCategory", result) + self.assertIn("metrics", result["byCategory"]) + overall_metrics = result["byCategory"]["metrics"] + self.assertIn("flops", overall_metrics) + self.assertIn("bandwidthUtils", overall_metrics) + self.assertGreater(overall_metrics["flops"], 0) def test_device_trace_contains_threads(self): xspace_filenames = self._get_session_snapshot() result, _ = raw_to_tool_data.xspace_to_tool_data( - xspace_filenames, 'trace_viewer^', {} + xspace_filenames, "trace_viewer^", {} ) result = json.loads(result) thread_names = [] - for event in result['traceEvents']: - if 'name' in event and event['name'] == 'thread_name': - thread_names.append((event['args']['name'])) - expected_threads = [ - 'TensorFlow Name Scope', - 'TensorFlow Ops', - 'XLA Modules', - 'XLA Ops', - 'XLA TraceMe', - 'Steps', - ] + for event in result["traceEvents"]: + if "name" in event and event["name"] == "thread_name": + thread_names.append((event["args"]["name"])) + expected_threads = [ + "TensorFlow Name Scope", + "TensorFlow Ops", + "XLA Modules", + "XLA Ops", + "XLA TraceMe", + "Steps", + ] # Ensure that thread_names contains at least all expected threads. - self.assertEqual(set(expected_threads)-set(thread_names), set()) + self.assertEqual(set(expected_threads) - set(thread_names), set()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/quantizations_test.py b/MaxText/tests/quantizations_test.py index 7f35ffb0f..ad6e55d49 100644 --- a/MaxText/tests/quantizations_test.py +++ b/MaxText/tests/quantizations_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the quantizations """ from jax import numpy as jnp @@ -23,8 +23,10 @@ from layers import quantizations import unittest + class QuantTestModule(nn.Module): """Test module for einsum.""" + quantization: quantizations.AqtQuantization @nn.compact @@ -36,23 +38,31 @@ def __call__(self, inputs): einsum = self.quantization.einsum() dot_general_cls = self.quantization.dot_general_cls() dot_general = dot_general_cls() - res_einsum = einsum('bc,ab->ac', inputs, identity) + res_einsum = einsum("bc,ab->ac", inputs, identity) res_dg = dot_general(inputs, inputs, (((), ()), ((), ())), precision=None) return res_einsum, res_dg -def _configure_quantization(quant_str="", mode_str='train'): - pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False, quantization=quant_str) + +def _configure_quantization(quant_str="", mode_str="train"): + pyconfig.initialize( + [None, "configs/base.yml"], + enable_checkpointing=False, + quantization=quant_str, + ) config = pyconfig.config quant = quantizations.configure_quantization(config, mode_str) return quant + def _apply(quant_str=""): quant = _configure_quantization(quant_str) test_module = QuantTestModule(quant) rng = random.PRNGKey(0) - variables = test_module.init({'params': rng}, jnp.ones((2, 2))) + variables = test_module.init({"params": rng}, jnp.ones((2, 2))) inputs = jnp.ones((2, 2)) - res_einsum, res_dg = test_module.apply(variables, inputs, rngs={'params': random.PRNGKey(0)}) + res_einsum, res_dg = test_module.apply( + variables, inputs, rngs={"params": random.PRNGKey(0)} + ) return inputs, res_einsum, res_dg @@ -60,11 +70,10 @@ class QuantizationTest(unittest.TestCase): """Tests for quantization.""" def test_in_quant_mode(self): - quant = _configure_quantization(quant_str="int8", mode_str='convert') + quant = _configure_quantization(quant_str="int8", mode_str="convert") self.assertTrue(quantizations.in_convert_mode(quant)) self.assertFalse(quantizations.in_serve_mode(quant)) - def test_configure_quantization_is_null(self): for quant_mode in ["train", "serve", "convert"]: quant = _configure_quantization(quant_str="", mode_str=quant_mode) @@ -88,37 +97,74 @@ def test_aqt_quantization(self): self.assertTrue(jnp.greater(jnp.max(inputs), jnp.max(res_einsum))) self.assertEqual(res_einsum.dtype, np.dtype(np.float32)) self.assertTrue(jnp.greater(jnp.max(inputs), jnp.max(res_dg[0][0]))) - #self.assertEqual(res_dg.dtype, np.dtype(np.float32)) + # self.assertEqual(res_dg.dtype, np.dtype(np.float32)) def test_remove_quantized_params(self): _params = { - 'decoder': { - 'decoder_norm': {'scale': 1.0}, - 'layers': { - 'mlp': {'wi_0': {'kernel': 1.0}, 'wi_1': {'kernel': 1.0}, 'wo': {'kernel': 1.0}}, - 'self_attention': {'key': {'kernel': 1.0},}}, - 'logits_dense': {'kernel': 1.0}}, - } - _aqt_vars = { - 'decoder': { - 'layers': { - 'mlp': { - 'wi_0': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}}, - 'wi_1': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}}, - 'wo': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0 }}} + "decoder": { + "decoder_norm": {"scale": 1.0}, + "layers": { + "mlp": { + "wi_0": {"kernel": 1.0}, + "wi_1": {"kernel": 1.0}, + "wo": {"kernel": 1.0}, + }, + "self_attention": { + "key": {"kernel": 1.0}, + }, }, - 'self_attention': {'key': {'AqtDotGeneral_0': {'qrhs': {'scale': 1.0, '_value': 1.0}},}}}} - } + "logits_dense": {"kernel": 1.0}, + }, + } + _aqt_vars = { + "decoder": { + "layers": { + "mlp": { + "wi_0": { + "AqtDotGeneral_0": { + "qrhs": {"scale": 1.0, "_value": 1.0} + } + }, + "wi_1": { + "AqtDotGeneral_0": { + "qrhs": {"scale": 1.0, "_value": 1.0} + } + }, + "wo": { + "AqtDotGeneral_0": { + "qrhs": {"scale": 1.0, "_value": 1.0} + } + }, + }, + "self_attention": { + "key": { + "AqtDotGeneral_0": { + "qrhs": {"scale": 1.0, "_value": 1.0} + }, + } + }, + } + } + } _expected = { - 'decoder': { - 'decoder_norm': {'scale': 1.0}, - 'layers': { - 'mlp': {'wi_0': {'kernel': {}}, 'wi_1': {'kernel': {}}, 'wo': {'kernel': {}}}, - 'self_attention': {'key': {'kernel': {}},}}, - 'logits_dense': {'kernel': 1.0},} - } + "decoder": { + "decoder_norm": {"scale": 1.0}, + "layers": { + "mlp": { + "wi_0": {"kernel": {}}, + "wi_1": {"kernel": {}}, + "wo": {"kernel": {}}, + }, + "self_attention": { + "key": {"kernel": {}}, + }, + }, + "logits_dense": {"kernel": 1.0}, + } + } result = quantizations.remove_quantized_params(_params, _aqt_vars) self.assertEqual(_expected, result) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index 8b51246b2..59b2379df 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the standalone_checkpointer.py """ import unittest @@ -25,35 +25,70 @@ class Standalone_DL_CKPT(unittest.TestCase): - """Tests for standalone_checkpointer.py, checkpoint and restore. """ + """Tests for standalone_checkpointer.py, checkpoint and restore.""" def _get_random_test_name(self, test_name): now = datetime.now() date_time = now.strftime("_%Y-%m-%d-%H-%M_") - random_string = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(6)) + random_string = "".join( + random.SystemRandom().choice(string.ascii_uppercase + string.digits) + for _ in range(6) + ) random_run_name = test_name + date_time + random_string return random_run_name @pytest.mark.tpu def test_standalone_dataloader(self): random_run_name = self._get_random_test_name("standalone_dataloader") - sdl_main((None, "configs/base.yml", "run_name="+random_run_name, "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", "steps=100", "enable_checkpointing=false", - "tokenizer_path=../assets/tokenizer.llama2")) # need to pass relative path to tokenizer + sdl_main(( + None, + "configs/base.yml", + "run_name=" + random_run_name, + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "steps=100", + "enable_checkpointing=false", + "tokenizer_path=../assets/tokenizer.llama2", + )) # need to pass relative path to tokenizer @pytest.mark.tpu def test_standalone_checkpointer(self): random_run_name = self._get_random_test_name("standalone_checkpointer") # checkpoint at 50 - sckpt_main((None, "configs/base.yml", f"run_name={random_run_name}", "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset","base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", - "base_mlp_dim=128", "base_num_decoder_layers=2", "steps=60", "enable_checkpointing=True", - "checkpoint_period=50", "async_checkpointing=False")) + sckpt_main(( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=60", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + )) # restore at 50 and checkpoint at 100 - sckpt_main((None, "configs/base.yml", f"run_name={random_run_name}", "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset","base_emb_dim=128", "base_num_query_heads=4", "base_num_kv_heads=4", - "base_mlp_dim=128", "base_num_decoder_layers=2", "steps=110", "enable_checkpointing=True", - "checkpoint_period=50", "async_checkpointing=False")) + sckpt_main(( + None, + "configs/base.yml", + f"run_name={random_run_name}", + "base_output_directory=gs://runner-maxtext-logs", + "dataset_path=gs://maxtext-dataset", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=110", + "enable_checkpointing=True", + "checkpoint_period=50", + "async_checkpointing=False", + )) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/tfds_data_processing_test.py b/MaxText/tests/tfds_data_processing_test.py index 1c998fe82..7871b8c7f 100644 --- a/MaxText/tests/tfds_data_processing_test.py +++ b/MaxText/tests/tfds_data_processing_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring, missing-function-docstring import os @@ -29,47 +29,65 @@ from input_pipeline import _tfds_data_processing from input_pipeline import input_pipeline_interface + class TfdsDataProcessingTest(unittest.TestCase): def setUp(self): super().setUp() - pyconfig.initialize([sys.argv[0], 'configs/base.yml'], per_device_batch_size=1, run_name='test', mesh_axes = ['data'], - logical_axis_rules = [['batch', 'data']], - data_sharding = ['data'], - base_output_directory = "gs://max-experiments/", - dataset_path = "gs://maxtext-dataset/", - tokenizer_path = "../assets/tokenizer", - enable_checkpointing=False) + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + dataset_path="gs://maxtext-dataset/", + tokenizer_path="../assets/tokenizer", + enable_checkpointing=False, + ) os.environ["TFDS_DATA_DIR"] = pyconfig.config.dataset_path self.config = pyconfig.config self.mesh_shape_1d = (len(jax.devices()),) - self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.mesh = Mesh( + mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes + ) self.read_config = tfds.ReadConfig( - shuffle_seed = self.config.data_shuffle_seed, + shuffle_seed=self.config.data_shuffle_seed, ) self.read_config.add_tfds_id = True - + self.train_ds, self.eval_ds = self._get_datasets() - self.train_iter, self.eval_iter, self.predict_iter = self._get_preprocessed_datasets() + ( + self.train_iter, + self.eval_iter, + self.predict_iter, + ) = self._get_preprocessed_datasets() def _get_datasets(self): print("Sharding dataset in ", jax.process_count(), " shards") - process_indices = input_pipeline_interface.get_process_loading_real_data(self.config, self.mesh) + process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config, self.mesh + ) train_ds, eval_ds = _tfds_data_processing.get_datasets( - config=self.config, - dataloading_host_index = process_indices.index(jax.process_index()), - dataloading_host_count = len(process_indices), - read_config = self.read_config) + config=self.config, + dataloading_host_index=process_indices.index(jax.process_index()), + dataloading_host_count=len(process_indices), + read_config=self.read_config, + ) return train_ds, eval_ds def _get_preprocessed_datasets(self): mesh_shape_1d = (len(jax.devices()),) - mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), self.config.mesh_axes) - sp_tokenizer = input_pipeline_interface.get_tokenizer(self.config.tokenizer_path) + mesh = Mesh( + mesh_utils.create_device_mesh(mesh_shape_1d), self.config.mesh_axes + ) + sp_tokenizer = input_pipeline_interface.get_tokenizer( + self.config.tokenizer_path + ) train_iter, eval_iter, test_iter = _tfds_data_processing.preprocess_dataset( - self.config, - mesh, - self.train_ds, self.eval_ds, sp_tokenizer) + self.config, mesh, self.train_ds, self.eval_ds, sp_tokenizer + ) return train_iter, eval_iter, test_iter def test_train_ds(self): @@ -77,33 +95,39 @@ def test_train_ds(self): # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. batch = next(self.train_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) def test_eval_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] batch = next(self.eval_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "targets": expected_shape, + }, + ) def test_predict_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] batch = next(self.predict_iter) - self.assertEqual({k: list(v.shape) for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) - + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "targets": expected_shape, + }, + ) def test_ds_determinism(self): train_ds1 = self.train_ds.batch(64) @@ -113,20 +137,41 @@ def test_ds_determinism(self): train_ds = train_ds.batch(64) train_ds2 = next(train_ds.as_numpy_iterator()) - self.assertCountEqual(train_ds1['tfds_id'], train_ds2['tfds_id']) - + self.assertCountEqual(train_ds1["tfds_id"], train_ds2["tfds_id"]) def test_batch_determinism(self): batch1 = next(self.train_iter) self.train_ds, _ = self._get_datasets() - train_iter2, _, _= self._get_preprocessed_datasets() + train_iter2, _, _ = self._get_preprocessed_datasets() batch2 = next(train_iter2) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs'], batch2['inputs']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets'], batch2['targets']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs_segmentation'], batch2['inputs_segmentation']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets_segmentation'], batch2['targets_segmentation']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['inputs_position'], batch2['inputs_position']))) - self.assertTrue(tf.reduce_all(tf.equal(batch1['targets_position'], batch2['targets_position']))) + self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs"], batch2["inputs"]))) + self.assertTrue( + tf.reduce_all(tf.equal(batch1["targets"], batch2["targets"])) + ) + self.assertTrue( + tf.reduce_all( + tf.equal( + batch1["inputs_segmentation"], batch2["inputs_segmentation"] + ) + ) + ) + self.assertTrue( + tf.reduce_all( + tf.equal( + batch1["targets_segmentation"], batch2["targets_segmentation"] + ) + ) + ) + self.assertTrue( + tf.reduce_all( + tf.equal(batch1["inputs_position"], batch2["inputs_position"]) + ) + ) + self.assertTrue( + tf.reduce_all( + tf.equal(batch1["targets_position"], batch2["targets_position"]) + ) + ) def test_for_loop_repeatable(self): def get_first_batch(iterator): @@ -137,10 +182,9 @@ def get_first_batch(iterator): eval_batch1 = get_first_batch(self.eval_iter) eval_batch2 = get_first_batch(self.eval_iter) - self.assertTrue((eval_batch1['inputs']==eval_batch2['inputs']).all()) - self.assertTrue((eval_batch1['targets']==eval_batch2['targets']).all()) + self.assertTrue((eval_batch1["inputs"] == eval_batch2["inputs"]).all()) + self.assertTrue((eval_batch1["targets"] == eval_batch2["targets"]).all()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() - diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index 6797f6ee3..7b0847fb9 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -1,17 +1,17 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ Tests for tokenizer @@ -30,41 +30,53 @@ class TokenizerTest(unittest.TestCase): @classmethod def setUpClass(cls): - dataset_name = 'c4/en:3.0.1' - dataset_path = 'gs://maxtext-dataset' + dataset_name = "c4/en:3.0.1" + dataset_path = "gs://maxtext-dataset" cls.vocab_size = 32_768 cls.max_corpus_chars = 10_000_000 - assets_path = 'tests' - vocab_model_name = 'test_tokenizer' + assets_path = "tests" + vocab_model_name = "test_tokenizer" cls.tokenizer_path = os.path.join(assets_path, vocab_model_name) os.environ["TFDS_DATA_DIR"] = dataset_path read_config = tfds.ReadConfig( - shuffle_seed = 0, + shuffle_seed=0, ) train_ds_builder = tfds.builder(dataset_name) - cls.dataset = train_ds_builder.as_dataset(split='train', read_config=read_config, shuffle_files=True) - train_tokenizer.train_tokenizer(cls.dataset, - assets_path=assets_path, - vocab_path=cls.tokenizer_path, - vocab_size=cls.vocab_size, - max_corpus_chars=cls.max_corpus_chars) + cls.dataset = train_ds_builder.as_dataset( + split="train", read_config=read_config, shuffle_files=True + ) + train_tokenizer.train_tokenizer( + cls.dataset, + assets_path=assets_path, + vocab_path=cls.tokenizer_path, + vocab_size=cls.vocab_size, + max_corpus_chars=cls.max_corpus_chars, + ) @classmethod def tearDownClass(cls): os.remove(cls.tokenizer_path) def test_tokenize(self): - source_tokenizer = tokenizer.load_tokenizer('../assets/tokenizer') + source_tokenizer = tokenizer.load_tokenizer("../assets/tokenizer") test_tokenizer = tokenizer.load_tokenizer(self.tokenizer_path) - text = 'This is a test' - self.assertTrue((np.asarray(source_tokenizer.tokenize(text)) & np.asarray(test_tokenizer.tokenize(text))).all()) + text = "This is a test" + self.assertTrue( + ( + np.asarray(source_tokenizer.tokenize(text)) + & np.asarray(test_tokenizer.tokenize(text)) + ).all() + ) def test_detokenize(self): - source_tokenizer = tokenizer.load_tokenizer('../assets/tokenizer') + source_tokenizer = tokenizer.load_tokenizer("../assets/tokenizer") test_tokenizer = tokenizer.load_tokenizer(self.tokenizer_path) - tokens = [66,12,10,698,2] - self.assertEqual(np.asarray(source_tokenizer.detokenize(tokens)),np.asarray(test_tokenizer.detokenize(tokens))) + tokens = [66, 12, 10, 698, 2] + self.assertEqual( + np.asarray(source_tokenizer.detokenize(tokens)), + np.asarray(test_tokenizer.detokenize(tokens)), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 3bf46e753..cc64aea91 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Tests for the common Max Utils """ import unittest @@ -26,73 +26,158 @@ class TrainCompile(unittest.TestCase): @pytest.mark.tpu def test_save_compiled_v4(self): - compiled_trainstep_file='/tmp/test_compiled_v4.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v4-8", "compile_topology_num_slices=1", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v4.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v4-8", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_save_compiled_v5e(self): - compiled_trainstep_file='/tmp/test_compiled_v5e.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-16", "compile_topology_num_slices=1", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-16", + "compile_topology_num_slices=1", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_minimal_offloaded_v5e(self): - compiled_trainstep_file='/tmp/test_compiled_v5e_offload.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=1", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=minimal_offloaded", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_compiled_v5e_offload.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=minimal_offloaded", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_save_compiled_v5p_two_slices(self): - compiled_trainstep_file='/tmp/test_compiled_v5p_two_slices.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5p-8", "compile_topology_num_slices=2", "base_emb_dim=256", "base_mlp_dim=256", - "base_num_decoder_layers=2")) + compiled_trainstep_file = "/tmp/test_compiled_v5p_two_slices.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=2", + "base_emb_dim=256", + "base_mlp_dim=256", + "base_num_decoder_layers=2", + )) @pytest.mark.tpu def test_sequence_parallelism(self): - compiled_trainstep_file='/tmp/test_compiled.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "use_iota_embed=true", "compile_topology_num_slices=1", - "ici_sequence_parallelism=16", "global_parameter_scale=32", "per_device_batch_size=0.0625", "max_target_length=65536")) + compiled_trainstep_file = "/tmp/test_compiled.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "ici_sequence_parallelism=16", + "global_parameter_scale=32", + "per_device_batch_size=0.0625", + "max_target_length=65536", + )) @pytest.mark.tpu def test_remat_save_dot_except_mlpwi(self): - compiled_trainstep_file='/tmp/test_remat_save_dot_except_mlpwi.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.125", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_dot_except_mlpwi", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlpwi.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.125", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlpwi", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_save_dot_except_mlp(self): - compiled_trainstep_file='/tmp/test_remat_save_dot_except_mlp.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.25", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_dot_except_mlp", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlp.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.25", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_dot_except_mlp", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_save_qkv_proj(self): - compiled_trainstep_file='/tmp/test_remat_save_qkv_proj.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=0.375", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=save_qkv_proj", - "use_iota_embed=true", "global_parameter_scale=128")) + compiled_trainstep_file = "/tmp/test_remat_save_qkv_proj.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=0.375", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=save_qkv_proj", + "use_iota_embed=true", + "global_parameter_scale=128", + )) @pytest.mark.tpu def test_remat_full(self): - compiled_trainstep_file='/tmp/test_remat_full.pickle' - train_compile_main((None, "configs/base.yml", f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5e-256", "compile_topology_num_slices=1", "per_device_batch_size=1", "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=16", "max_target_length=2048", - "fused_qkv=true", "fused_mlp=true", "remat_policy=full", - "use_iota_embed=true", "global_parameter_scale=128")) \ No newline at end of file + compiled_trainstep_file = "/tmp/test_remat_full.pickle" + train_compile_main(( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5e-256", + "compile_topology_num_slices=1", + "per_device_batch_size=1", + "ici_fsdp_parallelism=16", + "ici_tensor_parallelism=16", + "max_target_length=2048", + "fused_qkv=true", + "fused_mlp=true", + "remat_policy=full", + "use_iota_embed=true", + "global_parameter_scale=128", + )) diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 5efc5ebeb..a05e0fb6e 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test for int8""" import os @@ -20,18 +20,32 @@ from train import main as train_main from absl.testing import absltest + class Train(unittest.TestCase): """Smoke test for int8 G3 only""" def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([None, "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", "dataset_type=synthetic", "steps=10", - "enable_checkpointing=False", "quantization=int8"]) - -if __name__ == '__main__': + train_main([ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + "quantization=int8", + ]) + + +if __name__ == "__main__": absltest.main() diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index 8cd41fb33..b3046fd35 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test """ import os @@ -20,18 +20,31 @@ from train import main as train_main from absl.testing import absltest + class Train(unittest.TestCase): """Smoke test G3 only""" def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - train_main([None, "third_party/py/maxtext/configs/base.yml", - f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", - r"dataset_path=gs://maxtext-dataset", - "base_emb_dim=8", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32", - "base_num_decoder_layers=8", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", "dataset_type=synthetic", "steps=10", - "enable_checkpointing=False"]) - -if __name__ == '__main__': + train_main([ + None, + "third_party/py/maxtext/configs/base.yml", + f"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "base_emb_dim=8", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=8", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "dataset_type=synthetic", + "steps=10", + "enable_checkpointing=False", + ]) + + +if __name__ == "__main__": absltest.main() diff --git a/MaxText/tests/weight_dtypes_test.py b/MaxText/tests/weight_dtypes_test.py index 49829d43c..908010a8f 100644 --- a/MaxText/tests/weight_dtypes_test.py +++ b/MaxText/tests/weight_dtypes_test.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Test that all weights are expected dtype (default float32) """ import unittest @@ -30,37 +30,45 @@ Transformer = models.Transformer -class WeightDtypes(unittest.TestCase): - """Test that all weights are expected dtype (default float32) """ - - def get_weights(self, argv): - """ Gets model weights """ - - # Setup necessary inputs to build a model state - pyconfig.initialize(argv) - config = pyconfig.config - quant = quantizations.configure_quantization(config) - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - model = Transformer(config, mesh, quant=quant) - learning_rate_schedule = max_utils.create_learning_rate_schedule(config) - tx = optimizers.get_optimizer(config, learning_rate_schedule) - _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - - abstract_state, _ , _ = max_utils.get_abstract_state(model, tx, config, example_rng, mesh) - return abstract_state.params - - def assert_weights_are_dtype(self, weights, expected_dtype): - jax.tree_util.tree_map_with_path(lambda x,y: self.assertEqual(y.dtype, expected_dtype), weights) - - def test_default_float32(self): - argv = [None, "configs/base.yml", "enable_checkpointing=False"] - weights = self.get_weights(argv) - self.assert_weights_are_dtype(weights, jnp.float32) - - def test_set_bf16(self): - argv = [None, "configs/base.yml", "enable_checkpointing=False", "weight_dtype=bfloat16"] - weights = self.get_weights(argv) - self.assert_weights_are_dtype(weights, jnp.bfloat16) - +class WeightDtypes(unittest.TestCase): + """Test that all weights are expected dtype (default float32)""" + + def get_weights(self, argv): + """Gets model weights""" + + # Setup necessary inputs to build a model state + pyconfig.initialize(argv) + config = pyconfig.config + quant = quantizations.configure_quantization(config) + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + model = Transformer(config, mesh, quant=quant) + learning_rate_schedule = max_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) + _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) + + abstract_state, _, _ = max_utils.get_abstract_state( + model, tx, config, example_rng, mesh + ) + return abstract_state.params + + def assert_weights_are_dtype(self, weights, expected_dtype): + jax.tree_util.tree_map_with_path( + lambda x, y: self.assertEqual(y.dtype, expected_dtype), weights + ) + + def test_default_float32(self): + argv = [None, "configs/base.yml", "enable_checkpointing=False"] + weights = self.get_weights(argv) + self.assert_weights_are_dtype(weights, jnp.float32) + + def test_set_bf16(self): + argv = [ + None, + "configs/base.yml", + "enable_checkpointing=False", + "weight_dtype=bfloat16", + ] + weights = self.get_weights(argv) + self.assert_weights_are_dtype(weights, jnp.bfloat16) diff --git a/MaxText/tokenizer.py b/MaxText/tokenizer.py index 9e29b1e68..f1b9d6098 100644 --- a/MaxText/tokenizer.py +++ b/MaxText/tokenizer.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Provides op for tokenizing a dataset.""" @@ -29,34 +29,38 @@ Features = Dict[str, tf.Tensor] - -def _load_sentencepiece_tokenizer(tokenizer_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer( + tokenizer_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" max_logging.log(f"Tokenizer path: {tokenizer_path}") - with tf.io.gfile.GFile(tokenizer_path, 'rb') as model_fp: + with tf.io.gfile.GFile(tokenizer_path, "rb") as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer + def load_tokenizer(tokenizer_path: str, add_bos=False, add_eos=True): """Loads the tokenizer at `tokenizer_path` or trains a one from `dataset`.""" try: - sp_tokenizer = _load_sentencepiece_tokenizer(tokenizer_path, add_bos, add_eos) + sp_tokenizer = _load_sentencepiece_tokenizer( + tokenizer_path, add_bos, add_eos + ) return sp_tokenizer except (tf.errors.NotFoundError, tf.errors.InvalidArgumentError): - logging.info('SentencePiece vocab not found, Run train_tokenizer.py') + logging.info("SentencePiece vocab not found, Run train_tokenizer.py") return None @dataclasses.dataclass class TokenizeOp: - sp_tokenizer: Any - data_keys: Iterable[str] = ('inputs', 'targets') + data_keys: Iterable[str] = ("inputs", "targets") def __call__(self, features: Features) -> Features: for k in self.data_keys: diff --git a/MaxText/train.py b/MaxText/train.py index 4a1966945..0cc919622 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports """Training loop and Decoding of the model.""" @@ -64,87 +64,109 @@ Transformer = models.Transformer EPS = 1e-8 + def validate_train_config(config): - """ Validates the configuration is set correctly for train.py""" + """Validates the configuration is set correctly for train.py""" assert config.run_name, "Erroring out, need a real run_name" - if not config.dataset_path.startswith('gs://'): - max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") - if not config.base_output_directory.startswith('gs://'): - max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") - assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive interger." - + if not config.dataset_path.startswith("gs://"): + max_logging.log( + "WARNING: 'dataset_path' might be pointing your local file system" + ) + if not config.base_output_directory.startswith("gs://"): + max_logging.log( + "WARNING: 'base_output_directory' might be pointing your local file system" + ) + assert ( + config.steps > 0 + ), "You must set steps or learning_rate_schedule_steps to a positive interger." def get_first_step(state): - with jax.spmd_mode('allow_all'): + with jax.spmd_mode("allow_all"): return int(state.step) def load_next_batch(train_iter, example_batch, config): - """Loads the next batch. Can keep reusing the same batch for performance reasons """ + """Loads the next batch. Can keep reusing the same batch for performance reasons""" if config.reuse_example_batch and example_batch is not None: return example_batch else: return next(train_iter) + def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): """Records scalar metrics to be written to tensorboard""" - metrics['scalar'].update({ - 'perf/step_time_seconds': step_time_delta.total_seconds() - }) - metrics['scalar'].update({ - 'perf/per_device_tflops' : per_device_tflops - }) - metrics['scalar'].update({ - 'perf/per_device_tflops_per_sec': - per_device_tflops / - step_time_delta.total_seconds() - }) - metrics['scalar'].update({'learning/current_learning_rate': lr }) + metrics["scalar"].update( + {"perf/step_time_seconds": step_time_delta.total_seconds()} + ) + metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops}) + metrics["scalar"].update( + { + "perf/per_device_tflops_per_sec": per_device_tflops + / step_time_delta.total_seconds() + } + ) + metrics["scalar"].update({"learning/current_learning_rate": lr}) + _buffered_step = None _buffered_metrics = None -def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config): + + +def write_metrics( + writer, local_metrics_file, running_gcs_metrics, metrics, step, config +): """Entry point for all metrics writing in Train's Main. - TODO: would be better as a Class in the future (that initialized all state!) + TODO: would be better as a Class in the future (that initialized all state!) - To avoid introducing an unnecessary dependency, we "double buffer" -- we hold - onto the last metrics and step and only publish when we receive a new metrics and step. - The logic is that this ensures that Jax is able to queues train_steps and we - don't block when turning "lazy" Jax arrays into real Python numbers. + To avoid introducing an unnecessary dependency, we "double buffer" -- we hold + onto the last metrics and step and only publish when we receive a new metrics and step. + The logic is that this ensures that Jax is able to queues train_steps and we + don't block when turning "lazy" Jax arrays into real Python numbers. """ global _buffered_step, _buffered_metrics if _buffered_metrics is not None: if _buffered_step is None: raise ValueError(f"When writing metrics, {_buffered_step=} was none") - write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config) + write_metrics_to_tensorboard( + writer, _buffered_metrics, _buffered_step, config + ) if config.metrics_file: - max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file) + max_utils.write_metrics_locally( + _buffered_metrics, _buffered_step, config, local_metrics_file + ) if config.gcs_metrics and jax.process_index() == 0: - running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics) + running_gcs_metrics = max_utils.write_metrics_for_gcs( + _buffered_metrics, _buffered_step, config, running_gcs_metrics + ) _buffered_step = step _buffered_metrics = metrics + def write_metrics_to_tensorboard(writer, metrics, step, config): - """ Writes metrics to tensorboard""" - with jax.spmd_mode('allow_all'): + """Writes metrics to tensorboard""" + with jax.spmd_mode("allow_all"): if jax.process_index() == 0: - for metric_name in metrics.get("scalar",[]): - writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) - for metric_name in metrics.get("scalars",[]): + for metric_name in metrics.get("scalar", []): + writer.add_scalar( + metric_name, np.array(metrics["scalar"][metric_name]), step + ) + for metric_name in metrics.get("scalars", []): writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) full_log = step % config.log_period == 0 - max_logging.log(f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " - f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " - f"loss: {metrics['scalar']['learning/loss']:.3f}") + max_logging.log( + f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " + f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " + f"loss: {metrics['scalar']['learning/loss']:.3f}" + ) if full_log and jax.process_index() == 0: max_logging.log( @@ -152,47 +174,67 @@ def write_metrics_to_tensorboard(writer, metrics, step, config): ) writer.flush() -def save_checkpoint(checkpoint_manager, step, state, dataset_type='c4', data_iterator=None): + +def save_checkpoint( + checkpoint_manager, step, state, dataset_type="c4", data_iterator=None +): """Wrapper for saving checkpoint""" - if dataset_type == 'c4-array_record': + if dataset_type == "c4-array_record": return checkpoint_manager.save( - step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeSave(item=state), - iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator) - ) - ) + step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeSave(item=state), + iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator), + ), + ) else: return checkpoint_manager.save( - step, - args=orbax.checkpoint.args.Composite( - items=orbax.checkpoint.args.PyTreeSave(item=state) - )) + step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeSave(item=state) + ), + ) + # ----------------------------------------------------------------------------- # Top-level Functions # ----------------------------------------------------------------------------- + def record_activation_metrics(output_metrics, intermediate_outputs, config): - """ Adds the activation metrics to the metrics dict""" + """Adds the activation metrics to the metrics dict""" if config.scan_layers: - metrics_dict = intermediate_outputs['intermediates']['decoder']['decoder'] + metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"] for layer_num in range(config.num_decoder_layers): - output_metrics['scalar'][f'activ_fraction_zero/layer_{layer_num:03d}'] = \ - metrics_dict["activation_fraction_zero"][0][layer_num] - output_metrics['scalar'][f'activ_mean/layer_{layer_num:03d}'] = metrics_dict["activation_mean"][0][layer_num] - output_metrics['scalar'][f'activ_stdev/layer_{layer_num:03d}'] = metrics_dict["activation_stdev"][0][layer_num] + output_metrics["scalar"][ + f"activ_fraction_zero/layer_{layer_num:03d}" + ] = metrics_dict["activation_fraction_zero"][0][layer_num] + output_metrics["scalar"][ + f"activ_mean/layer_{layer_num:03d}" + ] = metrics_dict["activation_mean"][0][layer_num] + output_metrics["scalar"][ + f"activ_stdev/layer_{layer_num:03d}" + ] = metrics_dict["activation_stdev"][0][layer_num] else: for layer_num in range(config.num_decoder_layers): - layer = intermediate_outputs['intermediates']['decoder'][f'layers_{layer_num}'] - output_metrics['scalar'][f'activ_fraction_zero/layer_{layer_num:03d}'] = layer["activation_fraction_zero"][0] - output_metrics['scalar'][f'activ_mean/layer_{layer_num:03d}'] = layer["activation_mean"][0] - output_metrics['scalar'][f'activ_stdev/layer_{layer_num:03d}'] = layer["activation_stdev"][0] + layer = intermediate_outputs["intermediates"]["decoder"][ + f"layers_{layer_num}" + ] + output_metrics["scalar"][ + f"activ_fraction_zero/layer_{layer_num:03d}" + ] = layer["activation_fraction_zero"][0] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer[ + "activation_mean" + ][0] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer[ + "activation_stdev" + ][0] + def loss_fn(model, config, data, dropout_rng, params, is_train=True): - '''loss_fn for both train and eval. + """loss_fn for both train and eval. Args: model: A nn.Module @@ -205,33 +247,38 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): Returns: loss: average loss aux: a dictionary including intermediate_outputs, total_loss, and total_weights - ''' + """ # inputs, targets, segments, positions = apply_args rng1, aqt_rng = jax.random.split(dropout_rng) # decimate proportion of data when per_device_batch_size<1 if is_train: for k, v in data.items(): - data[k] = v[:config.global_batch_size_to_train_on,:] - - logits, intermediate_outputs = model.apply(params, - data['inputs'], - data['inputs_position'], - decoder_segment_ids=data['inputs_segmentation'], - enable_dropout=config.enable_dropout if is_train else False, - rngs={'dropout': rng1, 'params': aqt_rng}, mutable='intermediates') - one_hot_targets = jax.nn.one_hot(data['targets'], config.vocab_size) + data[k] = v[: config.global_batch_size_to_train_on, :] + + logits, intermediate_outputs = model.apply( + params, + data["inputs"], + data["inputs_position"], + decoder_segment_ids=data["inputs_segmentation"], + enable_dropout=config.enable_dropout if is_train else False, + rngs={"dropout": rng1, "params": aqt_rng}, + mutable="intermediates", + ) + one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) - xent = nn.with_logical_constraint(xent, ('activation_batch', 'activation_length')) + xent = nn.with_logical_constraint( + xent, ("activation_batch", "activation_length") + ) # Mask out paddings at the end of each example. - xent = xent * (data['targets_segmentation'] != 0) + xent = xent * (data["targets_segmentation"] != 0) total_loss = jnp.sum(xent) - total_weights = jnp.sum(data['targets_segmentation'] != 0) + total_weights = jnp.sum(data["targets_segmentation"] != 0) loss = total_loss / (total_weights + EPS) aux = { - 'intermediate_outputs': intermediate_outputs, - 'total_loss': total_loss, - 'total_weights': total_weights, + "intermediate_outputs": intermediate_outputs, + "total_loss": total_loss, + "total_weights": total_weights, } return loss, aux @@ -251,45 +298,65 @@ def train_step(model, config, state, data, dropout_rng): rng2: A new rng key that can be used in future calls. """ - train_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=True) + train_loss_fn = functools.partial( + loss_fn, model, config, data, dropout_rng, is_train=True + ) grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True) (loss, aux), raw_grads = grad_fn(state.params) - intermediate_outputs = aux['intermediate_outputs'] + intermediate_outputs = aux["intermediate_outputs"] if config.gradient_clipping_threshold > 0: - grads, _ = optax.clip_by_global_norm(config.gradient_clipping_threshold).update(raw_grads, state, None) + grads, _ = optax.clip_by_global_norm( + config.gradient_clipping_threshold + ).update(raw_grads, state, None) else: grads = raw_grads new_state = state.apply_gradients(grads=grads) - metrics = {'scalar': {'learning/loss': loss, 'learning/grad_norm': max_utils.l2norm_pytree(grads), - 'learning/raw_grad_norm': max_utils.l2norm_pytree(raw_grads), - 'learning/param_norm': max_utils.l2norm_pytree(new_state.params)}, 'scalars': {}} + metrics = { + "scalar": { + "learning/loss": loss, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + "learning/param_norm": max_utils.l2norm_pytree(new_state.params), + }, + "scalars": {}, + } if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) return new_state, metrics + def eval_step(model, config, state, data, dropout_rng): """eval_step no backprop and new state compared with train_step.""" - eval_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=False) + eval_loss_fn = functools.partial( + loss_fn, model, config, data, dropout_rng, is_train=False + ) loss, aux = eval_loss_fn(state.params) - total_loss = aux['total_loss'] - total_weights = aux['total_weights'] - metrics = {'scalar': - {'evaluation/loss': loss, - 'evaluation/total_loss': total_loss, - 'evaluation/total_weights': total_weights}} + total_loss = aux["total_loss"] + total_weights = aux["total_weights"] + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/total_loss": total_loss, + "evaluation/total_weights": total_weights, + } + } return metrics + def create_goodput_recorder(config): if config.enable_goodput_recording: - logger_name = f'goodput_{config.run_name}' - recorder = goodput.GoodputRecorder(config.run_name, logger_name, jax.process_index() == 0) + logger_name = f"goodput_{config.run_name}" + recorder = goodput.GoodputRecorder( + config.run_name, logger_name, jax.process_index() == 0 + ) return recorder return None + def record_goodput(recorder, config, step=None, job_start=False, job_end=False): if recorder and config.enable_goodput_recording: if job_start and step is None: @@ -299,8 +366,9 @@ def record_goodput(recorder, config, step=None, job_start=False, job_end=False): if step is not None: recorder.record_step_start_time(step) + def setup_mesh_and_model(config): - """ Set up the mesh and the model for training + """Set up the mesh and the model for training Args: config @@ -334,10 +402,19 @@ def setup_mesh_and_model(config): model = Transformer(config, mesh, quant=quant) learning_rate_schedule = max_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx + return ( + init_rng, + writer, + checkpoint_manager, + mesh, + model, + learning_rate_schedule, + tx, + ) + def setup_train_loop(config): - """ Set up prerequisites for the training loop - + """Set up prerequisites for the training loop - checkpoint_manager, PRNG keys, Mesh, Model and optimizer. Set up data iterator and tokenizer, initialize the model. @@ -355,16 +432,37 @@ def setup_train_loop(config): data_iterator: state: the initialized train state """ - init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config) - data_iterator, eval_data_iterator, _ = create_data_iterator_with_tokenizer(config, mesh) + ( + init_rng, + writer, + checkpoint_manager, + mesh, + model, + learning_rate_schedule, + tx, + ) = setup_mesh_and_model(config) + data_iterator, eval_data_iterator, _ = create_data_iterator_with_tokenizer( + config, mesh + ) - state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(model, data_iterator, - tx, config, init_rng, mesh, checkpoint_manager) + state, state_mesh_annotations, data_iterator = max_utils.setup_training_state( + model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + ) maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh) - return ( init_rng, writer, checkpoint_manager, state_mesh_annotations, model, - mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state) + return ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) def train_loop(config, state=None): @@ -379,67 +477,100 @@ def train_loop(config, state=None): recorder = create_goodput_recorder(config) record_goodput(recorder, config, job_start=True) - ( init_rng, writer, checkpoint_manager, state_mesh_annotations, model, - mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state) = setup_train_loop(config) + ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) = setup_train_loop(config) # pylint: disable=line-too-long - functional_train, in_shard_train, out_shard_train, static_argnums_train, donate_argnums_train = maxtext_utils.get_functional_train_with_signature( - train_step, - mesh, - state_mesh_annotations, - model, - config + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature( + train_step, mesh, state_mesh_annotations, model, config ) if eval_data_iterator: # pylint: disable=line-too-long - functional_eval, in_shard_eval, out_shard_eval, static_argnums_eval, donate_argnums_eval = maxtext_utils.get_functional_eval_with_signature( - eval_step, - mesh, - state_mesh_annotations, - model, - config + ( + functional_eval, + in_shard_eval, + out_shard_eval, + static_argnums_eval, + donate_argnums_eval, + ) = maxtext_utils.get_functional_eval_with_signature( + eval_step, mesh, state_mesh_annotations, model, config ) - - num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) + num_model_parameters = max_utils.calculate_num_params_from_pytree( + state.params + ) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") - per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(num_model_parameters, config) + per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device( + num_model_parameters, config + ) # Write train config params, num model params, and XLA flags to tensorboard - max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer) - max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer) + max_utils.add_text_to_summary_writer( + "num_model_parameters", str(num_model_parameters), writer + ) + max_utils.add_text_to_summary_writer( + "libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer + ) max_utils.add_config_to_summary_writer(config, writer) # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit - if config.compiled_trainstep_file != '': + if config.compiled_trainstep_file != "": print("Loading the compiled function...", flush=True) # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state) print("Loaded compiled function!", flush=True) else: p_train_step = jax.jit( - functional_train, - in_shardings=in_shard_train, - out_shardings=out_shard_train, - static_argnums=static_argnums_train, - donate_argnums=donate_argnums_train) + functional_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + static_argnums=static_argnums_train, + donate_argnums=donate_argnums_train, + ) if eval_data_iterator: p_eval_step = jax.jit( - functional_eval, - in_shardings=in_shard_eval, - out_shardings=out_shard_eval, - static_argnums=static_argnums_eval, - donate_argnums=donate_argnums_eval) + functional_eval, + in_shardings=in_shard_eval, + out_shardings=out_shard_eval, + static_argnums=static_argnums_eval, + donate_argnums=donate_argnums_eval, + ) - local_metrics_file = open(config.metrics_file, 'a', encoding="utf8") if config.metrics_file else None + local_metrics_file = ( + open(config.metrics_file, "a", encoding="utf8") + if config.metrics_file + else None + ) running_gcs_metrics = [] if config.gcs_metrics else None - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training first_profiling_step = start_step + config.skip_first_n_steps_for_profiler if config.enable_profiler and first_profiling_step >= config.steps: - raise ValueError("Profiling requested but initial profiling step set past training final step") - last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1) + raise ValueError( + "Profiling requested but initial profiling step set past training final step" + ) + last_profiling_step = np.clip( + first_profiling_step + config.profiler_steps - 1, + first_profiling_step, + config.steps - 1, + ) example_batch = None last_step_completion = datetime.datetime.now() @@ -453,16 +584,21 @@ def train_loop(config, state=None): nextrng = jax.jit(jax.random.fold_in)(init_rng, step) record_goodput(recorder, config, step=step) with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - state, metrics = p_train_step( - state, example_batch, nextrng - ) + state, metrics = p_train_step(state, example_batch, nextrng) new_time = datetime.datetime.now() - record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) + record_scalar_metrics( + metrics, + new_time - last_step_completion, + per_device_tflops, + learning_rate_schedule(step), + ) last_step_completion = new_time if checkpoint_manager is not None: - if save_checkpoint(checkpoint_manager, step, state, config.dataset_type, data_iterator): + if save_checkpoint( + checkpoint_manager, step, state, config.dataset_type, data_iterator + ): max_logging.log(f"saved a checkpoint at step {step}") # Upon preemption, exit when and only when all ongoing saves are complete. @@ -470,22 +606,36 @@ def train_loop(config, state=None): checkpoint_manager.wait_until_finished() sys.exit() - write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config) + write_metrics( + writer, local_metrics_file, running_gcs_metrics, metrics, step, config + ) - if config.eval_interval > 0 and step > start_step and step % config.eval_interval == 0: + if ( + config.eval_interval > 0 + and step > start_step + and step % config.eval_interval == 0 + ): assert eval_data_iterator - cumulative_eval_metrics = {"total_loss": 0., "total_weights": 0.} + cumulative_eval_metrics = {"total_loss": 0.0, "total_weights": 0.0} for eval_batch in eval_data_iterator: with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step( - state, eval_batch, nextrng - ) - cumulative_eval_metrics['total_loss'] += float(eval_metrics['scalar']['evaluation/total_loss']) - cumulative_eval_metrics['total_weights'] += float(eval_metrics['scalar']['evaluation/total_weights']) - eval_loss = cumulative_eval_metrics['total_loss'] / (cumulative_eval_metrics['total_weights'] + EPS) - max_logging.log(f"average loss after {step=}: {eval_loss=}, total_weights={cumulative_eval_metrics['total_weights']}") + eval_metrics = p_eval_step(state, eval_batch, nextrng) + cumulative_eval_metrics["total_loss"] += float( + eval_metrics["scalar"]["evaluation/total_loss"] + ) + cumulative_eval_metrics["total_weights"] += float( + eval_metrics["scalar"]["evaluation/total_weights"] + ) + eval_loss = cumulative_eval_metrics["total_loss"] / ( + cumulative_eval_metrics["total_weights"] + EPS + ) + max_logging.log( + f"average loss after {step=}: {eval_loss=}, total_weights={cumulative_eval_metrics['total_weights']}" + ) if eval_loss <= config.target_eval_loss: - max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") + max_logging.log( + f"Early stop and exit loop after reaching {config.target_eval_loss=}" + ) max_utils.deactivate_profiler(config) break @@ -494,31 +644,47 @@ def train_loop(config, state=None): if checkpoint_manager is not None: checkpoint_manager.wait_until_finished() - write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics + write_metrics( + writer, + local_metrics_file, + running_gcs_metrics, + metrics, + config.steps - 1, + config, + ) # final step metrics max_utils.close_summary_writer(writer) record_goodput(recorder, config, job_end=True) return state + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) pyconfig.initialize(argv) config = pyconfig.config validate_train_config(config) os.environ["TFDS_DATA_DIR"] = config.dataset_path vertex_tensorboard_manager = VertexTensorboardManager() - if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + if config.use_vertex_tensorboard or os.environ.get( + "UPLOAD_DATA_TO_TENSORBOARD" + ): vertex_tensorboard_manager.configure_vertex_tensorboard(config) debug_config = debug_configuration.DebugConfig( - stack_trace_config = stack_trace_configuration.StackTraceConfig( - collect_stack_trace = config.collect_stack_trace, - stack_trace_to_cloud = config.stack_trace_to_cloud, - stack_trace_interval_seconds = config.stack_trace_interval_seconds)) + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) + ) diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) with diagnostic.diagnose(diagnostic_config): train_loop(config) + if __name__ == "__main__": app.run(main) diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index 55384c13e..4d20ba9d1 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Save a Cross Ahead of Time Compiled (XAOT) version of train.py's train step @@ -45,15 +45,20 @@ def validate_config(config): - """ Validates the config is is setup correctly to compile, returning a useful error message if not. """ - assert config.compile_topology != '',\ - "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" - assert config.compile_topology_num_slices > 0,\ - "You must set compile_topology_num_slices to a positive integer" + """Validates the config is is setup correctly to compile, returning a useful error message if not.""" + assert ( + config.compile_topology != "" + ), "You must pass your desired target hardware in compile_topology, e.g. compile_topology=v5e-256" + assert ( + config.compile_topology_num_slices > 0 + ), "You must set compile_topology_num_slices to a positive integer" + def get_topology_mesh(config): - """ Get the target hardware devices, and create configured mesh with them """ - target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) + """Get the target hardware devices, and create configured mesh with them""" + target_hardware = accelerator_to_spec_map.get_system_characteristics( + config.compile_topology + ) topology_devices = get_topology_desc( platform=target_hardware.platform, topology_name=target_hardware.topology_name, @@ -65,8 +70,9 @@ def get_topology_mesh(config): topology_mesh = Mesh(topology_device_mesh, config.mesh_axes) return topology_mesh + def get_shaped_inputs(topology_mesh, config): - """ Get shaped abstractions of inputs to train_step: state, batch and rng """ + """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizier to get shaped versions of the state quant = quantizations.configure_quantization(config) model = Transformer(config, topology_mesh, quant=quant) @@ -79,7 +85,9 @@ def get_shaped_inputs(topology_mesh, config): shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) # Shaped state - abstract_state, state_mesh_annotations, _ = max_utils.get_abstract_state(model, tx, config, example_rng, topology_mesh) + abstract_state, state_mesh_annotations, _ = max_utils.get_abstract_state( + model, tx, config, example_rng, topology_mesh + ) # Shaped batch shaped_batch = input_pipeline_interface.get_shaped_batch(config) @@ -89,29 +97,40 @@ def get_shaped_inputs(topology_mesh, config): return shaped_train_args, shaped_train_kwargs, state_mesh_annotations, model -def jit_and_compile(func, func_input_args, func_input_kwargs, mesh, in_shardings, - out_shardings, static_argnums, donate_argnums, logical_axis_rules): - """ Jit, lower, and compile func.""" +def jit_and_compile( + func, + func_input_args, + func_input_kwargs, + mesh, + in_shardings, + out_shardings, + static_argnums, + donate_argnums, + logical_axis_rules, +): + """Jit, lower, and compile func.""" with mesh, logical_axis_rules: jitted = jax.jit( - func, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - donate_argnums=donate_argnums + func, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + donate_argnums=donate_argnums, ) lowered = jitted.lower(*func_input_args, **func_input_kwargs) compiled = lowered.compile() return compiled + def save_compiled(compiled, save_name): - """ Serialize and save the compiled function. """ + """Serialize and save the compiled function.""" serialized, _, _ = serialize(compiled) with open(save_name, "wb") as f: pickle.dump(serialized, f) + def main(argv: Sequence[str]) -> None: - jax.config.update('jax_default_prng_impl', 'unsafe_rbg') + jax.config.update("jax_default_prng_impl", "unsafe_rbg") print("Starting train_compile.py...", flush=True) # Parse and validate configuration @@ -123,37 +142,46 @@ def main(argv: Sequence[str]) -> None: topology_mesh = get_topology_mesh(config) # Get shaped inputs - shaped_train_args, shaped_train_kwargs, state_mesh_annotations, model = get_shaped_inputs(topology_mesh, config) + ( + shaped_train_args, + shaped_train_kwargs, + state_mesh_annotations, + model, + ) = get_shaped_inputs(topology_mesh, config) # Get function to compile and shardings - func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = maxtext_utils.get_functional_train_with_signature( - train.train_step, - topology_mesh, - state_mesh_annotations, - model, - config + ( + func_to_compile, + in_shard, + out_shard, + static_argnums, + donate_argnums, + ) = maxtext_utils.get_functional_train_with_signature( + train.train_step, topology_mesh, state_mesh_annotations, model, config ) # Compile print("Jitting and compiling train step...", flush=True) compiled = jit_and_compile( - func_to_compile, - shaped_train_args, - shaped_train_kwargs, - topology_mesh, - in_shard, - out_shard, - static_argnums, - donate_argnums, - nn_partitioning.axis_rules(config.logical_axis_rules) + func_to_compile, + shaped_train_args, + shaped_train_kwargs, + topology_mesh, + in_shard, + out_shard, + static_argnums, + donate_argnums, + nn_partitioning.axis_rules(config.logical_axis_rules), ) print("Jitting and compilation complete!", flush=True) # Serialize and save the compiled object - if config.compiled_trainstep_file != '': + if config.compiled_trainstep_file != "": print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) - print(f"Successfully saved compiled object as {config.compiled_trainstep_file}") + print( + f"Successfully saved compiled object as {config.compiled_trainstep_file}" + ) print("Finished train_compile.py successfully!", flush=True) print(f"Cost analysis: {compiled.cost_analysis()}") print(f"Memory analysis: {compiled.memory_analysis()}") diff --git a/MaxText/train_tokenizer.py b/MaxText/train_tokenizer.py index dab4d99be..c8ed7e19f 100644 --- a/MaxText/train_tokenizer.py +++ b/MaxText/train_tokenizer.py @@ -1,14 +1,14 @@ """ - Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Copyright 2023 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ Train tokenizer @@ -30,26 +30,25 @@ from sentencepiece import SentencePieceTrainer _DATASET_PATH = flags.DEFINE_string( - 'dataset_path', None, 'Path to the dataset', required=True + "dataset_path", None, "Path to the dataset", required=True ) _DATASET_NAME = flags.DEFINE_string( - 'dataset_name', None, 'Name to the dataset', required=True + "dataset_name", None, "Name to the dataset", required=True ) -_VOCAB_SIZE = flags.DEFINE_integer('vocab_size', 32_768, 'Vocab size') +_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size") _MAX_CORPUS_CHARS = flags.DEFINE_integer( - 'max_corpus_chars', 10_000_000, 'Max corpus chars' + "max_corpus_chars", 10_000_000, "Max corpus chars" ) _ASSETS_PATH = flags.DEFINE_string( - 'assets_path', 'assets', 'Name to the dataset' + "assets_path", "assets", "Name to the dataset" ) _VOCAB_MODEL_NAME = flags.DEFINE_string( - 'vocab_model_name', 'tokenizer', 'Name to the dataset' + "vocab_model_name", "tokenizer", "Name to the dataset" ) + def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('text',) + dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",) ) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: @@ -62,24 +61,28 @@ def _dump_chars_to_textfile( char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + delete=False, prefix="/tmp/ds_chars" + ) as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: - line = example[k] + b'\n' + line = example[k] + b"\n" char_count += len(line) outfp.write(line) return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - assets_path: str, - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('text',)): + +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + assets_path: str, + model_path: str, + model_type: str = "unigram", + character_coverage: float = 1.0, + data_keys=("text",), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset @@ -94,65 +97,75 @@ def _train_sentencepiece(dataset: tf.data.Dataset, Returns: path to the trained sentencepiece vocabulary model. """ - if model_path.startswith('gs://'): + if model_path.startswith("gs://"): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) abs_assets_path = os.path.abspath(os.path.expanduser(assets_path)) fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) + dataset, maxchars=maxchars, data_keys=data_keys + ) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + delete=False, prefix="/tmp/sp_tmp" + ) as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join([ - f'--input={fname}', f'--vocab_size={vocab_size}', - f'--character_coverage={character_coverage}', - f'--model_prefix={model_fp.name}', f'--model_type={model_type}' + argstr = " ".join([ + f"--input={fname}", + f"--vocab_size={vocab_size}", + f"--character_coverage={character_coverage}", + f"--model_prefix={model_fp.name}", + f"--model_type={model_type}", ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. - copy_rename_path = abs_model_path + '.rntmp' - if not model_path.startswith('gs://'): + copy_rename_path = abs_model_path + ".rntmp" + if not model_path.startswith("gs://"): tf.io.gfile.makedirs(abs_assets_path) - tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) + tf.io.gfile.copy(model_fp.name + ".model", copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) - logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) + logging.info("copied %s to %s", model_fp.name + ".model", abs_model_path) else: while not tf.io.gfile.exists(abs_model_path): time.sleep(1) time.sleep(1) return abs_model_path -def train_tokenizer(dataset: tf.data.Dataset, - *, - assets_path: str, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: Tuple[str] = ('text',)): + +def train_tokenizer( + dataset: tf.data.Dataset, + *, + assets_path: str, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: Tuple[str] = ("text",), +): """tokenizer training function""" - logging.info('SentencePiece vocab not found, building one from data.') + logging.info("SentencePiece vocab not found, building one from data.") vocab_path = _train_sentencepiece( dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, assets_path=assets_path, model_path=vocab_path, - data_keys=data_keys) - logging.info('Model saved at %s', vocab_path) + data_keys=data_keys, + ) + logging.info("Model saved at %s", vocab_path) def main(argv): del argv - os.environ['TFDS_DATA_DIR'] = _DATASET_PATH.value + os.environ["TFDS_DATA_DIR"] = _DATASET_PATH.value read_config = tfds.ReadConfig( - shuffle_seed = 0, + shuffle_seed=0, ) train_ds_builder = tfds.builder(_DATASET_NAME.value) - train_ds = train_ds_builder.as_dataset(split='train', read_config=read_config, shuffle_files=True) + train_ds = train_ds_builder.as_dataset( + split="train", read_config=read_config, shuffle_files=True + ) train_tokenizer( train_ds, assets_path=_ASSETS_PATH.value, @@ -162,5 +175,5 @@ def main(argv): ) -if __name__ == '__main__': +if __name__ == "__main__": app.run(main) diff --git a/MaxText/vertex_tensorboard.py b/MaxText/vertex_tensorboard.py index 1cb438db5..18b225a51 100644 --- a/MaxText/vertex_tensorboard.py +++ b/MaxText/vertex_tensorboard.py @@ -1,18 +1,18 @@ """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Utilities for Tensorboard in Vertex AI.""" @@ -40,7 +40,7 @@ def __del__(self): def setup(self): """Creates Tensorboard instance and Experiment in Vertex AI. - + Returns: URL to view Vertex Tensorboard created in Google Cloud Project. """ @@ -49,24 +49,30 @@ def setup(self): vertex_tensorboard_project = os.environ.get("TENSORBOARD_PROJECT") vertex_tensorboard_region = os.environ.get("TENSORBOARD_REGION") if not vertex_tensorboard_project or not vertex_tensorboard_region: - max_logging.log("Either config.vertex_tensorboard_project or config.vertex_tensorboard_region is not set.") + max_logging.log( + "Either config.vertex_tensorboard_project or config.vertex_tensorboard_region is not set." + ) return None # Create Vertex Tensorboard instance vertex_tensorboard_name = os.environ.get("TENSORBOARD_NAME") - instance_id = tensorboard.create_instance(project=vertex_tensorboard_project, - location=vertex_tensorboard_region, - tensorboard_name=vertex_tensorboard_name) + instance_id = tensorboard.create_instance( + project=vertex_tensorboard_project, + location=vertex_tensorboard_region, + tensorboard_name=vertex_tensorboard_name, + ) # Failed to create Vertex Tensorboard instance if instance_id is None: return None # Create Vertex Experiment vertex_experiment_name = os.environ.get("EXPERIMENT_NAME") - _, tensorboard_url = tensorboard.create_experiment(project=vertex_tensorboard_project, - location=vertex_tensorboard_region, - experiment_name=vertex_experiment_name, - tensorboard_name=vertex_tensorboard_name) + _, tensorboard_url = tensorboard.create_experiment( + project=vertex_tensorboard_project, + location=vertex_tensorboard_region, + experiment_name=vertex_experiment_name, + tensorboard_name=vertex_tensorboard_name, + ) return tensorboard_url def upload_data(self, tensorboard_dir): @@ -80,22 +86,33 @@ def upload_data(self, tensorboard_dir): tensorboard_name = os.environ.get("TENSORBOARD_NAME") experiment_name = os.environ.get("EXPERIMENT_NAME") - if not tensorboard_project or not tensorboard_region or not tensorboard_name or not experiment_name: - max_logging.log("Vertex Tensorboard configurations are not set. Data will not be uploaded to Vertex AI.") + if ( + not tensorboard_project + or not tensorboard_region + or not tensorboard_name + or not experiment_name + ): + max_logging.log( + "Vertex Tensorboard configurations are not set. Data will not be uploaded to Vertex AI." + ) self.uploader_flag = False - max_logging.log(f"Data will be uploaded to Vertex Tensorboard instance: {tensorboard_name} " - f"and Experiment: {experiment_name} in {tensorboard_region}.") - uploader.start_upload_to_tensorboard(project=tensorboard_project, - location=tensorboard_region, - experiment_name=experiment_name, - tensorboard_name=tensorboard_name, - logdir=tensorboard_dir) + max_logging.log( + f"Data will be uploaded to Vertex Tensorboard instance: {tensorboard_name} " + f"and Experiment: {experiment_name} in {tensorboard_region}." + ) + uploader.start_upload_to_tensorboard( + project=tensorboard_project, + location=tensorboard_region, + experiment_name=experiment_name, + tensorboard_name=tensorboard_name, + logdir=tensorboard_dir, + ) self.uploader_flag = True def configure_vertex_tensorboard(self, config): """Creates Vertex Tensorboard and start thread to upload data to Vertex Tensorboard.""" - if jax.process_index()==0: + if jax.process_index() == 0: if not os.environ.get("TENSORBOARD_PROJECT"): if not config.vertex_tensorboard_project: os.environ["TENSORBOARD_PROJECT"] = max_utils.get_project() @@ -107,16 +124,24 @@ def configure_vertex_tensorboard(self, config): if not os.environ.get("TENSORBOARD_NAME"): vertex_tensorboard_project = os.environ.get("TENSORBOARD_PROJECT") - os.environ["TENSORBOARD_NAME"] = f"{vertex_tensorboard_project}-tb-instance" + os.environ[ + "TENSORBOARD_NAME" + ] = f"{vertex_tensorboard_project}-tb-instance" if not os.environ.get("EXPERIMENT_NAME"): os.environ["EXPERIMENT_NAME"] = config.run_name - if config.use_vertex_tensorboard: # running MaxText on GCE + if config.use_vertex_tensorboard: # running MaxText on GCE tensorboard_url = self.setup() if tensorboard_url is None: - raise ValueError("Unable to create Tensorboard and Experiment in Vertex AI.") - max_logging.log(f"View your Vertex AI Tensorboard at: {tensorboard_url}") + raise ValueError( + "Unable to create Tensorboard and Experiment in Vertex AI." + ) + max_logging.log( + f"View your Vertex AI Tensorboard at: {tensorboard_url}" + ) self.upload_data(config.tensorboard_dir) - elif os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): # running MaxText via XPK + elif os.environ.get( + "UPLOAD_DATA_TO_TENSORBOARD" + ): # running MaxText via XPK self.upload_data(config.tensorboard_dir) diff --git a/code_style.sh b/code_style.sh new file mode 100644 index 000000000..6dc251d2e --- /dev/null +++ b/code_style.sh @@ -0,0 +1,32 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Clean up Python codes using Pylint & Pyink +# Googlers: please run `sudo apt install pipx; pipx install pylint --force; pipx install pyink==23.10.0` in advance + +set -e + +FOLDERS_TO_FORMAT=("MaxText" "pedagogical_examples") + +for folder in "${FOLDERS_TO_FORMAT[@]}" +do + pyink "$folder" --pyink-indentation=2 --line-length=80 +done + +for folder in "${FOLDERS_TO_FORMAT[@]}" +do + pylint "./$folder" --recursive=y --check-quote-consistency=y +done + +echo "Successfully clean up all codes." diff --git a/pedagogical_examples/non_spmd.py b/pedagogical_examples/non_spmd.py index 743bd4138..9918cbec3 100644 --- a/pedagogical_examples/non_spmd.py +++ b/pedagogical_examples/non_spmd.py @@ -1,22 +1,22 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -''' +""" This programs demonstrates embarrassingly parallelizable non-SPMD computations in Jax, in this case by having each process_index run its own computation. The same approach can be extended for non-embarrassingly parallelizable computations. @@ -24,7 +24,7 @@ then using a `host_local_array_to_global_array` to reshard into a new global array. An important limitation of this approach is that we cannot overlap communication and computation between the different kernel calls. -''' +""" import jax @@ -34,23 +34,21 @@ import numpy as np - - # Notice this is jax.local_devices(), not jax.devices(). Hence each process (on TPUVMs, each VM) will run separate programs # on its mesh. mesh = Mesh(np.array(jax.local_devices()), ["data"]) sharding = jax.sharding.NamedSharding(mesh, PartitionSpec(None)) idx = jax.process_index() + # Example step depends on idx which is different on each program def example_step(): - return idx * jax.numpy.ones((idx+1)) + return idx * jax.numpy.ones((idx + 1)) + jit_func = jax.jit( - example_step, - out_shardings=sharding, - ) + example_step, + out_shardings=sharding, +) print(f"{idx=} -> {jit_func()=}") - - diff --git a/pedagogical_examples/shardings.py b/pedagogical_examples/shardings.py index 85198c3ee..94d104ddf 100644 --- a/pedagogical_examples/shardings.py +++ b/pedagogical_examples/shardings.py @@ -1,22 +1,22 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -'''This script is used to measure the performance of different sharding schemes on TPU.''' +"""This script is used to measure the performance of different sharding schemes on TPU.""" from absl import app from absl import flags @@ -32,79 +32,76 @@ from typing import Sequence parser = argparse.ArgumentParser( - description="Experiment different sharding techniques with a simple NN.\ + description="Experiment different sharding techniques with a simple NN.\ Ensure 1) The product of dcn dimensions == number of slices \ 2) product of ici dimension = number of devices per slice" - ) +) parser.add_argument( - "--profiler_path", "-p", + "--profiler_path", + "-p", required=False, default="", help="Path to the profiler where the script will write to.", - type=str + type=str, ) parser.add_argument( - "--embedding_dimension", "-d", - required=False, - default=2048, - type=int + "--embedding_dimension", "-d", required=False, default=2048, type=int ) parser.add_argument( - "--batch_size", "-b", - required=False, - default=131072, - type=int + "--batch_size", "-b", required=False, default=131072, type=int ) +parser.add_argument("--num_layers", "-n", required=False, default=4, type=int) parser.add_argument( - "--num_layers", "-n", - required=False, - default=4, - type=int -) -parser.add_argument( - "--dcn_data_parallelism", "-dd", + "--dcn_data_parallelism", + "-dd", help="N-way Data Parallelism across slices", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--dcn_fsdp_parallelism", "-df", + "--dcn_fsdp_parallelism", + "-df", help="Fsdp parallelism across slices that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--dcn_tensor_parallelism", "-dt", + "--dcn_tensor_parallelism", + "-dt", help="Tensor parallelism across slices that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--ici_data_parallelism", "-id", + "--ici_data_parallelism", + "-id", help="Data parallelism within each slice that is expected to be 1 in most cases", required=False, default=1, - type=int + type=int, ) parser.add_argument( - "--ici_fsdp_parallelism", "-if", + "--ici_fsdp_parallelism", + "-if", help="Number of shards for Fsdp Parallelism within each slice.", required=False, default=4, - type=int + type=int, ) parser.add_argument( - "--ici_tensor_parallelism", "-it", + "--ici_tensor_parallelism", + "-it", help="Number of shards for Tensor Parallelism within each slice.", required=False, default=1, - type=int + type=int, ) args = parser.parse_args() + def main(_argv: Sequence[str]) -> None: def activate_profiler(profiler_path): if profiler_path: @@ -115,22 +112,30 @@ def deactivate_profiler(profiler_path): if profiler_path: jax.profiler.stop_trace() - def simple_timeit(f, tries = 5, verbose = True): - '''Simple utility to time a function for multiple runs''' + def simple_timeit(f, tries=5, verbose=True): + """Simple utility to time a function for multiple runs""" outcomes = [] - f() #warm it up! + f() # warm it up! for _ in range(tries): s = datetime.datetime.now() f() e = datetime.datetime.now() - outcomes.append((e-s).total_seconds()) - average_time = sum(outcomes)/len(outcomes) + outcomes.append((e - s).total_seconds()) + average_time = sum(outcomes) / len(outcomes) if verbose: print(f"average time: {average_time}, timings (seconds) {outcomes}") return average_time - dcn_parallelism = [args.dcn_data_parallelism, args.dcn_fsdp_parallelism, args.dcn_tensor_parallelism] - ici_parallelism = [args.ici_data_parallelism, args.ici_fsdp_parallelism, args.ici_tensor_parallelism] + dcn_parallelism = [ + args.dcn_data_parallelism, + args.dcn_fsdp_parallelism, + args.dcn_tensor_parallelism, + ] + ici_parallelism = [ + args.ici_data_parallelism, + args.ici_fsdp_parallelism, + args.ici_tensor_parallelism, + ] devices = jax.devices() num_devices = len(devices) @@ -138,17 +143,22 @@ def simple_timeit(f, tries = 5, verbose = True): assert len(devices) > 1, "You must have at least two devices" # Assert that we have correct inputs of sharding that fit the number of chips - assert np.product(dcn_parallelism) * np.product(ici_parallelism) == num_devices, f"Number of devices {num_devices} \ + assert ( + np.product(dcn_parallelism) * np.product(ici_parallelism) == num_devices + ), f"Number of devices {num_devices} \ does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}" - multi_slice_env = hasattr(jax.devices()[0], 'slice_index') + multi_slice_env = hasattr(jax.devices()[0], "slice_index") # Create device mesh if multi_slice_env: - assert args.dcn_data_parallelism == 1 + max(x.slice_index for x in jax.devices()), \ - f"Number of slices given {args.dcn_data_parallelism} \ + assert args.dcn_data_parallelism == 1 + max( + x.slice_index for x in jax.devices() + ), f"Number of slices given {args.dcn_data_parallelism} \ does not match the number fetched from jax devices {jax.devices()[0]}" - devices_array = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism) + devices_array = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, dcn_parallelism + ) else: devices_array = mesh_utils.create_device_mesh(ici_parallelism) @@ -156,27 +166,31 @@ def simple_timeit(f, tries = 5, verbose = True): mesh = Mesh(devices_array, ["data", "fsdp", "tensor"]) - data_sharding = PartitionSpec(("data", "fsdp"), "tensor") + data_sharding = PartitionSpec(("data", "fsdp"), "tensor") # We assume parameters are stored in a decreasing order of dimension size parameter_sharding = PartitionSpec("tensor", "fsdp") BATCH = len(jax.devices()) * args.batch_size D_EMB = args.embedding_dimension - D_FF = 4 * D_EMB + D_FF = 4 * D_EMB NUM_LAYERS = args.num_layers parameters = 2 * D_FF * D_EMB * NUM_LAYERS parameter_bytes = 2 * parameters - activation_bytes = 2 * ( BATCH * ( D_FF+D_EMB) ) * NUM_LAYERS + activation_bytes = 2 * (BATCH * (D_FF + D_EMB)) * NUM_LAYERS memory_bytes = parameter_bytes + activation_bytes - print(f"total {memory_bytes/1e9} GB, parameters {parameter_bytes/1e9} GB, activations {activation_bytes/1e9} GB") + print( + f"total {memory_bytes/1e9} GB, parameters {parameter_bytes/1e9} GB, activations {activation_bytes/1e9} GB" + ) def gen_layer(random_key): - keys = jax.random.split(random_key, num = 4) + keys = jax.random.split(random_key, num=4) return { - "EMB2FF" : 1e-4 * jax.random.normal( keys[0], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), - "FF2EMB" : 1e-4 * jax.random.normal( keys[1], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), + "EMB2FF": 1e-4 + * jax.random.normal(keys[0], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), + "FF2EMB": 1e-4 + * jax.random.normal(keys[1], (D_FF, D_EMB), dtype=jax.numpy.bfloat16), } def gen_layers(random_key): @@ -187,8 +201,9 @@ def gen_layers(random_key): return tuple(layers) def gen_data(random_key): - return jax.random.uniform(random_key, (BATCH, D_EMB), dtype=jax.numpy.bfloat16 ) - + return jax.random.uniform( + random_key, (BATCH, D_EMB), dtype=jax.numpy.bfloat16 + ) def multiply_layer(in_act, in_layer): with jax.named_scope("M1"): @@ -210,49 +225,53 @@ def multiply_layers(in_act, in_layers): return x, in_layers def multiply_layers_with_loss(in_act, in_layers): - x, _ = multiply_layers(in_act, in_layers) + x, _ = multiply_layers(in_act, in_layers) return jax.numpy.sum(x) - multiply_layers_and_grad = jax.value_and_grad(multiply_layers_with_loss, argnums=[1]) + multiply_layers_and_grad = jax.value_and_grad( + multiply_layers_with_loss, argnums=[1] + ) def training_step(in_act, in_layers): _, grad_layers = multiply_layers_and_grad(in_act, in_layers) - out_layers = jax.tree_map(lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0]) + out_layers = jax.tree_map( + lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0] + ) return out_layers - print("finished includes ", flush = True) + print("finished includes ", flush=True) replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding) parameter_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding + ) data_pspec_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding + ) jit_func = jax.jit( - training_step, - in_shardings=(replicated_sharding, parameter_mesh_shardings), - out_shardings=data_pspec_shardings, - ) + training_step, + in_shardings=(replicated_sharding, parameter_mesh_shardings), + out_shardings=data_pspec_shardings, + ) data_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) + lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding + ) jit_gen_data = jax.jit( - gen_data, - in_shardings=None, - out_shardings=data_mesh_shardings - ) + gen_data, in_shardings=None, out_shardings=data_mesh_shardings + ) parameter_mesh_shardings = jax.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding + ) jit_gen_layers = jax.jit( - gen_layers, - in_shardings=None, - out_shardings=parameter_mesh_shardings - ) + gen_layers, in_shardings=None, out_shardings=parameter_mesh_shardings + ) # starting the profiler outside `with` statement, # will call it right before the computation once b/301309635 is resolved @@ -261,14 +280,21 @@ def training_step(in_act, in_layers): key = jax.random.PRNGKey(0) presharded_X = jax.block_until_ready(jit_gen_data(key)) presharded_layers = jax.block_until_ready(jit_gen_layers(key)) - TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices()) - time = simple_timeit(lambda : jax.block_until_ready(jit_func(presharded_X, presharded_layers))) - print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush = True) + TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices()) + time = simple_timeit( + lambda: jax.block_until_ready(jit_func(presharded_X, presharded_layers)) + ) + print( + f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", + flush=True, + ) deactivate_profiler(args.profiler_path) + def parse_flags(argv): return parser.parse_args(argv[1:]) + if __name__ == "__main__": flags.FLAGS.mark_as_parsed() app.run(main, flags_parser=parse_flags) diff --git a/pedagogical_examples/shmap_collective_matmul.py b/pedagogical_examples/shmap_collective_matmul.py index fe8c38be9..c872a5cee 100644 --- a/pedagogical_examples/shmap_collective_matmul.py +++ b/pedagogical_examples/shmap_collective_matmul.py @@ -1,24 +1,25 @@ #!/usr/bin/python3 """ - Copyright 2023 Google LLC +Copyright 2023 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" -'''This script is an example collective matmul.''' +"""This script is an example collective matmul.""" import os + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" import numpy as np @@ -50,50 +51,67 @@ import string import datetime -def simple_timeit(f, *args, tries = 10, trace_base_dir = None, task = None): - '''Simple utility to time a function for multiple runs''' - assert task is not None - trace_name = f"t_{task}_" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) +def simple_timeit(f, *args, tries=10, trace_base_dir=None, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(10) + ) + + if trace_base_dir: + trace_dir = f"{trace_base_dir}/{trace_name}" + else: + trace_dir = None - if trace_base_dir: - trace_dir = f"{trace_base_dir}/{trace_name}" - else: - trace_dir = None + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + if trace_dir: + jax.profiler.start_trace(trace_dir) - outcomes_ms = [] - jax.block_until_ready(f(*args)) #warm it up! - if trace_dir: - jax.profiler.start_trace(trace_dir) + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + if trace_dir: + jax.profiler.stop_trace() - for _ in range(tries): - s = datetime.datetime.now() - jax.block_until_ready(f(*args)) - e = datetime.datetime.now() - outcomes_ms.append(1000*(e-s).total_seconds()) - if trace_dir: - jax.profiler.stop_trace() + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}") + return average_time_ms - average_time_ms = sum(outcomes_ms)/len(outcomes_ms) - print(f"{task}: average time milliseconds: {average_time_ms:.2f}") - return average_time_ms # gen data def gen_data_fn(): - key = jax.random.PRNGKey(np.random.randint(0, 256)) - activations = jax.random.normal(key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16) - weights = jax.random.normal(key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16) - return activations, weights + key = jax.random.PRNGKey(np.random.randint(0, 256)) + activations = jax.random.normal( + key, shape=(batch_size, seq_len, emb_dim), dtype=jnp.bfloat16 + ) + weights = jax.random.normal( + key, shape=(emb_dim, n_heads, head_dim), dtype=jnp.bfloat16 + ) + return activations, weights + data_fn = pjit( gen_data_fn, - out_shardings=(P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None), P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None)), + out_shardings=( + P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None), + P(MESH_FSDP_AXIS, MESH_TENSOR_AXIS, None), + ), ) + def matmul(activations, weights): - return jnp.einsum("bsE,Ehd->bshd", activations, weights) + return jnp.einsum("bsE,Ehd->bshd", activations, weights) + + +jit_matmul = pjit( + matmul, out_shardings=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None) +) -jit_matmul = pjit(matmul, out_shardings=P(MESH_FSDP_AXIS, None, MESH_TENSOR_AXIS, None)) @partial( shard_map, @@ -106,29 +124,46 @@ def matmul(activations, weights): check_rep=False, ) def collective_matmul(activations, weights): - print(f"sh_map {activations.shape=} {weights.shape=}") - - axis_size = jax.lax.psum(1, axis_name=MESH_TENSOR_AXIS) - axis_index = jax.lax.axis_index(axis_name=MESH_TENSOR_AXIS) - # The current sequence chunk - chunk_size = activations.shape[1] - mid_chunk = chunk_size // 2 - # create accum buffer - accum = jnp.zeros( - ( - activations.shape[0], - activations.shape[1] * axis_size, - weights.shape[-2], - weights.shape[-1], - ), - dtype=activations.dtype, - ) + print(f"sh_map {activations.shape=} {weights.shape=}") + + axis_size = jax.lax.psum(1, axis_name=MESH_TENSOR_AXIS) + axis_index = jax.lax.axis_index(axis_name=MESH_TENSOR_AXIS) + # The current sequence chunk + chunk_size = activations.shape[1] + mid_chunk = chunk_size // 2 + # create accum buffer + accum = jnp.zeros( + ( + activations.shape[0], + activations.shape[1] * axis_size, + weights.shape[-2], + weights.shape[-1], + ), + dtype=activations.dtype, + ) + + # compute first chunk + update = jnp.einsum("bsE,Ehd->bshd", activations, weights) + update_index = (0, axis_index * chunk_size, 0, 0) + accum = jax.lax.dynamic_update_slice(accum, update, update_index) + activation_forward, activation_backward = jnp.split(activations, 2, axis=1) + activation_forward = jax.lax.ppermute( + activation_forward, + axis_name=MESH_TENSOR_AXIS, + perm=[(j, (j + 1) % axis_size) for j in range(axis_size)], + ) + activation_backward = jax.lax.ppermute( + activation_backward, + axis_name=MESH_TENSOR_AXIS, + perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], + ) + + # split activations into chunks and send + def scanned_call(i, carrys): + accum, activation_forward, activation_backward = carrys + update_forward = jnp.einsum("bsE,Ehd->bshd", activation_forward, weights) + update_backward = jnp.einsum("bsE,Ehd->bshd", activation_backward, weights) - # compute first chunk - update = jnp.einsum("bsE,Ehd->bshd", activations, weights) - update_index = (0, axis_index * chunk_size, 0, 0) - accum = jax.lax.dynamic_update_slice(accum, update, update_index) - activation_forward, activation_backward = jnp.split(activations, 2, axis=1) activation_forward = jax.lax.ppermute( activation_forward, axis_name=MESH_TENSOR_AXIS, @@ -140,70 +175,65 @@ def collective_matmul(activations, weights): perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], ) - # split activations into chunks and send - def scanned_call(i, carrys): - accum, activation_forward, activation_backward = carrys - update_forward = jnp.einsum("bsE,Ehd->bshd", activation_forward, weights) - update_backward = jnp.einsum("bsE,Ehd->bshd", activation_backward, weights) - - activation_forward = jax.lax.ppermute( - activation_forward, - axis_name=MESH_TENSOR_AXIS, - perm=[(j, (j + 1) % axis_size) for j in range(axis_size)], - ) - activation_backward = jax.lax.ppermute( - activation_backward, - axis_name=MESH_TENSOR_AXIS, - perm=[(j, (j - 1) % axis_size) for j in range(axis_size)], - ) - - forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size - backward_update_index = ((axis_index + i + 1) % axis_size) * chunk_size + mid_chunk - - accum = jax.lax.dynamic_update_slice(accum, update_forward, (0, forward_update_index, 0, 0)) - accum = jax.lax.dynamic_update_slice(accum, update_backward, (0, backward_update_index, 0, 0)) - return (accum, activation_forward, activation_backward) - - print(f"{accum.shape=}") - - accum, _, _ = jax.lax.fori_loop( - 0, (axis_size - 1), scanned_call, (accum, activation_forward, activation_backward) - ) - return accum - -with global_mesh: - activations, weights = data_fn() - - jax.block_until_ready(activations) - jax.block_until_ready(weights) - - @jax.jit - def run_naive(_activations, _weights): - with jax.named_scope("naive_matmul"): - outputs = jit_matmul(_activations, _weights) - return outputs - - @jax.jit - def run_collective(_activations, _weights): - with jax.named_scope("collective_matmul"): - manual_outputs = jax.jit(collective_matmul)(_activations, _weights) - return manual_outputs - - - - - naive_outputs = run_naive(activations, weights) - collective_outputs = run_collective(activations, weights) - - print(f"input {activations.shape=} {activations.addressable_shards[0].data.shape=}") - print(f"input {weights.shape=} {weights.addressable_shards[0].data.shape=}") - print(f"naive_outputs {naive_outputs.shape=} {naive_outputs.addressable_shards[0].data.shape=}") - print(f"collective_outputs {collective_outputs.shape=} {collective_outputs.addressable_shards[0].data.shape=}") + forward_update_index = ((axis_index - i - 1) % axis_size) * chunk_size + backward_update_index = ( + (axis_index + i + 1) % axis_size + ) * chunk_size + mid_chunk + accum = jax.lax.dynamic_update_slice( + accum, update_forward, (0, forward_update_index, 0, 0) + ) + accum = jax.lax.dynamic_update_slice( + accum, update_backward, (0, backward_update_index, 0, 0) + ) + return (accum, activation_forward, activation_backward) + print(f"{accum.shape=}") - assert jnp.allclose(naive_outputs, collective_outputs), "Two algorithms should match but don't" + accum, _, _ = jax.lax.fori_loop( + 0, + (axis_size - 1), + scanned_call, + (accum, activation_forward, activation_backward), + ) + return accum - simple_timeit(run_naive, activations, weights, task = "naive") - simple_timeit(run_collective, activations, weights, task = "collective") +with global_mesh: + activations, weights = data_fn() + + jax.block_until_ready(activations) + jax.block_until_ready(weights) + + @jax.jit + def run_naive(_activations, _weights): + with jax.named_scope("naive_matmul"): + outputs = jit_matmul(_activations, _weights) + return outputs + + @jax.jit + def run_collective(_activations, _weights): + with jax.named_scope("collective_matmul"): + manual_outputs = jax.jit(collective_matmul)(_activations, _weights) + return manual_outputs + + naive_outputs = run_naive(activations, weights) + collective_outputs = run_collective(activations, weights) + + print( + f"input {activations.shape=} {activations.addressable_shards[0].data.shape=}" + ) + print(f"input {weights.shape=} {weights.addressable_shards[0].data.shape=}") + print( + f"naive_outputs {naive_outputs.shape=} {naive_outputs.addressable_shards[0].data.shape=}" + ) + print( + f"collective_outputs {collective_outputs.shape=} {collective_outputs.addressable_shards[0].data.shape=}" + ) + + assert jnp.allclose( + naive_outputs, collective_outputs + ), "Two algorithms should match but don't" + + simple_timeit(run_naive, activations, weights, task="naive") + simple_timeit(run_collective, activations, weights, task="collective") diff --git a/pylintrc b/pylintrc index 4897238f0..59412c3a1 100644 --- a/pylintrc +++ b/pylintrc @@ -29,6 +29,11 @@ jobs=4 # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no +# check python modules in the dir recursively +recursive=y + +# score threshold under which it fails +fail-under=9 [MESSAGES CONTROL] @@ -249,7 +254,10 @@ generated-members= [FORMAT] # Maximum number of characters on a single line. -max-line-length=125 +max-line-length=80 + +# Check whether character used as a quote delimiter is used inconsistently +check-quote-consistency=y # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt # lines made too long by directives to pytype.