diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml new file mode 100644 index 000000000..3a660dec9 --- /dev/null +++ b/.github/workflows/build_packages.yml @@ -0,0 +1,158 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: Build packages + +on: + workflow_dispatch: + schedule: + # Runs at 11:00 AM UTC, which is 3:00 AM PST (UTC-8) + - cron: '0 11 * * *' + +permissions: + contents: read + +jobs: + # Note: metadata generation could happen in a separate trigger/schedule + # workflow. For cross platform builds, it's useful to just generate the + # metadata on Linux and pass that to later jobs using artifacts. + setup_metadata: + if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} + runs-on: ubuntu-24.04 + outputs: + shark_package_version: ${{ steps.version.outputs.shark_package_version }} + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + submodules: false + - name: Setup Python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: 3.12 + cache: "pip" + - name: Install Python packages + run: | + pip install packaging + pip freeze + - name: Generate release candidate versions + id: version_rc + run: | + version_suffix="$(printf 'rc%(%Y%m%d)T')" + echo "version_suffix=${version_suffix}" >> $GITHUB_ENV + sharktank_package_version=$(python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} sharktank) + shortfin_package_version=$(python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} shortfin) + sharkai_package_version=$(python3 build_tools/python_deploy/compute_common_version.py -rc --version-suffix=${version_suffix} --write-json) + - name: Upload version_local.json + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: version_local + path: | + sharktank/version_local.json + shortfin/version_local.json + shark-ai/version_local.json + + build_packages: + name: "${{ matrix.package }} :: ${{ matrix.platform }} :: ${{ matrix.python-version }}" + runs-on: ${{ matrix.runs-on }} + permissions: + contents: write + needs: [setup_metadata] + strategy: + fail-fast: false + matrix: + include: + # Ubuntu packages. + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: shark-ai + python-version: cp311-cp311 # Ignored (generic wheel), set for workflow naming + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: sharktank + python-version: cp311-cp311 # Ignored (generic wheel), set for workflow naming + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: shortfin + python-version: cp310-cp310 + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: shortfin + python-version: cp311-cp311 + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: shortfin + python-version: cp312-cp312 + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: shortfin + python-version: cp313-cp313 + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: shortfin + python-version: cp313-cp313t + + # TODO(#130): macOS platform + # TODO(#130): Windows platform + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: "c" # Windows can hit path length limits, so use a short path. + submodules: false + + - name: Download version_local.json + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 + with: + name: version_local + path: ./c/ + merge-multiple: true + + - name: Build shark-ai (Linux x86_64) + if: "matrix.package == 'shark-ai' && matrix.platform == 'linux-x86_64'" + env: + OUTPUT_DIR: "${{ github.workspace }}/bindist" + run: | + [ -e ./bindist/* ] && rm ./bindist/* + ./c/build_tools/python_deploy/write_requirements.py --version-suffix=${version_suffix} + ./c/shark-ai/build_tools/build_linux_package.sh + + - name: Build sharktank (Linux x86_64) + if: "matrix.package == 'sharktank' && matrix.platform == 'linux-x86_64'" + env: + OUTPUT_DIR: "${{ github.workspace }}/bindist" + run: | + [ -e ./bindist/* ] && rm ./bindist/* + ./c/sharktank/build_tools/build_linux_package.sh + + - name: Build shortfin (Linux x86_64, ${{ matrix.python-version }}) + if: "matrix.package == 'shortfin' && matrix.platform == 'linux-x86_64'" + env: + OUTPUT_DIR: "${{ github.workspace }}/bindist" + OVERRIDE_PYTHON_VERSIONS: "${{ matrix.python-version }}" + run: | + [ -e ./bindist/* ] && rm ./bindist/* + ./c/shortfin/build_tools/build_linux_package.sh + + - name: Upload python wheels + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + if-no-files-found: error + name: snapshot-${{ matrix.package }}-${{ matrix.platform }}-${{ matrix.python-version }} + path: bindist + + - name: Release python wheels + uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 + with: + artifacts: bindist/*.whl + tag: "dev-wheels" + name: "dev-wheels" + body: "Automatic snapshot release of shark-ai python wheels." + removeArtifacts: false + allowUpdates: true + replacesArtifacts: true + makeLatest: false diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml new file mode 100644 index 000000000..644066094 --- /dev/null +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -0,0 +1,93 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: Llama Benchmarking Tests + +on: + workflow_dispatch: + schedule: + # Weekdays at 4:00 AM UTC = 9:00 PM PST. + - cron: "0 4 * * 1-5" + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_llama_large: + if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} + name: "Llama Benchmarking Tests" + strategy: + matrix: + version: [3.11] + fail-fast: false + runs-on: llama-mi300x-1 + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv + steps: + - name: Get Current Date + id: date + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-turbine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + + # Test with nightly releases, not what iree-turbine uses. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run llama tests + run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/llm/llama/benchmark/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/llama/benchmark + destination_dir: ./llm/llama/benchmark + keep_files: true + + - name: Upload llama executable files + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: llama-files + path: ${{ github.workspace }}/${{ steps.date.outputs.date }} diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml new file mode 100644 index 000000000..a8c315ec8 --- /dev/null +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -0,0 +1,85 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: Llama Benchmarking 8B Tests + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_llama_quick: + name: "Llama Benchmarking 8B Tests" + strategy: + matrix: + version: [3.11] + fail-fast: false + runs-on: llama-mi300x-1 + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + VENV_DIR: ${{ github.workspace }}/.venv + steps: + - name: Get Current Date + id: date + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-turbine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + + # Test with nightly releases, not what iree-turbine uses. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run llama 8b f16 decomposed test + run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-quick-llama-test + + - name: Upload llama executable files + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + with: + name: llama-files + path: ${{ github.workspace }}/${{ steps.date.outputs.date }} diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml new file mode 100644 index 000000000..355ffcf8b --- /dev/null +++ b/.github/workflows/ci-sdxl.yaml @@ -0,0 +1,110 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - shortfin - SDXL + +on: + workflow_dispatch: + pull_request: + paths: + - '.github/workflows/ci-sdxl.yaml' + - 'shortfin/**' + push: + branches: + - main + paths: + - '.github/workflows/ci-sdxl.yaml' + - 'shortfin/**' + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +env: + IREE_REPO_DIR: ${{ github.workspace }}/iree + LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ + +jobs: + build-and-test: + name: Build and test + runs-on: mi300-sdxl-kernel + + steps: + - name: Install dependencies + run: | + if dpkg -s cmake &>/dev/null; then + echo 'cmake is installed' + else + sudo apt install cmake -y + fi + if dpkg -s ninja-build &>/dev/null; then + echo 'ninja is installed' + else + sudo apt install ninja -y + fi + + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + submodules: false + + - name: Checkout IREE repo + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: iree-org/iree + path: ${{ env.IREE_REPO_DIR }} + submodules: false + ref: iree-3.0.0rc20241118 + + - name: Initalize IREE submodules + working-directory: ${{ env.IREE_REPO_DIR }} + run : | + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ + + - name: Setup Python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: "3.12" + cache: "pip" + - name: Install Python packages + # TODO: Switch to `pip install -r requirements.txt -e shortfin/`. + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + pip install -r requirements-tests.txt + pip install -r requirements-iree-compiler.txt + pip freeze + + - name: Build shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + mkdir build + cmake -GNinja \ + -S. \ + -Bbuild \ + -DCMAKE_C_COMPILER=clang-18 \ + -DCMAKE_CXX_COMPILER=clang++-18 \ + -DSHORTFIN_BUNDLE_DEPS=ON \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + cmake --build build --target all + pip install -v -e build/ + + - name: Test shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + ctest --timeout 30 --output-on-failure --test-dir build + HIP_VISIBLE_DEVICES=0 pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu diff --git a/.github/workflows/test.yaml b/.github/workflows/ci-sglang-benchmark.yml similarity index 54% rename from .github/workflows/test.yaml rename to .github/workflows/ci-sglang-benchmark.yml index 6eb519717..504e7e5e3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -1,10 +1,16 @@ -name: Integration Tests +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: SGLang Llama Benchmarking Tests on: workflow_dispatch: schedule: - # Weekdays at 13:00 UTC = 05:00 PST / 06:00 PDT. - - cron: "5 4 * * 1-5" + # Weekdays at 4:00 AM UTC = 9:00 PM PST. + - cron: "0 4 * * 1-5" concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels @@ -15,31 +21,35 @@ concurrency: cancel-in-progress: true jobs: - test_llama: - name: "Integration Tests - llama" + sglang_bench_serve: + if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} + name: "SGLang Serving Benchmark Tests" strategy: matrix: version: [3.11] - os: [ubuntu-latest, windows-latest] fail-fast: false - runs-on: ${{matrix.os}} + runs-on: llama-mi300x-3 defaults: run: shell: bash env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" steps: + - name: Get Current Date + id: date + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} @@ -53,42 +63,27 @@ jobs: # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ # Try with the latest nightly releases, not what iree-turbine pins. # We could also pin to a known working or stable version. # This should eventually stabilize. Do the best we can for now. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-compiler \ - iree-runtime \ + iree-base-compiler==3.0.0rc20241118 \ + iree-base-runtime==3.0.0rc20241118 \ "numpy<2.0" - - name: Run llama test - run: ./build_tools/integration_tests/llama_export_compile_serve.sh - - test_punet: - name: "Integration Tests - punet" - runs-on: nodai-amdgpu-mi250-x86-64 - env: - VENV_DIR: ${{ github.workspace }}/.venv - steps: - - name: "Checkout Code" - uses: actions/checkout@v3 - - - name: "Setup Python venv" - run: python3 -m venv ${VENV_DIR} + - name: Install SGLang + run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" - - name: Install pip deps - run: | - source ${VENV_DIR}/bin/activate - python -m pip install --no-compile --upgrade pip - pip install --no-compile -r pytorch-rocm-requirements.txt - pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" - pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ + - name: Launch Shortfin Server + run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html - - name: Run punet tests - run: | - source ${VENV_DIR}/bin/activate - pytest -v sharktank/ -m model_punet + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/sglang + destination_dir: ./llm/sglang + keep_files: true diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml new file mode 100644 index 000000000..1c382617d --- /dev/null +++ b/.github/workflows/ci-sglang-integration-tests.yml @@ -0,0 +1,84 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: SGLang Llama Integration Tests + +on: + workflow_dispatch: + schedule: + # Run periodically, every 4 hours. This is ran periodically with the + # intent of catching regressions early, and allowing for those + # regressions to be easily triaged to a small subset of commits. + - cron: '0 */4 * * *' + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + sglang_bench_serve: + name: "SGLang Integration Tests" + strategy: + matrix: + version: [3.11] + fail-fast: false + runs-on: llama-mi300x-3 + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: Get Current Date + id: date + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ + + # Use newest possible releases to be able to track commits that may + # cause errors. + pip install -f https://iree.dev/pip-release-links.html --upgrade \ + iree-base-compiler \ + iree-base-runtime \ + "numpy<2.0" + + - name: Install SGLang + run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + + - name: Install sentence_transformers + run: pip install sentence_transformers + + - name: Run Integration Tests + run: pytest -v app_tests/integration_tests/llm/sglang --log-cli-level=INFO diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci-shark-ai.yml similarity index 58% rename from .github/workflows/ci.yaml rename to .github/workflows/ci-shark-ai.yml index b2d302a8e..bf8007e65 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci-shark-ai.yml @@ -1,4 +1,10 @@ -name: CI +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - shark-ai on: workflow_dispatch: @@ -16,14 +22,13 @@ concurrency: cancel-in-progress: true jobs: - test: - name: "Unit Tests and Type Checking" + test_shortfin_llm_server: + name: "Integration Tests - Shortfin LLM Server" strategy: matrix: version: [3.11] - os: [ubuntu-latest, windows-latest] fail-fast: false - runs-on: ${{matrix.os}} + runs-on: nodai-amdgpu-mi250-x86-64 defaults: run: shell: bash @@ -32,15 +37,15 @@ jobs: steps: - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} @@ -53,22 +58,18 @@ jobs: # from non default locations first. Installing the PyTorch CPU # wheels saves multiple minutes and a lot of bandwidth on runner setup. pip install --no-compile -r pytorch-cpu-requirements.txt - pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ - - name: Run sharktank tests - if: ${{ !cancelled() }} - run: | - pytest -n 4 sharktank/ + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - - name: Run shortfin tests - if: ${{ !cancelled() }} - run: | - pytest -n 4 shortfin/ + # Try with the latest IREE nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime - # TODO: Enable type checking of sharktank - - name: MyPy Type Checking Shortfin - if: ${{ !cancelled() }} - run: | - (cd shortfin && mypy) + - name: Run LLM Integration Tests + run: pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml new file mode 100644 index 000000000..1d3960b43 --- /dev/null +++ b/.github/workflows/ci-sharktank.yml @@ -0,0 +1,131 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - sharktank + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test: + name: "Unit Tests and Type Checking" + strategy: + matrix: + version: [3.11] + os: [ubuntu-latest, windows-latest] + fail-fast: false + runs-on: ${{matrix.os}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Update to the latest iree packages. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + - name: Run sharktank tests + if: ${{ !cancelled() }} + run: | + pytest -n 4 sharktank/ + + + test_with_data: + name: "Data-dependent Tests" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + HF_HOME: "/data/huggingface" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + # Try with the latest IREE nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run tests + run: | + pytest \ + --with-t5-data \ + sharktank/tests/models/t5/t5_test.py diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index 5de7d4182..81b920e31 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -1,11 +1,23 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + name: CI - Tuner on: workflow_dispatch: pull_request: + paths: + - '.github/workflows/ci-tuner.yml' + - 'tuner/**' push: branches: - main + paths: + - '.github/workflows/ci-tuner.yml' + - 'tuner/**' concurrency: group: ${{ github.workflow }}-${{ github.event.number || github.sha }} @@ -20,10 +32,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.10.12' @@ -37,8 +49,11 @@ jobs: pip install -r tuner/requirements-tuner.txt python -m pip install \ --find-links https://iree.dev/pip-release-links.html \ - --upgrade \ - iree-compiler iree-runtime + --upgrade --pre \ + iree-base-compiler iree-base-runtime - name: Run tuner tests run: pytest tuner/ + + - name: Run mypy type checker + run: mypy tuner/tuner diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml new file mode 100644 index 000000000..0164b6cdc --- /dev/null +++ b/.github/workflows/ci_eval.yaml @@ -0,0 +1,143 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - sharktank perplexity + +on: + workflow_dispatch: + schedule: + # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT. + - cron: "0 7 * * 1-5" + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_perplexity_iree: + if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} + timeout-minutes: 1000 + name: "Perplexity-IREE" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + # Try with the latest IREE nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run perplexity test with IREE + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/llama/perplexity/iree_perplexity + destination_dir: ./llm/llama/perplexity/iree_perplexity + keep_files: true + + test_perplexity_torch: + if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} + timeout-minutes: 1000 + name: "Perplexity-Torch" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + - name: Run perplexity test with Torch + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/llama/perplexity/torch_perplexity + destination_dir: ./llm/llama/perplexity/torch_perplexity + keep_files: true diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml new file mode 100644 index 000000000..4622f5c57 --- /dev/null +++ b/.github/workflows/ci_eval_short.yaml @@ -0,0 +1,77 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - sharktank perplexity short + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test_perplexity_iree: + name: "Llama3.1 8B FP16" + strategy: + matrix: + version: [3.11] + runs-on: [llama-mi300x-3] + fail-fast: false + runs-on: ${{matrix.runs-on}} + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }} + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }} + + - name: Install sharktank deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ + + # Install latest iree-tubrine. + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + + # Try with the latest IREE nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler \ + iree-base-runtime + + - name: Run perplexity test with vmfb + run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 793058570..afeca11a6 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -4,125 +4,122 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - libshortfin +name: CI - shortfin on: workflow_dispatch: pull_request: + paths: + - '.github/workflows/ci_linux_x64-libshortfin.yml' + - 'shortfin/**' push: branches: - main paths: - '.github/workflows/ci_linux_x64-libshortfin.yml' - - 'libshortfin/**' + - 'shortfin/**' permissions: contents: read +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + env: IREE_REPO_DIR: ${{ github.workspace }}/iree - LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/ + LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ jobs: build-and-test: name: Build and test runs-on: ubuntu-24.04 + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] steps: - name: Install dependencies run: | sudo apt update sudo apt install clang lld cmake ninja-build - sudo apt install libspdlog-dev libxtensor-dev - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: false - name: Checkout IREE repo - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: candidate-20240904.1006 + ref: iree-3.0.0rc20241118 - name: Initalize IREE submodules + working-directory: ${{ env.IREE_REPO_DIR }} run : | - cd ${{ env.IREE_REPO_DIR }} git submodule update --init --depth 1 -- third_party/benchmark git submodule update --init --depth 1 -- third_party/cpuinfo/ git submodule update --init --depth 1 -- third_party/flatcc git submodule update --init --depth 1 -- third_party/googletest git submodule update --init --depth 1 -- third_party/hip-build-deps/ - - name: Build IREE runtime - run: | - mkdir ${{ env.IREE_REPO_DIR }}/build - cd ${{ env.IREE_REPO_DIR }}/build - cmake -GNinja \ - -DCMAKE_C_COMPILER=clang-18 \ - -DCMAKE_CXX_COMPILER=clang++-18 \ - -DIREE_ENABLE_LLD=ON \ - -DIREE_ERROR_ON_MISSING_SUBMODULES=OFF \ - -DIREE_HAL_DRIVER_DEFAULTS=OFF \ - -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ - -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ - -DIREE_HAL_DRIVER_HIP=ON \ - -DIREE_BUILD_COMPILER=OFF \ - -DIREE_BUILD_SAMPLES=OFF \ - -DIREE_BUILD_TESTS=OFF \ - .. - cmake --build . --target all - - - name: Setup Python - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: - python-version: "3.12" + python-version: ${{ matrix.python-version }} cache: "pip" - name: Install Python packages - # TODO: Switch to `pip install -r requirements.txt -e libshortfin/`. + # TODO: Switch to `pip install -r requirements.txt -e shortfin/`. + working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt + pip install -r requirements-tests.txt + pip install -r requirements-iree-compiler.txt pip freeze - - name: Build libshortfin (full) + - name: Build shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - mkdir ${{ env.LIBSHORTFIN_DIR }}/build - cd ${{ env.LIBSHORTFIN_DIR }}/build + mkdir build cmake -GNinja \ + -S. \ + -Bbuild \ -DCMAKE_C_COMPILER=clang-18 \ -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ -DSHORTFIN_BUNDLE_DEPS=ON \ - -DCMAKE_PREFIX_PATH=${{ env.IREE_REPO_DIR }}/build/lib/cmake/IREE \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - .. - cmake --build . --target all - pip install -v -e . + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + cmake --build build --target all + pip install -v -e build/ - - name: Test libshortfin (full) + - name: Test shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - cd ${{ env.LIBSHORTFIN_DIR }}/build - ctest --timeout 30 --output-on-failure - cd ${{ env.LIBSHORTFIN_DIR }} - pytest -s -v -m "not requires_amd_gpu" + ctest --timeout 30 --output-on-failure --test-dir build + pytest -s - - name: Build libshortfin (host-only) + - name: Build shortfin (host-only) + working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | - mkdir ${{ env.LIBSHORTFIN_DIR }}/build-host-only - cd ${{ env.LIBSHORTFIN_DIR }}/build-host-only + mkdir build-host-only # In this configuration, also build static+dynamic in order to verify # that path structurally works. cmake -GNinja \ + -S. \ + -Bbuild-host-only \ -DCMAKE_C_COMPILER=clang-18 \ -DCMAKE_CXX_COMPILER=clang++-18 \ -DCMAKE_LINKER_TYPE=LLD \ - -DCMAKE_PREFIX_PATH=${{ env.IREE_REPO_DIR }}/build/lib/cmake/IREE \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ -DSHORTFIN_HAVE_AMDGPU=OFF \ -DSHORTFIN_BUILD_STATIC=ON \ - -DSHORTFIN_BUILD_DYNAMIC=ON \ - .. - cmake --build . --target all + -DSHORTFIN_BUILD_DYNAMIC=ON + cmake --build build-host-only --target all diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index c998fbc77..42de8f0f6 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -4,29 +4,40 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - libshortfin - ASan +name: CI - shortfin - ASan on: workflow_dispatch: pull_request: + paths: + - '.github/workflows/ci_linux_x64_asan-libshortfin.yml' + - 'shortfin/**' push: branches: - main paths: - '.github/workflows/ci_linux_x64_asan-libshortfin.yml' - - 'libshortfin/**' + - 'shortfin/**' permissions: contents: read +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + env: PYENV_ROOT: ${{ github.workspace }}/pyenv - PYENV_REF: 9ecd803bffaffb949fbdd8c70cb086227f6a3202 # v2.4.10 - PYTHON_VER: 3.12.3 + PYENV_REF: 96b3fb2fc3bee85650cb22e2cb06c83c24509a6d # v2.4.17 + PYTHON_VER: 3.12.7 CACHE_ASAN_VER: 2 CACHE_DEPS_VER: 1 IREE_SOURCE_DIR: ${{ github.workspace }}/iree - LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/ + LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ jobs: setup-python-asan: @@ -40,7 +51,7 @@ jobs: steps: - name: Cache Python ASan id: cache-python-asan - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ${{ env.PYENV_ROOT }} key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }} @@ -55,7 +66,7 @@ jobs: - name: Checkout pyenv if: steps.cache-python-asan.outputs.cache-hit != 'true' - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: pyenv/pyenv ref: ${{ env.PYENV_REF }} @@ -63,8 +74,8 @@ jobs: - name: Install pyenv & Python if: steps.cache-python-asan.outputs.cache-hit != 'true' + working-directory: ${{ env.PYENV_ROOT }} run: | - cd ${{ env.PYENV_ROOT }} src/configure && make -C src export PATH=${{ env.PYENV_ROOT }}/bin:$PATH && eval "$(pyenv init -)" CC=clang-18 CXX=clang++-18 LDFLAGS="-lstdc++" PYTHON_CONFIGURE_OPTS="--with-address-sanitizer" pyenv install -v ${{ env.PYTHON_VER }} @@ -72,13 +83,15 @@ jobs: build-and-test: - name: Build and test libshortfin + name: Build and test shortfin needs: [setup-python-asan] runs-on: ubuntu-24.04 env: - # TODO(#151): Don't ignore ODR violations - ASAN_OPTIONS: detect_odr_violation=0 - LSAN_OPTIONS: suppressions=${{ github.workspace }}/libshortfin/build_tools/python_lsan_suppressions.txt + # We can't count on being leak free in general (i.e. pip, etc) so disable + # leak checker by default. Here we suppress any ASAN features needed to + # pass the build. Test configuration is done specially just for that step. + ASAN_OPTIONS: detect_leaks=0,detect_odr_violation=0 + LSAN_OPTIONS: suppressions=${{ github.workspace }}/shortfin/build_tools/python_lsan_suppressions.txt steps: - name: Install dependencies run: | @@ -86,21 +99,21 @@ jobs: sudo apt install clang lld cmake ninja-build - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: false - name: Checkout IREE repo - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: iree-org/iree path: ${{ env.IREE_SOURCE_DIR }} submodules: false - ref: candidate-20240904.1006 + ref: iree-3.0.0rc20241118 - name: Initalize IREE submodules + working-directory: ${{ env.IREE_SOURCE_DIR }} run : | - cd ${{ env.IREE_SOURCE_DIR }} git submodule update --init --depth 1 -- third_party/benchmark git submodule update --init --depth 1 -- third_party/cpuinfo/ git submodule update --init --depth 1 -- third_party/flatcc @@ -112,7 +125,7 @@ jobs: uses: actions/cache/restore@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 with: path: ${{ env.PYENV_ROOT }} - key: ${{ runner.os }}-python-deps-${{ hashFiles('libshortfin/requirements-tests.txt') }}-v${{ env.CACHE_DEPS_VER }} + key: ${{ runner.os }}-python-deps-${{ hashFiles('shortfin/requirements-tests.txt') }}-v${{ env.CACHE_DEPS_VER }} - name: Restore Python ASan cache id: cache-python-asan @@ -128,10 +141,11 @@ jobs: - name: Install Python dependencies if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' + working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | eval "$(pyenv init -)" - pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt - pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-iree-compiler.txt + pip install -r requirements-tests.txt + pip install -r requirements-iree-compiler.txt pip freeze - name: Save Python dependencies cache @@ -142,35 +156,22 @@ jobs: path: ${{ env.PYENV_ROOT }} key: ${{ steps.cache-python-deps-restore.outputs.cache-primary-key }} - - name: Build libshortfin + - name: Build shortfin + working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | eval "$(pyenv init -)" - mkdir ${{ env.LIBSHORTFIN_DIR }}/build - cd ${{ env.LIBSHORTFIN_DIR }}/build - cmake -GNinja \ - -DCMAKE_BUILD_TYPE=Debug \ - -DCMAKE_C_COMPILER=clang-18 \ - -DCMAKE_CXX_COMPILER=clang++-18 \ - -DCMAKE_LINKER_TYPE=LLD \ - -DSHORTFIN_BUNDLE_DEPS=ON \ - -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_SOURCE_DIR }} \ - -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ - -DSHORTFIN_ENABLE_ASAN=ON \ - .. - cmake --build . --target all + SHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_SOURCE_DIR }}" \ + SHORTFIN_ENABLE_ASAN=ON \ + SHORTFIN_DEV_MODE=ON \ + SHORTFIN_RUN_CTESTS=ON \ pip install -v -e . - - name: Run ctest - if: ${{ !cancelled() }} - env: - CTEST_OUTPUT_ON_FAILURE: 1 - run: | - cd ${{ env.LIBSHORTFIN_DIR }}/build - ctest --timeout 30 --output-on-failure - - name: Run pytest if: ${{ !cancelled() }} + env: + # TODO(#151): Don't ignore ODR violations + ASAN_OPTIONS: detect_odr_violation=0 + working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | eval "$(pyenv init -)" - cd ${{ env.LIBSHORTFIN_DIR }} - pytest -m "not requires_amd_gpu" -s + pytest -s diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml new file mode 100644 index 000000000..0e0e1db2a --- /dev/null +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -0,0 +1,103 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - shortfin - Python 3.13 Free-threaded + +on: + workflow_dispatch: + pull_request: + paths: + - '.github/workflows/ci_linux_x64-libshortfin.yml' + - 'shortfin/**' + + push: + branches: + - main + paths: + - '.github/workflows/ci_linux_x64-libshortfin.yml' + - 'shortfin/**' + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +env: + IREE_REPO_DIR: ${{ github.workspace }}/iree + LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ + +jobs: + build-and-test: + name: Build and test + runs-on: ubuntu-24.04 + + steps: + - name: Install dependencies + run: | + sudo apt update + sudo apt install clang lld cmake ninja-build + + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + submodules: false + + - name: Checkout IREE repo + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: iree-org/iree + path: ${{ env.IREE_REPO_DIR }} + submodules: false + ref: iree-3.0.0rc20241118 + + - name: Initalize IREE submodules + working-directory: ${{ env.IREE_REPO_DIR }} + run : | + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ + + - name: Setup Python + uses: deadsnakes/action@e640ac8743173a67cca4d7d77cd837e514bf98e8 # v3.2.0 + with: + python-version: "3.13-dev" + nogil : true + - name: Install Python packages + run: | + pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests-nogil.txt + pip freeze + + - name: Build shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + mkdir build + cmake -GNinja \ + -S. \ + -Bbuild \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_C_COMPILER=clang-18 \ + -DCMAKE_CXX_COMPILER=clang++-18 \ + -DCMAKE_LINKER_TYPE=LLD \ + -DSHORTFIN_BUNDLE_DEPS=ON \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + cmake --build build --target all + pip install -v -e build/ + + - name: Run shortfin Python tests (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd + # TODO: Enable further tests and switch to + # pytest -s diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml new file mode 100644 index 000000000..00873c432 --- /dev/null +++ b/.github/workflows/ci_windows_x64-libshortfin.yml @@ -0,0 +1,98 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - shortfin - Windows + +on: + workflow_dispatch: + pull_request: + paths: + - '.github/workflows/ci_windows_x64-libshortfin.yml' + - 'shortfin/**' + push: + branches: + - main + paths: + - '.github/workflows/ci_windows_x64-libshortfin.yml' + - 'shortfin/**' + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +env: + IREE_REPO_DIR: ${{ github.workspace }}/iree + LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/ + +jobs: + build-and-test: + name: Build and test + runs-on: windows-2022 + + steps: + - name: Configure MSVC + uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0 + + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + submodules: false + + - name: Checkout IREE repo + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: iree-org/iree + path: ${{ env.IREE_REPO_DIR }} + submodules: false + ref: iree-3.0.0rc20241118 + + - name: Initalize IREE submodules + working-directory: ${{ env.IREE_REPO_DIR }} + run : | + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ + + - name: Setup Python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: "3.12" + cache: "pip" + - name: Install Python packages + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + pip install -r requirements-tests.txt + pip install -r requirements-iree-compiler.txt + pip freeze + + - name: Build shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + shell: bash + run: | + mkdir build + cmake -GNinja \ + -S. \ + -Bbuild \ + -DSHORTFIN_BUNDLE_DEPS=ON \ + -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON + cmake --build build --target all + pip install -v -e build/ + + - name: Test shortfin (full) + working-directory: ${{ env.LIBSHORTFIN_DIR }} + run: | + ctest --timeout 30 --output-on-failure --test-dir build + pytest -s diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 2b11178bf..8ec1e8d55 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -9,6 +9,6 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 diff --git a/.gitignore b/.gitignore index b766af2d2..bdb0b5387 100644 --- a/.gitignore +++ b/.gitignore @@ -8,12 +8,17 @@ *.suo *.user +# Common IREE source/build paths +iree/ +iree-build/ + # macOS files .DS_Store # CMake artifacts build/ build-*/ +_build/ # Python __pycache__ @@ -25,8 +30,20 @@ wheelhouse *.whl *.venv +# Local-only config options +version_local.json + #Model artifacts *.pt *.safetensors *.gguf *.vmfb +genfiles/ +*.zip +tmp/ + +# Known inference result blobs +*output*.png + +# Log files. +*.log diff --git a/README.md b/README.md index 7ff8d0126..44f1e6113 100644 --- a/README.md +++ b/README.md @@ -1,65 +1,71 @@ -# SHARK Modeling and Serving Libraries - -**WARNING: This is an early preview that is in progress. It is not ready for -general use.** +# shark-ai: SHARK Modeling and Serving Libraries +![GitHub License](https://img.shields.io/github/license/nod-ai/shark-ai) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) +## SHARK Users + +If you're looking to use SHARK check out our [User Guide](docs/user_guide.md). For developers continue to read on. + + + +## Sub-projects + +### [`shortfin/`](./shortfin/) + + + +[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush) -## Development Getting Started +The shortfin sub-project is SHARK's high performance inference library and +serving engine. -Use this as a guide to get started developing the project using pinned, -pre-release dependencies. You are welcome to deviate as you see fit, but -these canonical directions mirror what the CI does. +* API documentation for shortfin is available on + [readthedocs](https://shortfin.readthedocs.io/en/latest/). -### Setup a venv +### [`sharktank/`](./sharktank/) -We recommend setting up a virtual environment (venv). The project is configured -to ignore `.venv` directories, and editors like VSCode pick them up by default. +[![PyPI version](https://badge.fury.io/py/sharktank.svg)](https://badge.fury.io/py/sharktank) [![CI - sharktank](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml?query=event%3Apush) -``` -python -m venv --prompt sharktank .venv -source .venv/bin/activate -``` +The SHARK Tank sub-project contains a collection of model recipes and +conversion tools to produce inference-optimized programs. -### Install PyTorch for Your System +> [!WARNING] +> SHARK Tank is still under development. Experienced users may want to try it +> out, but we currently recommend most users download pre-exported or +> pre-compiled model files for serving with shortfin. -If no explicit action is taken, the default PyTorch version will be installed. -This will give you a current CUDA-based version. Install a different variant -by doing so explicitly first: + -*CPU:* +* See the [SHARK Tank Programming Guide](./docs/programming_guide.md) for + information about core concepts, the development model, dataset management, + and more. +* See [Direct Quantization with SHARK Tank](./docs/quantization.md) + for information about quantization support. -``` -pip install -r pytorch-cpu-requirements.txt -``` +### [`tuner/`](./tuner/) -*ROCM:* +[![CI - Tuner](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml?query=event%3Apush) -``` -pip install -r pytorch-rocm-requirements.txt -``` +The Tuner sub-project assists with tuning program performance by searching for +optimal parameter configurations to use during model compilation. -### Install Development Packages +> [!WARNING] +> SHARK Tuner is still in early development. Interested users may want +> to try it out, but the tuner is not ready for general use yet. Check out +> [the readme](tuner/README.md) for more details. -``` -# Clone and install editable iree-turbine dep in deps/ -pip install -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" +## Support matrix -# Install editable local projects. -pip install -r requirements.txt -e sharktank/ shortfin/ -``` + -### Running Tests +### Models -``` -pytest sharktank -pytest shortfin -``` +Model name | Model recipes | Serving apps +---------- | ------------- | ------------ +SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/sd) +llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/llm) -### Optional: Pre-commits and developer settings +## SHARK Developers -This project is set up to use the `pre-commit` tooling. To install it in -your local repo, run: `pre-commit install`. After this point, when making -commits locally, hooks will run. See https://pre-commit.com/ +If you're looking to develop SHARK, check out our [Developer Guide](docs/developer_guide.md). diff --git a/app_tests/__init__.py b/app_tests/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/benchmark_tests/__init__.py b/app_tests/benchmark_tests/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/benchmark_tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/conftest.py new file mode 100644 index 000000000..cc354b7eb --- /dev/null +++ b/app_tests/benchmark_tests/llm/conftest.py @@ -0,0 +1,46 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os +import pytest +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from integration_tests.llm.utils import compile_model, export_paged_llm_v1 + + +@pytest.fixture(scope="module") +def pre_process_model(request, tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test") + + model_path = request.param["model_path"] + settings = request.param["settings"] + batch_sizes = request.param["batch_sizes"] + + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + vmfb_path = tmp_dir / "model.vmfb" + + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) + + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "transformer_block_count": 32, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + with open(config_path, "w") as file: + json.dump(config, file) + + compile_model(mlir_path, vmfb_path, settings) + + return tmp_dir diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py new file mode 100644 index 000000000..0de775795 --- /dev/null +++ b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py @@ -0,0 +1,122 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import logging +import multiprocessing +import os +from pathlib import Path +import pytest +import time +from unittest.mock import patch + +pytest.importorskip("sglang") +from sglang import bench_serving + +from utils import SGLangBenchmarkArgs + +from integration_tests.llm.utils import ( + find_available_port, + start_llm_server, +) + +logger = logging.getLogger("__name__") + +device_settings = { + "device_flags": [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx942", + ], + "device": "hip", +} + +# TODO: Download on demand instead of assuming files exist at this path +MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa") +TOKENIZER_DIR = Path("/data/llama3.1/8b/") + + +def log_jsonl_result(file_path): + with open(file_path, "r") as file: + json_string = file.readline().strip() + + json_data = json.loads(json_string) + for key, val in json_data.items(): + logger.info(f"{key.upper()}: {val}") + + +@pytest.mark.parametrize( + "request_rate", + [1, 2, 4, 8, 16, 32], +) +@pytest.mark.parametrize( + "pre_process_model", + [ + ( + { + "model_path": MODEL_PATH, + "settings": device_settings, + "batch_sizes": [1, 4], + } + ) + ], + indirect=True, +) +def test_sglang_benchmark_server(request_rate, pre_process_model): + # TODO: Remove when multi-device is fixed + os.environ["ROCR_VISIBLE_DEVICES"] = "1" + + tmp_dir = pre_process_model + + config_path = tmp_dir / "config.json" + vmfb_path = tmp_dir / "model.vmfb" + tokenizer_path = TOKENIZER_DIR / "tokenizer.json" + + # Start shortfin llm server + port = find_available_port() + server_process = start_llm_server( + port, + tokenizer_path, + config_path, + vmfb_path, + MODEL_PATH, + device_settings, + timeout=30, + ) + + # Run and collect SGLang Serving Benchmark + benchmark_args = SGLangBenchmarkArgs( + backend="shortfin", + num_prompt=10, + base_url=f"http://localhost:{port}", + tokenizer=TOKENIZER_DIR, + request_rate=request_rate, + ) + output_file = ( + tmp_dir + / f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl" + ) + benchmark_args.output_file = output_file + + logger.info("Running SGLang Benchmark with the following args:") + logger.info(benchmark_args) + try: + start = time.time() + with patch.object(bench_serving, "print", side_effect=logger.info): + benchmark_process = multiprocessing.Process( + target=bench_serving.run_benchmark, + args=(benchmark_args.as_namespace(),), + ) + benchmark_process.start() + benchmark_process.join() + + logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds") + logger.info("======== RESULTS ========") + log_jsonl_result(benchmark_args.output_file) + except Exception as e: + logger.info(e) + + server_process.terminate() + server_process.wait() diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/utils.py new file mode 100644 index 000000000..55b01da04 --- /dev/null +++ b/app_tests/benchmark_tests/llm/utils.py @@ -0,0 +1,56 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from argparse import Namespace +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class SGLangBenchmarkArgs: + base_url: str + num_prompt: int + request_rate: int + tokenizer: str | Path + + seed: int = 1 + extra_request_body: str | None = None + output_file: str | Path | None = None + port: int = 8000 + backend: str = "shortfin" + + def as_namespace(self) -> Namespace: + return Namespace( + num_prompts=self.num_prompt, + base_url=self.base_url, + tokenizer=str(self.tokenizer), + request_rate=self.request_rate, + backend=self.backend, + output_file=self.output_file, + seed=self.seed, + extra_request_body=self.extra_request_body, + port=8000, + model=None, + dataset_name="sharegpt", + random_input_len=None, + random_output_len=None, + random_range_ratio=0.0, + dataset_path="", + sharegpt_output_len=None, + multi=False, + disable_tqdm=False, + disable_stream=False, + disable_ignore_eos=False, + ) + + def __repr__(self): + return ( + f"Backend: {self.backend}\n" + f"Base URL: {self.base_url}\n" + f"Num Prompt: {self.num_prompt}\n" + f"Tokenizer: {self.tokenizer}\n" + f"Request Rate: {self.request_rate}" + ) diff --git a/app_tests/integration_tests/__init__.py b/app_tests/integration_tests/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/integration_tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/integration_tests/llm/__init__.py b/app_tests/integration_tests/llm/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/integration_tests/llm/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/integration_tests/llm/sglang/__init__.py b/app_tests/integration_tests/llm/sglang/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/integration_tests/llm/sglang/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/integration_tests/llm/sglang/conftest.py b/app_tests/integration_tests/llm/sglang/conftest.py new file mode 100644 index 000000000..8543708da --- /dev/null +++ b/app_tests/integration_tests/llm/sglang/conftest.py @@ -0,0 +1,123 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import logging +import os +import pytest + +from ..utils import ( + find_available_port, + start_llm_server, + download_with_hf_datasets, + export_paged_llm_v1, + compile_model, +) + +pytest.importorskip("sglang") +import sglang as sgl +from sglang.lang.chat_template import get_chat_template + +pytest.importorskip("sentence_transformers") +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def register_shortfin_backend(available_port): + backend = sgl.Shortfin( + chat_template=get_chat_template("llama-3-instruct"), + base_url=f"http://localhost:{available_port}", + ) + sgl.set_default_backend(backend) + + +@pytest.fixture(scope="module") +def pre_process_model(request, tmp_path_factory): + device_settings = request.param["device_settings"] + tmp_dir = tmp_path_factory.mktemp("sglang_integration_tests") + + # Download model + model_params_path = tmp_dir / "meta-llama-3.1-8b-instruct.f16.gguf" + download_with_hf_datasets(tmp_dir, "llama3_8B_fp16") + + # Export to mlir + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + batch_sizes = [1, 4] + export_paged_llm_v1( + mlir_path, + config_path, + model_params_path, + batch_sizes, + ) + + # Compile Model + vmfb_path = tmp_dir / "model.vmfb" + compile_model( + mlir_path, + vmfb_path, + device_settings, + ) + + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": [1, 4], + "decode_batch_sizes": [1, 4], + "transformer_block_count": 32, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + config_path = tmp_dir / "config.json" + with open(config_path, "w") as f: + json.dump(config, f) + + return tmp_dir + + +@pytest.fixture(scope="module") +def available_port(): + return find_available_port() + + +@pytest.fixture(scope="module") +def start_server(request, pre_process_model, available_port): + os.environ["ROCR_VISIBLE_DEVICES"] = "1" + device_settings = request.param["device_settings"] + + export_dir = pre_process_model + + tokenizer_path = export_dir / "tokenizer.json" + model_params_path = export_dir / "meta-llama-3.1-8b-instruct.f16.gguf" + vmfb_path = export_dir / "model.vmfb" + config_path = export_dir / "config.json" + + logger.info("Starting server...") + server_process = start_llm_server( + available_port, + tokenizer_path, + config_path, + vmfb_path, + model_params_path, + device_settings, + timeout=30, + ) + logger.info("Server started") + + yield server_process + + server_process.terminate() + server_process.wait() + + +@pytest.fixture(scope="module") +def load_comparison_model(): + model = SentenceTransformer("all-MiniLM-L6-v2") + return model diff --git a/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py b/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py new file mode 100644 index 000000000..efab14ea7 --- /dev/null +++ b/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py @@ -0,0 +1,309 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import re +import pytest + +from ..utils import ( + AccuracyValidationException, +) + +pytest.importorskip("sglang") +import sglang as sgl +from sglang.lang.chat_template import get_chat_template + +pytest.importorskip("sentence_transformers") +from sentence_transformers import SentenceTransformer, util + +logger = logging.getLogger(__name__) + +DEVICE_SETTINGS = { + "device_flags": [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx942", + ], + "device": "hip", +} + +ACCEPTED_THRESHOLD = 0.8 + + +def compute_similarity(model: SentenceTransformer, sentence_1: str, sentence_2: str): + embeddings = model.encode([sentence_1, sentence_2]) + return util.pytorch_cos_sim(embeddings[0], embeddings[1]).item() + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=50, temperature=1.0)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=50, temperature=1.0)) + + +@sgl.function +def tip_suggestion(s): + s += ( + "Here are two tips for staying healthy: " + "1. Balanced Diet. 2. Regular Exercise.\n\n" + ) + + forks = s.fork(2) + for i, f in enumerate(forks): + f += f"Now, expand tip {i+1} into a paragraph:\n" + f += sgl.gen(f"detailed_tip", max_tokens=50, temperature=1.0) + + s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" + s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") + + +@pytest.mark.parametrize( + "pre_process_model,start_server", + [ + ( + {"device_settings": DEVICE_SETTINGS}, + {"device_settings": DEVICE_SETTINGS}, + ) + ], + indirect=True, +) +def test_multi_turn_qa(load_comparison_model, start_server, register_shortfin_backend): + model = load_comparison_model + + question_1 = "Name the capital city of the USA." + question_2 = "The Smithsonian is in this location." + + answer_1 = "The capital city of the United States of America is Washington, D.C. (short for District of Columbia).assistant\n\nWould you like to know more about Washington, D.C. or is there something else I can help you with?" + answer_2 = "The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. It was founded in 1846 and is named after British scientist James Smithson, who left a bequest to" + + logger.info("Testing multi-turn Q&A run...") + state = multi_turn_question.run( + question_1=question_1, + question_2=question_2, + ) + messages = state.messages() + logger.info("Received messages from multi-turn call.") + + assert messages[0] == { + "role": "user", + "content": question_1, + } + assert messages[1]["role"] == "assistant" + + logger.info("Computing similarity between first question and first answer...") + first_q_answer = messages[1]["content"] + score = compute_similarity(model, answer_1, first_q_answer) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {answer_1} and {first_q_answer}:\n SCORE: {score}" + ) + logger.info("Similarity passed") + + assert messages[2] == { + "role": "user", + "content": question_2, + } + assert messages[3]["role"] == "assistant" + + logger.info("Testing similarity between second question and second answer...") + second_q_answer = messages[3]["content"] + score = compute_similarity(model, answer_2, second_q_answer) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {answer_2} and {second_q_answer}:\n SCORE: {score}" + ) + logger.info("Similarity passed.") + + +@pytest.mark.parametrize( + "pre_process_model,start_server", + [ + ( + {"device_settings": DEVICE_SETTINGS}, + {"device_settings": DEVICE_SETTINGS}, + ) + ], + indirect=True, +) +def test_stream_multi_turn_qa( + load_comparison_model, start_server, register_shortfin_backend +): + def clean_message(message: str): + """Remove chat tags from message before comparison. + + Args: + message (str): Message to clean. + + Returns: + str: Message without tags (i.e. <|start_header_id|>) + """ + pattern = r"<\|.*?\|>" + return re.sub(pattern, "", message) + + model = load_comparison_model + question_1 = "Name the capital city of the USA." + question_2 = "The Smithsonian is in this location." + expected_answer_1 = "The capital city of the United States of America is Washington, D.C. (short for District of Columbia).assistant\n\nWould you like to know more about Washington, D.C. or is there something else I can help you with?" + expected_answer_2 = "The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. It was founded in 1846 and is named after British scientist James Smithson, who left a bequest to" + + logger.info("Testing multi-turn Q&A run w/ stream...") + state = multi_turn_question.run( + question_1=question_1, + question_2=question_2, + stream=True, + ) + messages = "" + for chunk in state.text_iter(): + messages += chunk + logger.info("Received messages from multi-turn call.") + + logger.info("Computing similarity between expectation and result") + expected_result = f"user: {question_1}\nassistant: {expected_answer_1}\nuser: {question_2}\nassistant: {expected_answer_2}" + cleaned_messages = clean_message(messages) + score = compute_similarity(model, cleaned_messages, expected_result) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {expected_result} and {messages}:\n SCORE: {score}" + ) + logger.info("Similarity passed.") + + +@pytest.mark.parametrize( + "pre_process_model,start_server", + [ + ( + {"device_settings": DEVICE_SETTINGS}, + {"device_settings": DEVICE_SETTINGS}, + ) + ], + indirect=True, +) +def test_batch_multi_turn_qa( + load_comparison_model, start_server, register_shortfin_backend +): + model = load_comparison_model + + question_1_1 = "Name the capital city of the USA." + question_1_2 = "The Smithsonian is in this location." + expected_answer_1_1 = "The capital city of the United States of America is Washington, D.C. (short for District of Columbia).assistant\n\nWould you like to know more about Washington, D.C. or is there something else I can help you with?" + expected_answer_1_2 = "The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. It was founded in 1846 and is named after British scientist James Smithson, who left a bequest to" + + question_2_1 = "Name the largest city in the USA." + question_2_2 = "The Empire State Building is in this location." + expected_answer_2_1 = "The largest city in the USA is New York City, with a population of over 8.4 million people, according to the United States Census Bureau (2020 estimates).assistant\n\nHowever, I should note that the largest city in the" + expected_answer_2_2 = "That's correct, the iconic Empire State Building is located in Midtown Manhattan, New York City. It's one of the most recognizable landmarks in the world and a symbol of the city's grandeur and history.assistant\n\nAnd, by" + + logger.info("Testing batch multi-turn Q&A run...") + states = multi_turn_question.run_batch( + [ + { + "question_1": question_1_1, + "question_2": question_1_2, + }, + { + "question_1": question_2_1, + "question_2": question_2_2, + }, + ] + ) + + first_qa = states[0] + second_qa = states[1] + + first_qa_messages = first_qa.messages() + second_qa_messages = second_qa.messages() + + logger.info("Testing first batch of messages...") + assert first_qa_messages[0] == { + "role": "user", + "content": question_1_1, + } + + assert first_qa_messages[1]["role"] == "assistant" + first_answer = first_qa_messages[1]["content"] + expected_answer = expected_answer_1_1 + score = compute_similarity(model, expected_answer, first_answer) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {expected_answer} and {first_answer}:\n SCORE: {score}" + ) + + assert first_qa_messages[2] == { + "role": "user", + "content": question_1_2, + } + first_qa_messages[3]["role"] = "assistant" + second_answer = first_qa_messages[3]["content"] + expected_answer = expected_answer_1_2 + score = compute_similarity(model, expected_answer, second_answer) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {expected_answer} and {second_answer}:\n SCORE: {score}" + ) + logger.info("First batch passed.") + + logger.info("Testing second batch of messages...") + assert second_qa_messages[0] == { + "role": "user", + "content": question_2_1, + } + + assert second_qa_messages[1]["role"] == "assistant" + first_answer = second_qa_messages[1]["content"] + expected_answer = expected_answer_2_1 + score = compute_similarity(model, expected_answer, first_answer) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {expected_answer} and {first_answer}:\n SCORE: {score}" + ) + + assert second_qa_messages[2] == { + "role": "user", + "content": question_2_2, + } + second_qa_messages[3]["role"] = "assistant" + second_answer = second_qa_messages[3]["content"] + expected_answer = expected_answer_2_2 + score = compute_similarity(model, expected_answer, second_answer) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {expected_answer} and {second_answer}:\n SCORE: {score}" + ) + logger.info("Second batch passed.") + + +@pytest.mark.parametrize( + "pre_process_model,start_server", + [ + ( + {"device_settings": DEVICE_SETTINGS}, + {"device_settings": DEVICE_SETTINGS}, + ) + ], + indirect=True, +) +def test_fork(load_comparison_model, start_server, register_shortfin_backend): + model = load_comparison_model + + logger.info("Testing fork...") + state = tip_suggestion.run() + result = state.text() + logger.info("Fork response received.") + + logger.info("Computing similarity...") + expected_answer = """Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise. + Tip 1:A balanced diet is essential for maintaining good health. It involves consuming a variety of foods from different food groups, including fruits, vegetables, whole grains, lean proteins, and healthy fats. A balanced diet provides the body with the necessary nutrients, vitamins, and + Tip 2:Regular exercise is essential for maintaining a healthy body. It helps to improve cardiovascular health, increase strength and flexibility, and boost the immune system. Regular physical activity can also reduce the risk of chronic diseases such as heart disease, diabetes, and certain types of cancer + In summary, a balanced diet and regular exercise are two of the most important tips for staying healthy. By following these tips, you can maintain a healthy body and reduce the risk of chronic diseases. + """ + score = compute_similarity(model, result, expected_answer) + if not score > ACCEPTED_THRESHOLD: + raise AccuracyValidationException( + f"Accuracy error between {expected_answer} and {result}:\n SCORE: {score}" + ) + logger.info("Similarity passed.") diff --git a/app_tests/integration_tests/llm/shortfin/__init__.py b/app_tests/integration_tests/llm/shortfin/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/app_tests/integration_tests/llm/shortfin/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py new file mode 100644 index 000000000..0d40119c7 --- /dev/null +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -0,0 +1,140 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import logging +import os +from pathlib import Path +import pytest +import shutil + +pytest.importorskip("transformers") +from ..utils import ( + download_huggingface_model, + download_tokenizer, + export_paged_llm_v1, + compile_model, + find_available_port, + start_llm_server, + start_log_group, + end_log_group, +) + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_test_dir(request, tmp_path_factory): + """Prepare model artifacts for starting the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - repo_id (str): The Hugging Face repo ID. + - model_file (str): The model file to download. + - tokenizer_id (str): The tokenizer ID to download. + - settings (dict): The settings for sharktank export. + - batch_sizes (list): The batch sizes to use for the model. + tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. + + Yields: + Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. + """ + logger.info( + "Preparing model artifacts..." + start_log_group("Preparing model artifacts") + ) + + repo_id = request.param["repo_id"] + model_file = request.param["model_file"] + tokenizer_id = request.param["tokenizer_id"] + settings = request.param["settings"] + batch_sizes = request.param["batch_sizes"] + + tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") + hf_home = os.environ.get("HF_HOME", None) + hf_home = Path(hf_home) if hf_home is not None else tmp_dir + try: + # Download model if it doesn't exist + model_path = hf_home / model_file + download_huggingface_model(hf_home, repo_id, model_file) + + # Set up tokenizer if it doesn't exist + download_tokenizer(hf_home, tokenizer_id) + + # Export model + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) + + # Compile model + vmfb_path = tmp_dir / "model.vmfb" + compile_model(mlir_path, vmfb_path, settings) + + # Write config + edited_config_path = tmp_dir / "edited_config.json" + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 2048, + "attn_head_count": 32, + "attn_head_dim": 100, + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "transformer_block_count": 26, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + logger.info(f"Saving edited config to: {edited_config_path}\n") + logger.info(f"Config: {json.dumps(config, indent=2)}") + with open(edited_config_path, "w") as f: + json.dump(config, f) + logger.info("Model artifacts setup successfully" + end_log_group()) + yield hf_home, tmp_dir + finally: + shutil.rmtree(tmp_dir) + + +@pytest.fixture(scope="module") +def available_port(): + return find_available_port() + + +@pytest.fixture(scope="module") +def llm_server(request, model_test_dir, available_port): + """Start the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - model_file (str): The model file to download. + - settings (dict): The settings for starting the server. + model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. + available_port (int): The available port to start the server on. + + Yields: + subprocess.Popen: The server process that was started. + """ + logger.info("Starting LLM server..." + start_log_group("Starting LLM server")) + hf_home, tmp_dir = model_test_dir + model_file = request.param["model_file"] + settings = request.param["settings"] + + tokenizer_path = hf_home / "tokenizer.json" + config_path = tmp_dir / "edited_config.json" + vmfb_path = tmp_dir / "model.vmfb" + parameters_path = hf_home / model_file + + # Start llm server + server_process = start_llm_server( + available_port, + tokenizer_path, + config_path, + vmfb_path, + parameters_path, + settings, + ) + logger.info("LLM server started!" + end_log_group()) + yield server_process + # Teardown: kill the server + server_process.terminate() + server_process.wait() diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py new file mode 100644 index 000000000..c4da9e4eb --- /dev/null +++ b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py @@ -0,0 +1,104 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import os +import pytest +import requests +import uuid + +from ..utils import AccuracyValidationException, start_log_group, end_log_group + +logger = logging.getLogger(__name__) + +CPU_SETTINGS = { + "device_flags": [ + "-iree-hal-target-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu=host", + ], + "device": "local-task", +} +IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100") +gpu_settings = { + "device_flags": [ + "-iree-hal-target-backends=rocm", + f"--iree-hip-target={IREE_HIP_TARGET}", + ], + "device": "hip", +} + + +def do_generate(prompt, port): + logger.info("Generating request...") + headers = {"Content-Type": "application/json"} + # Create a GenerateReqInput-like structure + data = { + "text": prompt, + "sampling_params": {"max_completion_tokens": 15, "temperature": 0.7}, + "rid": uuid.uuid4().hex, + "return_logprob": False, + "logprob_start_len": -1, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "stream": False, + } + logger.info("Prompt text:") + logger.info(data["text"]) + BASE_URL = f"http://localhost:{port}" + response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data) + logger.info(f"Generate endpoint status code: {response.status_code}") + if response.status_code == 200: + logger.info("Generated text:") + data = response.text + assert data.startswith("data: ") + data = data[6:] + assert data.endswith("\n\n") + data = data[:-2] + return data + else: + response.raise_for_status() + + +@pytest.mark.parametrize( + "model_test_dir,llm_server", + [ + ( + { + "repo_id": "SlyEcho/open_llama_3b_v2_gguf", + "model_file": "open-llama-3b-v2-f16.gguf", + "tokenizer_id": "openlm-research/open_llama_3b_v2", + "settings": CPU_SETTINGS, + "batch_sizes": [1, 4], + }, + {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS}, + ) + ], + indirect=True, +) +def test_llm_server(llm_server, available_port): + # Here you would typically make requests to your server + # and assert on the responses + assert llm_server.poll() is None + PROMPT = "1 2 3 4 5 " + expected_output_prefix = "6 7 8" + logger.info( + "Sending HTTP Generation Request" + + start_log_group("Sending HTTP Generation Request") + ) + output = do_generate(PROMPT, available_port) + # log to GITHUB_STEP_SUMMARY if we are in a GitHub Action + if "GITHUB_ACTION" in os.environ: + with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: + # log prompt + f.write("LLM results:\n") + f.write(f"- llm_prompt:`{PROMPT}`\n") + f.write(f"- llm_output:`{output}`\n") + logger.info(output) + if not output.startswith(expected_output_prefix): + raise AccuracyValidationException( + f"Expected '{output}' to start with '{expected_output_prefix}'" + ) + logger.info("HTTP Generation Request Successful" + end_log_group()) diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py new file mode 100644 index 000000000..05712039e --- /dev/null +++ b/app_tests/integration_tests/llm/utils.py @@ -0,0 +1,220 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import multiprocessing +import os +from pathlib import Path +import subprocess +import sys +import time + +import requests +from transformers import AutoTokenizer + +logger = logging.getLogger("__name__") + + +class AccuracyValidationException(RuntimeError): + pass + + +def download_huggingface_model(local_dir, repo_id, model_file): + model_path = local_dir / model_file + logger.info(f"Preparing model_path: {model_path}..") + if not os.path.exists(model_path): + logger.info(f"Downloading model {repo_id} {model_file} from Hugging Face...") + subprocess.run( + f"huggingface-cli download --local-dir {local_dir} {repo_id} {model_file}", + shell=True, + check=True, + ) + logger.info(f"Model downloaded to {model_path}") + else: + logger.info("Using cached model") + + +def download_with_hf_datasets(local_dir: Path | str, model_name: str): + """Download a model using `sharktank.utils.hf_datasets` script. + + Args: + local_dir (Path | str): Local directory to download model to. + model_name (str): Name of model to download. + """ + if isinstance(local_dir, Path): + local_dir = str(local_dir) + + logger.info(f"Download model {model_name} with `hf_datasets` to {local_dir}...") + subprocess.run( + [ + "python", + "-m", + "sharktank.utils.hf_datasets", + model_name, + "--local-dir", + local_dir, + ], + check=True, + ) + logger.info(f"Model {model_name} successfully downloaded.") + + +def download_tokenizer(local_dir, tokenizer_id): + # Set up tokenizer if it doesn't exist + tokenizer_path = local_dir / "tokenizer.json" + logger.info(f"Preparing tokenizer_path: {tokenizer_path}...") + if not os.path.exists(tokenizer_path): + logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...") + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_id, + ) + tokenizer.save_pretrained(local_dir) + logger.info(f"Tokenizer saved to {tokenizer_path}") + else: + logger.info("Using cached tokenizer") + + +def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes): + bs_string = ",".join(map(str, batch_sizes)) + logger.info( + "Exporting model with following settings:\n" + f" MLIR Path: {mlir_path}\n" + f" Config Path: {config_path}\n" + f" Batch Sizes: {bs_string}" + ) + subprocess.run( + [ + "python", + "-m", + "sharktank.examples.export_paged_llm_v1", + f"--{model_path.suffix.strip('.')}-file={model_path}", + f"--output-mlir={mlir_path}", + f"--output-config={config_path}", + f"--bs={bs_string}", + ], + check=True, + ) + logger.info(f"Model successfully exported to {mlir_path}") + + +def compile_model(mlir_path, vmfb_path, device_settings): + logger.info(f"Compiling model to {vmfb_path}") + subprocess.run( + [ + "iree-compile", + mlir_path, + "-o", + vmfb_path, + ] + + device_settings["device_flags"], + check=True, + ) + logger.info(f"Model successfully compiled to {vmfb_path}") + + +def find_available_port(): + import socket + from contextlib import closing + + logger.info(f"Finding available port...") + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + logger.info(f"Found available port: {port}") + return port + + +def wait_for_server(url, timeout=10): + logger.info(f"Waiting for server to start at {url}...") + start = time.time() + while time.time() - start < timeout: + try: + requests.get(f"{url}/health") + logger.info("Server successfully started") + return + except requests.exceptions.ConnectionError: + time.sleep(1) + raise TimeoutError(f"Server did not start within {timeout} seconds") + + +def _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, +): + return [ + sys.executable, + "-m", + "shortfin_apps.llm.server", + f"--tokenizer_json={tokenizer_path}", + f"--model_config={model_config_path}", + f"--vmfb={vmfb_path}", + f"--parameters={parameters_path}", + f"--device={settings['device']}", + f"--port={port}", + ] + + +def start_llm_server( + port, + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + timeout=10, + multi=False, +): + logger.info("Starting LLM server...") + if multi: + server_process = multiprocessing.Process( + target=subprocess.Popen( + _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, + ), + ) + ) + server_process.start() + + else: + # Start the server + server_process = subprocess.Popen( + _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, + ) + ) + logger.info("Process started... waiting for server") + # Wait for server to start + wait_for_server(f"http://localhost:{port}", timeout) + return server_process + + +def start_log_group(headline): + # check if we are in github ci + if os.environ.get("GITHUB_ACTIONS") == "true": + return f"\n::group::{headline}" + return "" + + +def end_log_group(): + # check if we are in github ci + if os.environ.get("GITHUB_ACTIONS") == "true": + return "\n::endgroup::" + return "" diff --git a/build_tools/integration_tests/llama_export_compile_serve.sh b/build_tools/integration_tests/llama_export_compile_serve.sh deleted file mode 100755 index edd54b688..000000000 --- a/build_tools/integration_tests/llama_export_compile_serve.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/bin/bash -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -set -xeuo pipefail - -# Assume that the environment is already set up: -# * Python venv set up with requirements, sharktank, and shortfin -# * iree-compile and iree-run-module on $PATH -# * authenticated with `huggingface-cli login` - -# Input variables. -# Default model: https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf -# Default tokenizer: https://huggingface.co/openlm-research/open_llama_3b_v2 -TEMP_DIR="${TEMP_DIR:-/tmp/sharktank/llama}" -HUGGING_FACE_MODEL_NAME="${HUGGING_FACE_MODEL_NAME:-SlyEcho/open_llama_3b_v2_gguf}" -HUGGING_FACE_MODEL_FILE="${HUGGING_FACE_MODEL_FILE:-open-llama-3b-v2-f16.gguf}" -HUGGING_FACE_TOKENIZER="${HUGGING_FACE_TOKENIZER:-openlm-research/open_llama_3b_v2}" - -# Derived variables. -LOCAL_GGUF_FILE="${TEMP_DIR}/${HUGGING_FACE_MODEL_FILE}" -LOCAL_MLIR_FILE="${TEMP_DIR}/model.mlir" -LOCAL_CONFIG_FILE="${TEMP_DIR}/config.json" -LOCAL_VMFB_FILE="${TEMP_DIR}/model.vmfb" - -mkdir -p ${TEMP_DIR} - -huggingface-cli download --local-dir ${TEMP_DIR} ${HUGGING_FACE_MODEL_NAME} ${HUGGING_FACE_MODEL_FILE} - -python -m sharktank.examples.export_paged_llm_v1 \ - --gguf-file="${LOCAL_GGUF_FILE}" \ - --output-mlir="${LOCAL_MLIR_FILE}" \ - --output-config="${LOCAL_CONFIG_FILE}" - -iree-compile "${LOCAL_MLIR_FILE}" \ - --iree-hal-target-backends=llvm-cpu \ - --iree-llvmcpu-target-cpu-features=host \ - -o ${LOCAL_VMFB_FILE} - -python -m shortfin.llm.impl.service_v1_cli \ - --tokenizer="${HUGGING_FACE_TOKENIZER}" \ - --config="${LOCAL_CONFIG_FILE}" \ - --vmfb="${LOCAL_VMFB_FILE}" \ - --gguf="${LOCAL_GGUF_FILE}" diff --git a/build_tools/python_deploy/README.md b/build_tools/python_deploy/README.md new file mode 100644 index 000000000..d36545a9c --- /dev/null +++ b/build_tools/python_deploy/README.md @@ -0,0 +1,48 @@ +# Python Deployment + +These scripts assist with building Python packages and pushing them to +[PyPI (the Python Package Index)](https://pypi.org/). See also + +* The Python Packaging User Guide: + +## Overview + +See comments in scripts for canonical usage. This page includes additional +notes. + +### Package building + +These scripts build packages: + +* [`/shark-ai/build_tools/build_linux_package.sh`](/shark-ai/build_tools/build_linux_package.sh) +* [`/sharktank/build_tools/build_linux_package.sh`](/sharktank/build_tools/build_linux_package.sh) +* [`/shortfin/build_tools/build_linux_package.sh`](/shortfin/build_tools/build_linux_package.sh) + +### Version management + +These scripts handle versioning across packages, including considerations like +major, minor, and patch levels (`X.Y.Z`), as well as suffixes like +`rc20241107`: + +* [`compute_common_version.py`](./compute_common_version.py) +* [`compute_local_version.py`](./compute_local_version.py) +* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py) +* [`write_requirements.py`](./write_requirements.py) + +### PyPI deployment + +These scripts handle promoting nightly releases packages to stable and pushing +to PyPI: + +* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py) +* [`pypi_deploy.sh`](./pypi_deploy.sh) + +Both of these scripts expect to have the dependencies from +[`requirements-pypi-deploy.txt`](./requirements-pypi-deploy.txt) installed. +This can be easily managed by using a Python virtual environment: + +```bash +python -m venv .venv +source .venv/bin/activate +python -m pip install -r ./requirements-pypi-deploy.txt +``` diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py new file mode 100755 index 000000000..aa193bcc1 --- /dev/null +++ b/build_tools/python_deploy/compute_common_version.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This scripts grabs the `X.Y.Z[.dev]` version identifier from the +# 'sharktank' and 'shortfin' version files and computes the version +# for the meta 'shark-ai' package. +# +# Usage: +# ./compute_common_version.py --stable-release --write-json +# cat ../../shark-ai/version_local.json + +import argparse +from pathlib import Path +import json +from datetime import datetime +import sys + +from packaging.version import Version + + +parser = argparse.ArgumentParser() +parser.add_argument("--write-json", action="store_true") +parser.add_argument("--version-suffix", action="store", type=str) + +release_type = parser.add_mutually_exclusive_group() +release_type.add_argument("-stable", "--stable-release", action="store_true") # default +release_type.add_argument("-rc", "--nightly-release", action="store_true") + + +args = parser.parse_args() + +if not (args.stable_release or args.nightly_release): + parser.print_usage(sys.stderr) + sys.stderr.write("error: A release type is required\n") + sys.exit(1) + +if args.stable_release and args.version_suffix: + sys.stderr.write("error: A version suffix is only supported for stable releases\n") + sys.exit(1) + +THIS_DIR = Path(__file__).parent.resolve() +REPO_ROOT = THIS_DIR.parent.parent + +VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version.json" +VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version.json" +VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json" + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +def write_version_info(): + with open(VERSION_FILE_LOCAL, "w") as f: + json.dump(version_local, f, indent=2) + f.write("\n") + + +sharktank_version = load_version_info(VERSION_FILE_SHARKTANK) +SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version") +SHARKTANK_BASE_VERSION = Version(SHARKTANK_PACKAGE_VERSION).base_version + +shortfin_version = load_version_info(VERSION_FILE_SHORTFIN) +SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version") +SHORTFIN_BASE_VERSION = Version(SHORTFIN_PACKAGE_VERSION).base_version + +if SHARKTANK_BASE_VERSION > SHORTFIN_BASE_VERSION: + COMMON_VERSION = SHARKTANK_BASE_VERSION +else: + COMMON_VERSION = SHORTFIN_BASE_VERSION + +if args.nightly_release: + if args.version_suffix: + VERSION_SUFFIX = args.version_suffix + else: + VERSION_SUFFIX = "rc" + datetime.today().strftime("%Y%m%d") + + COMMON_VERSION += VERSION_SUFFIX + +if args.write_json: + version_local = {"package-version": COMMON_VERSION} + write_version_info() + +print(COMMON_VERSION) diff --git a/build_tools/python_deploy/compute_local_version.py b/build_tools/python_deploy/compute_local_version.py new file mode 100755 index 000000000..0465fa443 --- /dev/null +++ b/build_tools/python_deploy/compute_local_version.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This scripts grabs the X.Y.Z[.dev]` version identifier from a +# `version.json` and writes the corresponding +# `X.Y.ZrcYYYYMMDD` version identifier to `version_local.json`. + +import argparse +from pathlib import Path +import json +from datetime import datetime + +from packaging.version import Version + + +parser = argparse.ArgumentParser() +parser.add_argument("path", type=Path) +parser.add_argument("--version-suffix", action="store", type=str) +args = parser.parse_args() + +VERSION_FILE = args.path / "version.json" +VERSION_FILE_LOCAL = args.path / "version_local.json" + + +def load_version_info(): + with open(VERSION_FILE, "rt") as f: + return json.load(f) + + +def write_version_info(): + with open(VERSION_FILE_LOCAL, "w") as f: + json.dump(version_local, f, indent=2) + f.write("\n") + + +version_info = load_version_info() + +if args.version_suffix: + VERSION_SUFFIX = args.version_suffix +else: + VERSION_SUFFIX = "rc" + datetime.today().strftime("%Y%m%d") + +PACKAGE_VERSION = version_info.get("package-version") +PACKAGE_BASE_VERSION = Version(PACKAGE_VERSION).base_version +PACKAGE_LOCAL_VERSION = PACKAGE_BASE_VERSION + VERSION_SUFFIX + +version_local = {"package-version": PACKAGE_LOCAL_VERSION} + +write_version_info() + +print(PACKAGE_LOCAL_VERSION) diff --git a/build_tools/python_deploy/promote_whl_from_rc_to_final.py b/build_tools/python_deploy/promote_whl_from_rc_to_final.py new file mode 100755 index 000000000..061dd933b --- /dev/null +++ b/build_tools/python_deploy/promote_whl_from_rc_to_final.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This scripts takes a file like 'sharktank-2.9.0rc20241110-py3-none-any.whl' +# with embedded version '2.9.0rc20241110' as input and then drops the +# 'rcYYYYMMDD' suffix from both the embedded version and file name. +# +# Typical usage: +# pip install -r requirements-pypi-deploy.txt +# ./promote_whl_from_rc_to_final.py /path/to/file.whl --delete-old-wheel + +import argparse +from change_wheel_version import change_wheel_version +from packaging.version import Version +from pathlib import Path +from pkginfo import Wheel + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_file", + help="Path to the input .whl file to promote", + type=Path, + ) + parser.add_argument( + "--delete-old-wheel", + help="Deletes the original wheel after successfully promoting it", + action="store_true", + default=False, + ) + return parser.parse_args() + + +def main(args): + original_wheel_path = args.input_file + print(f"Promoting whl from rc to final: '{original_wheel_path}'") + + original_wheel = Wheel(original_wheel_path) + original_version = Version(original_wheel.version) + base_version = original_version.base_version + print( + f" Original wheel version is '{original_version}' with base '{base_version}'" + ) + + if str(base_version) == str(original_version): + print(" Version is already a release version, skipping") + return + + print(f" Changing to base version: '{base_version}'") + new_wheel_path = change_wheel_version(original_wheel_path, str(base_version), None) + print(f" New wheel path is '{new_wheel_path}'") + + new_wheel = Wheel(new_wheel_path) + new_version = Version(new_wheel.version) + print(f" New wheel version is '{new_version}'") + + if args.delete_old_wheel: + print(" Deleting original wheel") + original_wheel_path.unlink() + + +if __name__ == "__main__": + main(parse_arguments()) diff --git a/build_tools/python_deploy/pypi_deploy.sh b/build_tools/python_deploy/pypi_deploy.sh new file mode 100755 index 000000000..63f123ac0 --- /dev/null +++ b/build_tools/python_deploy/pypi_deploy.sh @@ -0,0 +1,126 @@ +#!/bin/bash + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This script promotes Python packages from nightly releases to PyPI. +# +# Prerequisites: +# * You will need to have PyPI credentials set up. See +# https://packaging.python.org/en/latest/tutorials/packaging-projects/#uploading-the-distribution-archives +# * Install requirements, e.g. in a Python virtual environment (venv): +# `pip install -r requirements-pypi-deploy.txt` +# * Install python3.13t and install pip. On Ubuntu: +# ```bash +# sudo add-apt-repository ppa:deadsnakes +# sudo apt-get update +# sudo apt-get install python3.13-nogil +# python3.13t -m ensurepip --upgrade +# ``` +# * Choose a release candidate to promote from +# https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels +# +# Usage: +# ./pypi_deploy.sh 2.9.0rc20241108 + +set -euo pipefail + +RELEASE="$1" + +SCRIPT_DIR="$(dirname -- "$( readlink -f -- "$0"; )")"; +REPO_ROOT="$(cd "$SCRIPT_DIR"/../../ && pwd)" +TMPDIR="$(mktemp --directory --tmpdir shark_platform_pypi_wheels.XXXXX)" +ASSETS_PAGE="https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels" + +# TODO: rewrite in Python? + +function download_wheels() { + echo "" + echo "Downloading wheels for '${RELEASE}'..." + + # sharktank + python -m pip download sharktank==${RELEASE} \ + --no-deps --python-version 3.11 -f ${ASSETS_PAGE} + + # shortfin + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.11 -f ${ASSETS_PAGE} + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.12 -f ${ASSETS_PAGE} + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.13 -f ${ASSETS_PAGE} + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.13 -f ${ASSETS_PAGE} + # TODO: fetch 3.13t using the same `python` somehow + # * https://pip.pypa.io/en/stable/cli/pip_download/ + # * https://py-free-threading.github.io/installing_cpython/ + # * https://pip.pypa.io/en/stable/installation/ + python3.13t -m pip download shortfin==${RELEASE} --no-deps -f ${ASSETS_PAGE} + + # TODO: shark-ai meta package when it is published to nightlies + + echo "" + echo "Downloaded wheels:" + ls +} + +function edit_release_versions() { + echo "" + echo "Editing release versions..." + for file in * + do + ${SCRIPT_DIR}/promote_whl_from_rc_to_final.py ${file} --delete-old-wheel + done + + echo "Edited wheels:" + ls +} + +function upload_wheels() { + # TODO: list packages that would be uploaded, pause, prompt to continue + echo "" + echo "Uploading wheels:" + ls + twine upload --verbose * +} + +function build_shark_ai_meta_package() { + # TODO: download meta package from nightly releases instead of this + # Be aware that nightly releases pin other dependencies via the + # generated `requirements.txt` compared to stable releases. + echo "" + + # TODO: rework `write_requirements.py` to use the versions from the downloaded whls? + echo "Computing local versions for sharktank and shortfin..." + ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/sharktank + ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/shortfin + + echo "Computing common version for shark-ai meta package..." + ${SCRIPT_DIR}/compute_common_version.py --stable-release --write-json + + echo "Writing requirements for shark-ai meta package..." + ${SCRIPT_DIR}/write_requirements.py + + echo "Building shark-ai meta package..." + ${REPO_ROOT}/shark-ai/build_tools/build_linux_package.sh + + # TODO: This is error-prone. We only want to publish the whl for this release. + # Copy instead? Specify exact file name? Clear directory before building? + mv ${REPO_ROOT}/shark-ai/build_tools/wheelhouse/* . +} + +function main() { + echo "Changing into ${TMPDIR}" + cd "${TMPDIR}" + # TODO: check_requirements (using pip) + + download_wheels + edit_release_versions + build_shark_ai_meta_package + upload_wheels +} + +main diff --git a/build_tools/python_deploy/requirements-pypi-deploy.txt b/build_tools/python_deploy/requirements-pypi-deploy.txt new file mode 100644 index 000000000..dcc32d47a --- /dev/null +++ b/build_tools/python_deploy/requirements-pypi-deploy.txt @@ -0,0 +1,4 @@ +change_wheel_version +packaging +pkginfo +twine diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py new file mode 100755 index 000000000..224e01fd0 --- /dev/null +++ b/build_tools/python_deploy/write_requirements.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This script writes the `packaging/shark-ai/requirements.txt` file and pins +# the versions of the dependencies accordingly. For nightly releases, +# * sharktank +# * shortfin +# get pinned to the corresponding nightly version. The IREE packages are +# unpinned. For stable releases, +# * iree-base-compiler +# * iree-base-runtime +# * iree-turbine +# * sharktank +# * shortfin +# get pinned to the corresponding `X.Y.*` version. + +import argparse +from pathlib import Path +import json + +from packaging.version import Version + + +parser = argparse.ArgumentParser() +parser.add_argument("--version-suffix", action="store", type=str) + +args = parser.parse_args() + + +THIS_DIR = Path(__file__).parent.resolve() +REPO_ROOT = THIS_DIR.parent.parent + +VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_local.json" +VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_local.json" +VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json" +REQUIREMENTS_TXT = REPO_ROOT / "shark-ai/requirements.txt" + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +def write_requirements(requirements): + with open(REQUIREMENTS_TXT, "w") as f: + f.write("%s\n" % requirements) + + +metapackage_version = load_version_info(VERSION_FILE_LOCAL) +PACKAGE_VERSION = metapackage_version.get("package-version") + +# sharktank_version = load_version_info(VERSION_FILE_SHARKTANK) +# SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version") + +shortfin_version = load_version_info(VERSION_FILE_SHORTFIN) +SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version") + +stable_packages_list = ["iree-base-compiler", "iree-base-runtime", "iree-turbine"] + +if Version(PACKAGE_VERSION).is_prerelease: + requirements = "" + for package in stable_packages_list: + requirements += package + "\n" + # TODO: Include sharktank as a dependencies of future releases + # requirements = ( + # "sharktank==" + # + Version(SHARKTANK_PACKAGE_VERSION).base_version + # + args.version_suffix + # + "\n" + # ) + requirements += ( + "shortfin==" + + Version(SHORTFIN_PACKAGE_VERSION).base_version + + args.version_suffix + ) + + write_requirements(requirements) + +else: + MAJOR_VERSION = Version(PACKAGE_VERSION).major + MINOR_VERSION = Version(PACKAGE_VERSION).minor + + STABLE_VERSION_TO_PIN = str(MAJOR_VERSION) + "." + str(MINOR_VERSION) + ".*" + + requirements = "" + for package in stable_packages_list: + requirements += package + "==" + STABLE_VERSION_TO_PIN + "\n" + # TODO: Include sharktank as a dependencies of future releases + # requirements += ( + # "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n" + # ) + requirements += "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version + + write_requirements(requirements) diff --git a/docs/developer_guide.md b/docs/developer_guide.md new file mode 100644 index 000000000..73aee61f7 --- /dev/null +++ b/docs/developer_guide.md @@ -0,0 +1,118 @@ +# SHARK Developer Guide + +Each sub-project has its own developer guide. If you would like to work across +projects, these instructions should help you get started: + + +### Install Dependencies + +Install shortfin dependencies +```bash +sudo apt update && sudo apt install -y clang lld +``` + +### Prepare your python environment + +Install: + +```bash +sudo apt install python-is-python3 python3-venv python3-dev +``` + +
+ + Or, alternatively, use `pyenv` to manage a separate python installation for more control over its version: + + +The following instructions are taken from pyenv's guide here: https://github.com/pyenv/pyenv?tab=readme-ov-file#a-getting-pyenv + +First, install pyenv and its dependencies. + +```bash +sudo apt update; sudo apt install build-essential libssl-dev zlib1g-dev \ +libbz2-dev libreadline-dev libsqlite3-dev curl git \ +libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev +curl https://pyenv.run | bash +``` + +Then, make pyenv available by adding the below to your `~/.bashrc`: + +```bash +export PYENV_ROOT="$HOME/.pyenv" +command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" +eval "$(pyenv init -)" +``` + +Finally, install a pyenv-managed version of python + +```bash +pyenv install 3.12 # or whichever python version you'd like +pyenv local 3.12 +``` + +Now, your python, pip, and venv should be managed by pyenv instead. + +
+ +### Setup a venv + +We recommend setting up a Python +[virtual environment (venv)](https://docs.python.org/3/library/venv.html). +The project is configured to ignore `.venv` directories, and editors like +VSCode pick them up by default. + +```bash +python -m venv .venv +source .venv/bin/activate +``` + +### Install PyTorch for your system + +If no explicit action is taken, the default PyTorch version will be installed. +This will give you a current CUDA-based version, which takes longer to download +and includes other dependencies that SHARK does not require. To install a +different variant, run one of these commands first: + +* *CPU:* + + ```bash + pip install -r pytorch-cpu-requirements.txt + ``` + +* *ROCM:* + + ```bash + pip install -r pytorch-rocm-requirements.txt + ``` + +* *Other:* see instructions at . + +### Install development packages + +```bash +# Install editable local projects. +pip install -r requirements.txt -e sharktank/ shortfin/ + +# Optionally clone and install the latest editable iree-turbine dep in deps/, +# along with nightly versions of iree-base-compiler and iree-base-runtime. +pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" +``` + +See also: [nightly_releases.md](nightly_releases.md). + +### Running tests + +```bash +pip install -r shortfin/requirements-tests.txt +pytest sharktank +pytest shortfin +pytest app_tests/integration_tests +``` + +### Optional: pre-commits and developer settings + +This project is set up to use the `pre-commit` tooling. To install it in +your local repo, run: `pre-commit install`. After this point, when making +commits locally, hooks will run. See https://pre-commit.com/ diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index 64137956d..ddc0cb3bb 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -1,7 +1,7 @@ # Model cookbook -Note: These are early notes and commands that the sharktank team is using and -will turn into proper docs later. +Note: These are early notes and commands that the shark-ai team is using +and will turn into proper docs later. ## Diagrams @@ -165,18 +165,13 @@ tokenizer_config.json: 100%|█████████████████ Setup (from [README.md](../README.md)): -* TODO: this could be replaced with `pip install iree-turbine` or - `pip install sharktank` at some point. For now these are dev packages. - ```bash # Setup venv. python -m venv --prompt sharktank .venv source .venv/bin/activate -# Install requirements. +# (Optional) Install PyTorch for CPU only, to save on download time. pip install -r pytorch-cpu-requirements.txt -pip install -f https://iree.dev/pip-release-links.html --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" # Install local projects. pip install -r requirements.txt -e sharktank/ shortfin/ @@ -257,6 +252,21 @@ iree-run-module \ --parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf ``` +## Evaluation pipeline + +Run perplexity test: + +```bash +pytest sharktank/tests/evaluate/perplexity_test.py --longrun +``` + +Run perplexity for a new model: +```bash +python -m sharktank.evaluate.perplexity \ + --gguf-file=llama8b_f16.gguf \ + --tokenizer-config-json=tokenizer_config.json +``` + ## Generating data for llama models ```bash diff --git a/docs/nightly_releases.md b/docs/nightly_releases.md new file mode 100644 index 000000000..f41374445 --- /dev/null +++ b/docs/nightly_releases.md @@ -0,0 +1,213 @@ +# Nightly releases + +Nightly releases are uploaded to +https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels. + +* The "expanded_assets" version of a release page is compatible with the + `-f, --find-links ` options of `pip install` + ([docs here](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-f)). + For the "dev-wheels" release above, that page is: + +* These releases are generated using + [`.github/workflows/build_package.yml`](../.github/workflows/build_packages.yml) +* That workflow runs the + [`sharktank/build_tools/build_linux_package.sh`](../sharktank/build_tools/build_linux_package.sh) + and +[`shortfin/build_tools/build_linux_package.sh`](../shortfin/build_tools/build_linux_package.sh) + scripts +* Workflow history can be viewed at + + +## Prerequisites + +### Operating system + +Currently we only officially support Linux with published packages. Windows and +macOS support is possible, but may need additional setup, code changes, and +source builds. + +### Python + +You will need a recent version of Python. + +* As of Nov 1, 2024, sharktank is compatible with Python 3.11. See + https://github.com/nod-ai/shark-ai/issues/349 for Python 3.12 support. +* As of Nov 4, 2024, shortfin publishes packages for Python 3.11, 3.12, 3.13, + and 3.13t + +For example, to install Python 3.11 on Ubuntu: + +```bash +sudo apt install python3.11 python3.11-dev python3.11-venv + +which python3.11 +# /usr/bin/python3.11 +``` + +> [!TIP] +> Manage multiple Python versions using `pyenv` +> (), or the +> [Python Launcher for Windows](https://docs.python.org/3/using/windows.html#python-launcher-for-windows) +> on Windows. + +## Quickstart - sharktank + +```bash +# Set up a virtual environment to isolate packages from other envs. +python3.11 -m venv 3.11.venv +source 3.11.venv/bin/activate + +# Install 'sharktank' package from nightly releases. +pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels + +# Test the installation. +python -c "from sharktank import ops; print('Sanity check passed')" + +# Deactivate the virtual environment when done. +deactivate +``` + +## Quickstart - shortfin + +```bash +# Set up a virtual environment to isolate packages from other envs. +python3.11 -m venv 3.11.venv +source 3.11.venv/bin/activate + +# Install 'shortfin' package from nightly releases. +pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels + +# Test the installation. +python -c "import shortfin as sf; print('Sanity check passed')" + +# Deactivate the virtual environment when done. +deactivate +``` + +## Installing newer versions of dependencies + +To install the `iree-turbine` package from the latest source: + +```bash +pip install --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" +``` + +To install the `iree-base-compiler` and `iree-base-runtime` packages from +nightly releases: + +```bash +pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime +``` + +To install all three packages together: + +```bash +pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" +``` + +## Switching between stable and nightly channels + +The [`shark-ai` package on PyPI](https://pypi.org/project/shark-ai/) is a +meta-package that pins specific stable versions of each package that share +at least their major and minor versions: + +```bash +pip install shark-ai==2.9.1 + +pip freeze +# ... +# iree-base-compiler==2.9.0 +# iree-base-runtime==2.9.0 +# iree-turbine==2.9.0 +# ... +# shark-ai==2.9.1 +# shortfin==2.9.1 +# ... +``` + +If you attempt to update any individual package outside of those supported +versions, pip will log an error but continue anyway: + +```bash +pip install --upgrade --pre \ + -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels \ + shortfin==3.0.0rc20241118 + +# Looking in links: https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +# Collecting shortfin==3.0.0rc20241118 +# Downloading https://github.com/nod-ai/shark-ai/releases/download/dev-wheels/shortfin-3.0.0rc20241118-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 24.3 MB/s eta 0:00:00 +# Installing collected packages: shortfin +# Attempting uninstall: shortfin +# Found existing installation: shortfin 2.9.1 +# Uninstalling shortfin-2.9.1: +# Successfully uninstalled shortfin-2.9.1 +# ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. +# shark-ai 2.9.1 requires shortfin==2.9.1, but you have shortfin 3.0.0rc20241118 which is incompatible. +# Successfully installed shortfin-3.0.0rc20241118 + +pip freeze +# ... +# shark-ai==2.9.1 +# shortfin==3.0.0rc20241118 +# ... +``` + +Installing the `shark-ai` package again should get back to aligned versions: + +```bash +pip install shark-ai==2.9.1 +# ... +# Installing collected packages: shortfin +# Attempting uninstall: shortfin +# Found existing installation: shortfin 3.0.0rc20241118 +# Uninstalling shortfin-3.0.0rc20241118: +# Successfully uninstalled shortfin-3.0.0rc20241118 +# Successfully installed shortfin-2.9.1 + +pip freeze +# ... +# shark-ai==2.9.1 +# shortfin==2.9.1 +# ... +``` + +You can also uninstall the `shark-ai` package to bypass the error and take full +control of package versions yourself: + +```bash +pip uninstall shark-ai + +pip freeze +# ... +# (note: no shark-ai package) +# shortfin==2.9.1 +# ... + +pip install --upgrade --pre \ + -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels \ + shortfin==3.0.0rc20241118 + +# Looking in links: https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +# Collecting shortfin==3.0.0rc20241118 +# Using cached https://github.com/nod-ai/shark-ai/releases/download/dev-wheels/shortfin-3.0.0rc20241118-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB) +# Installing collected packages: shortfin +# Attempting uninstall: shortfin +# Found existing installation: shortfin 2.9.1 +# Uninstalling shortfin-2.9.1: +# Successfully uninstalled shortfin-2.9.1 +# Successfully installed shortfin-3.0.0rc20241118 + +pip freeze +# ... +# (note: no shark-ai package) +# shortfin==3.0.0rc20241118 +# ... +``` + +If you ever get stuck, consider creating a fresh +[virtual environment](https://docs.python.org/3/library/venv.html). diff --git a/docs/quantization.md b/docs/quantization.md index 0563e8108..25bfc9f8d 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -4,10 +4,10 @@ author: Stella Laurenzo date: June 30, 2024 --- -# Direct Quantization with sharktank +# Direct Quantization with SHARK Tank As a toolkit for building and adapting PyTorch based models for deployment, -sharktank provides rich quantization support. By targeting the +SHARK Tank provides rich quantization support. By targeting the [IREE compiler](https://github.com/iree-org/iree) for optimizations, we can strike a balance with our quantization setup that: @@ -36,7 +36,7 @@ supports these indirect schemes -- effectively using compiler transformations under the covers to do opaque model transformations that mirror a subset of what is exposed directly to the user in the rest of this document. -As an alternative, when developing sharktank and bringing up the initial +As an alternative, when developing SHARK Tank and bringing up the initial models, we wanted something more flexible, easier to debug/extend, and less laden with needing to lowest common denominator something for everyone in order to fit into fixed-function op sets that are very expensive to change. @@ -63,12 +63,12 @@ amount of Python code implementing direct math and packing schemes. drop-in replacements for subsets of the functionality available in stock PyTorch modules like `Linear` and `Conv2D`. 2. Types/Ops: The `nn.Module` implementations we provide are built in terms - of sharktank custom - [`InferenceTensor`](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/types/tensors.py#L153) - and [polymorphic functional ops library](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/ops/signatures.py). + of SHARK Tank custom + [`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/tensors.py#L153) + and [polymorphic functional ops library](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py). 3. Op specializations for optimized subsets of op type signatures and features (for example, [an optimized affine quantized linear specialization for - supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/ops/qlinear_impls.py)). + supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py)). (TODO: good place for a diagram) @@ -78,18 +78,18 @@ amount of Python code implementing direct math and packing schemes. Available modules that support direct quantization (TODO: refactor to use torch "Module" terminology and naming schemes consistently): -* [`LinearLayer`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/layers/linear.py) -* [convolution layers](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/layers/conv.py) +* [`LinearLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/linear.py) +* [convolution layers](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/conv.py) Note that most sharktank modules extend -[`ThetaLayer`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/layers/base.py#L63), +[`ThetaLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/base.py#L63), which calls for a bit of explanation. Traditional PyTorch Modules directly instantiate their backing parameters in their constructor. For dataset-heavy and polymorphic implementations like we commonly see in quantization and distribution, however, it can be beneficial to separate these concerns. The `ThetaLayer` simply takes a -[`Theta` object](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/theta.py#L74), +[`Theta` object](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L74), which is a tree-structured bag of native `torch.Tensor` or `InferenceTensor` instances, and it adopts the tensors in the bag as its own vs creating them. For those familiar with the concept, this is a form of dependency-injection @@ -114,7 +114,7 @@ tree to a specific Module instance. We've already met the `Theta` object above, which holds a tree of something called an -[`InferenceTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153). +[`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153). Now we describe what this is. Note that presently, `InferenceTensor` is not a `torch.Tensor` but its own `ABC` type that: @@ -140,11 +140,11 @@ pipelines. There is a growing list of `InferenceTensor` sub-types, many of which are related to quantization: -* [`PrimitiveTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286): +* [`PrimitiveTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286): A simple composition of a single `torch.Tensor`. This is often used interchangeably with a `torch.Tensor` but is present for completeness of the type hierarchy and to be able to type select on. -* [`QuantizedTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372): +* [`QuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372): Abstract base class of all quantized tensors, providing two primary operations: * `unpack`: Accesses the backing `QuantizedLayout` of the tensor, which is @@ -154,12 +154,12 @@ related to quantization: layout, this explodes it into a canonical representation of individual tensors which can be algebraically implemented individually/generically). -* [`PlanarQuantizedTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): +* [`PlanarQuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): Concrete implementation for all non-packed quantized tensors that can be losslessly represented by a layout based on individual tensor components. All `QuantizedTensor` instances can be converted to a `PlanarQuantizedTensor`. -* [`QuantizerTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): +* [`QuantizerTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): (note the "r" in the name) An abstract `InferenceTensor` that exposes a `quantize(torch.Tensor | InferenceTensor) -> QuantizedTensor` operation used to transform an arbitrary tensor to a quantized form. There are a handful @@ -178,7 +178,7 @@ manipulate tensor contents via `QuantizedLayout`, but we haven't yet defined that. The *Tensor types are structural and exist to give identity, but the `QuantizedLayout` is where the "magic happens". -[`QuantizedLayout`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44) +[`QuantizedLayout`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44) is an `ABC`, supporting: * Serialization/interop with parameter archives. @@ -193,7 +193,7 @@ is an `ABC`, supporting: There are a number of implementations, as every quantization scheme typically needs at least one concrete `QuantizedLayout`. Simple schemes like affine quantization can be fully defined in terms of a single -[`TensorScaledLayout`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/layouts.py#L43). +[`TensorScaledLayout`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/layouts.py#L43). Whereas packed schemes like we find in inference engines like GGML and XNNPACK optimally require both a packed layout and a planar layout. @@ -224,7 +224,7 @@ interpreting/transforming using their natively defined forms. Previously, we found a rich type system defining all manner of layouts and quantization schemes, but what can be done with it? That is where the sharktank functional op library comes in. These -[logical ops](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/signatures.py) +[logical ops](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py) provide the building blocks to implement built-in and custom `nn.Module` implementations operating on `InferenceTensor` (and torch.Tensor) types. @@ -239,12 +239,12 @@ implementation at any needed level of granularity: structures and preserve it when computing (when combined with a fusing compiler, this alone provides decent fallback implementations for a variety of "weight compression" oriented techniques). See - [some examples](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/custom_impls.py#L51). + [some examples](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/custom_impls.py#L51). * Pure-Torch decompositions for algebraic techniques like affine quantization (when combined with a fusing compiler, this alone is sufficient for optimization). See - [qlinear](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/qlinear_impls.py) and - [qconv](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/qconv_impls.py) + [qlinear](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py) and + [qconv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qconv_impls.py) implementations of actual affine quantized decompositions. * Completely custom packed/optimized implementation. These can be written to activate on any level of detail of the type hierarchy. The implementation @@ -277,11 +277,11 @@ is everything). We're just starting to exploit some of this as the PyTorch level. Some examples: * Something as simple as a humble runtime -[tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/iree.py#L52) -* [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/_jinja_test_ops.py#L28) - (see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/templates/test_add_jinja.mlir)). -* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py) - (see backing [jinja template](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)). +[tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/iree.py#L52) +* [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/_jinja_test_ops.py#L28) + (see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/templates/test_add_jinja.mlir)). +* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py) + (see backing [jinja template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)). * DSL based [like this fused attention kernel](https://github.com/iree-org/iree-turbine/blob/main/tests/kernel/fused_attention_test.py#L20) (note that in this case, the DSL exports to the unerlying IR-based registration mechanism used in the previous examples). @@ -292,8 +292,8 @@ Since all of these types of custom kernels are just defined with simple Python tooling, they are really fast to iterate on. The linalg based kernels specifically tend to be highly portable, and we don't hesitate to write one of those when we need something specific that PyTorch doesn't provide out of the box -(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py) -([template](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))). +(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py) +([template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))). ## Dataset transformation @@ -307,7 +307,7 @@ We take a practical approach to this, writing implementation specific converters where needed, and taking advantage of industry-standard consolidation points where available (like GGUF) in order to cover a wider surface area. -Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/theta.py#L263), +Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L263), which combines some set of hyper-parameters with a root `Theta` object (typically representing the layer-tree of frozen tensors). Datasets can be losslessly persisted to IREE IRPA files, which can then be loaded by either @@ -321,9 +321,9 @@ transform, shard, etc. See some examples: -* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) : +* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) : Creating a `Dataset` object from an HF diffusers safetensors file and config.json. -* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) : +* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) : Creates a quantized `Dataset` by combining: * HF diffusers `config.json` diff --git a/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md new file mode 100644 index 000000000..1ce2d1e8d --- /dev/null +++ b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md @@ -0,0 +1,242 @@ +# LLama 8b GPU Instructions on MI300X + +**NOTE: This was ran on the `mi300x-3` system** + +## Setup + +We will use an example with `llama_8b_f16_decomposed` in order to describe the +process of exporting a model for use in the shortfin llm server with an MI300 GPU. + +### Pre-Requisites + +- Python >= 3.11 is recommended for this flow + - You can check out [pyenv](https://github.com/pyenv/pyenv) as a good tool + to be able to manage multiple versions of python on the same system. + +### Setting Up Environment + +Follow the `Development Getting Started` docs +[here](https://github.com/nod-ai/shark-ai/blob/main/README.md#development-getting-started) +to setup your environment for development. + +We will use an example with `llama_8b_f16_decomposed` in order to describe the +process of exporting a model for use in the shortfin llm server with an MI300 GPU. + +### Define a directory for export files + +Create a new directory for us to export files like `model.mlir`, `model.vmfb`, etc. + +```bash +mkdir $PWD/export +export EXPORT_DIR=$PWD/exportd +``` + +### Define environment variables + +Define the following environment variables to make running this example a bit easier: + +#### Model/Tokenizer vars + +This example uses the `llama8b_f16.irpa` and `tokenizer.json` files that are +pre-existing on the MI300X-3 system. +You may need to change the paths for your own system. + +```bash +export MODEL_PARAMS_PATH=/data/llama3.1/8b/llama8b_f16.irpa # Path to existing .irpa file, may need to change w/ system +export TOKENIZER_PATH=/data/llama3.1/8b/tokenizer.json # Path to existing tokenizer.json, may need to change w/ system +``` + +#### General env vars + +The following env vars can be copy + pasted directly: + +```bash +export MLIR_PATH=$EXPORT_DIR/model.mlir # Path to export model.mlir file +export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json # Path to export config.json file +export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json # Path to export config.json file +export VMFB_PATH=$EXPORT_DIR/model.vmfb # Path to export model.vmfb file +export BS=1,4 # Batch size for kvcache +export ROCR_VISIBLE_DEVICES=1 # NOTE: This is temporary, until multi-device is fixed +``` + +### Export to MLIR + +We will now use the `sharktank.examples.export_paged_llm_v1` script to export +our model to `.mlir` format. + +```bash +python -m sharktank.examples.export_paged_llm_v1 \ + --irpa-file=$MODEL_PARAMS_PATH \ + --output-mlir=$MLIR_PATH \ + --output-config=$OUTPUT_CONFIG_PATH \ + --bs=$BS +``` + +## Compiling to `.vmfb` + +Now that we have generated a `model.mlir` file, we can compile it to `.vmfb` +format, which is required for running the `shortfin` LLM server. + +We will use the [iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile) +tool for compiling our model. + +### Compile for MI300 + +**NOTE: This command is specific to MI300 GPUs. +For other `--iree-hip-target` GPU options, +look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)** + +```bash +iree-compile $MLIR_PATH \ + --iree-hal-target-backends=rocm \ + --iree-hip-target=gfx942 \ + -o $VMFB_PATH +``` + +## Write an edited config + +We need to write a config for our model with a slightly edited structure +to run with shortfin. This will work for the example in our docs. +You may need to modify some of the parameters for a specific model. + +### Write edited config + +```bash +cat > $EDITED_CONFIG_PATH << EOF +{ + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": [ + $BS + ], + "decode_batch_sizes": [ + $BS + ], + "transformer_block_count": 32, + "paged_kv_cache": { + "block_seq_stride": 16, + "device_block_count": 256 + } +} +EOF +``` + +## Running the `shortfin` LLM server + +We should now have all of the files that we need to run the shortfin LLM server. + +Verify that you have the following in your specified directory ($EXPORT_DIR): + +```bash +ls $EXPORT_DIR +``` + +- edited_config.json +- model.vmfb + +### Launch server: + +#### Set the target device + + + +#### Run the shortfin server + +Run the following command to launch the Shortfin LLM Server in the background: + +> **Note** +> By default, our server will start at `http://localhost:8000`. +> You can specify the `--host` and/or `--port` arguments, to run at a different address. +> +> If you receive an error similar to the following: +> +> `[errno 98] address already in use` +> +> Then, you can confirm the port is in use with `ss -ntl | grep 8000` +> and either kill the process running at that port, +> or start the shortfin server at a different port. + +```bash +python -m shortfin_apps.llm.server \ + --tokenizer_json=$TOKENIZER_PATH \ + --model_config=$EDITED_CONFIG_PATH \ + --vmfb=$VMFB_PATH \ + --parameters=$MODEL_PARAMS_PATH \ + --device=hip > shortfin_llm_server.log 2>&1 & +shortfin_process=$! +``` + +You can verify your command has launched successfully when you see the following + logs outputted to terminal: + +```bash +cat shortfin_llm_server.log +``` + +#### Expected output + +```text +[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete. +[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +## Verify server + +### Client script + +We can test the LLM server, by running our client script: + +```bash +python shortfin/python/shortfin_apps/llm/client.py --port 8000 +``` + +### Simple request + +Or by sending a simple request: + +### Open python shell + +```bash +python +``` + +### Send request + +```python +import requests + +import os + +port = 8000 # Change if running at a different port + +generate_url = f"http://localhost:{port}/generate" + +def generation_request(): + payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}} + try: + resp = requests.post(generate_url, json=payload) + resp.raise_for_status() # Raises an HTTPError for bad responses + print(resp.text) + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + +generation_request() +``` + +After you receive the request, you can exit the python shell: + +```bash +quit() +``` + +## Cleanup + +When done, you can kill the shortfin_llm_server by killing the process: + +```bash +kill -9 $shortfin_process +``` diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md new file mode 100644 index 000000000..4a8423bc8 --- /dev/null +++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md @@ -0,0 +1,278 @@ +# LLama 8b GPU instructions on MI300X + +## Setup + +We will use an example with `llama_8b_f16` in order to describe the +process of exporting a model for use in the shortfin llm server with an +MI300 GPU. + +### Pre-Requisites + +- Python >= 3.11 is recommended for this flow + - You can check out [pyenv](https://github.com/pyenv/pyenv) + as a good tool to be able to manage multiple versions of python + on the same system. + +### Create virtual environment + +To start, create a new virtual environment: + +```bash +python -m venv --prompt shark-ai .venv +source .venv/bin/activate +``` + +### Install `shark-ai` + +You can install either the `latest stable` version of `shark-ai` +or the `nightly` version: + +#### Stable + +```bash +pip install shark-ai +``` + +#### Nightly + +```bash +pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +``` + +#### Install dataclasses-json + + + +```bash +pip install dataclasses-json +``` + +### Define a directory for export files + +Create a new directory for us to export files like +`model.mlir`, `model.vmfb`, etc. + +```bash +mkdir $PWD/export +export EXPORT_DIR=$PWD/export +``` + +### Download llama3_8b_fp16.gguf + +We will use the `hf_datasets` module in `sharktank` to download a +LLama3.1 8b f16 model. + +```bash +python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR +``` + +### Define environment variables + +Define the following environment variables to make running +this example a bit easier: + +#### Model/Tokenizer vars + +This example uses the `llama8b_f16.gguf` and `tokenizer.json` files +that were downloaded in the previous step. + +```bash +export MODEL_PARAMS_PATH=$EXPORT_DIR/llama3.1-8b/llama8b_f16.gguf +export TOKENIZER_PATH=$EXPORT_DIR/llama3.1-8b/tokenizer.json +``` + +#### General env vars + +The following env vars can be copy + pasted directly: + +```bash +# Path to export model.mlir file +export MLIR_PATH=$EXPORT_DIR/model.mlir +# Path to export config.json file +export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json +# Path to export edited_config.json file +export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json +# Path to export model.vmfb file +export VMFB_PATH=$EXPORT_DIR/model.vmfb +# Batch size for kvcache +export BS=1,4 +# NOTE: This is temporary, until multi-device is fixed +export ROCR_VISIBLE_DEVICES=1 +``` + +## Export to MLIR + +We will now use the `sharktank.examples.export_paged_llm_v1` script +to export our model to `.mlir` format. + +```bash +python -m sharktank.examples.export_paged_llm_v1 \ + --irpa-file=$MODEL_PARAMS_PATH \ + --output-mlir=$MLIR_PATH \ + --output-config=$OUTPUT_CONFIG_PATH \ + --bs=$BS +``` + +## Compiling to `.vmfb` + +Now that we have generated a `model.mlir` file, +we can compile it to `.vmfb` format, which is required for running +the `shortfin` LLM server. + +We will use the +[iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile) +tool for compiling our model. + +### Compile for MI300 + +**NOTE: This command is specific to MI300 GPUs. +For other `--iree-hip-target` GPU options, +look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)** + +```bash +iree-compile $MLIR_PATH \ + --iree-hal-target-backends=rocm \ + --iree-hip-target=gfx942 \ + -o $VMFB_PATH +``` + +## Write an edited config + +We need to write a config for our model with a slightly edited structure +to run with shortfin. This will work for the example in our docs. +You may need to modify some of the parameters for a specific model. + +### Write edited config + +```bash +cat > $EDITED_CONFIG_PATH << EOF +{ + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": [ + $BS + ], + "decode_batch_sizes": [ + $BS + ], + "transformer_block_count": 32, + "paged_kv_cache": { + "block_seq_stride": 16, + "device_block_count": 256 + } +} +EOF +``` + +## Running the `shortfin` LLM server + +We should now have all of the files that we need to run the shortfin LLM server. + +Verify that you have the following in your specified directory ($EXPORT_DIR): + +```bash +ls $EXPORT_DIR +``` + +- edited_config.json +- model.vmfb + +### Launch server: + + + +#### Run the shortfin server + +Now that we are finished with setup, we can start the Shortfin LLM Server. + +Run the following command to launch the Shortfin LLM Server in the background: + +> **Note** +> By default, our server will start at `http://localhost:8000`. +> You can specify the `--host` and/or `--port` arguments, to run at a different address. +> +> If you receive an error similar to the following: +> +> `[errno 98] address already in use` +> +> Then, you can confirm the port is in use with `ss -ntl | grep 8000` +> and either kill the process running at that port, +> or start the shortfin server at a different port. + +```bash +python -m shortfin_apps.llm.server \ + --tokenizer_json=$TOKENIZER_PATH \ + --model_config=$EDITED_CONFIG_PATH \ + --vmfb=$VMFB_PATH \ + --parameters=$MODEL_PARAMS_PATH \ + --device=hip > shortfin_llm_server.log 2>&1 & +shortfin_process=$! +``` + +You can verify your command has launched successfully +when you see the following logs outputted to terminal: + +```bash +cat shortfin_llm_server.log +``` + +#### Expected output + +```text +[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete. +[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +## Verify server + +We can now verify our LLM server by sending a simple request: + +### Open python shell + +```bash +python +``` + +### Send request + +```python +import requests + +import os + +port = 8000 # Change if running on a different port + +generate_url = f"http://localhost:{port}/generate" + +def generation_request(): + payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}} + try: + resp = requests.post(generate_url, json=payload) + resp.raise_for_status() # Raises an HTTPError for bad responses + print(resp.text) + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + +generation_request() +``` + +After you receive the request, you can exit the python shell: + +```bash +quit() +``` + +## Cleanup + +When done, you can kill the shortfin_llm_server by killing the process: + +```bash +kill -9 $shortfin_process +``` diff --git a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md new file mode 100644 index 000000000..b63861a56 --- /dev/null +++ b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md @@ -0,0 +1,254 @@ +# Using `shortfin` with `sglang` + +This doc includes basic steps for hooking up sglang with a running Shortfin server. + +## Current Support Status + +| Feature | Description | Enabled | Reference | +| ----------- | ----------- | ---------- | ------------ | +| `gen` | Generate shortfin completion, given a prompt | ✅ | [Shortfin Implementation](https://github.com/nod-ai/sglang/blob/main/python/sglang/lang/backend/shortfin.py) | +| `streaming` | Stream shortfin completion, given a prompt | ✅ | [Streaming](https://sgl-project.github.io/frontend/frontend.html#streaming) | +| `run_batch` | Run batch of disjoint requests with continous batching | ✅ | [Batching](https://sgl-project.github.io/frontend/frontend.html#batching) | +| `fork` | Generate sections of the same prompt in parallel | ✅ | [Fork Docs](https://sgl-project.github.io/frontend/frontend.html#parallelism) | +| `choices` | Given set of choices, generate response based on best log probs | ❌ | [Choices Methods](https://sgl-project.github.io/frontend/choices_methods.html#choices-methods-in-sglang) | +| `image` | Pass image as part of multi-modal prompt | ❌ | [sgl.image](https://sgl-project.github.io/frontend/frontend.html#multi-modality) | +| `regex` | Specify regular expression as decoding constraint | ❌ | [Regex](https://sgl-project.github.io/frontend/frontend.html#constrained-decoding) | + +## Prerequisites + +For this tutorial, you will need to meet the following prerequisites: + +### Software + +- Python >= 3.11 + - You can check out [pyenv](https://github.com/pyenv/pyenv) + as a good tool to be able to manage multiple versions of python + on the same system. +- A running `shortfin` LLM server as described [below](#installstart-shortfin-llm-server) + - We will use the shortfin server as the `backend` to generate completions + from SGLang's `frontend language`. In this tutorial, you can think of + `sglang` as the client and `shortfin` as the server. + +### Hardware + +- This tutorial is designed to run on an [AMD MI300X GPU](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html) + +## Install/Start `shortfin` LLM server + +Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/e2e_llama8b_mi300x.md) +to export a model with `sharktank` and start a `shortfin` LLM server +with that model. + +## Install sglang + +### Install sglang inside of virtual environment + +Currently, we have our SGLang integration located at this [forked repo](https://github.com/nod-ai/sglang). +We can use pip to install it in the same virtual environment that we used +to start our Shortfin LLM Server. + +```bash +pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" +``` + +## Getting started + +You can verify the installation/setup through the following examples: + +- [Multi-Turn Q&A Example](#multi-turn-qa-example) +- [Fork Example](#fork-example) +- [Benchmark Shortfin](#bench-mark-shortfin-w-sglang-bench_serving-script) + +## Multi-Turn Q&A example + +Now that we have sglang installed, we can run an example to show a multi-turn +Q&A flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): + +### Open python interpreter + +```bash +python +``` + +### Run example + +You can copy and paste the following example into your interpreter: + +```python +import sglang as sgl + +from sglang.lang.chat_template import get_chat_template + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000", ) # Change base_url if running at different address + +sgl.set_default_backend(backend) + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + +state = multi_turn_question.run(question_1="Name the capital city of the USA.", question_2="The Smithsonian is in this location.") + +for m in state.messages(): + print(m["role"], m["content"]) +``` + +### Shortfin example output + +You should see an output similar to this: + +```text +========== single ========== + +user : Name the capital city of the USA +assistant : The capital city of the United States of America is Washington, D.C. (short for District of Columbia). +user : The Smithsonian is in this location. +assistant : The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. +``` + +## Fork example + +Now that we have sglang installed, we can run an example to show a `fork` +flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html): + +### Open python interpreter + +```bash +python +``` + +### Run example + +You can copy and paste the following example into your interpreter: + +```python +import sglang as sgl + +from sglang.lang.chat_template import get_chat_template + +backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000") # Change base_url if running at different address + +sgl.set_default_backend(backend) + +@sgl.function +def tip_suggestion(s): + s += ( + "Here are two tips for staying healthy: " + "1. Balanced Diet. 2. Regular Exercise.\n\n" + ) + forks = s.fork(2) + for i, f in enumerate(forks): + f += f"Now, expand tip {i+1} into a paragraph:\n" + f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") + s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" + s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") + +state = tip_suggestion.run() + +print(state.text()) +``` + +### Shortfin example output + +You should see an output similar to this: + +```text +Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise. + +Tip 1:A balanced diet is important for maintaining good health. It should +include a variety of foods from all the major food groups, such as fruits, +vegetables, grains, proteins, and dairy. Eating a balanced diet can help +prevent chronic diseases such as heart disease, diabetes, and obesity. + +Now, expand tip 2 into a paragraph: +Regular exercise is also important for maintaining good health. It can help +improve cardiovascular health, strengthen muscles and bones, and reduce the +risk of chronic diseases. Exercise can also help improve mental health by +reducing stress and anxiety. It is recommended that adults get at least 150 +minutes of moderate-intensity exercise or 75 minutes of vigorous-intensity +exercise per week. + +Now, combine the two paragraphs into a single paragraph: +A balanced diet and regular exercise are both important for maintaining good +health. A balanced diet should include a variety of foods from all the major +food groups, such as fruits, vegetables, grains, proteins, and dairy. +Eating a balanced diet can help prevent chronic diseases such as heart disease, +diabetes, and obesity. Regular exercise is also important for maintaining good +health. It can help improve cardiovascular health, strengthen muscles and bones, +and reduce the risk of chronic diseases. Exercise can also help improve mental +health by reducing stress and anxiety. It is recommended that + +Tip 2:Regular exercise is important for maintaining a healthy body and mind. +It can help improve cardiovascular health, strengthen muscles and bones, +and reduce the risk of chronic diseases such as diabetes and heart disease. +Additionally, exercise has been shown to improve mood, reduce stress, +and increase overall well-being. It is recommended that adults engage in +at least 150 minutes of moderate-intensity aerobic activity or 75 minutes of +vigorous-intensity aerobic activity per week, as well as strength training +exercises at least two days per week. + +In summary, a balanced diet and regular exercise are both essential for +maintaining good health. A balanced diet should include a variety of foods from +all the major food groups, while regular exercise can help improve +cardiovascular health, strengthen muscles and bones, reduce the risk of +chronic diseases, and improve mental health. It is recommended that adults +engage in at least 150 minutes of moderate-intensity aerobic activity or +75 minutes of vigorous-intensity aerobic activity per week, +as well as strength training exercises at least two days per week. +``` + +## Benchmark shortfin w/ sglang `bench_serving` script + +We can obtain benchmarking metrics using the `bench_serving` script +provided by SGLang: + +**NOTE: Change `--base-url` if running at a different address** + +```bash +python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer /path/to/tokenizer/dir --request-rate 1 +``` + +There are some more metrics captured, but the most relevant are the following: + +- E2E Latency +- TTFT (Time to First Token) +- TPOT (Time per Output Token) +- ITL (Inter-Token Latency) +- Request Throughput +- Benchmark Duration + +When complete, you should see an output similar to this: + +```text +============ Serving Benchmark Result ============ +Backend: shortfin +Traffic request rate: 1.0 +Successful requests: 10 +Benchmark duration (s): 427.91 +Total input tokens: 1960 +Total generated tokens: 2774 +Total generated tokens (retokenized): 63 +Request throughput (req/s): 0.02 +Input token throughput (tok/s): 4.58 +Output token throughput (tok/s): 6.48 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 416268.77 +Median E2E Latency (ms): 417159.14 +---------------Time to First Token---------------- +Mean TTFT (ms): 292404.29 +Median TTFT (ms): 365989.01 +P99 TTFT (ms): 367325.63 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 1359.41 +Median TPOT (ms): 163.96 +P99 TPOT (ms): 6316.12 +---------------Inter-token Latency---------------- +Mean ITL (ms): 2238.99 +Median ITL (ms): 958.75 +P99 ITL (ms): 2719.50 +================================================== +``` diff --git a/docs/user_guide.md b/docs/user_guide.md new file mode 100644 index 000000000..a0415eb63 --- /dev/null +++ b/docs/user_guide.md @@ -0,0 +1,116 @@ +# SHARK User Guide + +These instructions cover the usage of the latest stable release of SHARK. For a more bleeding edge release please install the [nightly releases](nightly_releases.md). + +## Prerequisites + +Our current user guide requires that you have: +- Access to a computer with an installed AMD Instinct™ MI300x Series Accelerator +- Installed a compatible version of Linux and ROCm on the computer (see the [ROCm compatability matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html)) + +## Set up Environment + +This section will help you install Python and set up a Python environment with venv. + +Officially we support Python versions: 3.11, 3.12, 3.13 + +The rest of this guide assumes you are using Python 3.11. + +### Install Python +To install Python 3.11 on Ubuntu: + +```bash +sudo apt install python3.11 python3.11-dev python3.11-venv + +which python3.11 +# /usr/bin/python3.11 +``` + +### Create a Python Environment + +Setup your Python environment with the following commands: + +```bash +# Set up a virtual environment to isolate packages from other envs. +python3.11 -m venv 3.11.venv +source 3.11.venv/bin/activate + +# Optional: faster installation of torch with just CPU support. +# See other options at https://pytorch.org/get-started/locally/ +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +``` + +## Install SHARK and its dependencies + +```bash +pip install shark-ai[apps] +``` + +> [!TIP] +> To switch from the stable release channel to the nightly release channel, +> see [`nightly_releases.md`](./nightly_releases.md). + +### Test the installation. + +``` +python -m shortfin_apps.sd.server --help +``` + +## Quickstart + +### Run the SDXL Server + +Run the [SDXL Server](../shortfin/python/shortfin_apps/sd/README.md#Start-SDXL-Server) + +### Run the SDXL Client + +``` +python -m shortfin_apps.sd.simple_client --interactive +``` + +Congratulations!!! At this point you can play around with the server and client based on your usage. + +### Note: Server implementation scope + +The SDXL server's implementation does not account for extremely large client batches. Normally, for heavy workloads, services would be composed under a load balancer to ensure each service is fed with requests optimally. For most cases outside of large-scale deployments, the server's internal batching/load balancing is sufficient. + +### Update flags + +Please see --help for both the server and client for usage instructions. Here's a quick snapshot. + +#### Update server options: + +| Flags | options | +|---|---| +|--host HOST | +|--port PORT | server port | +|--root-path ROOT_PATH | +|--timeout-keep-alive | +|--device | local-task,hip,amdgpu | amdgpu only supported in this release +|--target | gfx942,gfx1100 | gfx942 only supported in this release +|--device_ids | +|--tokenizers | +|--model_config | +| --workers_per_device | +| --fibers_per_device | +| --isolation | per_fiber, per_call, none | +| --show_progress | +| --trace_execution | +| --amdgpu_async_allocations | +| --splat | +| --build_preference | compile,precompiled | +| --compile_flags | +| --flagfile FLAGFILE | +| --artifacts_dir ARTIFACTS_DIR | Where to store cached artifacts from the Cloud | + +#### Update client with different options: + +| Flags |options| +|---|--- +|--file | +|--reps | +|--save | Whether to save image generated by the server | +|--outputdir| output directory to store images generated by SDXL | +|--steps | +|--interactive | +|--port| port to interact with server | diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt deleted file mode 100644 index 86d0267a5..000000000 --- a/libshortfin/CMakeLists.txt +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception - -cmake_minimum_required(VERSION 3.28) - -if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR) - message( - FATAL_ERROR - "Do not build in-source. Please remove CMakeCache.txt and the CMakeFiles/ directory. Then build out-of-source." - ) -endif() - -project( - "libshortfin" - VERSION 0.9 - LANGUAGES C CXX) - -include(CMakeDependentOption) - -set(SOVERSION 1) - -set(CMAKE_C_STANDARD 11) -set(CMAKE_CXX_STANDARD 20) -# https://discourse.cmake.org/t/cmake-3-28-cmake-cxx-compiler-clang-scan-deps-notfound-not-found/9244/3 -set(CMAKE_CXX_SCAN_FOR_MODULES 0) - -# Problems with linking libfmt without PIC. -# Turn on PIC on non windows targets. -if(NOT WIN32) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) -endif() - -# build options -option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF) -option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON) -option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" ON) - -set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source") - -# Options for building static or dynamic libraries. -option(SHORTFIN_BUILD_STATIC "Builds static libraries" OFF) -option(SHORTFIN_BUILD_DYNAMIC "Builds dynamic libraries" ON) -cmake_dependent_option(SHORTFIN_LINK_DYNAMIC "Links internal binaries against static libshortfin.a" ON "SHORTFIN_BUILD_DYNAMIC" OFF) -if(NOT SHORTFIN_BUILD_STATIC AND NOT SHORTFIN_BUILD_DYNAMIC) - message(FATAL_ERROR "One of SHORTFIN_BUILD_STATIC or SHORTFIN_BUILD_DYNAMIC must be ON") -endif() -message(STATUS "Shortfin build static = ${SHORTFIN_BUILD_STATIC}, dynamic = ${SHORTFIN_BUILD_DYNAMIC}") -if(SHORTFIN_LINK_DYNAMIC) - message(STATUS "Dynamic linking to shortfin") - set(SHORTFIN_LINK_LIBRARY_NAME "shortfin") -else() - message(STATUS "Static linking to shortfin-static") - set(SHORTFIN_LINK_LIBRARY_NAME "shortfin-static") -endif() - -# Enabling ASAN. Note that this will work best if building in a completely -# bundled fashion and with an ASAN rigged CPython. Otherwise, various LD_PRELOAD -# hacks are needed. This is merely a develope convenience: people are more -# than welcome to set flags themselves. -option(SHORTFIN_ENABLE_ASAN "Enable ASAN" OFF) -if(SHORTFIN_ENABLE_ASAN) - add_compile_options(-fsanitize=address) - add_link_options(-fsanitize=address) - - # Enable more ASAN checks. - add_compile_definitions(IREE_SANITIZER_ADDRESS) -endif() - -option(SHORTFIN_SYSTEMS_AMDGPU "Builds for AMD GPU systems" ON) -message(STATUS "libshortfin supported systems:") -if(SHORTFIN_SYSTEMS_AMDGPU) - message(STATUS " - AMD GPU") - add_compile_definitions("SHORTFIN_HAVE_AMDGPU") -endif() -message(STATUS " - Host") - -include(FetchContent) - -# Includes. -list(APPEND CMAKE_MODULE_PATH - ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/ -) -include(shortfin_library) - -# Dependencies. - -if(SHORTFIN_BUNDLE_DEPS) - ## fmt - FetchContent_Declare( - fmt - GIT_REPOSITORY https://github.com/fmtlib/fmt.git - GIT_TAG e69e5f977d458f2650bb346dadf2ad30c5320281 # 10.2.1 (sync with spdlog) - ) - - ## spdlog - # We build fmt from source instead, because we also use fmt. - set(SPDLOG_FMT_EXTERNAL ON) - FetchContent_Declare( - spdlog - GIT_REPOSITORY https://github.com/gabime/spdlog.git - GIT_TAG 2d4acf8cc321d7783d8f2e22e17a794c6d0e9450 # v1.14.1 - ) - - ## xtl: required for xtensor - FetchContent_Declare( - xtl - GIT_REPOSITORY https://github.com/xtensor-stack/xtl.git - GIT_TAG a7c1c5444dfc57f76620391af4c94785ff82c8d6 # v0.7.7 - ) - - ## xtensor - FetchContent_Declare( - xtensor - GIT_REPOSITORY https://github.com/xtensor-stack/xtensor.git - GIT_TAG 3634f2ded19e0cf38208c8b86cea9e1d7c8e397d # v0.25.0 - ) - - FetchContent_MakeAvailable(fmt spdlog xtl xtensor) -else() - find_package(spdlog) - find_package(xtensor) -endif() - -## iree runtime - -if (NOT SHORTFIN_IREE_SOURCE_DIR AND SHORTFIN_BUNDLE_DEPS) - FetchContent_Declare( - iree - GIT_REPOSITORY https://github.com/iree-org/iree.git - GIT_TAG candidate-20240904.1006 - # TODO: We shouldn't have to pull googletest when we are not building tests. - # This needs to be fixed with IREE. - GIT_SUBMODULES "third_party/benchmark third_party/cpuinfo third_party/flatcc third_party/hip-build-deps third_party/googletest" - GIT_SHALLOW TRUE - ) - FetchContent_GetProperties(iree) - if(NOT iree_POPULATED) - FetchContent_Populate(iree) - endif() - set(SHORTFIN_IREE_SOURCE_DIR ${iree_SOURCE_DIR}) -endif() - -if(SHORTFIN_IREE_SOURCE_DIR) - set(IREE_BUILD_COMPILER OFF) - set(IREE_BUILD_TESTS OFF) - set(IREE_BUILD_SAMPLES OFF) - # Disable missing submodules error because we are only building the runtime. - set(IREE_ERROR_ON_MISSING_SUBMODULES OFF) - # Only enable local_sync/local_task/hip drivers for now. - set(IREE_HAL_DRIVER_DEFAULTS OFF) - set(IREE_HAL_DRIVER_LOCAL_SYNC ON) - set(IREE_HAL_DRIVER_LOCAL_TASK ON) - if(SHORTFIN_SYSTEMS_AMDGPU) - set(IREE_HAL_DRIVER_HIP ON) - endif() - add_subdirectory(${SHORTFIN_IREE_SOURCE_DIR} shortfin_iree SYSTEM EXCLUDE_FROM_ALL) -else() - # Try to find iree using find_package - find_package(IREERuntime) -endif() - -# tests - -if(SHORTFIN_BUILD_TESTS) - if (NOT SHORTFIN_IREE_SOURCE_DIR) - # For now we use gtest shipped alongside with IREE. - FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip - ) - # For Windows: Prevent overriding the parent project's compiler/linker settings - set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) - FetchContent_MakeAvailable(googletest) - endif() - include(GoogleTest) - enable_testing() -endif() - - -add_subdirectory(src) - -if(SHORTFIN_BUILD_PYTHON_BINDINGS) - find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) - add_subdirectory(bindings/python) - set(SHORTFIN_PYTHON_CPP_PREBUILT "TRUE") # See setup.py. - configure_file(setup.py setup.py @ONLY) - configure_file(pyproject.toml pyproject.toml COPYONLY) -endif() diff --git a/libshortfin/README.md b/libshortfin/README.md deleted file mode 100644 index 435dcbb87..000000000 --- a/libshortfin/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# libshortfin - SHARK C++ inference library - -## Dev Builds - -Library dependencies: - -* [spdlog](https://github.com/gabime/spdlog) -* [xtensor](https://github.com/xtensor-stack/xtensor) -* [iree runtime](https://github.com/iree-org/iree) - -On recent Ubuntu, the primary dependencies can be satisfied via: - -``` -apt install libspdlog-dev libxtensor-dev -``` - -CMake must be told how to find the IREE runtime, either from a distribution -tarball, or local build/install dir. For a local build directory, pass: - -``` -# Assumes that the iree-build directory is adjacent to this repo. --DCMAKE_PREFIX_PATH=$(pwd)/../../iree-build/lib/cmake/IREE -``` - -One liner recommended CMake command (note that CMAKE_LINKER_TYPE requires -cmake>=3.29): - -``` -cmake -GNinja -S. -Bbuild \ - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_LINKER_TYPE=LLD \ - -DCMAKE_PREFIX_PATH=$(pwd)/../../iree-build/lib/cmake/IREE -``` - -## Building Python Bindings - -If using a Python based development flow, there are two options: - -1. `pip install -v .` to build and install the library (TODO: Not yet implemented). -2. Build with cmake and `-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON` and then - from the `build/` directory, run `pip install -v -e .` to create an - editable install that will update as you build the C++ project. - -If predominantly developing with a C++ based flow, the second option is -recommended. Your python install should track any source file changes or -builds without further interaction. Re-installing will be necessary if package -structure changes significantly. - -## Running Tests - -The project uses a combination of ctest for native C++ tests and pytest. Much -of the functionality is only tested via the Python tests, using the -`_shortfin.lib` internal implementation directly. In order to run these tests, -you must have installed the Python package as per the above steps. - -Which style of test is used is pragmatic and geared at achieving good test -coverage with a minimum of duplication. Since it is often much more expensive -to build native tests of complicated flows, many things are only tested via -Python. This does not preclude having other language bindings later, but it -does mean that the C++ core of the library must always be built with the -Python bindings to test the most behavior. Given the target of the project, -this is not considered to be a significant issue. - -# Production Library Building - -In order to build a production library, additional build steps are typically -recommended: - -* Compile all deps with the same compiler/linker for LTO compatibility -* Provide library dependencies manually and compile them with LTO -* Compile dependencies with `-fvisibility=hidden` -* Enable LTO builds of libshortfin -* Set flags to enable symbol versioning diff --git a/libshortfin/bindings/python/CMakeLists.txt b/libshortfin/bindings/python/CMakeLists.txt deleted file mode 100644 index f1e999783..000000000 --- a/libshortfin/bindings/python/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception - -# libshortfin publishes multiple python packages: - _shortfin: Trampoline -# __init__.py which looks at environment variables to load an appropriate native -# library. - _shortfin_default.lib: Native library as a default, uninstrumented -# build. - _shortfin_tracing.lib: Native library with tracing enabled (TODO). - -# Others. - -# nanobind -FetchContent_Declare( - nanobind - GIT_REPOSITORY https://github.com/wjakob/nanobind.git - GIT_TAG 9641bb7151f04120013b812789b3ebdfa7e7324f # 2.1.0 -) -FetchContent_MakeAvailable(nanobind) - -nanobind_add_module(shortfin_python_extension NB_STATIC LTO - array_binding.cc - lib_ext.cc -) - -set_target_properties(shortfin_python_extension - PROPERTIES OUTPUT_NAME "_shortfin_default/lib") - -target_link_libraries(shortfin_python_extension - PRIVATE ${SHORTFIN_LINK_LIBRARY_NAME} -) - -nanobind_add_stub( - shortfin_python_extension_stub - MODULE _shortfin_default.lib - OUTPUT _shortfin_default/lib.pyi - DEPENDS shortfin_python_extension -) diff --git a/libshortfin/bindings/python/array_binding.cc b/libshortfin/bindings/python/array_binding.cc deleted file mode 100644 index 608aa4944..000000000 --- a/libshortfin/bindings/python/array_binding.cc +++ /dev/null @@ -1,297 +0,0 @@ -// Copyright 2024 Advanced Micro Devices, Inc. -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "./lib_ext.h" -#include "./utils.h" -#include "shortfin/array/api.h" - -using namespace shortfin::array; - -namespace shortfin::python { - -namespace { -static const char DOCSTRING_ARRAY_COPY_FROM[] = - R"(Copy contents from a source array to this array. - -Equivalent to `dest_array.storage.copy_from(source_array.storage)`. -)"; - -static const char DOCSTRING_ARRAY_COPY_TO[] = - R"(Copy contents this array to a destination array. - -Equivalent to `dest_array.storage.copy_from(source_array.storage)`. -)"; - -static const char DOCSTRING_ARRAY_FILL[] = R"(Fill an array with a value. - -Equivalent to `array.storage.fill(pattern)`. -)"; - -static const char DOCSTRING_STORAGE_DATA[] = R"(Access raw binary contents. - -Accessing `foo = storage.data` is equivalent to `storage.data.map(read=True)`. -The returned object is a context manager that will close on exit. - -Assigning `storage.data = array.array("f", [1.0])` will copy that raw data -from the source object using the buffer protocol. The source data must be -less than or equal to the length of the storage object. Note that the entire -storage is mapped as write-only/discardable, and writing less than the storage -bytes leaves any unwritten contents in an undefined state. - -As with `map`, this will only work on buffers that are host visible, which -includes all host buffers and device buffers created with the necessary access. -)"; - -static const char DOCSTRING_STORAGE_COPY_FROM[] = - R"(Copy contents from a source storage to this array. - -This operation executes asynchronously and the effect will only be visible -once the execution scope has been synced to the point of mutation. -)"; - -static const char DOCSTRING_STORAGE_FILL[] = R"(Fill a storage with a value. - -Takes as argument any value that can be interpreted as a buffer with the Python -buffer protocol of size 1, 2, or 4 bytes. The storage will be filled uniformly -with the pattern. - -This operation executes asynchronously and the effect will only be visible -once the execution scope has been synced to the point of mutation. -)"; - -static const char DOCSTRING_STORAGE_MAP[] = - R"(Create a mapping of the buffer contents in host memory. - -Support kwargs of: - -read: Enables read access to the mapped memory. -write: Enables write access to the mapped memory and will flush upon close - (for non-unified memory systems). -discard: Indicates that the entire memory map should be treated as if it will - be overwritten. Initial contents will be undefined. - -Mapping memory for access from the host requires a compatible buffer that has -been created with host visibility (which includes host buffers). - -The returned mapping object is a context manager that will close/flush on -exit. Alternatively, the `close()` method can be invoked explicitly. -)"; - -// Does in-place creation of a mapping object and stores a pointer to the -// contained array::mapping C++ object. -py::object CreateMappingObject(mapping **out_cpp_mapping) { - py::object py_mapping = py::inst_alloc(py::type()); - mapping *cpp_mapping = py::inst_ptr(py_mapping); - new (cpp_mapping) mapping(); - py::inst_mark_ready(py_mapping); - *out_cpp_mapping = cpp_mapping; - return py_mapping; -} - -} // namespace - -void BindArray(py::module_ &m) { - py::class_(m, "DType") - .def_prop_ro("is_boolean", &DType::is_boolean) - .def_prop_ro("is_integer", &DType::is_integer) - .def_prop_ro("is_float", &DType::is_float) - .def_prop_ro("is_complex", &DType::is_complex) - .def_prop_ro("bit_count", &DType::bit_count) - .def_prop_ro("is_byte_aligned", &DType::is_byte_aligned) - .def_prop_ro("dense_byte_count", &DType::dense_byte_count) - .def("is_integer_bitwidth", &DType::is_integer_bitwidth) - .def(py::self == py::self) - .def("__repr__", &DType::name); - -#define SHORTFIN_DTYPE_HANDLE(et, ident) m.attr(#ident) = DType::ident(); -#include "shortfin/array/dtypes.inl" -#undef SHORTFIN_DTYPE_HANDLE - - // storage - py::class_(m, "storage") - .def("__sfinv_marshal__", - [](device_array *self, py::capsule inv_capsule, int barrier) { - auto *inv = - static_cast(inv_capsule.data()); - static_cast(self) - ->AddAsInvocationArgument( - inv, static_cast(barrier)); - }) - .def_static( - "allocate_host", - [](local::ScopedDevice &device, iree_device_size_t allocation_size) { - return storage::allocate_host(device, allocation_size); - }, - py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) - .def_static( - "allocate_device", - [](local::ScopedDevice &device, iree_device_size_t allocation_size) { - return storage::allocate_device(device, allocation_size); - }, - py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) - .def( - "fill", - [](storage &self, py::handle buffer) { - Py_buffer py_view; - int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND. - if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { - throw py::python_error(); - } - PyBufferReleaser py_view_releaser(py_view); - self.fill(py_view.buf, py_view.len); - }, - py::arg("pattern"), DOCSTRING_STORAGE_FILL) - .def( - "copy_from", [](storage &self, storage &src) { self.copy_from(src); }, - py::arg("source_storage"), DOCSTRING_STORAGE_COPY_FROM) - .def( - "map", - [](storage &self, bool read, bool write, bool discard) { - int access = 0; - if (read) access |= IREE_HAL_MEMORY_ACCESS_READ; - if (write) access |= IREE_HAL_MEMORY_ACCESS_WRITE; - if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD; - if (!access) { - throw std::invalid_argument( - "One of the access flags must be set"); - } - mapping *cpp_mapping = nullptr; - py::object py_mapping = CreateMappingObject(&cpp_mapping); - self.map_explicit( - *cpp_mapping, - static_cast(access)); - return py_mapping; - }, - py::kw_only(), py::arg("read") = false, py::arg("write") = false, - py::arg("discard") = false, DOCSTRING_STORAGE_MAP) - // The 'data' prop is a short-hand for accessing the backing storage - // in a one-shot manner (as for reading or writing). Getting the attribute - // will map for read and return a memory view (equiv to map(read=True)). - // On write, it will accept an object implementing the buffer protocol - // and write/discard the backing storage. - .def_prop_rw( - "data", - [](storage &self) { - mapping *cpp_mapping = nullptr; - py::object py_mapping = CreateMappingObject(&cpp_mapping); - *cpp_mapping = self.map_read(); - return py_mapping; - }, - [](storage &self, py::handle buffer_obj) { - PyBufferRequest src_info(buffer_obj, PyBUF_SIMPLE); - auto dest_data = self.map_write_discard(); - if (src_info.view().len > dest_data.size()) { - throw std::invalid_argument( - fmt::format("Cannot write {} bytes into buffer of {} bytes", - src_info.view().len, dest_data.size())); - } - std::memcpy(dest_data.data(), src_info.view().buf, - src_info.view().len); - }, - DOCSTRING_STORAGE_DATA) - .def(py::self == py::self) - .def("__repr__", &storage::to_s); - - // mapping - auto mapping_class = py::class_(m, "mapping"); - mapping_class.def("close", &mapping::reset) - .def_prop_ro("valid", [](mapping &self) -> bool { return self; }) - .def("__enter__", [](py::object self_obj) { return self_obj; }) - .def( - "__exit__", - [](mapping &self, py::handle exc_type, py::handle exc_value, - py::handle exc_tb) { self.reset(); }, - py::arg("exc_type").none(), py::arg("exc_value").none(), - py::arg("exc_tb").none()); - struct MappingBufferHandler { - int operator()(mapping &self, Py_buffer *view, int flags) { - view->buf = self.data(); - view->len = self.size(); - view->readonly = !self.writable(); - view->itemsize = 1; - view->format = (char *)"B"; // Byte - view->ndim = 1; - view->shape = nullptr; - view->strides = nullptr; - view->suboffsets = nullptr; - view->internal = nullptr; - return 0; - } - }; - BindBufferProtocol(mapping_class); - - // base_array and subclasses - py::class_(m, "base_array") - .def_prop_ro("dtype", &base_array::dtype) - .def_prop_ro("shape", &base_array::shape); - py::class_(m, "device_array") - .def("__init__", [](py::args, py::kwargs) {}) - .def_static("__new__", - [](py::handle py_type, class storage storage, - std::span shape, DType dtype) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/storage.scope(), storage, shape, - dtype); - }) - .def_static("__new__", - [](py::handle py_type, local::ScopedDevice &device, - std::span shape, DType dtype) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/device.scope(), - device_array::for_device(device, shape, dtype)); - }) - .def("__sfinv_marshal__", - [](device_array *self, py::capsule inv_capsule, int barrier) { - auto *inv = - static_cast(inv_capsule.data()); - static_cast(self) - ->AddAsInvocationArgument( - inv, static_cast(barrier)); - }) - .def_static("for_device", - [](local::ScopedDevice &device, std::span shape, - DType dtype) { - return custom_new_keep_alive( - py::type(), /*keep_alive=*/device.scope(), - device_array::for_device(device, shape, dtype)); - }) - .def_static("for_host", - [](local::ScopedDevice &device, std::span shape, - DType dtype) { - return custom_new_keep_alive( - py::type(), /*keep_alive=*/device.scope(), - device_array::for_host(device, shape, dtype)); - }) - .def("for_transfer", - [](device_array &self) { - return custom_new_keep_alive( - py::type(), - /*keep_alive=*/self.device().scope(), self.for_transfer()); - }) - .def_prop_ro("device", &device_array::device, - py::rv_policy::reference_internal) - .def_prop_ro("storage", &device_array::storage, - py::rv_policy::reference_internal) - - .def( - "fill", - [](py::handle_t self, py::handle buffer) { - self.attr("storage").attr("fill")(buffer); - }, - py::arg("pattern"), DOCSTRING_ARRAY_FILL) - .def("copy_from", &device_array::copy_from, py::arg("source_array"), - DOCSTRING_ARRAY_COPY_FROM) - .def("copy_to", &device_array::copy_to, py::arg("dest_array"), - DOCSTRING_ARRAY_COPY_TO) - .def("__repr__", &device_array::to_s) - .def("__str__", [](device_array &self) -> std::string { - auto contents = self.contents_to_s(); - if (!contents) return "<>"; - return *contents; - }); -} - -} // namespace shortfin::python diff --git a/libshortfin/build_tools/cmake/shortfin_library.cmake b/libshortfin/build_tools/cmake/shortfin_library.cmake deleted file mode 100644 index bae9fe115..000000000 --- a/libshortfin/build_tools/cmake/shortfin_library.cmake +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception - -set(SHORTFIN_DEFAULT_COPTS - # General clang and GCC options application to C and C++. - $<$: - -Wall - -Werror - > - - # General MSVC options applicable to C and C++. - $<$: - > -) - -function(shortfin_public_library) - cmake_parse_arguments( - _RULE - "" - "NAME" - "COMPONENTS" - ${ARGN} - ) - if(SHORTFIN_BUILD_STATIC) - # Static library. - shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS}) - add_library("${_RULE_NAME}-static" STATIC) - target_link_libraries( - "${_RULE_NAME}-static" PUBLIC ${_STATIC_COMPONENTS} - ) - endif() - - if(SHORTFIN_BUILD_DYNAMIC) - # Dylib library. - shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS}) - add_library("${_RULE_NAME}" SHARED) - target_compile_definitions("${_RULE_NAME}" INTERFACE _SHORTFIN_USING_DYLIB) - target_link_libraries( - "${_RULE_NAME}" PUBLIC ${_DYLIB_COMPONENTS} - ) - endif() -endfunction() - -function(shortfin_cc_component) - cmake_parse_arguments( - _RULE - "" - "NAME" - "HDRS;SRCS;DEPS;COMPONENTS" - ${ARGN} - ) - if(SHORTFIN_BUILD_STATIC) - # Static object library. - set(_STATIC_OBJECTS_NAME "${_RULE_NAME}.objects") - shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS}) - add_library(${_STATIC_OBJECTS_NAME} OBJECT) - target_sources(${_STATIC_OBJECTS_NAME} - PRIVATE - ${_RULE_SRCS} - ${_RULE_HDRS} - ) - target_compile_options(${_STATIC_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS}) - target_link_libraries(${_STATIC_OBJECTS_NAME} - PUBLIC - _shortfin_defs - ${_STATIC_COMPONENTS} - ${_RULE_DEPS} - ) - endif() - - if(SHORTFIN_BUILD_DYNAMIC) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) - set(_DYLIB_OBJECTS_NAME "${_RULE_NAME}.dylib.objects") - shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS}) - # Dylib object library. - add_library(${_DYLIB_OBJECTS_NAME} OBJECT) - target_sources(${_DYLIB_OBJECTS_NAME} - PRIVATE - ${_RULE_SRCS} - ${_RULE_HDRS} - ) - target_compile_options(${_DYLIB_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS}) - target_link_libraries(${_DYLIB_OBJECTS_NAME} - PUBLIC - _shortfin_defs - ${_DYLIB_COMPONENTS} - ${_RULE_DEPS} - ) - set_target_properties( - ${_DYLIB_OBJECTS_NAME} PROPERTIES - CXX_VISIBILITY_PRESET hidden - C_VISIBILITY_PRESET hidden - VISIBILITY_INLINES_HIDDEN ON - ) - target_compile_definitions(${_DYLIB_OBJECTS_NAME} - PRIVATE _SHORTFIN_BUILDING_DYLIB) - endif() -endfunction() - -function(shortfin_components_to_static_libs out_static_libs) - set(_LIBS ${ARGN}) - list(TRANSFORM _LIBS APPEND ".objects") - set(${out_static_libs} ${_LIBS} PARENT_SCOPE) -endfunction() - -function(shortfin_components_to_dynamic_libs out_dynamic_libs) - set(_LIBS ${ARGN}) - list(TRANSFORM _LIBS APPEND ".dylib.objects") - set(${out_dynamic_libs} "${_LIBS}" PARENT_SCOPE) -endfunction() - -function(shortfin_gtest_test) - cmake_parse_arguments( - _RULE - "" - "NAME" - "SRCS;DEPS" - ${ARGN} - ) - - if(NOT SHORTFIN_BUILD_TESTS) - return() - endif() - - add_executable(${_RULE_NAME} ${_RULE_SRCS}) - target_link_libraries(${_RULE_NAME} PRIVATE - ${_RULE_DEPS} - ${SHORTFIN_LINK_LIBRARY_NAME} - GTest::gmock - GTest::gtest_main - ) - gtest_discover_tests(${_RULE_NAME}) -endfunction() diff --git a/libshortfin/build_tools/python_lsan_suppressions.txt b/libshortfin/build_tools/python_lsan_suppressions.txt deleted file mode 100644 index cc768d575..000000000 --- a/libshortfin/build_tools/python_lsan_suppressions.txt +++ /dev/null @@ -1,2 +0,0 @@ -leak:PyUnicode_New -leak:_PyUnicodeWriter_Finish diff --git a/libshortfin/pyproject.toml b/libshortfin/pyproject.toml deleted file mode 100644 index f1b34c64a..000000000 --- a/libshortfin/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[build-system] -requires = [ - "cmake>=3.29", - "setuptools>=61.0", - "wheel", - "ninja", -] -build-backend = "setuptools.build_meta" - -[tool.pytest.ini_options] -addopts = [ - "-ra", - "--import-mode=importlib", -] -markers = [ - "requires_amd_gpu: tests that require and AMD GPU (deselect with '-m \"not requires_amd_gpu\"')", -] -testpaths = [ - "tests", -] diff --git a/libshortfin/requirements-iree-compiler.txt b/libshortfin/requirements-iree-compiler.txt deleted file mode 100644 index 34228ccf4..000000000 --- a/libshortfin/requirements-iree-compiler.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Keep in sync with IREE_REF in CI and GIT_TAG in CMakeLists.txt --f https://iree.dev/pip-release-links.html -iree-compiler==20240904.1006 diff --git a/libshortfin/setup.py b/libshortfin/setup.py deleted file mode 100644 index b28b8d114..000000000 --- a/libshortfin/setup.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from distutils.core import setup, Extension -import sys -import shutil -import subprocess -import os -from pathlib import Path -from distutils.command.build import build as _build -from setuptools.command.build_ext import build_ext as _build_ext -from setuptools.command.build_py import build_py as _build_py - - -# This file can be generated into the build directory to allow an arbitrary -# CMake built version of the project to be installed into a venv for development. -# This can be detected if the CPP_PREBUILT global contains the string -# "TRUE", which will be the case if generated. -CPP_PREBUILT = "@SHORTFIN_PYTHON_CPP_PREBUILT@" -CPP_PREBUILT_SOURCE_DIR = "@libshortfin_SOURCE_DIR@" -CPP_PREBUILT_BINARY_DIR = "@libshortfin_BINARY_DIR@" - -SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) - - -def is_cpp_prebuilt(): - return CPP_PREBUILT == "TRUE" - - -if is_cpp_prebuilt(): - print("setup.py running in pre-built mode:", file=sys.stderr) - SOURCE_DIR = Path(CPP_PREBUILT_SOURCE_DIR) - BINARY_DIR = Path(CPP_PREBUILT_BINARY_DIR) -else: - print("setup.py running in cmake build mode:", file=sys.stderr) - # setup.py is in the source directory. - SOURCE_DIR = Path(SETUPPY_DIR) - BINARY_DIR = Path(os.path.join(SETUPPY_DIR, "build", "b")) - -print(f" SOURCE_DIR = {SOURCE_DIR}", file=sys.stderr) -print(f" BINARY_DIR = {BINARY_DIR}", file=sys.stderr) - -# Due to a quirk of setuptools, that package_dir map must only contain -# paths relative to the directory containing setup.py. Why? No one knows. -REL_SOURCE_DIR = SOURCE_DIR.relative_to(SETUPPY_DIR, walk_up=True) -REL_BINARY_DIR = BINARY_DIR.relative_to(SETUPPY_DIR, walk_up=True) - - -class CMakeExtension(Extension): - def __init__(self, name, sourcedir=""): - Extension.__init__(self, name, sources=[]) - self.sourcedir = os.path.abspath(sourcedir) - - -class CustomBuild(_build): - def run(self): - self.run_command("build_py") - self.run_command("build_ext") - self.run_command("build_scripts") - - -class NoopBuildExtension(_build_ext): - def build_extension(self, ext): - ... - - def copy_extensions_to_source(self, *args, **kwargs): - ... - - -def maybe_nuke_cmake_cache(cmake_build_dir): - # From run to run under pip, we can end up with different paths to ninja, - # which isn't great and will confuse cmake. Detect if the location of - # ninja changes and force a cache flush. - ninja_path = "" - try: - import ninja - except ModuleNotFoundError: - pass - else: - ninja_path = ninja.__file__ - expected_stamp_contents = f"{sys.executable}\n{ninja_path}" - - # In order to speed things up on CI and not rebuild everything, we nuke - # the CMakeCache.txt file if the path to the Python interpreter changed. - # Ideally, CMake would let us reconfigure this dynamically... but it does - # not (and gets very confused). - PYTHON_STAMP_FILE = os.path.join(cmake_build_dir, "python_stamp.txt") - if os.path.exists(PYTHON_STAMP_FILE): - with open(PYTHON_STAMP_FILE, "rt") as f: - actual_stamp_contents = f.read() - if actual_stamp_contents == expected_stamp_contents: - # All good. - return - - # Mismatch or not found. Clean it. - cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt") - if os.path.exists(cmake_cache_file): - print("Removing CMakeCache.txt because Python version changed", file=sys.stderr) - os.remove(cmake_cache_file) - - # And write. - with open(PYTHON_STAMP_FILE, "wt") as f: - f.write(expected_stamp_contents) - - -class CMakeBuildPy(_build_py): - def run(self): - # The super-class handles the pure python build. - super().run() - - # Build using cmake if not in prebuild mode. - if not is_cpp_prebuilt(): - - # Build extension using cmake. - print("*****************************", file=sys.stderr) - print("* Building libshortfin *", file=sys.stderr) - print("*****************************", file=sys.stderr) - - cfg = os.getenv("SHORTFIN_CMAKE_BUILD_TYPE", "Release") - - CMAKE_BUILD_DIR = BINARY_DIR - - # Configure CMake. - os.makedirs(BINARY_DIR, exist_ok=True) - maybe_nuke_cmake_cache(CMAKE_BUILD_DIR) - print(f"CMake build dir: {CMAKE_BUILD_DIR}", file=sys.stderr) - cmake_args = [ - "-GNinja", - "--log-level=VERBOSE", - "-DSHORTFIN_BUNDLE_DEPS=ON", - f"-DCMAKE_BUILD_TYPE={cfg}", - "-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON", - # TODO: This shouldn't be hardcoded... but shortfin doesn't - # compile without it. - "-DCMAKE_C_COMPILER=clang", - "-DCMAKE_CXX_COMPILER=clang++", - ] - - # Only do a from-scratch configure if not already configured. - cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt") - if not os.path.exists(cmake_cache_file): - print(f"Configuring with: {cmake_args}", file=sys.stderr) - subprocess.check_call( - ["cmake", SOURCE_DIR] + cmake_args, cwd=CMAKE_BUILD_DIR - ) - else: - print(f"Not re-configing (already configured)", file=sys.stderr) - - # Build. - subprocess.check_call(["cmake", "--build", "."], cwd=CMAKE_BUILD_DIR) - print("Build complete.", file=sys.stderr) - - # We only take _shortfin_default from the build. - target_dir = os.path.join( - os.path.abspath(self.build_lib), "_shortfin_default" - ) - print(f"Building in target: {target_dir}", file=sys.stderr) - os.makedirs(target_dir, exist_ok=True) - print("Copying build to target.", file=sys.stderr) - if os.path.exists(target_dir): - shutil.rmtree(target_dir) - shutil.copytree( - os.path.join( - CMAKE_BUILD_DIR, - "bindings", - "python", - "_shortfin_default", - ), - target_dir, - symlinks=False, - ) - - -PYTHON_SOURCE_DIR = REL_SOURCE_DIR / "bindings" / "python" -PYTHON_BINARY_DIR = REL_BINARY_DIR / "bindings" / "python" - -# We need some directories to exist before setup. -def populate_built_package(abs_dir): - """Makes sure that a directory and __init__.py exist. - - This needs to unfortunately happen before any of the build process - takes place so that setuptools can plan what needs to be built. - We do this for any built packages (vs pure source packages). - """ - os.makedirs(abs_dir, exist_ok=True) - with open(os.path.join(abs_dir, "__init__.py"), "wt"): - pass - - -populate_built_package(os.path.join(PYTHON_BINARY_DIR / "_shortfin_default")) - -setup( - name="shortfin", - version="0.9", - description="Shortfin native library implementation", - author="SHARK Authors", - packages=[ - "_shortfin", - "_shortfin_default", - # TODO: Conditionally map additional native library variants. - "shortfin", - ], - zip_safe=False, - package_dir={ - "_shortfin": str(PYTHON_SOURCE_DIR / "_shortfin"), - "_shortfin_default": str(PYTHON_BINARY_DIR / "_shortfin_default"), - # TODO: Conditionally map additional native library variants. - "shortfin": str(PYTHON_SOURCE_DIR / "shortfin"), - }, - ext_modules=[ - CMakeExtension("_shortfin_default.lib") - # TODO: Conditionally map additional native library variants. - ], - cmdclass={ - "build": CustomBuild, - "build_ext": NoopBuildExtension, - "build_py": CMakeBuildPy, - }, -) diff --git a/libshortfin/src/CMakeLists.txt b/libshortfin/src/CMakeLists.txt deleted file mode 100644 index 51dd5f87f..000000000 --- a/libshortfin/src/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception - -add_subdirectory(shortfin) - -# Common definitions exported from both static and dynamic libraries. -add_library(_shortfin_defs INTERFACE) -target_include_directories( - _shortfin_defs INTERFACE $ - $) - - -set(_INIT_INTERNAL_DEPS) -if(SHORTFIN_SYSTEMS_AMDGPU) - list(APPEND _INIT_INTERNAL_DEPS shortfin_systems_amdgpu) -endif() - -shortfin_public_library( - NAME - shortfin - COMPONENTS - shortfin_array - shortfin_local - shortfin_support - shortfin_systems_host - ${_INIT_INTERNAL_DEPS} -) - -set_target_properties(shortfin PROPERTIES VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR} SOVERSION ${SOVERSION}) diff --git a/libshortfin/src/shortfin/CMakeLists.txt b/libshortfin/src/shortfin/CMakeLists.txt deleted file mode 100644 index 1bea0003b..000000000 --- a/libshortfin/src/shortfin/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception - -add_subdirectory(array) -add_subdirectory(local) -add_subdirectory(support) diff --git a/libshortfin/src/shortfin/local/systems/CMakeLists.txt b/libshortfin/src/shortfin/local/systems/CMakeLists.txt deleted file mode 100644 index effd15204..000000000 --- a/libshortfin/src/shortfin/local/systems/CMakeLists.txt +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception - -shortfin_cc_component( - NAME - shortfin_systems_host - HDRS - host.h - SRCS - host.cc - COMPONENTS - shortfin_local - shortfin_support - DEPS - iree_hal_drivers_local_task_task_driver - iree_hal_local_executable_loader - iree_hal_local_executable_plugin - iree_hal_local_executable_plugin_manager - iree_hal_local_loaders_registration_registration - iree_hal_local_local - iree_task_api - iree_task_task -) - -shortfin_cc_component( - NAME - shortfin_systems_amdgpu - HDRS - amdgpu.h - SRCS - amdgpu.cc - COMPONENTS - shortfin_local - shortfin_support - DEPS - iree_hal_drivers_hip_hip -) diff --git a/libshortfin/src/shortfin/local/systems/amdgpu.cc b/libshortfin/src/shortfin/local/systems/amdgpu.cc deleted file mode 100644 index b12cb1a5f..000000000 --- a/libshortfin/src/shortfin/local/systems/amdgpu.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 Advanced Micro Devices, Inc. -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "shortfin/local/systems/amdgpu.h" - -#include "shortfin/support/logging.h" - -namespace shortfin::local::systems { - -namespace { -const std::string_view SYSTEM_DEVICE_CLASS = "amdgpu"; -const std::string_view LOGICAL_DEVICE_CLASS = "gpu"; -const std::string_view HAL_DRIVER_PREFIX = "hip"; -} // namespace - -AMDGPUSystemBuilder::AMDGPUSystemBuilder(iree_allocator_t host_allocator) - : HostCPUSystemBuilder(host_allocator) { - InitializeDefaultSettings(); - iree_hal_hip_device_params_initialize(&default_device_params_); -} - -AMDGPUSystemBuilder::~AMDGPUSystemBuilder() = default; - -void AMDGPUSystemBuilder::InitializeDefaultSettings() { - char *raw_dylib_path_env_cstr = std::getenv("IREE_HIP_DYLIB_PATH"); - if (raw_dylib_path_env_cstr) { - std::string_view rest(raw_dylib_path_env_cstr); - for (;;) { - auto pos = rest.find(';'); - if (pos == std::string_view::npos) { - hip_lib_search_paths.emplace_back(rest); - break; - } - std::string_view first = rest.substr(0, pos); - rest = rest.substr(pos + 1); - hip_lib_search_paths.emplace_back(first); - } - } -} - -void AMDGPUSystemBuilder::Enumerate() { - if (hip_hal_driver_) return; - - iree_hal_hip_driver_options_t driver_options; - iree_hal_hip_driver_options_initialize(&driver_options); - - // Search path. - std::vector hip_lib_search_path_sv; - hip_lib_search_path_sv.resize(hip_lib_search_paths.size()); - for (size_t i = 0; i < hip_lib_search_paths.size(); ++i) { - hip_lib_search_path_sv[i].data = hip_lib_search_paths[i].data(); - hip_lib_search_path_sv[i].size = hip_lib_search_paths[i].size(); - } - driver_options.hip_lib_search_paths = hip_lib_search_path_sv.data(); - driver_options.hip_lib_search_path_count = hip_lib_search_path_sv.size(); - - SHORTFIN_THROW_IF_ERROR(iree_hal_hip_driver_create( - IREE_SV("hip"), &driver_options, &default_device_params_, - host_allocator(), hip_hal_driver_.for_output())); - - // Get available devices and filter into visible_devices_. - iree_host_size_t available_devices_count = 0; - iree::allocated_ptr raw_available_devices( - host_allocator()); - SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices( - hip_hal_driver_, host_allocator(), &available_devices_count, - raw_available_devices.for_output())); - for (iree_host_size_t i = 0; i < available_devices_count; ++i) { - iree_hal_device_info_t *info = &raw_available_devices.get()[i]; - // TODO: Filter based on visibility list. - visible_devices_.push_back(*info); - logging::info("Enumerated visible AMDGPU device: {} ({})", - to_string_view(visible_devices_.back().path), - to_string_view(visible_devices_.back().name)); - } -} - -SystemPtr AMDGPUSystemBuilder::CreateSystem() { - auto lsys = std::make_shared(host_allocator()); - Enumerate(); - // TODO: Real NUMA awareness. - lsys->InitializeNodes(1); - lsys->InitializeHalDriver(SYSTEM_DEVICE_CLASS, hip_hal_driver_); - - // Initialize all visible GPU devices. - for (size_t i = 0; i < visible_devices_.size(); ++i) { - auto &it = visible_devices_[i]; - iree::hal_device_ptr device; - SHORTFIN_THROW_IF_ERROR(iree_hal_driver_create_device_by_id( - hip_hal_driver_, it.device_id, 0, nullptr, host_allocator(), - device.for_output())); - lsys->InitializeHalDevice(std::make_unique( - DeviceAddress( - /*system_device_class=*/SYSTEM_DEVICE_CLASS, - /*logical_device_class=*/LOGICAL_DEVICE_CLASS, - /*hal_driver_prefix=*/HAL_DRIVER_PREFIX, - /*instance_ordinal=*/i, - /*queue_ordinal=*/0, - /*instance_topology_address=*/{0}), - std::move(device), /*node_affinity=*/0, - /*node_locked=*/false)); - } - - // Initialize CPU devices if requested. - if (cpu_devices_enabled) { - // Delegate to the HostCPUSystemConfig to configure CPU devices. - // This will need to become more complicated and should happen after - // GPU configuration when mating NUMA nodes, etc. - InitializeHostCPUDefaults(); - auto *driver = InitializeHostCPUDriver(*lsys); - InitializeHostCPUDevices(*lsys, driver); - } - - lsys->FinishInitialization(); - return lsys; -} - -} // namespace shortfin::local::systems diff --git a/libshortfin/src/shortfin/local/systems/amdgpu.h b/libshortfin/src/shortfin/local/systems/amdgpu.h deleted file mode 100644 index b1a182a8d..000000000 --- a/libshortfin/src/shortfin/local/systems/amdgpu.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Advanced Micro Devices, Inc. -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H -#define SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H - -#include - -#include "iree/hal/drivers/hip/api.h" -#include "shortfin/local/system.h" -#include "shortfin/local/systems/host.h" -#include "shortfin/support/api.h" -#include "shortfin/support/iree_helpers.h" - -namespace shortfin::local::systems { - -// AMD GPU device subclass. -class SHORTFIN_API AMDGPUDevice : public Device { - public: - using Device::Device; -}; - -// System configuration for some subset of AMD GPUs connected to the local -// system. Note that this inherits from HostCPUSystemBuilder, allowing joint -// configuration of a heterogenous CPU/GPU system. Depending on the specific -// system, this can involve more than simple starting CPU drivers: datacenter -// GPU systems have specific NUMA configurations that need to be mated. -class SHORTFIN_API AMDGPUSystemBuilder : public HostCPUSystemBuilder { - public: - AMDGPUSystemBuilder(iree_allocator_t host_allocator); - AMDGPUSystemBuilder() : AMDGPUSystemBuilder(iree_allocator_system()) {} - ~AMDGPUSystemBuilder(); - - // Triggers driver setup and initial device enumeration. No-op if already - // done. - void Enumerate(); - - SystemPtr CreateSystem() override; - - // Settings. - bool cpu_devices_enabled = false; - - // See iree_hal_hip_driver_options_t::hip_lib_search_paths. Each is either - // a directory or "file:" prefixed path to a specific HIP dynamic library. - // This is typically libamdhip64.so or amdhip64.dll. - // If the environment variable IREE_HIP_DYLIB_PATH is present, then it is - // split on ';' and each entry added here (for compatibility with IREE - // tools). - // Changing these paths after enumeration has no effect. - std::vector hip_lib_search_paths; - - private: - void InitializeDefaultSettings(); - - // Valid at construction time. - iree_hal_hip_device_params_t default_device_params_; - - // Valid post enumeration. - iree::hal_driver_ptr hip_hal_driver_; - std::vector visible_devices_; -}; - -} // namespace shortfin::local::systems - -#endif // SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H diff --git a/libshortfin/src/shortfin/local/systems/host.cc b/libshortfin/src/shortfin/local/systems/host.cc deleted file mode 100644 index 311c13398..000000000 --- a/libshortfin/src/shortfin/local/systems/host.cc +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2024 Advanced Micro Devices, Inc. -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "shortfin/local/systems/host.h" - -#include "iree/hal/local/loaders/registration/init.h" -#include "shortfin/support/iree_helpers.h" -#include "shortfin/support/logging.h" - -namespace shortfin::local::systems { - -namespace { -const std::string_view SYSTEM_DEVICE_CLASS = "host-cpu"; -const std::string_view LOGICAL_DEVICE_CLASS = "cpu"; -const std::string_view HAL_DRIVER_PREFIX = "local"; -} // namespace - -// -------------------------------------------------------------------------- // -// HostCPUSystemBuilder -// -------------------------------------------------------------------------- // - -HostCPUSystemBuilder::Deps::Deps(iree_allocator_t host_allocator) { - iree_task_executor_options_initialize(&task_executor_options); - iree_hal_task_device_params_initialize(&task_params); - iree_task_topology_initialize(&task_topology_options); -} - -HostCPUSystemBuilder::Deps::~Deps() { - iree_task_topology_deinitialize(&task_topology_options); - for (iree_host_size_t i = 0; i < loader_count; ++i) { - iree_hal_executable_loader_release(loaders[i]); - } - if (device_allocator) { - iree_hal_allocator_release(device_allocator); - } - if (executor) { - iree_task_executor_release(executor); - } - if (plugin_manager) { - iree_hal_executable_plugin_manager_release(plugin_manager); - } -} - -HostCPUSystemBuilder::HostCPUSystemBuilder(iree_allocator_t host_allocator) - : HostSystemBuilder(host_allocator), host_cpu_deps_(host_allocator) {} - -HostCPUSystemBuilder::~HostCPUSystemBuilder() = default; - -void HostCPUSystemBuilder::InitializeHostCPUDefaults() { - // Give it a default device allocator. - if (!host_cpu_deps_.device_allocator) { - SHORTFIN_THROW_IF_ERROR(iree_hal_allocator_create_heap( - iree_make_cstring_view("local"), host_allocator(), host_allocator(), - &host_cpu_deps_.device_allocator)); - } - - // And loaders. - if (host_cpu_deps_.loader_count == 0) { - SHORTFIN_THROW_IF_ERROR(iree_hal_create_all_available_executable_loaders( - /*plugin_manager=*/nullptr, IREE_ARRAYSIZE(host_cpu_deps_.loaders), - &host_cpu_deps_.loader_count, host_cpu_deps_.loaders, - host_allocator())); - } -} - -SystemPtr HostCPUSystemBuilder::CreateSystem() { - auto lsys = std::make_shared(host_allocator()); - // TODO: Real NUMA awareness. - lsys->InitializeNodes(1); - InitializeHostCPUDefaults(); - auto *driver = InitializeHostCPUDriver(*lsys); - InitializeHostCPUDevices(*lsys, driver); - lsys->FinishInitialization(); - return lsys; -} - -iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) { - // TODO: Kill these flag variants in favor of settings on the config - // object. - SHORTFIN_THROW_IF_ERROR(iree_task_executor_options_initialize_from_flags( - &host_cpu_deps_.task_executor_options)); - // TODO: Do something smarter than pinning to NUMA node 0. - SHORTFIN_THROW_IF_ERROR(iree_task_topology_initialize_from_flags( - /*node_id=*/0, &host_cpu_deps_.task_topology_options)); - - SHORTFIN_THROW_IF_ERROR( - iree_task_executor_create(host_cpu_deps_.task_executor_options, - &host_cpu_deps_.task_topology_options, - host_allocator(), &host_cpu_deps_.executor)); - - // Create the driver and save it in the System. - iree::hal_driver_ptr driver; - iree_hal_driver_t *unowned_driver; - SHORTFIN_THROW_IF_ERROR(iree_hal_task_driver_create( - /*identifier=*/ - { - .data = HAL_DRIVER_PREFIX.data(), - .size = HAL_DRIVER_PREFIX.size(), - }, - &host_cpu_deps_.task_params, /*queue_count=*/1, &host_cpu_deps_.executor, - host_cpu_deps_.loader_count, host_cpu_deps_.loaders, - host_cpu_deps_.device_allocator, host_allocator(), driver.for_output())); - unowned_driver = driver.get(); - lsys.InitializeHalDriver(SYSTEM_DEVICE_CLASS, std::move(driver)); - return unowned_driver; -} - -void HostCPUSystemBuilder::InitializeHostCPUDevices(System &lsys, - iree_hal_driver_t *driver) { - iree_host_size_t device_info_count = 0; - iree::allocated_ptr device_infos(host_allocator()); - SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices( - driver, host_allocator(), &device_info_count, &device_infos)); - - for (iree_host_size_t i = 0; i < device_info_count; ++i) { - iree::hal_device_ptr device; - iree_hal_device_info_t *it = &device_infos.get()[i]; - SHORTFIN_THROW_IF_ERROR(iree_hal_driver_create_device_by_id( - driver, it->device_id, 0, nullptr, host_allocator(), - device.for_output())); - lsys.InitializeHalDevice(std::make_unique( - DeviceAddress( - /*system_device_class=*/SYSTEM_DEVICE_CLASS, - /*logical_device_class=*/LOGICAL_DEVICE_CLASS, - /*hal_driver_prefix=*/HAL_DRIVER_PREFIX, - /*instance_ordinal=*/i, - /*queue_ordinal=*/0, - /*instance_topology_address=*/{0}), - /*hal_device=*/std::move(device), - /*node_affinity=*/0, - /*node_locked=*/false)); - } -} - -} // namespace shortfin::local::systems diff --git a/libshortfin/tests/amdgpu_system_test.py b/libshortfin/tests/amdgpu_system_test.py deleted file mode 100644 index cf04a7caf..000000000 --- a/libshortfin/tests/amdgpu_system_test.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest - - -@pytest.mark.requires_amd_gpu -def test_create_amd_gpu_system(): - from _shortfin import lib as sfl - - sc = sfl.local.amdgpu.SystemBuilder() - ls = sc.create_system() - print(f"LOCAL SYSTEM:", ls) - for device_name in ls.device_names: - print(f" DEVICE: {device_name} = {ls.device(device_name)}") - - print(ls.devices) - print("Shutting down") - ls.shutdown() diff --git a/libshortfin/tests/api/array_test.py b/libshortfin/tests/api/array_test.py deleted file mode 100644 index 729f091d8..000000000 --- a/libshortfin/tests/api/array_test.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import array -import pytest - -import shortfin as sf -import shortfin.array as sfnp - - -@pytest.fixture -def lsys(): - sc = sf.host.CPUSystemBuilder() - lsys = sc.create_system() - yield lsys - lsys.shutdown() - - -@pytest.fixture -def scope(lsys): - return lsys.create_scope() - - -@pytest.fixture -def device(scope): - return scope.device(0) - - -def test_storage_constructor(lsys, device): - async def main(): - s = sfnp.storage.allocate_host(device, 8) - s.fill(b"\0\1\2\3") - await device - ary = sfnp.device_array(s, [2, 4], sfnp.uint8) - assert ary.dtype == sfnp.uint8 - assert ary.shape == [2, 4] - assert str(ary) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}" - assert ary.device == device - assert ary.storage == s - - lsys.run(main()) - - -def test_device_constructor(lsys, device): - async def main(): - ary = sfnp.device_array(device, [2, 4], sfnp.uint8) - ary.storage.fill(b"\0\1\2\3") - await device - assert ary.dtype == sfnp.uint8 - assert ary.shape == [2, 4] - assert str(ary) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}" - assert ary.device == device - - lsys.run(main()) - - -def test_fill_copy_from_for_transfer(lsys, device): - async def main(): - src = sfnp.device_array(device, [2, 4], sfnp.uint8) - src.fill(b"\0\1\2\3") - dst = src.for_transfer() - dst.copy_from(src) - await device - assert str(dst) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}" - - lsys.run(main()) - - -def test_fill_copy_to_for_transfer(lsys, device): - async def main(): - src = sfnp.device_array(device, [2, 4], sfnp.uint8) - src.fill(b"\0\1\2\3") - dst = src.for_transfer() - src.copy_to(dst) - await device - assert str(dst) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}" - - lsys.run(main()) - - -def test_shape_overflow(lsys, device): - async def main(): - s = sfnp.storage.allocate_host(device, 4) - _ = sfnp.device_array(s, [2, 4], sfnp.uint8) - - with pytest.raises( - ValueError, match="Array storage requires at least 8 bytes but has only 4" - ): - lsys.run(main()) - - -@pytest.mark.parametrize( - "dtype,code,py_value,expected_str", - [ - (sfnp.int8, "b", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), - (sfnp.int16, "h", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), - (sfnp.int32, "i", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), - ( - sfnp.float32, - "f", - 42.0, - "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", - ), - ( - sfnp.float64, - "d", - 42.0, - "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", - ), - ], -) -def test_xtensor_types(scope, dtype, code, py_value, expected_str): - ary = sfnp.device_array.for_host(scope.device(0), [2, 4], dtype) - ary.storage.data = array.array(code, [py_value] * 8) - s = str(ary) - print("__str__ =", s) - assert expected_str == s, f"Expected '{expected_str}' == '{s}'" - r = repr(ary) - print("__repr__ =", r) - assert expected_str in r, f"Expected '{expected_str}' in '{r}'" diff --git a/libshortfin/tests/local_scope_test.py b/libshortfin/tests/local_scope_test.py deleted file mode 100644 index 379f13d4b..000000000 --- a/libshortfin/tests/local_scope_test.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest -import time - -from _shortfin import lib as sfl - - -@pytest.fixture -def lsys(): - sc = sfl.local.host.CPUSystemBuilder() - ls = sc.create_system() - yield ls - ls.shutdown() - - -@pytest.fixture -def scope(lsys): - # TODO: Should adopt the main thread. - worker = lsys.create_worker("main") - return lsys.create_scope(worker) - - -def test_raw_device_access(scope): - first_name = scope.device_names[0] - assert first_name == "cpu0" - first_device = scope.raw_device(0) # By index - assert isinstance(first_device, sfl.local.host.HostCPUDevice) - assert first_device is scope.raw_device(first_name) # By name - print(first_device) - devices = scope.raw_devices - named_devices = scope.named_devices - assert first_name in named_devices - assert devices[0] is named_devices[first_name] - assert devices[0] is first_device - with pytest.raises(ValueError): - scope.raw_device("cpu1") - with pytest.raises(ValueError): - scope.raw_device(1) - - -def test_devices_collection_access(scope): - # # Access via devices pseudo collection. - first_device = scope.raw_device(0) - assert scope.devices.cpu0.raw_device is first_device - assert scope.devices[0].raw_device is first_device - assert scope.devices["cpu0"].raw_device is first_device - assert len(scope.devices) == 1 - with pytest.raises(ValueError): - scope.devices.cpu1 - with pytest.raises(ValueError): - scope.devices[1] - - -def test_device_affinity_repr(scope): - assert ( - repr(sfl.local.DeviceAffinity(scope.raw_device(0))) - == "DeviceAffinity(host-cpu:0:0@0[0x1])" - ) - assert repr(sfl.local.DeviceAffinity()) == "DeviceAffinity(ANY)" - - -def test_device_affinity_resolve(scope): - # TODO: Need a scope with multiple devices to test errors. - print(scope.device(0, "cpu0", scope.raw_device(0))) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..e736fe3bd --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +# Used for managing pre-commit flows. +pre-commit + +# Type checking +mypy==1.8.0 +types-requests==2.31.0.20240125 + +# Testing +pytest==8.0.0 +pytest-xdist==3.5.0 diff --git a/requirements.txt b/requirements.txt index 6a5cccc92..cc2edf876 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,26 +1,4 @@ -# Runtime deps. -gguf==0.6.0 -numpy==1.26.3 -onnx==1.15.0 - -# Model deps. -huggingface-hub==0.22.2 -transformers==4.40.0 -sentencepiece==0.2.0 - -# It is expected that you have installed a PyTorch version/variant specific -# to your needs, so we only include a minimum version spec. -# TODO: Use a versioned release once 2.3.0 drops. -torch>=2.3.0.dev1 - -# Used for managing pre-commit flows. -pre-commit - -# Type checking -mypy==1.8.0 -types-requests==2.31.0.20240125 - -# Testing -parameterized -pytest==8.0.0 -pytest-xdist==3.5.0 +-r sharktank/requirements.txt +-r sharktank/requirements-tests.txt +-r shortfin/requirements-tests.txt +-r requirements-dev.txt diff --git a/shark-ai/.gitignore b/shark-ai/.gitignore new file mode 100644 index 000000000..80bf001b8 --- /dev/null +++ b/shark-ai/.gitignore @@ -0,0 +1,2 @@ +# Local-only config options +requirements.txt diff --git a/shark-ai/README.md b/shark-ai/README.md new file mode 100644 index 000000000..0bb1abafd --- /dev/null +++ b/shark-ai/README.md @@ -0,0 +1,3 @@ +# SHARK AI meta package + +Meta package to install `shortfin` and compatible IREE packages. diff --git a/shark-ai/build_tools/build_linux_package.sh b/shark-ai/build_tools/build_linux_package.sh new file mode 100755 index 000000000..d16f339b1 --- /dev/null +++ b/shark-ai/build_tools/build_linux_package.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# build_linux_package.sh +# +# Builds shark-ai Python package for Linux. +# +# Usage: +# ./build_tools/build_linux_package.sh + +set -xeu -o errtrace + +THIS_DIR="$(cd $(dirname $0) && pwd)" +REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)" +OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}" + +python -m pip wheel --disable-pip-version-check --no-deps -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/shark-ai" + +wheel_output="$(echo "${OUTPUT_DIR}/shark_ai-"*".whl")" +ls "${wheel_output}" diff --git a/shark-ai/pyproject.toml b/shark-ai/pyproject.toml new file mode 100644 index 000000000..3f7e4a1da --- /dev/null +++ b/shark-ai/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "shark-ai" +authors = [ + {name = "SHARK Authors"}, +] +description = "SHARK AI meta package" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", +] +# Version is set via the `setup.py` and requirements are set via files below. +dynamic = ["version", "dependencies"] + +[project.urls] +Repository = "https://github.com/nod-ai/shark-ai" + +[project.optional-dependencies] +onnx = [ + "iree-base-compiler[onnx]", +] +apps = [ + "shortfin[apps]", +] + +[tool.setuptools] +packages = [] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} diff --git a/shark-ai/setup.py b/shark-ai/setup.py new file mode 100644 index 000000000..5ceac55bd --- /dev/null +++ b/shark-ai/setup.py @@ -0,0 +1,33 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os +from pathlib import Path + +from setuptools import setup + +THIS_DIR = Path(__file__).parent.resolve() + +# Setup and get version information. +# The `version_local.json` is generated by calling: +# `build_tools/python_deploy/compute_common_version.py -stable --write-json` +VERSION_FILE_LOCAL = os.path.join(THIS_DIR, "version_local.json") + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +version_info = load_version_info(VERSION_FILE_LOCAL) + +PACKAGE_VERSION = version_info.get("package-version") +print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") + +setup( + version=f"{PACKAGE_VERSION}", +) diff --git a/sharktank/.gitignore b/sharktank/.gitignore new file mode 100644 index 000000000..000e575d5 --- /dev/null +++ b/sharktank/.gitignore @@ -0,0 +1,2 @@ +# Local-only config options +version_info_rc.json diff --git a/sharktank/README.md b/sharktank/README.md index 9dd239eba..7770595ed 100644 --- a/sharktank/README.md +++ b/sharktank/README.md @@ -10,6 +10,10 @@ This sub-project is a work in progress. It is intended to be a repository of layers, model recipes, and conversion tools from popular LLM quantization tooling. +## Project Status + +[![CI - Perplexity](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml/badge.svg?branch=main&event=schedule)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml) + ## Examples The repository will ultimately grow a curated set of models and tools for @@ -40,3 +44,30 @@ python -m sharktank.examples.export_paged_llm_v1 \ ```shell python -m sharktank.tools.dump_gguf --hf-dataset=open_llama_3b_v2_f16_gguf ``` + +## Package Python Release Builds + +* To build wheels for Linux: + + ```bash + ./build_tools/build_linux_package.sh + ``` + + That should produce + `build_tools/wheelhouse/sharktank-{X.Y.Z}.dev0-py3-none-any.whl`, which can + then be installed with + + ```bash + python3 -m pip install build_tools/wheelhouse/sharktank-{X.Y.Z}.dev0-py3-none-any.whl + ``` + +* To build a wheel for your host OS/arch manually: + + ```bash + # Build sharktank.*.whl into the dist/ directory + # e.g. `sharktank-3.0.0.dev0-py3-none-any.whl` + python3 -m pip wheel -v -w dist . + + # Install the built wheel. + python3 -m pip install dist/*.whl + ``` diff --git a/sharktank/build_tools/build_linux_package.sh b/sharktank/build_tools/build_linux_package.sh new file mode 100755 index 000000000..3ed3a7fbb --- /dev/null +++ b/sharktank/build_tools/build_linux_package.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# build_linux_package.sh +# +# Builds sharktank Python package for Linux. +# +# Note: requires a modern Python (3.12+ seems to work). Troubleshooting help: +# * https://stackoverflow.com/a/77284076 +# * https://stackoverflow.com/a/77364602 +# Older versions like 3.10 don't include the package name and set as UNKNOWN? +# * Might just need some local packages updated? +# +# Usage: +# ./build_tools/build_linux_package.sh + +set -xeu -o errtrace + +THIS_DIR="$(cd $(dirname $0) && pwd)" +REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)" +OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}" + +python -m pip wheel --disable-pip-version-check --no-deps -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/sharktank" + +wheel_output="$(echo "${OUTPUT_DIR}/sharktank-"*".whl")" +ls "${wheel_output}" diff --git a/sharktank/conftest.py b/sharktank/conftest.py index fd9c47447..475f386be 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -5,6 +5,9 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from pathlib import Path +import pytest +from pytest import FixtureRequest +from typing import Optional, Any # Tests under each top-level directory will get a mark. @@ -24,3 +27,312 @@ def pytest_collection_modifyitems(items, config): mark = TLD_MARKS.get(tld) if mark: item.add_marker(mark) + + +def pytest_addoption(parser): + parser.addoption( + "--mlir", + type=Path, + default=None, + help="Path to exported MLIR program. If not specified a temporary file will be used.", + ) + parser.addoption( + "--module", + type=Path, + default=None, + help="Path to exported IREE module. If not specified a temporary file will be used.", + ) + parser.addoption( + "--parameters", + type=Path, + default=None, + help="Exported model parameters. If not specified a temporary file will be used.", + ) + parser.addoption( + "--prefix", + type=str, + default=None, + help=( + "Path prefix for test artifacts. " + "Other arguments may override this for specific values." + ), + ) + parser.addoption( + "--caching", + action="store_true", + default=False, + help="Load cached results if present instead of recomputing.", + ) + + parser.addoption( + "--longrun", + action="store_true", + dest="longrun", + default=False, + help="Enable long tests", + ) + + parser.addoption( + "--run-quick-llama-test", + action="store_true", + dest="run-quick-llama-test", + default=False, + help="Enable llama 8b f16 decomposed benchmarking test", + ) + + parser.addoption( + "--run-nightly-llama-tests", + action="store_true", + dest="run-nightly-llama-tests", + default=False, + help="Enable all llama benchmarking tests", + ) + + parser.addoption( + "--with-t5-data", + action="store_true", + default=False, + help=( + "Enable tests that use T5 data like models that is not a part of the source " + "code. The user is expected to provide the data" + ), + ) + + # TODO: Remove all hardcoded paths in CI tests + parser.addoption( + "--llama3-8b-tokenizer-path", + type=Path, + action="store", + help="Llama3.1 8b tokenizer path, defaults to 30F CI system path", + ) + + parser.addoption( + "--llama3-8b-f16-model-path", + type=Path, + action="store", + help="Llama3.1 8b model path, defaults to 30F CI system path", + ) + + parser.addoption( + "--llama3-8b-fp8-model-path", + type=Path, + action="store", + default=None, + help="Llama3.1 8b fp8 model path", + ) + + parser.addoption( + "--llama3-405b-tokenizer-path", + type=Path, + action="store", + help="Llama3.1 405b tokenizer path, defaults to 30F CI system path", + ) + + parser.addoption( + "--llama3-405b-f16-model-path", + type=Path, + action="store", + help="Llama3.1 405b model path, defaults to 30F CI system path", + ) + + parser.addoption( + "--llama3-405b-fp8-model-path", + type=Path, + action="store", + default=None, + help="Llama3.1 405b fp8 model path", + ) + + # To obtain a T5 GGUF file you can use llama.cpp's convert_hf_to_gguf.py. + # https://github.com/ggerganov/llama.cpp/blob/9abe9eeae98b11fa93b82632b264126a010225ff/convert_hf_to_gguf.py + # E.g. + # git lfs install + # git clone https://huggingface.co/google/t5-v1_1-small + # convert_hf_to_gguf.py \ + # --outfile t5-v1_1-small.gguf \ + # --outtype=f32 \ + # t5-v1_1-small + parser.addoption( + "--google-t5-v1-1-small-fp32-model-path", + type=Path, + default="/data/t5/small/google__t5-v1_1-small_fp32.gguf", + help="Google T5 v1.1 small fp32 model path", + ) + parser.addoption( + "--google-t5-v1-1-xxl-fp32-model-path", + type=Path, + default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf", + help="Google T5 v1.1 XXL fp32 model path", + ) + + parser.addoption( + "--baseline-perplexity-scores", + type=Path, + action="store", + default="sharktank/tests/evaluate/baseline_perplexity_scores.json", + help="Llama3.1 8B & 405B model baseline perplexity scores", + ) + + parser.addoption( + "--iree-device", + type=str, + action="store", + help="List an IREE device from iree-run-module --list_devices", + ) + + parser.addoption( + "--iree-hip-target", + action="store", + help="Specify the iree-hip target version (e.g., gfx942)", + ) + + parser.addoption( + "--iree-hal-target-backends", + action="store", + help="Specify the iree-hal target backend (e.g., rocm)", + ) + + parser.addoption( + "--tensor-parallelism-size", + action="store", + type=int, + default=1, + help="Number of devices for tensor parallel sharding", + ) + + parser.addoption( + "--bs", + action="store", + type=int, + default=4, + help="Batch size for mlir export", + ) + + +def set_fixture_from_cli_option( + request: FixtureRequest, + cli_option_name: str, + class_attribute_name: Optional[str] = None, +) -> Optional[Any]: + res = request.config.getoption(cli_option_name) + if request.cls is None: + return res + else: + if class_attribute_name is None: + class_attribute_name = cli_option_name + setattr(request.cls, class_attribute_name, res) + + +@pytest.fixture(scope="class") +def mlir_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "mlir", "mlir_path") + + +@pytest.fixture(scope="class") +def module_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "module", "module_path") + + +@pytest.fixture(scope="class") +def parameters_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "parameters", "parameters_path") + + +@pytest.fixture(scope="class") +def path_prefix(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option(request, "prefix", "path_prefix") + + +@pytest.fixture(scope="class") +def caching(request: FixtureRequest) -> Optional[bool]: + return set_fixture_from_cli_option(request, "caching") + + +@pytest.fixture(scope="class") +def tensor_parallelism_size(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option( + request, "tensor_parallelism_size", "tensor_parallelism_size" + ) + + +@pytest.fixture(scope="class") +def baseline_perplexity_scores(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option( + request, "baseline_perplexity_scores", "baseline_perplexity_scores" + ) + + +@pytest.fixture(scope="class") +def batch_size(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option(request, "bs", "batch_size") + + +@pytest.fixture(scope="class") +def get_model_artifacts(request: FixtureRequest): + model_path = {} + model_path["llama3_8b_tokenizer_path"] = set_fixture_from_cli_option( + request, "--llama3-8b-tokenizer-path", "llama3_8b_tokenizer" + ) + model_path["llama3_8b_f16_model_path"] = set_fixture_from_cli_option( + request, "--llama3-8b-f16-model-path", "llama3_8b_f16_model" + ) + model_path["llama3_8b_fp8_model_path"] = set_fixture_from_cli_option( + request, "--llama3-8b-fp8-model-path", "llama3_8b_fp8_model" + ) + model_path["llama3_405b_tokenizer_path"] = set_fixture_from_cli_option( + request, "--llama3-405b-tokenizer-path", "llama3_405b_tokenizer" + ) + model_path["llama3_405b_f16_model_path"] = set_fixture_from_cli_option( + request, "--llama3-405b-f16-model-path", "llama3_405b_f16_model" + ) + model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option( + request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model" + ) + model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option( + request, + "--google-t5-v1-1-small-fp32-model-path", + "google__t5_v1_1_small_fp32_model", + ) + model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option( + request, + "--google-t5-v1-1-xxl-fp32-model-path", + "google__t5_v1_1_xxl_fp32_model", + ) + return model_path + + +@pytest.fixture(scope="class") +def get_iree_flags(request: FixtureRequest): + model_path = {} + model_path["iree_device"] = set_fixture_from_cli_option( + request, "--iree-device", "iree_device" + ) + model_path["iree_hip_target"] = set_fixture_from_cli_option( + request, "--iree-hip-target", "iree_hip_target" + ) + model_path["iree_hal_target_backends"] = set_fixture_from_cli_option( + request, "--iree-hal-target-backends", "iree_hal_target_backends" + ) + + +# The following three functions allow us to add a "XFail Reason" column to the html reports for each test +@pytest.hookimpl(optionalhook=True) +def pytest_html_results_table_header(cells): + cells.insert(2, "XFail Reason") + + +@pytest.hookimpl(optionalhook=True) +def pytest_html_results_table_row(report, cells): + if hasattr(report, "wasxfail"): + cells.insert(2, f"{report.wasxfail}") + else: + cells.insert(2, f"") + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + report = outcome.get_result() + + if report.when == "call" and hasattr(item, "wasxfail"): + report.wasxfail = item.wasxfail diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml index bc5203c86..01cad409b 100644 --- a/sharktank/pyproject.toml +++ b/sharktank/pyproject.toml @@ -2,6 +2,40 @@ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" +[project] +name = "sharktank" +authors = [ + {name = "SHARK Authors"}, +] +description = "SHARK layers and inference models for genai" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", +] + +# Version is set via the `setup.py` and requirements are set via files below. +dynamic = ["version", "dependencies", "optional-dependencies"] + +[project.urls] +Repository = "https://github.com/nod-ai/shark-ai" + +[tool.setuptools.packages.find] +where = ["."] +include = ["sharktank*"] +namespaces = true + +[tool.setuptools.package-data] +sharktank = ["py.typed", "kernels/templates/*.mlir"] + +[tool.setuptools.dynamic.dependencies] +file = ["requirements.txt"] + +[tool.setuptools.dynamic.optional-dependencies] +testing = {file = ["requirements-tests.txt"]} + [tool.pytest.ini_options] addopts = [ "-ra", diff --git a/sharktank/requirements-dev-turbine.txt b/sharktank/requirements-dev-turbine.txt new file mode 100644 index 000000000..0d0dc7619 --- /dev/null +++ b/sharktank/requirements-dev-turbine.txt @@ -0,0 +1 @@ +-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" diff --git a/sharktank/requirements-tests.txt b/sharktank/requirements-tests.txt new file mode 100644 index 000000000..d5b4b0c0e --- /dev/null +++ b/sharktank/requirements-tests.txt @@ -0,0 +1,4 @@ +datasets==3.0.0 +parameterized +pytest==8.0.0 +pytest-html diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 6b21f239f..6b533d977 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -1 +1,21 @@ -gguf +iree-turbine + +# Runtime deps. +gguf==0.6.0 +numpy<2.0 + +# Needed for newer gguf versions (TODO: remove when gguf package includes this) +# sentencepiece>=0.1.98,<=0.2.0 + +# Model deps. +huggingface-hub==0.22.2 +transformers==4.40.0 +datasets + +# It is expected that you have installed a PyTorch version/variant specific +# to your needs, so we only include a minimum version spec. +torch>=2.3.0 + +# Serving deps. +fastapi==0.112.2 +uvicorn==0.30.6 diff --git a/sharktank/setup.py b/sharktank/setup.py index b8caf9e7d..182f94abc 100644 --- a/sharktank/setup.py +++ b/sharktank/setup.py @@ -6,102 +6,31 @@ import json import os -import distutils.command.build from pathlib import Path -from setuptools import find_namespace_packages, setup # type: ignore +from setuptools import setup -THIS_DIR = Path(__file__).resolve().parent -REPO_DIR = THIS_DIR.parent -VERSION_INFO_FILE = REPO_DIR / "version_info.json" +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) +# Setup and get version information. +VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json") +VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json") -with open( - os.path.join( - THIS_DIR, - "README.md", - ), - "rt", -) as f: - README = f.read() - -def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: +def load_version_info(version_file): + with open(version_file, "rt") as f: return json.load(f) -version_info = load_version_info() -PACKAGE_VERSION = version_info["package-version"] - -packages = find_namespace_packages( - include=[ - "sharktank", - "sharktank.*", - ], -) - -print("Found packages:", packages) - -# Lookup version pins from requirements files. -requirement_pins = {} - - -def load_requirement_pins(requirements_file: Path): - with open(requirements_file, "rt") as f: - lines = f.readlines() - pin_pairs = [line.strip().split("==") for line in lines if "==" in line] - requirement_pins.update(dict(pin_pairs)) - - -load_requirement_pins(REPO_DIR / "requirements.txt") - - -def get_version_spec(dep: str): - if dep in requirement_pins: - return f">={requirement_pins[dep]}" - else: - return "" - - -# Override build command so that we can build into _python_build -# instead of the default "build". This avoids collisions with -# typical CMake incantations, which can produce all kinds of -# hilarity (like including the contents of the build/lib directory). -class BuildCommand(distutils.command.build.build): - def initialize_options(self): - distutils.command.build.build.initialize_options(self) - self.build_base = "_python_build" +try: + version_info = load_version_info(VERSION_FILE_LOCAL) +except FileNotFoundError: + print("version_local.json not found. Default to dev build") + version_info = load_version_info(VERSION_FILE) +PACKAGE_VERSION = version_info.get("package-version") +print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") setup( - name=f"sharktank", version=f"{PACKAGE_VERSION}", - author="SHARK Authors", - author_email="stella@nod.ai", - description="SHARK layers and inference models for genai", - long_description=README, - long_description_content_type="text/markdown", - url="https://github.com/nod-ai/sharktank", - license="Apache-2.0", - classifiers=[ - "Development Status :: 3 - Alpha", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - ], - packages=packages, - include_package_data=True, - package_data={ - "sharktank": ["py.typed", "kernels/templates/*.mlir"], - }, - install_requires=[ - "shark-turbine", - ], - extras_require={ - "testing": [ - f"pytest{get_version_spec('pytest')}", - f"pytest-xdist{get_version_spec('pytest-xdist')}", - ], - }, - cmdclass={"build": BuildCommand}, ) diff --git a/sharktank/sharktank/evaluate/README.md b/sharktank/sharktank/evaluate/README.md new file mode 100644 index 000000000..beb0281cd --- /dev/null +++ b/sharktank/sharktank/evaluate/README.md @@ -0,0 +1,40 @@ +# LLM Evaluation Pipeline + +## Setup +Setup SHARK Platform's Evaluation Pipeline + +``` +pip install -r sharktank/requirements-tests.txt +``` + +### Perplexity + +Perplexity score measures the ability of a language model to predict the next token in a sequence. A lower score indicates that a model has higher certainty in it's predictions. Perplexity acts as an intrinsic evaluation metric that measures the model quality, independent of any downstream task. + +In SHARK-Platform, we use perplexity to track code regressions and quality loss across quantized models (with FP16 as baseline). We use 100 prompts randomly selected from the Wikitext-2 test set and calculate the mean perplexities shown below. These numbers are neither comparable between models with different tokenizers nor with other projects due to varying implementations. + +* Test perplexity for Llama3.1 8B (FP16) model: + +```bash +pytest sharktank/tests/evaluate/perplexity_test.py --longrun +``` + +* Calculate perplexity for a new model: + +```bash +python -m sharktank.evaluate.perplexity \ + --gguf-file=llama3_70b_f16.gguf \ + --tokenizer-config-json=tokenizer_config.json +``` + +### Perplexity Scoreboard + +| CPU | GPU | +|:-------------: |:----------:| +| AMD EPYC 9554 | MI300X | + +#### LLaMA 3.1 + +|Models |Model size (GB) |Torch score |IREE score | +|:----------------------|:---------------|:-------------|:-------------| +|8B FP16 TP1 decomposed |16.07 |14.930181 |14.991893 | diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py new file mode 100644 index 000000000..6060eb91b --- /dev/null +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -0,0 +1,492 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys +import logging +import json +import time +import random +import re +from datetime import timedelta +from tqdm import tqdm + +import numpy as np + +from datasets import load_dataset + +import torch +from torch.nn import CrossEntropyLoss + +from sharktank.models.llama.llama import * +from sharktank.models.mixtral.mixtral import * +from sharktank.models.grok.grok import * + +from ..models.llama.sharding import shard_theta + +from sharktank.layers import * +from sharktank.types import * + +from sharktank.utils import cli +from sharktank.utils.vmfb_runner import * +from sharktank.utils.load_llm import * +from sharktank.utils.create_cache import * +from sharktank.utils.export_artifacts import * + +log_levels = { + "info": logging.INFO, + "debug": logging.DEBUG, +} +logger = logging.getLogger("eval") + +logger.setLevel(log_levels["info"]) + +logger.root.handlers[0].setFormatter( + logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") +) + +__all__ = ["Perplexity", "run_perplexity"] + + +class Perplexity: + """ + Perplexity (PPL) is one of the most common metrics for evaluating language models. + It is defined as the exponentiated average negative log-likelihood of a sequence, + calculated with exponent base `e`. + + For more information, see https://huggingface.co/docs/transformers/perplexity + """ + + def __init__( + self, + torch_device, + iree_device, + iree_hip_target, + iree_hal_target_backends, + kv_cache_type, + tensor_parallelism_size, + attention_kernel, + ): + self.torch_device = torch_device + self.iree_device = iree_device + self.iree_hip_target = iree_hip_target + self.iree_hal_target_backends = iree_hal_target_backends + self.kv_cache_type = kv_cache_type + self.activation_dtype = torch.float16 + self.attention_dtype = torch.float16 + self.tensor_parallelism_size = tensor_parallelism_size + self.attention_kernel = attention_kernel + + def timeit(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) + + func_name = func.__name__ + if func_name == "get_perplexity": + func_name = f"Calculate perplexity" + elif func_name == "compile_model": + func_name = f"Export & compile" + logger.info(f" {func_name}: {time_taken}") + return result + + return wrapper + + def print_token_comparison(self, i): + if i <= self.max_prompt_length: + batch_predicted_token_id = [[i[-1]] for i in self.batch.results] + batch_predicted_token = self.generator.tokenizer.decode( + batch_predicted_token_id + ) + logger.debug(f"Predicted:") + logger.debug(f"{batch_predicted_token}") + logger.debug(f"{batch_predicted_token_id}") + + expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist() + expected_token = self.generator.tokenizer.decode(expected_token_id) + logger.debug(f"Expected:") + logger.debug(f"{expected_token}") + logger.debug(f"{expected_token_id}") + + @timeit + def compile_model(self, weight_path_str): + self.weight_path_str = weight_path_str + + logger.info(f" Compiling: {self.weight_path_str}") + + export_artifacts = ExportArtifacts( + irpa_path=self.weight_path_str, + batch_size=self.bs, + iree_hip_target=self.iree_hip_target, + iree_hal_target_backends=self.iree_hal_target_backends, + attention_kernel=self.attention_kernel, + tensor_parallelism_size=self.tensor_parallelism_size, + ) + vmfb_path = export_artifacts.get_artifacts() + return vmfb_path + + @timeit + def load_model(self, weight_path, tokenizer, vmfb_path): + + self.config = LlamaModelConfig( + hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), + block_seq_stride=16, + kv_cache_type=self.kv_cache_type, + device=self.torch_device, + activation_dtype=self.activation_dtype, + attention_dtype=self.attention_dtype, + tensor_parallelism_size=self.tensor_parallelism_size, + ) + + if self.config.tensor_parallelism_size > 1: + weight_path.root_theta = shard_theta(weight_path.root_theta, self.config) + + theta = weight_path.root_theta + + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) + else: + model = PagedMixtralModelV1(theta, self.config) + else: + model = PagedLlamaModelV1(theta, self.config) + + self.generator = TorchGenerator(model, tokenizer) + + self.runner = vmfbRunner( + device=self.iree_device, + vmfb_path=vmfb_path, + external_weight_path=self.weight_path_str, + ) + + self.haldevice = self.runner.config.device + + @timeit + def get_prompts(self, num_prompts): + test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ + "text" + ] + num_test_prompts = 219 + + random.seed(0) + test_prompts = random.sample(test_prompts, num_test_prompts) + + # Ignore prompts that are: empty, less than 20 tokens or a title. + test_prompts = [ + s.replace("\n", "").rstrip() + for s in test_prompts + if s != "" and len(s.split()) >= 20 and s.count("=") < 2 + ][0:num_prompts] + + self.test_prompts = test_prompts + + self.bs = len(test_prompts) + + logger.info(f" Batch size: {self.bs}") + + @timeit + def prefill_vmfb(self, token_batch, i): + + seq_block_ids = self.batch.pad_block_ids() + prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"]( + token_batch, + self.batch.seq_lens, + seq_block_ids, + self.cache_state, + ) + + prefill_logits = torch.tensor(prefill_logits[:, :, :]) + + tokens = torch.tensor( + self.generator.model.extract_tokens_from_logits( + prefill_logits, self.batch.seq_lens + ) + ).unsqueeze(1) + self.batch.add_result_token(tokens) + + self.print_token_comparison(i) + return prefill_logits + + def decode_vmfb(self, token_batch, i): + logger.debug("Decode:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + logger.debug(f"{token_batch.tolist()}") + + start_positions = self.batch.seq_lens.clone() + self.batch.seq_lens.add_(1) + self.batch.allocate_seq_block_ids() + seq_block_ids = self.batch.pad_block_ids() + + decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"]( + token_batch, + self.batch.seq_lens, + start_positions, + seq_block_ids, + self.cache_state, + ) + + decode_logits = torch.tensor(decode_logits[:, :, :]) + + tokens = torch.tensor( + self.generator.model.extract_tokens_from_logits( + decode_logits, [1] * self.bs + ), + device=self.generator.model.device, + ).unsqueeze(1) + self.batch.add_result_token(tokens) + self.print_token_comparison(i) + return decode_logits + + @timeit + def get_logits(self, page_cache_size): + + is_first_token = True + start = 0 + for i in tqdm( + range(start, self.max_prompt_length - 1), + mininterval=300, + desc="eval: Calculating logits", + ): + logger.debug(f"Iteration: {i}") + + if is_first_token: + + token_batch = self.token_ids[:, : i + 1] + + logger.debug(f"Prefill:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + + token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( + token_ids=token_batch.tolist(), + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + logger.debug(f"{token_batch}") + + token_batch = torch.tensor(token_batch, device=self.torch_device) + self.seq_lens_batch = torch.tensor( + seq_lens_batch, device=self.torch_device + ) + + self.batch = self.generator.begin_eval_batch( + token_batch=token_batch, + seq_lens_batch=self.seq_lens_batch, + bs=self.bs, + page_cache_size=page_cache_size, + ) + + self.cache_state = ireert.asdevicearray( + self.haldevice, self.batch.cache_state[0].to("cpu").numpy() + ) + + prefill_logits = self.prefill_vmfb(token_batch, i) + self.out_logits = prefill_logits[:, -1:, :] + + is_first_token = False + + else: + token_batch = self.token_ids[:, i : i + 1] + + decode_logits = self.decode_vmfb(token_batch, i) + self.out_logits = torch.cat((self.out_logits, decode_logits), 1) + + pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1] + + self.pad_logits = torch.zeros( + self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2] + ) + + self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to( + self.torch_device + ) + + @timeit + def compute_perplexity(self): + loss_fct = CrossEntropyLoss(reduction="none") + + ## perplexity = e ^ (sum(losses) / num_tokenized_tokens) + crossentropy_loss = ( + loss_fct(self.out_logits.transpose(1, 2), self.token_ids) + * self.attention_mask + ).sum(1) + crossentropy_loss = torch.tensor(crossentropy_loss.tolist()) + perplexity_batch = torch.exp( + crossentropy_loss / self.attention_mask.sum(1) + ).tolist() + + perplexity_batch = [round(ppl, 6) for ppl in perplexity_batch] + + return { + "perplexities": perplexity_batch, + "mean_perplexity": round(np.mean(perplexity_batch), 6), + } + + @timeit + def get_perplexity(self): + + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.torch_device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.torch_device) + ) + + self.get_logits(page_cache_size=self.page_cache_size) + + self.out_logits = self.out_logits[..., :-1, :].contiguous() + self.token_ids = self.token_ids[..., 1:].contiguous() + self.attention_mask = self.attention_mask[..., 1:].contiguous() + + logger.debug(f"Final Logits shape: {self.out_logits.shape}") + logger.debug(f"Token ids: {self.token_ids}, \n{self.token_ids.shape}") + logger.debug( + f"Mask shape: {self.attention_mask}, \n{self.attention_mask.shape}" + ) + + assert self.token_ids.shape == self.out_logits.shape[0:2] + + return self.compute_perplexity() + + +def run_perplexity( + weight_path, + weight_path_str, + tokenizer, + torch_device, + iree_device, + iree_hip_target, + iree_hal_target_backends, + kv_cache_type, + tensor_parallelism_size, + attention_kernel, + num_prompts, +): + start = time.time() + perplexity = Perplexity( + torch_device=torch_device, + iree_device=iree_device, + iree_hip_target=iree_hip_target, + iree_hal_target_backends=iree_hal_target_backends, + kv_cache_type=kv_cache_type, + tensor_parallelism_size=tensor_parallelism_size, + attention_kernel=attention_kernel, + ) + + perplexity.get_prompts(num_prompts=num_prompts) + + vmfb_path = perplexity.compile_model(weight_path_str) + perplexity.load_model(weight_path, tokenizer, vmfb_path) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") + + return ppl + + +def main(argv): + parser = cli.create_parser() + parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") + parser.add_argument("--torch-device", help="Torch device (or default)") + parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')") + parser.add_argument( + "--iree-hip-target", + action="store", + default="gfx942", + help="Specify the iree-hip target version (e.g., gfx942)", + ) + parser.add_argument( + "--iree-hal-target-backends", + action="store", + default="rocm", + help="Specify the iree-hal target backends (e.g., rocm)", + ) + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch_sdpa"], + ) + parser.add_argument( + "--tensor-parallelism-size", + type=int, + default=1, + help="Number of devices for tensor parallel sharding", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) + + cli.add_tokenizer_options(parser) + cli.add_input_dataset_options(parser) + args = cli.parse(parser, args=argv) + + torch_device = torch.device(args.torch_device) if args.torch_device else None + iree_device = args.iree_device + kv_cache_type = args.kv_cache_type + weight_path = cli.get_input_dataset(args) + tokenizer = cli.get_tokenizer(args) + weight_path_str = str(args.irpa_file) + + ppl = run_perplexity( + weight_path=weight_path, + weight_path_str=weight_path_str, + tokenizer=tokenizer, + torch_device=torch_device, + iree_device=iree_device, + iree_hip_target=args.iree_hip_target, + iree_hal_target_backends=args.iree_hal_target_backends, + kv_cache_type=kv_cache_type, + tensor_parallelism_size=args.tensor_parallelism_size, + attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, + ) + + logger.info(f"\n{json.dumps(ppl, indent=2)}") + return ppl + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/sharktank/sharktank/evaluate/perplexity_prefill.py b/sharktank/sharktank/evaluate/perplexity_prefill.py new file mode 100644 index 000000000..2bb785801 --- /dev/null +++ b/sharktank/sharktank/evaluate/perplexity_prefill.py @@ -0,0 +1,276 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys +import logging +import time +from datetime import timedelta + +import json +import numpy as np +from tqdm import tqdm + +import torch +from torch.nn import CrossEntropyLoss + +from sharktank.layers import * +from sharktank.types import * + +from sharktank.models.llama.llama import * +from sharktank.models.mixtral.mixtral import * +from sharktank.models.grok.grok import * + +from sharktank.utils import cli +from sharktank.utils.load_llm import * + +log_levels = { + "info": logging.INFO, + "debug": logging.DEBUG, +} +logger = logging.getLogger("eval") + +logger.setLevel(log_levels["debug"]) + +logger.root.handlers[0].setFormatter( + logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") +) + +__all__ = ["Perplexity", "run_perplexity"] + + +class Perplexity: + """ + Perplexity (PPL) is one of the most common metrics for evaluating language models. + It is defined as the exponentiated average negative log-likelihood of a sequence, + calculated with exponent base `e`. + + For more information, see https://huggingface.co/docs/transformers/perplexity + """ + + def __init__( + self, + prompts: list, + device, + kv_cache_type, + ): + self.prompts = prompts + self.add_start_token = False + self.batch_size = 16 + self.bs = len(prompts) + self.device = device + self.kv_cache_type = kv_cache_type + + def timeit(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + seconds = end - start + time_taken = abs(timedelta(seconds=round(seconds))) + + if seconds < 1: + time_taken = f" {seconds * 1000} ms" + + func_name = func.__name__ + if func_name == "get_perplexity": + func_name = "Total time" + logger.info(f" {func_name}: {time_taken}") + return result + + return wrapper + + def print_token_comparison(self, i): + if i <= self.max_prompt_length: + batch_predicted_token_id = [[i[-1]] for i in self.batch.results] + batch_predicted_token = self.generator.tokenizer.decode( + batch_predicted_token_id + ) + logger.debug(f"Predicted:") + logger.debug(f"{batch_predicted_token}") + logger.debug(f"{batch_predicted_token_id}") + + expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist() + expected_token = self.generator.tokenizer.decode(expected_token_id) + logger.debug(f"Expected:") + logger.debug(f"{expected_token}") + logger.debug(f"{expected_token_id}") + + @timeit + def load_model(self, dataset, tokenizer): + + theta = dataset.root_theta + + config = LlamaModelConfig( + hp=configs.LlamaHParams.from_gguf_props(dataset.properties), + block_seq_stride=16, + kv_cache_type=self.kv_cache_type, + device=self.device, + activation_dtype=torch.float32, + attention_dtype=torch.float32, + ) + + if config.hp.expert_count: + if config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, config) + else: + model = PagedMixtralModelV1(theta, config) + else: + model = PagedLlamaModelV1(theta, config) + + self.generator = TorchGenerator(model, tokenizer) + + @timeit + def get_logits(self): + + token_ids, seq_lens = self.generator.tokenizer.encode( + self.prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + add_start_token=self.add_start_token, + ) + + logger.info(f" Prompts:") + for idx, prompt in enumerate(self.prompts): + logger.info(f" Prompt {idx} - {prompt.encode()}\n{token_ids[idx]}") + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.device) + ) + + is_first_token = True + for i in tqdm( + range(0, self.max_prompt_length - 1), + desc="eval: Calculating logits", + ): + token_batch = self.token_ids[:, : i + 1] + logger.debug(f"Prefill:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + + token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( + token_ids=token_batch.tolist(), + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + token_batch = torch.tensor(token_batch, device=self.device) + seq_lens_batch = torch.tensor(seq_lens_batch, device=self.device) + + self.batch = self.generator.begin_eval_batch( + token_batch=token_batch, + seq_lens_batch=seq_lens_batch, + bs=self.bs, + ) + + self.cache_state = self.batch.prefill() + self.print_token_comparison(i) + + if is_first_token: + self.out_logits = self.batch.prefill_logits[:, 0:1, :] + is_first_token = False + else: + self.out_logits = torch.cat( + (self.out_logits, self.batch.prefill_logits[:, 0:1, :]), 1 + ) + + pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1] + + self.pad_logits = torch.zeros( + self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2] + ) + + self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to( + self.device + ) + + @timeit + def compute_perplexity(self): + loss_fct = CrossEntropyLoss(reduction="none") + + ## perplexity = e ^ (sum(losses) / num_tokenized_tokens) + crossentropy_loss = ( + loss_fct(self.out_logits.transpose(1, 2), self.token_ids) + * self.attention_mask + ).sum(1) + crossentropy_loss = torch.tensor(crossentropy_loss.tolist()) + perplexity_batch = torch.exp( + crossentropy_loss / self.attention_mask.sum(1) + ).tolist() + + return { + "perplexities": perplexity_batch, + "mean_perplexity": np.mean(perplexity_batch), + } + + @timeit + def get_perplexity(self): + + self.get_logits() + + self.out_logits = self.out_logits[..., :-1, :].contiguous() + self.token_ids = self.token_ids[..., 1:].contiguous() + self.attention_mask = self.attention_mask[..., 1:].contiguous() + + assert self.token_ids.shape == self.out_logits.shape[0:2] + + logger.debug(f"Logits shape: {self.out_logits.shape}") + logger.debug(f"Token ids: {self.token_ids}, {self.token_ids.shape}") + logger.debug( + f"Logits shape: {self.attention_mask}, {self.attention_mask.shape}" + ) + + return self.compute_perplexity() + + +def run_perplexity( + prompts: list[str], + dataset, + tokenizer, + device, + kv_cache_type, +): + perplexity = Perplexity(prompts=prompts, device=device, kv_cache_type=kv_cache_type) + + perplexity.load_model(dataset, tokenizer) + ppl = perplexity.get_perplexity() + + return ppl + + +def main(argv): + parser = cli.create_parser() + parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") + parser.add_argument("--device", help="Torch device (or default)") + + cli.add_input_dataset_options(parser) + cli.add_tokenizer_options(parser) + args = cli.parse(parser, args=argv) + + device = torch.device(args.device) if args.device else None + kv_cache_type = args.kv_cache_type + dataset = cli.get_input_dataset(args) + tokenizer = cli.get_tokenizer(args) + + prompt_path = "sharktank/evaluate/data/eval_prompts.txt" + with open(prompt_path, "r") as f: + input_texts = f.read().splitlines() + + ppl = run_perplexity( + prompts=input_texts[0:1], + dataset=dataset, + tokenizer=tokenizer, + device=device, + kv_cache_type=kv_cache_type, + ) + + logger.info(f"\n{json.dumps(ppl, indent=2)}") + return ppl + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py new file mode 100644 index 000000000..258e8c9a0 --- /dev/null +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -0,0 +1,370 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys +import logging +import time +import random +import re +from datetime import timedelta +import json +import numpy as np +from tqdm import tqdm + +from datasets import load_dataset + +import torch +from torch.nn import CrossEntropyLoss + +from sharktank.layers import * +from sharktank.types import * + +from sharktank.models.llama.llama import * +from sharktank.models.mixtral.mixtral import * +from sharktank.models.grok.grok import * + +from ..models.llama.sharding import shard_theta + +from sharktank.utils import cli +from sharktank.utils.load_llm import * + +log_levels = { + "info": logging.INFO, + "debug": logging.DEBUG, +} +logger = logging.getLogger("eval") + +logger.setLevel(log_levels["info"]) + +logger.root.handlers[0].setFormatter( + logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") +) + +__all__ = ["Perplexity_torch", "run_perplexity_torch"] + + +class Perplexity_torch: + """ + Perplexity (PPL) is one of the most common metrics for evaluating language models. + It is defined as the exponentiated average negative log-likelihood of a sequence, + calculated with exponent base `e`. + + For more information, see https://huggingface.co/docs/transformers/perplexity + """ + + def __init__( + self, + device, + kv_cache_type, + ): + self.device = device + self.kv_cache_type = kv_cache_type + self.activation_dtype = torch.float32 + self.attention_dtype = torch.float32 + + def timeit(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) + + func_name = func.__name__ + if func_name == "get_perplexity": + func_name = "Calculate perplexity" + logger.info(f" {func_name}: {time_taken}") + return result + + return wrapper + + def print_token_comparison(self, i): + if i <= self.max_prompt_length: + batch_predicted_token_id = [[i[-1]] for i in self.batch.results] + batch_predicted_token = self.generator.tokenizer.decode( + batch_predicted_token_id + ) + logger.debug(f"Predicted:") + logger.debug(f"{batch_predicted_token}") + logger.debug(f"{batch_predicted_token_id}") + + expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist() + expected_token = self.generator.tokenizer.decode(expected_token_id) + logger.debug(f"Expected:") + logger.debug(f"{expected_token}") + logger.debug(f"{expected_token_id}") + + @timeit + def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kernel): + + self.config = LlamaModelConfig( + hp=configs.LlamaHParams.from_gguf_props(dataset.properties), + block_seq_stride=16, + kv_cache_type=self.kv_cache_type, + device=self.device, + activation_dtype=self.activation_dtype, + attention_dtype=self.attention_dtype, + tensor_parallelism_size=tensor_parallelism_size, + ) + + if self.config.tensor_parallelism_size > 1: + dataset.root_theta = shard_theta(dataset.root_theta, self.config) + + theta = dataset.root_theta + + if self.config.hp.expert_count: + if self.config.hp.model_arch == "grok": + model = PagedGrokModelV1(theta, self.config) + else: + model = PagedMixtralModelV1(theta, self.config) + else: + model = PagedLlamaModelV1(theta, self.config) + + self.generator = TorchGenerator(model, tokenizer) + + @timeit + def get_prompts(self, num_prompts): + + test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ + "text" + ] + + num_test_prompts = 219 + + random.seed(0) + test_prompts = random.sample(test_prompts, num_test_prompts) + + # Ignore prompts that are: empty, less than 20 tokens or a title. + test_prompts = [ + s.replace("\n", "").rstrip() + for s in test_prompts + if s != "" and len(s.split()) >= 20 and s.count("=") < 2 + ][0:num_prompts] + + self.test_prompts = test_prompts + + self.bs = len(test_prompts) + + logger.info(f" Batch size: {self.bs}") + + @timeit + def get_logits(self, page_cache_size): + + is_first_token = True + start = 0 + for i in tqdm( + range(start, self.max_prompt_length - 1), + mininterval=300, + desc="eval: Calculating logits", + ): + logger.debug(f"Iteration: {i}") + + if is_first_token: + + token_batch = self.token_ids[:, : i + 1] + logger.debug(f"Prefill:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + + token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( + token_ids=token_batch.tolist(), + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + logger.debug(f"{token_batch}") + + token_batch = torch.tensor(token_batch, device=self.device) + seq_lens_batch = torch.tensor(seq_lens_batch, device=self.device) + + self.batch = self.generator.begin_eval_batch( + token_batch=token_batch, + seq_lens_batch=seq_lens_batch, + bs=self.bs, + page_cache_size=page_cache_size, + ) + + self.batch.prefill() + self.out_logits = self.batch.prefill_logits[:, 0:1, :] + is_first_token = False + + self.print_token_comparison(i) + + else: + token_batch = self.token_ids[:, i : i + 1] + + logger.debug("Decode:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + logger.debug(f"{token_batch.tolist()}") + + self.batch.decode(token_batch=token_batch) + self.out_logits = torch.cat( + (self.out_logits, self.batch.decode_logits), 1 + ) + + self.print_token_comparison(i) + + pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1] + + self.pad_logits = torch.zeros( + self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2] + ) + + self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to( + self.device + ) + + @timeit + def compute_perplexity(self): + loss_fct = CrossEntropyLoss(reduction="none") + + ## perplexity = e ^ (sum(losses) / num_tokenized_tokens) + crossentropy_loss = ( + loss_fct(self.out_logits.transpose(1, 2), self.token_ids) + * self.attention_mask + ).sum(1) + crossentropy_loss = torch.tensor(crossentropy_loss.tolist()) + perplexity_batch = torch.exp( + crossentropy_loss / self.attention_mask.sum(1) + ).tolist() + + perplexity_batch = [round(ppl, 6) for ppl in perplexity_batch] + + return { + "perplexities": perplexity_batch, + "mean_perplexity": round(np.mean(perplexity_batch), 6), + } + + @timeit + def get_perplexity(self): + + token_ids, seq_lens = self.generator.tokenizer.encode( + self.test_prompts, + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + self.page_cache_size = ( + len(token_ids[0]) // self.config.block_seq_stride + ) * self.bs + 1 + + logger.debug(f" Prompts for Evaluation:") + for idx, prompt in enumerate(self.test_prompts): + logger.debug( + f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n" + ) + + self.max_prompt_length = max(seq_lens) + + self.token_ids = torch.tensor(token_ids, device=self.device) + self.attention_mask = ( + (self.token_ids != 0).int().detach().clone().to(self.device) + ) + + self.get_logits(page_cache_size=self.page_cache_size) + + self.out_logits = self.out_logits[..., :-1, :].contiguous() + self.token_ids = self.token_ids[..., 1:].contiguous() + self.attention_mask = self.attention_mask[..., 1:].contiguous() + + logger.debug(f"Final Logits shape: {self.out_logits.shape}") + logger.debug(f"Token ids: {self.token_ids}, \n{self.token_ids.shape}") + logger.debug( + f"Mask shape: {self.attention_mask}, \n{self.attention_mask.shape}" + ) + + assert self.token_ids.shape == self.out_logits.shape[0:2] + + return self.compute_perplexity() + + +def run_perplexity_torch( + dataset, + tokenizer, + device, + kv_cache_type, + tensor_parallelism_size, + attention_kernel, + num_prompts, +): + start = time.time() + + perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + perplexity.get_prompts(num_prompts=num_prompts) + perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) + ppl = perplexity.get_perplexity() + + end = time.time() + total_time = round(end - start, 2) + if total_time < 60: + total_time = str(total_time) + " secs" + else: + total_time = str(round(total_time / 60, 2)) + " mins" + logger.info(f" Total time taken: {total_time}") + + return ppl + + +def main(argv): + parser = cli.create_parser() + parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") + parser.add_argument("--device", help="Torch device (or default)") + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch_sdpa"], + ) + + parser.add_argument( + "--tensor-parallelism-size", + type=int, + default=1, + help="Number of devices for tensor parallel sharding.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts for perplexity test", + ) + + cli.add_input_dataset_options(parser) + cli.add_tokenizer_options(parser) + args = cli.parse(parser, args=argv) + + device = torch.device(args.device) if args.device else None + kv_cache_type = args.kv_cache_type + dataset = cli.get_input_dataset(args) + tokenizer = cli.get_tokenizer(args) + + ppl = run_perplexity_torch( + dataset=dataset, + tokenizer=tokenizer, + device=device, + kv_cache_type=kv_cache_type, + tensor_parallelism_size=args.tensor_parallelism_size, + attention_kernel=args.attention_kernel, + num_prompts=args.num_prompts, + ) + + logger.info(f"\n{json.dumps(ppl, indent=2)}") + return ppl + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 5439cc38b..6dd9785c3 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -9,14 +9,18 @@ import json import torch -from shark_turbine.aot import * +from iree.turbine.aot import * from sharktank.layers import * from sharktank.types import * +from sharktank.utils.math import ceildiv # TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +from ..models.llama.sharding import shard_theta from ..models.mixtral.mixtral import * +from ..models.grok.grok import * +from .. import ops def main(): @@ -32,7 +36,7 @@ def main(): parser.add_argument( "--output-config", help="Output file path for exported config file", - default="tmp/batch_llama_v1.json", + default="/tmp/batch_llama_v1.json", ) parser.add_argument( "--bs", @@ -45,16 +49,40 @@ def main(): help="Include verbose logging", action="store_true", ) + parser.add_argument( + "--strict", + help="Enables strictness during export", + action="store_true", + ) + cli.add_quantization_options(parser) + cli.add_model_options(parser) args = cli.parse(parser) + dataset_type = cli.get_input_data_files(args) + dataset_type = "irpa" if "irpa" in dataset_type else "gguf" dataset = cli.get_input_dataset(args) - hp = configs.LlamaHParams.from_gguf_props(dataset.properties) - llama_config = LlamaModelConfig(hp) - llama_config.static_tables = False # Rely on the compiler for hoisting tables. - llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" + tensor_parallelism_size = ( + dataset.properties["tensor_parallelism_size"] + if "tensor_parallelism_size" in dataset.properties + else 1 + ) + + llama_config = LlamaModelConfig( + hp, + tensor_parallelism_size=tensor_parallelism_size, + use_hf=False, + static_tables=False, # Rely on the compiler for hoisting tables. + kv_cache_type="direct" if args.bs == [1] else "paged", + attention_kernel=args.attention_kernel, + ) + llama_config.fake_quant = args.fake_quant + if llama_config.hp.expert_count: - model = PagedMixtralModelV1(dataset.root_theta, llama_config) + if llama_config.hp.model_arch == "grok": + model = PagedGrokModelV1(dataset.root_theta, llama_config) + else: + model = PagedMixtralModelV1(dataset.root_theta, llama_config) else: model = PagedLlamaModelV1(dataset.root_theta, llama_config) @@ -80,86 +108,157 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): fxb = FxProgramsBuilder(model) - def generate_batch_prefill(bs: int): - tokens = torch.empty(bs, 64, dtype=torch.int64) - seq_lens = torch.empty(bs, dtype=torch.int64) - seq_block_ids = torch.empty(bs, 4, dtype=torch.int64) - block_dim = torch.export.Dim( - "block", max=(hp.context_length - 1) // llama_config.block_seq_stride - ) - sl_dim = llama_config.block_seq_stride * block_dim - + def setup_cache(model, shard_count): if model.config.kv_cache_type == "paged": cache_state = model.cache.allocate( page_count=hp.context_length // llama_config.block_seq_stride ) page_dim = torch.export.Dim("page") - cache_state_dynamic_shapes = [{0: page_dim}] + + dynamic_shapes = [{0: page_dim}] + unpacked = cache_state + arg_affinities = {} + shard_dim = None + + # Need to unpacke that state when sharded + if llama_config.tensor_parallelism_size > 1: + shard_dim = cache_state[0].shard_dim + + unpacked = [[shard._data for shard in cs.shards] for cs in cache_state] + dynamic_shapes = [ + [ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes + ] + + for i in range(llama_config.tensor_parallelism_size): + arg_affinities[i] = DeviceAffinity(str(i)) + + return unpacked, shard_dim, dynamic_shapes, arg_affinities + elif model.config.kv_cache_type == "direct": cache_state = model.cache.allocate(bs=1) # Direct cache dimensions: # 2 * transformer_block_count of... # [bs, seq_length, attn_head_count, attn_head_dim] - cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] + dynamic_shapes = [None] + arg_affinities = {} + shard_dim = None + return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities else: raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") + def repack_cache(cache, shard_dim): + return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache] + + def generate_batch_prefill(bs: int): + # torch.export.Dim would make min at least 2 + block_dim_min = 2 + block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1 + block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max) + sl_dim = llama_config.block_seq_stride * block_dim + seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64) + tokens = torch.empty( + bs, + seq_block_ids.shape[1] * llama_config.block_seq_stride, + dtype=torch.int64, + ) + seq_lens = torch.empty(bs, dtype=torch.int64) + + cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache( + model, llama_config.tensor_parallelism_size + ) + + if llama_config.tensor_parallelism_size > 1: + # We need to offset the indices for the cache + arg_affinities = {key + 3: arg_affinities[key] for key in arg_affinities} + + for i in range(3): + arg_affinities[i] = DeviceAffinity("0") + dynamic_shapes = { "tokens": {1: sl_dim}, "seq_lens": {}, "seq_block_ids": {1: block_dim}, - "cache_state": cache_state_dynamic_shapes, + "cs": cache_dynamic_shapes, } print(f"Exporting prefill_bs{bs}") @fxb.export_program( name=f"prefill_bs{bs}", - args=(tokens, seq_lens, seq_block_ids, cache_state), + args=(tokens, seq_lens, seq_block_ids, cache), dynamic_shapes=dynamic_shapes, + strict=args.strict, + arg_device=arg_affinities, ) - def _(model, tokens, seq_lens, seq_block_ids, cache_state): + def _(model, tokens, seq_lens, seq_block_ids, cs): + if ( + model.config.tensor_parallelism_size == 1 + and model.config.kv_cache_type == "direct" + ): + cache_tensors = torch.unbind(cs) + else: + cache_tensors = cs + sl = tokens.shape[1] input_mask = model.input_mask(seq_lens, sl) attention_mask = model.attention_mask(input_mask) + + if llama_config.tensor_parallelism_size != 1: + shard_count = llama_config.tensor_parallelism_size + + tokens = ops.replicate(tokens, count=shard_count) + attention_mask = ops.replicate(attention_mask, count=shard_count) + seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + + cache_tensors = repack_cache(cs, cache_shard_dim) + logits = model.prefill( tokens, attention_mask=attention_mask, seq_block_ids=seq_block_ids, - cache_state=cache_state, + cache_state=cache_tensors, ) + + if llama_config.tensor_parallelism_size != 1: + logits = ops.unshard(logits) + return logits def generate_batch_decode(bs: int): - tokens = torch.ones(bs, 1, dtype=torch.int64) - seq_lens = torch.ones(bs, dtype=torch.int64) - start_positions = torch.ones(bs, dtype=torch.int64) - seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64) - block_dim = torch.export.Dim( - "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + # torch.export.Dim would make min at least 2 + block_dim_min = 2 + block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1 + block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max) + tokens = torch.empty( + bs, + 1, + dtype=torch.int64, ) + seq_lens = torch.empty(bs, dtype=torch.int64) + start_positions = torch.ones(bs, dtype=torch.int64) + seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64) - if model.config.kv_cache_type == "paged": - cache_state = model.cache.allocate( - page_count=hp.context_length // llama_config.block_seq_stride - ) - page_dim = torch.export.Dim("page") - cache_state_dynamic_shapes = [{0: page_dim}] - elif model.config.kv_cache_type == "direct": - cache_state = model.cache.allocate(bs=1) - # Direct cache dimensions: - # 2 * transformer_block_count of... - # [bs, seq_length, attn_head_count, attn_head_dim] - cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") + ( + cache_state, + cache_shard_dim, + cache_dynamic_shapes, + arg_affinities, + ) = setup_cache(model, llama_config.tensor_parallelism_size) + + if llama_config.tensor_parallelism_size > 1: + # We need to offset the indices for the cache + arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities} + + # Inputs have default affinity 0 + for i in range(4): + arg_affinities[i] = DeviceAffinity("0") dynamic_shapes = { "tokens": {}, "seq_lens": {}, "start_positions": {}, "seq_block_ids": {1: block_dim}, - "cache_state": cache_state_dynamic_shapes, + "cache_state": cache_dynamic_shapes, } print(f"Exporting decode_bs{bs}") @@ -174,6 +273,8 @@ def generate_batch_decode(bs: int): cache_state, ), dynamic_shapes=dynamic_shapes, + strict=args.strict, + arg_device=arg_affinities, ) def _( model, @@ -187,6 +288,17 @@ def _( seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride ) attention_mask = model.decode_attention_mask(input_mask) + + if llama_config.tensor_parallelism_size != 1: + shard_count = llama_config.tensor_parallelism_size + + tokens = ops.replicate(tokens, count=shard_count) + attention_mask = ops.replicate(attention_mask, count=shard_count) + start_positions = ops.replicate(start_positions, count=shard_count) + seq_block_ids = ops.replicate(seq_block_ids, count=shard_count) + + cache_state = repack_cache(cache_state, cache_shard_dim) + logits = model.decode( tokens, attention_mask=attention_mask, @@ -194,12 +306,17 @@ def _( seq_block_ids=seq_block_ids, cache_state=cache_state, ) + + if llama_config.tensor_parallelism_size != 1: + logits = ops.unshard(logits) + return logits bsizes = [] for bs in args.bs: generate_batch_prefill(bs) - generate_batch_decode(bs) + if not args.skip_decode: + generate_batch_decode(bs) bsizes.append(bs) config = generate_params_json(hp, bsizes, bsizes) print("GENERATED!") @@ -209,7 +326,7 @@ def _( print(f"EXPORT {name}:\n{ep}") print("Exporting") - output = export(fxb) + output = export(fxb, import_symbolic_shape_expressions=True) print(f"Saving to '{args.output_mlir}'") output.save_mlir(args.output_mlir) json.dump(config, open(args.output_config, "w")) diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index c415b2ce1..b30acc026 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -8,6 +8,8 @@ from typing import Optional +from safetensors import safe_open + import math import sys @@ -16,11 +18,15 @@ from ..layers import * from ..types import * +from ..ops import replicate, unshard + # TODO: Should be using a base class with the protocol supported. from ..models.mixtral.mixtral import * +from ..models.grok.grok import * from ..models.llama.llama import * +from ..models.llama.sharding import shard_theta from ..utils.debugging import trace_tensor -from ..utils.tokenizer import InferenceTokenizer, load_tokenizer +from ..utils.tokenizer import InferenceTokenizer class TorchGenerator: @@ -38,9 +44,9 @@ def __init__( self.tokenizer = tokenizer if model.cache.is_paged: self.shared_cache_state = model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: self.shared_cache_state = None - self.free_pages = list(range(1, 128)) self.end_token = end_token @property @@ -51,6 +57,7 @@ def begin_batch(self, prompts: list[str]): token_ids, seq_lens = self.tokenizer.encode( prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride ) + token_ids = torch.tensor(token_ids, device=self.model.device) seq_lens = torch.tensor(seq_lens, device=self.model.device) if self.shared_cache_state is not None: @@ -145,13 +152,23 @@ def prefill(self): trace_tensor("prefill.token_ids", self.token_ids) trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) trace_tensor("prefill.attention_mask", attention_mask) + + token_ids = self.token_ids + if model.config.tensor_parallelism_size != 1: + tp = model.config.tensor_parallelism_size + token_ids = replicate(token_ids, tp) + attention_mask = replicate(attention_mask, tp) + seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) + logits = model.prefill( - self.token_ids, + token_ids, attention_mask=attention_mask, seq_block_ids=seq_block_ids_tensor, cache_state=self.cache_state, ) + logits = unshard(logits) + # TODO: Generalize the sampling and don't make it swap on/off cpu. # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. @@ -179,6 +196,14 @@ def decode(self): trace_tensor("decode.start_positions", start_positions) trace_tensor("decode.seq_block_ids", seq_block_ids_tensor) trace_tensor("decode.attention_mask", decode_attention_mask) + + if model.config.tensor_parallelism_size != 1: + tp = model.config.tensor_parallelism_size + self.next_tokens = replicate(self.next_tokens, tp) + start_positions = replicate(start_positions, tp) + seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) + decode_attention_mask = replicate(decode_attention_mask, tp) + logits = model.decode( self.next_tokens, attention_mask=decode_attention_mask, @@ -186,6 +211,8 @@ def decode(self): seq_block_ids=seq_block_ids_tensor, cache_state=self.cache_state, ) + + logits = unshard(logits) trace_tensor("decode.logits", logits) # TODO: Normalize the output of extract_tokens_from_logits into # tensor [bs, 1]. @@ -218,17 +245,28 @@ def main(): help="DType to use for activations in the model", default="float32", ) + parser.add_argument( + "--use-hf", + action="store_true", + default=False, + ) + parser.add_argument( + "--tensor-parallelism-size", + type=int, + default=1, + help="How many devices are involved for tensor parallel sharding.", + ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) + cli.add_quantization_options(parser) + cli.add_model_options(parser) args = cli.parse(parser) - device = torch.device(args.device) if args.device else None activation_dtype = getattr(torch, args.activation_dtype) assert isinstance(activation_dtype, torch.dtype) dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) prompts = args.prompt - config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, @@ -236,13 +274,22 @@ def main(): device=device, activation_dtype=activation_dtype, attention_dtype=activation_dtype, + attention_kernel=args.attention_kernel, + use_hf=args.use_hf, + tensor_parallelism_size=args.tensor_parallelism_size, + fake_quant=args.fake_quant, ) + if config.tensor_parallelism_size > 1: + dataset.root_theta = shard_theta(dataset.root_theta, config) if config.hp.expert_count: - model = PagedMixtralModelV1(dataset.root_theta, config) + if config.hp.model_arch == "grok": + model = PagedGrokModelV1(dataset.root_theta, config) + else: + model = PagedMixtralModelV1(dataset.root_theta, config) else: model = PagedLlamaModelV1(dataset.root_theta, config) - + if args.save_intermediates_path: from ..utils.patching import SaveModuleResultTensorsPatch @@ -272,6 +319,7 @@ def main(): ) print(f":: Result tokens: {batch.results}") batch.print_current_results() + counter += 1 if __name__ == "__main__": diff --git a/sharktank/sharktank/examples/sharding/export_ffn_net.py b/sharktank/sharktank/examples/sharding/export_ffn_net.py index f80b9a2ac..f261a92e1 100644 --- a/sharktank/sharktank/examples/sharding/export_ffn_net.py +++ b/sharktank/sharktank/examples/sharding/export_ffn_net.py @@ -50,6 +50,7 @@ def forward(self, x: torch.Tensor): ffn_gate_weight = self.theta.tensor("ffn_gate", "weight") ffn_up_weight = self.theta.tensor("ffn_up", "weight") ffn_down_weight = self.theta.tensor("ffn_down", "weight") + x = ops.replicate(x, count=ffn_gate_weight.shard_count) ffn_gate = ops.elementwise( torch.nn.functional.silu, ops.linear(x, ffn_gate_weight) ) @@ -89,7 +90,7 @@ def main(raw_args=None): ds = Dataset.load(args.output_irpa_file) mdl = ShardedFFN(ds.root_theta) - from shark_turbine import aot + from iree.turbine import aot example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16) ep = torch.export.export(mdl, (example_arg,)) diff --git a/sharktank/sharktank/examples/sharding/export_gemm.py b/sharktank/sharktank/examples/sharding/export_gemm.py index 7a4322e38..9744a6d82 100644 --- a/sharktank/sharktank/examples/sharding/export_gemm.py +++ b/sharktank/sharktank/examples/sharding/export_gemm.py @@ -4,7 +4,7 @@ import torch from torch import Tensor from sharktank import ops -from shark_turbine import aot +from iree.turbine import aot def export_gemm( diff --git a/sharktank/sharktank/examples/sharding/shard_llm_dataset.py b/sharktank/sharktank/examples/sharding/shard_llm_dataset.py index 0921a2c83..91e88d071 100644 --- a/sharktank/sharktank/examples/sharding/shard_llm_dataset.py +++ b/sharktank/sharktank/examples/sharding/shard_llm_dataset.py @@ -10,7 +10,8 @@ weights of an LLM by converting the RHS of all eligible layers to a sharded form. """ -from ...transforms.dataset import MmtRHSShardingTransform +from ...models.llama.sharding import shard_theta +from ...layers import LlamaHParams, LlamaModelConfig from ...types import * @@ -21,16 +22,30 @@ def main(raw_args=None): cli.add_input_dataset_options(parser) cli.add_output_dataset_options(parser) parser.add_argument( - "--num-shards", type=int, required=True, help="Number of shards to split" + "--tensor-parallelism-size", + type=int, + required=True, + help="Number of shards to split", ) args = cli.parse(parser, args=raw_args) dataset = cli.get_input_dataset(args) - tr = MmtRHSShardingTransform( - r"^blk\.[0-9]+\.(attn_k|attn_q|attn_v|ffn_gate|ffn_up|ffn_down)\.weight$", - num_shards=8, + if args.output_irpa_file is None: + raise RuntimeError(f"Need file destination for IRPA file") + + if args.tensor_parallelism_size < 2: + raise RuntimeError( + f"Expect sharding greater than 1 found {args.tensor_parallelism_size}" + ) + + hp = LlamaHParams.from_gguf_props(dataset.properties) + llama_config = LlamaModelConfig( + hp, tensor_parallelism_size=args.tensor_parallelism_size ) - dataset.transform(tr) + sharded_theta = shard_theta(dataset.root_theta, llama_config) + sharded_theta.rename_tensors_to_paths() + dataset.root_theta = sharded_theta + dataset.properties["tensor_parallelism_size"] = args.tensor_parallelism_size dataset.save(args.output_irpa_file, io_report_callback=print) diff --git a/sharktank/sharktank/export.py b/sharktank/sharktank/export.py new file mode 100644 index 000000000..0a1c6940d --- /dev/null +++ b/sharktank/sharktank/export.py @@ -0,0 +1,174 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable, Any +import torch +from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder +from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten +from .types.tensors import ShardedTensor +from torch.utils._pytree import PyTree, _is_leaf +import functools + + +def flatten_signature( + *sample_args: list[PyTree], **sample_kwargs: dict[str, PyTree] +) -> Callable[[Callable], Any]: + """Decorator that flattens the signature of a function using PyTorch's type + registration. + It will flatten the same way torch PyTorch does, returning a function that accepts + and returns a flat list of torch.Tensor. + The decorator requires sample arguments of the unflattened function. + + ``` + @flatten_signature( + { + "a1": SplitPrimitiveTensor(ts=[torch.tensor([1])], shard_dim=0), + "a2": torch.tensor([2]), + }, + [DefaultPrimitiveTensor(data=torch.tensor([3]))] + ) + def f(a, b): + return a["a1"], b + ``` + + will result in a function with signature + + ``` + ( + torch.Tensor of size 1, + torch.Tensor of size 2, + torch.Tensor of size 3, + ) -> ( + torch.Tensor of size 1, + torch.Tensor of size 2, + ) + ``` + """ + flat_sample_args, args_tree_spec = tree_flatten(sample_args) + n_args = len(flat_sample_args) + kwargs_tree_spec = tree_structure(sample_kwargs) + + def _decorator(f: Callable) -> Callable: + def _wrapper(*flat_args: list[Any]) -> list[Any]: + unflattended_args = tree_unflatten(flat_args[:n_args], args_tree_spec) + unflattended_kwargs = tree_unflatten(flat_args[n_args:], kwargs_tree_spec) + return tree_flatten(f(*unflattended_args, **unflattended_kwargs))[0] + + return _wrapper + + return _decorator + + +def get_argument_flat_device_affinities( + *args: list[PyTree], **kwargs: dict[str, PyTree] +) -> dict[int, DeviceAffinity]: + """Return the flat device affinities for unflattened arguments. + ShardedTensor types have their device affinities assigned. + All other arguments are left unassigned. + + ``` + get_argument_flat_device_affinities( + torch.Tensor([1]), + [ReplicatedTensor(ts=[torch.tensor([2]), torch.tensor([3])])] + ) + ``` + returns + ``` + { + 1: DeviceAffinity("0"), + 2: DeviceAffinity("1"), + } + ``` + """ + + def is_leaf(v: PyTree) -> bool: + if isinstance(v, ShardedTensor): + return True + # TODO: It is sad _is_leaf is private. Find a way not use it. + from torch.utils._pytree import _is_leaf + + return _is_leaf(v) + + # flattened up to a sharded tensor. + flat_args_up_to_sharded_tensor = tree_flatten((args, kwargs), is_leaf=is_leaf)[0] + nested_device_affinities: list[list[DeviceAffinity | None]] = [ + [DeviceAffinity(f"{shard_idx}") for shard_idx in range(len(arg.shards))] + if isinstance(arg, ShardedTensor) + else [None] + for arg in flat_args_up_to_sharded_tensor + ] + flat_device_affinities: list[DeviceAffinity | None] = [ + affinity + for affinity_list in nested_device_affinities + for affinity in affinity_list + ] + return { + arg_idx: affinity + for arg_idx, affinity in enumerate(flat_device_affinities) + if affinity is not None + } + + +def export( + f: Callable | None = None, + fx_builder: FxProgramsBuilder | None = None, + args: tuple[PyTree] | None = None, + kwargs: dict[PyTree] | None = None, + arg_device: dict[int, DeviceAffinity] | None = None, + *transitive_args, + **transitive_kwargs, +) -> torch.export.ExportedProgram: + """Wrapper around FxProgramsBuilder.export_program that handles + the sharktank custom tensor types. + + If `arg_device` is not specified it will extract the affinities + from the passed `args`. + `arg_device` must pass the affinities for the flattened arguments. + These are those that correspond to torch.Tensor. + For example a sharded tensor with 2 shards would result in 2 arguments in the MLIR + signature.""" + + if f is None: + return functools.partial( + export, + fx_builder=fx_builder, + args=args, + kwargs=kwargs, + arg_device=arg_device, + *transitive_args, + **transitive_kwargs, + ) + + if args is None: + args = [] + if kwargs is None: + kwargs = {} + if arg_device is None: + arg_device = get_argument_flat_device_affinities(*args, **kwargs) + flat_args = tree_flatten((args, kwargs))[0] + if fx_builder is not None: + # Flatten the signature of the function. + # Technically this is done during export, but we want the signature to match + # the flat device affinities. + def module_fn_with_flat_signature(module, *flat_args): + @flatten_signature(*args, **kwargs) + def flat_fn(*args, **kwargs): + return f(module, *args, **kwargs) + + return flat_fn(*flat_args) + + amended_kwargs = dict(**transitive_kwargs) + if "name" not in amended_kwargs or amended_kwargs["name"] is None: + amended_kwargs["name"] = f.__name__ + return fx_builder.export_program( + module_fn_with_flat_signature, + *transitive_args, + args=flat_args, + arg_device=arg_device, + **amended_kwargs, + ) + + assert False, "TODO: implement the case when not using an FxProgramsBuilder" diff --git a/sharktank/sharktank/export_layer/export_kv_cache.py b/sharktank/sharktank/export_layer/export_kv_cache.py new file mode 100644 index 000000000..09f0a1c15 --- /dev/null +++ b/sharktank/sharktank/export_layer/export_kv_cache.py @@ -0,0 +1,125 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +import torch.nn.functional as F + +from iree.turbine.aot import * + +from sharktank.types import SplitPrimitiveTensor +from sharktank.ops import reshard_split, replicate +from sharktank.layers.kv_cache import PagedKVCache +from ..utils import cli + + +def main(): + parser = cli.create_parser() + parser.add_argument( + "--output-mlir", + help="Output file path for exported MLIR file", + default="/tmp/kv_cache.mlir", + ) + parser.add_argument( + "--batch-size", + "-bs", + help="Batch size to generate, e.g. `4` or `2`", + type=lambda arg: int(arg), + default="2", + ) + parser.add_argument( + "--sharding", + help="Sharding level of kv-cache", + type=lambda arg: int(arg), + default="1", + ) + parser.add_argument( + "--verbose", + "-v", + help="Include verbose logging", + action="store_true", + ) + parser.add_argument( + "--strict", + help="Enables strictness during export", + action="store_true", + ) + + args = cli.parse(parser) + + bs = args.batch_size + + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + page_count = bs * seq_length // block_seq_stride + write_seq_length = seq_length - 4 + + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + shard_count=args.sharding, + dtype=torch.float32, + device=None, + ) + + alloc = cache.allocate(page_count=page_count) + allocation = alloc + + model = torch.nn.Module() + fxb = FxProgramsBuilder(model) + + page_ids = torch.empty(bs, seq_length // block_seq_stride, dtype=torch.int64) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + partition_0 = torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ) + + if args.sharding > 1: + partition_0 = reshard_split(partition_0, dim=2, count=args.sharding).shards + allocation = allocation[0].shards + + argument_device_affinities = {} + for i in range(args.sharding): + argument_device_affinities[i] = DeviceAffinity(f"{i}") + argument_device_affinities[i + args.sharding] = DeviceAffinity(f"{i}") + + @fxb.export_program( + name="write", + args=(allocation, partition_0, write_page_ids), + strict=False, + argument_device_affinities=argument_device_affinities, + ) + def _(model, state, partition_0, write_page_ids: torch.Tensor) -> torch.Tensor: + old_state = state + if args.sharding > 1: + state = [SplitPrimitiveTensor(ts=state, shard_dim=alloc[0].shard_dim)] + partition_0 = SplitPrimitiveTensor(ts=partition_0, shard_dim=2) + write_page_ids = replicate(write_page_ids, count=args.sharding) + cache.write( + state, + cache_partitions=[partition_0, partition_0], + transformer_block_index=1, + page_ids=write_page_ids, + ) + return state + + if args.verbose: + for name, ep in fxb.programs.items(): + print(f"EXPORT {name}:\n{ep}") + + print("Exporting") + output = export(fxb) + print(f"Saving to '{args.output_mlir}'") + output.save_mlir(args.output_mlir) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py new file mode 100644 index 000000000..af4ed51d2 --- /dev/null +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -0,0 +1,83 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +import torch.nn.functional as F + +from iree.turbine.aot import * + +from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch +from sharktank.layers.mixture_of_experts_block import MoeBlock +from ..utils import cli + + +def main(): + parser = cli.create_parser() + parser.add_argument( + "--output-mlir", + help="Output file path for exported MLIR file", + default="/tmp/batch_llama_v1.mlir", + ) + parser.add_argument( + "--batch-size", + "-bs", + help="Batch size to generate, e.g. `4` or `2`", + type=lambda arg: int(arg), + default="2", + ) + parser.add_argument( + "--verbose", + "-v", + help="Include verbose logging", + action="store_true", + ) + parser.add_argument( + "--strict", + help="Enables strictness during export", + action="store_true", + ) + parser.add_argument( + "--use-gelu", + help="Enable to use gelu for moe activation", + action="store_true", + ) + + args = cli.parse(parser) + + bs = args.batch_size + + model = MoeBlock( + theta=make_moe_block_theta()("blk.0"), + expert_count=8, + expert_used_count=2, + rms_epsilon=1e-5, + moe_activation=F.gelu if args.use_gelu else F.silu, + ) + fxb = FxProgramsBuilder(model) + input = make_rand_torch((bs, 32, 6144)) + + @fxb.export_program(name="prefill_moe", args=(input,)) + def _(model, input: torch.Tensor) -> torch.Tensor: + return model(input) + + input = make_rand_torch((bs, 1, 6144)) + + @fxb.export_program(name="decode_moe", args=(input,)) + def _(model, input: torch.Tensor) -> torch.Tensor: + return model(input) + + if args.verbose: + for name, ep in fxb.programs.items(): + print(f"EXPORT {name}:\n{ep}") + + print("Exporting") + output = export(fxb) + print(f"Saving to '{args.output_mlir}'") + output.save_mlir(args.output_mlir) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py new file mode 100644 index 000000000..cb28371bb --- /dev/null +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -0,0 +1,424 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Export support for the PagedLLMV1 protocol of models.""" + +import json +import torch + +from typing import Optional + +import torch.nn.functional as F + +from iree.turbine.aot import * + +from sharktank.layers import * +from sharktank.types import * + +from sharktank.models.llama.testing import * +from sharktank.layers import causal_llm + +from sharktank.utils.create_cache import * + +# TODO: Should be using a base class with the protocol supported. +from ..models.llama.llama import LlamaModelConfig, PagedLlamaAttentionBlock + + +def paged_attention( + attention_block: PagedLlamaAttentionBlock, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + is_causal: bool, + seq_block_ids: torch.Tensor, + start_positions: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_state: list[torch.Tensor] = None, + xk_temp: Optional[torch.Tensor] = None, + xv_temp: Optional[torch.Tensor] = None, +): + + bs, batch_seq_len, _, _ = xq.shape + + # Full sequence length. + kv_seq_len = seq_block_ids.shape[1] * attention_block.cache.block_seq_stride + + if attention_block.cache.is_paged: + xk, xv = attention_block.transact_cache_paged( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + elif attention_block.cache.is_direct: + xk, xv = attention_block.transact_cache_direct( + xk_cache_update=xk, + xv_cache_update=xv, + start_positions=start_positions, + kv_seq_len=kv_seq_len, + cache_state=cache_state, + ) + else: + raise NotImplementedError(f"Unsupported KV cache type: {type(cache)}") + + # Expand kv heads for GQA. + gqa_n_rep = attention_block.head_count // attention_block.head_count_kv + assert gqa_n_rep > 0 + if gqa_n_rep > 1: + + def repeat_kv(x: torch.Tensor) -> torch.Tensor: + bs, slen, n_kv_heads, head_dim = x.shape + return ( + x.unsqueeze(-2) + .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim) + ) + + xk = repeat_kv(xk) + xv = repeat_kv(xv) + + # Transpose into [bs, heads, sl, dim] + xq = xq.transpose(1, 2) + keys = xk.transpose(1, 2) + values = xv.transpose(1, 2) + attention_mask = None + attn_output = F.scaled_dot_product_attention( + xq, keys, values, attn_mask=attention_mask, is_causal=is_causal + ) + attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1) + return attn_output + + +def run_llama( + model: PagedLlamaAttentionBlock, + config: LlamaModelConfig, + phase: str, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + # [bs] of starting positions + start_positions: Optional[torch.Tensor] = None, +): + + if phase == "decode": + bs, _, _, _ = xq.shape + + # Allocate per-block temporary K/V tensors. These temporaries hold + # one block's K/V state for the maximum context length. + xk_temp = torch.empty( + [ + bs, + config.hp.context_length, + config.hp.attention_head_count_kv, + config.hp.attn_head_dim, + ], + dtype=config.activation_dtype, + device=config.device, + ) + xv_temp = torch.empty( + [ + bs, + config.hp.context_length, + config.hp.attention_head_count_kv, + config.hp.attn_head_dim, + ], + dtype=config.activation_dtype, + device=config.device, + ) + elif phase == "prefill": + xk_temp = None + xv_temp = None + else: + raise ValueError("'phase' argument needs to be either 'prefill' or 'decode'") + + h = paged_attention( + model, + xq=xq, + xk=xk, + xv=xv, + is_causal=config.is_causal, + start_positions=start_positions, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + + return h + + +def main(): + from ..utils import cli + + parser = cli.create_parser() + # cli.add_input_dataset_options(parser) + parser.add_argument( + "--output-mlir", + help="Output file path for exported MLIR file", + default="/tmp/sharktank/artifacts/paged_llama.mlir", + ) + parser.add_argument( + "--output-config", + help="Output file path for exported config file", + default="/tmp/sharktank/artifacts/paged_llama.json", + ) + parser.add_argument( + "--bs", + help="Comma-separated batch size(s) to generate, e.g. `4` or `2,4`", + type=lambda arg: [int(bs) for bs in arg.split(",")], + default="4", + ) + parser.add_argument( + "--verbose", + help="Include verbose logging", + action="store_true", + ) + + parser.add_argument( + "--is-causal", + help="Enable Causal attention", + action="store_true", + ) + # TODO: move this to CLI to enable re-use with eager + parser.add_argument( + "--attention_kernel", + help="decomposed/torch", + default="decomposed", + ) + + args = cli.parse(parser) + + # dataset = cli.get_input_dataset(args) + # hp = configs.LlamaHParams.from_gguf_props(dataset.properties) + + hp = configs.LlamaHParams( + context_length=4096, + embedding_length=4096, + block_count=1, + feed_forward_length=11008, + attn_head_dim=128, + rope_dimension_count=128, + attention_head_count=32, + attention_layer_norm_rms_epsilon=9.999999747378752e-06, + attention_head_count_kv=32, + model_arch="llama", + ) + + llama_config = LlamaModelConfig(hp) + llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" + llama_config.bs = args.bs + llama_config.is_causal = args.is_causal + + attention_block_theta = make_attention_block_theta( + feature_dim=llama_config.hp.attention_head_count + * llama_config.hp.attn_head_dim, + ffn_dim=llama_config.hp.feed_forward_length, + dtype=llama_config.attention_dtype, + ) + + causal_model = causal_llm.BaseCausalLMModel( + attention_block_theta, context_length=llama_config.hp.context_length + ) + + model = PagedLlamaAttentionBlock( + theta=attention_block_theta, + block_index=0, + cache=create_kv_cache(llama_config), + head_count=llama_config.hp.attention_head_count, + head_dim=llama_config.hp.attn_head_dim, + head_count_kv=llama_config.hp.attention_head_count_kv, + rms_epsilon=llama_config.hp.attention_layer_norm_rms_epsilon, + attention_kernel=args.attention_kernel, + ) + + def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): + return { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": hp.context_length, + "attn_head_count": hp.attention_head_count, + "attn_head_dim": hp.attn_head_dim, + "prefill_batch_sizes": prefill_bs, + "decode_batch_sizes": decode_bs, + "transformer_block_count": hp.block_count, + "block_seq_stride": llama_config.block_seq_stride, + } + + fxb = FxProgramsBuilder(model) + + def generate_batch_prefill(bs: int): + tokens = torch.empty(bs, 64, dtype=torch.int64) + seq_lens = torch.empty(bs, dtype=torch.int64) + seq_block_ids = torch.empty(bs, 4, dtype=torch.int64) + block_dim = torch.export.Dim( + "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + ) + sl_dim = llama_config.block_seq_stride * block_dim + + if llama_config.kv_cache_type == "paged": + cache_state = model.cache.allocate( + page_count=hp.context_length // llama_config.block_seq_stride + ) + page_dim = torch.export.Dim("page") + cache_state_dynamic_shapes = [{0: page_dim}] + elif llama_config.kv_cache_type == "direct": + cache_state = model.cache.allocate(bs=1) + # Direct cache dimensions: + # 2 * transformer_block_count of... + # [bs, seq_length, attn_head_count, attn_head_dim] + cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] + else: + raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") + + dynamic_shapes = { + "tokens": {1: sl_dim}, + "seq_lens": {}, + "seq_block_ids": {1: block_dim}, + "cache_state": cache_state_dynamic_shapes, + } + + q = torch.zeros((bs, 64, 32, 128), dtype=torch.float16) + k = torch.zeros((bs, 64, 32, 128), dtype=torch.float16) + v = torch.zeros((bs, 64, 32, 128), dtype=torch.float16) + + print(f"Exporting prefill_bs{bs}") + example_args = (q, k, v, seq_lens, seq_block_ids, cache_state) + + @fxb.export_program( + name=f"prefill_bs{bs}", + args=example_args, + ) + def _(model, q, k, v, seq_lens, seq_block_ids, cache_state): + + if llama_config.is_causal: + attention_mask = None + else: + sl = tokens.shape[1] + input_mask = causal_model.input_mask(seq_lens, sl) + attention_mask = causal_model.attention_mask(input_mask) + + h = run_llama( + model=model, + config=llama_config, + phase="prefill", + xq=q, + xk=k, + xv=v, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + return h + + def generate_batch_decode(bs: int): + tokens = torch.ones(bs, 1, dtype=torch.int64) + seq_lens = torch.ones(bs, dtype=torch.int64) + start_positions = torch.ones(bs, dtype=torch.int64) + seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64) + block_dim = torch.export.Dim( + "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + ) + + if llama_config.kv_cache_type == "paged": + cache_state = model.cache.allocate( + page_count=hp.context_length // llama_config.block_seq_stride + ) + page_dim = torch.export.Dim("page") + cache_state_dynamic_shapes = [{0: page_dim}] + elif llama_config.kv_cache_type == "direct": + cache_state = model.cache.allocate(bs=1) + # Direct cache dimensions: + # 2 * transformer_block_count of... + # [bs, seq_length, attn_head_count, attn_head_dim] + cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] + else: + raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") + + dynamic_shapes = { + "tokens": {}, + "seq_lens": {}, + "start_positions": {}, + "seq_block_ids": {1: block_dim}, + "cache_state": cache_state_dynamic_shapes, + } + + q = torch.zeros((bs, 1, 32, 128), dtype=torch.float16) + k = torch.zeros((bs, 1, 32, 128), dtype=torch.float16) + v = torch.zeros((bs, 1, 32, 128), dtype=torch.float16) + + print(f"Exporting decode_bs{bs}") + example_args = (q, k, v, seq_lens, start_positions, seq_block_ids, cache_state) + + @fxb.export_program( + name=f"decode_bs{bs}", + args=example_args, + ) + def _( + model, + q, + k, + v, + seq_lens, + start_positions, + seq_block_ids, + cache_state, + ): + + if llama_config.is_causal: + attention_mask = None + else: + input_mask = causal_model.input_mask( + seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride + ) + attention_mask = causal_model.decode_attention_mask(input_mask) + + h = run_llama( + model=model, + config=llama_config, + phase="decode", + xq=q, + xk=k, + xv=v, + attention_mask=attention_mask, + start_positions=start_positions, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + + return h + + bsizes = [] + for bs in llama_config.bs: + generate_batch_prefill(bs) + generate_batch_decode(bs) + bsizes.append(bs) + + if args.verbose: + for name, ep in fxb.programs.items(): + print(f"EXPORT {name}:\n{ep}") + + config = generate_params_json(hp, bsizes, bsizes) + print("GENERATED!") + + print("Exporting") + output = export(fxb) + print(f"Saving to '{args.output_mlir}'") + output.save_mlir(args.output_mlir) + json.dump(config, open(args.output_config, "w")) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py index 308e20ef4..445f44852 100644 --- a/sharktank/sharktank/kernels/__init__.py +++ b/sharktank/sharktank/kernels/__init__.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .attention import * +from .einsum_2args_q4 import * from .mmtfp import * from .mmt_block_scaled_offset_q4 import * from .mmt_block_scaled_q8 import * @@ -13,3 +14,4 @@ from .conv_2d_nchw_fchw import * from .pooling_nchw_sum import * from .base import * +from .bitcast import * diff --git a/sharktank/sharktank/kernels/base.py b/sharktank/sharktank/kernels/base.py index 8c99c81d9..ce792b525 100644 --- a/sharktank/sharktank/kernels/base.py +++ b/sharktank/sharktank/kernels/base.py @@ -12,7 +12,7 @@ from jinja2 import Environment, PackageLoader, select_autoescape -from shark_turbine.support.ir_imports import ( +from iree.turbine.support.ir_imports import ( FlatSymbolRefAttr, FunctionType, IrType, @@ -24,7 +24,7 @@ Value, ) -from shark_turbine.runtime.op_reg import ( +from iree.turbine.runtime.op_reg import ( def_library, CustomOp, KernelBuilder, @@ -32,7 +32,7 @@ TensorArg, ) -from shark_turbine.transforms.merger import Merger +from iree.turbine.transforms.merger import Merger from ..utils.logging import get_logger diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py index 11a6b5fc2..21f9e9ed4 100644 --- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py +++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py @@ -8,6 +8,8 @@ import torch +from iree.compiler.ir import IntegerType + __all__ = [ "batch_matmul_transpose_b", ] @@ -80,7 +82,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): spec_sig = f"L{a_ident}_R{b_ident}" template_file = "batch_matmul_transpose_b.mlir" target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}" - + cst_zero = "0" if IntegerType.isinstance(accum_type) else "0." # Template params. c_asm_type = f"tensor<{'x'.join('?' if d is None else str(d) for d in result_desc.spec_dims)}x{accum_type}>" @@ -93,5 +95,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): b_asm_type=b_asm_type, c_asm_type=c_asm_type, dtype=str(accum_type), + cst_zero=cst_zero, ) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/bitcast.py b/sharktank/sharktank/kernels/bitcast.py new file mode 100644 index 000000000..66850008f --- /dev/null +++ b/sharktank/sharktank/kernels/bitcast.py @@ -0,0 +1,138 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from sharktank.kernels.base import * + +import torch + +from iree.turbine.support.ir_imports import ( + ComplexType, + F16Type, + F32Type, + RankedTensorType, + ShapedType, + Value, + flow_d, + tensor_d, +) + +from iree.turbine.runtime.op_reg import ( + CustomOp, + KernelBuilder, + KernelSelection, +) + +__all__ = [ + "bitcast_to_complex", + "bitcast_to_real", +] + +_ftype_to_ctype_table = { + torch.float16: torch.complex32, + torch.float32: torch.complex64, +} + +_ctype_to_ftype_table = { + torch.complex32: torch.float16, + torch.complex64: torch.float32, +} + +_type_to_irtype_table = { + torch.float16: lambda: F16Type.get(), + torch.float32: lambda: F32Type.get(), + torch.complex32: lambda: ComplexType.get(F16Type.get()), + torch.complex64: lambda: ComplexType.get(F32Type.get()), +} + + +@CustomOp.register(library=LIBRARY) +class bitcast_to_complex(CustomOp): + + signature = "bitcast_to_complex(Tensor q) -> (Tensor)" + + def select(self, ksel: KernelSelection): + ta = ksel.arg_tensor(0) + + torch._check(ta.t.dtype in _ftype_to_ctype_table) + torch._check(isinstance(ta.t.shape[-1], int)) + + new_shape = [i for i in ta.t.shape] + new_shape[-1] = new_shape[-1] // 2 + + ctype = _ftype_to_ctype_table[ta.t.dtype] + ret = ksel.return_new_tensor(new_shape, dtype=ctype) + specialize_all_known_dims(ta) + specialize_all_known_dims(ret) + + def eager_execute(self, tensor): + return torch.view_as_complex(tensor.unflatten(-1, (-1, 2))) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + t = kb.arg_bindings[0] + result_desc = ksel.result_descs[0] + result_shape = [ + d if isinstance(d, int) else RankedTensorType.get_dynamic_size() + for d in result_desc.t.shape + ] + + dynamic_dims: list[Value] = [] + _append_dynamic_dims(kb, dynamic_dims, t) + + c64 = _type_to_irtype_table[result_desc.t.dtype]() + rtt = RankedTensorType.get(result_shape, c64) + result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result + kb.yield_results(result) + + +@CustomOp.register(library=LIBRARY) +class bitcast_to_real(CustomOp): + + signature = "bitcast_to_real(Tensor q) -> (Tensor)" + + def select(self, ksel: KernelSelection): + ta = ksel.arg_tensor(0) + + torch._check(ta.t.dtype in _ctype_to_ftype_table) + torch._check(isinstance(ta.t.shape[-1], int)) + + new_shape = [i for i in ta.t.shape] + new_shape[-1] = new_shape[-1] * 2 + + ftype = _ctype_to_ftype_table[ta.t.dtype] + ret = ksel.return_new_tensor(new_shape, dtype=ftype) + specialize_all_known_dims(ta) + specialize_all_known_dims(ret) + + def eager_execute(self, tensor): + return torch.view_as_real(tensor).flatten(-2, -1) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + t = kb.arg_bindings[0] + result_desc = ksel.result_descs[0] + result_shape = [ + d if isinstance(d, int) else RankedTensorType.get_dynamic_size() + for d in result_desc.t.shape + ] + + dynamic_dims: list[Value] = [] + _append_dynamic_dims(kb, dynamic_dims, t) + + ftype = _type_to_irtype_table[result_desc.t.dtype]() + rtt = RankedTensorType.get(result_shape, ftype) + result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result + kb.yield_results(result) + + +################################################################################ +# Emission utilities +################################################################################ + + +def _append_dynamic_dims(kb: KernelBuilder, dynamic_dims: list[Value], tensor: Value): + rtt = RankedTensorType(tensor.type) + for i in range(rtt.rank): + if rtt.is_dynamic_dim(i): + dynamic_dims.append(tensor_d.dim(tensor, kb.constant_index(i))) diff --git a/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py b/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py index 9ada3b099..529511e02 100644 --- a/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py +++ b/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py @@ -23,6 +23,8 @@ (torch.int16, torch.int16, "torch.int16"): torch.int16, (torch.int16, torch.int16, "torch.int32"): torch.int32, # Legal fp types. + (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float16"): torch.float16, + (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float32"): torch.float32, (torch.float16, torch.float16, "torch.float16"): torch.float16, (torch.float16, torch.float16, "torch.float32"): torch.float32, (torch.float32, torch.float32, "torch.float32"): torch.float32, @@ -33,6 +35,7 @@ torch.int8: "i8", torch.int16: "i16", torch.int32: "i32", + torch.float8_e4m3fnuz: "f8E4M3FNUZ", torch.float16: "f16", torch.float32: "f32", } diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py new file mode 100644 index 000000000..76d8ad61c --- /dev/null +++ b/sharktank/sharktank/kernels/einsum_2args_q4.py @@ -0,0 +1,259 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .base import * + +import torch + +__all__ = [ + "einsum_2args_q4", +] + + +def einsum_util(einsum_str): + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + es_set = set(es_out) + es_set = es_set.union(es_in0) + es_set = es_set.union(es_in1) + size = len(es_set) + imap = dict() + lmap = dict() + for i in range(len(es_out)): + imap[i] = es_out[i] + lmap[es_out[i]] = i + count = len(es_out) + for c in es_set: + if c not in lmap: + imap[count] = c + lmap[c] = count + count += 1 + + assert count == len(es_set) + + in0_idx = [lmap[i] for i in es_in0] + in1_idx = [lmap[i] for i in es_in1] + out_idx = [lmap[i] for i in es_out] + + input_idx_str = ", ".join(["d" + str(i) for i in range(size)]) + in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx]) + in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx]) + out_idx_str = ", ".join(["d" + str(i) for i in out_idx]) + + iterators = ", ".join( + ['"parallel"' if i in out_idx else '"reduction"' for i in range(size)] + ) + + affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>" + affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>" + affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>" + + indexing_maps = f"""{affine_map_in0}, + {affine_map_in1}, + {affine_map_out} +""" + + out_dyn_dim_size_str = "" + for c in es_out: + if c in es_in0: + out_dyn_dim_size_str += "%a" + str(es_in0.find(c)) + "," + elif c in es_in1: + if es_in1.find(c) == len(es_in1) - 1: + out_dyn_dim_size_str += "%b_unblocked_dim," + else: + out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + "," + else: + raise Exception("Invalid einsum string") + out_dyn_dim_size_str = out_dyn_dim_size_str[:-1] + return ( + (in0_idx, in1_idx, out_idx), + iterators, + indexing_maps, + out_dyn_dim_size_str, + ) + + +@CustomOp.register(library=LIBRARY) +class einsum_2args_q4(CustomOp): + """Einsum that takes two tensor inputs and returns one tensor. + + The first input is expected to be a normal tensor. + + The second input corresponds to the BlockScaledLayout and operates on planar `d` + and `qs` tensors as specified there: + + * `d`: `[..., K // BLOCK_SIZE, 1]` + * `qs`: `[..., K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8) + * `m`: `[..., K // BLOCK_SIZE, 1]` + """ + + signature = ( + "einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)" + ) + + def select(self, ksel: KernelSelection): + a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k + d_desc = ksel.arg_tensor(1) # Shape [N, K // BLOCK_SIZE, 1] + qs_desc = ksel.arg_tensor(2) # Shape [N, K // BLOCK_SIZE, BLOCK_SIZE // 2] + m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] + einsum_str = ksel.attr_str(4).v + + # a arg + a_dims = a_desc.t.shape + torch._check( + a_desc.t.dtype.is_floating_point, + lambda: f"einsum_2args_q4 arg 'a': Expected floating point (got {a_desc.t.dtype})", + ) + + # qs arg + *qs_dims, qs_group0, qs_bs_div_2 = qs_desc.t.shape + block_size = qs_bs_div_2 * 2 + + # d arg + *d_dims, d_group0, d_one = d_desc.t.shape + torch._check( + d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims), + lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})", + ) + + # m arg + *m_dims, m_group0, m_one = m_desc.t.shape + torch._check( + m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims), + lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})", + ) + # einsum_str + torch._check( + einsum_str.count(",") == 1 and einsum_str.count("->") == 1, + lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')", + ) + + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + es_set = set(es_out) + + shp = qs_desc.t.shape + b_dims = list(shp[:-2]) + [shp[-2] * block_size] + torch._check( + len(es_in0) == len(a_desc.t.shape) + and len(es_in1) + == len(qs_desc.t.shape) + - 1, # The quantized shape is larger until the blocks are collapsed + lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", + ) + torch._check( + len(es_in0) == len(set(es_in0)) + and len(es_in1) == len(set(es_in1)) + and len(es_in0) != 0 + and len(es_in1) != 0, + lambda: f"einsum_2args_q4 arg 'einsum_str': Unsupported einsum str (got '{einsum_str}')", + ) + + # Check corresponding dimensions match + for i in range(len(es_in0)): + a_dim = a_dims[i] + c = es_in0[i] + pos = es_in1.find(c) + if pos >= 0: + b_dim = b_dims[pos] + torch._check( + a_dim == b_dim, + lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dim for idx {c} (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})", + ) + + # Determine the output shape by referencing corresponding input shapes + out_dims = [] + for c in es_out: + pos0 = es_in0.find(c) + pos1 = es_in1.find(c) + a_dim = a_dims[pos0] + b_dim = b_dims[pos1] + if pos0 >= 0: + out_dims.append(a_dim) + elif pos1 >= 0: + out_dims.append(b_dim) + else: + torch._check( + False, + lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')", + ) + + # Specialize on BS + qs_desc.specialize_dims(-1) + d_desc.specialize_dims(-1) + m_desc.specialize_dims(-1) + + # Shape batch..., m, n + c_desc = ksel.return_new_tensor(out_dims, dtype=a_desc.t.dtype) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + a = kb.arg_value(0) + a_tensor_type = RankedTensorType(a.type) + d = kb.arg_value(1) + d_tensor_type = RankedTensorType(d.type) + qs = kb.arg_value(2) + qs_tensor_type = RankedTensorType(qs.type) + einsum_str = ksel.arg_descs[4].v + # einsum_str = "mek,menk->men" + + es_in, es_out = einsum_str.split("->") + es_in0, es_in1 = es_in.split(",") + + es_name = "_".join([es_in0, es_in1, es_out]) + + ( + (es_0, es_1, es_2), + einsum_iterators, + einsum_indexing_maps, + oddss, + ) = einsum_util(einsum_str) + + rank1 = len(es_1) + dequant_iterators = ", ".join( + ['"parallel"' for i in range(rank1 + 1)] + ) # rank + 1 because of the group dimensions + input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)]) + broadcast_idx_str = ", ".join( + ["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)] + ) + affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>" + affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>" + dequant_indexing_maps = f"""{affine_map_broadcast}, + {affine_map_broadcast}, + {affine_map_parallel}, + {affine_map_parallel}""" + + size_str = "x".join("?" for i in range(rank1 - 2)) + + rank = a_tensor_type.rank + *n_dims, group0, bs_i8 = qs_tensor_type.shape + bs = bs_i8 * 2 # 2 nibbles per byte. + group = group0 * bs + a_type_str = str(a_tensor_type.element_type) + scale_type_str = str(d_tensor_type.element_type) + + template_file = "einsum_2args_q4.mlir" + target_function_name = f"sharktank_einsum_2args_q4_{es_name}_{bs}_{a_type_str}" + + target_function = inline_template_function( + kb, + template_file, + target_function_name, + bs=bs, + bs_i8=bs_i8, + a_type=a_type_str, + scale_type=scale_type_str, + dequant_indexing_maps=dequant_indexing_maps, + dequant_iterator_types=dequant_iterators, + einsum_indexing_maps=einsum_indexing_maps, + einsum_iterator_types=einsum_iterators, + es_name=es_name, + a_size=len(es_in0), + b_size=len(es_in1), + c_size=len(es_out), + out_dyn_dim_size_str=oddss, + ) + kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py index 2ed171115..0c8a61f32 100644 --- a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py +++ b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py @@ -37,28 +37,33 @@ def select(self, ksel: KernelSelection): m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1] # a arg - *batch_dims, a_m, a_k = a_desc.t.shape + *a_batch_dims, a_m, a_k = a_desc.t.shape torch._check( a_desc.t.dtype.is_floating_point, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})", ) torch._check( - len(batch_dims) == 1, + len(a_batch_dims) == 1, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})", ) # qs arg - qs_n, qs_group0, qs_bs_div_2, *rest = qs_desc.t.shape + *qs_batch_dims, qs_n, qs_group0, qs_bs_div_2 = qs_desc.t.shape torch._check( - len(rest) == 0 and (qs_group0 * qs_bs_div_2 * 2) == a_k, + ( + len(qs_batch_dims) == 0 + or len(qs_batch_dims) == 1 + and qs_batch_dims == a_batch_dims + ) + and (qs_group0 * qs_bs_div_2 * 2) == a_k, lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape})", ) block_size = qs_bs_div_2 * 2 # d arg - d_n, d_group0, d_one, *rest = d_desc.t.shape + *d_batch_dims, d_n, d_group0, d_one = d_desc.t.shape torch._check( - len(rest) == 0 + d_batch_dims == qs_batch_dims and (d_group0 * block_size) == a_k and d_one == 1 and d_n == qs_n, @@ -66,9 +71,9 @@ def select(self, ksel: KernelSelection): ) # m arg - m_n, m_group0, m_one, *rest = m_desc.t.shape + *m_batch_dims, m_n, m_group0, m_one = m_desc.t.shape torch._check( - len(rest) == 0 + m_batch_dims == qs_batch_dims and (m_group0 * block_size) == a_k and m_one == 1 and m_n == qs_n, @@ -81,12 +86,17 @@ def select(self, ksel: KernelSelection): # Specialize on K, N, BS a_desc.specialize_dims(-1) - qs_desc.specialize_all_dims() - d_desc.specialize_all_dims() - m_desc.specialize_all_dims() + if len(qs_batch_dims) == 0: + qs_desc.specialize_all_dims() + d_desc.specialize_all_dims() + m_desc.specialize_all_dims() + else: + qs_desc.specialize_dims(1, 2, 3) + d_desc.specialize_dims(1, 2, 3) + m_desc.specialize_dims(1, 2, 3) # Shape batch..., m, n - c_desc = ksel.return_new_tensor(batch_dims + [a_m, d_n], dtype=a_desc.t.dtype) + c_desc = ksel.return_new_tensor(a_batch_dims + [a_m, d_n], dtype=a_desc.t.dtype) c_desc.specialize_dims(-1) def generate(self, ksel: KernelSelection, kb: KernelBuilder): @@ -99,13 +109,14 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): rank = a_tensor_type.rank k = a_tensor_type.get_dim_size(rank - 1) - n, group0, bs_i8 = qs_tensor_type.shape + *qs_batch_dims, n, group0, bs_i8 = qs_tensor_type.shape + batched_rhs = len(qs_batch_dims) == 1 bs = bs_i8 * 2 # 2 nibbles per byte. a_type_str = str(a_tensor_type.element_type) scale_type_str = str(d_tensor_type.element_type) template_file = "mmt_block_scaled_offset_q4_unsigned.mlir" - target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}" + target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}_{batched_rhs}" target_function = inline_template_function( kb, @@ -118,5 +129,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder): group0=group0, a_type=a_type_str, scale_type=scale_type_str, + batched_rhs=batched_rhs, ) kb.yield_results(*call_function(target_function, *kb.arg_bindings)) diff --git a/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir b/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir index 908ca1c7f..056f72af3 100644 --- a/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir +++ b/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir @@ -15,7 +15,7 @@ module { util.func private @sharktank_batch_matmul_transpose_b_{{spec_sig}}( %a: !a_tensor_type, %b: !b_tensor_type) -> !c_tensor_type { - %zero = arith.constant 0: !dtype + %zero = arith.constant {{cst_zero}}: !dtype %c0 = arith.constant 0: index %c1 = arith.constant 1: index %batch_dim = tensor.dim %a, %c0 : !a_tensor_type // b, m, k diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir new file mode 100644 index 000000000..47ca6b331 --- /dev/null +++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir @@ -0,0 +1,104 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +{% set accum_type = "f32" %} + +!lowp_type = i4 +!a_type = {{a_type}} +!scale_type = {{scale_type}} +!accum_type = {{accum_type}} +!a_tensor_type = tensor<{% for i in range(a_size) %}?x{% endfor %}!a_type> +!qs_raw_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs_i8}}xi8> +!qs_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!lowp_type> +!d_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> +!m_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type> +!accum_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!accum_type> +!c_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!a_type> +!b_grouped_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!a_type> +!b_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}!a_type> + +module { + +util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}( + %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) + -> !c_tensor_type { + %debug = tensor.empty() : tensor<1xf32> + %zero = arith.constant 0.0: !accum_type + {% for i in range(a_size) %} + %k{{i}} = arith.constant {{i}} : index + {% endfor %} + {% for i in range(a_size, b_size) %} + %k{{i}} = arith.constant {{i}} : index + {% endfor %} + {% for i in range(a_size) %} + %a{{i}} = tensor.dim %a, %k{{i}}: !a_tensor_type + {% endfor %} + {% for i in range(b_size) %} + %b{{i}} = tensor.dim %qs_raw, %k{{i}}: !qs_raw_tensor_type + {% endfor %} + %bs = arith.constant {{bs}} : index + %b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index + + //%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} + + // Dequantize. + %b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type + %b_grouped_dequant = linalg.generic { + indexing_maps = [ + {{dequant_indexing_maps}}], + iterator_types = [{{dequant_iterator_types}}] } + ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) + outs(%b_grouped : !b_grouped_tensor_type) { + ^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): + %q_element_ext = arith.extui %q_element : !lowp_type to i32 + %q_element_fp = arith.uitofp %q_element_ext : i32 to !a_type + {% if scale_type == a_type %} + %q_element_scaled = arith.mulf %q_element_fp, %d_element : !a_type + %q_element_offset = arith.addf %q_element_scaled, %m_element : !a_type + {% else %} + %d_element_ext = arith.extf %d_element : !scale_type to !a_type + %m_element_ext = arith.extf %m_element : !scale_type to !a_type + %q_element_scaled = arith.mulf %q_element_fp, %d_element_ext : !a_type + %q_element_offset = arith.addf %q_element_scaled, %m_element_ext : !a_type + {% endif %} + linalg.yield %q_element_offset : !a_type + } -> !b_grouped_tensor_type + + // Collapse %b to the same unblocked structure. + %b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type + + // Einsum + %result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type + %result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type + %result = linalg.generic { + indexing_maps = [ + {{einsum_indexing_maps}}], + iterator_types = [{{einsum_iterator_types}}] } + ins(%a, %b_unblocked : !a_tensor_type, !b_tensor_type) + outs(%result_fill : !accum_tensor_type) { + ^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type): + %bmm_mul = arith.mulf %a_element, %b_element : !a_type + {% if accum_type == a_type %} + %bmm_accum = arith.addf %bmm_mul, %out : !a_type + {% else %} + %bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type + %bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type + {% endif %} + linalg.yield %bmm_accum : !accum_type + } -> !accum_tensor_type + + // Cast. + %result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type + %result_cast = linalg.copy + ins(%result : !accum_tensor_type) + outs(%result_cast_empty : !c_tensor_type) -> !c_tensor_type + + //iree_input.tensor.trace "foobar" = [%a : !a_tensor_type, %d : !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type, %b_grouped_dequant: !b_grouped_tensor_type] + util.return %result_cast : !c_tensor_type +} + +} diff --git a/sharktank/sharktank/kernels/templates/flash_attention.mlir b/sharktank/sharktank/kernels/templates/flash_attention.mlir index db76db84f..15d75c372 100644 --- a/sharktank/sharktank/kernels/templates/flash_attention.mlir +++ b/sharktank/sharktank/kernels/templates/flash_attention.mlir @@ -7,6 +7,7 @@ !q_type = tensor !k_type = tensor !v_type = tensor +!trans_v_type = tensor !o_type = tensor !o_dyn_type = tensor !s_type = tensor<{{scale_type}}> @@ -32,15 +33,22 @@ util.func private @sharktank_flash_attention_{{l}}_{{s}}_{{d}}_{{e}}_{{i_type}}_ %scale = tensor.extract %s[] : !s_type + %init_trans_v = tensor.empty(%b0, %b1) : !trans_v_type + %transpose_v = linalg.transpose ins(%v: !v_type) outs(%init_trans_v: !trans_v_type) permutation = [0, 1, 3, 2] + %empty_dyn = tensor.empty(%b0, %b1, %l, %e) : !o_dyn_type %empty = tensor.cast %empty_dyn : !o_dyn_type to !o_type %atten = iree_linalg_ext.attention {indexing_maps = [ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]} - ins(%q, %k, %v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) -> !o_type + affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>]} + ins(%q, %k, %transpose_v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score : f32 + } -> !o_type util.return %atten : !o_type } } diff --git a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir index a7f3138cb..afe2928c0 100644 --- a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir +++ b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir @@ -12,17 +12,25 @@ !accum_type = {{accum_type}} !a_tensor_type = tensor !aexp_tensor_type = tensor +{% if batched_rhs %} +!qs_raw_tensor_type = tensor +!qs_tensor_type = tensor +!d_tensor_type = tensor +!m_tensor_type = tensor +!b_grouped_tensor_type = tensor +{% else %} !qs_raw_tensor_type = tensor<{{n}}x{{group0}}x{{bs_i8}}xi8> !qs_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!lowp_type> !d_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type> !m_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type> +!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type> +{% endif %} !accum_tensor_type = tensor !c_tensor_type = tensor -!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type> module { -util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}( +util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}_{{batched_rhs}}( %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type) -> !c_tensor_type { %zero = arith.constant 0.0: !accum_type @@ -32,17 +40,31 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ %m_dim = tensor.dim %a, %c1 : !a_tensor_type // Cast qs_raw from i8 to lowp type. +{% if batched_rhs %} + %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{ %batch0_dim } -> !qs_tensor_type{ %batch0_dim } + %b_grouped = tensor.empty(%batch0_dim) : !b_grouped_tensor_type +{% else %} %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type + %b_grouped = tensor.empty() : !b_grouped_tensor_type +{% endif %} // Dequantize. - %b_grouped = tensor.empty() : !b_grouped_tensor_type %b_grouped_dequant = linalg.generic { +{% if batched_rhs %} + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } +{% else %} indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"] } +{% endif %} ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type) outs(%b_grouped : !b_grouped_tensor_type) { ^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type): @@ -70,7 +92,7 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_ indexing_maps = [ // d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r) affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>, - affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> ({% if batched_rhs %}d0,{% endif %} d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] } ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type) diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index 181544763..fd56ec872 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -16,6 +16,6 @@ from .paged_llama_attention_block import PagedLlamaAttentionBlock from .ffn_block import FFN from .ffn_moe_block import FFNMOE -from .mixture_of_experts_block import SparseMoeBlock +from .mixture_of_experts_block import MoeBlock -from . import configs +from .configs import * diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 3d91683fe..8ace77981 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -33,12 +33,14 @@ def __init__( device: Optional[torch.device] = None, activation_dtype: torch.dtype = torch.float32, attention_dtype: torch.dtype = torch.float32, + fake_quant: bool = True, ): super().__init__(theta) self.device = device self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype self.context_length = context_length + self.fake_quant = fake_quant if static_tables: self.register_buffer( @@ -89,16 +91,15 @@ def input_mask( masked. """ range_vector = torch.arange(0, batch_seqlen, 1, device=self.device) - matrix = torch.unsqueeze(seq_lens, dim=-1) + matrix = seq_lens.unsqueeze(dim=-1) mask = range_vector >= matrix return mask def decode_attention_mask(self, boolean_input_mask: torch.Tensor): dtype = self.attention_dtype - numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype) - numeric_mask.masked_fill_( - boolean_input_mask, self._maximally_negative_value(dtype) - ) + numeric_mask = torch.where( + boolean_input_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device) def attention_mask( @@ -127,9 +128,10 @@ def attention_mask( dtype = self.attention_dtype _, batch_seq_len = input_mask.shape causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len] - boolean_mask = causal_mask + input_mask[:, None, None, :] - numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype) - numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype)) + boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :]) + numeric_mask = torch.where( + boolean_mask, self._maximally_negative_value(dtype), 0 + ).to(dtype) return numeric_mask.to(self.device) def extract_tokens_from_logits( diff --git a/sharktank/sharktank/layers/configs/__init__.py b/sharktank/sharktank/layers/configs/__init__.py index 21336d1d2..c5d75c602 100644 --- a/sharktank/sharktank/layers/configs/__init__.py +++ b/sharktank/sharktank/layers/configs/__init__.py @@ -4,4 +4,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .llm_configs import LlamaHParams +from .llm_configs import * diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index ab3a582e4..35a2ee570 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -14,12 +14,11 @@ (and indeed, can bootstrap these off of GGUF files). """ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional - import torch -__all__ = ["LlamaHParams"] +__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"] @dataclass @@ -29,51 +28,79 @@ class LlamaHParams: Comments are only provided if they differ from this source. """ + model_arch: str context_length: int embedding_length: int block_count: int feed_forward_length: int - rope_dimension_count: int - rope_freq_base: float attention_head_count: int attn_head_dim: int attention_layer_norm_rms_epsilon: float attention_head_count_kv: int - expert_count: int - expert_used_count: int + rope_dimension_count: Optional[int] = None + rope_freq_base: Optional[float] = None + expert_count: Optional[int] = None + expert_used_count: Optional[int] = None @staticmethod def from_gguf_props(p: dict[str, Any]): + name_prefix = p.get("general.architecture", "llama") default_expert_count = 0 default_expert_used_count = 0 default_rope_freq_base = 10000.0 - attention_head_count = _int_prop(p, "llama.attention.head_count") + default_rope_dimension_count = 128 + attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count") + rope_dimension_count = _optional_int_prop( + p, f"{name_prefix}.rope.dimension_count", default_rope_dimension_count + ) return LlamaHParams( - context_length=_int_prop(p, "llama.context_length"), - embedding_length=_int_prop(p, "llama.embedding_length"), - block_count=_int_prop(p, "llama.block_count"), - feed_forward_length=_int_prop(p, "llama.feed_forward_length"), - attn_head_dim=_int_prop(p, "llama.rope.dimension_count"), - rope_dimension_count=_int_prop(p, "llama.rope.dimension_count"), + model_arch=name_prefix, + context_length=_int_prop(p, f"{name_prefix}.context_length"), + embedding_length=_int_prop(p, f"{name_prefix}.embedding_length"), + block_count=_int_prop(p, f"{name_prefix}.block_count"), + feed_forward_length=_int_prop(p, f"{name_prefix}.feed_forward_length"), attention_head_count=attention_head_count, attention_layer_norm_rms_epsilon=_float_prop( - p, "llama.attention.layer_norm_rms_epsilon" + p, f"{name_prefix}.attention.layer_norm_rms_epsilon" ), attention_head_count_kv=_optional_int_prop( - p, "llama.attention.head_count_kv", attention_head_count + p, f"{name_prefix}.attention.head_count_kv", attention_head_count ), + attn_head_dim=rope_dimension_count, + rope_dimension_count=rope_dimension_count, rope_freq_base=_optional_float_prop( - p, "llama.rope.freq_base", default_rope_freq_base + p, f"{name_prefix}.rope.freq_base", default_rope_freq_base ), expert_count=_optional_int_prop( - p, "llama.expert_count", default_expert_count + p, f"{name_prefix}.expert_count", default_expert_count ), expert_used_count=_optional_int_prop( - p, "llama.expert_used_count", default_expert_used_count + p, f"{name_prefix}.expert_used_count", default_expert_used_count ), ) + def to_gguf_props(self) -> dict[str, Any]: + res = { + "general.architecture": self.model_arch, + f"{self.model_arch}.context_length": self.context_length, + f"{self.model_arch}.embedding_length": self.embedding_length, + f"{self.model_arch}.block_count": self.block_count, + f"{self.model_arch}.feed_forward_length": self.feed_forward_length, + f"{self.model_arch}.attention.head_count": self.attention_head_count, + f"{self.model_arch}.attention.layer_norm_rms_epsilon": self.attention_layer_norm_rms_epsilon, + f"{self.model_arch}.attention.head_count_kv": self.attention_head_count_kv, + } + if self.rope_dimension_count is not None: + res[f"{self.model_arch}.rope.dimension_count"] = self.rope_dimension_count + if self.rope_freq_base is not None: + res[f"{self.model_arch}.rope.freq_base"] = self.rope_freq_base + if self.expert_count is not None: + res[f"{self.model_arch}.expert_count"] = self.expert_count + if self.expert_used_count is not None: + res[f"{self.model_arch}.expert_used_count"] = self.expert_used_count + return res + def _float_prop(p: dict[str, Any], name: str) -> float: try: @@ -107,3 +134,122 @@ def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int: return int(value) except ValueError as e: raise ValueError(f"Property '{name}' expected to be an int and was not") from e + + +@dataclass +class LlamaModelConfig: + hp: LlamaHParams + + # Block sequence stride for a paged KV cache. This must divide evenly + # into the context length. + block_seq_stride: int = 16 + + # Either "paged" or "direct". + kv_cache_type: str = "paged" + + # The device on which to place intermediate state. + device: Optional[torch.device] = None + + # Dtype to use for general FP activations not otherwise configured. + activation_dtype: torch.dtype = torch.float16 + + # Dtype to use for attention. + attention_dtype: torch.dtype = torch.float16 + + # fake quant determines the mode the Layer Thetas operate w.r.t quantized tensors. + fake_quant: bool = True + + # How many devices are involved for tensor parallel sharding. + # If greater than 1, the model will expect sharded model parameters and function + # arguments. + tensor_parallelism_size: int = 1 + + # Which attention kernel to use. + attention_kernel: str = "decomposed" + + # Indicates if running with HuggingFace implementation and ensures + # numerical equivalency to HuggingFace's LLaMa if true (by modifying + # rotary embedding). + use_hf: bool = False + + # If true, then the model may pre-initialize certain tables during + # init. This can be better for eager execution but when capturing a program, + # it is often better to preserve the calculation explicitly and rely on + # the compiler to transform it to an initialization time step. This can + # be the difference of many gigabytes of static data being embedded in + # the program and not. + static_tables: bool = True + + +@dataclass +class T5Config: + return_dict: bool = True + output_hidden_states: bool = False + output_attentions: bool = False + is_encoder_decoder: bool = True + is_decoder: bool = False + vocab_size: int = 32128 + context_length: int = 512 + d_model: int = 512 + d_kv: int = 64 + d_ff: int = 2048 + num_layers: int = 6 + num_decoder_layers: int = 6 + num_heads: int = 8 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + layer_norm_epsilon: float = 1e-6 + feed_forward_proj: str = "relu" + is_gated_act: bool = field(init=False) + activation_dtype: torch.dtype = torch.float32 + dense_act_fn: str = field(init=False) + use_cache: bool = True + pad_token_id: int = 0 + eos_token_id: int = 1 + decoder_start_token_id: int = 0 + context_length_padding_block_size: int = 16 + + def __post_init__(self): + self.is_gated_act = self.feed_forward_proj.startswith("gated-") + self.dense_act_fn = ( + self.feed_forward_proj.split("-")[1] + if "-" in self.feed_forward_proj + else self.feed_forward_proj + ) + if self.dense_act_fn == "gelu": + self.dense_act_fn = "gelu_new" + + @staticmethod + def from_gguf_properties(properties: dict[str, Any], **kwargs): + assert properties["general.architecture"] == "t5" + assert ( + properties["t5.attention.layer_norm_epsilon"] + == properties["t5.attention.layer_norm_rms_epsilon"] + ) + + gguf_to_config_names_map = { + "t5.context_length": ["context_length"], + "t5.embedding_length": ["d_model"], + "t5.feed_forward_length": ["d_ff"], + "t5.block_count": ["num_layers", "num_decoder_layers"], + "t5.attention.head_count": ["num_heads"], + "t5.attention.key_length": ["d_kv"], + "t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"], + "t5.attention.relative_buckets_count": ["relative_attention_num_buckets"], + "t5.decoder_start_token_id": ["decoder_start_token_id"], + "tokenizer.ggml.eos_token_id": ["eos_token_id"], + "tokenizer.ggml.padding_token_id": ["pad_token_id"], + } + all_kwargs = {"vocab_size": None, "feed_forward_proj": None} + all_kwargs.update( + { + config_name: properties[gguf_name] + for gguf_name, config_names in gguf_to_config_names_map.items() + for config_name in config_names + } + ) + if "tokenizer.ggml.tokens" in properties: + all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"]) + all_kwargs.update(kwargs) + + return T5Config(**all_kwargs) diff --git a/sharktank/sharktank/layers/ffn_block.py b/sharktank/sharktank/layers/ffn_block.py index 420b06893..cde603b98 100644 --- a/sharktank/sharktank/layers/ffn_block.py +++ b/sharktank/sharktank/layers/ffn_block.py @@ -4,10 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional +from typing import Optional, Callable import torch import torch.nn.functional as F +from .. import ops +from ..types import AnyTensor from .base import Theta, ThetaLayer from .linear import LinearLayer @@ -20,18 +22,29 @@ class FFN(ThetaLayer): def __init__( self, theta: Theta, + is_gated: bool = True, + activation_fn: Callable[[AnyTensor], AnyTensor] = F.silu, ): super().__init__(theta) - self.add_module("ffn_gate", LinearLayer(theta("ffn_gate"))) + self.is_gated = is_gated + self.activation_fn = activation_fn + if self.is_gated: + self.add_module("ffn_gate", LinearLayer(theta("ffn_gate"))) self.add_module("ffn_up", LinearLayer(theta("ffn_up"))) self.add_module("ffn_down", LinearLayer(theta("ffn_down"))) def forward( self, - h: torch.Tensor, - ): - ffn_gate = F.silu(self.ffn_gate(h)) - ffn_up = self.ffn_up(h) - ffn_down = self.ffn_down(ffn_gate * ffn_up) - return ffn_down + h: AnyTensor, + ) -> AnyTensor: + if self.is_gated: + ffn_gate = ops.elementwise(self.activation_fn, self.ffn_gate(h)) + ffn_up = self.ffn_up(h) + ffn_down = self.ffn_down(ffn_gate * ffn_up) + return ffn_down + else: + h = self.ffn_up(h) + h = ops.elementwise(self.activation_fn, h) + h = self.ffn_down(h) + return h diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index d833882f5..2adb9464f 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -12,12 +12,68 @@ from .base import ThetaLayer from .linear import LinearLayer from ..types import Theta, DefaultPrimitiveTensor +from ..ops import einsum_2args, elementwise __all__ = [ "FFNMOE", + "PreGatherFFNMOE", ] +class PreGatherFFNMOE(ThetaLayer): + def __init__( + self, + theta: Theta, + activation=F.silu, + ): + + super().__init__(theta) + + self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") + self.ffn_up = theta.tensor("ffn_up_exps", "weight") + self.ffn_down = theta.tensor("ffn_down_exps", "weight") + self.activation = activation + + def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): + inputs = inputs[:, :] + weights = weights[experts, :, :] + matmul = einsum_2args(inputs, weights, einstring) + return matmul + + def bigger_mmg(self, inputs, weights, experts): + inputs = inputs[:, :] + weights = weights[experts, :, :] + matmul = einsum_2args(inputs, weights, "mek,menk->men") + return matmul + + def one_hot_matmul(self, inputs, weights, experts): + matmul = einsum_2args(inputs, weights, "mk,bnk->bmn") + # Post mix the experts + oh = ( + torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8) + .transpose(0, 1) + .to(torch.float32) + ) + output = einsum_2args(oh, matmul, "bm,bmn->mn") + return output + + def forward( + self, + h: torch.Tensor, + experts: torch.Tensor, + expert_gate: torch.Tensor, + ): + ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) + ffn_gate = elementwise(self.activation, ffn_gate) + + ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) + ffn_down = self.pre_matmul_gather( + ffn_gate * ffn_up, self.ffn_down, experts, einstring="mek,menk->men" + ) + ffn_down = einsum_2args(expert_gate, ffn_down, "me,men->men") + return torch.sum(ffn_down, dim=1) + + class FFNMOE(ThetaLayer): def __init__( self, diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 47b465bd2..c73b7a8f4 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -11,7 +11,7 @@ and dims floating around everywhere. """ -from typing import Optional +from typing import Optional, Union, List import abc import math @@ -19,6 +19,8 @@ import torch from ..utils.debugging import trace_tensor +from ..types import SplitPrimitiveTensor, ReplicatedTensor +from .. import ops __all__ = [ "BaseKVCache", @@ -90,6 +92,7 @@ def __init__( attn_head_count: int, attn_head_dim: int, seq_length: int, + shard_count: int = 1, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): @@ -98,6 +101,7 @@ def __init__( self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.seq_length = seq_length + self.shard_count = shard_count self.device = device self.dtype = dtype @@ -111,15 +115,109 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]: Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim] """ - return [ + allocations = [ torch.empty( - [bs, self.seq_length, self.attn_head_count, self.attn_head_dim], + [ + bs, + self.seq_length, + self.attn_head_count, + self.attn_head_dim, + ], dtype=self.dtype, device=self.device, ) for _ in range(2 * self.transformer_block_count) ] + if self.shard_count == 1: + return allocations + + return [ + ops.reshard_split(allocation, dim=2, count=self.shard_count) + for allocation in allocations + ] + + def read( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + transformer_block_index: int, + seq_len: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Reads cache partitions from the page table for the given page_ids. + + Args: + state: State struct as returned from allocate(). + read_into_partitions: List of cache partitions to read into in-place. + transformer_block_index: The index of the transformer block accessing + the cache. + page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids + to access. + + Returns a tuple of cache partitions (i.e. k and v caches for the transformer + block), linearized. Note that this reference approach to reading by + materializing linearly may not be terribly efficient unless if the + compiler can fuse the gather. + """ + read_count = len(read_into_partitions) + reads = [] + for i in range(read_count): + reads.append( + state[transformer_block_index * read_count + i][:, :seq_len, :, :] + ) + + return tuple(reads) + + def write_timestep( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + # List of [bs, 1, attn_head_count, attn_head_dim] + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + # [bs] + seq_positions: Union[torch.Tensor, ReplicatedTensor], + # [bs, max_seqlen // block_pos_stride] + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes a single batched timestep across all cache partitions. + + Note that this internally loops over the batch size, which cannot be + dynamic. + """ + bs, _, _, _ = cache_partitions[0].shape + update_count = len(cache_partitions) + + for b in range(bs): + row_index = torch.tensor([b], dtype=torch.int64) + row_start_pos = seq_positions[row_index].unsqueeze(0) + + for i, update in enumerate(cache_partitions): + cache = state[transformer_block_index * update_count + i] + cache.index_put_((row_index, row_start_pos), update[row_index, 0]) + + def write( + self, + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + *, + transformer_block_index: int, + page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, + ): + """Writes cache partitions from a linear layout to the page table. + + This is the inverse of the linear read. The same caveat applies if the + in-place scatter cannot be fused. + """ + update_count = len(cache_partitions) + + for idx, update_src in enumerate(cache_partitions): + cache_dest = state[transformer_block_index * update_count + idx] + _, batch_seq_len, _, _ = update_src.shape + cache_dest[:, :batch_seq_len, :, :] = update_src + class PagedKVCache(BaseKVCache): """Implementation of a KV cache on top of a 'page table'. @@ -138,6 +236,11 @@ class PagedKVCache(BaseKVCache): Note that the internal page structure matches the organization of the model, allowing contiguous individual local reads and writes at a sub-block granularity if indexing deeply into the structure. + + When `shard_count > 1`, it would split the `attn_head_count` dimension. + The page slab is a 1D sharded split tensor. + It is reinterpreted as a 6D tensor, by working around the lack of sharded + block-cyclic sharded tensor type. """ def __init__( @@ -150,30 +253,59 @@ def __init__( block_seq_stride: int = 16, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + shard_count: int = 1, ): self.transformer_block_count = transformer_block_count self.attn_head_count = attn_head_count self.attn_head_dim = attn_head_dim self.cache_partition_count = cache_partition_count self.block_seq_stride = block_seq_stride + self.shard_count = shard_count + if attn_head_count % shard_count != 0: + raise ValueError( + f"The attention head count {attn_head_count} must be a multiple of the tensor parallelism size {shard_count}." + ) # Some derived values based on attributes. self.sub_page_dims = [ self.transformer_block_count, self.cache_partition_count, self.block_seq_stride, - self.attn_head_count, + self.attn_head_count // self.shard_count, self.attn_head_dim, ] self.page_slab_flat_dim = math.prod(self.sub_page_dims) self.device = device self.dtype = dtype - def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor: + def unflatten_page_table( + self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]] + ) -> Union[torch.Tensor, SplitPrimitiveTensor]: """Unflattens the 2D page table to a 6D tensor.""" assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}" page_slab = state[0] - return page_slab.reshape( + if self.shard_count == 1: + assert not isinstance(page_slab, SplitPrimitiveTensor) + return page_slab.unflatten(1, self.sub_page_dims) + else: + assert self.shard_count == page_slab.shard_count + shards = [ + shard.unflatten(1, self.sub_page_dims) for shard in page_slab.shards + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=4) + + def shard_state( + self, state: List[torch.Tensor] + ) -> List[Union[torch.Tensor, SplitPrimitiveTensor]]: + """Shard an unsharded state. + We can't just split the slab on the sub page dims. + First it needs to be reinterpreted into the actual shape. + The split the head dimension, then flatten each shard. + This is a work-around for the lack of block-cyclic sharded tensor type.""" + if self.shard_count == 1: + return state + + page_table = state[0].reshape( [ -1, self.transformer_block_count, @@ -183,30 +315,47 @@ def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor: self.attn_head_dim, ] ) + sharded_page_table = ops.reshard_split( + page_table, dim=4, count=self.shard_count + ) + shards = [ + ops.flatten(shard, start_dim=1) for shard in sharded_page_table.shards + ] + flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1) + return [flat_sharded_page_table] @property def pad_sequence_stride(self) -> int: return self.block_seq_stride - def allocate(self, page_count: int) -> list[torch.Tensor]: + def allocate( + self, page_count: int + ) -> list[Union[torch.Tensor, SplitPrimitiveTensor]]: """Allocates tensor state for a page table for the given capacity in pages. """ - return [ + shards = [ torch.empty( [page_count, self.page_slab_flat_dim], dtype=self.dtype, device=self.device, ) + for _ in range(self.shard_count) ] + if self.shard_count == 1: + return shards + + return [SplitPrimitiveTensor(ts=shards, shard_dim=1)] + def read( self, - state: list[torch.Tensor], + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, - read_into_partitions: list[torch.Tensor], + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, - page_ids: torch.Tensor, + seq_len: int, + page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Reads cache partitions from the page table for the given page_ids. @@ -231,7 +380,7 @@ def read( bs, block_seq_len, self.block_seq_stride, - self.attn_head_count, + self.attn_head_count // self.shard_count, self.attn_head_dim, ] @@ -249,7 +398,9 @@ def read( transformer_block_index * transformer_block_stride ) - def read_cache_partition(index: int, into_partition: torch.Tensor): + def read_cache_partition( + index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor] + ): subblock_ids = ( (base_subblock_ids + index) if index > 0 else base_subblock_ids ) @@ -262,7 +413,7 @@ def read_cache_partition(index: int, into_partition: torch.Tensor): # a linear list. # TODO: Can be rewritten into inplace with out= on index_select. selected = ( - torch.index_select(subblock_table, 0, subblock_ids.flatten(0, 1)) + ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1)) .unflatten(0, blocked_shape[0:2]) .flatten(1, 2) ) @@ -272,17 +423,19 @@ def read_cache_partition(index: int, into_partition: torch.Tensor): for index, read_into_partition in enumerate(read_into_partitions): read_cache_partition(index, read_into_partition) + return tuple([p[:, :seq_len, :] for p in read_into_partitions]) + def write_timestep( self, - state: list[torch.Tensor], + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], # List of [bs, 1, attn_head_count, attn_head_dim] - cache_partitions: list[torch.Tensor], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, transformer_block_index: int, # [bs] - seq_positions: torch.Tensor, + seq_positions: Union[torch.Tensor, ReplicatedTensor], # [bs, max_seqlen // block_pos_stride] - page_ids: torch.Tensor, + page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Writes a single batched timestep across all cache partitions. @@ -293,29 +446,41 @@ def write_timestep( page_table = self.unflatten_page_table(state) # 6D bs, *_ = seq_positions.shape assert len(cache_partitions) == self.cache_partition_count - for i in range(bs): - position = seq_positions[i] - # TODO: Let's clamp to the allowable range so that we don't need - # an assert. - page_id = page_ids[i, :].index_select(0, position // self.block_seq_stride) - page_offset = position % self.block_seq_stride - for partition_index in range(self.cache_partition_count): - cache_partition = cache_partitions[partition_index] - indices = ( - page_id, - torch.tensor([transformer_block_index], device=device), - torch.tensor([partition_index], device=device), - page_offset.unsqueeze(0), - ) - page_table.index_put_(indices=indices, values=cache_partition[i, 0]) + + partition_count = len(cache_partitions) + + # [bs, partitions, atten_head_count, attn_head_dim] + cache_partitions = ops.cat(cache_partitions, dim=1) + + # [bs, 1] + page_index = seq_positions // self.block_seq_stride + + page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1)) + page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1) + + # [1, partitions] + partitions = torch.arange(0, self.cache_partition_count).unsqueeze(0) + + # [bs, partitions] + page_id = page_id.repeat(1, partition_count) + transformer_block = torch.full( + (bs, partition_count), transformer_block_index, device=device + ) + page_offset = page_offset.repeat(1, partition_count) + partitions = partitions.repeat(bs, 1) + + indices = (page_id, transformer_block, partitions, page_offset) + page_table.index_put_(indices=indices, values=cache_partitions) + + return def write( self, - state: list[torch.Tensor], - cache_partitions: list[torch.Tensor], + state: list[Union[torch.Tensor, SplitPrimitiveTensor]], + cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, transformer_block_index: int, - page_ids: torch.Tensor, + page_ids: Union[torch.Tensor, ReplicatedTensor], ): """Writes cache partitions from a linear layout to the page table. @@ -348,21 +513,21 @@ def write( transformer_block_index * transformer_block_stride ) - def write_cache_partition(index: int, part: torch.Tensor): - part_block_view = part.reshape(blocked_shape) + part_block_views = [] + subblock_ids_kv = [] + for index, partition in enumerate(cache_partitions): + part_block_view = partition.unflatten( + 1, (block_seq_len, self.block_seq_stride) + ) + part_block_view = part_block_view.flatten(0, 1) + part_block_views.append(part_block_view) + subblock_ids = ( (base_subblock_ids + index) if index > 0 else base_subblock_ids - ) - # TODO: Potentially clamp all page 0 indices to the mask value. - # Or even better, require that the ids are replicated such that access is - # legal. - # Now for each of the k/v attn_block_ids, which have been adjusted to - # index into the sub-pages, we flatten to do a linear index_select - # copy of the sub-blocks by collapsing the first two dims so we have - # a linear list. - subblock_table.index_copy_( - 0, subblock_ids.flatten(0, 1), part_block_view.flatten(0, 1) - ) + ).flatten(0, 1) + subblock_ids_kv.append(subblock_ids) - for index, partition in enumerate(cache_partitions): - write_cache_partition(index, partition) + subblock_ids = ops.cat(subblock_ids_kv) + part_block_view = ops.cat(part_block_views, dim=0) + + subblock_table.index_copy_(0, subblock_ids, part_block_view) diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index 86e43d715..b679dccde 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -7,16 +7,15 @@ from typing import Optional import torch - from .. import ops from .base import Theta, ThetaLayer -from ..types.layout_utils import saturate_cast from ..types import ( DynamicScaledQuantizer, QuantizedTensor, QuantizerTensor, StaticScaledQuantizer, TensorScaledLayout, + PlanarQuantizedTensor, ) __all__ = [ @@ -31,6 +30,10 @@ class LinearLayer(ThetaLayer): if premul_input is not None: x = x * premul_input matmul(x, weight.T) + bias + + fake_quant exists to allow export without adding dequant ops. + when fake_quant is True, the op will in quant dequant fashion. + When false, it will keep quantized types. ``` """ @@ -40,11 +43,13 @@ def __init__( *, weight_name: str = "weight", bias_name: str = "bias", + fake_quant: bool = True, ): super().__init__(theta) self._simulate_native_quant = True self.weight = self.theta_tensor(weight_name) self.bias = None + self.fake_quant = fake_quant if bias_name in self.theta.keys: self.bias = self.theta_tensor(bias_name) @@ -54,26 +59,36 @@ def __init__( self.qdq_input: Optional[QuantizedTensor] = theta.optional_tensor("qdq_input") if self.q_input is not None and self.qdq_input is not None: raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input") + self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output") def forward(self, x): weight = self.weight bias = self.bias q_input = self.q_input qdq_input = self.qdq_input - + qdq_output = self.qdq_output if self.premul_input is not None: x = ops.elementwise(torch.mul, x, self.premul_input) if q_input is not None: x = q_input.quantize(x) - elif qdq_input is not None: + if self.fake_quant: + x = x.unpack().dequant() + elif qdq_input is not None and self.fake_quant: x = qdq_input.quantize(x).unpack().dequant() y = ops.linear(x, weight, bias) # Unconditionally dequantize. - # TODO: Support a q_output specifier that signals the layer to let - # the QuantizedTensor escape. - if isinstance(y, QuantizedTensor): + if isinstance(y, QuantizedTensor) and not self.fake_quant: y = y.unpack().dequant() + # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. + # We can truncate to fp16 in iree, so we do a cast here + # to account for this in the IR. This is may not be the right + # level to do this, but for now its here. + if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz: + y = ops.to(y, torch.float16) + return y + if qdq_output is not None and self.fake_quant: + y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py index 7be8c7366..0cdb5d713 100644 --- a/sharktank/sharktank/layers/llama_attention_block.py +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -6,8 +6,6 @@ from typing import Optional -import math - import torch import torch.nn.functional as F @@ -110,7 +108,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index 09bf491b7..ddce16c55 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -13,14 +13,14 @@ from .base import Theta, ThetaLayer from .linear import LinearLayer from .norm import RMSNormLayer -from .ffn_moe_block import FFNMOE +from .ffn_moe_block import FFNMOE, PreGatherFFNMOE __all__ = [ - "SparseMoeBlock", + "MoeBlock", ] -class SparseMoeBlock(ThetaLayer): +class MoeBlock(ThetaLayer): """ This implementation considers MoE operations as block-sparse operations to support imbalanced token assignments to experts. @@ -34,9 +34,13 @@ def __init__( expert_count: int, expert_used_count: int, rms_epsilon: float, + moe_activation=F.silu, ): super().__init__(theta) + self.expert_count = expert_count + self.expert_used_count = expert_used_count + # Add router gate self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) @@ -45,13 +49,17 @@ def __init__( "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) ) - # Add expert_count x FFN - self.experts = nn.ModuleList( - [FFNMOE(theta, expert_idx=i) for i in range(expert_count)] - ) + # Add optional FFN output norm layer + if theta.optional_tensor("layer_output_norm") is not None: + self.add_module( + "layer_output_norm", + RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), + ) + else: + self.add_module("layer_output_norm", torch.nn.Identity()) - self.expert_count = expert_count - self.expert_used_count = expert_used_count + # Add expert_count x FFN + self.experts = PreGatherFFNMOE(theta, activation=moe_activation) def forward( self, @@ -67,42 +75,16 @@ def forward( router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # Select top k experts from router weights - router_weights, top_k_experts = torch.topk( + expert_gate, top_k_experts = torch.topk( router_weights, self.expert_used_count, dim=-1 ) - router_weights /= router_weights.sum(dim=-1, keepdim=True) - router_weights = router_weights.to(ffn_input.dtype) - moe_output = torch.zeros( - (batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype - ) - - # Create an expert mask by one hot encoding the selected top k experts - # used to index which expert is to be invoked for each token - # expert_mask: (expert_count, expert_used_count, sequence_length) - expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute( - 2, 1, 0 - ) - - # Iterate over all experts in the model - for expert_idx in range(self.expert_count): - expert_layer = self.experts[expert_idx] - top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx]) - - # Given the hidden states, index the tokens assigned to this expert - # and calculate the current expert's hidden state and weigh the - # output expert hidden states by the router weights, based on the - # appropriate tokens - current_expert_tokens = ffn_input[None, token_idx] + expert_gate /= expert_gate.sum(dim=-1, keepdim=True) + expert_gate = expert_gate.to(ffn_input.dtype) - current_expert = ( - expert_layer(current_expert_tokens) - * router_weights[token_idx, top_k_expert_idx, None] - ) - - current_expert = current_expert.reshape(-1, feature_dim) - - moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) + moe_output = self.experts(ffn_input, top_k_experts, expert_gate) moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) + moe_output = self.layer_output_norm(moe_output) + return h + moe_output diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py index d062f1ffb..4fa08050a 100644 --- a/sharktank/sharktank/layers/norm.py +++ b/sharktank/sharktank/layers/norm.py @@ -33,9 +33,9 @@ def __init__( def forward(self, x: torch.Tensor): orig_dtype = x.dtype - x = x.to(self.dtype) + x = ops.to(x, self.dtype) norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon) # Will automatically upcast to the dtype of the weight, which is # often in higher precision. Downcast back to expected. - norm = norm.to(orig_dtype) + norm = ops.to(norm, orig_dtype) return norm diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 41f3a20cc..22647bf49 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -10,12 +10,13 @@ import torch import torch.nn.functional as F - +from ..types import QuantizerTensor from .base import Theta, ThetaLayer from .linear import LinearLayer from .norm import RMSNormLayer from .rotary_embedding import RotaryEmbeddingLayer from .kv_cache import PagedKVCache +from .. import ops __all__ = [ "PagedLlamaAttentionBlock", @@ -36,24 +37,54 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_hf: bool = False, - + attention_kernel: str = "decomposed", + attention_scale: Optional[float] = None, + softcap: Optional[float] = None, + fake_quant: Optional[bool] = True, ): super().__init__(theta) - self.add_module( - "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) - ) - self.add_module("attn_q", LinearLayer(theta("attn_q"))) - self.add_module("attn_k", LinearLayer(theta("attn_k"))) - self.add_module("attn_v", LinearLayer(theta("attn_v"))) - self.add_module("attn_output", LinearLayer(theta("attn_output"))) self.block_index = block_index self.cache = cache self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv - self.use_hf = use_hf + self.attention_kernel = attention_kernel + self.attention_scale = attention_scale + self.softcap = softcap + self.fake_quant = fake_quant + + self.add_module( + "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) + ) + self.add_module( + "attn_q", LinearLayer(theta("attn_q"), fake_quant=self.fake_quant) + ) + self.add_module( + "attn_k", LinearLayer(theta("attn_k"), fake_quant=self.fake_quant) + ) + self.add_module( + "attn_v", LinearLayer(theta("attn_v"), fake_quant=self.fake_quant) + ) + self.add_module( + "attn_output", LinearLayer(theta("attn_output"), fake_quant=self.fake_quant) + ) + self.cache_quantizer = None + if "kv_cache" in theta.keys: + self.cache_quantizer: Optional[QuantizerTensor] = theta.optional_tensor( + "kv_cache.quantizer" + ) + + if theta.optional_tensor("attn_output_norm") is None: + self.add_module( + "attn_output_norm", + torch.nn.Identity(), + ) + else: + self.add_module( + "attn_output_norm", + RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon), + ) def forward( self, @@ -88,36 +119,48 @@ def forward( # Fast path to start_index based embedding lookup if available. # Falls back to a slower position based index lookup. if start_index is not None: - xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index) + xq = embedding.forward(xt=xq, start_index=start_index) + xk = embedding.forward(xt=xk, start_index=start_index) else: - xq, xk = embedding.apply_batched_mask( - xq=xq, xk=xk, mask=embedding_batch_mask - ) + xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask) + xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask) # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride - if self.cache.is_paged: - xk, xv = self.transact_cache_paged( - xk_cache_update=xk, - xv_cache_update=xv, - seq_block_ids=seq_block_ids, - kv_seq_len=kv_seq_len, - start_positions=start_positions, - cache_state=cache_state, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - elif self.cache.is_direct: - xk, xv = self.transact_cache_direct( - xk_cache_update=xk, - xv_cache_update=xv, - start_positions=start_positions, - kv_seq_len=kv_seq_len, - cache_state=cache_state, - ) - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}") + # Used by fp8_e4m3fnuz model + if self.cache_quantizer is not None: + # For fake quant, store the fp16 qdq value in the cache + if self.fake_quant: + xk = ( + self.cache_quantizer.quantize(xk) + .unpack() + .dequant() + .to(torch.float16) + ) + xv = ( + self.cache_quantizer.quantize(xv) + .unpack() + .dequant() + .to(torch.float16) + ) + # For real quant, store the quantized fp8 value in the cache + else: + # TODO: this seems like a bastardization of our quantized tensor api + # Probably want to add support for using quantized tensors more directly + xk = self.cache_quantizer.quantize(xk).unpack().qs + xv = self.cache_quantizer.quantize(xv).unpack().qs + + xk, xv = self.transact_cache( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) # Expand kv heads for GQA. gqa_n_rep = self.head_count // self.head_count_kv @@ -126,87 +169,88 @@ def forward( def repeat_kv(x: torch.Tensor) -> torch.Tensor: bs, slen, n_kv_heads, head_dim = x.shape - return ( - x.unsqueeze(-2) - .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim) - ) + unsq = x.unsqueeze(-2) + exp = ops.expand(unsq, (bs, slen, n_kv_heads, gqa_n_rep, head_dim)) + return exp.flatten(2, 3) xk = repeat_kv(xk) xv = repeat_kv(xv) + # Fake quant is already dequantized when stored in the cache. + if self.cache_quantizer and not self.fake_quant: + xk = self.cache_quantizer.dequantize_raw_tensor( + xk, torch.float16, name="xk_deq" + ) + xv = self.cache_quantizer.dequantize_raw_tensor( + xv, torch.float16, name="xv_deq" + ) # Transpose into [bs, heads, sl, dim] xq = xq.transpose(1, 2) keys = xk.transpose(1, 2) values = xv.transpose(1, 2) - # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - self.assert_not_nan(attn_weights) + if self.attention_kernel == "decomposed": + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) + if self.attention_scale is None: + attn_weights = attn_weights / math.sqrt(self.head_dim) + else: + attn_weights = attn_weights * self.attention_scale - # Apply attention mask. - self.trace_tensor("attn_weights", attn_weights, values=False) - if attention_mask is not None: - # self.trace_tensor("attn_mask", attention_mask) - attn_weights = attn_weights + attention_mask + # Flash attention. + if self.softcap is not None: + attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap) - attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq) - attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim) - attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1) + self.assert_not_nan(attn_weights) + + # Apply attention mask. + self.trace_tensor("attn_weights", attn_weights, values=False) + if attention_mask is not None: + # self.trace_tensor("attn_mask", attention_mask) + attn_weights = attn_weights + attention_mask + + attn_weights = ops.softmax( + ops.to(attn_weights, dtype=torch.float32), dim=-1 + ) + attn_weights = ops.to(attn_weights, dtype=xq.dtype) + attn_output = ops.matmul( + attn_weights, values + ) # (bs, heads, slen, head_dim) + else: + is_causal = True + attention_mask = None + attn_output = ops.scaled_dot_product_attention( + q=xq, # [bs, ..., sl, dim] + k=keys, # [bs, ..., sl, dim] + v=values, # [bs, ..., sl, dim] + a=attention_mask, # [bs, ..., sl, sl] + is_causal=is_causal, # assumes causal masking when true + scale=None, # defaults to 1/sqrt(dim) + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(2, 3) # Project. attn_output = self.attn_output(attn_output) + attn_output = self.attn_output_norm(attn_output) - # Remainder of the block. h = h + attn_output - return h - def transact_cache_direct( - self, - *, - cache_state: list[torch.Tensor], - xk_cache_update: torch.Tensor, - xv_cache_update: torch.Tensor, - kv_seq_len: int, - start_positions: Optional[torch.Tensor] = None, - ): - bs, batch_seq_len, _, _ = xk_cache_update.shape - cache_k = cache_state[self.block_index * 2] - cache_v = cache_state[self.block_index * 2 + 1] - - if start_positions is None: - # Prefill. Write the entire cache. - cache_k[:, :batch_seq_len] = xk_cache_update - cache_v[:, :batch_seq_len] = xv_cache_update - return xk_cache_update, xv_cache_update - else: - # Decode. Write a single timestep. - # TODO: This needs to be reworked with index ops. - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - max_start_pos = 0 - for row_index in range(bs): - row_start_pos = start_positions[row_index].item() - max_start_pos = max(row_start_pos, max_start_pos) - cache_k[row_index, row_start_pos] = xk_cache_update[row_index, 0] - cache_v[row_index, row_start_pos] = xv_cache_update[row_index, 0] - return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len] - - def transact_cache_paged( + def transact_cache( self, *, xk_cache_update: torch.Tensor, xv_cache_update: torch.Tensor, cache_state: list[torch.Tensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, + seq_block_ids: Optional[torch.Tensor], kv_seq_len: int, start_positions: Optional[torch.Tensor] = None, xk_temp: Optional[torch.Tensor] = None, xv_temp: Optional[torch.Tensor] = None, ): - cache = self.cache.paged + cache = self.cache # Manage the cache. if start_positions is None: # Prefill: Write the entire cache. @@ -217,46 +261,45 @@ def transact_cache_paged( page_ids=seq_block_ids, ) return xk_cache_update, xv_cache_update - else: - # Decode at ragged start positions. - # We need to initialize/read the K/V from the cache for the whole - # sequence. Note that at this point, it is possible to fork and - # use a memory efficient attention kernel that can do indirect - # reads, skipping this materialization. This path is taken for - # a decode step. - assert xk_temp is not None and xv_temp is not None - assert xk_cache_update.shape[1] == 1 - assert xv_cache_update.shape[1] == 1 - assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride - - # Write our one updated cache row into the cache. - cache.write_timestep( - cache_state, - cache_partitions=[ - xk_cache_update, - xv_cache_update, - ], - transformer_block_index=self.block_index, - seq_positions=start_positions + 1, - page_ids=seq_block_ids, - ) - # Restore from the cache. - cache.read( - cache_state, - read_into_partitions=[ - xk_temp[:, 0:kv_seq_len, ...], - xv_temp[:, 0:kv_seq_len, ...], - ], - transformer_block_index=self.block_index, - page_ids=seq_block_ids, - ) + # Decode at ragged start positions. + # We need to initialize/read the K/V from the cache for the whole + # sequence. Note that at this point, it is possible to fork and + # use a memory efficient attention kernel that can do indirect + # reads, skipping this materialization. This path is taken for + # a decode step. + assert xk_temp is not None and xv_temp is not None + assert xk_cache_update.shape[1] == 1 + assert xv_cache_update.shape[1] == 1 + assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride + + # Write our one updated cache row into the cache. + cache.write_timestep( + cache_state, + cache_partitions=[ + xk_cache_update, + xv_cache_update, + ], + transformer_block_index=self.block_index, + seq_positions=start_positions, + page_ids=seq_block_ids, + ) + + # Restore from the cache. + xk, xv = cache.read( + cache_state, + read_into_partitions=[ + xk_temp[:, 0:kv_seq_len, ...], + xv_temp[:, 0:kv_seq_len, ...], + ], + transformer_block_index=self.block_index, + page_ids=seq_block_ids, + seq_len=kv_seq_len, + ) - # For computation, we create a subview of the xk/xv tensors to have - # a sequence length covering the blocked size. This must include - # the newly added row (the caller is responsible for ensuring that - # every block has at least one row left). We'll compute on this - # ragged view and use an appropriate mask. - xk = xk_temp[:, 0:kv_seq_len, ...] - xv = xv_temp[:, 0:kv_seq_len, ...] - return xk, xv + # For computation, we create a subview of the xk/xv tensors to have + # a sequence length covering the blocked size. This must include + # the newly added row (the caller is responsible for ensuring that + # every block has at least one row left). We'll compute on this + # ragged view and use an appropriate mask. + return xk, xv diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index eadd9e6b0..0664a9a46 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -4,11 +4,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional +from collections import namedtuple +from typing import Optional, Union import torch from .base import BaseLayer +from .. import ops +from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor class RotaryEmbeddingLayer(BaseLayer): @@ -19,35 +22,80 @@ def __init__( *, rope_dimension_count: int, max_seqlen: int, - rope_freq_base: Optional[float] = 10000.0, + rope_freq_base: Optional[float], device: Optional[torch.device] = None, use_hf: bool = False, - static_tables: bool = True, + static_tables: bool = False, + use_table: bool = True, + tensor_parallelism_size: int = 1, ): super().__init__() - # Force static_tables until compiler limitations are solved. - # See https://github.com/nod-ai/sharktank/issues/156 - static_tables = True self.device = device self.rope_dimension_count = rope_dimension_count self.max_seqlen = max_seqlen self.use_hf = use_hf - self.rope_freq_base = rope_freq_base + self.static_tables = static_tables + self.use_table = use_table + + self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 + self.tensor_parallelism_size = tensor_parallelism_size if static_tables: - self.register_buffer( - "static_rotary_embed_table", self._create_rotary_embed_table() + ops.module_register_buffer( + self, "static_rotary_embed_table", self._create_rotary_embed_table() ) else: self.static_rotary_embed_table = None @property def rotary_embed_table(self): - if self.static_rotary_embed_table is None: + if self.use_table: + if self.static_tables: + return self.static_rotary_embed_table return self._create_rotary_embed_table() + + return None + + def forward( + self, + *, + xt: Union[torch.Tensor, SplitPrimitiveTensor], + start_index: int, + ): + if isinstance(xt, SplitPrimitiveTensor): + rotary_shards = [None] * xt.shard_count + if self.rotary_embed_table is not None: + assert ( + isinstance(self.rotary_embed_table, ReplicatedTensor) + and xt.shard_count == self.rotary_embed_table.shard_count + ) + rotary_shards = [ + unbox_tensor(shard) for shard in self.rotary_embed_table.shards + ] + + xt_shards = [ + self.forward_unsharded( + xt=unbox_tensor(xt_shard), + start_index=start_index, + rotary_embed_table=rotary_shard, + ) + for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) + ] + xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + return xt else: - return self.static_rotary_embed_table + return self.forward_unsharded( + xt=xt, + start_index=start_index, + rotary_embed_table=self.rotary_embed_table, + ) - def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int): + def forward_unsharded( + self, + *, + xt: torch.Tensor, + start_index: int, + rotary_embed_table: Optional[torch.Tensor], + ): # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim @@ -89,57 +137,33 @@ def create_ordering_tensor(dim): return order_tensor if self.use_hf: - xq = xq[..., create_interleaved_tensor(xq.shape[-1])] - xk = xk[..., create_interleaved_tensor(xq.shape[-1])] - - xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) - _, sl, _, dim = xq_.shape + xt = xt[..., create_interleaved_tensor(xt.shape[-1])] + xt_ = xt + _, sl, _, _ = xt_.shape # Offset the table based on starting position. - freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :] - assert freqs_cis.shape[-1] == dim + if self.use_table: + freqs_cis = rotary_embed_table[start_index : start_index + sl, :] + freqs_cis = freqs_cis[None, 0:sl, None, :] + else: + freqs_cis = torch.arange(sl, device=xt.device) + start_index + freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :] + assert ( - freqs_cis.shape[0] >= sl + freqs_cis.shape[1] >= sl ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})" - broadcast_freqs_cis = freqs_cis[None, 0:sl, None, :] + xt_ = ops.view_as_complex(xt_) + xt_ = xt_ * freqs_cis + xt_out = ops.view_as_real(xt_) if self.use_hf: - xq_out = torch.view_as_real( - self.complex_multiply(xq_, broadcast_freqs_cis) - ).flatten(3) - xk_out = torch.view_as_real( - self.complex_multiply(xk_, broadcast_freqs_cis) - ).flatten(3) - - xq_out = xq_out[..., create_ordering_tensor(xq_out.shape[-1])] - xk_out = xk_out[..., create_ordering_tensor(xq_out.shape[-1])] - - return xq_out.type_as(xq), xk_out.type_as(xk) - - xq_out = torch.view_as_real(xq_ * broadcast_freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * broadcast_freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])] - def complex_multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """Function for elementwise-multiplication of two complex torch tensors. - Functionally similar to a*b, but numerically accurate for HuggingFace - LLaMa implementation. - - Args: - a: First torch tensor operand - b: Second torch tensor operand - Returns: - Tensor of same size to a, b whose elements is product of corresponding - elements in a, b - """ - return torch.complex( - a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real - ) + return ops.to(xt_out, xt.dtype) def compute_batch_mask( - self, start_positions: torch.Tensor, batch_seq_len: int + self, start_positions: Union[torch.Tensor, ReplicatedTensor], batch_seq_len: int ) -> torch.Tensor: """Computes a mask for a batch that can be repeatedly applied. @@ -156,15 +180,46 @@ def compute_batch_mask( ) + start_positions.unsqueeze(1) # Broadcast lookup to [b, ...]. self.trace_tensor("rope.positions_seq", positions_seq) - freqs_cis = self.rotary_embed_table[positions_seq] + + if self.use_table: + freqs_cis = self.rotary_embed_table[positions_seq] + else: + shape = positions_seq.shape + if isinstance(positions_seq, ReplicatedTensor): + ts = [ + self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape) + for s in positions_seq.shards + ] + freqs_cis = ReplicatedTensor(ts=ts) + else: + freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) + freqs_cis = freqs_cis.unflatten(0, shape) # Unsqueeze a unit dim for attention heads. broadcast_freqs_cis = freqs_cis.unsqueeze(2) return broadcast_freqs_cis def apply_batched_mask( - self, *, xq: torch.Tensor, xk: torch.Tensor, mask: torch.Tensor + self, + *, + xt: Union[torch.Tensor, SplitPrimitiveTensor], + mask: Union[torch.Tensor, ReplicatedTensor], ): + if not isinstance(xt, SplitPrimitiveTensor): + return self.apply_batched_mask_unsharded(xt=xt, mask=mask) + + assert isinstance(mask, ReplicatedTensor) and mask.shard_count == xt.shard_count + xt_shards = [ + self.apply_batched_mask_unsharded( + xt=unbox_tensor(xt_shard), + mask=unbox_tensor(mask_shard), + ) + for xt_shard, mask_shard in zip(xt.shards, mask.shards) + ] + xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + return xt + + def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): """Applies the embedding to a ragged batch of queries and keys. This does a more complicated indexing operation for cases when the each @@ -174,30 +229,32 @@ def apply_batched_mask( """ # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim - xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) - _, sl, _, dim = xq_.shape + xt_ = ops.view_as_complex(xt) + xt_ = xt_ * mask + xt_out = ops.view_as_real(xt_) - xq_out = torch.view_as_real(xq_ * mask).flatten(3) - xk_out = torch.view_as_real(xk_ * mask).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + return xt_out.type_as(xt) - def _create_rotary_embed_table( - self, - ): + def _compute_rotary_embed_table(self, t): dim = self.rope_dimension_count - max_seqlen = self.max_seqlen freqs = 1.0 / ( - self.rope_freq_base - ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim) + self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) - t = torch.arange(max_seqlen, device=freqs.device) freqs = torch.outer(t, freqs).float() - freqs_cis = ( - torch.complex(torch.cos(freqs), torch.sin(freqs)) - if self.use_hf - else torch.polar(torch.ones_like(freqs), freqs) - ) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + complex = torch.complex(cos, sin) + return complex + + def _create_rotary_embed_table(self): + t = torch.arange(self.max_seqlen, device=self.device) + freqs_cis = self._compute_rotary_embed_table(t) + return self._replicate(freqs_cis) + + def _replicate(self, t): + if self.tensor_parallelism_size > 1: + # Replicate across all devices, the data is not a lot and the computation is cheap. + t = ops.replicate(t, self.tensor_parallelism_size) - return freqs_cis + return t diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py new file mode 100644 index 000000000..fb330aadd --- /dev/null +++ b/sharktank/sharktank/layers/testing.py @@ -0,0 +1,45 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from ..types.theta import Theta +from ..types.tensors import DefaultPrimitiveTensor +from ..utils.testing import make_rand_torch + + +def make_llama_attention_block_theta( + *, + head_count: int, + head_count_kv: int, + head_dim: int, + embedding_length: int, + dtype: torch.dtype | None = None, +) -> Theta: + return Theta( + { + "attn_q.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (head_count * head_dim, embedding_length), dtype=dtype + ) + ), + "attn_k.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (head_count_kv * head_dim, embedding_length), dtype=dtype + ) + ), + "attn_v.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (head_count_kv * head_dim, embedding_length), dtype=dtype + ) + ), + "attn_output.weight": DefaultPrimitiveTensor( + data=make_rand_torch((embedding_length, embedding_length), dtype=dtype) + ), + "attn_norm.weight": DefaultPrimitiveTensor( + data=make_rand_torch((embedding_length), dtype=dtype) + ), + } + ) diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py new file mode 100644 index 000000000..077e4e064 --- /dev/null +++ b/sharktank/sharktank/models/grok/grok.py @@ -0,0 +1,234 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from ...layers import * +from ...utils.create_cache import * +from ...types import Theta + +torch.set_printoptions(profile="full") + +__all__ = [ + "PagedGrokModelV1", +] + +################################################################################ +# Models +################################################################################ + + +class PagedGrokModelV1(BaseCausalLMModel): + """Grok model with a paged KV cache and supporting variable sequence + length batched inference. + + As both the caching and batching setup is complicated, this model variant + is modular, intending to be instantiated and used in an overall assembly + vs trying to providing one-stop methods that do everything. + + The inference procedure is typically: + + 1. Initialize the PagedKVCache state tensors. + 2. Generate an input mask given a vector of sequence lengths. + 3. Generate an attention mask from the input mask. + 4. Allocate a block mapping table. + 5. Invoke prefill() with a batch of sequences. + 6. Extract tokens from batched logits. + 7. Iteratively invoke decode() for as long as there are sequences needing + to be serviced. + + Various samplers and schedulers can be interleaved throughout. + """ + + def __init__(self, theta: Theta, config: LlamaModelConfig): + hp = config.hp + super().__init__( + theta, + context_length=config.hp.context_length, + device=config.device, + activation_dtype=config.activation_dtype, + attention_dtype=config.attention_dtype, + ) + self.config = config + self.hp = hp + self.cache = create_kv_cache(self.config) + self.activation_dtype = config.activation_dtype + self.add_module( + "token_embedding", + TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + ) + self.add_module( + "attention_embedding", + RotaryEmbeddingLayer( + rope_dimension_count=hp.rope_dimension_count, + rope_freq_base=hp.rope_freq_base, + max_seqlen=hp.context_length, + device=self.device, + use_hf=True, + ), + ) + self.add_module( + "output_norm", + RMSNormLayer( + theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon + ), + ) + self.add_module("output_lm_head", LinearLayer(theta("output"))) + + self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() + + for n in range(hp.block_count): + self.attn_blocks.append( + PagedLlamaAttentionBlock( + theta("blk", n), + block_index=n, + cache=self.cache, + head_count=hp.attention_head_count, + head_dim=hp.attn_head_dim, + head_count_kv=hp.attention_head_count_kv, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864 + ) + ) + self.moe_blocks.append( + MoeBlock( + theta("blk", n), + expert_count=hp.expert_count, + expert_used_count=hp.expert_used_count, + rms_epsilon=hp.attention_layer_norm_rms_epsilon, + moe_activation=F.gelu, + ) + ) + + def prefill( + self, + # [bs, batch_seq_len] + tokens: torch.Tensor, + *, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + ): + self._assert_device(tokens) + self._assert_device(attention_mask, dtype=self.activation_dtype) + self._assert_device(seq_block_ids) + self._assert_device(*cache_state, dtype=self.activation_dtype) + h = self.token_embedding(tokens) + h *= math.sqrt(h.shape[-1]) + self.trace_tensor("grok.token_embedding", h) + + # Iterate over attention blocks. + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): + if block_idx == 0: + self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) + + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + + h = self.output_norm(h) + logits = self.output_lm_head(h) + logits = logits / math.sqrt(3.0) + return logits + + def decode( + self, + # [bs, 1] + tokens: torch.Tensor, + *, + # [bs, 1, 1, batch_seq_len] + attention_mask: torch.Tensor, + # [bs] of starting positions + start_positions: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + ): + self._assert_device(tokens) + self._assert_device(attention_mask, dtype=self.activation_dtype) + self._assert_device(start_positions) + self._assert_device(*cache_state, dtype=self.activation_dtype) + bs, _ = tokens.shape + # Precompute a position based mask for computing rope embeddings + # as it is the same for all blocks. + embedding_batch_mask = self.attention_embedding.compute_batch_mask( + start_positions, batch_seq_len=1 + ) + self.trace_tensor("grok.embedding_batch_mask", embedding_batch_mask) + + # Allocate per-block temporary K/V tensors. These temporaries hold + # one block's K/V state for the maximum context length. + xk_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + xv_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + + h = self.token_embedding(tokens) + h *= math.sqrt(h.shape[-1]) + self.trace_tensor("grok.token_embedding", h) + + # Iterate over attention blocks. + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): + if block_idx == 0: + self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) + + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + + h = self.output_norm(h) + logits = self.output_lm_head(h) + logits = logits / math.sqrt(3.0) + return logits diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 8266872ad..0a9a6f1c3 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -7,7 +7,7 @@ from typing import Optional from dataclasses import dataclass -import math +from typing import Union import torch import torch.nn as nn @@ -15,77 +15,13 @@ from ...layers import * from ...types import * +from ...utils.create_cache import * from ... import ops __all__ = [ - "LlamaModelConfig", "PagedLlamaModelV1", ] -################################################################################ -# Config -################################################################################ - - -@dataclass -class LlamaModelConfig: - hp: configs.LlamaHParams - - # Block sequence stride for a paged KV cache. This must divide evenly - # into the context length. - block_seq_stride: int = 16 - - # Either "paged" or "direct". - kv_cache_type: str = "paged" - - # The device on which to place intermediate state. - device: Optional[torch.device] = None - - # Dtype to use for general FP activations not otherwise configured. - activation_dtype: torch.dtype = torch.float16 - - # Dtype to use for attention. - attention_dtype: torch.dtype = torch.float16 - - # Indicates if running with HuggingFace implementation and ensures - # numerical equivalency to HuggingFace's LLaMa if true (by modifying - # rotary embedding). - use_hf: bool = False - - # If true, then the model may pre-initialize certain tables during - # init. This can be better for eager execution but when capturing a program, - # it is often better to preserve the calculation explicitly and rely on - # the compiler to transform it to an initialization time step. This can - # be the difference of many gigabytes of static data being embedded in - # the program and not. - static_tables: bool = True - - def create_kv_cache(self) -> BaseKVCache: - hp = self.hp - if self.kv_cache_type == "direct": - return DirectKVCache( - block_seq_stride=self.block_seq_stride, - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - seq_length=hp.context_length, - device=self.device, - dtype=self.attention_dtype, - ) - elif self.kv_cache_type == "paged": - return PagedKVCache( - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - cache_partition_count=2, # One for each of K/V. - block_seq_stride=self.block_seq_stride, - device=self.device, - dtype=self.attention_dtype, - ) - else: - raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}") - - ################################################################################ # Models ################################################################################ @@ -111,6 +47,19 @@ class PagedLlamaModelV1(BaseCausalLMModel): to be serviced. Various samplers and schedulers can be interleaved throughout. + + In the case of tensor sharding (config.tensor_parallelism_size > 1) the model's KV + cache head dimension is sharded. + The number of KV cache heads must be divisible by the parallelism size. + With this sharding approach the KV cache is not replicated across devices. + The cache is split across the devices while the indexing logic/computation is + replicated. + All other arguments aside from the cache state are replicated. + After the attention we all-reduce. + The the first fully connected layer is split along the parallel dimension. + This drives that the reduction dimension is split for the second FC layer. + We return the unreduced tensor. The user is free to reduce it to obtain the + unsharded result or chain it with other tensor-parallel operations. """ def __init__(self, theta: Theta, config: LlamaModelConfig): @@ -122,12 +71,14 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, + fake_quant=config.fake_quant, ) self.config = config self.hp = hp - self.cache = config.create_kv_cache() + self.cache = create_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.use_hf = config.use_hf + self.attention_kernel = config.attention_kernel self.add_module( "token_embedding", @@ -142,6 +93,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): device=self.device, use_hf=self.use_hf, static_tables=config.static_tables, + tensor_parallelism_size=config.tensor_parallelism_size, ), ) self.add_module( @@ -151,7 +103,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ), ) self.add_module("output_lm_head", LinearLayer(theta("output"))) - self.attn_blocks = nn.ModuleList( [ AttentionFFNBlock( @@ -162,7 +113,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_hf=self.use_hf, + attention_kernel=self.attention_kernel, + fake_quant=self.fake_quant, ) for n in range(hp.block_count) ] @@ -171,18 +123,19 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): def prefill( self, # [bs, batch_seq_len] - tokens: torch.Tensor, + tokens: Union[torch.Tensor, ReplicatedTensor], *, # [1, 1, batch_seq_len, batch_seq_len] - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, ReplicatedTensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, - cache_state: list[torch.Tensor], + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], ): self._assert_device(tokens) self._assert_device(attention_mask, dtype=self.activation_dtype) self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) + h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) @@ -207,20 +160,34 @@ def prefill( def decode( self, # [bs, 1] - tokens: torch.Tensor, + tokens: Union[torch.Tensor, ReplicatedTensor], *, # [bs, 1, 1, batch_seq_len] - attention_mask: torch.Tensor, + attention_mask: Union[torch.Tensor, ReplicatedTensor], # [bs] of starting positions - start_positions: torch.Tensor, + start_positions: Union[torch.Tensor, ReplicatedTensor], # [bs, batch_seq_len // block_seq_stride] - seq_block_ids: torch.Tensor, - cache_state: list[torch.Tensor], + seq_block_ids: Union[torch.Tensor, ReplicatedTensor], + cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]], ): + assert len(tokens.shape) == 2 + assert len(attention_mask.shape) == 4 + assert len(start_positions.shape) == 1 + assert len(seq_block_ids.shape) == 2 + assert tokens.shape[0] == attention_mask.shape[0] + assert tokens.shape[0] == start_positions.shape[0] + assert tokens.shape[0] == seq_block_ids.shape[0] + assert tokens.shape[1] == 1 + assert attention_mask.shape[1] == 1 and attention_mask.shape[2] == 1 + assert ( + attention_mask.shape[3] + == seq_block_ids.shape[1] * self.config.block_seq_stride + ) self._assert_device(tokens) self._assert_device(attention_mask, dtype=self.activation_dtype) self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) + bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. @@ -231,26 +198,48 @@ def decode( # Allocate per-block temporary K/V tensors. These temporaries hold # one block's K/V state for the maximum context length. - xk_temp = torch.empty( - [ - bs, - self.context_length, - self.hp.attention_head_count_kv, - self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) - xv_temp = torch.empty( - [ + if self.config.tensor_parallelism_size == 1: + xk_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + xv_temp = torch.empty( + [ + bs, + self.context_length, + self.hp.attention_head_count_kv, + self.hp.attn_head_dim, + ], + dtype=self.config.activation_dtype, + device=self.device, + ) + else: + shard_size = [ bs, self.context_length, - self.hp.attention_head_count_kv, + self.hp.attention_head_count_kv // self.config.tensor_parallelism_size, self.hp.attn_head_dim, - ], - dtype=self.config.activation_dtype, - device=self.device, - ) + ] + xk_temp_shard = [ + torch.empty( + shard_size, dtype=self.config.activation_dtype, device=self.device + ) + for _ in range(self.config.tensor_parallelism_size) + ] + xv_temp_shard = [ + torch.empty( + shard_size, dtype=self.config.activation_dtype, device=self.device + ) + for _ in range(self.config.tensor_parallelism_size) + ] + xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2) + xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2) h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) @@ -296,7 +285,8 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_hf: bool = False, + attention_kernel: str = "decomposed", + fake_quant: bool = True, ): super().__init__(theta) self.add_module( @@ -309,7 +299,8 @@ def __init__( head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, - use_hf=use_hf, + attention_kernel=attention_kernel, + fake_quant=fake_quant, ), ) self.add_module( @@ -324,7 +315,7 @@ def __init__( def forward( self, - h: torch.Tensor, + h: Union[torch.Tensor, ReplicatedTensor], *, embedding: RotaryEmbeddingLayer, # [bs, batch_seq_len // block_seq_stride] @@ -349,6 +340,7 @@ def forward( xk_temp=xk_temp, xv_temp=xv_temp, ) + # Feed forward network. ffn_input = self.ffn_norm(h) ffn_down = self.ffn(ffn_input) diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index 74ed9e8e0..9f77daa40 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -7,7 +7,6 @@ from typing import Optional from dataclasses import dataclass -import math import torch import torch.nn as nn @@ -230,7 +229,9 @@ def forward( values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/models/llama/sharding.py b/sharktank/sharktank/models/llama/sharding.py new file mode 100644 index 000000000..3715a3923 --- /dev/null +++ b/sharktank/sharktank/models/llama/sharding.py @@ -0,0 +1,114 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Specifications describing how the Llama model is sharded.""" + +from ...types.sharding import * +from ...types import Theta +from ... import ops + + +class PagedLlamaAttentionBlockSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + # The size of this is the token embedding length, which is not a memory + # space concern if replicated even for all attention blocks. + "attn_norm": RmsNormReplicatedSharding( + self.shard_count + ).theta_sharding(), + "attn_q": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_k": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_v": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_output": LinearSplitReductionDimSharding( + shard_count=self.shard_count + ).theta_sharding(), + } + ) + + +class AttentionFFNBlockSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + result = PagedLlamaAttentionBlockSharding(self.shard_count).theta_sharding() + result.update(FFNSharding(self.shard_count).theta_sharding()) + result.update( + { + # The size of this is the token embedding length, which is not a memory + # space concern if replicated. + "ffn_norm": RmsNormReplicatedSharding(self.shard_count).theta_sharding() + } + ) + return result + + +class LlamaSharding(ThetaLayerSharding): + """Shards the input channel and output channels of the convolutions.""" + + def __init__(self, shard_count: int, attention_block_count: int): + super().__init__() + self.shard_count = shard_count + self.attention_block_count = attention_block_count + + def theta_sharding(self) -> ThetaSharding: + result = ThetaSharding( + { + # Replicate the vocabulary. For llama 1-3 this will require 0.5 GiB. + # For devices with large memory this may be an acceptable tradeoff where + # we save on communication by not all-gathering the result afterwards. + # The computation is just indexing and replication is not a concern. + # Alternatively, we can try splitting the index dimension, + # this would require custom logic for indexing partitioning and gathering. + "token_embd": TokenEmbeddingLayerReplicatedSharding( + self.shard_count + ).theta_sharding(), + "rope_freqs": Ignore(), + "output_norm": RmsNormReplicatedSharding( + self.shard_count + ).theta_sharding(), + "output": LinearSplitReductionDimSharding( + self.shard_count + ).theta_sharding(), + } + ) + result.update( + { + "blk": ThetaSharding( + { + f"{i}": AttentionFFNBlockSharding( + self.shard_count + ).theta_sharding() + for i in range(self.attention_block_count) + } + ) + } + ) + return result + + +def shard_theta( + theta: Theta, config: "sharktank.models.llama.llama.LlamaModelConfig" +) -> Theta: + return ops.reshard( + theta, + LlamaSharding( + shard_count=config.tensor_parallelism_size, + attention_block_count=config.hp.block_count, + ), + ) diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py index b63fd5d07..079602b28 100644 --- a/sharktank/sharktank/models/llama/testing.py +++ b/sharktank/sharktank/models/llama/testing.py @@ -10,12 +10,11 @@ from ...types.tensors import * from ...types.theta import Theta - - -# Range of torch.rand() is [0,1) -# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values -def make_rand_torch(shape, dtype=torch.float32): - return torch.rand(shape, dtype=dtype) * 2 - 1 +from typing import Optional +from .llama import LlamaModelConfig +import torch +from ...utils.testing import make_rand_torch +from ...layers.testing import make_llama_attention_block_theta def make_attention_block_theta( @@ -56,11 +55,54 @@ def make_attention_block_theta( ) +def make_attention_block_ffn_theta_v2( + *, + head_count: int, + head_count_kv: int, + head_dim: int, + embedding_length: int, + feed_forward_length: int, + dtype: torch.dtype | None = None, +) -> Theta: + attention_theta = make_llama_attention_block_theta( + head_count=head_count, + head_count_kv=head_count_kv, + head_dim=head_dim, + embedding_length=embedding_length, + dtype=dtype, + ) + ffn_theta = Theta( + { + "ffn_norm.weight": DefaultPrimitiveTensor( + data=make_rand_torch((head_count * head_dim), dtype=dtype) + ), + "ffn_gate.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (feed_forward_length, embedding_length), dtype=dtype + ) + ), + "ffn_up.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (feed_forward_length, embedding_length), dtype=dtype + ) + ), + "ffn_down.weight": DefaultPrimitiveTensor( + data=make_rand_torch( + (embedding_length, feed_forward_length), dtype=dtype + ) + ), + } + ) + res_dict = attention_theta.tree + res_dict.update(ffn_theta.tree) + return Theta(res_dict) + + def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta: return Theta( { "blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor( - data=make_rand_torch((feature_dim, ffn_dim)) + data=make_rand_torch((num_experts, ffn_dim)) ), "blk.0.ffn_norm.weight": DefaultPrimitiveTensor( data=make_rand_torch((ffn_dim)) @@ -69,13 +111,41 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta data=make_rand_torch((ffn_dim)) ), "blk.0.ffn_gate_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((8, feature_dim * num_experts, ffn_dim)) + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)) ), "blk.0.ffn_up_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((8, feature_dim * num_experts, ffn_dim)) + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)) ), "blk.0.ffn_down_exps.weight": DefaultPrimitiveTensor( - data=make_rand_torch((8, ffn_dim, feature_dim * num_experts)) + data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)) ), } ) + + +def make_random_llama_theta( + config: LlamaModelConfig, vocab_size: int, dtype: Optional[torch.dtype] = None +) -> Theta: + res = { + "token_embd.weight": DefaultPrimitiveTensor( + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype) + ) + } + for i in range(config.hp.block_count): + res[f"blk.{i}"] = make_attention_block_ffn_theta_v2( + head_count=config.hp.attention_head_count, + head_count_kv=config.hp.attention_head_count_kv, + head_dim=config.hp.attn_head_dim, + embedding_length=config.hp.embedding_length, + feed_forward_length=config.hp.feed_forward_length, + dtype=dtype, + ).tree + + res[f"output.weight"] = DefaultPrimitiveTensor( + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype) + ) + res[f"output_norm.weight"] = DefaultPrimitiveTensor( + data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype) + ) + + return Theta(res) diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py new file mode 100644 index 000000000..052593748 --- /dev/null +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -0,0 +1,371 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Imports quark pre-processed weights and quantization config into a +Dataset of the gguf format. + +Usage: + python -m sharktank.models.llama.tools.import_quark_dataset \ + --params=llama2-7b-fp8.safetensors --output-irpa-file=new.irpa \ + --config-json=../llama2/config.json + +""" +from typing import Optional + +from safetensors.torch import save_file +import json +from pathlib import Path +import safetensors +import sys +import torch + +from sharktank.types import * +from sharktank.layers.configs.llm_configs import ( + _int_prop, + _float_prop, + _optional_int_prop, + _int_prop, +) + + +def _load_json(p: Path): + print(f"Loading {p}") + with open(p, "rb") as f: + return json.load(f) + + +def _get_dataset_props(config_json_struct) -> dict: + # Separate meta parameters (prefixed with _) from hparams. + meta_params = {k: v for k, v in config_json_struct.items() if k.startswith("_")} + hparams = {k: v for k, v in config_json_struct.items() if not k.startswith("_")} + return { + "meta": meta_params, + "hparams": hparams, + } + + +def _load_theta(st_source) -> Theta: + tensors = [ + DefaultPrimitiveTensor(name=name, data=st_source.get_tensor(name)) + for name in st_source.keys() + ] + return Theta(tensors) + + +def as_torch_or_none(tensor: Optional[InferenceTensor]) -> Optional[torch.Tensor]: + if tensor is None: + return None + return tensor.as_torch() + + +def hf_to_gguf(layer_name: str) -> str: + assert layer_name.startswith("model.layers") + mapping = { + "input_layernorm": "attn_norm", + "self_attn.q_proj": "attn_q", + "self_attn.k_proj": "attn_k", + "self_attn.v_proj": "attn_v", + "self_attn.o_proj": "attn_output", + "post_attention_layernorm": "ffn_norm", + "mlp.gate_proj": "ffn_gate", + "mlp.up_proj": "ffn_up", + "mlp.down_proj": "ffn_down", + } + + # Split the input string + parts = layer_name.split(".") + + # Extract the numerical value and the key to be mapped + numerical_value = parts[2] # The part after "models.layers" and its number + key_to_map = ".".join(parts[3:]) + + # Map the key + if key_to_map in mapping: + mapped_value = mapping[key_to_map] + else: + raise ValueError(f"Mapping for '{key_to_map}' not found.") + + # Construct the output string + output_str = f"blk.{numerical_value}.{mapped_value}" + return output_str + + +def apply_per_layer_quant( + root_theta: Theta, + layer_name: str, + updated_tensors: dict[str, InferenceTensor], + n_head: int, + split_sizes: list[int], +): + """Take the quantization parameters and hf weights from the imported Theta + and create InferenceTensors out of them, converting their names to gguf format + in the process. + """ + + layer_theta = root_theta(layer_name) + + weight_quant_scale = layer_theta.tensor("weight_scale").as_torch() + + weight = layer_theta.tensor("weight").as_torch() + + # It looks dumb but, this step is required for numerical correctness against quark. + # weight = weight.view(torch.float8_e4m3fn) + weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.float16) + + weight_quant_zero_point = layer_theta.optional_tensor("weight_zero_point") + if weight_quant_zero_point == None: + weight_quant_zero_point = torch.zeros(1, dtype=torch.float32) + else: + weight_quant_zero_point = weight_quant_zero_point.as_torch() + input_quant_scale = as_torch_or_none(layer_theta.optional_tensor("input_scale")) + output_quant_scale = as_torch_or_none(layer_theta.optional_tensor("output_scale")) + + if weight_quant_scale is None: + print("weight quant scale not found for layer ", layer_name) + return + + layer_parent = ".".join(layer_name.split(".")[:-1]) + + def quantize_weight( + weight_name: str, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: Optional[torch.Tensor], + ): + # Our scale is the reciprocal of the quark scale + # We multiply scale by two to account for diff between fnuz and fn + weight_quantizer = StaticScaledQuantizer( + scale=1.0 / (weight_scale * 2.0), + reciprocal_scale=(weight_scale * 2.0), + offset=None + if (weight_zp is None or torch.count_nonzero(weight_zp) == 0) + else weight_zp, + dtype=torch.float8_e4m3fnuz, + ) + weight_quant = weight_quantizer.quantize(weight, name=weight_name) + updated_tensors[weight_quant.name] = weight_quant + + if "qkv" in layer_name: + # The qkv layer is fused in the quark model, decompose back into individual q, k , and v weights + q_weight, k_weight, v_weight = torch.split(weight, split_sizes) + q_weight = ( + q_weight.reshape( + n_head, 2, q_weight.shape[0] // n_head // 2, *q_weight.shape[1:] + ) + .swapaxes(1, 2) + .reshape(q_weight.shape) + ) + k_weight = ( + k_weight.reshape( + n_head, 2, k_weight.shape[0] // n_head // 2, *k_weight.shape[1:] + ) + .swapaxes(1, 2) + .reshape(k_weight.shape) + ) + q_name = hf_to_gguf(layer_parent + ".q_proj") + k_name = hf_to_gguf(layer_parent + ".k_proj") + v_name = hf_to_gguf(layer_parent + ".v_proj") + quantize_weight( + q_name + ".weight", q_weight, weight_quant_scale, weight_quant_zero_point + ) + quantize_weight( + k_name + ".weight", k_weight, weight_quant_scale, weight_quant_zero_point + ) + quantize_weight( + v_name + ".weight", v_weight, weight_quant_scale, weight_quant_zero_point + ) + # The output and input quantizers are duplicated for each of the q, k, and v weights + names = [f"{i}.qdq_output" for i in [q_name, k_name, v_name]] + for name in names: + updated_tensors[name] = StaticScaledQuantizer( + name=name, + scale=1.0 / (output_quant_scale * 2.0), + reciprocal_scale=output_quant_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + names = [f"{i}.q_input" for i in [q_name, k_name, v_name]] + for name in names: + updated_tensors[name] = StaticScaledQuantizer( + name=name, + scale=1.0 / (input_quant_scale * 2.0), + reciprocal_scale=input_quant_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + # Remove the updated tensors from the original tree. + root_theta.pop(layer_parent + ".q_proj") + root_theta.pop(layer_parent + ".k_proj") + root_theta.pop(layer_parent + ".v_proj") + root_theta.pop(layer_name) + + else: + new_layer_name = hf_to_gguf(layer_name) + quantize_weight( + new_layer_name + ".weight", + weight, + weight_quant_scale, + weight_quant_zero_point, + ) + # we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences + if input_quant_scale is not None: + updated_tensors[new_layer_name + ".q_input"] = StaticScaledQuantizer( + name=new_layer_name + ".q_input", + scale=1.0 / (input_quant_scale * 2.0), + reciprocal_scale=input_quant_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + if output_quant_scale is not None: + updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer( + name=new_layer_name + ".qdq_output", + scale=1.0 / output_quant_scale, + reciprocal_scale=output_quant_scale, + dtype=torch.float8_e4m3fnuz, + ) + + # Remove the updated tensor from the original tree. + root_theta.pop(layer_name) + + +def convert_hf_hparams_to_gguf(hf_hparams: dict[str, any]) -> dict[str, any]: + hp = hf_hparams["hparams"] + attention_head_count = _int_prop(hp, "num_attention_heads") + attn_head_dim = int( + _int_prop(hp, "hidden_size") // _int_prop(hp, "num_attention_heads") + ) + + return { + "llama.context_length": _int_prop(hp, "max_position_embeddings"), + "llama.embedding_length": _int_prop(hp, "hidden_size"), + "llama.block_count": _int_prop(hp, "num_hidden_layers"), + "llama.feed_forward_length": _int_prop(hp, "intermediate_size"), + "llama.rope.dimension_count": attn_head_dim, + "llama.attention.head_count": attention_head_count, + "llama.attention.layer_norm_rms_epsilon": _float_prop(hp, "rms_norm_eps"), + "llama.attention.head_count_kv": _optional_int_prop( + hp, "num_key_value_heads", attention_head_count + ), + } + + +def update_norm_layer( + quant_theta: Theta, layer_name: str, updated_tensors: dict[str, InferenceTensor] +): + """Convert layernames for non quantized tensors and add them to the updated_tensors dict""" + for sub in ["input_layernorm", "post_attention_layernorm"]: + sub_name = layer_name + "." + sub + new_name = hf_to_gguf(sub_name) + ".weight" + single_replace(quant_theta, sub_name, new_name, updated_tensors) + kv_cache_scale = quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch() + layer_idx = layer_name.split(".")[-1] + new_name = f"blk.{layer_idx}.kv_cache" + updated_tensors[new_name] = StaticScaledQuantizer( + name=new_name + ".quantizer", + scale=1.0 / (kv_cache_scale * 2.0), + reciprocal_scale=kv_cache_scale * 2.0, + dtype=torch.float8_e4m3fnuz, + ) + + +def single_replace( + quant_theta: Theta, + layer_name: str, + gguf_name: str, + updated_tensors: dict[str, InferenceTensor], +): + data = quant_theta(layer_name).tensor("weight").as_torch() + if data.dtype == torch.bfloat16: + data = data.to(torch.float32) + updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data) + + +def main(argv): + from sharktank.utils import cli + + parser = cli.create_parser() + cli.add_output_dataset_options(parser) + parser.add_argument( + "--config-json", type=Path, required=True, help="Path to the config.json file" + ) + parser.add_argument( + "--params", + type=Path, + default=Path("params.safetensors"), + help="Parameter file name, relative to config.json", + ) + parser.add_argument( + "--model-base", + type=str, + default="7b", + help="Base model to use for split sizes to decompose the qkv tensor. Default is 7b, 70b is also supported.", + choices=["7b", "70b"], + ) + args = cli.parse(parser, args=argv) + + config_json_path: Path = args.config_json + params_path: Path = args.params + # TODO: find a way to get this programatically so we don't have to flag for it + split_sizes = [4096, 4096, 4096] if args.model_base == "7b" else [8192, 1024, 1024] + num_layers = 32 if args.model_base == "7b" else 80 + + # Construct the pre-transform dataset. + dataset_props = _get_dataset_props(_load_json(config_json_path)) + with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: + quant_theta = _load_theta(st) + ds = Dataset(dataset_props, quant_theta) + + # Convert hyperparams to gguf format + updated_properties = convert_hf_hparams_to_gguf(ds.properties) + + head_count = (updated_properties["llama.attention.head_count"],) + + updated_tensors: dict[str, InferenceTensor] = {} + model_layers = [f"model.layers.{i}" for i in range(num_layers)] + + sub_layers = [ + "mlp.gate_proj", + "mlp.down_proj", + "mlp.up_proj", + "self_attn.o_proj", + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ] + for layer in model_layers: + for sub in sub_layers: + layer_name = layer + "." + sub + apply_per_layer_quant( + quant_theta, + layer_name, + updated_tensors, + n_head=head_count[0], + split_sizes=split_sizes, + ) + + # Update the non quantized weights (norm layers) + for layer_idx in model_layers: + update_norm_layer( + quant_theta, + layer_idx, + updated_tensors, + ) + + # The stragglers + stragglers = [ + ("model.embed_tokens", "token_embd.weight"), + ("model.norm", "output_norm.weight"), + ("lm_head", "output.weight"), + ] + for layer, new_name in stragglers: + single_replace(quant_theta, layer, new_name, updated_tensors) + + new_theta = Theta(updated_tensors) + # Make a new Dataset from the updated properties and tensors. + new_ds = Dataset(updated_properties, new_theta) + + new_ds.save(args.output_irpa_file, io_report_callback=print) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 5a179e5b9..e2995dfde 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -13,65 +13,15 @@ from ...layers import * +from ...utils.create_cache import * from ...types import Theta torch.set_printoptions(profile="full") __all__ = [ - "LlamaModelConfig", "PagedMixtralModelV1", ] -################################################################################ -# Config -################################################################################ - - -@dataclass -class LlamaModelConfig: - hp: configs.LlamaHParams - - # Block sequence stride for a paged KV cache. This must divide evenly - # into the context length. - block_seq_stride: int = 16 - - # Either "paged" or "direct". - kv_cache_type: str = "paged" - - # The device on which to place intermediate state. - device: Optional[torch.device] = None - - # Dtype to use for general FP activations not otherwise configured. - activation_dtype: torch.dtype = torch.float16 - - # Dtype to use for attention. - attention_dtype: torch.dtype = torch.float16 - - def create_kv_cache(self) -> BaseKVCache: - hp = self.hp - if self.kv_cache_type == "direct": - return DirectKVCache( - block_seq_stride=self.block_seq_stride, - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - seq_length=hp.context_length, - device=self.device, - dtype=self.attention_dtype, - ) - elif self.kv_cache_type == "paged": - return PagedKVCache( - transformer_block_count=hp.block_count, - attn_head_count=hp.attention_head_count_kv, - attn_head_dim=hp.attn_head_dim, - cache_partition_count=2, # One for each of K/V. - block_seq_stride=self.block_seq_stride, - device=self.device, - dtype=self.attention_dtype, - ) - else: - raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}") - ################################################################################ # Models @@ -111,7 +61,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) self.config = config self.hp = hp - self.cache = config.create_kv_cache() + self.cache = create_kv_cache(self.config) self.activation_dtype = config.activation_dtype self.add_module( "token_embedding", @@ -135,6 +85,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -148,8 +99,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( - SparseMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, @@ -176,25 +127,26 @@ def prefill( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) @@ -252,28 +204,29 @@ def decode( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 392f60a25..70a9b9cf8 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -66,6 +66,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -78,8 +79,8 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( - SparseMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, @@ -130,28 +131,29 @@ def forward( block_count = len(self.attn_blocks) // 2 # print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count) # Iterate over attention + MoE blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): # print("block_idx, block", block_idx, block) if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "LlamaAttentionBlock": - attn_block_idx = block_idx // 2 - block_cache_k = local_kv_cache[attn_block_idx] - block_cache_v = local_kv_cache[block_count + attn_block_idx] - h = block( - h, - cache_k=block_cache_k, - cache_v=block_cache_v, - start_index=start_index, - attention_mask=attention_mask, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + attn_block_idx = block_idx // 2 + block_cache_k = local_kv_cache[attn_block_idx] + block_cache_v = local_kv_cache[block_count + attn_block_idx] + h = attn_block( + h, + cache_k=block_cache_k, + cache_v=block_cache_v, + start_index=start_index, + attention_mask=attention_mask, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = attn_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/models/punet/sharding.py b/sharktank/sharktank/models/punet/sharding.py index 22237b8bb..a827b7942 100644 --- a/sharktank/sharktank/models/punet/sharding.py +++ b/sharktank/sharktank/models/punet/sharding.py @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Specifications describing how block/layers of punet are sharded.""" +"""Specifications describing how blocks/layers of punet are sharded.""" from ...types.sharding import * @@ -31,7 +31,7 @@ def theta_sharding(self) -> ThetaSharding: "conv2": Conv2DSplitOutputChannelSharding( shard_count=self.shard_count ).theta_sharding(), - "time_emb_proj": LinearReplicatedInputSplitWeightAndBiasSharding( + "time_emb_proj": LinearSplitParallelWeightAndBiasSharding( shard_count=self.shard_count ).theta_sharding(), "conv_shortcut": Conv2DSplitOutputChannelSharding( diff --git a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py index fa862105c..22ca1f591 100644 --- a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py +++ b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py @@ -133,8 +133,19 @@ def _get_json_tensor( # for signed arithmetic. input_zp = _get_json_tensor("input_zp", dtype=None) if input_zp is not None: - assert torch.count_nonzero(input_zp) == 0 - + assert torch.count_nonzero(input_zp.float()) == 0 + + # Currently, there seems to be no standardization in `quant_params.json` for fields in every layer + # across different quantization schemes (int8, fp8). int8 quantization was the first end-to-end tested + # quantization scheme so there's some defaults to that. + quantization_type = ( + qp.get("input_zp_dtype") + if qp.get("input_zp_dtype") is not None + else "torch.int8" + ) + quantization_dtype = tensors.serialized_name_to_dtype( + quantization_type.split(".")[-1] + ) if output_scale is not None: output_quantizer = StaticScaledQuantizer( name=f"{layer_name}.q_output", @@ -175,10 +186,12 @@ def quantize_weight( weight_quantizer = StaticScaledQuantizer( scale=1.0 / weight_scale, reciprocal_scale=weight_scale, - offset=None - if (weight_zp is None or torch.count_nonzero(weight_zp) == 0) - else weight_zp, - dtype=torch.int8, + offset=( + None + if (weight_zp is None or torch.count_nonzero(weight_zp) == 0) + else weight_zp + ), + dtype=quantization_dtype, ) weight_quant = weight_quantizer.quantize(weight, name=weight_name) updated_tensors[weight_quant.name] = weight_quant @@ -195,7 +208,7 @@ def quantize_bias( bias_scale = 1.0 / (input_scale * weight_scale) bias_quantizer = StaticScaledQuantizer( scale=bias_scale, - dtype=torch.int32, + dtype=torch.int32 if quantization_dtype == torch.int8 else torch.float16, disable_saturate=True, ) bias_quant = bias_quantizer.quantize(bias, name=bias_name) @@ -227,7 +240,7 @@ def quantize_bias( f"{layer_name}.to_v.weight", weight_v, weight_scale_v, weight_zp_v ) updated_tensors[weight.name] = None - if bias is not None: + if bias is not None and quantization_dtype == torch.int8: bias_k, bias_v = bias.as_torch().chunk(2, dim=0) quantize_bias( f"{layer_name}.to_k.bias", bias_k, input_scale, weight_scale_k @@ -259,7 +272,7 @@ def quantize_bias( f"{layer_name}.to_v.weight", weight_v, weight_scale_v, weight_zp_v ) updated_tensors[weight.name] = None - if bias is not None: + if bias is not None and quantization_dtype == torch.int8: bias_q, bias_k, bias_v = bias.as_torch().chunk(3, dim=0) quantize_bias( f"{layer_name}.to_q.bias", bias_q, input_scale, weight_scale_q @@ -284,7 +297,7 @@ def quantize_bias( name=f"{layer_name}.q_input", scale=1.0 / input_scale, reciprocal_scale=input_scale, - dtype=torch.int8, + dtype=quantization_dtype, ) updated_tensors[input_quantizer.name] = input_quantizer diff --git a/sharktank/sharktank/models/punet/tools/run_punet.py b/sharktank/sharktank/models/punet/tools/run_punet.py index ace279a3b..b2ad58d9d 100644 --- a/sharktank/sharktank/models/punet/tools/run_punet.py +++ b/sharktank/sharktank/models/punet/tools/run_punet.py @@ -9,7 +9,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from ..model import Unet2DConditionModel, ClassifierFreeGuidanceUnetModel from ....utils.patching import SaveModuleResultTensorsPatch diff --git a/sharktank/sharktank/models/t5/__init__.py b/sharktank/sharktank/models/t5/__init__.py new file mode 100644 index 000000000..7c7e76704 --- /dev/null +++ b/sharktank/sharktank/models/t5/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .t5 import * +from .export import * diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py new file mode 100644 index 000000000..7bd5eef3d --- /dev/null +++ b/sharktank/sharktank/models/t5/export.py @@ -0,0 +1,97 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Union +from pathlib import Path +import torch + +from .t5 import T5Config, T5Encoder +from ...types import Dataset +from iree.turbine.aot import FxProgramsBuilder, export + +__all__ = [ + "export_encoder_mlir", + "export_encoder_iree_parameters", + "prune_decoder_parameters", +] + + +def export_encoder_mlir( + model: Union[T5Encoder, Path, str], + batch_sizes: list[int], + mlir_output_path: str, +): + """ + Args: + model: either the torch module or path to GGUF/IRPA. + """ + if isinstance(model, (Path, str)): + dataset = Dataset.load(model) + config = T5Config.from_gguf_properties( + dataset.properties, + # TODO: add this property to our HuggingFace-to-GGUF conversion script. + # We currently use llama.cpp's converter and it can not make a distinction + # between T5 V1 and V1.1. + # V1 uses ReLU and V1.1 uses gated GeLU. + feed_forward_proj="gated-gelu", + ) + model = T5Encoder(theta=dataset.root_theta, config=config) + + fxb = FxProgramsBuilder(model) + + for batch_size in batch_sizes: + sample_inputs = model.sample_inputs(batch_size) + + context_length_dim_idx = 1 + assert ( + sample_inputs["input_ids"].shape[context_length_dim_idx] + % config.context_length_padding_block_size + == 0 + ) + context_length_block_dim_max = ( + sample_inputs["input_ids"].shape[context_length_dim_idx] + // config.context_length_padding_block_size + ) + context_length_block_dim = torch.export.Dim( + "block", max=context_length_block_dim_max + ) + context_length_dim = ( + config.context_length_padding_block_size * context_length_block_dim + ) + dynamic_shapes = {"input_ids": {context_length_dim_idx: context_length_dim}} + + @fxb.export_program( + name=f"forward_bs{batch_size}", + args=tuple(sample_inputs.values()), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + def _( + model, + input_ids, + ): + return model(input_ids) + + output = export(fxb, import_symbolic_shape_expressions=True) + output.save_mlir(mlir_output_path) + + +def prune_decoder_parameters(dataset: Dataset): + # Remove decoder tensors/parameters if present. + try: + del dataset.root_theta.tree["dec"] + except KeyError: + pass + try: + del dataset.properties["t5.decoder_start_token_id"] + except KeyError: + pass + + +def export_encoder_iree_parameters(model_path: str, output_path: str): + dataset = Dataset.load(model_path) + prune_decoder_parameters(dataset) + dataset.save(output_path) diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py new file mode 100644 index 000000000..4ae9108d5 --- /dev/null +++ b/sharktank/sharktank/models/t5/t5.py @@ -0,0 +1,1095 @@ +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""T5 LLM adapted from transformers +https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/t5/modeling_t5.py +""" + +from typing import Any, Optional, Tuple +from dataclasses import dataclass, field +import math +import torch +from torch import nn +import copy +import logging +import warnings +from collections import OrderedDict + +from ...layers import ( + BaseLayer, + RMSNormLayer, + TokenEmbeddingLayer, + LinearLayer, +) +from ... import ops +from ...types.theta import Theta +from ...types.tensors import AnyTensor +from ...layers import FFN, T5Config + +__all__ = [ + "T5Config", + "T5LayerFF", + "T5Attention", + "T5SelfAttention", + "T5CrossAttention", + "T5Block", + "T5Stack", + "T5Encoder", +] + +logger = logging.getLogger(__name__) + + +ACT2FN = { + "gelu": nn.functional.gelu, + "gelu_new": ops.gelu_tanh_approximation, + "relu": nn.functional.relu, +} + + +class T5LayerFF(nn.Module): + def __init__( + self, + theta: Theta, + is_gated_act: bool, + dense_act_fn: str, + layer_norm_epsilon: float, + activation_dtype: torch.dtype, + ): + super().__init__() + self.dense_activation_dense = FFN( + theta=theta, is_gated=is_gated_act, activation_fn=ACT2FN[dense_act_fn] + ) + + self.layer_norm = RMSNormLayer( + theta=theta("ffn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype + ) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.dense_activation_dense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(BaseLayer): + def __init__( + self, + theta: Theta, + is_decoder: bool, + relative_attention_num_buckets: int, + relative_attention_max_distance: int, + d_model: int, + d_kv: int, + num_heads: int, + activation_dtype: torch.dtype, + has_relative_attention_bias: bool = False, + ): + super().__init__() + self.is_decoder = is_decoder + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.d_model = d_model + self.key_value_proj_dim = d_kv + self.n_heads = num_heads + self.has_relative_attention_bias = has_relative_attention_bias + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.activation_dtype = activation_dtype + + self.q = LinearLayer(theta("attn_q")) + self.k = LinearLayer(theta("attn_k")) + self.v = LinearLayer(theta("attn_v")) + self.o = LinearLayer(theta("attn_o")) + + if self.has_relative_attention_bias: + self.relative_attention_bias = TokenEmbeddingLayer( + theta("attn_rel_b"), dtype=activation_dtype + ) + self.pruned_heads = set() + + def prune_heads(self, heads): + # See transformers implementation + raise NotImplementedError() + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += ( + ops.to(relative_position > 0, dtype=torch.long) * num_buckets + ) + relative_position = ops.elementwise(torch.abs, relative_position) + else: + relative_position = -ops.elementwise( + torch.min, relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + ops.elementwise(torch.log, relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = ops.elementwise( + torch.min, + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += ops.elementwise( + torch.where, is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) + + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = ops.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = ops.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = ops.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + ops.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5SelfAttention(BaseLayer): + def __init__( + self, + theta: Theta, + is_decoder: bool, + relative_attention_num_buckets: int, + relative_attention_max_distance: int, + d_model: int, + d_kv: int, + num_heads: int, + layer_norm_epsilon: float, + activation_dtype: torch.dtype, + has_relative_attention_bias: bool = False, + ): + super().__init__() + self.attention = T5Attention( + theta=theta, + is_decoder=is_decoder, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + activation_dtype=activation_dtype, + has_relative_attention_bias=has_relative_attention_bias, + ) + self.layer_norm = RMSNormLayer( + theta=theta("attn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype + ) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + attention_output[0] + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5CrossAttention(nn.Module): + def __init__( + self, + theta: Theta, + is_decoder: bool, + relative_attention_num_buckets: int, + relative_attention_max_distance: int, + d_model: int, + d_kv: int, + num_heads: int, + layer_norm_epsilon: float, + activation_dtype: torch.dtype, + ): + super().__init__() + self.enc_dec_attention = T5Attention( + theta=theta, + is_decoder=is_decoder, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + activation_dtype=activation_dtype, + has_relative_attention_bias=False, + ) + self.layer_norm = RMSNormLayer( + theta=theta("attn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype + ) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.enc_dec_attention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + attention_output[0] + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__( + self, + theta: Theta, + is_decoder: bool, + relative_attention_num_buckets: int, + relative_attention_max_distance: int, + d_model: int, + d_kv: int, + num_heads: int, + layer_norm_epsilon: float, + is_gated_act: bool, + dense_act_fn: str, + activation_dtype: torch.dtype, + has_relative_attention_bias=False, + ): + super().__init__() + self.is_decoder = is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5SelfAttention( + theta=theta, + is_decoder=is_decoder, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + layer_norm_epsilon=layer_norm_epsilon, + activation_dtype=activation_dtype, + has_relative_attention_bias=has_relative_attention_bias, + ) + ) + if self.is_decoder: + self.layer.append( + T5CrossAttention( + theta=theta, + is_decoder=is_decoder, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + activation_dtype=activation_dtype, + layer_norm_epsilon=layer_norm_epsilon, + ) + ) + + self.layer.append( + T5LayerFF( + theta=theta, + is_gated_act=is_gated_act, + dense_act_fn=dense_act_fn, + layer_norm_epsilon=layer_norm_epsilon, + activation_dtype=activation_dtype, + ) + ) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = ops.elementwise( + torch.where, + ops.elementwise(torch.isinf, hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = ops.elementwise( + torch.clamp, hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = ops.elementwise( + torch.where, + ops.elementwise(torch.isinf, hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = ops.elementwise( + torch.clamp, hidden_states, min=-clamp_value, max=clamp_value + ) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = ops.elementwise( + torch.where, + ops.elementwise(torch.isinf, hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = ops.elementwise( + torch.clamp, hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5Stack(BaseLayer): + def __init__(self, theta: Theta, config: T5Config, embed_tokens=None): + super().__init__() + + self.embed_tokens = embed_tokens + self.config = config + self.is_decoder = config.is_decoder + theta_prefix = "dec" if config.is_decoder else "enc" + + self.block = torch.nn.ModuleList( + [ + T5Block( + theta=theta(f"{theta_prefix}.blk.{i}"), + is_decoder=config.is_decoder, + relative_attention_num_buckets=config.relative_attention_num_buckets, + relative_attention_max_distance=config.relative_attention_max_distance, + d_model=config.d_model, + d_kv=config.d_kv, + num_heads=config.num_heads, + layer_norm_epsilon=config.layer_norm_epsilon, + is_gated_act=config.is_gated_act, + dense_act_fn=config.dense_act_fn, + activation_dtype=config.activation_dtype, + has_relative_attention_bias=bool(i == 0), + ) + for i in range(config.num_layers) + ] + ) + self.add_module( + "final_layer_norm", + RMSNormLayer( + theta(f"{theta_prefix}.output_norm"), epsilon=config.layer_norm_epsilon + ), + ) + + dtypes = set(tensor.dtype for tensor in theta.flatten().values()) + assert len(dtypes) == 1 + self.dtype = dtypes.pop() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + @staticmethod + def create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device=None + ): + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + else: + device = attention_mask.device + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + return extended_attention_mask + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int], + device: torch.device = None, + dtype: torch.dtype = None, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ( + T5Stack.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = ops.to( + extended_attention_mask, dtype=dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo( + dtype + ).min + return extended_attention_mask + + def get_head_mask( + self, + head_mask: Optional[torch.Tensor], + num_hidden_layers: int, + is_attention_chunked: bool = False, + ) -> torch.Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + is_attention_chunked (`bool`, *optional*, defaults to `False`): + Whether or not the attentions scores are computed by chunks or not. + + Returns: + `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = ( + head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + ) # We can specify head_mask for each layer + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = head_mask.to( + dtype=self.dtype + ) # switch to float if need + fp16 compatibility + return head_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError( + "You have to initialize the model with valid token embeddings" + ) + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + if not self.is_decoder: + raise ValueError( + f"`use_cache` can only be set to `True` if {self} is used as a decoder" + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = inputs_embeds + + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return OrderedDict( + (k, v) + for k, v in [ + ("last_hidden_state", hidden_states), + ("past_key_values", present_key_value_states), + ("hidden_states", all_hidden_states), + ("attentions", all_attentions), + ("cross_attentions", all_cross_attentions), + ] + if v is not None + ) + + +class T5Encoder(BaseLayer): + def __init__(self, theta: Theta, config: T5Config): + super().__init__() + self.add_module( + "token_embedding", + TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype), + ) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack( + theta=theta, config=encoder_config, embed_tokens=self.token_embedding + ) + + @property + def config(self): + return self.encoder.config + + def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]: + return OrderedDict( + [ + ( + "input_ids", + torch.empty( + size=[batch_size, self.config.context_length], dtype=torch.long + ), + ) + ] + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> tuple[torch.FloatTensor]: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/sharktank/sharktank/ops/_registry.py b/sharktank/sharktank/ops/_registry.py index 3c8ea0aed..f067c047b 100644 --- a/sharktank/sharktank/ops/_registry.py +++ b/sharktank/sharktank/ops/_registry.py @@ -18,6 +18,7 @@ __all__ = [ "AllOfExprs", + "AllOfExprsVariadic", "AllOfType", "AnyOfType", "IsOfType", @@ -65,7 +66,8 @@ def __call__(self, *args: type) -> bool: class AllOfExprs(BoolTypeExpr): - """Returns True if all types match their respective boolean type expression. + """Returns True if all type arguments match their respective boolean type + expression. ```python # True. int == int and str in (float, str). @@ -87,6 +89,38 @@ def expr(*types: type): super().__init__(expr) +class AllOfExprsVariadic(BoolTypeExpr): + """Returns True if all type arguments match their respective boolean type + expression and any remaining trailing arguments match the last type expression. + + ```python + # True. int == int + # str in (float, str). + # float in (float, str). + AllOfExprsVariadic(IsOfType(int), IsOfType(float, str))(int, str, float) + + # False. str is not in (int, float). + AllOfExprsVariadic(IsOfType(int), IsOfType(int, float))(int, float, str, int) + ``` + """ + + def __init__(self, *exprs: BoolTypeExpr): + if len(exprs) == 0: + raise ValueError("At least one expression is required.") + self._exprs = list(exprs) + + def expr(*types: type): + if len(types) < len(self._exprs): + return False + exprs = self._exprs + if len(types) > len(exprs): + # pad with the trailing expression. + exprs = exprs + ([exprs[-1]] * (len(types) - len(self._exprs))) + return all([e(t) for e, t in zip(exprs, types)]) + + super().__init__(expr) + + class AllOfType(BoolTypeExpr): """Returns True if all of the types are from a set of types. diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py index fe6ae27b1..9acc7c562 100644 --- a/sharktank/sharktank/ops/custom_impls.py +++ b/sharktank/sharktank/ops/custom_impls.py @@ -5,20 +5,24 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch + from torch import Tensor, dtype +from typing import Union + import torch.nn.functional as F from ..kernels import ( + einsum_2args_q4, mmt_block_scaled_offset_q4_unsigned, mmt_block_scaled_q8, - mmtfp, mmt_super_block_scaled_offset_q4_unsigned, + bitcast_to_complex, + bitcast_to_real, ) from ..types import ( BlockScaledLayout, BlockScaledI4Layout, - InferenceTensor, PrimitiveTensor, QuantizedTensor, SuperBlockOffsetScaled_4_6_Layout, @@ -29,7 +33,7 @@ # Fused FP matmul. -# Disabled: See https://github.com/nod-ai/sharktank/issues/44 +# Disabled: See https://github.com/nod-ai/shark-ai/issues/44 # @matmul.override(Tensor, Tensor) # def matmul_mmtfp_tensor_tensor(lhs, rhs, *, transpose_rhs: bool): # lhs = unbox_tensor(lhs) @@ -44,6 +48,18 @@ # return mmtfp(lhs, rhs) +# Einsum + + +@einsum_2args.override(Tensor, QuantizedTensor) +def einsum_2args_QuantizedTensor(input0, input1, einsum_str): + unpacked = input1.unpack() + layout = input1.layout_type + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str) + + # Quantized Matmul @@ -110,3 +126,15 @@ def matmul_generic_tensor_super_block_offset_scaled_4_6_i4( sb_mins_low, rhs_unpacked.qs_bit_packed, ) + + +@view_as_complex.override(Union[Tensor, PrimitiveTensor]) +def view_as_complex(t): + t = unbox_tensor(t) + return bitcast_to_complex(t) + + +@view_as_real.override(Union[Tensor, PrimitiveTensor]) +def view_as_real(t): + t = unbox_tensor(t) + return bitcast_to_real(t) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 72f3db711..d117ada23 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -7,17 +7,24 @@ # This file contains overrides of the standard ops for normal torch and # generic primitive/quantized types. -from typing import Optional, List, Sequence, Union +from typing import Optional, List, Sequence, Union, Tuple import torch from torch import Tensor, dtype import torch.nn.functional as F from numbers import Number -from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor +from ..types import ( + PrimitiveTensor, + QuantizedTensor, + InferenceTensor, + PlanarQuantizedTensor, + BlockScaledI4Layout, +) from ..types.tensors import unbox_tensor, AnyTensor -from ._registry import AllOfType, AllOfExprs, IsOfType +from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType from .signatures import * +import iree.turbine.ops.iree @cat.override(AllOfType(Tensor, PrimitiveTensor)) @@ -61,11 +68,64 @@ def conv2d_default( conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default) conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) +# Einsum +def mk_menk_men(inputs, weights): + # batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k + inputs = inputs.unsqueeze(1) + weights_shape = weights.shape + weights = weights.view( + weights_shape[0], weights_shape[1] * weights_shape[2], weights_shape[3] + ) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +def mek_menk_men(inputs, weights): + # batch dims: me, lhs pdims: none, lhs rdims: k, rhs pdims: n, rhs rdims: k + inputs_shape = inputs.shape + inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, inputs_shape[2]) + weights_shape = weights.shape + weights = weights.view( + weights_shape[0] * weights_shape[1], weights_shape[2], weights_shape[3] + ) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +def me_men_men(inputs, weights): + # batch dims: me, lhs pdims: none, lhs rdims: none, rhs pdims: n, rhs rdims: none + inputs_shape = inputs.shape + inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, 1) + weights_shape = weights.shape + weights = weights.view(weights_shape[0] * weights_shape[1], weights_shape[2], 1) + result = matmul(inputs, weights, transpose_rhs=True) + result = result.view(weights_shape[0], weights_shape[1], weights_shape[2]) + return result + + +@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor, QuantizedTensor)) +def einsum_2args(input0, input1, einsum_str): + # Special optimized einsum kernels that lower to batch matmul + if einsum_str == "mk,menk->men": + return mk_menk_men(input0, input1) + elif einsum_str == "mek,menk->men": + return mek_menk_men(input0, input1) + elif einsum_str == "me,men->men": + return me_men_men(input0, input1) + # Default non-QuantizedTensor einsum + if not isinstance(input1, QuantizedTensor): + return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y)) + # Fallback to other kernels + return NotImplemented + + # Elementwise @elementwise.override(Tensor) -def elementwise_unary(operator, x): +def elementwise_unary(operator, x, *args, **kwargs): x = unbox_tensor(x) - return operator(x) + return operator(x, *args, **kwargs) @elementwise.override( @@ -73,11 +133,27 @@ def elementwise_unary(operator, x): IsOfType(Tensor, PrimitiveTensor), IsOfType(Tensor, PrimitiveTensor, Number) ) ) -def elementwise_binary(operator, x, y): +def elementwise_binary(operator, x, y, *args, **kwargs): x = unbox_tensor(x) if isinstance(y, PrimitiveTensor): y = unbox_tensor(y) - return operator(x, y) + return operator(x, y, *args, **kwargs) + + +@elementwise.override( + AllOfExprs( + IsOfType(Tensor, PrimitiveTensor), + IsOfType(Tensor, PrimitiveTensor, Number), + IsOfType(Tensor, PrimitiveTensor, Number), + ) +) +def elementwise_ternary(operator, x, y, z, *args, **kwargs): + x = unbox_tensor(x) + if isinstance(y, PrimitiveTensor): + y = unbox_tensor(y) + if isinstance(z, PrimitiveTensor): + z = unbox_tensor(z) + return operator(x, y, z, *args, **kwargs) # Embedding Lookup @@ -99,6 +175,49 @@ def equal_default(a, b) -> bool: return torch.equal(unbox_tensor(a), unbox_tensor(b)) +@expand.override(Tensor) +def expand_default(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + return unbox_tensor(tensor).expand(*shape) + + +@flatten.override(Tensor) +def flatten_default( + input: Union[Tensor, PrimitiveTensor], start_dim: int, end_dim: int +) -> Tensor: + return torch.flatten(unbox_tensor(input), start_dim, end_dim) + + +@gather.override(Tensor, Tensor) +def gather_default( + input: Union[Tensor, PrimitiveTensor], + dim: int, + index: Union[Tensor, PrimitiveTensor], +) -> Tensor: + return torch.gather(unbox_tensor(input), dim, unbox_tensor(index)) + + +@get_index.override(AllOfType(Tensor, PrimitiveTensor)) +def get_index_default(tensor, key): + return unbox_tensor(tensor).__get_item__(key) + + +@get_index.override(QuantizedTensor) +def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice): + unpacked = tensor.unpack() + if isinstance(unpacked, BlockScaledI4Layout): + mul = 2 + else: + return NotImplemented + new_d = unpacked._d[key] + new_qs = unpacked._qs[key] + if unpacked.m is not None: + new_m = unpacked.m[key] + dims = new_qs.shape + dims = dims[:-2] + (dims[-2] * dims[-1] * mul,) + layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=dims, layout=layout) + + @gemm.override(AllOfType(Tensor, InferenceTensor)) def gemm( a: AnyTensor, @@ -133,6 +252,37 @@ def group_norm_affine_default(input, weight, bias, *, num_groups, eps): return F.group_norm(input, num_groups=num_groups, weight=weight, bias=bias, eps=eps) +@index_copy_.override(Tensor, Tensor, Tensor) +def index_copy__default( + inout: Union[Tensor, PrimitiveTensor], + dim: int, + index: Union[Tensor, PrimitiveTensor], + tensor: Union[Tensor, PrimitiveTensor], +) -> Union[Tensor, PrimitiveTensor]: + unbox_tensor(inout).index_copy_(dim, unbox_tensor(index), unbox_tensor(tensor)) + return inout + + +@index_put_.override(AllOfType(Tensor, PrimitiveTensor)) +def index_put__default( + inout: Union[Tensor, PrimitiveTensor], + indices: Tuple[Union[Tensor, PrimitiveTensor]], + values: Union[Tensor, PrimitiveTensor], +) -> Union[Tensor, PrimitiveTensor]: + indices = tuple(unbox_tensor(index) for index in indices) + unbox_tensor(inout).index_put_(indices, unbox_tensor(values)) + return inout + + +@index_select.override(Tensor, Tensor) +def index_select_default( + tensor: Union[Tensor, PrimitiveTensor], + dim: int, + index: Union[Tensor, PrimitiveTensor], +) -> Union[Tensor, PrimitiveTensor]: + return torch.index_select(unbox_tensor(tensor), dim, unbox_tensor(index)) + + @interpolate.override(Tensor) def interpolate_default( input: Tensor, @@ -187,15 +337,21 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor: lhs = unbox_tensor(lhs) rhs = unbox_tensor(rhs) if transpose_rhs: - rhs = rhs.T - return torch.matmul(lhs, rhs.to(lhs.dtype)) + rhs = rhs.mT + rhs = rhs.to(lhs.dtype) + + if len(lhs.shape) > 2 and len(rhs.shape) < 3: + bdims = lhs.shape[:-1] + lhs = torch.flatten(lhs, 0, -2) + mm = torch.matmul(lhs, rhs) + return torch.unflatten(mm, 0, bdims) + + return torch.matmul(lhs, rhs) # Scaled dot product attention -@scaled_dot_product_attention.override( - Tensor, Tensor, Tensor, Optional[Tensor], auto_dequant=True -) -def scaled_dot_product_attention(q, k, v, a) -> Tensor: +@scaled_dot_product_attention.override(Tensor, Tensor, Tensor, None) +def scaled_dot_product_attention_torch(q, k, v, a, is_causal, scale) -> Tensor: q = unbox_tensor(q) k = unbox_tensor(k) v = unbox_tensor(v) @@ -204,18 +360,41 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor: # TODO: plumb dropout and is_causal through ops return torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=a, dropout_p=0.0, is_causal=False + q, k, v, attn_mask=a, dropout_p=0.0, is_causal=is_causal, scale=scale ) +@mean.override(Tensor) +def mean_default( + x: Tensor, dim: Union[int, List[int]], keepdim: bool, *, dtype: torch.dtype +) -> None: + return torch.mean(unbox_tensor(x), dim=dim, keepdim=keepdim, dtype=dtype) + + +@module_register_buffer.override(torch.nn.Module, Tensor) +def module_register_buffer_default( + module: torch.nn.Module, name: str, tensor: Union[Tensor, InferenceTensor] +) -> None: + return module.register_buffer(name, unbox_tensor(tensor)) + + +@repeat.override(Tensor) +def repeat_default(input: Union[Tensor, PrimitiveTensor], *sizes: List[int]) -> Tensor: + return unbox_tensor(input).repeat(*sizes) + + +@reshape.override(Tensor) +def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) -> Tensor: + return torch.reshape(unbox_tensor(input), shape) + + # RMS norm -@rms_norm.override(Tensor, Tensor) +@rms_norm.override(AllOfType(Tensor, InferenceTensor)) def rms_norm_default(x, weight, *, epsilon: float) -> Tensor: - x = unbox_tensor(x) - weight = unbox_tensor(weight) variance = x.pow(2).mean(-1, keepdim=True) - output = x * torch.rsqrt(variance + epsilon) - output = output * weight + output = x * elementwise(torch.rsqrt, variance + epsilon) + # The cast here is to match the hf implementation, affects numerics + output = elementwise(torch.mul, weight, to(output, weight.dtype)) return output @@ -234,6 +413,34 @@ def permute(tensor: Tensor, dims: List[int]): return torch.permute(torch_tensor, dims) +@softmax.override(Tensor) +def softmax_default( + tensor: Union[Tensor, PrimitiveTensor], + dim: Optional[int], + dtype: Optional[torch.dtype], +) -> Tensor: + return F.softmax(unbox_tensor(tensor), dim=dim, dtype=dtype) + + +@to.override(Tensor) +def to_default(tensor: Tensor, *args, **kwargs): + return unbox_tensor(tensor).to(*args, **kwargs) + + +@transfer_to_logical_device.override(Tensor) +def transfer_to_logical_device_default(tensor: Tensor, ordinal: int): + return iree.turbine.ops.iree.transfer_to_logical_device( + f"{ordinal}", unbox_tensor(tensor) + ) + + +@transpose.override(Tensor) +def transpose_default( + tensor: Union[Tensor, PrimitiveTensor], dim0: int, dim1: int +) -> Tensor: + return torch.transpose(unbox_tensor(tensor), dim0, dim1) + + # Sharded default impls (do nothing). @@ -245,3 +452,46 @@ def sharded_cat_unsharded(maybe_sharded): @sharded_sum.override(Tensor) def sharded_sum_unsharded(maybe_sharded): return unbox_tensor(maybe_sharded) + + +@unflatten.override(Tensor) +def unflatten_default( + input: Union[Tensor, PrimitiveTensor], dim: int, sizes: Tuple[int] +) -> Tensor: + return torch.unflatten(unbox_tensor(input), dim, sizes) + + +@unsqueeze.override(Tensor) +def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tensor: + return torch.unsqueeze(tensor, dim) + + +@view.override(Tensor) +def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor: + return unbox_tensor(tensor).view(*shape) + + +@view.override(QuantizedTensor) +def view_QuantizedTensor(tensor: QuantizedTensor, shape): + unpacked = tensor.unpack() + if not isinstance(unpacked, BlockScaledI4Layout): + return NotImplemented + bs = 16 + shape = list(shape) + new_d = unpacked._d.view(shape[:-1] + [shape[-1] // 32, 1]) + qs_shape = shape[:-1] + [shape[-1] // 32, 16] + new_qs = unpacked._qs.view(qs_shape) + if unpacked.m is not None: + new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1]) + layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=shape, layout=layout) + + +@view_as_complex.override(Tensor) +def view_as_complex_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor: + return torch.view_as_complex(unbox_tensor(tensor)) + + +@view_as_real.override(Tensor) +def view_as_real_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor: + return torch.view_as_real(unbox_tensor(tensor)) diff --git a/sharktank/sharktank/ops/qconv_impls.py b/sharktank/sharktank/ops/qconv_impls.py index af1199976..add64a605 100644 --- a/sharktank/sharktank/ops/qconv_impls.py +++ b/sharktank/sharktank/ops/qconv_impls.py @@ -12,6 +12,7 @@ import warnings import torch +import torch.nn.functional as F from sharktank import kernels @@ -31,7 +32,7 @@ ) -def qconv2d_tensor_scaled_integer( +def qconv2d_tensor_scaled( input: QuantizedTensor, weight: QuantizedTensor, bias: Optional[AnyTensor] = None, @@ -59,12 +60,16 @@ def qconv2d_tensor_scaled_integer( input_layout: TensorScaledLayout = input.unpack() weight_layout: TensorScaledLayout = weight.unpack() - # Only handle integer quantizations. + # # Handle integer and fp8 quantizations. if ( input_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point ): - return NotImplemented + if ( + input_layout.qs.dtype != torch.float8_e4m3fnuz + or weight_layout.qs.dtype != torch.float8_e4m3fnuz + ): + return NotImplemented # Bias is both optional and may either be quantized or fp. bias_qs = None @@ -86,6 +91,8 @@ def qconv2d_tensor_scaled_integer( # Alias components (d=scale, qs=quantized samples, m=offset). if accum_dtype is None: accum_dtype = torch.int32 + if weight_layout.qs.dtype.is_floating_point: + accum_dtype = torch.float32 input_d = input_layout.d input_dtype = input_layout.dtype input_qs = input_layout.qs @@ -113,8 +120,8 @@ def qconv2d_tensor_scaled_integer( padding = _expand_int_to_2_tuple(padding) dilation = _expand_int_to_2_tuple(dilation) extended_padding_list = [item for item in padding for _ in range(2)] - padded_input = _pad_last_2d(input_qs, extended_padding_list) - y_qs = _invoke_int32_conv2d( + padded_input = F.pad(input_qs, pad=extended_padding_list) + y_qs = _invoke_conv2d_kernel( padded_input, weight_qs, bias_qs.to(accum_dtype) if bias_qs is not None else None, @@ -145,7 +152,7 @@ def qconv2d_tensor_scaled_integer( weight_offset_fix = torch.sum( padded_input, dim=1, keepdim=True, dtype=accum_dtype ) - weight_offset_fix = _invoke_int32_pooling_sum( + weight_offset_fix = _invoke_pooling_sum_kernel( weight_offset_fix, [weight_qs.shape[2], weight_qs.shape[3]], stride, @@ -188,13 +195,11 @@ def qconv2d_tensor_scaled_integer( return y -conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled_integer) -conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)( - qconv2d_tensor_scaled_integer -) +conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled) +conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qconv2d_tensor_scaled) -def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype): +def _invoke_conv2d_kernel(input, weight, bias, stride, dilation, *, accum_dtype): """Does a low level invocation of a conv2d integer kernel on an explicitly padded input. This presumes that the stride/padding/dilation have already been normalized @@ -233,7 +238,7 @@ def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype): return y_qs -def _invoke_int32_pooling_sum(input, kernel_size, stride, dilation, *, accum_dtype): +def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dtype): """Invokes either a custom integer pooling sum or the built-in fp avg_pool2d kernel on an explicitly padded input. """ @@ -254,27 +259,6 @@ def _invoke_int32_pooling_sum(input, kernel_size, stride, dilation, *, accum_dty return output -def _pad_last_2d(input_tensor, pad_width): - # pad_width should be in the format [pad_left, pad_right, pad_top, pad_bottom] - pad_left, pad_right, pad_top, pad_bottom = pad_width - batch_size, channels, height, width = input_tensor.shape - - # Create a new tensor with the desired padded size filled with zeros - padded_height = height + pad_top + pad_bottom - padded_width = width + pad_left + pad_right - padded_tensor = torch.zeros( - (batch_size, channels, padded_height, padded_width), - dtype=input_tensor.dtype, - device=input_tensor.device, - ) - - # Copy the values from the input tensor to the appropriate location in the padded tensor - padded_tensor[ - :, :, pad_top : pad_top + height, pad_left : pad_left + width - ] = input_tensor - return padded_tensor - - def _flatten_input_scale_offset_channels(d, m): """Flattens either a 4d or 0d scale/offset as [N, C, H, W] to 1D. diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index 0a381d613..b66d3be1d 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -28,7 +28,7 @@ from sharktank import kernels -def qlinear_tensor_scaled_integer( +def qlinear_tensor_scaled( x: QuantizedTensor, weight: QuantizedTensor, bias: Optional[AnyTensor], @@ -48,9 +48,15 @@ def qlinear_tensor_scaled_integer( x_layout: TensorScaledLayout = x.unpack() weight_layout: TensorScaledLayout = weight.unpack() - # Only handle integer quantizations. + # Handle only integer and fp8 quantizations. if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: - return NotImplemented + if x_layout.qs.dtype == torch.float8_e4m3fnuz: + # assume quark + return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to( + torch.float16 + ) + else: + return NotImplemented # Bias. quantized_bias_accum = False @@ -65,6 +71,8 @@ def qlinear_tensor_scaled_integer( # Alias components (d=scale, qs=quantized samples, m=offset) if accum_dtype is None: accum_dtype = torch.int32 + if weight_layout.qs.dtype.is_floating_point: + accum_dtype = torch.float32 x_d = x_layout.d x_dtype = x_layout.dtype x_qs = x_layout.qs @@ -86,7 +94,7 @@ def qlinear_tensor_scaled_integer( # TODO: Handle permutation that we have a kernel for. # Fall back to automatic fusion based on integer, high precision matmul. - y_qs = _invoke_int32_mmt(x_qs, weight_qs, accum_dtype=accum_dtype) + y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype) # Offset correction. By applying the offset correction in post, it is # set up to fuse with its consumer, which is already doing additional @@ -143,10 +151,8 @@ def qlinear_tensor_scaled_integer( # Overrload for both bias and no bias. -linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled_integer) -linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)( - qlinear_tensor_scaled_integer -) +linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled) +linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qlinear_tensor_scaled) def linear_quantized_weight( @@ -166,19 +172,30 @@ def linear_quantized_weight( linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight) -def _invoke_int32_mmt(lhs, rhs, *, accum_dtype): +def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype): if debugging.flags.use_custom_iree_kernels: # The custom kernel requires that the lhs and rhs be the same # rank. Broadcast the rhs to match. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) + # If input to the kernel is 2D, expand the tensor to add the batch + # dimension. + if lhs_rank == 2: + lhs_size = [1] + list(lhs.shape) + lhs = lhs.unsqueeze(0).expand(lhs_size) + lhs_rank = len(lhs.shape) if rhs_rank < lhs_rank: assert (rhs_rank + 1) == lhs_rank rhs_size = [lhs.shape[0]] + list(rhs.shape) rhs = rhs.unsqueeze(0).expand(rhs_size) + rhs_rank = len(rhs.shape) y_qs = kernels.batch_matmul_transpose_b( lhs.to(accum_dtype), rhs.to(accum_dtype) ) + # Squeeze the batch dimension to maintain shape parity with other + # layers. + if len(y_qs.shape) > 2: + y_qs = y_qs.squeeze(0) else: # FP emulation. y_qs = torch.matmul( diff --git a/sharktank/sharktank/ops/shape.py b/sharktank/sharktank/ops/shape.py index 69683f84f..0616b97b6 100644 --- a/sharktank/sharktank/ops/shape.py +++ b/sharktank/sharktank/ops/shape.py @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Sequence +from typing import Sequence, Optional from ..types.tensors import AnyTensor @@ -53,3 +53,12 @@ def broadcast_dims( ranks = [len(shape) for shape in shaped_or_shape] broadcast_rank = max(ranks) return [dim + max(0, broadcast_rank - rank) for dim, rank in zip(dims, ranks)] + + +def unbroadcast_dim(dim: int, shapes: Sequence[Sequence[int]]) -> Optional[int]: + """Returns the dimension in `shapes[0]` such that it would correspond to `dim` + after broadcasting the shapes `shapes`.""" + ranks = [len(shape) for shape in shapes] + broadcast_rank = max(ranks) + res = dim - max(0, broadcast_rank - ranks[0]) + return None if res < 0 else res diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index ed639915e..07f466f5b 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -6,14 +6,17 @@ import torch from torch import Tensor -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Union, Any, Tuple import itertools from numbers import Number +import math +import functools from ..types import ( AnyTensor, DefaultPrimitiveTensor, InferenceTensor, + PrimitiveTensor, ReplicatedTensor, ShardedTensor, sharding, @@ -22,29 +25,67 @@ UnreducedTensor, ) from ..types.tensors import unbox_tensor -from ._registry import AllOfType +from ._registry import AllOfType, AllOfExprsVariadic, IsOfType from .signatures import * -from .shape import broadcast_dims +from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim +from ..utils import longest_equal_range @all_gather.override(SplitPrimitiveTensor) def all_gather_split( input: SplitPrimitiveTensor, *, dim: int | None ) -> ReplicatedTensor: - assert ( - dim is None - ), "gather dimension other than `input.shard_dim` is not supported." - # TODO: figure out how to avoid common sub-expression elimination to not - # merge all these into one. - # Even if we place each resulting shard inside of ReplicatedTensor on a - # distinct logical device with an explicit operation, CSE should still - # collapse them. - shards = [sharded_cat(input) for i in range(input.shard_count)] + dim = input.shard_dim if dim is None else dim + # For each device move the shards to it and do a concatenation. + # If we don't move first, common sub-expression elimination is free to collapse all + # concatenations into one and then copy to all devices, which is not what we want. + shards = [ + cat( + [ + shard if i == j else transfer_to_logical_device(shard, i) + for j, shard in enumerate(input.shards) + ], + dim=dim, + ) + for i in range(input.shard_count) + ] + return ReplicatedTensor(ts=shards) + + +@all_reduce.override(AllOfType(SplitPrimitiveTensor, UnreducedTensor)) +def all_reduce_split_or_unreduced( + input: Union[SplitPrimitiveTensor, UnreducedTensor], +) -> ReplicatedTensor: + # For each device move the shards to it and do a reduction. + # If we don't move first, common sub-expression elimination is free to collapse all + # reductions into one and then copy to all devices, which is not what we want. + shards = [ + functools.reduce( + lambda x, y: elementwise(torch.add, x, y), + [ + shard if i == j else transfer_to_logical_device(shard, i) + for j, shard in enumerate(input.shards) + ], + ) + for i in range(input.shard_count) + ] + return ReplicatedTensor(ts=shards) + + +@cat.override(AllOfType(ReplicatedTensor)) +def cat_replicated(tensors: Sequence[ReplicatedTensor], dim: int) -> ReplicatedTensor: + assert len(tensors) > 0 + shard_count = tensors[0].shard_count + assert all([t.shard_count == shard_count for t in tensors]) + + shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])] return ReplicatedTensor(ts=shards) @cat.override(AllOfType(SplitPrimitiveTensor)) -def cat_sharded(tensors: Sequence[SplitPrimitiveTensor], dim: int): +def cat_split( + tensors: Sequence[SplitPrimitiveTensor], dim: int +) -> SplitPrimitiveTensor: assert len(tensors) > 0 assert all( [ @@ -61,6 +102,7 @@ def cat_sharded(tensors: Sequence[SplitPrimitiveTensor], dim: int): return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) else: # TODO: implement efficient cat along split dim. + # This would probably result in doing the concatenation on one device. concatenated_unsharded = cat( [shard for t in tensors for shard in t.shards], dim ) @@ -88,7 +130,7 @@ def conv2d_all_split( input.is_replicated or input.shard_dim == 1 ), "Only sharding of input channel dimension is supported" assert ( - weight.shard_dim == 0 and bias.shard_dim == 0 + bias is None or weight.shard_dim == 0 and bias.shard_dim == 0 ), "Only sharding of output channel dimension is supported" # TODO: allow for implementation where we don't all-gather, but gather @@ -146,7 +188,7 @@ def conv2d_replicated_input_split_weight_and_bias( assert input.shard_count == weight.shard_count assert bias is None or weight.shard_count == bias.shard_count assert ( - weight.shard_dim == 0 and bias.shard_dim == 0 + bias is None or weight.shard_dim == 0 and bias.shard_dim == 0 ), "Only sharding of output channel dimension is supported" assert groups == 1 @@ -189,7 +231,8 @@ def conv2d_split_weight_and_bias( accum_dtype, ) -> SplitPrimitiveTensor: assert accum_dtype is None, "accum_dtype not supported" - assert weight.shard_count == bias.shard_count + if bias is not None: + assert weight.shard_count == bias.shard_count # Output channels dimension is split. if weight.shard_dim == 0 and groups == 1: @@ -225,49 +268,67 @@ def conv2d_split_weight_and_bias( @elementwise.override(ReplicatedTensor) -def replicated_elementwise_unary(operator, x: ReplicatedTensor): - partials = [operator(unbox_tensor(pt)) for pt in x.shards] +def replicated_elementwise_unary(operator, x: ReplicatedTensor, *args, **kwargs): + partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] return ReplicatedTensor(ts=partials) @elementwise.override(SplitPrimitiveTensor) -def split_elementwise_unary(operator, x: SplitPrimitiveTensor): - partials = [operator(unbox_tensor(pt)) for pt in x.shards] +def split_elementwise_unary(operator, x: SplitPrimitiveTensor, *args, **kwargs): + partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards] return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) +@elementwise.override(ReplicatedTensor, ReplicatedTensor) +def replicated_elementwise_binary( + operator, x: ReplicatedTensor, y: ReplicatedTensor, *args, **kwargs +): + assert x.shard_count == y.shard_count + shards = [ + operator(unbox_tensor(shard_x), unbox_tensor(shard_y), *args, **kwargs) + for shard_x, shard_y in zip(x.shards, y.shards) + ] + return ReplicatedTensor(ts=shards) + + @elementwise.override(SplitPrimitiveTensor, SplitPrimitiveTensor) def split_elementwise_binary( - operator, x: SplitPrimitiveTensor, y: SplitPrimitiveTensor + operator, x: SplitPrimitiveTensor, y: SplitPrimitiveTensor, *args, **kwargs ): assert x.shard_count == y.shard_count x_shard_dim, y_shard_dim = broadcast_dims([x.shard_dim, y.shard_dim], [x, y]) assert x_shard_dim == y_shard_dim pt_xs = [unbox_tensor(pt) for pt in x.shards] pt_ys = [unbox_tensor(pt) for pt in y.shards] - partials = [operator(pt_x, pt_y) for pt_x, pt_y in zip(pt_xs, pt_ys)] - return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + partials = [ + operator(pt_x, pt_y, *args, **kwargs) for pt_x, pt_y in zip(pt_xs, pt_ys) + ] + return SplitPrimitiveTensor( + shard_dim=x.shard_dim, + shape=torch.broadcast_shapes(x.shape, y.shape), + ts=partials, + ) @elementwise.override(SplitPrimitiveTensor, Number) def elementwise_binary_split_lhs_scalar_rhs( - operator, x: SplitPrimitiveTensor, y: Number + operator, x: SplitPrimitiveTensor, y: Number, *args, **kwargs ): pt_xs = [unbox_tensor(pt) for pt in x.shards] - partials = [operator(pt_x, y) for pt_x in pt_xs] + partials = [operator(pt_x, y, *args, **kwargs) for pt_x in pt_xs] return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) @elementwise.override(SplitPrimitiveTensor, Tensor) def elementwise_binary_split_lhs_tensor_rhs( - operator, x: SplitPrimitiveTensor, y: Tensor + operator, x: SplitPrimitiveTensor, y: Tensor, *args, **kwargs ): - return elementwise(operator, x, replicate(y, count=x.shard_count)) + return elementwise(operator, x, reshard_like(y, like=x), *args, **kwargs) @elementwise.override(ReplicatedTensor, SplitPrimitiveTensor) def elementwise_binary_replicated_lhs_sharder_rhs( - operator, x: ReplicatedTensor, y: SplitPrimitiveTensor + operator, x: ReplicatedTensor, y: SplitPrimitiveTensor, *args, **kwargs ): if x.shard_count != y.shard_count: raise ValueError( @@ -276,20 +337,76 @@ def elementwise_binary_replicated_lhs_sharder_rhs( # A replicated tensor can be split with no cost. # It is natural to propagate the split instead of the replication. x_sharded = reshard_like(x, like=y) - return elementwise(operator, x_sharded, y) + return elementwise(operator, x_sharded, y, *args, **kwargs) @elementwise.override(SplitPrimitiveTensor, ReplicatedTensor) def elementwise_binary_split_lhs_replicated_rhs( - operator, x: SplitPrimitiveTensor, y: ReplicatedTensor + operator, x: SplitPrimitiveTensor, y: ReplicatedTensor, *args, **kwargs ): assert len(y.shape) > 0, "0-rank not supported" if x.shard_count != y.shard_count: raise ValueError( f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})" ) + + shard_dim_in_res = broadcast_dim(x.shard_dim, [x.shape, y.shape]) + shard_dim_in_y = unbroadcast_dim(shard_dim_in_res, [y.shape, x.shape]) + is_shard_dim_broadcasted_in_y = ( + shard_dim_in_y is None or y.shape[shard_dim_in_y] == 1 + ) + if is_shard_dim_broadcasted_in_y: + shards = [ + elementwise(operator, x_shard, y_shard) + for x_shard, y_shard in zip(x.shards, y.shards) + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim_in_res) + y_sharded = reshard_like(y, like=x) - return elementwise(operator, x, y_sharded) + return elementwise(operator, x, y_sharded, *args, **kwargs) + + +@elementwise.override(ReplicatedTensor, UnreducedTensor) +def elementwise_binary_replicated_lhs_unreduced_rhs( + operator, x: ReplicatedTensor, y: UnreducedTensor, *args, **kwargs +): + if x.shard_count != y.shard_count: + raise ValueError( + f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})" + ) + y_replicated = reshard_like(y, like=x) + return elementwise(operator, x, y_replicated, *args, **kwargs) + + +@elementwise.override(ReplicatedTensor, Tensor) +def elementwise_binary_replicated_lhs_unsharded_rhs( + operator, x: ReplicatedTensor, y: Tensor, *args, **kwargs +): + y_replicated = reshard_like(y, like=x) + return elementwise(operator, x, y_replicated, *args, **kwargs) + + +@elementwise.override(Tensor, ReplicatedTensor) +def elementwise_binary_replicated_lhs_unsharded_rhs( + operator, x: Tensor, y: ReplicatedTensor, *args, **kwargs +): + x_replicated = reshard_like(x, like=y) + return elementwise(operator, x_replicated, y, *args, **kwargs) + + +# Embedding Lookup +@embedding_lookup.override(ReplicatedTensor, ReplicatedTensor) +def embedding_lookup_default( + input: ReplicatedTensor, embedding_matrix: ReplicatedTensor, dtype: torch.dtype +): + assert input.shard_count == embedding_matrix.shard_count + shards = [ + embedding_lookup(input_shard, embedding_matrix_shard, dtype) + for input_shard, embedding_matrix_shard in zip( + input.shards, embedding_matrix.shards + ) + ] + return ReplicatedTensor(ts=shards) @equal.override(ReplicatedTensor) @@ -302,6 +419,77 @@ def equal_split(a: SplitPrimitiveTensor, b: AnyTensor) -> bool: return a.is_deep_equal(b) +@expand.override(SplitPrimitiveTensor) +def expand_split( + tensor: SplitPrimitiveTensor, shape: List[int] +) -> SplitPrimitiveTensor: + assert len(shape) == len(tensor.shape) + expanded_dims = [ + i + for i, (old_dim, new_dim) in enumerate(zip(tensor.shape, shape)) + if old_dim == 1 and new_dim != 1 + ] + assert ( + tensor.shard_dim not in expanded_dims + ), "Expanding a split dimension is not supported" + + def set_element(l: List, idx: int, el: Any) -> List: + l[idx] = el + return l + + shards = [ + expand( + shard, + set_element(list(shape), tensor.shard_dim, shard.shape[tensor.shard_dim]), + ) + for shard in tensor.shards + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@flatten.override(ReplicatedTensor) +def flatten_replicated( + input: ReplicatedTensor, start_dim: int, end_dim: int +) -> ReplicatedTensor: + shards = [shard.flatten(start_dim, end_dim) for shard in input.shards] + return ReplicatedTensor(ts=shards) + + +@flatten.override(SplitPrimitiveTensor) +def flatten_split( + input: SplitPrimitiveTensor, start_dim: int, end_dim: int +) -> SplitPrimitiveTensor: + end_dim_resolved = len(input.shape) - 1 if end_dim == -1 else end_dim + assert input.shard_dim <= start_dim or end_dim_resolved < input.shard_dim, ( + "Flattening of a sharded dimension that is not the leading dimension in the" + " flattening dimension range is not supported. This would result in a" + " block-cyclic sharding which is not implemented." + ) + assert ( + input.shard_dim != start_dim + or input.shape[input.shard_dim] % input.shard_count == 0 + ), "If the leading flattening dimension is the split dimension, its size must be divisible by the shard count." + shards = [shard.flatten(start_dim, end_dim) for shard in input.shards] + shard_dim = ( + input.shard_dim + if input.shard_dim <= start_dim + else input.shard_dim - (end_dim_resolved - start_dim) + ) + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + +@gather.override(ReplicatedTensor, ReplicatedTensor) +def gather_replicated( + input: ReplicatedTensor, dim: int, index: ReplicatedTensor +) -> Tensor: + assert input.shard_count == index.shard_count + shards = [ + gather(input_shard, dim, index_shard) + for input_shard, index_shard in zip(input.shards, index.shards) + ] + return ReplicatedTensor(ts=shards) + + @group_norm_affine.override( SplitPrimitiveTensor, SplitPrimitiveTensor, SplitPrimitiveTensor ) @@ -322,6 +510,78 @@ def shareded_group_norm_affine(input, weight, bias, *, num_groups, eps): return SplitPrimitiveTensor(shard_dim=1, ts=result_shards) +@index_copy_.override(SplitPrimitiveTensor, ReplicatedTensor, SplitPrimitiveTensor) +def index_copy__split_replicated_split( + inout: SplitPrimitiveTensor, + dim: int, + index: ReplicatedTensor, + tensor: SplitPrimitiveTensor, +) -> SplitPrimitiveTensor: + assert ( + inout.shard_count == index.shard_count + and inout.shard_count == tensor.shard_count + ) + assert inout.shard_dim == tensor.shard_dim + assert inout.shard_dim != dim + for inout_shard, index_shard, tensor_shard in zip( + inout.shards, index.shards, tensor.shards + ): + index_copy_(inout_shard, dim, index_shard, tensor_shard) + return inout + + +@index_put_.override( + AllOfExprsVariadic( + IsOfType(SplitPrimitiveTensor), + IsOfType(SplitPrimitiveTensor), + IsOfType(Tensor, PrimitiveTensor, ReplicatedTensor), + ) +) +def index_put__split( + inout: SplitPrimitiveTensor, + indices: Tuple[Union[Tensor, PrimitiveTensor, ReplicatedTensor]], + values: SplitPrimitiveTensor, +) -> SplitPrimitiveTensor: + # TODO: verify that the values split dimension is not being indexed or implement + # this case. + indices = [replicate(idx, count=inout.shard_count) for idx in indices] + for i, shard in enumerate(inout.shards): + shard_indices = [idx.shards[i] for idx in indices] + shard.index_put_(shard_indices, values.shards[i]) + return inout + + +@index_select.override(ReplicatedTensor, ReplicatedTensor) +def index_select_replicated( + tensor: ReplicatedTensor, + dim: int, + index: ReplicatedTensor, +) -> ReplicatedTensor: + assert tensor.shard_count == index.shard_count + shards = [ + index_select(tensor_shard, dim, index_shard) + for tensor_shard, index_shard in zip(tensor.shards, index.shards) + ] + return ReplicatedTensor(ts=shards) + + +@index_select.override(SplitPrimitiveTensor, ReplicatedTensor) +def index_select_split_replicated( + tensor: SplitPrimitiveTensor, + dim: int, + index: ReplicatedTensor, +) -> ReplicatedTensor: + assert tensor.shard_count == index.shard_count + assert ( + dim != tensor.shard_dim + ), "Indexing along the split dimension is not supported." + shards = [ + index_select(tensor_shard, dim, index_shard) + for tensor_shard, index_shard in zip(tensor.shards, index.shards) + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + @interpolate.override(ReplicatedTensor) def interpolate_replicated( input: ReplicatedTensor, @@ -416,13 +676,14 @@ def matmul_replicated_lhs_split_rhs( lhs: ReplicatedTensor, rhs: SplitPrimitiveTensor, *, transpose_rhs: bool ) -> SplitPrimitiveTensor | UnreducedTensor: assert lhs.shard_count == rhs.shard_count + assert len(rhs.shape) == 2 if transpose_rhs: return matmul(lhs, rhs.T) rhs_reduction_dim = 1 if rhs_reduction_dim != rhs.shard_dim: - lhs_reduction_dimension = 0 + lhs_reduction_dimension = len(lhs.shape) - 1 lhs_split = reshard_split( lhs, dim=lhs_reduction_dimension, count=lhs.shard_count ) @@ -496,30 +757,108 @@ def matmul_split( f"Cannot matmul split tensors of different shard_count: " f"({lhs.shard_count} vs {rhs.shard_count})" ) + if transpose_rhs: + return matmul(lhs, rhs.T) lhs_reduction_dim = len(lhs.shape) - 1 - rhs_reduction_dim = 1 if transpose_rhs else 0 + rhs_reduction_dim = len(rhs.shape) - 2 if len(rhs.shape) > 1 else len(rhs.shape) - 1 # The reduction dimension is split on both tensors. if lhs_reduction_dim == lhs.shard_dim and rhs_reduction_dim == rhs.shard_dim: partials = [ - matmul(partial_lhs, partial_rhs, transpose_rhs=transpose_rhs) + matmul(partial_lhs, partial_rhs) for partial_lhs, partial_rhs in zip(lhs.shards, rhs.shards) ] return UnreducedTensor(ts=partials) + is_batched_matmul = len(lhs.shape) > 2 or len(rhs.shape) > 2 + if ( + is_batched_matmul + and len(lhs.shape) == len(rhs.shape) + and lhs.shard_dim == rhs.shard_dim + ): + # The same batch dim is sharded for both arguments. + shards = [ + matmul(lhs_shard, rhs_shard) + for lhs_shard, rhs_shard in zip(lhs.shards, rhs.shards) + ] + return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim) + + # -1 for missing parallel dim. + lhs_parallel_dim = len(lhs.shape) - 2 + rhs_parallel_dim = len(rhs.shape) - 1 if len(rhs.shape) > 1 else -1 + # One parallel dimension is split for each tensor. - if lhs_reduction_dim != lhs.shard_dim and rhs_reduction_dim != rhs.shard_dim: - if transpose_rhs: - rhs = rhs.T + # Or lhs batch dim and rhs parallel dim are split. + if lhs.shard_dim <= lhs_parallel_dim and rhs_parallel_dim == rhs.shard_dim: # We gather along the rhs shard dim. # It is more natural to preserve the sharding axis of the input. - shards = [sharded_cat(matmul(lhs_shard, rhs)) for lhs_shard in lhs.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim) + # TODO: This assumes non-peered memory. We prepare the operands to be + # available on the required devices. + # We need to distinguish based on some config. + replicated_rhs = replicate(rhs, count=lhs.shard_count) + return matmul(lhs, replicated_rhs) assert False, "Sharding configuration not supported" +# Scaled dot product attention +@scaled_dot_product_attention.override( + SplitPrimitiveTensor, + SplitPrimitiveTensor, + SplitPrimitiveTensor, + Optional[ReplicatedTensor], +) +def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor: + if q.shard_count != k.shard_count or q.shard_count != v.shard_count: + raise ValueError("Incompatible number of shards for qkv") + + if a and q.shard_count != a.shard_count: + raise ValueError( + f"Incompatible number of shards for a ({a.shard_count}) should be ({q.shard_count})" + ) + + if q.shard_dim != k.shard_dim or q.shard_dim != v.shard_dim: + raise ValueError("Incompatible shard dim across qkv") + + if q.shard_dim > len(q.shards[0].shape) - 2: + raise ValueError("Sharding must occur as batch dimension") + + a_shards = [None] * q.shard_count + if a is not None: + a_shards = a.shards + + output_shards = [] + for q_s, k_s, v_s, a_s in zip(q.shards, k.shards, v.shards, a_shards): + o_s = scaled_dot_product_attention( + q_s, k_s, v_s, a_s, is_causal=is_causal, scale=scale + ) + output_shards.append(o_s) + + return SplitPrimitiveTensor(ts=output_shards, shard_dim=q.shard_dim) + + +@mean.override(ReplicatedTensor) +def mean_replicated( + x: ReplicatedTensor, + dim: Union[int, List[int]], + keepdim: bool, + *, + dtype: torch.dtype, +) -> None: + shards = [mean(shard, dim=dim, keepdim=keepdim, dtype=dtype) for shard in x.shards] + return ReplicatedTensor(ts=shards) + + +@module_register_buffer.override(torch.nn.Module, ShardedTensor) +def module_register_buffer_sharded( + module: torch.nn.Module, name: str, tensor: ShardedTensor +) -> None: + for i, shard in enumerate(tensor.shards): + module_register_buffer(module, f"{name}__shard__{i}", shard) + setattr(module, name, tensor) + + @permute.override(SplitPrimitiveTensor) def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] @@ -528,23 +867,60 @@ def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]): @permute.override(ReplicatedTensor) -def permute_split(tensor: ReplicatedTensor, dims: List[int]): +def permute_replicated(tensor: ReplicatedTensor, dims: List[int]): permuted_shards = [permute(shard, dims) for shard in tensor.shards] return ReplicatedTensor(ts=permuted_shards) +@repeat.override(ReplicatedTensor) +def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedTensor: + shards = [repeat(shard, *sizes) for shard in input.shards] + return ReplicatedTensor(ts=shards) + + @replicate.override(ReplicatedTensor) def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor: if input.shard_count != count: raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") - assert input.shard_count == count return input +@replicate.override(SplitPrimitiveTensor) +def replicate_split(input: SplitPrimitiveTensor, *, count: int) -> ReplicatedTensor: + if input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") + return all_gather(input) + + +@replicate.override(UnreducedTensor) +def replicate_unreduced(input: UnreducedTensor, *, count: int) -> ReplicatedTensor: + if input.shard_count != count: + raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})") + return all_reduce(input) + + @replicate.override(Tensor) def replicate_unsharded(input, *, count: int) -> ReplicatedTensor: torch_input = unbox_tensor(input) - return ReplicatedTensor(ts=torch_input, shard_count=count) + # If we have a torch input replicating we can assume we need to transfer: + torch_inputs = [transfer_to_logical_device(torch_input, i) for i in range(count)] + return ReplicatedTensor(ts=torch_inputs) + + +@reshape.override(SplitPrimitiveTensor) +def reshape_split( + tensor: SplitPrimitiveTensor, shape: List[int] +) -> SplitPrimitiveTensor: + if _reshape_get_single_split_dim(tensor.shape, shape) is not None: + return view(tensor, shape) + + flatten_dim_range = _reshape_get_flatten_dim_range(tensor.shape, shape) + if flatten_dim_range is not None: + return flatten(tensor, flatten_dim_range[0], flatten_dim_range[1] - 1) + + raise ValueError( + f"Unsupported reshaping of sharded split tensor of shape {tensor.shape} to shape {shape}" + ) @reshard.override(Tensor, sharding.Split) @@ -572,7 +948,13 @@ def make_value(input: Theta | InferenceTensor, spec) -> dict | InferenceTensor: result.name = input.name return result - return Theta({k: make_value(input(k), spec[k]) for k in input.keys}) + return Theta( + { + k: make_value(input(k), spec[k]) + for k in input.keys + if not isinstance(spec[k], sharding.Ignore) + } + ) @reshard.override(Theta, sharding.ThetaLayerSharding) @@ -691,13 +1073,23 @@ def reshard_like_split_to_split( return tensor -# Sharded sum. +@reshard_like.override(UnreducedTensor, ReplicatedTensor) +def reshard_like_unreduced_to_replicated( + tensor: UnreducedTensor, like: ReplicatedTensor +) -> ReplicatedTensor: + return replicate(tensor, count=like.shard_count) @sharded_cat.override(SplitPrimitiveTensor) -def sharded_cat_unsharded(maybe_sharded: SplitPrimitiveTensor): - shard_ts = [t.as_torch() for t in maybe_sharded.shards] - return torch.cat(shard_ts, dim=maybe_sharded.shard_dim) +def sharded_cat_unsharded(tensor: SplitPrimitiveTensor): + shard_ts = [ + transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else shard.as_torch() + for i, shard in enumerate(tensor.shards) + ] + return torch.cat(shard_ts, dim=tensor.shard_dim) + + +# Sharded sum. def _sharded_sum_sharded(tensor: ShardedTensor) -> Tensor: @@ -708,16 +1100,67 @@ def _sharded_sum_sharded(tensor: ShardedTensor) -> Tensor: @sharded_sum.override(SplitPrimitiveTensor) -def sharded_sum_split(maybe_sharded: SplitPrimitiveTensor): +def sharded_sum_split(maybe_sharded: SplitPrimitiveTensor) -> Tensor: # TODO: Should implement as an all reduce. return _sharded_sum_sharded(maybe_sharded) @sharded_sum.override(UnreducedTensor) -def sharded_sum_unreduced(maybe_sharded: UnreducedTensor): +def sharded_sum_unreduced(maybe_sharded: UnreducedTensor) -> Tensor: return _sharded_sum_sharded(maybe_sharded) +@softmax.override(SplitPrimitiveTensor) +def softmax_split( + tensor: SplitPrimitiveTensor, dim: Optional[int], dtype: Optional[torch.dtype] +) -> Tensor: + dim = dim if dim is None or dim >= 0 else len(tensor.shape) + dim + assert ( + dim is not None and dim != tensor.shard_dim + ), "Softmax along split dimension is not supported." + shards = [softmax(shard, dim=dim, dtype=dtype) for shard in tensor.shards] + return SplitPrimitiveTensor( + ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape + ) + + +@to.override(ReplicatedTensor) +def to_replicated(tensor: ReplicatedTensor, *args, **kwargs): + shards = [to(shard, *args, **kwargs) for shard in tensor.shards] + return ReplicatedTensor(ts=shards) + + +@to.override(SplitPrimitiveTensor) +def to_split(tensor: SplitPrimitiveTensor, *args, **kwargs): + shards = [to(shard, *args, **kwargs) for shard in tensor.shards] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@transpose.override(SplitPrimitiveTensor) +def transpose_split( + tensor: SplitPrimitiveTensor, dim0: int, dim1: int +) -> SplitPrimitiveTensor: + shards = [transpose(shard, dim0, dim1) for shard in tensor.shards] + shard_dim = tensor.shard_dim + if shard_dim == dim0: + shard_dim = dim1 + elif shard_dim == dim1: + shard_dim = dim0 + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + +@unflatten.override(SplitPrimitiveTensor) +def unflatten_split( + input: SplitPrimitiveTensor, dim: int, sizes: Tuple[int] +) -> SplitPrimitiveTensor: + assert dim != input.shard_dim, "Unflattening the split dimension is not supported." + shards = [unflatten(shard, dim, sizes) for shard in input.shards] + shard_dim = input.shard_dim + if dim < shard_dim: + shard_dim += len(sizes) - 1 + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + @unshard.override(ReplicatedTensor) def unshard_replicated(input: ReplicatedTensor) -> Tensor: return input.shards[0] @@ -728,6 +1171,160 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor: return sharded_cat(input) +@unshard.override(UnreducedTensor) +def unshard_unreduced(input: UnreducedTensor) -> Tensor: + shards = input.shards + shards = [ + shard if i == 0 else transfer_to_logical_device(shard, 0) + for i, shard in enumerate(shards) + ] + return functools.reduce(lambda x, y: elementwise(torch.add, x, y), shards) + + @unshard.override(Tensor) def unshard_unsharded(input: Tensor) -> Tensor: return input + + +def _reshape_get_flatten_dim_range( + from_shape: List[int], to_shape: List[int] +) -> Optional[Tuple[int, int]]: + """If a reshape would flatten a range of dimensions return that index range [begin, end). + If the reshape is not of that kind return `None`.""" + flatten_start_len = _reshape_get_single_split_dim(to_shape, from_shape) + if flatten_start_len is None: + return None + start, length = flatten_start_len + return start, start + length + + +def _reshape_infer_dynamic_dim( + shape1: List[int], shape2: List[int] +) -> Tuple[List[int], List[int]]: + assert ( + len([d for d in list(shape1) + list(shape2) if d < 0]) <= 1 + ), "Only one dynamic dimension is allowed" + shape1_dynamic_dims = [i for i, d in enumerate(shape1) if d <= 0] + if len(shape1_dynamic_dims) > 0: + s2, s1 = _reshape_infer_dynamic_dim(shape2, shape1) + return s1, s2 + + shape2_dynamic_dims = [i for i, d in enumerate(shape2) if d <= 0] + if len(shape2_dynamic_dims) == 0: + return shape1, shape2 + shape2_dynamic_dim = shape2_dynamic_dims[0] + shape1_size = math.prod(shape1) + shape2_size_without_dynamic_dim = math.prod(d for d in shape2 if d > 0) + shape2_res = list(shape2) + assert shape1_size % shape2_size_without_dynamic_dim == 0 + shape2_res[shape2_dynamic_dim] = shape1_size // shape2_size_without_dynamic_dim + assert shape2_res[shape2_dynamic_dim] > 0 + return shape1, shape2_res + + +def _reshape_get_single_split_dim( + from_shape: List[int], to_shape: List[int] +) -> Optional[Tuple[int, int]]: + """If a reshape would split a single dimension, return its index and the length of the new dimensions. + If the reshape is not of that kind return `None`. + E.g. + _reshape_get_single_split_dim(from_shape=(2, 12, 5), to_shape=(2, 3, 4, 5)) + results in + (1, 2)""" + from_shape, to_shape = _reshape_infer_dynamic_dim(from_shape, to_shape) + + if len(to_shape) < len(from_shape): + return None + i = longest_equal_range(from_shape, to_shape) + split_dims_length = len(to_shape) - len(from_shape) + 1 + if i == len(from_shape): + return ( + i, + split_dims_length, + ) + j = len(to_shape) - longest_equal_range(reversed(from_shape), reversed(to_shape)) + assert i < j + expected_split_dim_size = math.prod(to_shape[i:j]) + if expected_split_dim_size == 1: + # 1's were inserted. + return ( + i, + split_dims_length, + ) + if expected_split_dim_size != from_shape[i]: + return None + return ( + i, + split_dims_length, + ) + + +@unsqueeze.override(SplitPrimitiveTensor) +def unsqueeze_split(tensor: SplitPrimitiveTensor, dim: int) -> SplitPrimitiveTensor: + shards = [torch.unsqueeze(unbox_tensor(shard), dim) for shard in tensor.shards] + shard_dim = tensor.shard_dim + dim_resolved = dim if dim >= 0 else dim + len(tensor.shape) + 1 + if shard_dim >= dim_resolved: + shard_dim += 1 + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + +@unsqueeze.override(ReplicatedTensor) +def unsqueeze_replicated(tensor: ReplicatedTensor, dim: int) -> SplitPrimitiveTensor: + shards = [torch.unsqueeze(unbox_tensor(shard), dim) for shard in tensor.shards] + return ReplicatedTensor(ts=shards) + + +@view.override(SplitPrimitiveTensor) +def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitiveTensor: + view_split_range = _reshape_get_single_split_dim(tensor.shape, shape) + if view_split_range is None: + raise ValueError( + "Only taking a tensor view where splitting a single dimension is supported" + ) + view_split_dim = view_split_range[0] + + if view_split_dim == tensor.shard_dim: + if tensor.shape[view_split_dim] % tensor.shard_count != 0: + raise ValueError( + "Only splitting a dimension that is multiple of the shard count is supported" + ) + if shape[view_split_dim] % tensor.shard_count != 0: + raise ValueError( + "The resulting leading splitting dimension must be multiple of the shard count" + ) + + shard_dim = tensor.shard_dim + if shard_dim > view_split_dim: + new_dims_count = len(shape) - len(tensor.shape) + shard_dim += new_dims_count + new_shard_shape = list(shape) + new_shard_shape[shard_dim] //= tensor.shard_count + shards = [view(shard, new_shard_shape) for shard in tensor.shards] + res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) + assert math.prod(res.shape) == math.prod(tensor.shape) + return res + + +@view_as_complex.override(SplitPrimitiveTensor) +def view_as_complex_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: + shards = [view_as_complex(shard) for shard in tensor.shards] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@view_as_complex.override(ReplicatedTensor) +def view_as_complex_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: + shards = [view_as_complex(shard) for shard in tensor.shards] + return ReplicatedTensor(ts=shards) + + +@view_as_real.override(SplitPrimitiveTensor) +def view_as_real_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: + shards = [view_as_real(shard) for shard in tensor.shards] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@view_as_real.override(ReplicatedTensor) +def view_as_real_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: + shards = [view_as_real(shard) for shard in tensor.shards] + return ReplicatedTensor(ts=shards) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 6350d3e0b..408f00ec7 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -11,34 +11,58 @@ import torch import numbers from torch import Tensor, dtype -from ..types import AnyTensor, ShardedTensor, Theta, sharding +from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor from numbers import Number +import math from ._registry import * __all__ = [ "all_gather", + "all_reduce", "cat", "conv2d", + "einsum_2args", "elementwise", "embedding_lookup", "equal", + "expand", + "flatten", + "gather", + "gelu_tanh_approximation", + "get_index", "gemm", "group_norm_affine", "layer_norm", + "index_copy_", + "index_put_", + "index_select", "interpolate", "linear", "matmul", + "mean", + "module_register_buffer", "permute", "rms_norm", + "repeat", "replicate", + "reshape", "reshard", "reshard_split", "reshard_like", "scaled_dot_product_attention", "sharded_cat", "sharded_sum", + "softmax", + "to", + "transfer_to_logical_device", + "transpose", + "unflatten", "unshard", + "unsqueeze", + "view", + "view_as_complex", + "view_as_real", ] IntOrSequenceInt = Union[int, Sequence[int]] @@ -46,6 +70,7 @@ @overridable def all_gather(maybe_sharded: AnyTensor, *, dim: int | None = None) -> AnyTensor: + "Gather/concatenate on all devices along dimension `dim`." ... @@ -62,6 +87,23 @@ def _all_gather_trampoline( d.fail(tensors) +@overridable +def all_reduce(tensor: AnyTensor) -> AnyTensor: + "Reduce on all devices." + ... + + +@all_reduce.trampoline +def _all_reduce_trampoline(d: SignatureDispatcher, tensor: AnyTensor): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def cat(tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0) -> AnyTensor: ... @@ -132,16 +174,52 @@ def _conv2d_trampoline( @overridable -def elementwise(operator, *args: AnyTensor) -> AnyTensor: +def einsum_2args( + input0: AnyTensor, + input1: AnyTensor, + einsum_str: str, + *, + accum_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Executes a given Einstein summation notation string on the provided tensors. + + Equivalent to: + ``` + y = torch.einsum(einsum_str, input0, input1) + ``` + """ + raise NotImplementedError + + +@einsum_2args.trampoline +def _einsum_trampoline( + d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str +): + tensors = (input0, input1) + for override in d.find_overrides(tensors): + result = override(input0, input1, einsum_str) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def elementwise(operator, *args, **kwargs) -> AnyTensor: """Applies an elementwise operator against arguments.""" raise NotImplementedError @elementwise.trampoline -def _elementwise_trampoline(d: SignatureDispatcher, operator, *args: AnyTensor): - tensors = args +def _elementwise_trampoline(d: SignatureDispatcher, operator, *args, **kwargs): + tensors = [] + for a in args: + if isinstance(a, (Tensor, InferenceTensor)): + tensors.append(a) + else: + break for override in d.find_overrides(tensors): - result = override(operator, *args) + result = override(operator, *args, **kwargs) if result is not NotImplemented: return override, result else: @@ -212,6 +290,109 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): d.fail(tensors) +@overridable +def expand(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.expand""" + ... + + +@expand.trampoline +def _expand_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, shape: List[int] +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, shape) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def get_index( + tensor: AnyTensor, + key: slice, +) -> torch.Tensor: + """Indexes the tensor using the key. + + Equivalent to: + ``` + out = tensor[key] + ``` + """ + raise NotImplementedError + + +@get_index.trampoline +def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, key) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def flatten(input: AnyTensor, start_dim: int = 0, end_dim: int = -1) -> AnyTensor: + """See torch.flatten""" + ... + + +@flatten.trampoline +def _flatten_trampoline( + d: SignatureDispatcher, input: AnyTensor, start_dim: int = 0, end_dim: int = -1 +) -> AnyTensor: + dispatch_args = (input,) + for override in d.find_overrides(dispatch_args): + result = override(input, start_dim, end_dim) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + +@overridable +def gather(input: AnyTensor, dim: int, index: AnyTensor) -> AnyTensor: + """See torch.gather""" + ... + + +@gather.trampoline +def _gather_trampoline( + d: SignatureDispatcher, input: AnyTensor, dim: int, index: AnyTensor +) -> AnyTensor: + dispatch_args = ( + input, + index, + ) + for override in d.find_overrides(dispatch_args): + result = override(input, dim, index) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + +def gelu_tanh_approximation(input: AnyTensor) -> AnyTensor: + """Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + Approximation with tanh""" + return ( + 0.5 + * input + * ( + 1.0 + + elementwise( + torch.tanh, + math.sqrt(2.0 / math.pi) + * (input + 0.044715 * elementwise(torch.pow, input, 3.0)), + ) + ) + ) + + @overridable def gemm( a: AnyTensor, @@ -278,6 +459,75 @@ def _group_norm_affine_trampoline( d.fail(tensors) +@overridable +def index_copy_( + inout: AnyTensor, dim: int, index: AnyTensor, tensor: AnyTensor +) -> AnyTensor: + """See torch.Tensor.index_copy_""" + ... + + +@index_copy_.trampoline +def _index_copy__trampoline( + d: SignatureDispatcher, + inout: AnyTensor, + dim: int, + index: AnyTensor, + tensor: AnyTensor, +) -> AnyTensor: + tensors = (inout, index, tensor) + for override in d.find_overrides(tensors): + result = override(inout, dim, index, tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def index_put_( + inout: AnyTensor, indices: Tuple[AnyTensor], values: AnyTensor +) -> AnyTensor: + """See torch.Tensor.index_put_""" + ... + + +@index_put_.trampoline +def _index_put__trampoline( + d: SignatureDispatcher, + inout: AnyTensor, + indices: Tuple[AnyTensor], + values: AnyTensor, +) -> AnyTensor: + # We change the order for the variadic indices to be last. + tensors = (inout, values, *indices) + for override in d.find_overrides(tensors): + result = override(inout, indices, values) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def index_select(tensor: AnyTensor, dim: int, index: AnyTensor) -> AnyTensor: + """See torch.Tensor.index_select""" + ... + + +@index_select.trampoline +def _index_select_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, dim: int, index: AnyTensor +) -> AnyTensor: + tensors = (tensor, index) + for override in d.find_overrides(tensors): + result = override(tensor, dim, index) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def interpolate( input: AnyTensor, @@ -416,7 +666,6 @@ def _matmul_trampoline( d: SignatureDispatcher, lhs, rhs, *, transpose_rhs: bool = False ): tensors = (lhs, rhs) - assert isinstance(rhs, numbers.Number) or len(rhs.shape) == 2 for override in d.find_overrides(tensors): result = override(lhs, rhs, transpose_rhs=transpose_rhs) if result is not NotImplemented: @@ -444,6 +693,57 @@ def _permute_trampoline(d: SignatureDispatcher, tensor: AnyTensor, dims: List[in d.fail(tensors) +@overridable +def mean( + x: AnyTensor, + dim: Union[int, List[int]], + keepdim: bool = False, + *, + dtype: torch.dtype = None, +) -> AnyTensor: + """See torch.mean""" + raise NotImplementedError + + +@mean.trampoline +def _mean_trampoline( + d: SignatureDispatcher, + x: AnyTensor, + dim: Union[int, List[int]], + keepdim: bool = False, + *, + dtype: torch.dtype = None, +) -> AnyTensor: + tensors = (x,) + for override in d.find_overrides(tensors): + result = override(x, dim, keepdim, dtype=dtype) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def module_register_buffer( + module: torch.nn.Module, name: str, tensor: AnyTensor +) -> None: + """Register the tensor into the module. See torch.nn.Module.register_buffer.""" + ... + + +@module_register_buffer.trampoline +def _module_register_buffer_trampoline( + d: SignatureDispatcher, module: torch.nn.Module, name: str, tensor: AnyTensor +) -> None: + args = (module, tensor) + for override in d.find_overrides(args): + result = override(module, name, tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(args) + + @overridable def rms_norm(x: AnyTensor, weight: AnyTensor, *, epsilon: float) -> AnyTensor: """Computes the full, unbiased RMS normalization of an input.""" @@ -463,6 +763,25 @@ def _rms_norm_trampoline( d.fail(tensors) +@overridable +def repeat(input: AnyTensor, *sizes: List[int]) -> AnyTensor: + """See torch.Tensor.repeat""" + ... + + +@repeat.trampoline +def _repeat_trampoline( + d: SignatureDispatcher, input: AnyTensor, *sizes: List[int] +) -> AnyTensor: + dispatch_args = (input,) + for override in d.find_overrides(dispatch_args): + result = override(input, *sizes) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def replicate(input: AnyTensor, count: int) -> ShardedTensor: """Replicate across devices. @@ -486,7 +805,7 @@ def _replicate_trampoline( @overridable def scaled_dot_product_attention( - q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor] + q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor], is_causal: bool ) -> AnyTensor: """Computes the scaled dot product attention using QKV.""" raise NotImplementedError @@ -499,16 +818,38 @@ def _scaled_dot_product_attention( k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor], + is_causal: bool = False, + scale: Optional[float] = None, ): tensors = (q, k, v, a) for override in d.find_overrides(tensors): - result = override(q, k, v, a) + result = override(q, k, v, a, is_causal=is_causal, scale=scale) if result is not NotImplemented: return override, result else: d.fail(tensors) +@overridable +def reshape(input: AnyTensor, shape: List[int]) -> AnyTensor: + """Returns a tensor with the same data and number of elements as input, but with + the specified shape. + See torch.reshape. + """ + ... + + +@reshape.trampoline +def _reshape_trampoline(d: SignatureDispatcher, input, shape) -> AnyTensor: + dispatch_args = (input,) + for override in d.find_overrides(dispatch_args): + result = override(input, shape) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def reshard( input: AnyTensor | Theta, @@ -616,6 +957,104 @@ def _sharded_sum_trampoline(d: SignatureDispatcher, maybe_sharded: AnyTensor): d.fail(tensors) +@overridable +def softmax( + tensor: AnyTensor, dim: Optional[int] = None, dtype: Optional[torch.dtype] = None +) -> AnyTensor: + """See torch.nn.functional.softmax""" + ... + + +@softmax.trampoline +def _softmax_trampoline( + d: SignatureDispatcher, + tensor: AnyTensor, + dim: Optional[int] = None, + dtype: Optional[torch.dtype] = None, +) -> AnyTensor: + dispatch_args = [tensor] + for override in d.find_overrides(dispatch_args): + result = override(tensor, dim=dim, dtype=dtype) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + +@overridable +def to(tensor: AnyTensor, *args, **kwargs) -> AnyTensor: + """See torch.Tensor.to""" + ... + + +@to.trampoline +def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs): + dispatch_args = [tensor] + for override in d.find_overrides(dispatch_args): + result = override(tensor, *args, **kwargs) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + +@overridable +def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor: + """Transfer the tensor to a device with ordinal `ordinal`.""" + ... + + +@transfer_to_logical_device.trampoline +def _transfer_to_logical_device_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, ordinal: int +): + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, ordinal) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def transpose(tensor: AnyTensor, dim0: int, dim1: int) -> AnyTensor: + """See torch.transpose""" + ... + + +@transpose.trampoline +def _transpose_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, dim0: int, dim1: int +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, dim0, dim1) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def unflatten(input: AnyTensor, dim: int, sizes: Tuple[int]) -> AnyTensor: + """See torch.unflatten""" + ... + + +@unflatten.trampoline +def _unflatten_trampoline( + d: SignatureDispatcher, input: AnyTensor, dim: int, sizes: Tuple[int] +) -> AnyTensor: + dispatch_args = (input,) + for override in d.find_overrides(dispatch_args): + result = override(input, dim, sizes) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def unshard(tensor: AnyTensor) -> AnyTensor: """Return the tensor that has the same elements and shape, but is not sharded.""" @@ -631,3 +1070,75 @@ def _unshard_trampoline(d: SignatureDispatcher, tensor: AnyTensor): return override, result else: d.fail(tensors) + + +@overridable +def unsqueeze(tensor: AnyTensor, dim: int) -> AnyTensor: + """See torch.unsqueeze""" + ... + + +@unsqueeze.trampoline +def _unsqueeze_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, dim: int +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, dim) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def view(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.view""" + ... + + +@view.trampoline +def _view_trampoline( + d: SignatureDispatcher, tensor: AnyTensor, shape: List[int] +) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor, shape) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def view_as_complex(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.view_as_complex""" + ... + + +@view_as_complex.trampoline +def _view_as_complex_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def view_as_real(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.view_as_complex""" + ... + + +@view_as_real.trampoline +def _view_as_real_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) diff --git a/shortfin/shortfin/__init__.py b/sharktank/sharktank/serving_poc/__init__.py similarity index 100% rename from shortfin/shortfin/__init__.py rename to sharktank/sharktank/serving_poc/__init__.py diff --git a/shortfin/shortfin/framework/logging.py b/sharktank/sharktank/serving_poc/framework/logging.py similarity index 95% rename from shortfin/shortfin/framework/logging.py rename to sharktank/sharktank/serving_poc/framework/logging.py index 4843ee358..fe5ffc069 100644 --- a/shortfin/shortfin/framework/logging.py +++ b/sharktank/sharktank/serving_poc/framework/logging.py @@ -24,7 +24,7 @@ def __init__(self): def _setup_logger(): - root_logger = logging.getLogger("shortfin") + root_logger = logging.getLogger("sharktank.serving_poc") root_logger.setLevel(logging.DEBUG) default_handler = logging.StreamHandler(sys.stderr) default_handler.flush = sys.stderr.flush diff --git a/shortfin/shortfin/framework/session.py b/sharktank/sharktank/serving_poc/framework/session.py similarity index 99% rename from shortfin/shortfin/framework/session.py rename to sharktank/sharktank/serving_poc/framework/session.py index c28b81f5b..28af0fd44 100644 --- a/shortfin/shortfin/framework/session.py +++ b/sharktank/sharktank/serving_poc/framework/session.py @@ -319,7 +319,7 @@ def __init__(self, session: DeviceSession, index: int = 0): self._semaphore = session.device.create_semaphore(0) self._step = 0 - def execute_sequential(self, command_buffers: list[HalCommandBuffer]): + def execute_sequential(self, command_buffer: HalCommandBuffer): """Executes a list of command buffers at the current step, advancing to the next. """ @@ -329,7 +329,7 @@ def execute_sequential(self, command_buffers: list[HalCommandBuffer]): self._step = next_step sem = self._semaphore self._device.queue_execute( - command_buffers, [(sem, current_step)], [(sem, next_step)] + command_buffer, [(sem, current_step)], [(sem, next_step)] ) def current_fence(self) -> HalFence: diff --git a/shortfin/shortfin/llm/__init__.py b/sharktank/sharktank/serving_poc/llm/__init__.py similarity index 100% rename from shortfin/shortfin/llm/__init__.py rename to sharktank/sharktank/serving_poc/llm/__init__.py diff --git a/shortfin/shortfin/llm/api/rest_server.py b/sharktank/sharktank/serving_poc/llm/api/rest_server.py similarity index 98% rename from shortfin/shortfin/llm/api/rest_server.py rename to sharktank/sharktank/serving_poc/llm/api/rest_server.py index 33810428b..67536173f 100644 --- a/shortfin/shortfin/llm/api/rest_server.py +++ b/sharktank/sharktank/serving_poc/llm/api/rest_server.py @@ -27,7 +27,7 @@ GenerateRequest, ) -logger = get_logger("shortfin.llm.api_server") +logger = get_logger("sharktank.serving_poc.llm.api_server") app = FastAPI() service: Optional[GenerateService] = None diff --git a/shortfin/shortfin/llm/attn_block_cache.py b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py similarity index 98% rename from shortfin/shortfin/llm/attn_block_cache.py rename to sharktank/sharktank/serving_poc/llm/attn_block_cache.py index a9443a1e3..a2299c67e 100644 --- a/shortfin/shortfin/llm/attn_block_cache.py +++ b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py @@ -21,7 +21,7 @@ from .config import human_size, CacheParams -logger = get_logger("shortfin.llm.cache") +logger = get_logger("sharktank.serving_poc.llm.cache") class AttnBlockCacheEntry: diff --git a/shortfin/shortfin/llm/config.py b/sharktank/sharktank/serving_poc/llm/config.py similarity index 100% rename from shortfin/shortfin/llm/config.py rename to sharktank/sharktank/serving_poc/llm/config.py diff --git a/shortfin/shortfin/llm/impl/service_v1.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py similarity index 99% rename from shortfin/shortfin/llm/impl/service_v1.py rename to sharktank/sharktank/serving_poc/llm/impl/service_v1.py index f3fd80a93..8ae0be637 100644 --- a/shortfin/shortfin/llm/impl/service_v1.py +++ b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py @@ -42,7 +42,7 @@ ) -logger = get_logger("shortfin.llm.impl.service_v1") +logger = get_logger("sharktank.serving_poc.llm.impl.service_v1") EXPECTED_CONCURRENCY = 10 @@ -340,7 +340,7 @@ async def prefill(self) -> TimelineGuarded[HalBufferView]: # Perform h2d transfers. cb.end() - work_queue.execute_sequential([cb]) + work_queue.execute_sequential(cb) # Inputs: # token_ids @@ -468,7 +468,7 @@ async def decode(self) -> TimelineGuarded[HalBufferView]: # Perform h2d transfers. cb.end() - work_queue.execute_sequential([cb]) + work_queue.execute_sequential(cb) # Inputs: # token_ids diff --git a/shortfin/shortfin/llm/impl/service_v1_cli.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py similarity index 92% rename from shortfin/shortfin/llm/impl/service_v1_cli.py rename to sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py index 9fc55b5ac..7895341c9 100644 --- a/shortfin/shortfin/llm/impl/service_v1_cli.py +++ b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py @@ -15,21 +15,21 @@ HalElementType, ) -from shortfin.framework.session import DeviceSession +from sharktank.serving_poc.framework.session import DeviceSession -from shortfin.llm.attn_block_cache import ( +from sharktank.serving_poc.llm.attn_block_cache import ( create_attn_block_cache_module, AttnBlockCache, ) -from shortfin.llm.config import ( +from sharktank.serving_poc.llm.config import ( CacheParams, ModelParams, ServiceParams, ) -from shortfin.llm.impl.service_v1 import GenerateServiceV1 -from shortfin.llm.service import GenerateRequest +from sharktank.serving_poc.llm.impl.service_v1 import GenerateServiceV1 +from sharktank.serving_poc.llm.service import GenerateRequest def setup(vmfb_path, config_path, gguf_path): diff --git a/shortfin/shortfin/llm/service.py b/sharktank/sharktank/serving_poc/llm/service.py similarity index 100% rename from shortfin/shortfin/llm/service.py rename to sharktank/sharktank/serving_poc/llm/service.py diff --git a/shortfin/shortfin/llm/testing/fake_v1_module.py b/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py similarity index 100% rename from shortfin/shortfin/llm/testing/fake_v1_module.py rename to sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py diff --git a/shortfin/shortfin/py.typed b/sharktank/sharktank/serving_poc/py.typed similarity index 100% rename from shortfin/shortfin/py.typed rename to sharktank/sharktank/serving_poc/py.typed diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py index ab383a14c..9a7dcf1ee 100644 --- a/sharktank/sharktank/types/gguf_interop/base.py +++ b/sharktank/sharktank/types/gguf_interop/base.py @@ -11,9 +11,9 @@ import numpy as np import torch -from gguf import GGUFReader, GGUFValueType +from gguf import GGUFReader, GGUFValueType, ReaderField -from shark_turbine.aot import ( +from iree.turbine.aot import ( ExternalTensorTrait, ) @@ -44,12 +44,26 @@ def _sanitize_scalar(scalar): return scalar +def _load_array(field: ReaderField) -> list: + if len(field.types) != 2: + raise ValueError(f"Unsupported array type {field.types}") + element_type = field.types[1] + if element_type == GGUFValueType.STRING: + return [ + str(bytes(field.parts[parts_index]), encoding="utf8") + for parts_index in field.data + ] + elif element_type in GGUFReader.gguf_scalar_to_np: + return [ + _sanitize_scalar(field.parts[parts_index][0]) for parts_index in field.data + ] + else: + raise ValueError(f"Unsupported array element type f{element_type}") + + def _load_properties(reader: GGUFReader) -> dict[str, Any]: - # TODO: Figure out what to do with tables. - tables: dict[str, Any] = {} properties: dict[str, Any] = { "schema": "GGUF", - # "tables": tables, } # Extract hyper-parameters. Adapted from gguf-dump.py @@ -60,8 +74,10 @@ def _load_properties(reader: GGUFReader) -> dict[str, Any]: properties[field.name] = str(bytes(field.parts[-1]), encoding="utf8") elif field.types[0] in reader.gguf_scalar_to_np: properties[field.name] = _sanitize_scalar(field.parts[-1][0]) + elif field.types[0] == GGUFValueType.ARRAY: + properties[field.name] = _load_array(field) else: - tables[field.name] = field.parts + raise ValueError(f"Invalid field type.") return properties diff --git a/sharktank/sharktank/types/layouts.py b/sharktank/sharktank/types/layouts.py index 54210da9b..586e4f673 100644 --- a/sharktank/sharktank/types/layouts.py +++ b/sharktank/sharktank/types/layouts.py @@ -22,8 +22,8 @@ register_quantized_layout, MetaDataValueType, QuantizedLayout, - _dtype_to_serialized_name, - _serialized_name_to_dtype, + dtype_to_serialized_name, + serialized_name_to_dtype, ) from .layout_utils import ( @@ -96,7 +96,7 @@ def create( m = planes.get("m") dtype_str = metadata.get("dtype") if dtype_str is not None: - dtype = _serialized_name_to_dtype(dtype_str) + dtype = serialized_name_to_dtype(dtype_str) else: # Backwards compat with old serialized. Emulate original behavior # before mixed precision. @@ -106,7 +106,7 @@ def create( @property def metadata(self) -> Optional[dict[str, MetaDataValueType]]: """Additional metadata needed to reconstruct a layout.""" - return {"dtype": _dtype_to_serialized_name(self._dtype)} + return {"dtype": dtype_to_serialized_name(self._dtype)} @property def planes(self) -> dict[str, torch.Tensor]: diff --git a/sharktank/sharktank/types/quantizers.py b/sharktank/sharktank/types/quantizers.py index 75189bdf3..d3c093b85 100644 --- a/sharktank/sharktank/types/quantizers.py +++ b/sharktank/sharktank/types/quantizers.py @@ -38,8 +38,8 @@ QuantizedTensor, UnnamedTensorName, register_inference_tensor, - _serialized_name_to_dtype, - _dtype_to_serialized_name, + serialized_name_to_dtype, + dtype_to_serialized_name, ) __all__ = [ @@ -131,6 +131,25 @@ def __init__( else: assert len(self._scale.shape) == 0, "Expected per-tensor scale to be 0D" + def dequantize_raw_tensor( + self, t: torch.Tensor, to: torch.dtype, *, name: str + ) -> torch.Tensor: + return ( + PlanarQuantizedTensor( + shape=t.shape, + name=t.name, + layout=TensorScaledLayout( + shape=t.shape, + d=self._reciprocal_scale, + qs=t, + m=self.offset, + dtype=to, + ), + ) + .unpack() + .dequant() + ) + def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor: """Performs a quantizing transformation on t, returning a QuantizeTensor.""" shape = list(t.shape) @@ -139,14 +158,15 @@ def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor if axis is None: # Per tensor. if offset is None: + # Changed to t/reciprocal because narrow float types are garbage qs = saturate_cast( - t * self._scale, + t / self._reciprocal_scale, dtype=self.dtype, disable_saturate=self._disable_saturate, ) else: qs = saturate_cast( - t * self._scale + offset, + t / self._reciprocal_scale + offset, dtype=self.dtype, disable_saturate=self._disable_saturate, ) @@ -245,7 +265,7 @@ def create( raise IOError("Missing property") from e axis = int(extra_properties["axis"]) if "axis" in extra_properties else None disable_saturate = bool(extra_properties.get("disable_saturate")) - dtype = _serialized_name_to_dtype(dtype_name) + dtype = serialized_name_to_dtype(dtype_name) return cls( name=name, scale=scale, @@ -271,7 +291,7 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad scale_name = f"{self.name}:scale" rscale_name = f"{self.name}:rscale" offset_name = f"{self.name}:offset" - extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)} + extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)} if self._axis is not None: extra_properties["axis"] = self._axis if self._disable_saturate: @@ -387,7 +407,7 @@ def create( dtype_name = extra_properties["dtype"] except KeyError as e: raise IOError("Missing property") from e - dtype = _serialized_name_to_dtype(dtype_name) + dtype = serialized_name_to_dtype(dtype_name) return cls( name=name, dtype=dtype, @@ -399,7 +419,7 @@ def globals(self) -> dict[str, torch.Tensor]: def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata: """Adds this tensor to the global archive.""" - extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)} + extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)} raw_tensors = {} return InferenceTensorMetadata( self.serialized_name(), diff --git a/sharktank/sharktank/types/sharding.py b/sharktank/sharktank/types/sharding.py index a670cae33..81d2f31a5 100644 --- a/sharktank/sharktank/types/sharding.py +++ b/sharktank/sharktank/types/sharding.py @@ -18,7 +18,7 @@ def __init__(self): class TensorSharding(Sharding): - def __init__(self, *, shard_count: int): + def __init__(self, shard_count: int): super().__init__() self.shard_count = shard_count @@ -29,7 +29,7 @@ def __init__(self): class Replicated(TensorSharding): - def __init__(self, *, shard_count: int): + def __init__(self, shard_count: int): super().__init__(shard_count=shard_count) @@ -39,6 +39,17 @@ def __init__(self, *, shard_count: int, shard_dim: int): self.shard_dim = shard_dim +class Ignore(TensorSharding): + """When a theta is sharded, a tensor or a branch with this sharding type will be + ignored. + It will not appear in the resulting sharded theta. + This is not strictly a TensorSharding. It will terminate further traversal of a + branch of a theta tree as well.""" + + def __init__(self): + super().__init__(shard_count=0) + + class ThetaSharding(dict): """Sharding for each tensor in a theta. It is of type dict[str, "ThetaSharding" | TensorSharding]. @@ -49,7 +60,15 @@ def __init__(self, *args, **kwargs): for k, v in d.items(): d[k] = tree.map_nodes( tree=v, - f=lambda x: x if isinstance(x, TensorSharding) else ThetaSharding(x), + f=lambda x: x + if isinstance( + x, + ( + TensorSharding, + ThetaSharding, + ), + ) + else ThetaSharding(x), ) super().__init__(d) @@ -89,6 +108,27 @@ def theta_sharding(self) -> ThetaSharding: ) +class FFNSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + "ffn_gate": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "ffn_up": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "ffn_down": LinearSplitReductionDimSharding( + shard_count=self.shard_count + ).theta_sharding(), + } + ) + + class GroupNormSplitChannelSharding(ThetaLayerSharding): def __init__(self, shard_count: int): super().__init__() @@ -103,23 +143,67 @@ def theta_sharding(self) -> ThetaSharding: ) -class LinearReplicatedInputSplitWeightAndBiasSharding(ThetaLayerSharding): +class LinearLayerSharding(ThetaLayerSharding): + def __init__( + self, premul_input: TensorSharding, weight: TensorSharding, bias: TensorSharding + ): + super().__init__() + self.premul_input = premul_input + self.weight = weight + self.bias = bias + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + "premul_input": self.premul_input, + "weight": self.weight, + "bias": self.bias, + } + ) + + +class LinearSplitParallelWeightAndBiasSharding(LinearLayerSharding): def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0): + """Split one parallel dimension for both the weight and bias. + Since the weight is transposed before multiplying, the weight parallel + dimension is the same as the output(bias) dimension.""" + super().__init__( + premul_input=Replicated(shard_count=shard_count), + weight=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim), + bias=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim), + ) + + +class LinearSplitReductionDimSharding(LinearLayerSharding): + def __init__(self, shard_count: int): + super().__init__( + premul_input=Replicated(shard_count=shard_count), + weight=Split(shard_count=shard_count, shard_dim=1), + bias=Replicated(shard_count=shard_count), + ) + + +class RmsNormReplicatedSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + "weight": Replicated(shard_count=self.shard_count), + } + ) + + +class TokenEmbeddingLayerReplicatedSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): super().__init__() self.shard_count = shard_count - self.weight_and_bias_spit_dim = weight_and_bias_spit_dim def theta_sharding(self) -> ThetaSharding: return ThetaSharding( { - "premul_input": Replicated(shard_count=self.shard_count), - "weight": Split( - shard_count=self.shard_count, - shard_dim=self.weight_and_bias_spit_dim, - ), - "bias": Split( - shard_count=self.shard_count, - shard_dim=self.weight_and_bias_spit_dim, - ), + "weight": Replicated(shard_count=self.shard_count), } ) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index eab7d5823..f870aa101 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -17,20 +17,22 @@ Tuple, ) from copy import deepcopy -from collections.abc import Collection -from numbers import Integral +from collections.abc import Collection, Sequence +from numbers import Integral, Number +import os from abc import ABC, abstractmethod from dataclasses import dataclass import torch from torch import Tensor -from torch.utils._pytree import register_pytree_node +from torch.utils._pytree import register_pytree_node, SequenceKey +import torch.utils._pytree from ..utils.math import ceildiv -from shark_turbine.aot import ( +from iree.turbine.aot import ( + DeviceTensorTrait, ExternalTensorTrait, ) -from shark_turbine.ops.iree import transfer_to_logical_device from ..utils import tree as tree_utils from ..utils.io import ShardedArchiveBuilder @@ -38,6 +40,7 @@ __all__ = [ "AnyTensor", "DefaultPrimitiveTensor", + "dtype_to_serialized_name", "flatten_tensor_tree", "InferenceTensor", "MetaDataValueType", @@ -47,12 +50,26 @@ "QuantizedTensor", "register_quantized_layout", "ReplicatedTensor", + "serialized_name_to_dtype", "ShardedTensor", "SplitPrimitiveTensor", + "torch_tree_flatten", "unbox_tensor", "UnreducedTensor", ] +if ( + "SHARKTANK_OVERRIDE_TORCH_TENSOR_REPR" in os.environ + and os.environ["SHARKTANK_OVERRIDE_TORCH_TENSOR_REPR"] != "0" +): + + def _tensor_debugger_friendly_repr(self: torch.Tensor): + """Override for the torch.Tensor.__repr__ so it does not take forever when the + debugger wants to query many/large tensors.""" + return f"Tensor({list(self.shape)}, {self.dtype})" + + Tensor.__repr__ = _tensor_debugger_friendly_repr + # JSON encodable value types. MetaDataValueType = Union[int, bool, float, str] UnnamedTensorName = "" @@ -255,8 +272,15 @@ def transform_globals( return self._clone_with_globals(prev_globals) def to( - self, *, device: Optional[Union[str, torch.device]] = None + self, + *, + device: Optional[Union[str, torch.device]] = None, ) -> "InferenceTensor": + # TODO: reconcile with ops.to(...) and torch.Tensor.to(...). + # Do we always want to clone with globals? + # This makes our type inconsistent with torch tensors. + # If we use this to transform a theta we want to change the theta. + # If we want to use this in a computation we don't want to change the theta. return self.transform_globals( lambda d: {k: t.to(device=device) for k, t in d.items()} ) @@ -279,26 +303,144 @@ def T(self) -> "InferenceTensor": return permute(self, dims=dims) + @property + def dtype(self) -> torch.dtype: + raise NotImplementedError() + + def expand(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": + from ..ops import expand + + if all(isinstance(a, int) for a in args): + shape = args + else: + assert len(args) == 1 + shape = args[0] + return expand(self, shape) + + def flatten(self, start_dim: int = 0, end_dim: int = -1) -> "AnyTensor": + from ..ops import flatten + + return flatten(self, start_dim, end_dim) + + def index_copy_( + self, dim: int, index: "AnyTensor", tensor: "AnyTensor" + ) -> "InferenceTensor": + from ..ops import index_copy_ + + return index_copy_(self, dim, index, tensor) + + def index_put_( + self, indices: Tuple["AnyTensor"], values: "AnyTensor" + ) -> "InferenceTensor": + from ..ops import index_put_ + + return index_put_(self, indices, values) + + def index_select( + self, + dim: int, + index: "AnyTensor", + ) -> "InferenceTensor": + from ..ops import index_select + + return index_select(self, dim, index) + + def mean( + self, + dim: Union[int, List[int]], + keepdim: bool = False, + *, + dtype: torch.dtype = None, + ) -> "AnyTensor": + from ..ops import mean + + return mean(self, dim, keepdim, dtype=None) + + def pow(self, exponent: Union["AnyTensor", Number]) -> "AnyTensor": + from ..ops import elementwise + + return elementwise(torch.pow, self, exponent) + + def repeat(self, *sizes: List[int]) -> "AnyTensor": + from ..ops import repeat + + return repeat(self, *sizes) + + def reshape(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": + from ..ops import reshape + + if all(isinstance(a, int) for a in args): + shape = args + else: + assert len(args) == 1 + shape = args[0] + return reshape(self, shape) + + def transpose(self, dim0: int, dim1: int) -> "AnyTensor": + from ..ops import transpose + + return transpose(self, dim0, dim1) + + def unflatten(self, dim: int, sizes: Tuple[int]) -> "AnyTensor": + from ..ops import unflatten + + return unflatten(self, dim, sizes) + + def unsqueeze(self, dim: int) -> "AnyTensor": + from ..ops import unsqueeze + + return unsqueeze(self, dim) + + def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": + from ..ops import view + + if all(isinstance(a, int) or isinstance(a, torch.SymInt) for a in args): + shape = args + else: + assert len(args) == 1 + shape = args[0] + return view(self, shape) + def __add__(self, rhs): from ..ops import elementwise return elementwise(torch.add, self, rhs) def __radd__(self, lhs): - # Assumes commutative addition due to torch.elementwise not handling numbers on - # the lhs. + # Assumes commutative addition due to torch elementwise ops not handling + # numbers on the lhs. return self.__add__(lhs) + def __mod__(self, rhs): + from ..ops import elementwise + + return elementwise(torch.remainder, self, rhs) + def __mul__(self, rhs): from ..ops import elementwise return elementwise(torch.mul, self, rhs) def __rmul__(self, lhs): - # Assumes commutative multiplication due to torch.elementwise not handling + # Assumes commutative multiplication due to torch elementwise ops not handling # numbers on the lhs. return self.__mul__(lhs) + def __truediv__(self, rhs): + from ..ops import elementwise + + return elementwise(torch.true_divide, self, rhs) + + def __floordiv__(self, rhs): + from ..ops import elementwise + + return elementwise(torch.floor_divide, self, rhs) + + def __getitem__(self, key): + from ..ops import get_index + + return get_index(self, key) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} @@ -338,6 +480,17 @@ def as_torch(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ ... + @property + def dtype(self) -> torch.dtype: + return self.as_torch().dtype + + def __setitem__(self, key, value: "AnyTensor"): + if not isinstance(key, list) and not isinstance(key, tuple): + key = (key,) + + key = [unbox_tensor(k) if isinstance(k, PrimitiveTensor) else k for k in key] + self.as_torch()[*key] = unbox_tensor(value) + @register_inference_tensor class DefaultPrimitiveTensor(PrimitiveTensor): @@ -391,7 +544,15 @@ def _clone_with_globals( return DefaultPrimitiveTensor(name=self.name, data=new_globals[self.name]) def __getitem__(self, key): - return self._data[key] + keys = [key] + if isinstance(key, tuple) or isinstance(key, list): + keys = key + + keys = [ + unbox_tensor(key) if isinstance(key, PrimitiveTensor) else key + for key in keys + ] + return self._data[*keys] def __repr__(self): return f"PrimitiveTensor({self.name}, {self.shape}, {self._data.dtype})" @@ -434,7 +595,9 @@ def to_planar(self) -> "PlanarQuantizedTensor": it should override this method to implement properly or raise NotImplementedError. """ - return PlanarQuantizedTensor(self.name, self.shape, self.unpack()) + return PlanarQuantizedTensor( + name=self.name, shape=self.shape, layout=self.unpack() + ) def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata: """By default all QuantizedTensors serialize as a generic PlanarQuantizedTensor. @@ -601,6 +764,10 @@ def name(self, name: str): for i, shard in enumerate(self.shards): shard.name = f"{name}.shard.{i}" + @property + def dtype(self) -> torch.dtype: + return self.shards[0].dtype + @register_inference_tensor class ShardedTensorBase(ShardedTensor): @@ -619,13 +786,15 @@ def __init__( shape: Optional[list[int]], ): assert len(ts) > 0 - assert shard_dim is None or len(ts[0].shape) > shard_dim + assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim) super().__init__(name=name, shape=shape, shard_dim=shard_dim) self._shards: tuple[DefaultPrimitiveTensor] = tuple( DefaultPrimitiveTensor( name=f"{name}.shard.{i}", - data=transfer_to_logical_device(f"{i}", unbox_tensor(t)), + data=t, ) + if isinstance(t, torch.Tensor) + else t for i, t in enumerate(ts) ) @@ -695,6 +864,8 @@ def create( try: t = raw_tensors[t_name] ts.append(t) + # TODO: this should be changed to tracked device affinity + DeviceTensorTrait(i).set(t) except KeyError as e: raise IOError( f"Missing component tensor '{t_name}' in {raw_tensors.keys()}" @@ -745,6 +916,32 @@ def _is_full_slice(s: slice, dim_size: int) -> bool: ) +def _resolve_ellipsis_in_slicing(key: Tuple[Any], shape: Tuple[int]) -> Tuple[Any]: + """Example: + key = [1:2, ..., 0] + shape = [2, 3, 4, 5, 6] + Returns: + [1:2, :, :, :, 0]""" + num_ellipsis = len([k for k in key if k == Ellipsis]) + assert num_ellipsis <= 1, "Only one Ellipses is allowed." + if num_ellipsis <= 0: + return key + assert len(key) <= len( + shape + ), "Inserting trailing singleton dimensions is not supported." + dim = 0 + res = [] + for k in key: + if k == Ellipsis: + ellipsis_num_dims = len(shape) - len(key) + 1 + res.extend([slice(None)] * ellipsis_num_dims) + dim += ellipsis_num_dims + else: + dim += 1 + res.append(k) + return tuple(res) + + @register_inference_tensor class SplitPrimitiveTensor(ShardedTensorBase): """Sharded tensor split along a dimension into primitive tensors. @@ -770,8 +967,11 @@ def __init__( number of pieces. """ if isinstance(ts, torch.Tensor): + from ..ops import transfer_to_logical_device + assert shard_count is not None ts = ts.split(ceildiv(ts.shape[shard_dim], shard_count), dim=shard_dim) + ts = [transfer_to_logical_device(t, i) for i, t in enumerate(ts)] assert len(ts) == shard_count shard_count = None @@ -779,21 +979,23 @@ def __init__( assert len(ts) > 0 first_shape = ts[0].shape assert len(first_shape) > shard_dim - if shape is None: - # Compute the shape. - shape = list(first_shape) - shape[shard_dim] *= len(ts) - - # Assert the shape. - shard_dim_size = first_shape[shard_dim] - for t in ts[1:]: - assert ( - t.shape == first_shape - ), f"Shape mismatch for split tensors: {t.shape} vs {first_shape}" - shard_dim_size += t.shape[shard_dim] - assert ( - shard_dim_size == shape[shard_dim] - ), f"Sharding mismatch: Sharded dims do not cover the whole volume {shard_dim_size} vs {shape[shard_dim]}" + expected_shape = list(first_shape) + expected_shape[shard_dim] = sum([t.shape[shard_dim] for t in ts]) + if shape is not None: + shape = list(shape) + assert expected_shape == shape + else: + shape = expected_shape + + # Assert the shapes. + for i, t in enumerate(ts): + t_shape = list(t.shape) + assert len(shape) == len( + t_shape + ), f"Shape size mismatch tensor shard {i} with shape {t.shape}. Expected shape size {len(shape)}. Got {len(t_shape)}." + assert all( + s == t for i, (s, t) in enumerate(zip(shape, t_shape)) if i != shard_dim + ), f"Shape mismatch for non-split dimension for tensor shard {i} with shape {t.shape}" super().__init__(name=name, ts=ts, shape=shape, shard_dim=shard_dim) @@ -813,7 +1015,7 @@ def _is_slicing_split_dim(self, key): else: # Any other collection is a indexing only dimension 0. return self.shard_dim == 0 - if len(key) < self.shard_dim: + if len(key) <= self.shard_dim: return False if not isinstance(key[self.shard_dim], slice): return True @@ -834,19 +1036,52 @@ def _get_shard_slice(self, key): if len(key) <= self.shard_count: return key new_key = list(key) - new_key[self.shard_dim] = slice(None) + + if self.shard_dim < len(new_key): + new_key[self.shard_dim] = slice(None) return new_key def __getitem__(self, key): # TODO: implement all cases. - # Only slicing of non-split dimension is supported. + if not isinstance(key, Sequence): + key = (key,) + key = _resolve_ellipsis_in_slicing(key, self.shape) if self._is_slicing_split_dim(key): raise NotImplementedError( f"Slicing of the split dimension {self.shard_dim} is not supported." ) new_key = self._get_shard_slice(key) - shards = [shard[new_key] for shard in self.shards] - return SplitPrimitiveTensor(ts=shards, shard_dim=self.shard_dim) + + shards = [] + for i, shard in enumerate(self.shards): + shard_keys = [ + k.shards[i] if isinstance(k, ReplicatedTensor) else k for k in new_key + ] + shards.append(shard[*shard_keys]) + + shard_dim = self.shard_dim + for i in range(min(shard_dim, len(key))): + if isinstance(key[i], Number) and key[i] >= 0: + # Rank reduction dimension before the split dim. + shard_dim -= 1 + + return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) + + def __setitem__(self, key, value): + assert isinstance(value, SplitPrimitiveTensor) + assert self.shard_count == value.shard_count + if not isinstance(key, Sequence): + key = (key,) + key = _resolve_ellipsis_in_slicing(key, self.shape) + if self._is_slicing_split_dim(key): + raise NotImplementedError( + f"Slicing of the split dimension {self.shard_dim} is not supported." + ) + for i, (shard, value_shard) in enumerate(zip(self.shards, value.shards)): + shard_keys = [ + k.shards[i] if isinstance(k, ReplicatedTensor) else k for k in key + ] + shard[*shard_keys] = unbox_tensor(value_shard) @register_inference_tensor @@ -866,10 +1101,11 @@ def __init__( If `ts` is a tensor then `shard_count` must be provided and it, will be replicated that many times. """ - if isinstance(ts, torch.Tensor): assert shard_count is not None - ts = [ts] * shard_count + from ..ops import transfer_to_logical_device + + ts = [transfer_to_logical_device(ts, i) for i in range(shard_count)] shard_count = None assert shard_count is None @@ -884,8 +1120,10 @@ def __init__( self._shards: tuple[DefaultPrimitiveTensor] = tuple( DefaultPrimitiveTensor( name=f"{name}.shard.{i}", - data=transfer_to_logical_device(f"{i}", unbox_tensor(t)), + data=t, ) + if isinstance(t, torch.Tensor) + else t for i, t in enumerate(ts) ) @@ -936,13 +1174,35 @@ def create( ) -> "InferenceTensor": shard_count = int(extra_properties["shard_count"]) try: - ts = raw_tensors[""] + # We have to do this to avoid exporting as part of the `mlir` blob: + t = raw_tensors[""] + ts = [raw_tensors[""]] + for i in range(1, shard_count): + nt = deepcopy(t) + ts.append(nt) + + # TODO This should be changed to assigned affinities + for i in range(shard_count): + DeviceTensorTrait(i).set(ts[i]) + except KeyError as e: raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e - return cls(name=name, ts=ts, shard_count=shard_count) + return cls(name=name, ts=ts) def __getitem__(self, key): - shards = [shard[key] for shard in self.shards] + keys = [key] + if isinstance(key, tuple) or isinstance(key, list): + keys = key + + shards = [] + for i, shard in enumerate(self.shards): + shard_keys = [] + for k in keys: + if isinstance(k, ReplicatedTensor): + shard_keys.append(k.shards[i]) + else: + shard_keys.append(k) + shards.append(shard[*shard_keys]) return ReplicatedTensor(ts=shards) def __repr__(self): @@ -988,6 +1248,7 @@ def __init__( def flatten_tensor_tree( tree: tree_utils.Tree, ) -> Iterable[torch.Tensor | InferenceTensor]: + """Flatten up to our tensor types.""" return tree_utils.flatten( tree, is_leaf=lambda x: isinstance( @@ -1016,7 +1277,7 @@ def unbox_tensor(t: Any) -> Tensor: ######################################################################################## -def _dtype_to_serialized_name(dtype: torch.dtype) -> str: +def dtype_to_serialized_name(dtype: torch.dtype) -> str: try: return _DTYPE_TO_NAME[dtype] except KeyError as e: @@ -1025,7 +1286,7 @@ def _dtype_to_serialized_name(dtype: torch.dtype) -> str: ) from e -def _serialized_name_to_dtype(dtype_name: str) -> torch.dtype: +def serialized_name_to_dtype(dtype_name: str) -> torch.dtype: try: return _NAME_TO_DTYPE[dtype_name] except KeyError as e: @@ -1047,6 +1308,7 @@ def _serialized_name_to_dtype(dtype_name: str) -> torch.dtype: "int32": torch.int32, "int64": torch.int64, "bool": torch.bool, + "float8_e4m3fnuz": torch.float8_e4m3fnuz, } @@ -1097,10 +1359,16 @@ def unflatten_defult_primitive_tensor( return DefaultPrimitiveTensor(data=values_as_list[0], name=ctx["name"]) +def flatten_with_keys_default_primitive_tensor(t: DefaultPrimitiveTensor): + values, context = flatten_default_primitive_tensor(t) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + register_pytree_node( DefaultPrimitiveTensor, flatten_fn=flatten_default_primitive_tensor, unflatten_fn=unflatten_defult_primitive_tensor, + flatten_with_keys_fn=flatten_with_keys_default_primitive_tensor, ) @@ -1118,10 +1386,16 @@ def unflatten_split_primitive_tensor( ) +def flatten_with_keys_split_primitive_tensor(t: SplitPrimitiveTensor): + values, context = flatten_split_primitive_tensor(t) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + register_pytree_node( SplitPrimitiveTensor, flatten_fn=flatten_split_primitive_tensor, unflatten_fn=unflatten_split_primitive_tensor, + flatten_with_keys_fn=flatten_with_keys_split_primitive_tensor, ) @@ -1137,8 +1411,45 @@ def unflatten_replicated_tensor( return ReplicatedTensor(ts=list(values), name=ctx["name"]) +def flatten_with_keys_replicated_tensor(t: ReplicatedTensor): + values, context = flatten_replicated_tensor(t) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + register_pytree_node( ReplicatedTensor, flatten_fn=flatten_replicated_tensor, unflatten_fn=unflatten_replicated_tensor, + flatten_with_keys_fn=flatten_with_keys_replicated_tensor, +) + + +def flatten_unreduced_tensor( + t: UnreducedTensor, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + return list(t.shards), {"name": t.name} + + +def unflatten_unreduced_tensor( + values: Iterable[Any], ctx: torch.utils._pytree.Context +) -> UnreducedTensor: + return UnreducedTensor(ts=list(values), name=ctx["name"]) + + +def flatten_with_keys_unreduced_tensor(t: UnreducedTensor): + values, context = flatten_unreduced_tensor(t) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +register_pytree_node( + UnreducedTensor, + flatten_fn=flatten_unreduced_tensor, + unflatten_fn=unflatten_unreduced_tensor, + flatten_with_keys_fn=flatten_with_keys_unreduced_tensor, ) + + +def torch_tree_flatten(tree: tree_utils.Tree): + """Flatten a tree of tensors the same way they will be flattened during torch.export.export + if they are arguments or results of a function signature.""" + return torch.utils._pytree.tree_flatten(tree=tree) diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 975f54d24..29bc29bb8 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -15,7 +15,7 @@ import torch import torch.nn.functional as F -from shark_turbine.aot import ( +from iree.turbine.aot import ( ExternalTensorTrait, ParameterArchive, ParameterArchiveEntry, @@ -110,6 +110,18 @@ def transform(self, *transforms: InferenceTensorTransform) -> "Theta": def to(self, *, device: Optional[Union[str, torch.device]] = None) -> "Theta": return self.transform(InferenceTensorTransforms.to_device(device)) + def pop(self, *name_path: str | int) -> "Theta": + # prune a subtree from the tree and return it as a new Theta object + name_path = ".".join(_norm_name_path(name_path)) + flat = self.flatten() + accum = {} + key_list = list(flat.keys()) + for key in key_list: + if key.startswith(name_path): + accum[key] = flat.pop(key) + self._tree = flat_to_nested_dict(flat) + return Theta(flat_to_nested_dict(accum)) + def flatten(self) -> dict[str, InferenceTensor]: results = {} diff --git a/sharktank/sharktank/utils/__init__.py b/sharktank/sharktank/utils/__init__.py new file mode 100644 index 000000000..3651913ca --- /dev/null +++ b/sharktank/sharktank/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .misc import * diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 3a13b1d45..bc0b3b0b6 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -61,6 +61,29 @@ def add_output_dataset_options(parser: argparse.ArgumentParser): ) +def add_model_options(parser: argparse.ArgumentParser): + """Adds model config options not exclusive to export or eager""" + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch"], + ) + parser.add_argument( + "--skip-decode", + help="Enables prefill only, skips decode", + action="store_true", + ) + + +def add_quantization_options(parser: argparse.ArgumentParser): + parser.add_argument( + "--fake-quant", + action=argparse.BooleanOptionalAction, + help="whether or not to run/export the model in fake quant mode. Note, running eagerly without fake quant is dependent on torch types supporting operations. YMMV", + ) + + def add_tokenizer_options(parser: argparse.ArgumentParser): """Adds options for specifying a tokenizer. @@ -106,6 +129,8 @@ def get_input_dataset(args) -> Dataset: if "irpa" in data_files: return Dataset.load(data_files["irpa"], file_type="irpa") + raise ValueError(f'Dataset format unsupported. Must be "gguf" or "irpa".') + def get_tokenizer(args) -> tokenizer.InferenceTokenizer: """Gets a tokenizer based on arguments. diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py new file mode 100644 index 000000000..c1691c8a8 --- /dev/null +++ b/sharktank/sharktank/utils/create_cache.py @@ -0,0 +1,34 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..layers import * + + +def create_kv_cache(config: LlamaModelConfig) -> BaseKVCache: + hp = config.hp + if config.kv_cache_type == "direct": + return DirectKVCache( + block_seq_stride=config.block_seq_stride, + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + seq_length=hp.context_length, + device=config.device, + dtype=config.attention_dtype, + ) + elif config.kv_cache_type == "paged": + return PagedKVCache( + transformer_block_count=hp.block_count, + attn_head_count=hp.attention_head_count_kv, + attn_head_dim=hp.attn_head_dim, + cache_partition_count=2, # One for each of K/V. + block_seq_stride=config.block_seq_stride, + device=config.device, + dtype=config.attention_dtype, + shard_count=config.tensor_parallelism_size, + ) + else: + raise NotImplementedError(f"kv_cache_type = {config.kv_cache_type}") diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py new file mode 100644 index 000000000..c950a875a --- /dev/null +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -0,0 +1,333 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import subprocess +import logging +import time +import re +from pathlib import Path +from datetime import timedelta +from typing import List, Optional + +import iree.compiler as ireec + +logger = logging.getLogger("eval") + +logger.setLevel(logging.INFO) + +logger.root.handlers[0].setFormatter( + logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s") +) + + +class ExportMlirException(Exception): + """shark-ai export MLIR exception that preserves the command line and error output.""" + + def __init__(self, process: subprocess.CompletedProcess, cwd: str): + try: + errs = process.stderr.decode("utf-8") + except: + errs = str(process.stderr) + super().__init__( + f"Error invoking export_paged_llama_v1.py\n" + f"Error code: {process.returncode}\n" + f"Stderr diagnostics:\n{errs}\n\n" + f"Invoked with:\n" + f" cd {cwd} && {process.args}\n\n" + ) + + +class IreeCompileException(Exception): + """Compiler exception that preserves the command line and error output.""" + + def __init__(self, process: subprocess.CompletedProcess, cwd: str): + try: + errs = process.stderr.decode("utf-8") + except: + errs = str(process.stderr) + super().__init__( + f"Error invoking iree-compile\n" + f"Error code: {process.returncode}\n" + f"Stderr diagnostics:\n{errs}\n\n" + f"Invoked with:\n" + f" cd {cwd} && {process.args}\n\n" + ) + + +class IreeBenchmarkException(Exception): + """Runtime exception that preserves the command line and error output.""" + + def __init__(self, process: subprocess.CompletedProcess, cwd: str): + # iree-run-module sends output to both stdout and stderr + try: + errs = process.stderr.decode("utf-8") + except: + errs = str(process.stderr) + try: + outs = process.stdout.decode("utf-8") + except: + outs = str(process.stdout) + super().__init__( + f"Error invoking iree-benchmark-module\n" + f"Error code: {process.returncode}\n" + f"Stderr diagnostics:\n{errs}\n" + f"Stdout diagnostics:\n{outs}\n" + f"Run with:\n" + f" cd {cwd} && {process.args}\n\n" + ) + + +class ExportArtifacts: + def __init__( + self, + *, + irpa_path: str, + batch_size: int, + iree_hip_target: str, + iree_hal_target_backends: str, + attention_kernel: str, + tensor_parallelism_size: int, + ): + self.sharktank_dir = str( + Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent + ) + self.irpa_path = irpa_path + self.batch_size = batch_size + self.iree_hip_target = iree_hip_target + self.iree_hal_target_backends = iree_hal_target_backends + self.attention_kernel = attention_kernel + self.tensor_parallelism_size = tensor_parallelism_size + + def timeit(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + total_seconds = end - start + time_taken = abs(timedelta(seconds=total_seconds)) + hours, minutes, seconds = re.split(":", str(time_taken)) + + if total_seconds < 1: + time_taken = f" {round(total_seconds * 1000, 3)} ms" + elif total_seconds < 60: + time_taken = "{:.2f} secs".format(round(float(total_seconds), 2)) + else: + time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format( + int(hours), int(minutes), round(float(seconds), 2) + ) + + func_name = func.__name__ + logger.info(f" {func_name}: {time_taken}") + return result + + return wrapper + + @timeit + def shard_irpa_file( + self, + *, + irpa_file: str, + output_irpa: str, + ): + shard_irpa_args = [ + "python3", + "-m", + "sharktank.examples.sharding.shard_llm_dataset", + "--irpa-file", + irpa_file, + "--output-irpa-file", + output_irpa, + "--tensor-parallelism-size", + str(self.tensor_parallelism_size), + ] + + cwd = self.sharktank_dir + cmd = subprocess.list2cmdline(shard_irpa_args) + + logger.info(f"Sharding irpa file:\n" f"cd {cwd} && {cmd}") + + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True) + if proc.returncode != 0: + logger.error( + f"Error sharding irpa file with shard_llm_dataset.py\n" + f"{proc.stdout+proc.stderr}" + ) + else: + logger.info(f"Sharded irpa file successfully:\n" f"{proc.stdout}") + + return proc.returncode + + @timeit + def export_to_mlir( + self, + *, + mlir_path: str, + json_path: str, + skip_decode: Optional[bool] = None, + ): + export_args = [ + "python3", + "-m", + "sharktank.examples.export_paged_llm_v1", + f"--irpa-file={self.irpa_path}", + f"--output-mlir={mlir_path}", + f"--output-config={json_path}", + f"--bs={str(self.batch_size)}", + ] + if skip_decode: + export_args.append("--skip-decode") + if self.attention_kernel in ["decomposed", "torch"]: + export_args.append("--attention-kernel") + export_args.append(self.attention_kernel) + + cwd = self.sharktank_dir + cmd = subprocess.list2cmdline(export_args) + + logger.info(f" Exporting mlir:\n" f"cd {cwd} && {cmd}") + + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True) + if proc.returncode != 0: + raise ExportMlirException(proc, cwd) + else: + logger.info(f" Exported to mlir successfully:\n" f"{proc.stdout}") + + return proc.returncode + + @timeit + def compile_to_vmfb( + self, + *, + mlir_path, + vmfb_path, + cwd, + hal_dump_path: Optional[Path] = None, + args: Optional[List[str]] = None, + ): + # TODO: Control flag to enable multiple backends + compile_args = [ + f"iree-compile", + f"{mlir_path}", + f"--iree-hip-target={self.iree_hip_target}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"-o={vmfb_path}", + ] + if self.tensor_parallelism_size > 1: + iree_hal_target_devices = [ + f"--iree-hal-target-device=hip[{i}]" + for i in range(self.tensor_parallelism_size) + ] + compile_args += iree_hal_target_devices + if hal_dump_path: + compile_args += [ + f"--iree-hal-dump-executable-files-to={hal_dump_path}/files" + ] + # Append optional arguments if provided + if args: + compile_args += args + cmd = subprocess.list2cmdline(compile_args) + + logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}") + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) + return_code = proc.returncode + if return_code != 0: + raise IreeCompileException(proc, cwd) + + def iree_benchmark_vmfb( + self, + *, + hip_device_id: str, + vmfb_name: str, + irpa_path: str, + args: List[str], + cwd: str | Path, + ): + """Runs a compiled program with the given args using `iree-benchmark-module`. + This assumes that the `iree-benchmark-module` command is available (usually via PATH). + Args: + vmfb_name: Name of the .vmfb file (relative to `cwd`). + args: List of arguments to pass to `iree-benchmark-module`. + cwd: Working directory to run the command within. (either string or Path works) + compile_cmd: Command used to compile the program, for inclusion in error messages. + Raises Exception if running fails for some reason. + """ + benchmark_args = [] + if self.tensor_parallelism_size > 1: + base_irpa_path, _ = os.path.splitext(irpa_path) + rocr_visible_devices = [ + f"ROCR_VISIBLE_DEVICES={','.join(str(i) for i in range(self.tensor_parallelism_size))}" + ] + params = [f"--parameters=model={base_irpa_path}.irpa"] + params += [ + f"--parameters=model={base_irpa_path}.rank{i}.irpa" + for i in range(self.tensor_parallelism_size) + ] + devices = [ + f"--device=hip://{i}" for i in range(self.tensor_parallelism_size) + ] + else: + rocr_visible_devices = [f"ROCR_VISIBLE_DEVICES={hip_device_id}"] + params = [f"--parameters=model={irpa_path}"] + devices = [f"--device=hip://{hip_device_id}"] + benchmark_args += rocr_visible_devices + benchmark_args += [ + "iree-benchmark-module", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--device_allocator=caching", + f"--module={vmfb_name}", + ] + benchmark_args += params + benchmark_args += devices + benchmark_args += args + cmd = subprocess.list2cmdline(benchmark_args) + logger.info(f" Launching run command:\n" f"cd {cwd} && {cmd}") + proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) + return_code = proc.returncode + if return_code != 0: + raise IreeBenchmarkException(proc, cwd) + + def create_file(self, *, suffix, prefix): + file_path = Path(prefix).with_suffix(suffix) + f = open(file_path, "w") + return file_path + + def get_artifacts(self): + + self.dir_path = self.sharktank_dir + "/" + "tmp_perplexity_ci_artifacts/" + temp_dir = Path(self.dir_path) + temp_dir.mkdir(parents=True, exist_ok=True) + + model_name = ( + str(self.irpa_path).split("/")[-1].split(".")[0] + + "_" + + self.attention_kernel + ) + mlir_path = str( + self.create_file(suffix=".mlir", prefix=self.dir_path + model_name) + ) + json_path = str( + self.create_file(suffix=".json", prefix=self.dir_path + model_name) + ) + vmfb_path = str( + self.create_file(suffix=".vmfb", prefix=self.dir_path + model_name) + ) + + if self.attention_kernel == "decomposed": + returncode = self.export_to_mlir( + mlir_path=mlir_path, + json_path=json_path, + ) + + if returncode == 0: + self.compile_to_vmfb( + mlir_path=mlir_path, + vmfb_path=vmfb_path, + cwd=self.sharktank_dir, + ) + + return vmfb_path diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index e49623a8d..0562d5854 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -83,6 +83,23 @@ def alias_dataset(from_name: str, to_name: str): # Dataset definitions ################################################################################ +Dataset( + "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", + ( + RemoteFile( + "gguf", + "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", + "meta-llama-3.1-8b-instruct.f16.gguf", + ), + RemoteFile( + "tokenizer_config.json", + "NousResearch/Meta-Llama-3-8B", + "tokenizer_config.json", + extra_filenames=["tokenizer.json"], + ), + ), +).alias_to("llama3_8B_fp16") + Dataset( "QuantFactory/Llama-3-8B_q4_1_gguf", ( @@ -247,6 +264,86 @@ def alias_dataset(from_name: str, to_name: str): ), ).alias_to("mixtral_8x7b_q8_0_gguf") +Dataset( + "amd-shark/llama3.1-8B", + ( + RemoteFile( + "gguf", + "amd-shark/llama-quant-models", + "llama3.1-8b/llama8b_f16.gguf", + ), + RemoteFile( + "tokenizer_config.json", + "amd-shark/llama-quant-models", + "llama3.1-8b/tokenizer_config.json", + extra_filenames=["llama3.1-8b/tokenizer.json"], + ), + ), +).alias_to("llama3_8B_f16") + +Dataset( + "amd-shark/llama2-7B", + ( + RemoteFile( + "gguf", + "amd-shark/llama-quant-models", + "llama2-7b/llama2_7b_f16.gguf", + ), + RemoteFile( + "tokenizer_config.json", + "amd-shark/llama-quant-models", + "llama2-7b/tokenizer_config.json", + extra_filenames=["llama2-7b/tokenizer.json"], + ), + ), +).alias_to("llama2_7B_f16") + +Dataset( + "google/t5-v1_1-small", + ( + RemoteFile( + "config", + "google/t5-v1_1-small", + "config.json", + extra_filenames=["generation_config.json", "special_tokens_map.json"], + ), + RemoteFile( + "tokenizer_config.json", + "google/t5-v1_1-small", + "tokenizer_config.json", + extra_filenames=["spiece.model"], + ), + RemoteFile( + "pytorch_model.bin", + "google/t5-v1_1-small", + "pytorch_model.bin", + ), + ), +) + +Dataset( + "google/t5-v1_1-xxl", + ( + RemoteFile( + "config", + "google/t5-v1_1-xxl", + "config.json", + extra_filenames=["generation_config.json", "special_tokens_map.json"], + ), + RemoteFile( + "tokenizer_config.json", + "google/t5-v1_1-xxl", + "tokenizer_config.json", + extra_filenames=["spiece.model"], + ), + RemoteFile( + "pytorch_model.bin", + "google/t5-v1_1-xxl", + "pytorch_model.bin", + ), + ), +) + ################################################################################ # Tool entrypoint ################################################################################ diff --git a/sharktank/sharktank/utils/io.py b/sharktank/sharktank/utils/io.py index 62fd78f33..ac2480846 100644 --- a/sharktank/sharktank/utils/io.py +++ b/sharktank/sharktank/utils/io.py @@ -6,7 +6,7 @@ from pathlib import Path -from shark_turbine.aot import ( +from iree.turbine.aot import ( ParameterArchiveBuilder, ) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py new file mode 100644 index 000000000..d5976ec48 --- /dev/null +++ b/sharktank/sharktank/utils/iree.py @@ -0,0 +1,192 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.runtime +from typing import List, Tuple, Optional, Union +from pathlib import Path +import torch +import numpy as np +import collections.abc +from collections import OrderedDict +from ..types.tensors import ( + AnyTensor, + InferenceTensor, + ShardedTensor, + DefaultPrimitiveTensor, + unbox_tensor, + torch_tree_flatten, +) +from .tree import Tree + + +def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]: + hal_driver = iree.runtime.get_driver(driver) + available_devices = hal_driver.query_available_devices() + if driver in ["local-task", "local-sync"]: + # Use the same actual device for all devices. + return [ + hal_driver.create_device(available_devices[0]) for _ in range(device_count) + ] + else: + return [ + hal_driver.create_device(available_devices[i]) for i in range(device_count) + ] + + +def load_iree_module( + module_path: str, + devices: List[iree.runtime.HalDevice], + parameters_path: Optional[str] = None, +) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]: + """The VmContext and VmInstance need to outlive the VmModule and any device + buffers.""" + vm_instance = iree.runtime.VmInstance() + hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices) + modules = [hal_module] + if parameters_path is not None: + params_path = Path(parameters_path) + parameter_index = iree.runtime.ParameterIndex() + if len(devices) > 1: + # TODO: make IREE able to load the parameters from the top parameter file + # without having to specify the parameter file for each shard separately. + for i in range(len(devices)): + parameter_index.load( + file_path=str( + Path(params_path).with_suffix(f".rank{i}{params_path.suffix}") + ) + ) + else: + parameter_index.load(file_path=str(params_path)) + parameter_provider = parameter_index.create_provider(scope="model") + parameters_module = iree.runtime.create_io_parameters_module( + vm_instance, parameter_provider + ) + modules.append(parameters_module) + vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path)) + modules.append(vm_module) + vm_context = iree.runtime.VmContext(instance=vm_instance, modules=modules) + return vm_module, vm_context, vm_instance + + +def run_iree_module_function( + module: iree.runtime.VmModule, + vm_context: iree.runtime.VmContext, + args: List[iree.runtime.DeviceArray], + driver: str, + function_name: str = "main", + trace_path_prefix: Optional[str] = None, +) -> List[iree.runtime.DeviceArray]: + """Run IREE module function with optional tracing of arguments/results.""" + vm_function = module.lookup_function(function_name) + invoker = iree.runtime.FunctionInvoker( + vm_context=vm_context, + # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices. + # This works, but does not look right. + device=iree.runtime.get_device(driver, cache=False), + vm_function=vm_function, + ) + if trace_path_prefix is not None: + for i, arg in enumerate(args): + np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host()) + results = invoker(*args) + if isinstance(results, iree.runtime.DeviceArray): + results = (results,) + + if trace_path_prefix is not None: + for i, arg in enumerate(args): + np.save( + f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy", + arg.to_host(), + ) + for i, arg in enumerate(results): + np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host()) + return results + + +def prepare_iree_module_function_args( + args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice] +) -> List[iree.runtime.DeviceArray]: + """Flatten composite tensors into their parts and place them on devices. + Sharded tensors become a list of their shards while placing them onto their + corresponding device. + All unsharded tensors go on device 0. + """ + res = [] + for arg in args: + if isinstance(arg, ShardedTensor): + assert len(devices) == len(arg.shards) + res.extend( + [ + prepare_iree_module_function_args([shard], [device])[0] + for shard, device in zip(arg.shards, devices) + ] + ) + elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)): + res.append( + iree.runtime.asdevicearray( + devices[0], unbox_tensor(arg).to("cpu").numpy() + ) + ) + else: + assert isinstance(arg, collections.abc.Sequence) + res.extend(prepare_iree_module_function_args(arg, devices)) + return res + + +def flatten_for_iree_signature(tree: Tree) -> List[torch.Tensor]: + """Flatten a tree of arguments or results for an IREE call. + E.g. sharded tensors gets flattened into their shards.""" + return torch_tree_flatten(tree)[0] + + +def call_torch_module_function( + module: torch.nn.Module, + function_name: str, + kwargs: OrderedDict, + trace_path_prefix: Optional[str] = None, +): + """Call a torch module function with optional tracing. + For tracing the arguments/results are flattened to match IREE's signature.""" + assert isinstance( + kwargs, OrderedDict + ), "Make sure when flattening the order is preserved" + if trace_path_prefix is not None: + flat_args = flatten_for_iree_signature(kwargs) + for i, arg in enumerate(flat_args): + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + arg.to("cpu").numpy(), + ) + res = getattr(module, function_name)(**kwargs) + if trace_path_prefix is not None: + flat_args = flatten_for_iree_signature(kwargs) + for i, arg in enumerate(flat_args): + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + arg.to("cpu").numpy(), + ) + results = ( + (res,) + if isinstance( + res, + ( + torch.Tensor, + InferenceTensor, + ), + ) + else res + ) + flat_results = flatten_for_iree_signature(results) + for i, result in enumerate(flat_results): + np.save( + f"{trace_path_prefix}{function_name}_result{i}.npy", + result.to("cpu").numpy(), + ) + return res + + +def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: + return [torch.tensor(tensor.to_host()) for tensor in tensors] diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py new file mode 100644 index 000000000..47d9f0244 --- /dev/null +++ b/sharktank/sharktank/utils/load_llm.py @@ -0,0 +1,211 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import math + +import torch + +from sharktank.layers import * +from sharktank.types import * +from sharktank.models.llama.llama import * + +from ..utils.debugging import trace_tensor +from ..utils.tokenizer import InferenceTokenizer + + +class TorchGenerator: + """Generator that runs directly on the Torch model.""" + + def __init__( + self, + model: PagedLlamaModelV1, + tokenizer: InferenceTokenizer, + # Need to look at the model more for this. + end_token: int = 2, + ): + self.model = model + self.tokenizer = tokenizer + self.end_token = end_token + + @property + def block_seq_stride(self) -> int: + return self.model.cache.block_seq_stride + + def begin_batch( + self, prompts: list[str], add_start_token: bool, page_cache_size: int = 128 + ): + token_ids, seq_lens = self.tokenizer.encode( + prompts, + pad_to_multiple_of=self.model.cache.pad_sequence_stride, + add_start_token=add_start_token, + ) + token_ids = torch.tensor(token_ids, device=self.model.device) + seq_lens = torch.tensor(seq_lens, device=self.model.device) + + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) + else: + cache_state = self.model.cache.direct.allocate(bs=len(prompts)) + return Batch(self, token_ids, seq_lens, cache_state) + + def begin_eval_batch( + self, + token_batch: torch.tensor, + seq_lens_batch: torch.tensor, + bs: int, + page_cache_size: int = 128, + ): + if self.model.cache.is_paged: + cache_state = self.model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) + else: + cache_state = self.model.cache.direct.allocate(bs=bs) + return Batch(self, token_batch, seq_lens_batch, cache_state) + + def alloc_page(self) -> int: + if self.model.cache.is_direct: + # We don't allocate block ids for the direct cache. + return 0 + + return self.free_pages.pop() + + def release_page(self, index: int): + if self.model.cache.is_direct: + return + self.free_pages.append(index) + + +class Batch: + def __init__( + self, + parent: TorchGenerator, + token_ids: torch.Tensor, + seq_lens: torch.Tensor, + cache_state: list[torch.Tensor], + ): + self.bs = token_ids.shape[0] + # assert seq_lens.shape[0] == self.bs + self.parent = parent + self.token_ids = token_ids + self.seq_lens = seq_lens + self.cache_state = cache_state + self.results: list[list[int]] = [[] for _ in range(self.bs)] + self.done_result_indices: set[int] = set() + + # Assemble the batch. + seq_stride = self.parent.block_seq_stride + self.seq_block_ids: list[list[int]] = [] + for seq_len in self.seq_lens: + blocks_needed = ( + int(math.ceil(seq_len / seq_stride)) if seq_stride > 0 else 0 + ) + row = [] + for _ in range(blocks_needed): + row.append(self.parent.alloc_page()) + self.seq_block_ids.append(row) + + @property + def done(self) -> bool: + return len(self.done_result_indices) == self.bs + + def detokenize(self) -> list[str]: + return self.parent.tokenizer.decode(self.results) + + def print_current_results(self): + results = self.detokenize() + for i, s in enumerate(results): + seq_len = int(self.seq_lens[i]) + print(f" {i}({len(self.results[i])}, {seq_len}): {s}") + + def add_result_token(self, tokens: torch.Tensor): + for i in range(self.bs): + token = tokens[i][0] + if token == self.parent.end_token: + self.done_result_indices.add(i) + if i in self.done_result_indices: + continue + token = int(tokens[i, 0]) + self.results[i].append(token) + + def allocate_seq_block_ids(self): + for i in range(self.bs): + sl = int(self.seq_lens[i]) + if (sl % self.parent.block_seq_stride) == 0: + needed_blocks = sl // self.parent.block_seq_stride + 1 + else: + needed_blocks = math.ceil(sl / self.parent.block_seq_stride) + block_ids_row = self.seq_block_ids[i] + while len(block_ids_row) < needed_blocks: + block_ids_row.append(self.parent.alloc_page()) + + def prefill(self): + model = self.parent.model + attention_mask = model.attention_mask( + model.input_mask(self.seq_lens, self.token_ids.shape[1]) + ) + seq_block_ids_tensor = self.pad_block_ids() + trace_tensor("prefill.token_ids", self.token_ids) + trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor) + trace_tensor("prefill.attention_mask", attention_mask) + self.prefill_logits = model.prefill( + self.token_ids, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids_tensor, + cache_state=self.cache_state, + ) + + # TODO: Generalize the sampling and don't make it swap on/off cpu. + # TODO: Normalize the output of extract_tokens_from_logits into + # tensor [bs, 1]. + tokens = torch.tensor( + model.extract_tokens_from_logits(self.prefill_logits, self.seq_lens) + ).unsqueeze(1) + self.add_result_token(tokens) + self.next_tokens = tokens.to(device=model.device) + + def decode(self, token_batch): + self.token_batch = token_batch + + model = self.parent.model + start_positions = self.seq_lens.clone() + self.seq_lens.add_(1) + self.allocate_seq_block_ids() + # TODO: Allocate more blocks on overflow. + seq_block_ids_tensor = self.pad_block_ids() + decode_attention_mask = model.decode_attention_mask( + model.input_mask( + self.seq_lens, + seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride, + ) + ) + trace_tensor("decode.token_ids", self.token_ids) + trace_tensor("decode.start_positions", start_positions) + trace_tensor("decode.seq_block_ids", seq_block_ids_tensor) + trace_tensor("decode.attention_mask", decode_attention_mask) + + self.decode_logits = model.decode( + self.token_batch, + attention_mask=decode_attention_mask, + start_positions=start_positions, + seq_block_ids=seq_block_ids_tensor, + cache_state=self.cache_state, + ) + + trace_tensor("decode.logits", self.decode_logits) + # # TODO: Normalize the output of extract_tokens_from_logits into + # # tensor [bs, 1]. + tokens = torch.tensor( + model.extract_tokens_from_logits(self.decode_logits, [1] * self.bs), + device=self.parent.model.device, + ).unsqueeze(1) + self.add_result_token(tokens) + self.next_tokens = tokens + + def pad_block_ids(self) -> torch.Tensor: + max_length = max(len(r) for r in self.seq_block_ids) + rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids] + return torch.tensor(rows, device=self.parent.model.device) diff --git a/sharktank/sharktank/utils/logging.py b/sharktank/sharktank/utils/logging.py index 977462d86..3801f96cb 100644 --- a/sharktank/sharktank/utils/logging.py +++ b/sharktank/sharktank/utils/logging.py @@ -6,7 +6,7 @@ import logging -from shark_turbine.support.logging import get_logger +from iree.turbine.support.logging import get_logger transform_logger: logging.Logger = get_logger("sharktank.transforms") diff --git a/sharktank/sharktank/utils/math.py b/sharktank/sharktank/utils/math.py index df47b5ae6..3f32ac952 100644 --- a/sharktank/sharktank/utils/math.py +++ b/sharktank/sharktank/utils/math.py @@ -4,6 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from numbers import Number + def ceildiv(a: int | float, b: int | float) -> int | float: return -(a // -b) + + +def round_up_to_multiple_of(x: Number, multiple: Number) -> Number: + return x + (-x % multiple) diff --git a/sharktank/sharktank/utils/misc.py b/sharktank/sharktank/utils/misc.py new file mode 100644 index 000000000..dd1ebf4d6 --- /dev/null +++ b/sharktank/sharktank/utils/misc.py @@ -0,0 +1,30 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Callable, List +from collections.abc import Iterable +from operator import eq + + +def longest_equal_range(l1: List[Any], l2: List[Any]) -> int: + """Find the longest range that is the same from the start of both lists. + Returns the greatest `i` such that `l1[0:i] == l2[0:i]`.""" + for i, (a, b) in enumerate(zip(l1, l2)): + if a != b: + return i + return min(len(list(l1)), len(list(l2))) + + +def iterables_equal( + iterable1: Iterable, + iterable2: Iterable, + *, + elements_equal: Callable[[Any, Any], bool] | None = None +) -> bool: + elements_equal = elements_equal or eq + return all( + elements_equal(v1, v2) for v1, v2 in zip(iterable1, iterable2, strict=True) + ) diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py index 8494c1d9b..933bfd2b6 100644 --- a/sharktank/sharktank/utils/testing.py +++ b/sharktank/sharktank/utils/testing.py @@ -10,15 +10,26 @@ import shutil import tempfile import unittest +import torch +from typing import Any, Callable +from operator import eq +from collections.abc import Iterable +import gc from ..types import * +# Range of torch.rand() is [0,1) +# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values +def make_rand_torch(shape, dtype=torch.float32): + return torch.rand(shape, dtype=dtype) * 2 - 1 + class TempDirTestBase(unittest.TestCase): def setUp(self): self._temp_dir = Path(tempfile.mkdtemp(type(self).__qualname__)) def tearDown(self): + gc.collect() shutil.rmtree(self._temp_dir, ignore_errors=True) @@ -93,6 +104,30 @@ def get_best_torch_device() -> str: return "cpu" +def assert_dicts_equal( + dict1: dict, dict2: dict, *, values_equal: Callable[[Any, Any], bool] | None = None +) -> None: + values_equal = values_equal or eq + assert len(dict1) == len( + dict2 + ), f"Dictionaries not equal. {dict1} and {dict2} have different number of elements {len(dict1)} != {len(dict2)}" + for k, v1 in dict1.items(): + assert ( + k in dict2 + ), f"Dictionaries {dict1} and {dict2} not equal. Key {k} not found in {dict2}" + v2 = dict2[k] + assert values_equal( + v1, dict2[k] + ), f"Dictionaries {dict1} and {dict2} not equal for key {k}. Values {v1} and {v2} not equal" + + +def assert_equal( + a: Any, b: Any, *, equal: Callable[[Any, Any], bool] | None = None +) -> None: + equal = equal or eq + assert equal(a, b), f"{a} and {b} are not equal" + + def assert_golden_safetensors(actual_path, ref_path): """Asserts that actual and reference safetensors files are within tolerances.""" from safetensors import safe_open @@ -127,3 +162,36 @@ def print_stats(label, t): actual = actual_f.get_tensor(name) ref = ref_f.get_tensor(name) torch.testing.assert_close(actual, ref, msg=name) + + +def assert_iterables_equal( + iterable1: Iterable, + iterable2: Iterable, + *, + elements_equal: Callable[[Any, Any], bool] | None = None, +) -> None: + elements_equal = elements_equal or eq + for i, (v1, v2) in enumerate(zip(iterable1, iterable2, strict=True)): + assert elements_equal( + v1, v2 + ), f"Iterables not equal at index {i} for elements {v1} and {v2}" + + +SHARKTANK_TEST_SKIP_ENV_VAR = "SHARKTANK_TEST_SKIP" + + +def skip(*decorator_args, **decorator_kwargs): + """Decorator to skip a test when SHARKTANK_TEST_SKIP env var is not set or != 0""" + + def decorator(test_item: Callable): + if SHARKTANK_TEST_SKIP_ENV_VAR not in os.environ: + should_skip = True + else: + should_skip = os.environ[SHARKTANK_TEST_SKIP_ENV_VAR] != "0" + + if should_skip: + return unittest.skip(*decorator_args, **decorator_kwargs)(test_item) + + return test_item + + return decorator diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index 29a57f958..b459c706a 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -22,34 +22,57 @@ class InferenceTokenizer(ABC): """Simple inference tokenizer.""" def encode( - self, texts: list[str], pad_to_multiple_of: int = 1, pad_token: int = 0 + self, + texts: list[str], + pad_to_multiple_of: int = 1, + add_start_token: bool = False, ) -> tuple[list[list[int]]]: """Encodes a list of texts into a padded list of tokens. Returns a list of list of tokens and a list of unpadded lengths. """ - raw_rows = self._encode(texts) + raw_rows = self._encode(texts, add_start_token) + raw_rows, lengths = self.pad_tokens( + token_ids=raw_rows, pad_to_multiple_of=pad_to_multiple_of + ) + return raw_rows, lengths + + def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None): + """Decodes a list of tokens.""" + if lens is not None: + tokens = list(tokens) + for i, row_length in enumerate(lens): + tokens[i] = tokens[i][0:row_length] + return self._decode(tokens) + + def get_prompt_lengths( + self, + token_ids: list[list[int]], + ): max_length = 0 lengths: list[int] = [] - for row in raw_rows: + for row in token_ids: lengths.append(len(row)) max_length = max(max_length, len(row)) + + return lengths, max_length + + def pad_tokens( + self, + token_ids: list[list[int]], + pad_to_multiple_of: int, + pad_token: int = 0, + ): + lengths, max_length = self.get_prompt_lengths(token_ids) if pad_to_multiple_of > 1: max_length = int( pad_to_multiple_of * math.ceil(max_length / pad_to_multiple_of) ) - for row in raw_rows: + for row in token_ids: pad_count = max_length - len(row) row.extend(pad_count * [pad_token]) - return raw_rows, lengths - def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None): - """Decodes a list of tokens.""" - if lens is not None: - tokens = list(tokens) - for i, row_length in enumerate(lens): - tokens[i] = tokens[i][0:row_length] - return self._decode(tokens) + return token_ids, lengths @abstractmethod def _encode(self, texts: list[str]) -> list[list[int]]: @@ -76,9 +99,10 @@ class _TransformersTokenizer(InferenceTokenizer): def __init__(self, t: AutoTokenizer): self._t = t - def _encode(self, texts: list[str]) -> list[list[int]]: + def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]: results = t.batch_encode_plus( texts, + add_special_tokens=add_start_token, padding=False, truncation=False, ) diff --git a/sharktank/sharktank/utils/vmfb_runner.py b/sharktank/sharktank/utils/vmfb_runner.py new file mode 100644 index 000000000..cdbf96c9d --- /dev/null +++ b/sharktank/sharktank/utils/vmfb_runner.py @@ -0,0 +1,82 @@ +from iree import runtime as ireert +from iree.runtime._binding import create_hal_driver + + +class vmfbRunner: + def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None): + + # If an extra plugin is requested, add a global flag to load the plugin + # and create the driver using the non-caching creation function, as + # the caching creation function may ignore the flag. + if extra_plugin: + ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") + haldriver = create_hal_driver(device) + + # No plugin requested: create the driver with the caching create + # function. + else: + haldriver = ireert.get_driver(device) + if "://" in device: + try: + device_idx = int(device.split("://")[-1]) + device_uri = None + except: + device_idx = None + device_uri = device.split("://")[-1] + else: + device_idx = 0 + device_uri = None + if device_uri: + if not any(x in device for x in ["cpu", "task"]): + allocators = ["caching"] + haldevice = haldriver.create_device_by_uri( + device_uri, allocators=allocators + ) + else: + haldevice = haldriver.create_device_by_uri(device_uri) + else: + hal_device_id = haldriver.query_available_devices()[device_idx]["device_id"] + if not any(x in device for x in ["cpu", "task"]): + allocators = ["caching"] + haldevice = haldriver.create_device( + hal_device_id, allocators=allocators + ) + else: + haldevice = haldriver.create_device(hal_device_id) + + self.config = ireert.Config(device=haldevice) + mods = [] + if not isinstance(vmfb_path, list): + vmfb_path = [vmfb_path] + for path in vmfb_path: + mods.append(ireert.VmModule.mmap(self.config.vm_instance, path)) + vm_modules = [ + *mods, + ireert.create_hal_module(self.config.vm_instance, self.config.device), + ] + + # TODO: Enable multiple weight files + if external_weight_path: + index = ireert.ParameterIndex() + if not isinstance(external_weight_path, list): + external_weight_path = [external_weight_path] + for i, path in enumerate(external_weight_path): + if path in ["", None]: + continue + index.load(path) + # TODO: extend scope + param_module = ireert.create_io_parameters_module( + self.config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(i, param_module) + del param_module + del index + + self.ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=self.config, + ) + + def unload(self): + self.ctx = None + self.config = None diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json new file mode 100644 index 000000000..24511b05f --- /dev/null +++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json @@ -0,0 +1,318 @@ +{ + "llama3_8B_f16_decomposed": { + "perplexities": [ + 6.677369, + 21.807926, + 15.424338, + 17.332415, + 14.951956, + 7.913092, + 8.728321, + 22.425966, + 8.184698, + 20.977249, + 7.088408, + 14.574989, + 9.036912, + 7.277581, + 16.132208, + 6.685175, + 6.525683, + 7.080791, + 10.680925, + 9.034086, + 10.639015, + 41.102894, + 11.723896, + 64.305908, + 47.054577, + 19.9259, + 18.918842, + 13.842684, + 9.974381, + 5.919641, + 10.181265, + 23.609016, + 14.340417, + 9.712208, + 5.602878, + 14.088163, + 5.680599, + 17.377926, + 9.037231, + 8.305407, + 8.028031, + 17.744528, + 11.5076, + 3.936302, + 12.987297, + 10.371798, + 11.927772, + 21.387051, + 37.799526, + 25.67762, + 15.429109, + 13.923962, + 7.594806, + 10.983875, + 14.595965, + 11.022234, + 5.853358, + 15.609065, + 8.044486, + 14.389134, + 5.917565, + 6.892455, + 2.30309, + 15.974725, + 42.017342, + 8.022307, + 12.284297, + 10.018423, + 9.268936, + 10.680118, + 8.12535, + 21.550434, + 3.638689, + 15.345065, + 23.742884, + 14.288899, + 17.796623, + 16.515446, + 8.746647, + 12.922096, + 12.94269, + 13.574061, + 14.013302, + 10.76523, + 14.746032, + 28.208134, + 17.646687, + 9.848188, + 15.280471, + 15.621455, + 29.126505, + 12.302313, + 32.452534, + 31.192411, + 14.371797, + 17.490683, + 14.689407, + 15.284843, + 12.252508, + 16.460979 + ], + "mean_perplexity": 14.930181 + }, + + "llama3_405B_f16_decomposed": { + "perplexities": [ + 2.170036, + 8.014498, + 3.743922, + 10.629776, + 8.965701, + 2.884743, + 2.886767, + 3.853816, + 2.73785, + 15.235562, + 2.65135, + 1.970936, + 5.08259, + 2.507602, + 7.571635, + 3.005182, + 1.904492, + 3.182651, + 6.249443, + 4.661795, + 12.68933, + 35.432453, + 5.50336, + 60.950359, + 18.433432, + 5.001391, + 4.814827, + 2.99482, + 2.697508, + 2.617349, + 2.359061, + 16.697233, + 2.145065, + 2.1207, + 2.496015, + 1.822896, + 4.671626, + 2.389186, + 2.701802, + 1.921128, + 2.236057, + 4.741998, + 4.946936, + 2.758695, + 2.446043, + 2.146302, + 8.72202, + 4.180647, + 11.449497, + 13.429152, + 3.72468, + 2.407385, + 3.592854, + 5.412414, + 3.189998, + 4.186216, + 1.642744, + 2.279058, + 1.855652, + 3.453852, + 1.436223, + 1.516955, + 1.716439, + 4.715765, + 21.48657, + 2.208737, + 6.420449, + 2.001433, + 2.400955, + 3.543744, + 3.054271, + 7.904545, + 1.950376, + 3.983746, + 6.28265, + 2.64157, + 5.473378, + 3.444444, + 1.926046, + 3.092915, + 3.996159, + 3.125222, + 1.718025, + 3.856093, + 3.041075, + 11.798485, + 14.881112, + 5.631516, + 4.407883, + 4.840533, + 21.351448, + 2.065821, + 6.658993, + 28.123312, + 1.673253, + 3.729975, + 5.336116, + 8.579758, + 2.979404, + 1.915619 + ], + "mean_perplexity": 6.060831 + }, + "llama3_8B_f16_decomposed_iree": { + "perplexities": [ + 6.651368, + 22.059452, + 15.392176, + 17.418619, + 15.206824, + 7.907998, + 8.829535, + 22.355659, + 8.29262, + 20.958277, + 7.167404, + 14.592677, + 9.060788, + 7.274667, + 16.238981, + 6.666115, + 6.535679, + 7.086256, + 10.676177, + 8.979206, + 10.597121, + 42.038162, + 11.70071, + 65.731316, + 47.42622, + 20.109543, + 18.897541, + 13.781085, + 9.99165, + 5.955308, + 10.175659, + 23.628405, + 14.306578, + 9.719462, + 5.594786, + 14.198979, + 5.711433, + 17.381332, + 9.058512, + 8.286205, + 8.016202, + 18.4515, + 11.600831, + 3.945074, + 13.000222, + 10.373363, + 12.237907, + 21.408463, + 37.858665, + 25.794065, + 15.489001, + 14.004895, + 7.625473, + 10.993184, + 14.698832, + 11.062652, + 5.855446, + 15.625135, + 8.052419, + 14.365479, + 5.927001, + 6.931933, + 2.3014, + 15.769623, + 40.843319, + 8.022024, + 12.544907, + 10.090073, + 9.304819, + 10.679907, + 8.136175, + 21.540607, + 3.736973, + 15.381804, + 24.21562, + 14.385005, + 17.791706, + 16.498833, + 8.753955, + 12.941816, + 12.887664, + 13.725715, + 13.994792, + 10.769128, + 14.734674, + 26.970015, + 17.811842, + 9.847188, + 15.124973, + 15.623392, + 29.147844, + 12.309229, + 32.15152, + 33.225769, + 14.426914, + 17.496277, + 14.7356, + 15.503921, + 12.336852, + 16.469248 + ], + "mean_perplexity": 14.991893 + } +} diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py new file mode 100644 index 000000000..d10d9f5db --- /dev/null +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -0,0 +1,327 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import pytest +import json +import numpy as np + +from sharktank.evaluate import perplexity_iree + +is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") +skipif_run_quick_llama_test = pytest.mark.skipif( + 'not config.getoption("run-nightly-llama-tests")', + reason="Run large tests if --run-nightly-llama-tests is passed", +) + + +@pytest.mark.usefixtures( + "get_model_artifacts", + "get_iree_flags", + "tensor_parallelism_size", + "baseline_perplexity_scores", + "batch_size", +) +@is_mi300x +class PerplexityTest(unittest.TestCase): + def setUp(self): + self.current_perplexity_all = {} + self.delta = 5e-1 + self.tensor_parallelism_size = 8 + with open(self.baseline_perplexity_scores, "r") as f: + self.baseline_perplexity = json.load(f) + + def test_llama3_8B_f16_decomposed(self): + + # Llama 3.1 8B decomposed + + model_name = "llama3_8B_f16_decomposed_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_8b_f16_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") + def test_llama3_8B_f16(self): + + # Llama 3.1 8B non-decomposed + + model_name = "llama3_8B_f16_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_8b_f16_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") + def test_llama3_8B_fp8_decomposed(self): + + # Llama 3.1 8B decomposed + + model_name = "llama3_8B_fp8_decomposed_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_8b_fp8_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") + def test_llama3_8B_fp8(self): + + # Llama 3.1 8B non-decomposed + + model_name = "llama3_8B_fp8_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_8b_fp8_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size=1", + f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail( + reason="Sharding is unsupported", + ) + def test_llama3_405B_f16_decomposed(self): + + # Llama 3.1 405B decomposed + + model_name = "llama3_405B_f16_decomposed_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_405b_f16_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") + def test_llama3_405B_f16(self): + + # Llama 3.1 405B non-decomposed + + model_name = "llama3_405B_f16_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_405b_f16_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") + def test_llama3_405B_fp8_decomposed(self): + + # Llama 3.1 405B decomposed + + model_name = "llama3_405B_fp8_decomposed_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_405b_fp8_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=decomposed", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error") + def test_llama3_405B_fp8(self): + + # Llama 3.1 405B non-decomposed + + model_name = "llama3_405B_fp8_iree" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_iree.main( + [ + f"--irpa-file={self.llama3_405b_fp8_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--iree-device={self.iree_device}", + f"--iree-hal-target-backends={self.iree_hal_target_backends}", + f"--iree-hip-target={self.iree_hip_target}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=torch_sdpa", + f"--num-prompts={self.batch_size}", + ] + ) + + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 + ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity + + self.assertAlmostEqual( + baseline_mean_perplexity, + current_mean_perplexity, + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/evaluate/perplexity_torch_test.py b/sharktank/tests/evaluate/perplexity_torch_test.py new file mode 100644 index 000000000..042132f20 --- /dev/null +++ b/sharktank/tests/evaluate/perplexity_torch_test.py @@ -0,0 +1,274 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import pytest +import json + +from sharktank.evaluate import perplexity_torch + +longrun = pytest.mark.skipif("not config.getoption('longrun')") + + +@pytest.mark.usefixtures( + "get_model_artifacts", "tensor_parallelism_size", "baseline_perplexity_scores" +) +class PerplexityTest(unittest.TestCase): + def setUp(self): + self.current_perplexity_all = {} + self.delta = 5e-1 + self.tensor_parallelism_size = 8 + with open(self.baseline_perplexity_scores, "r") as f: + self.baseline_perplexity = json.load(f) + + @longrun + def test_llama3_8B_f16_decomposed(self): + + # Llama 3.1 8B decomposed + + model_name = "llama3_8B_f16_decomposed" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_8b_f16_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Non-decomposed attention is not supported yet", + ) + @longrun + def test_llama3_8B_f16(self): + + # Llama 3.1 8B non-decomposed + + model_name = "llama3_8B_f16" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_8b_f16_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="FP8 model is unsupported", + ) + @longrun + def test_llama3_8B_fp8_decomposed(self): + + # Llama 3.1 8B decomposed + + model_name = "llama3_8B_fp8_decomposed" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_8b_fp8_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Non-decomposed attention is not supported yet", + ) + @longrun + def test_llama3_8B_fp8(self): + + # Llama 3.1 8B non-decomposed + + model_name = "llama3_8B_fp8" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_8b_fp8_model}", + f"--tokenizer-config-json={self.llama3_8b_tokenizer}", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Sharding needs to be fixed", + ) + @longrun + def test_llama3_405B_f16_decomposed(self): + + # Llama 3.1 405B decomposed + + model_name = "llama3_405B_f16_decomposed" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_405b_f16_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Non-decomposed attention is not supported yet", + ) + @longrun + def test_llama3_405B_f16(self): + + # Llama 3.1 405B non-decomposed + + model_name = "llama3_405B_f16" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_405b_f16_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="FP8 model is unsupported", + ) + @longrun + def test_llama3_405B_fp8_decomposed(self): + + # Llama 3.1 405B decomposed + + model_name = "llama3_405B_fp8_decomposed" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_405b_fp8_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + @pytest.mark.xfail( + reason="Non-decomposed attention is not supported yet", + ) + @longrun + def test_llama3_405B_fp8(self): + + # Llama 3.1 405B non-decomposed + + model_name = "llama3_405B_fp8" + baseline_perplexity = self.baseline_perplexity[model_name] + + current_perplexity = perplexity_torch.main( + [ + f"--irpa-file={self.llama3_405b_fp8_model}", + f"--tokenizer-config-json={self.llama3_405b_tokenizer}", + f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=torch_sdpa", + ] + ) + + perplexity_difference = ( + current_perplexity["mean_perplexity"] + - baseline_perplexity["mean_perplexity"] + ) + + self.assertAlmostEqual( + baseline_perplexity["mean_perplexity"], + current_perplexity["mean_perplexity"], + delta=self.delta, + msg=f"Current perplexity deviates baseline by {perplexity_difference}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/export_test.py b/sharktank/tests/export_test.py new file mode 100644 index 000000000..20b7de734 --- /dev/null +++ b/sharktank/tests/export_test.py @@ -0,0 +1,101 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from sharktank.types import ( + ReplicatedTensor, + SplitPrimitiveTensor, + DefaultPrimitiveTensor, + unbox_tensor, +) +from sharktank.export import ( + export, + flatten_signature, + get_argument_flat_device_affinities, +) +from sharktank import ops +from sharktank.utils.testing import ( + assert_equal, + assert_iterables_equal, + assert_dicts_equal, +) +from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder +from iree.turbine import aot +from unittest import TestCase +import torch + + +class ExportTest(TestCase): + def testFlattenSignature(self): + expected_a = [SplitPrimitiveTensor(ts=[torch.tensor([1])], shard_dim=0)] + expected_b = {"element": DefaultPrimitiveTensor(data=torch.tensor([2]))} + expected_c = torch.tensor([3]) + + @flatten_signature(expected_a, expected_b, expected_c) + def f( + a: list[SplitPrimitiveTensor], + b: dict[str, DefaultPrimitiveTensor], + c: torch.Tensor, + ): + assert_iterables_equal(a, expected_a, elements_equal=ops.equal) + assert_dicts_equal(b, expected_b, values_equal=ops.equal) + assert_equal(c, expected_c, equal=ops.equal) + + f( + unbox_tensor(expected_a[0].shards[0]), + expected_b["element"].as_torch(), + expected_c, + ) + + def testGetFlatArgumentDeviceAffinities(self): + args = [ + { + "a": [ + SplitPrimitiveTensor( + ts=[torch.tensor([1]), torch.tensor([2])], shard_dim=0 + ) + ] + }, + torch.tensor([3]), + ReplicatedTensor(ts=[torch.tensor([4]), torch.tensor([5])]), + ] + affinities = get_argument_flat_device_affinities(*args) + expected_affinities = { + 0: DeviceAffinity("0"), + 1: DeviceAffinity("1"), + 3: DeviceAffinity("0"), + 4: DeviceAffinity("1"), + } + assert_dicts_equal(affinities, expected_affinities) + + def testExportWithArgumentDeviceAffinities(self): + args = (ReplicatedTensor(ts=[torch.tensor([1])]), torch.tensor([[2]])) + + class Module(torch.nn.Module): + def f(self, a, b): + return a, b + + module = Module() + fxb = FxProgramsBuilder(module) + export( + Module.f, + fx_builder=fxb, + args=args, + strict=False, + ) + export_output = aot.export( + fxb, + ) + asm = str(export_output.mlir_module) + print(asm) + self.assertRegex( + asm, + expected_regex=( + "func.func @f\\(" + "%.+: !torch.vtensor<\\[1\\],si64> " + "{iree.abi.affinity = #hal.device.promise<@__device_0>}, " + "%.+: !torch.vtensor<\\[1,1\\],si64>\\)" + ), + ) diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py index 30cc2296c..208d54782 100644 --- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py +++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels diff --git a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py index b03293523..ff1430a1a 100644 --- a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py +++ b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py @@ -12,10 +12,10 @@ from parameterized import parameterized import torch +import torch.nn.functional as F -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels -from sharktank.ops.qconv_impls import _pad_last_2d class conv_2d_nchw_fchw_test(unittest.TestCase): @@ -36,7 +36,8 @@ def testBS32(self, input_dtype, output_dtype_name, atol, rtol): inputs = (torch.rand([2, 4, 64, 64]) * 64).to(input_dtype) padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(inputs, extended_list) + inputs_pad = F.pad(inputs, pad=extended_list) + weights = (torch.rand([8, 4, 3, 3]) * 64).to(input_dtype) bias = (torch.rand([8]) * 64).to(dtype=output_dtype) result = kernels.conv_2d_nchw_fchw( @@ -68,7 +69,7 @@ def forward(self, a, b, c): inputs = torch.rand([2, 320, 64, 64]) * 64 padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(inputs, extended_list) + inputs_pad = F.pad(inputs, pad=extended_list) ep = torch.export.export( mod, args=( diff --git a/sharktank/tests/kernels/einsum_q4_test.py b/sharktank/tests/kernels/einsum_q4_test.py new file mode 100644 index 000000000..5f037ba9a --- /dev/null +++ b/sharktank/tests/kernels/einsum_q4_test.py @@ -0,0 +1,141 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest +from parameterized import parameterized + +import torch + +from iree.turbine import aot +from sharktank import kernels +from sharktank.types import layout_utils + + +class einsum_2args_q4_test(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_mk_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 320], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "mk,menk->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_mek_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 4, 320], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "mek,menk->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + @parameterized.expand( + [ + (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3), + (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3), + (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3), + ] + ) + def test_basic_me_men_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol): + a = torch.rand([2, 4], dtype=a_dtype) / 256.0 + d = torch.rand([2, 4, 10, 1], dtype=d_dtype) / 256.0 + qs = (torch.rand([2, 4, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8) + m = torch.rand([2, 4, 10, 1], dtype=d_dtype) + 16.0 + einsum_string = "me,men->men" + result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string) + + # Dequantize and test with normal matmul. + # Tolerances are empirical and results are not expected to match exactly. + qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs) + b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(2) + ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype)) + torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol) + + def testExportDynamicDims(self): + class MyModule(torch.nn.Module): + def forward(self, a, d, qs, m): + return kernels.einsum_2args_q4(a, d, qs, m, "ij,jk->ik") + + mod = MyModule() + ep = torch.export.export( + mod, + args=( + torch.rand([16, 320], dtype=torch.float32), + torch.rand([320, 2, 1], dtype=torch.float16), + (torch.rand([320, 2, 16], dtype=torch.float32) * 32).to(torch.uint8), + torch.rand([320, 2, 1], dtype=torch.float16), + ), + dynamic_shapes={ + "a": {}, + "d": {}, + "qs": {}, + "m": {}, + }, + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertIn("@sharktank_einsum_2args_q4_ij_jk_ik_32_f32", asm) + + def testExportStaticDims(self): + class MyModule(torch.nn.Module): + def forward(self, a, d, qs, m): + return kernels.einsum_2args_q4(a, d, qs, m, "mek,menk->men") + + mod = MyModule() + ep = torch.export.export( + mod, + args=( + torch.rand([4, 16, 320], dtype=torch.float32), + torch.rand([4, 16, 2, 10, 1], dtype=torch.float16), + (torch.rand([4, 16, 2, 10, 16], dtype=torch.float32) * 32).to( + torch.uint8 + ), + torch.rand([4, 16, 2, 10, 1], dtype=torch.float16), + ), + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertIn("@sharktank_einsum_2args_q4_mek_menk_men_32_f32", asm) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py b/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py index d9fc7370a..dca474446 100644 --- a/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py +++ b/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels from sharktank.types import layout_utils diff --git a/sharktank/tests/kernels/mmt_block_scaled_q8_test.py b/sharktank/tests/kernels/mmt_block_scaled_q8_test.py index f3fdf2ed9..08aa8d179 100644 --- a/sharktank/tests/kernels/mmt_block_scaled_q8_test.py +++ b/sharktank/tests/kernels/mmt_block_scaled_q8_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels diff --git a/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py b/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py index 41c04106d..01272553a 100644 --- a/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py +++ b/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels from sharktank.types import layout_utils diff --git a/sharktank/tests/kernels/mmtfp_test.py b/sharktank/tests/kernels/mmtfp_test.py index 281498f90..e2c36e4ac 100644 --- a/sharktank/tests/kernels/mmtfp_test.py +++ b/sharktank/tests/kernels/mmtfp_test.py @@ -13,7 +13,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels diff --git a/sharktank/tests/kernels/pooling_nchw_sum_test.py b/sharktank/tests/kernels/pooling_nchw_sum_test.py index 5c4e8ac0a..e512eb484 100644 --- a/sharktank/tests/kernels/pooling_nchw_sum_test.py +++ b/sharktank/tests/kernels/pooling_nchw_sum_test.py @@ -12,10 +12,10 @@ from parameterized import parameterized import torch +import torch.nn.functional as F -from shark_turbine import aot +from iree.turbine import aot from sharktank import kernels -from sharktank.ops.qconv_impls import _pad_last_2d class pooling_nchw_sum_test(unittest.TestCase): @@ -34,7 +34,7 @@ def testBS32(self, atol, rtol): a = (torch.randint(0, 100, (2, 1, 128, 128))).to(torch.float32) padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(a, extended_list) + inputs_pad = F.pad(a, pad=extended_list) weight_shape = [3, 3] stride = [1, 1] dilations = [1, 1] @@ -62,7 +62,7 @@ def forward(self, a): inputs = torch.rand([2, 1, 128, 128]) * 64 padding = [1, 1] extended_list = [item for item in padding for _ in range(2)] - inputs_pad = _pad_last_2d(inputs, extended_list) + inputs_pad = F.pad(inputs, pad=extended_list) ep = torch.export.export( mod, args=((inputs_pad).to(dtype),), diff --git a/sharktank/tests/layers/configs_test.py b/sharktank/tests/layers/configs_test.py new file mode 100644 index 000000000..024eb6bb6 --- /dev/null +++ b/sharktank/tests/layers/configs_test.py @@ -0,0 +1,27 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from sharktank.layers.configs import LlamaHParams + + +def test_llama_hp_params_to_from_gguf_props_roundtrip(): + params = LlamaHParams( + model_arch="llama", + context_length=1, + embedding_length=2, + block_count=3, + feed_forward_length=3, + rope_dimension_count=4, + rope_freq_base=5.0, + attention_head_count=6, + attn_head_dim=4, + attention_layer_norm_rms_epsilon=8.0, + attention_head_count_kv=9, + expert_count=10, + expert_used_count=11, + ) + roundtripped_params = LlamaHParams.from_gguf_props(params.to_gguf_props()) + assert params == roundtripped_params diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py new file mode 100644 index 000000000..65b42c986 --- /dev/null +++ b/sharktank/tests/layers/kv_cache_test.py @@ -0,0 +1,502 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest + +import torch + +from sharktank.ops import replicate, reshard_split, unshard +from sharktank.layers import * +from sharktank.types import * + + +def test_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_direct(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + shard_count = 4 + cache = DirectKVCache( + block_seq_stride=4, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + seq_length=seq_length, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + allocation = cache.allocate(bs=bs) + # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + write_seq_length = seq_length - 5 + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1 + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + 1, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) + + +def test_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 4 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation] + + # Write a prefill in: + write_ones = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32 + ) + write_twos = torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32 + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + read_empty = [ + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(write_ones, read_back[0]) + torch.testing.assert_close(write_twos, read_back[1]) + + # Check the others are still zero: + for i in range(transformer_block_count): + if i == 1: + continue + read_ones = [ + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_ones = cache.read( + allocation, + read_into_partitions=read_ones, + transformer_block_index=i, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0)) + torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0)) + + # Write timestep + write_threes = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32 + ) + write_fours = torch.full( + (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32 + ) + write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64) + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + read_empty = [ + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + ] + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([write_ones, write_threes], dim=1) + check_concat_1 = torch.concat([write_twos, write_fours], dim=1) + + torch.testing.assert_close(check_concat_0, read_back[0]) + torch.testing.assert_close(check_concat_1, read_back[1]) + + +def test_sharded_paged(): + bs = 4 + seq_length = 24 + attn_head_count = 8 + attn_head_dim = 16 + transformer_block_count = 4 + block_seq_stride = 4 + shard_count = 4 + cache = PagedKVCache( + block_seq_stride=block_seq_stride, + transformer_block_count=transformer_block_count, + attn_head_count=attn_head_count, + attn_head_dim=attn_head_dim, + shard_count=shard_count, + dtype=torch.float32, + device=None, + ) + + write_seq_length = seq_length - 4 + page_count = bs * seq_length // block_seq_stride + page_ids = torch.arange(page_count, dtype=torch.int64) + page_ids = page_ids.view(bs, seq_length // block_seq_stride) + page_ids = replicate(page_ids, shard_count) + write_page_ids = page_ids[:, : write_seq_length // block_seq_stride] + + allocation = cache.allocate(page_count=page_count) + + # Write a prefill in: + write_ones = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 1.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + write_twos = reshard_split( + torch.full( + (bs, write_seq_length, attn_head_count, attn_head_dim), + 2.0, + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + cache.write( + allocation, + cache_partitions=[write_ones, write_twos], + transformer_block_index=1, + page_ids=write_page_ids, + ) + + # Check the written values have updated: + empty_k = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.empty( + (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32 + ), + dim=2, + count=shard_count, + ) + + read_empty = [empty_k, empty_v] + + read_back = cache.read( + allocation, + read_into_partitions=read_empty, + transformer_block_index=1, + seq_len=write_seq_length, + page_ids=write_page_ids, + ) + torch.testing.assert_close(unshard(write_ones), unshard(read_back[0])) + torch.testing.assert_close(unshard(write_twos), unshard(read_back[1])) + + # Write timestep + write_threes = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_fours = reshard_split( + torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32), + dim=2, + count=shard_count, + ) + + write_pos = replicate( + torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count + ) + + cache.write_timestep( + allocation, + cache_partitions=[write_threes, write_fours], + transformer_block_index=1, + seq_positions=write_pos, + page_ids=page_ids, + ) + + empty_k = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + empty_v = reshard_split( + torch.zeros( + (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim), + dtype=torch.float32, + ), + dim=2, + count=shard_count, + ) + + read_back = cache.read( + allocation, + read_into_partitions=[empty_k, empty_v], + transformer_block_index=1, + seq_len=write_seq_length + 1, + page_ids=page_ids, + ) + + check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1) + check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1) + + torch.testing.assert_close(check_concat_0, unshard(read_back[0])) + torch.testing.assert_close(check_concat_1, unshard(read_back[1])) diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py index e2d038f72..ad657889d 100644 --- a/sharktank/tests/layers/linear_test.py +++ b/sharktank/tests/layers/linear_test.py @@ -84,7 +84,7 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self): bias_quant, ] ) - linear = LinearLayer(theta) + linear = LinearLayer(theta, fake_quant=False) output = linear(lhs) output_ref = torch.matmul(lhs, rhs.T) + bias diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py new file mode 100644 index 000000000..a55782329 --- /dev/null +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -0,0 +1,196 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest + +import torch + +from iree.turbine import aot +from sharktank.layers import ( + PagedLlamaAttentionBlock, + PagedKVCache, + RotaryEmbeddingLayer, +) +from sharktank.layers.testing import make_llama_attention_block_theta +from sharktank.types.tensors import DefaultPrimitiveTensor + + +class PagedLlamaAttentionBlockTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(12345) + self.transformer_block_count = 13 + self.block_index = 1 + self.shard_count = 3 + self.head_count_kv = 2 * self.shard_count + self.attention_head_count = 5 * self.head_count_kv + self.attention_head_dim = 11 * 2 + self.rms_epsilon = 0.01 + self.block_seq_stride = 17 + self.cache_partition_count = 2 + self.page_count = 23 + self.embedding_length = self.attention_head_count * self.attention_head_dim + self.rope_dimension_count = self.attention_head_dim + self.block_seqlen = 7 + self.max_seqlen = self.block_seq_stride * self.block_seqlen + self.rope_freq_base = None + self.batch_size = 3 + self.start_index = 0 + + def testExportDecomposed(self): + dtype = torch.float32 + + cache = PagedKVCache( + transformer_block_count=self.transformer_block_count, + attn_head_count=self.head_count_kv, + attn_head_dim=self.attention_head_dim, + cache_partition_count=self.cache_partition_count, + block_seq_stride=self.block_seq_stride, + dtype=dtype, + ) + + cache_state = cache.paged.allocate(self.page_count) + cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) + + theta = make_llama_attention_block_theta( + head_count=self.attention_head_count, + head_count_kv=self.head_count_kv, + head_dim=self.attention_head_dim, + embedding_length=self.embedding_length, + ) + attn = PagedLlamaAttentionBlock( + theta=theta, + block_index=self.block_index, + cache=cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + attention_kernel="decomposed", + ) + + seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( + self.batch_size, -1 + ) + + embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + ) + + class MyModule(torch.nn.Module): + def forward(self, h, seq_block_ids, cache_state): + return attn.forward( + h, + seq_block_ids=seq_block_ids, + embedding=embedding_module, + start_index=0, + cache_state=cache_state, + ) + + mod = MyModule() + h = torch.rand( + [ + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ] + ) + mod.forward(h, seq_block_ids, cache_state) + ep = torch.export.export( + mod, + args=( + h, + seq_block_ids, + cache_state, + ), + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertNotIn("scaled_dot_product_attention", asm) + + def testExportNondecomposed(self): + dtype = torch.float32 + + cache = PagedKVCache( + transformer_block_count=self.transformer_block_count, + attn_head_count=self.head_count_kv, + attn_head_dim=self.attention_head_dim, + cache_partition_count=self.cache_partition_count, + block_seq_stride=self.block_seq_stride, + dtype=dtype, + ) + + cache_state = cache.paged.allocate(self.page_count) + cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype) + + theta = make_llama_attention_block_theta( + head_count=self.attention_head_count, + head_count_kv=self.head_count_kv, + head_dim=self.attention_head_dim, + embedding_length=self.embedding_length, + ) + attn = PagedLlamaAttentionBlock( + theta=theta, + block_index=self.block_index, + cache=cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + attention_kernel="torch", + ) + + seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( + self.batch_size, -1 + ) + + embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + ) + + class MyModule(torch.nn.Module): + def forward(self, h, seq_block_ids, cache_state): + return attn.forward( + h, + seq_block_ids=seq_block_ids, + embedding=embedding_module, + start_index=0, + cache_state=cache_state, + ) + + mod = MyModule() + h = torch.rand( + [ + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ] + ) + mod.forward(h, seq_block_ids, cache_state) + ep = torch.export.export( + mod, + args=( + h, + seq_block_ids, + cache_state, + ), + ) + output = aot.export(ep) + output.verify() + asm = str(output.mlir_module) + self.assertIn("torch.aten._scaled_dot_product_flash_attention_for_cpu", asm) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py new file mode 100644 index 000000000..9b29e5761 --- /dev/null +++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py @@ -0,0 +1,210 @@ +import unittest + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +import tempfile +import torch +from iree.turbine import aot +from sharktank.models.punet.layers import Conv2DLayer +from sharktank import ops +from sharktank.types import ( + Dataset, + DefaultPrimitiveTensor, + Theta, + ShardedTensor, + SplitPrimitiveTensor, + unbox_tensor, +) +from sharktank.types.sharding import Conv2DSplitOutputChannelSharding +import iree.runtime +from typing import List, Optional +import os + +vm_context: iree.runtime.VmContext = None + + +def get_compiler_args(target_device_kind: str, shard_count: int) -> List[str]: + result = [ + f"--iree-hal-target-device={target_device_kind}[{i}]" + for i in range(shard_count) + ] + return result + + +def compile_iree_module( + export_output: aot.ExportOutput, module_path: str, shard_count: int +): + export_output.session.set_flags( + *get_compiler_args(target_device_kind="llvm-cpu", shard_count=shard_count) + ) + export_output.compile(save_to=module_path, target_backends=None) + + +# TODO: improve IREE's Python API to be more concise in a multi-device context. +# This run function should be way shorter. +def run_iree_module( + sharded_input_image: ShardedTensor, + module_path: str, + parameters_path: str, +) -> ShardedTensor: + shard_count = sharded_input_image.shard_count + hal_driver = iree.runtime.get_driver("local-task") + vm_instance = iree.runtime.VmInstance() + available_devices = hal_driver.query_available_devices() + # Use the same actual device for all devices. + devices = [ + hal_driver.create_device(available_devices[0]) for _ in range(shard_count) + ] + hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices) + params_path = Path(parameters_path) + # TODO: make IREE able to load the parameters from the top parameter file + # without having to specify the parameter file for each shard separately. + parameter_index = iree.runtime.ParameterIndex() + for i in range(shard_count): + parameter_index.load( + file_path=str( + Path(params_path).with_suffix(f".rank{i}{params_path.suffix}") + ) + ) + parameter_provider = parameter_index.create_provider(scope="model") + parameters_module = iree.runtime.create_io_parameters_module( + vm_instance, parameter_provider + ) + + vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path)) + + # The context needs to be destroyed after the buffers, although + # it is not associate with them on the API level. + global vm_context + vm_context = iree.runtime.VmContext( + instance=vm_instance, modules=(hal_module, parameters_module, vm_module) + ) + module_input_args = [ + iree.runtime.asdevicearray( + devices[i], sharded_input_image.shards[i].as_torch().to("cpu").numpy() + ) + for i in range(shard_count) + ] + + vm_function = vm_module.lookup_function("main") + invoker = iree.runtime.FunctionInvoker( + vm_context=vm_context, + # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices. + # This works, but does not look right. + device=devices[0], + vm_function=vm_function, + ) + results = invoker(*module_input_args) + shards = [torch.tensor(tensor.to_host()) for tensor in results] + return SplitPrimitiveTensor(ts=shards, shard_dim=1) + + +def run_test_sharded_conv2d_with_iree( + mlir_path: Path, module_path: Path, parameters_path: Path, caching: bool +): + torch.set_default_dtype(torch.float32) + torch.manual_seed(123456) + batches = 2 + in_channels = 6 + out_channels = 8 + height = 11 + width = 13 + kernel_height = 5 + kernel_width = 5 + shard_count = 2 + unsharded_theta = Theta( + { + "weight": DefaultPrimitiveTensor( + data=torch.rand( + out_channels, + in_channels, + kernel_height, + kernel_width, + ) + ), + } + ) + unsharded_theta.rename_tensors_to_paths() + + if not caching or not os.path.exists(parameters_path): + sharding_spec = Conv2DSplitOutputChannelSharding(shard_count=shard_count) + sharded_theta = ops.reshard(unsharded_theta, sharding_spec) + + # Roundtrip the dataset, which anchors the tensors as parameters to be loaded + # vs constants to be frozen (TODO: This is a bit wonky). + sharded_dataset = Dataset({}, sharded_theta) + sharded_dataset.save(parameters_path) + + sharded_dataset = Dataset.load(parameters_path) + + input_image = torch.rand( + batches, + in_channels, + height, + width, + ) + + sharded_torch_module = Conv2DLayer(sharded_dataset.root_theta, padding=(0, 0)) + sharded_input_image = ops.reshard_split(input_image, dim=1, count=shard_count) + expected_result = sharded_torch_module(sharded_input_image) + + if not caching or not os.path.exists(module_path): + exported_module = aot.export( + sharded_torch_module, + args=(sharded_input_image,), + ) + exported_module.save_mlir(mlir_path) + + compile_iree_module( + export_output=exported_module, + module_path=module_path, + shard_count=shard_count, + ) + + actual_result = run_iree_module( + sharded_input_image=sharded_input_image, + module_path=module_path, + parameters_path=parameters_path, + ) + assert len(actual_result.shards) == len(expected_result.shards) + assert actual_result.shard_dim == expected_result.shard_dim + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) + + +def test_sharded_conv2d_with_iree( + mlir_path: Optional[Path], + module_path: Optional[Path], + parameters_path: Optional[Path], + caching: bool, +): + """Test sharding, exporting and running with IREE a 2D convolution layer.""" + + with tempfile.TemporaryDirectory( + # TODO: verify hypothesis and remove ignore_cleanup_errors=True after a fix. + # torch.export.export is spawning some processes that don't exit when the + # function returns, this causes some objects to not get destroyed, which + # in turn holds files params.rank0.irpa and params.rank1.irpa open. + ignore_cleanup_errors=True + ) as tmp_dir: + mlir_path = Path(tmp_dir) / "model.mlir" if mlir_path is None else mlir_path + module_path = ( + Path(tmp_dir) / "module.vmfb" if module_path is None else module_path + ) + parameters_path = ( + Path(tmp_dir) / "params.irpa" + if parameters_path is None + else parameters_path + ) + run_test_sharded_conv2d_with_iree( + mlir_path, module_path, parameters_path, caching + ) diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py new file mode 100644 index 000000000..d7b6a0b33 --- /dev/null +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -0,0 +1,236 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from sharktank.layers import PagedKVCache +import torch +from sharktank.utils import iterables_equal +from copy import deepcopy +from typing import List, Tuple +from sharktank import ops +from sharktank.types import SplitPrimitiveTensor + + +class ShardedPagedKVCacheTest(unittest.TestCase): + """Verify that the sharded paged KV cache behaves as the unsharded variant.""" + + def setUp(self): + torch.manual_seed(12345) + self.dtype = torch.float32 + torch.set_default_dtype(self.dtype) + self.shard_count = 3 + self.transformer_block_count = 5 + self.attn_head_count = self.shard_count * 7 + self.block_seq_stride = 19 + self.attn_head_dim = 17 + self.cache_partition_count = 2 + self.page_count = 23 + self.batch_size = 11 + self.block_seq_len = 2 + self.max_seq_len = self.block_seq_len * self.block_seq_stride + + self.cache = PagedKVCache( + transformer_block_count=self.transformer_block_count, + attn_head_count=self.attn_head_count, + block_seq_stride=self.block_seq_stride, + attn_head_dim=self.attn_head_dim, + cache_partition_count=self.cache_partition_count, + dtype=self.dtype, + ) + self.sharded_cache = PagedKVCache( + shard_count=self.shard_count, + transformer_block_count=self.transformer_block_count, + attn_head_count=self.attn_head_count, + block_seq_stride=self.block_seq_stride, + attn_head_dim=self.attn_head_dim, + cache_partition_count=self.cache_partition_count, + dtype=self.dtype, + ) + + def make_unsharded_and_sharded_equal_cache_states( + self, + ) -> Tuple[List[torch.Tensor], List[SplitPrimitiveTensor]]: + cache_state = self.cache.allocate(self.page_count) + cache_state[0] = torch.rand_like(cache_state[0]) + sharded_cache_state = self.sharded_cache.shard_state(deepcopy(cache_state)) + self.assert_equal_unsharded_and_sharded_cache_states( + cache_state, sharded_cache_state + ) + return cache_state, sharded_cache_state + + def assert_equal_unsharded_and_sharded_cache_states( + self, + cache_state: List[torch.Tensor], + sharded_cache_state: List[SplitPrimitiveTensor], + ): + sharded_state_as_unsharded = ops.unshard( + self.sharded_cache.unflatten_page_table(sharded_cache_state) + ).flatten(start_dim=1) + assert ops.equal( + cache_state[0], + sharded_state_as_unsharded, + ) + + def testAllocate(self): + cache_state = self.cache.allocate(self.page_count) + sharded_cache_state = self.sharded_cache.allocate(self.page_count) + assert len(cache_state) == 1 + assert len(sharded_cache_state) == 1 + assert iterables_equal(cache_state[0].shape, sharded_cache_state[0].shape) + assert sharded_cache_state[0].shard_dim == 1 + assert sharded_cache_state[0].shard_count == self.shard_count + + def testUnflattenPageTable(self): + cache_state = self.cache.allocate(self.page_count) + sharded_cache_state = self.sharded_cache.allocate(self.page_count) + + unflattened_cache_state = self.cache.unflatten_page_table(cache_state) + sharded_unflattened_cache_state = self.sharded_cache.unflatten_page_table( + sharded_cache_state + ) + assert iterables_equal( + unflattened_cache_state.shape, sharded_unflattened_cache_state.shape + ) + assert sharded_unflattened_cache_state.shard_dim == 4 + assert sharded_unflattened_cache_state.shard_count == self.shard_count + assert sharded_unflattened_cache_state.shape[0] == self.page_count + + def testRead(self): + ( + cache_state, + sharded_cache_state, + ) = self.make_unsharded_and_sharded_equal_cache_states() + + read_into_partitions_snapshot = [ + torch.rand( + self.batch_size, + self.block_seq_len * self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, + ) + for _ in range(self.cache_partition_count) + ] + read_into_partitions = deepcopy(read_into_partitions_snapshot) + transformer_block_index = 1 + page_ids = torch.randint( + low=0, high=self.page_count, size=[self.batch_size, self.block_seq_len] + ).reshape([self.batch_size, self.block_seq_len]) + self.cache.read( + state=cache_state, + read_into_partitions=read_into_partitions, + transformer_block_index=transformer_block_index, + page_ids=page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, + ) + sharded_read_into_partitions = deepcopy( + [ + ops.reshard_split(t, dim=2, count=self.shard_count) + for t in read_into_partitions_snapshot + ] + ) + sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + self.sharded_cache.read( + state=sharded_cache_state, + read_into_partitions=sharded_read_into_partitions, + transformer_block_index=transformer_block_index, + page_ids=sharded_page_ids, + seq_len=self.block_seq_len * self.block_seq_stride, + ) + for unsharded, sharded in zip( + read_into_partitions, sharded_read_into_partitions + ): + assert ops.equal(unsharded, ops.unshard(sharded)) + + def testWriteTimestep(self): + ( + cache_state, + sharded_cache_state, + ) = self.make_unsharded_and_sharded_equal_cache_states() + + cache_partitions = [ + torch.rand( + self.batch_size, + 1, + self.attn_head_count, + self.attn_head_dim, + ) + for _ in range(self.cache_partition_count) + ] + transformer_block_index = 1 + seq_positions = torch.randint( + low=0, high=self.max_seq_len, size=[self.batch_size] + ) + page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape( + [self.batch_size, self.block_seq_len] + ) + self.cache.write_timestep( + state=cache_state, + cache_partitions=cache_partitions, + transformer_block_index=transformer_block_index, + seq_positions=seq_positions, + page_ids=page_ids, + ) + sharded_cache_partitions = deepcopy( + [ + ops.reshard_split(t, dim=2, count=self.shard_count) + for t in cache_partitions + ] + ) + sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + self.sharded_cache.write_timestep( + state=sharded_cache_state, + cache_partitions=sharded_cache_partitions, + transformer_block_index=transformer_block_index, + seq_positions=sharded_seq_positions, + page_ids=sharded_page_ids, + ) + self.assert_equal_unsharded_and_sharded_cache_states( + cache_state, sharded_cache_state + ) + + def testWrite(self): + ( + cache_state, + sharded_cache_state, + ) = self.make_unsharded_and_sharded_equal_cache_states() + + cache_partitions = [ + torch.rand( + self.batch_size, + self.block_seq_len * self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, + ) + for _ in range(self.cache_partition_count) + ] + transformer_block_index = 1 + assert self.batch_size * self.block_seq_len <= self.page_count + page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape( + [self.batch_size, self.block_seq_len] + ) + self.cache.write( + state=cache_state, + cache_partitions=cache_partitions, + transformer_block_index=transformer_block_index, + page_ids=page_ids, + ) + sharded_cache_partitions = deepcopy( + [ + ops.reshard_split(t, dim=2, count=self.shard_count) + for t in cache_partitions + ] + ) + sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + self.sharded_cache.write( + state=sharded_cache_state, + cache_partitions=sharded_cache_partitions, + transformer_block_index=transformer_block_index, + page_ids=sharded_page_ids, + ) + self.assert_equal_unsharded_and_sharded_cache_states( + cache_state, sharded_cache_state + ) diff --git a/sharktank/tests/layers/sharded_paged_llama_attention_block.py b/sharktank/tests/layers/sharded_paged_llama_attention_block.py new file mode 100644 index 000000000..c94fd44ab --- /dev/null +++ b/sharktank/tests/layers/sharded_paged_llama_attention_block.py @@ -0,0 +1,163 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +from sharktank.layers import ( + PagedLlamaAttentionBlock, + PagedKVCache, + RotaryEmbeddingLayer, +) +from sharktank.layers.testing import make_llama_attention_block_theta, make_rand_torch +from sharktank.models.llama.sharding import PagedLlamaAttentionBlockSharding +from sharktank.types import SplitPrimitiveTensor, unbox_tensor +import torch +from sharktank import ops +from copy import deepcopy +import pytest + + +class ShardedPagedLlamaAttentionBlockTest(unittest.TestCase): + """Verify that the sharded Llama paged attention block behaves in PyTorch as the + unsharded variant.""" + + def setUp(self): + torch.manual_seed(12345) + self.transformer_block_count = 13 + self.block_index = 1 + self.shard_count = 3 + self.head_count_kv = 2 * self.shard_count + self.attention_head_count = 5 * self.head_count_kv + self.attention_head_dim = 11 * 2 + self.rms_epsilon = 0.01 + self.block_seq_stride = 17 + self.cache_partition_count = 2 + self.page_count = 23 + self.embedding_length = self.attention_head_count * self.attention_head_dim + self.rope_dimension_count = self.attention_head_dim + self.block_seqlen = 7 + self.max_seqlen = self.block_seq_stride * self.block_seqlen + self.rope_freq_base = None + self.batch_size = 3 + self.start_index = 0 + + def testSmallSizedLayerFp64(self): + self.runTestSmallSizedLayer(dtype=torch.float64) + + @pytest.mark.xfail( + reason="The accuracy seems low (atol=0.0018, rtol=0.5065)", + strict=True, + raises=AssertionError, + ) + def testSmallSizedLayerFp32(self): + self.runTestSmallSizedLayer(dtype=torch.float32) + + def runTestSmallSizedLayer(self, dtype: torch.dtype): + torch.set_default_dtype(dtype) + + def make_paged_kv_cache(shard_count: int) -> PagedKVCache: + return PagedKVCache( + transformer_block_count=self.transformer_block_count, + attn_head_count=self.head_count_kv, + attn_head_dim=self.attention_head_dim, + cache_partition_count=self.cache_partition_count, + block_seq_stride=self.block_seq_stride, + dtype=dtype, + shard_count=shard_count, + ) + + cache = make_paged_kv_cache(shard_count=1) + sharded_cache = make_paged_kv_cache(shard_count=self.shard_count) + + def make_unsharded_and_sharded_equal_cache_states() -> tuple[ + list[torch.Tensor], list[SplitPrimitiveTensor] + ]: + cache_state = cache.allocate(self.page_count) + cache_state[0] = make_rand_torch(cache_state[0].shape, dtype=dtype) + sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state)) + return cache_state, sharded_cache_state + + ( + cache_state, + sharded_cache_state, + ) = make_unsharded_and_sharded_equal_cache_states() + + input_tensor = make_rand_torch( + ( + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ), + dtype=dtype, + ) + seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( + self.batch_size, -1 + ) + embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + ) + + theta = make_llama_attention_block_theta( + head_count=self.attention_head_count, + head_count_kv=self.head_count_kv, + head_dim=self.attention_head_dim, + embedding_length=self.embedding_length, + ) + attention_block = PagedLlamaAttentionBlock( + theta=theta, + block_index=self.block_index, + cache=cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + ) + expected_result = attention_block( + input_tensor, + embedding=embedding_module, + seq_block_ids=seq_block_ids, + start_index=self.start_index, + cache_state=cache_state, + ) + + sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count) + sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count) + sharded_embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + tensor_parallelism_size=self.shard_count, + ) + + theta_sharding = PagedLlamaAttentionBlockSharding(shard_count=self.shard_count) + sharded_theta = ops.reshard(theta, theta_sharding) + sharded_attention_block = PagedLlamaAttentionBlock( + theta=sharded_theta, + block_index=self.block_index, + cache=sharded_cache, + head_count=self.attention_head_count, + head_dim=self.attention_head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + ) + sharded_result = sharded_attention_block( + sharded_input_tensor, + embedding=sharded_embedding_module, + seq_block_ids=sharded_seq_block_ids, + start_index=self.start_index, + cache_state=sharded_cache_state, + ) + + actual_result = unbox_tensor(ops.unshard(sharded_result)) + actual_cache_state = unbox_tensor( + ops.unshard( + sharded_cache.unflatten_page_table(sharded_cache_state) + ).flatten(start_dim=1) + ) + + torch.testing.assert_close(actual_result, expected_result) + torch.testing.assert_close(actual_cache_state, cache_state[0]) diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py new file mode 100644 index 000000000..f24b8313a --- /dev/null +++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import torch + +from sharktank.layers import RotaryEmbeddingLayer +from sharktank import ops +from sharktank.types import ( + ShardedTensor, + SplitPrimitiveTensor, + unbox_tensor, +) + +import unittest +from typing import List, Optional +import os + + +def test_sharded_rotary_table(): + bs = 4 + rope_dims = 16 + heads = 8 + max_seqlen = 128 + rope_freq_base = None + + # First we setup and get the default rotary embedding layer + xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) + xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float) + default_layer = RotaryEmbeddingLayer( + rope_dimension_count=rope_dims, + max_seqlen=max_seqlen, + rope_freq_base=rope_freq_base, + ) + oq = default_layer(xt=xq, start_index=0) + ok = default_layer(xt=xk, start_index=0) + + # Then we can shard the same inputs and layer + xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4) + xk = SplitPrimitiveTensor(ts=xk, shard_dim=2, shard_count=4) + shard_layer = RotaryEmbeddingLayer( + rope_dimension_count=rope_dims, + max_seqlen=max_seqlen, + rope_freq_base=rope_freq_base, + tensor_parallelism_size=4, + ) + sq = shard_layer(xt=xq, start_index=0) + sk = shard_layer(xt=xk, start_index=0) + + # Gathering and unboxing should yield the same results + sq = ops.unshard(sq) + sk = ops.unshard(sk) + + torch.testing.assert_close(sq, oq) + torch.testing.assert_close(sk, ok) diff --git a/sharktank/tests/models/llama/README.md b/sharktank/tests/models/llama/README.md new file mode 100644 index 000000000..6adf38588 --- /dev/null +++ b/sharktank/tests/models/llama/README.md @@ -0,0 +1,14 @@ +# How to run Llama 3.1 Benchmarking Tests +In order to run Llama 3.1 8B F16 Decomposed test: +``` +pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s \ + --run-quick-test --iree-hip-target=gfx942 +``` + +In order to filter by test, use the -k option. If you +wanted to only run the Llama 3.1 70B F16 Decomposed test: +``` +pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s \ + --run-nightly-llama-tests --iree-hip-target=gfx942 \ + -k 'testBenchmark70B_f16_TP8_Decomposed' +``` diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py index bb5eb254d..211fab5a0 100644 --- a/sharktank/tests/models/llama/attention_test.py +++ b/sharktank/tests/models/llama/attention_test.py @@ -25,6 +25,7 @@ class AttentionBlockTest(unittest.TestCase): def test(self): + torch.manual_seed(123456) torch.set_default_dtype(torch.float32) block_index = 0 seq_len = 13 @@ -58,7 +59,7 @@ def test(self): head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, - use_hf=True, + attention_kernel="decomposed", ) attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=rope_dimension_count, @@ -148,7 +149,9 @@ def test(self): )[0] assert sharktank_output.shape == huggingface_output.shape - torch.testing.assert_close(sharktank_output, huggingface_output) + torch.testing.assert_close( + sharktank_output, huggingface_output, atol=1e-5, rtol=5e-2 + ) if __name__ == "__main__": diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py new file mode 100644 index 000000000..751615a85 --- /dev/null +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -0,0 +1,930 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from datetime import datetime +import os +import sys +import unittest +import pytest +import subprocess +from pathlib import Path +from typing import List +from sharktank.utils.export_artifacts import ( + ExportArtifacts, + ExportMlirException, + IreeBenchmarkException, + IreeCompileException, +) + +is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") +skipif_run_quick_llama_test = pytest.mark.skipif( + 'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")', + reason="Skipping largs tests when --run-quick-llama-test is set.", +) + + +@pytest.mark.usefixtures("get_iree_flags") +class BaseBenchmarkTest(unittest.TestCase): + directory_created = False + current_date = datetime.now() + dir_path_suffix = current_date.strftime("%Y-%m-%d") + cur_dir = os.path.dirname(os.path.abspath(__file__)) + models_dir = os.path.dirname(cur_dir) + tests_dir = os.path.dirname(models_dir) + sharktank_dir = os.path.dirname(tests_dir) + repo_root = os.path.dirname(sharktank_dir) + dir_path = Path(repo_root + "/" + dir_path_suffix) + + @classmethod + def setUpClass(cls): + """This method will be run once per class to create the directory.""" + if not cls.directory_created: + if not os.path.exists(cls.dir_path): + os.makedirs(cls.dir_path) + cls.directory_created = True + + def setUp(self): + self.hip_device_id = os.getenv("HIP_DEVICE_ID", default="0") + self.compile_args = [ + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-data-tiling=false", + "--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'", + "--iree-stream-resource-memory-model=discrete", + "--iree-hip-legacy-sync=false", + "--iree-hal-indirect-command-buffers=true", + "--iree-hal-memoization=true", + "--iree-opt-strip-assertions", + ] + + +@is_mi300x +class BenchmarkLlama3_1_8B(BaseBenchmarkTest): + def setUp(self): + super().setUp() + # TODO: add numpy files to Azure and download from it + self.artifacts_dir = Path("/data/llama3.1/weights/8b") + self.irpa_path = self.artifacts_dir / "fp16/llama3.1_8b_fp16.irpa" + self.irpa_path_fp8 = self.artifacts_dir / "f8/llama3.1_8b_fp8.irpa" + self.tensor_parallelism_size = 1 + self.dir_path_8b = self.dir_path / "llama-8b" + self.temp_dir_8b = Path(self.dir_path_8b) + self.temp_dir_8b.mkdir(parents=True, exist_ok=True) + self.llama8b_f16_decomposed_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama8b_f16_torch_sdpa_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="torch", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama8b_fp8_decomposed_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path_fp8), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama8b_fp8_torch_sdpa_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path_fp8), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="torch", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.prefill_args_f16 = self.artifacts_dir / "prefill_args" + self.prefill_args_bs4_128_in_tokens_f16 = ( + self.artifacts_dir / "prefill_args_bs4_128" + ) + self.decode_args_f16 = self.artifacts_dir / "decode_args" + self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" + self.iree_run_prefill_args = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_f16}/tokens.npy", + f"--input=@{self.prefill_args_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_nondecomposed_args_fp16 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_f16}/tokens.npy", + f"--input=@{self.decode_args_f16}/seq_lens.npy", + f"--input=@{self.decode_args_f16}/start_positions.npy", + f"--input=@{self.decode_args_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_args_fp8 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_fp8}/tokens.npy", + f"--input=@{self.prefill_args_fp8}/seq_lens.npy", + f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy", + f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args_fp8 = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_fp8}/tokens.npy", + f"--input=@{self.decode_args_fp8}/seq_lens.npy", + f"--input=@{self.decode_args_fp8}/start_positions.npy", + f"--input=@{self.decode_args_fp8}/seq_block_ids.npy", + f"--input=@{self.decode_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + + def testBenchmark8B_f16_Decomposed(self): + output_file_name = self.dir_path_8b / "f16_decomposed" + output_mlir = self.llama8b_f16_decomposed_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama8b_f16_decomposed_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_decomposed_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + export_return_code = self.llama8b_f16_decomposed_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama8b_f16_decomposed_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @skipif_run_quick_llama_test + def testBenchmark8B_f16_Non_Decomposed_Prefill(self): + output_file_name = self.dir_path_8b / "f16_torch_prefill" + output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" + export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + skip_decode=True, + ) + self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_nondecomposed_args_fp16, + cwd=self.repo_root, + ) + + @skipif_run_quick_llama_test + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + def testBenchmark8B_f16_Non_Decomposed(self): + output_file_name = self.dir_path_8b / "f16_torch" + output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch" + export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + def testBenchmark8B_fp8_Decomposed(self): + output_file_name = self.dir_path_8b / "fp8_decomposed" + output_mlir = self.llama8b_fp8_decomposed_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama8b_fp8_decomposed_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_fp8_decomposed_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + export_return_code = self.llama8b_fp8_decomposed_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama8b_fp8_decomposed_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama8b_fp8_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama8b_fp8_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + def testBenchmark8B_fp8_Non_Decomposed(self): + output_file_name = self.dir_path_8b / "fp8_torch" + output_mlir = self.llama8b_fp8_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama8b_fp8_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_fp8_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + export_return_code = self.llama8b_fp8_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama8b_fp8_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + +@is_mi300x +@skipif_run_quick_llama_test +class BenchmarkLlama3_1_70B(BaseBenchmarkTest): + def setUp(self): + super().setUp() + # TODO: add numpy files to Azure and download from it + self.artifacts_dir = Path("/data/llama3.1/weights/70b") + self.irpa_path = self.artifacts_dir / "fp16/llama3.1_70b_f16.irpa" + self.irpa_path_fp8 = self.artifacts_dir / "f8/llama70b_fp8.irpa" + self.tensor_parallelism_size = 8 + self.dir_path_70b = self.dir_path / "llama-70b" + self.temp_dir_70b = Path(self.dir_path_70b) + self.temp_dir_70b.mkdir(parents=True, exist_ok=True) + self.llama70b_f16_decomposed_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama70b_f16_torch_sdpa_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="torch", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama70b_fp8_decomposed_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path_fp8), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama70b_fp8_torch_sdpa_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path_fp8), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="torch", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.prefill_args_f16 = self.artifacts_dir / "prefill_args" + self.prefill_args_bs4_128_in_tokens_f16 = ( + self.artifacts_dir / "prefill_args_bs4_128" + ) + self.decode_args_f16 = self.artifacts_dir / "decode_args" + self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" + self.iree_run_prefill_args = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_f16}/tokens.npy", + f"--input=@{self.prefill_args_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_nondecomposed_args_fp16 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_f16}/tokens.npy", + f"--input=@{self.decode_args_f16}/seq_lens.npy", + f"--input=@{self.decode_args_f16}/start_positions.npy", + f"--input=@{self.decode_args_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_args_fp8 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_fp8}/tokens.npy", + f"--input=@{self.prefill_args_fp8}/seq_lens.npy", + f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy", + f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args_fp8 = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_fp8}/tokens.npy", + f"--input=@{self.decode_args_fp8}/seq_lens.npy", + f"--input=@{self.decode_args_fp8}/start_positions.npy", + f"--input=@{self.decode_args_fp8}/seq_block_ids.npy", + f"--input=@{self.decode_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + + @pytest.mark.xfail( + reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException + ) + def testBenchmark70B_f16_TP8_Decomposed(self): + output_file_name = self.dir_path_70b / "f16_decomposed" + output_mlir = self.llama70b_f16_decomposed_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama70b_f16_decomposed_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama70b_f16_decomposed_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = ( + self.artifacts_dir + / f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama70b_f16_decomposed_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama70b_f16_decomposed_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + def testBenchmark70B_f16_TP8_Non_Decomposed(self): + output_file_name = self.dir_path_70b / "f16_torch" + output_mlir = self.llama70b_f16_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama70b_f16_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama70b_f16_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + self.llama70b_f16_torch_sdpa_artifacts.attention_kernel = "torch" + output_shard_file_name = ( + self.artifacts_dir + / f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama70b_f16_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama70b_f16_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama70b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama70b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + def testBenchmark70B_fp8_TP8_Decomposed(self): + output_file_name = self.dir_path_70b / "fp8_decomposed" + output_mlir = self.llama70b_fp8_decomposed_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama70b_fp8_decomposed_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama70b_fp8_decomposed_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = ( + self.artifacts_dir + / f"f8/tp8/llama3.1_70b_fp8_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama70b_fp8_decomposed_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama70b_fp8_decomposed_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama70b_fp8_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama70b_fp8_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) + def testBenchmark70B_fp8_TP8_Non_Decomposed(self): + output_file_name = self.dir_path_70b / "fp8_torch" + output_mlir = self.llama70b_fp8_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama70b_fp8_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama70b_fp8_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = ( + self.artifacts_dir + / f"f8/tp8/llama3.1_70b_f8_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama70b_fp8_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama70b_fp8_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + +@is_mi300x +@skipif_run_quick_llama_test +class BenchmarkLlama3_1_405B(BaseBenchmarkTest): + def setUp(self): + super().setUp() + # TODO: add numpy files to Azure and download from it + self.artifacts_dir = Path("/data/llama3.1/weights/405b") + self.irpa_path = self.artifacts_dir / "fp16/llama3.1_405b_fp16.irpa" + self.irpa_path_fp8 = self.artifacts_dir / "f8/llama3.1_405b_fp8.irpa" + self.tensor_parallelism_size = 8 + self.dir_path_405b = self.dir_path / "llama-405b" + self.temp_dir_405b = Path(self.dir_path_405b) + self.temp_dir_405b.mkdir(parents=True, exist_ok=True) + self.llama405b_f16_decomposed_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama405b_f16_torch_sdpa_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="torch", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama405b_fp8_decomposed_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path_fp8), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.llama405b_fp8_torch_sdpa_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path_fp8), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="torch", + tensor_parallelism_size=self.tensor_parallelism_size, + ) + self.prefill_args_f16 = self.artifacts_dir / "prefill_args" + self.prefill_args_bs4_128_in_tokens_f16 = ( + self.artifacts_dir / "prefill_args_bs4_128" + ) + self.decode_args_f16 = self.artifacts_dir / "decode_args" + self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" + self.iree_run_prefill_args = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_f16}/tokens.npy", + f"--input=@{self.prefill_args_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_nondecomposed_args_fp16 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy", + f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_f16}/tokens.npy", + f"--input=@{self.decode_args_f16}/seq_lens.npy", + f"--input=@{self.decode_args_f16}/start_positions.npy", + f"--input=@{self.decode_args_f16}/seq_block_ids.npy", + f"--input=@{self.decode_args_f16}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_prefill_args_fp8 = [ + "--function=prefill_bs4", + f"--input=@{self.prefill_args_fp8}/tokens.npy", + f"--input=@{self.prefill_args_fp8}/seq_lens.npy", + f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy", + f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + self.iree_run_decode_args_fp8 = [ + "--function=decode_bs4", + f"--input=@{self.decode_args_fp8}/tokens.npy", + f"--input=@{self.decode_args_fp8}/seq_lens.npy", + f"--input=@{self.decode_args_fp8}/start_positions.npy", + f"--input=@{self.decode_args_fp8}/seq_block_ids.npy", + f"--input=@{self.decode_args_fp8}/cache_state_f16.npy", + "--benchmark_repetitions=3", + ] + + @pytest.mark.xfail( + reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException + ) + def testBenchmark405B_f16_TP8_Decomposed(self): + output_file_name = self.dir_path_405b / "f16_decomposed" + output_mlir = self.llama405b_f16_decomposed_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama405b_f16_decomposed_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama405b_f16_decomposed_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = ( + self.artifacts_dir + / f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama405b_f16_decomposed_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama405b_f16_decomposed_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail( + reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException + ) + def testBenchmark405B_f16_TP8_Non_Decomposed(self): + output_file_name = self.dir_path_405b / "f16_torch" + output_mlir = self.llama405b_f16_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama405b_f16_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama405b_f16_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + self.llama405b_f16_torch_sdpa_artifacts.attention_kernel = "torch" + output_shard_file_name = ( + self.artifacts_dir + / f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama405b_f16_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + skip_decode=True, + ) + self.llama405b_f16_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail( + reason="KeyError in theta.py", strict=True, raises=ExportMlirException + ) + def testBenchmark405B_fp8_TP8_Decomposed(self): + output_file_name = self.dir_path_405b / "fp8_decomposed" + output_mlir = self.llama405b_fp8_decomposed_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama405b_fp8_decomposed_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama405b_fp8_decomposed_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = ( + self.artifacts_dir + / f"f8/tp8/llama3.1_405b_f8_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama405b_fp8_decomposed_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama405b_fp8_decomposed_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama405b_fp8_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama405b_fp8_decomposed_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + @pytest.mark.xfail( + reason="KeyError in theta.py", strict=True, raises=ExportMlirException + ) + def testBenchmark405B_fp8_TP8_Non_Decomposed(self): + output_file_name = self.dir_path_405b / "fp8_torch" + output_mlir = self.llama405b_fp8_torch_sdpa_artifacts.create_file( + suffix=".mlir", prefix=output_file_name + ) + output_json = self.llama405b_fp8_torch_sdpa_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama405b_fp8_torch_sdpa_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = ( + self.artifacts_dir + / f"f8/tp8/llama3.1_405b_f8_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + if output_shard_file_name.exists(): + self.irpa_path = output_shard_file_name + export_return_code = self.llama405b_fp8_torch_sdpa_artifacts.export_to_mlir( + mlir_path=output_mlir, + json_path=output_json, + ) + self.llama405b_fp8_torch_sdpa_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, + cwd=self.repo_root, + args=self.compile_args, + ) + # benchmark prefill + self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_prefill_args, + cwd=self.repo_root, + ) + # benchmark decode + self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb( + hip_device_id=self.hip_device_id, + vmfb_name=output_vmfb, + irpa_path=self.irpa_path_fp8, + args=self.iree_run_decode_args, + cwd=self.repo_root, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index 9d36db2e2..a80575951 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -28,6 +28,7 @@ def setUp(self): self.block_seq_stride = 16 self.rms_epsilon = 1e-5 self.rope_dimension_count = 128 + self.rope_freq_base = 10000.0 self.max_seq_len = 4096 self.start_positions = torch.tensor([8]) self.bs = 1 @@ -58,6 +59,7 @@ def setUp(self): ) self.attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=self.rope_dimension_count, + rope_freq_base=self.rope_freq_base, max_seqlen=self.max_seq_len, device=self.device, use_hf=False, @@ -72,7 +74,6 @@ def setUp(self): head_dim=self.head_dim, head_count_kv=self.head_count_kv, rms_epsilon=self.rms_epsilon, - use_hf=False, ) for n in range(self.block_count) ] @@ -87,7 +88,6 @@ def setUp(self): head_dim=self.head_dim, head_count_kv=self.head_count_kv, rms_epsilon=self.rms_epsilon, - use_hf=False, ) for n in range(self.block_count) ] diff --git a/sharktank/tests/models/llama/llama_cpp_instructions.md b/sharktank/tests/models/llama/llama_cpp_instructions.md new file mode 100644 index 000000000..1ca6fa3cb --- /dev/null +++ b/sharktank/tests/models/llama/llama_cpp_instructions.md @@ -0,0 +1,19 @@ +### How to build llama.cpp logit comparison branch +``` +git clone https://github.com/aviator19941/llama.cpp.git +cd llama.cpp/ +git checkout llama_comparison +cmake -B build +cmake --build build --config Release +``` + +### How to run llama.cpp +``` +huggingface-cli download meta-llama/Meta-Llama-3.1-70B --local-dir /home/avsharma/Meta-Llama-3.1-70B +python convert_hf_to_gguf.py --outtype f16 --outfile Llama-3.1-70B-f16.gguf ../Meta-Llama-3.1-70B/ +``` + +To predict the prefill token, use `--n-predict 1` and to predict the first decode token, use `--n-predict 2`: +``` +./build/bin/llama-cli -m Llama-3.1-70B-f16.gguf -p "I believe the meaning of life is" --threads 1 --temp 0 --n-predict 1 --no-warmup +``` diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index e04ca11fd..9b3daabdf 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -8,25 +8,24 @@ from typing import List import torch -from shark_turbine.aot import * +from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import SparseMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from sharktank import ops -class SparseMoeBlockTest(unittest.TestCase): - @unittest.skip("Skip test until grok implementation") +class MoeBlockTest(unittest.TestCase): def test(self): - model = SparseMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, rms_epsilon=1e-5, ) fxb = FxProgramsBuilder(model) - input = make_rand_torch((2, 16, 6144)) + input = make_rand_torch((2, 32, 6144)) - @fxb.export_program(name="moe_block", args=(input,)) + @fxb.export_program(name="moe_block", args=(input,), strict=False) def _(model, input: torch.Tensor) -> torch.Tensor: return model(input) diff --git a/sharktank/tests/models/llama/prefill_tests.py b/sharktank/tests/models/llama/prefill_tests.py new file mode 100644 index 000000000..093ecdfc9 --- /dev/null +++ b/sharktank/tests/models/llama/prefill_tests.py @@ -0,0 +1,175 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from sharktank.examples.paged_llm_v1 import * +from sharktank.utils import tokenizer +from sharktank.utils import hf_datasets +import unittest +from pathlib import Path + + +class BaseLlamaTest(unittest.TestCase): + def setUp(self): + raise NotImplementedError("Subclasses should implement this method.") + + def createConfigModel(self, kv_cache_type): + return LlamaModelConfig( + hp=configs.LlamaHParams.from_gguf_props(self.dataset.properties), + block_seq_stride=16, + kv_cache_type=kv_cache_type, + device=self.device, + activation_dtype=self.activation_dtype, + attention_dtype=self.activation_dtype, + ) + + def runPrefill(self, *, kv_cache_type): + config = self.createConfigModel(kv_cache_type) + model = PagedLlamaModelV1(self.dataset.root_theta, config) + generator = TorchGenerator(model, self.tokenizer_config) + batch = generator.begin_batch(self.prompts) + attention_mask = model.attention_mask( + model.input_mask(batch.seq_lens, batch.token_ids.shape[1]) + ) + seq_block_ids_tensor = batch.pad_block_ids() + logits = batch.compute_prefill_logits( + model, + batch.token_ids, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids_tensor, + cache_state=batch.cache_state, + ) + batch.prefill() + + bs, *_ = logits.shape + assert len(batch.seq_lens) == bs + greedy_token_logit = 0.0 + step_logits = logits[0, batch.seq_lens[0] - 1] + greedy_token_logit = step_logits[torch.argmax(step_logits)] + + return batch.results, greedy_token_logit + + def comparePrefillResults( + self, + batch_results, + greedy_token_logit, + golden_prefill_token, + golden_prefill_token_logit, + ): + rtol = 3e-4 + atol = 4e-3 + assert batch_results == golden_prefill_token + torch.testing.assert_close( + greedy_token_logit, golden_prefill_token_logit, rtol=rtol, atol=atol + ) + + +class Llama7BTest(BaseLlamaTest): + def setUp(self): + default_arguments = { + "hf_dataset": "llama2_7B_f16", + "tokenizer-config-json": Path("./llama2-7b/tokenizer_config.json"), + "prompt": ["I believe the meaning of life is"], + "device": None, + "activation-dtype": "float32", + } + self.device = ( + torch.device(default_arguments["device"]) + if default_arguments["device"] + else None + ) + self.activation_dtype = getattr(torch, default_arguments["activation-dtype"]) + assert isinstance(self.activation_dtype, torch.dtype) + self.data_files = hf_datasets.get_dataset( + default_arguments["hf_dataset"] + ).download(local_dir=Path(".")) + self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf") + self.tokenizer_config = tokenizer.load_tokenizer( + default_arguments["tokenizer-config-json"].parent, + tokenizer_type="transformers", + ) + self.prompts = default_arguments["prompt"] + # token and logit determined by running llama.cpp (llama_cpp_instructions.md). + self.llama_cpp_7b_prefill_token = [[304]] + self.llama_cpp_7b_prefill_token_logit = torch.tensor(19.356606) + + def testPrefillPaged7B(self): + batch_results_paged, greedy_token_logit_paged = self.runPrefill( + kv_cache_type="paged" + ) + self.comparePrefillResults( + batch_results_paged, + greedy_token_logit_paged, + self.llama_cpp_7b_prefill_token, + self.llama_cpp_7b_prefill_token_logit, + ) + + def testPrefillDirect7B(self): + batch_results_direct, greedy_token_logit_direct = self.runPrefill( + kv_cache_type="direct" + ) + self.comparePrefillResults( + batch_results_direct, + greedy_token_logit_direct, + self.llama_cpp_7b_prefill_token, + self.llama_cpp_7b_prefill_token_logit, + ) + + +class Llama8BTest(BaseLlamaTest): + def setUp(self): + default_arguments = { + "hf_dataset": "llama3_8B_f16", + "tokenizer-config-json": Path("./llama3.1-8b/tokenizer_config.json"), + "prompt": ["I believe the meaning of life is"], + "device": None, + "activation-dtype": "float32", + } + self.device = ( + torch.device(default_arguments["device"]) + if default_arguments["device"] + else None + ) + self.activation_dtype = getattr(torch, default_arguments["activation-dtype"]) + assert isinstance(self.activation_dtype, torch.dtype) + self.data_files = hf_datasets.get_dataset( + default_arguments["hf_dataset"] + ).download(local_dir=Path(".")) + self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf") + self.tokenizer_config = tokenizer.load_tokenizer( + default_arguments["tokenizer-config-json"].parent, + tokenizer_type="transformers", + ) + self.prompts = default_arguments["prompt"] + # token and logit determined by running llama.cpp (llama_cpp_instructions.md). + self.llama_cpp_8b_prefill_token = [[311]] + self.llama_cpp_8b_prefill_token_logit = torch.tensor(15.612568) + + def testPrefillPaged8B(self): + batch_results_paged, greedy_token_logit_paged = self.runPrefill( + kv_cache_type="paged" + ) + self.comparePrefillResults( + batch_results_paged, + greedy_token_logit_paged, + self.llama_cpp_8b_prefill_token, + self.llama_cpp_8b_prefill_token_logit, + ) + + def testPrefillDirect8B(self): + batch_results_direct, greedy_token_logit_direct = self.runPrefill( + kv_cache_type="direct" + ) + self.comparePrefillResults( + batch_results_direct, + greedy_token_logit_direct, + self.llama_cpp_8b_prefill_token, + self.llama_cpp_8b_prefill_token_logit, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py new file mode 100644 index 000000000..386061731 --- /dev/null +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -0,0 +1,406 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import pytest +from typing import Any, List, Tuple, OrderedDict +from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 +import sharktank.ops as ops +from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor +from sharktank.models.llama.testing import make_random_llama_theta +from sharktank.utils.testing import skip +from sharktank.models.llama.sharding import shard_theta +from sharktank.layers.configs import LlamaHParams +from sharktank.utils.math import round_up_to_multiple_of +from sharktank.utils import iterables_equal +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + iree_to_torch, +) +from sharktank.export import export as sharktank_export +import tempfile +import torch +from copy import deepcopy +from iree.turbine.aot import FxProgramsBuilder, export +import iree.runtime +import numpy as np +import os + + +@pytest.mark.usefixtures("caching", "path_prefix") +class ShardedLlamaTest(unittest.TestCase): + def setUp(self): + torch.random.manual_seed(123456) + self.dtype = torch.float32 + torch.set_default_dtype(self.dtype) + self.batch_size = 3 + self.attention_head_count_kv = 4 + self.attention_head_count = self.attention_head_count_kv * 5 + self.vocabulary_size = 19 + self.rope_dimension_count = 7 * 2 + self.attn_head_dim = self.rope_dimension_count + self.block_seq_stride = 13 + self.cache_page_count = 11 + self.config = LlamaModelConfig( + hp=LlamaHParams( + context_length=self.block_seq_stride * 2, + embedding_length=self.attention_head_count * self.attn_head_dim, + block_count=3, + feed_forward_length=23, + rope_dimension_count=self.rope_dimension_count, + rope_freq_base=500000.0, + attention_head_count=self.attention_head_count, + attn_head_dim=self.attn_head_dim, + attention_layer_norm_rms_epsilon=0.01, + attention_head_count_kv=self.attention_head_count_kv, + expert_count=0, + expert_used_count=0, + model_arch="llama", + ), + block_seq_stride=self.block_seq_stride, + activation_dtype=self.dtype, + attention_dtype=self.dtype, + ) + self.sharded_config = deepcopy(self.config) + self.sharded_config.tensor_parallelism_size = 2 + self.theta = make_random_llama_theta( + config=self.config, + vocab_size=self.vocabulary_size, + ) + self.prefill_seq_lens = torch.tensor( + [14, 9, self.block_seq_stride - 1], dtype=torch.int64 + ) + + def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: + batch_seq_len = round_up_to_multiple_of( + int(torch.max(self.prefill_seq_lens)), model.cache.pad_sequence_stride + ) + token_ids = torch.randint( + low=0, + high=self.vocabulary_size, + size=[self.batch_size, batch_seq_len], + dtype=torch.int32, + ) + attention_mask = model.attention_mask( + model.input_mask(self.prefill_seq_lens, batch_seq_len) + ) + seq_block_ids = torch.arange( + self.batch_size * batch_seq_len // self.config.block_seq_stride + ).view(self.batch_size, -1) + cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) + cache_state = [torch.rand_like(cache_state[0])] + return OrderedDict( + [ + ("tokens", token_ids), + ("attention_mask", attention_mask), + ("seq_block_ids", seq_block_ids), + ("cache_state", cache_state), + ] + ) + + def make_equal_unsharded_and_sharded_prefill_args( + self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 + ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: + prefill_kwargs = self.make_prefill_args(model) + sharded_cache_state = sharded_model.cache.paged.allocate( + page_count=self.cache_page_count + ) + assert iterables_equal( + prefill_kwargs["cache_state"][0].shape, sharded_cache_state[0].shape + ) + sharded_prefill_kwargs = deepcopy(prefill_kwargs) + sharded_cache_state = sharded_model.cache.paged.shard_state( + sharded_prefill_kwargs["cache_state"] + ) + sharded_prefill_kwargs["cache_state"] = sharded_cache_state + + sharding = sharded_model.config.tensor_parallelism_size + for k in sharded_prefill_kwargs: + if k == "cache_state": + continue + sharded_prefill_kwargs[k] = ops.replicate( + sharded_prefill_kwargs[k], count=sharding + ) + + return prefill_kwargs, sharded_prefill_kwargs + + def make_decode_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: + start_positions = self.prefill_seq_lens.clone() + seq_lens = self.prefill_seq_lens + 1 + batch_seq_len = round_up_to_multiple_of( + int(torch.max(seq_lens)), model.cache.pad_sequence_stride + ) + decode_token_ids = torch.randint( + low=0, + high=self.vocabulary_size, + size=[self.batch_size, 1], + dtype=torch.int32, + ) + attention_mask = model.decode_attention_mask( + model.input_mask(seq_lens, batch_seq_len) + ) + seq_block_ids = torch.arange( + self.batch_size * batch_seq_len // self.config.block_seq_stride + ).view(self.batch_size, -1) + cache_state = model.cache.paged.allocate(page_count=self.cache_page_count) + cache_state = [torch.rand_like(cache_state[0])] + return OrderedDict( + [ + ("tokens", decode_token_ids), + ("attention_mask", attention_mask), + ("start_positions", start_positions), + ("seq_block_ids", seq_block_ids), + ("cache_state", cache_state), + ] + ) + + def make_equal_unsharded_and_sharded_decode_args( + self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1 + ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]: + decode_kwargs = self.make_decode_args(model) + sharded_decode_kwargs = deepcopy(decode_kwargs) + sharded_decode_kwargs["cache_state"] = sharded_model.cache.paged.shard_state( + sharded_decode_kwargs["cache_state"] + ) + + sharding = sharded_model.config.tensor_parallelism_size + for k in sharded_decode_kwargs: + if k == "cache_state": + continue + sharded_decode_kwargs[k] = ops.replicate( + sharded_decode_kwargs[k], count=sharding + ) + + return decode_kwargs, sharded_decode_kwargs + + def testCompareToySizedModelToUnsharded(self): + """Run a sharded variant of a toy model size and compare it against the + unsharded variant.""" + model = PagedLlamaModelV1(self.theta, self.config) + sharded_theta = shard_theta(self.theta, self.sharded_config) + sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config) + + # Verify prefill step. + ( + prefill_kwargs, + sharded_prefill_kwargs, + ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) + + expected_prefill_result = model.prefill(**prefill_kwargs) + sharded_prefill_result = sharded_model.prefill(**sharded_prefill_kwargs) + sharded_prefill_result = ops.unshard(sharded_prefill_result) + # The errors are quite high, but for float64 both errors drop to < 1e-12. + # The numerics are probably correct. + torch.testing.assert_close( + sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2 + ) + expected_cache_state = prefill_kwargs["cache_state"][0] + actual_cache_state = ops.unshard( + sharded_model.cache.paged.unflatten_page_table( + sharded_prefill_kwargs["cache_state"] + ) + ).flatten(start_dim=1) + torch.testing.assert_close( + actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1 + ) + + # Verify decode step. + ( + decode_kwargs, + sharded_decode_kwargs, + ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) + expected_decode_result = model.decode(**decode_kwargs) + sharded_decode_result = sharded_model.decode(**sharded_decode_kwargs) + sharded_decode_result = ops.unshard(sharded_decode_result) + torch.testing.assert_close( + sharded_decode_result, expected_decode_result, atol=1e-4, rtol=1e-5 + ) + expected_decode_cache_state = decode_kwargs["cache_state"][0] + actual_decode_cache_state = ops.unshard( + sharded_model.cache.paged.unflatten_page_table( + sharded_decode_kwargs["cache_state"] + ) + ).flatten(start_dim=1) + # TODO: investigate why the Windows machine CI is producing a larger numerical + # error. + # The Ubuntu CI runs fine with default tolerances. + torch.testing.assert_close( + actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4 + ) + + @skip( + ( + "Before this does not crash at all we need " + "https://github.com/iree-org/iree/pull/18663 merged." + ) + ) + def testExportAndRunToySizedModelWithIree(self): + """Test exporting to MLIR and compiling with IREE the sharded Llama model. + Test numerical accuracy of the IREE module against PyTorch.""" + + if self.path_prefix is not None: + self.runTestExportAndRunToySizedModelWithIree( + path_prefix=self.path_prefix, dump_enabled=True + ) + else: + with tempfile.TemporaryDirectory() as temp_dir: + self.runTestExportAndRunToySizedModelWithIree( + path_prefix=f"{temp_dir}/", dump_enabled=False + ) + + def runTestExportAndRunToySizedModelWithIree( + self, path_prefix: str, dump_enabled: bool + ): + sharded_theta = shard_theta(self.theta, self.sharded_config) + sharded_theta.rename_tensors_to_paths() + sharded_dataset = Dataset({}, sharded_theta) + sharded_parameters_path = f"{path_prefix}parameters.irpa" + sharded_dataset.save(sharded_parameters_path) + sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) + iree_driver = "local-task" + + model = PagedLlamaModelV1(self.theta, self.config) + sharded_model = PagedLlamaModelV1( + sharded_dataset.root_theta, self.sharded_config + ) + ( + _, + sharded_prefill_kwargs, + ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) + ( + _, + sharded_decode_kwargs, + ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) + + iree_module_path = f"{path_prefix}program.vmfb" + if not self.caching or not os.path.exists(iree_module_path): + # Export and compile the IREE module. + sharded_fxb = FxProgramsBuilder(sharded_model) + + @sharktank_export( + fx_builder=sharded_fxb, + name="prefill", + kwargs=sharded_prefill_kwargs, + strict=False, + ) + def _(model, *args, **kwargs) -> torch.Tensor: + return model.prefill(*args, **kwargs) + + # TODO: remove strict=False when + # https://github.com/pytorch/pytorch/issues/136757 + # is resolved. + @sharktank_export( + fx_builder=sharded_fxb, + name="decode", + kwargs=sharded_decode_kwargs, + strict=False, + ) + def _(model, *args, **kwargs) -> torch.Tensor: + return model.decode(*args, **kwargs) + + output = export(sharded_fxb) + if dump_enabled: + output.save_mlir(f"{path_prefix}program.mlir") + output.session.set_flags( + *[ + f"--iree-hal-target-device=llvm-cpu[{i}]" + for i in range(self.sharded_config.tensor_parallelism_size) + ] + ) + output.compile( + save_to=iree_module_path, + target_backends=None, + ) + + iree_devices = get_iree_devices( + driver=iree_driver, + device_count=self.sharded_config.tensor_parallelism_size, + ) + iree_module, vm_context, vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=sharded_parameters_path, + ) + + # Run prefill step. + prefill_iree_args = prepare_iree_module_function_args( + args=deepcopy(sharded_prefill_kwargs).values(), devices=iree_devices + ) + for i, arg in enumerate(prefill_iree_args): + np.save(f"{path_prefix}prefill_arg{i}.npy", arg.to_host()) + prefill_iree_result = run_iree_module_function( + args=prefill_iree_args, + function_name="prefill", + module=iree_module, + vm_context=vm_context, + driver=iree_driver, + trace_path_prefix=path_prefix if dump_enabled else None, + ) + prefill_iree_result = UnreducedTensor(ts=iree_to_torch(*prefill_iree_result)) + expected_prefill_result = call_torch_module_function( + module=sharded_model, + function_name="prefill", + kwargs=sharded_prefill_kwargs, + trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, + ) + prefill_iree_cache_state_shards = prefill_iree_args[ + -self.config.tensor_parallelism_size - 1 : + ] + prefill_iree_cache_state = SplitPrimitiveTensor( + ts=iree_to_torch(*prefill_iree_cache_state_shards), + shard_dim=sharded_prefill_kwargs["cache_state"][0].shard_dim, + ) + + # Run decode step. + decode_iree_args = prepare_iree_module_function_args( + args=deepcopy(sharded_decode_kwargs).values(), devices=iree_devices + ) + decode_iree_result = run_iree_module_function( + args=decode_iree_args, + function_name="decode", + module=iree_module, + vm_context=vm_context, + driver=iree_driver, + trace_path_prefix=path_prefix if dump_enabled else None, + ) + decode_iree_result = UnreducedTensor(ts=iree_to_torch(*decode_iree_result)) + expected_decode_result = call_torch_module_function( + module=sharded_model, + function_name="decode", + kwargs=sharded_decode_kwargs, + trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, + ) + decode_iree_cache_state_shards = decode_iree_args[ + -self.config.tensor_parallelism_size - 1 : + ] + decode_iree_cache_state = SplitPrimitiveTensor( + ts=iree_to_torch(*decode_iree_cache_state_shards), + shard_dim=sharded_decode_kwargs["cache_state"][0].shard_dim, + ) + + # Check IREE's numerical correctness against PyTorch. + # TODO: Although, not entirely wrong, investigate why this accuracy is that + # low for fp32 (atol=0.0011, rtol=0.013). + torch.testing.assert_close( + ops.unshard(prefill_iree_result), + ops.unshard(expected_prefill_result), + ) + torch.testing.assert_close( + ops.unshard(prefill_iree_cache_state), + ops.unshard(sharded_prefill_kwargs["cache_state"][0]), + ) + torch.testing.assert_close( + ops.unshard(decode_iree_result), + ops.unshard(expected_decode_result), + ) + torch.testing.assert_close( + ops.unshard(decode_iree_cache_state), + ops.unshard(sharded_decode_kwargs["cache_state"][0]), + ) diff --git a/sharktank/tests/models/punet/conftest.py b/sharktank/tests/models/punet/conftest.py deleted file mode 100644 index 55364c4ef..000000000 --- a/sharktank/tests/models/punet/conftest.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import pytest -from pathlib import Path -from typing import Optional - - -def pytest_addoption(parser): - parser.addoption( - "--mlir", - type=Path, - default=None, - help="Path to exported MLIR program. If not specified a temporary file will be used.", - ) - parser.addoption( - "--module", - type=Path, - default=None, - help="Path to exported IREE module. If not specified a temporary file will be used.", - ) - parser.addoption( - "--parameters", - type=Path, - default=None, - help="Exported model parameters. If not specified a temporary file will be used.", - ) - parser.addoption( - "--caching", - action="store_true", - default=False, - help="Load cached results if present instead of recomputing.", - ) - - -@pytest.fixture(scope="session") -def mlir_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("mlir") - - -@pytest.fixture(scope="session") -def module_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("module") - - -@pytest.fixture(scope="session") -def parameters_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("parameters") - - -@pytest.fixture(scope="session") -def caching(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("caching") diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py index e2b602a7c..581584369 100644 --- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py +++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py @@ -10,7 +10,7 @@ import torch -from shark_turbine import aot +from iree.turbine import aot from sharktank.models.punet.testing import make_resnet_block_2d_theta from sharktank.models.punet.layers import ResnetBlock2D from sharktank.models.punet.sharding import ResnetBlock2DSplitOutputChannelsSharding @@ -19,6 +19,7 @@ import iree.runtime from typing import List, Optional import os +import pytest vm_context: iree.runtime.VmContext = None @@ -40,6 +41,8 @@ def compile_iree_module( export_output.compile(save_to=module_path, target_backends=None) +# TODO: improve IREE's Python API to be more concise in a multi-device context. +# This run function should be way shorter. def run_iree_module( sharded_input_image: ShardedTensor, sharded_input_time_emb: ShardedTensor, @@ -205,18 +208,26 @@ def run_test_sharded_resnet_block_with_iree( parameters_path=parameters_path, ) assert len(actual_result.shards) == len(expected_result.shards) - # TODO: reenable this check once numerical issues are resolved. - # for actual_shard, expected_shard in zip( - # actual_result.shards, expected_result.shards - # ): - # torch.testing.assert_close( - # unbox_tensor(actual_shard), unbox_tensor(expected_shard) - # ) + # TODO: reenable this test once numerical issues are resolved. + # The absolute accuracy is > 0.00042. Is this good enough? + # Maybe add a test with fp64, where if the accuracy is high would give us more + # confidence that fp32 is also OK. + for actual_shard, expected_shard in zip( + actual_result.shards, expected_result.shards + ): + torch.testing.assert_close( + unbox_tensor(actual_shard), unbox_tensor(expected_shard) + ) global vm_context del vm_context +@pytest.mark.xfail( + reason="Maybe numerical issues with low accuracy.", + strict=True, + raises=AssertionError, +) def test_sharded_resnet_block_with_iree( mlir_path: Optional[Path], module_path: Optional[Path], diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py new file mode 100644 index 000000000..076404e5d --- /dev/null +++ b/sharktank/tests/models/t5/t5_test.py @@ -0,0 +1,410 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from transformers.models.t5.modeling_t5 import ( + T5Attention as ReferenceT5Attention, + T5LayerSelfAttention as ReferenceT5LayerSelfAttention, + T5LayerFF as ReferenceT5LayerFF, +) +from transformers import ( + AutoTokenizer, + T5EncoderModel as ReferenceT5EncoderModel, + T5Config as ReferenceT5Config, +) +import os +from collections import OrderedDict +import pytest +import torch +from unittest import TestCase +from parameterized import parameterized +from sharktank.types import Theta, DefaultPrimitiveTensor, unbox_tensor, Dataset +from sharktank.models.t5 import ( + T5Attention, + T5SelfAttention, + T5Config, + T5Encoder, + T5LayerFF, + export_encoder_mlir, + export_encoder_iree_parameters, +) +from sharktank.utils.testing import make_rand_torch, TempDirTestBase +from sharktank.utils.hf_datasets import get_dataset +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +import iree.compiler + +with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')") + + +def make_random_mask(shape: tuple[int], dtype: torch.dtype): + mask = make_rand_torch(shape=shape, dtype=dtype) + mask = (mask >= 0).to(dtype=dtype) + return mask + + +test_prompts = [ + "Studies have been shown that owning a dog is good for you", + "The horse went into the river", + "We need at least one sentence long enough so that it spans more than one padding block which by default is of size 16.", + "Make the batch size 4", +] + + +@pytest.mark.usefixtures("get_model_artifacts") +class T5EncoderEagerTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + def runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace( + self, huggingface_repo_id: str + ): + get_dataset( + huggingface_repo_id, + ).download() + tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) + reference_model = ReferenceT5EncoderModel.from_pretrained(huggingface_repo_id) + reference_model.eval() + + input_ids = tokenizer( + test_prompts, + return_tensors="pt", + padding=True, + ).input_ids + + target_model_name = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_fp32_model" + ) + target_model_path = getattr(self, target_model_name) + dataset = Dataset.load(target_model_path) + config = T5Config.from_gguf_properties( + dataset.properties, + feed_forward_proj="gated-gelu", + ) + model = T5Encoder(theta=dataset.root_theta, config=config) + model.eval() + + expected_outputs = reference_model(input_ids=input_ids) + actual_outputs = model(input_ids=input_ids) + torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + + @with_t5_data + def testV1_1SmallFp32CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-small") + + @with_t5_data + def testV1_1XxlFp32CompareTorchEagerAgainstHuggingFace(self): + self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-xxl") + + +@pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix") +class T5EncoderIreeTest(TempDirTestBase): + def setUp(self): + super().setUp() + if self.path_prefix is None: + self.path_prefix = f"{self._temp_dir}/" + + @parameterized.expand( + [ + "google/t5-v1_1-small", + "google/t5-v1_1-xxl", + ] + ) + @with_t5_data + def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str): + get_dataset( + huggingface_repo_id, + ).download() + tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id) + + huggingface_repo_id_as_path = ( + f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}" + ) + source_model_name = f"{huggingface_repo_id_as_path}_fp32_model" + source_model_path = getattr(self, source_model_name) + + dataset = Dataset.load(source_model_path) + config = T5Config.from_gguf_properties( + dataset.properties, + feed_forward_proj="gated-gelu", + ) + + input_ids = tokenizer( + test_prompts, + return_tensors="pt", + padding=True, + pad_to_multiple_of=config.context_length_padding_block_size, + ).input_ids + input_args = OrderedDict([("input_ids", input_ids)]) + batch_size = input_ids.shape[0] + + reference_model = T5Encoder(theta=dataset.root_theta, config=config) + reference_result = flatten_for_iree_signature( + call_torch_module_function( + module=reference_model, + function_name="forward", + kwargs=input_args, + trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_torch_", + ) + ) + + mlir_path = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.mlir" + if not self.caching or not os.path.exists(mlir_path): + export_encoder_mlir( + source_model_path, batch_sizes=[batch_size], mlir_output_path=mlir_path + ) + iree_module_path = ( + f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.vmfb" + ) + if not self.caching or not os.path.exists(iree_module_path): + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"], + ) + + parameters_path = ( + f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.irpa" + ) + if not self.caching or not os.path.exists(parameters_path): + export_encoder_iree_parameters(source_model_path, parameters_path) + + iree_devices = get_iree_devices(driver="hip", device_count=1) + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=parameters_path, + ) + iree_args = prepare_iree_module_function_args( + args=flatten_for_iree_signature(input_args), devices=iree_devices + ) + iree_result = iree_to_torch( + *run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name=f"forward_bs{batch_size}", + trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_iree_", + ) + ) + + torch.testing.assert_close( + reference_result, iree_result, atol=1e-4, rtol=2.0e-3 + ) + + +class T5AttentionTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + def testCompareAgainstTransformersFp32(self): + dtype = torch.float32 + batch_size = 19 + batch_seq_len = 23 + reference_config = ReferenceT5Config( + vocab_size=11, + d_model=13, + d_kv=7, + d_ff=3, + num_heads=2, + relative_attention_num_buckets=5, + relative_attention_max_distance=17, + dropout_rate=0.0, + ) + reference_model = ReferenceT5Attention( + reference_config, has_relative_attention_bias=True + ) + reference_model.eval() + + theta = Theta( + { + "attn_q.weight": DefaultPrimitiveTensor( + data=reference_model.q.weight.data + ), + "attn_k.weight": DefaultPrimitiveTensor( + data=reference_model.k.weight.data + ), + "attn_v.weight": DefaultPrimitiveTensor( + data=reference_model.v.weight.data + ), + "attn_o.weight": DefaultPrimitiveTensor( + data=reference_model.o.weight.data + ), + "attn_rel_b.weight": DefaultPrimitiveTensor( + data=reference_model.relative_attention_bias.weight.data + ), + } + ) + model = T5Attention( + theta=theta, + is_decoder=reference_config.is_decoder, + relative_attention_num_buckets=reference_config.relative_attention_num_buckets, + relative_attention_max_distance=reference_config.relative_attention_max_distance, + d_model=reference_config.d_model, + d_kv=reference_config.d_kv, + num_heads=reference_config.num_heads, + activation_dtype=dtype, + has_relative_attention_bias=True, + ) + model.eval() + + hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + ) + mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype) + expected_outputs = reference_model(hidden_states=hidden_states, mask=mask) + actual_outputs = model( + hidden_states=DefaultPrimitiveTensor(data=hidden_states), + mask=DefaultPrimitiveTensor(data=mask), + ) + torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + + def testCompareSelfAttentionAgainstTransformersFp32(self): + dtype = torch.float32 + batch_size = 19 + batch_seq_len = 23 + reference_config = ReferenceT5Config( + vocab_size=11, + d_model=13, + d_kv=7, + d_ff=3, + num_heads=2, + relative_attention_num_buckets=5, + relative_attention_max_distance=17, + dropout_rate=0.0, + layer_norm_epsilon=1e-6, + ) + reference_model = ReferenceT5LayerSelfAttention( + reference_config, has_relative_attention_bias=True + ) + reference_model.eval() + + theta = Theta( + { + "attn_q.weight": DefaultPrimitiveTensor( + data=reference_model.SelfAttention.q.weight.data + ), + "attn_k.weight": DefaultPrimitiveTensor( + data=reference_model.SelfAttention.k.weight.data + ), + "attn_v.weight": DefaultPrimitiveTensor( + data=reference_model.SelfAttention.v.weight.data + ), + "attn_o.weight": DefaultPrimitiveTensor( + data=reference_model.SelfAttention.o.weight.data + ), + "attn_rel_b.weight": DefaultPrimitiveTensor( + data=reference_model.SelfAttention.relative_attention_bias.weight.data + ), + "attn_norm.weight": DefaultPrimitiveTensor( + data=reference_model.layer_norm.weight.data + ), + } + ) + model = T5SelfAttention( + theta=theta, + is_decoder=reference_config.is_decoder, + relative_attention_num_buckets=reference_config.relative_attention_num_buckets, + relative_attention_max_distance=reference_config.relative_attention_max_distance, + d_model=reference_config.d_model, + d_kv=reference_config.d_kv, + num_heads=reference_config.num_heads, + activation_dtype=dtype, + layer_norm_epsilon=reference_config.layer_norm_epsilon, + has_relative_attention_bias=True, + ) + model.eval() + + hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + ) + mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype) + position_bias = make_rand_torch( + shape=[batch_size, reference_config.num_heads, batch_seq_len, batch_seq_len] + ) + expected_outputs = reference_model( + hidden_states=hidden_states, + attention_mask=mask, + position_bias=position_bias, + ) + actual_outputs = model( + hidden_states=DefaultPrimitiveTensor(data=hidden_states), + attention_mask=DefaultPrimitiveTensor(data=mask), + position_bias=DefaultPrimitiveTensor(data=position_bias), + ) + actual_outputs = [ + unbox_tensor(t) if t is not None else t for t in actual_outputs + ] + torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0) + + +class T5LayerFFTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + def testCompareAgainstTransformersFp32(self): + dtype = torch.float32 + batch_size = 19 + batch_seq_len = 23 + reference_config = ReferenceT5Config( + d_model=13, + d_ff=3, + dropout_rate=0.0, + layer_norm_epsilon=1e-6, + feed_forward_proj="gated-gelu", + ) + + reference_model = ReferenceT5LayerFF(reference_config) + reference_model.eval() + + theta = Theta( + { + "ffn_gate.weight": DefaultPrimitiveTensor( + data=reference_model.DenseReluDense.wi_0.weight + ), + "ffn_up.weight": DefaultPrimitiveTensor( + data=reference_model.DenseReluDense.wi_1.weight + ), + "ffn_down.weight": DefaultPrimitiveTensor( + data=reference_model.DenseReluDense.wo.weight + ), + "ffn_norm.weight": DefaultPrimitiveTensor( + data=reference_model.layer_norm.weight + ), + } + ) + model = T5LayerFF( + theta=theta, + is_gated_act=reference_config.is_gated_act, + dense_act_fn=reference_config.dense_act_fn, + layer_norm_epsilon=reference_config.layer_norm_epsilon, + activation_dtype=dtype, + ) + + hidden_states = make_rand_torch( + shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype + ) + + expected_output = reference_model( + hidden_states=hidden_states, + ) + actual_output = model( + hidden_states=DefaultPrimitiveTensor(data=hidden_states), + ) + torch.testing.assert_close(actual_output, expected_output, atol=1e-5, rtol=0) diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py index b282d03fe..d303f9f43 100644 --- a/sharktank/tests/ops/ops_test.py +++ b/sharktank/tests/ops/ops_test.py @@ -115,7 +115,7 @@ def testMatchFail(self): ): ops.matmul(1, 2) - @unittest.skip("https://github.com/nod-ai/sharktank/issues/44") + @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44") def testTorchImplTransposedRHS(self): ops._registry._test_enable_last_op_dispatch(True) t1 = torch.rand(32, 16, dtype=torch.float32) @@ -128,7 +128,7 @@ def testTorchImplTransposedRHS(self): ops.custom_impls.matmul_mmtfp_tensor_tensor, ) - @unittest.skip("https://github.com/nod-ai/sharktank/issues/44") + @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44") def testTorchImplNonTransposedRHS(self): ops._registry._test_enable_last_op_dispatch(True) t1 = torch.rand(32, 16, dtype=torch.float32) @@ -141,7 +141,7 @@ def testTorchImplNonTransposedRHS(self): ops.custom_impls.matmul_mmtfp_tensor_tensor, ) - @unittest.skip("https://github.com/nod-ai/sharktank/issues/44") + @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44") def testTorchImplTransposedPrimitiveRHS(self): ops._registry._test_enable_last_op_dispatch(True) t1 = torch.rand(32, 16, dtype=torch.float32) @@ -155,6 +155,15 @@ def testTorchImplTransposedPrimitiveRHS(self): ops.custom_impls.matmul_mmtfp_tensor_tensor, ) + def testTorchImplImplicitBatch(self): + ops._registry._test_enable_last_op_dispatch(True) + t1 = torch.rand(4, 32, 16, dtype=torch.float32) + t2 = torch.rand(48, 16, dtype=torch.float16) + t2_pt = DefaultPrimitiveTensor(data=t2) + result = ops.matmul(t1, t2_pt.T) + expected = torch.matmul(t1, t2.T.to(torch.float32)) + torch.testing.assert_close(result, expected) + def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self): ops._registry._test_enable_last_op_dispatch(True) a_dtype = torch.float32 diff --git a/sharktank/tests/ops/qconv_test.py b/sharktank/tests/ops/qconv_test.py index 4440202eb..97b0efd66 100644 --- a/sharktank/tests/ops/qconv_test.py +++ b/sharktank/tests/ops/qconv_test.py @@ -71,7 +71,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_NoBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled_integer, + ops.qconv_impls.qconv2d_tensor_scaled, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -105,7 +105,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_FloatBias(self): y_actual = ops.conv2d(input_q, weight_q, bias, stride=1, padding=(1, 1)) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled_integer, + ops.qconv_impls.qconv2d_tensor_scaled, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -147,7 +147,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_QuantizedBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled_integer, + ops.qconv_impls.qconv2d_tensor_scaled, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -184,7 +184,7 @@ def testInputSymPerTensor_WeightSymPerTensor_NoBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled_integer, + ops.qconv_impls.qconv2d_tensor_scaled, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), @@ -224,7 +224,7 @@ def testInputAsymPerChannel_WeightAsymPerChannel_NoBias(self): ) self.assertIs( ops._registry._test_get_last_op_dispatch(), - ops.qconv_impls.qconv2d_tensor_scaled_integer, + ops.qconv_impls.qconv2d_tensor_scaled, ) y_ref = torch.nn.functional.conv2d( input_q.unpack().dequant(), diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index 39078c5b9..e5efaa948 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -15,6 +15,40 @@ from sharktank.layers import Conv2DLayer +class AllGatherTest(unittest.TestCase): + def testAllGather(self): + shard_count = 3 + shard_shape = [3, 4] + shard_dim = 1 + shards = [ + torch.rand(shard_shape, dtype=torch.float32) for i in range(shard_count) + ] + expected_result = torch.cat(shards, dim=shard_dim) + + sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) + actual_result = ops.all_gather(sharded) + + for shard in actual_result.shards: + torch.testing.assert_close(shard.as_torch(), expected_result) + + +class AllReduceTest(unittest.TestCase): + def testAllReduce(self): + shard_count = 3 + shard_shape = [3, 4] + shard_dim = 1 + shards = [ + torch.rand(shard_shape, dtype=torch.float32) for i in range(shard_count) + ] + expected_result = torch.add(torch.add(shards[0], shards[1]), shards[2]) + + sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) + actual_result = ops.all_reduce(sharded) + + for shard in actual_result.shards: + torch.testing.assert_close(shard.as_torch(), expected_result) + + class CatTest(unittest.TestCase): def testCatSplitDim(self): """Concatenation along the sharded split dimension.""" @@ -50,21 +84,6 @@ def testCatNonSplitDim(self): class ConvTest(unittest.TestCase): - def testAllGather(self): - shard_count = 3 - shard_shape = [3, 4] - shard_dim = 1 - shards = [ - torch.rand(shard_shape, dtype=torch.float32) for i in range(shard_count) - ] - expected_result = torch.cat(shards, dim=shard_dim) - - sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) - actual_result = ops.all_gather(sharded) - - for shard in actual_result.shards: - torch.testing.assert_close(shard.as_torch(), expected_result) - def testConv2dShardedInputAndOutputChannelsOneGroup(self): batches = 2 in_channels = 6 @@ -337,6 +356,32 @@ def testNotEqualSharded(self): assert not ops.equal(b_sharded, a_sharded) +class FlattenTest(unittest.TestCase): + def testReplicated(self): + tensor = torch.rand(2, 3, 4, 5) + unsharded_expected_result = torch.flatten(tensor, start_dim=1, end_dim=2) + expected_result = ops.replicate(unsharded_expected_result, count=2) + sharded_tensor = ops.replicate(tensor, count=2) + actual_result = ops.flatten(sharded_tensor, start_dim=1, end_dim=2) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorFlattenNonSplitDim(self): + tensor = torch.rand(2, 3, 4, 5) + unsharded_expected_result = torch.flatten(tensor, start_dim=1, end_dim=2) + expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=3, count=2) + actual_result = ops.flatten(sharded_tensor, start_dim=1, end_dim=2) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorSplitDimIsLeadingFlattenDim(self): + tensor = torch.rand(3, 4, 5, 6) + unsharded_expected_result = torch.flatten(tensor, start_dim=1, end_dim=2) + expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=1, count=2) + actual_result = ops.flatten(sharded_tensor, start_dim=1, end_dim=2) + assert expected_result.is_deep_equal(actual_result) + + class GemmTest(unittest.TestCase): def testShardedParallelDim(self): a = torch.rand(4, 3) @@ -356,6 +401,47 @@ def testShardedParallelDim(self): torch.testing.assert_close(actual, expected) +class IndexCopyTest(unittest.TestCase): + def testSplitInPlace(self): + torch.set_default_dtype(torch.float32) + tensor = torch.rand(3, 4, 5, 6) + dim = 2 + source = torch.rand(3, 4, 2, 6) + index = torch.tensor([1, 3]) + expected_result = torch.index_copy(tensor, dim, index, source) + + split_dim = 1 + shard_count = 2 + sharded_tensor = ops.reshard_split(tensor, dim=split_dim, count=shard_count) + sharded_index = ops.replicate(index, count=shard_count) + sharded_source = ops.reshard_split(source, dim=split_dim, count=shard_count) + sharded_result = ops.index_copy_( + sharded_tensor, dim, sharded_index, sharded_source + ) + assert sharded_tensor is sharded_result + actual_result = ops.unshard(sharded_tensor) + assert ops.equal(actual_result, expected_result) + + +class IndexPutTest(unittest.TestCase): + def testSplitNonIndexDimInPlace(self): + torch.set_default_dtype(torch.float32) + tensor = torch.rand(3, 4, 5, 6) + indices = ( + torch.tensor([1, 2], dtype=torch.long), + torch.tensor([2, 3], dtype=torch.long), + ) + values = torch.rand(2, 5, 6) + expected_result = tensor.clone().index_put_(indices, values) + shard_count = 2 + sharded_tensor = ops.reshard_split(tensor.clone(), dim=3, count=shard_count) + sharded_values = ops.reshard_split(values, dim=2, count=shard_count) + sharded_result = ops.index_put_(sharded_tensor, indices, sharded_values) + assert sharded_tensor is sharded_result + actual_result = ops.unshard(sharded_tensor) + assert ops.equal(actual_result, expected_result) + + class InterpolateTest(unittest.TestCase): def testInterpolateSplitChannelDim(self): batches = 2 @@ -502,6 +588,60 @@ def testShardedPrimitiveTensorPermute(self): assert ops.equal(expected_result, result) +class AttentionTest(unittest.TestCase): + def testAttentionShardedBatch(self): + q = torch.rand(4, 32, 16, dtype=torch.float32) + k = torch.rand(4, 32, 16, dtype=torch.float32) + v = torch.rand(4, 32, 16, dtype=torch.float32) + + qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0)) + ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0)) + vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0)) + + expected_result = ops.scaled_dot_product_attention(q, k, v, a=None) + sharded_result = ops.scaled_dot_product_attention(qs, ks, vs, a=None) + unsharded_result = ops.sharded_cat(sharded_result) + torch.testing.assert_close(unsharded_result, expected_result) + + def testAttentionShardedBatchCausal(self): + q = torch.rand(4, 32, 16, dtype=torch.float32) + k = torch.rand(4, 32, 16, dtype=torch.float32) + v = torch.rand(4, 32, 16, dtype=torch.float32) + + qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0)) + ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0)) + vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0)) + + expected_result = ops.scaled_dot_product_attention( + q, k, v, a=None, is_causal=True + ) + sharded_result = ops.scaled_dot_product_attention( + qs, ks, vs, a=None, is_causal=True + ) + unsharded_result = ops.sharded_cat(sharded_result) + torch.testing.assert_close(unsharded_result, expected_result) + + def testAttentionShardedBatchMask(self): + q = torch.rand(4, 32, 16, dtype=torch.float32) + k = torch.rand(4, 32, 16, dtype=torch.float32) + v = torch.rand(4, 32, 16, dtype=torch.float32) + a = torch.rand(1, 32, 32, dtype=torch.float32) > 0.5 + + q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0)) + k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0)) + v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0)) + a_s = ReplicatedTensor(ts=a, shard_count=4) + + expected_result = ops.scaled_dot_product_attention( + q, k, v, a=a, is_causal=False + ) + sharded_result = ops.scaled_dot_product_attention( + q_s, k_s, v_s, a=a_s, is_causal=False + ) + unsharded_result = ops.sharded_cat(sharded_result) + torch.testing.assert_close(unsharded_result, expected_result) + + class MatmulTest(unittest.TestCase): def testTorchRHSColumnShardedTransposed(self): t1 = torch.rand(4, 32, 16, dtype=torch.float32) @@ -669,6 +809,21 @@ def compute(input, ffn_gate_weight, ffn_down_weight, ffn_up_weight): ) torch.testing.assert_close(Z_sharded, Z_ref) + def testSameSplitLhsAndRhsBatchDim(self): + a = torch.rand(3, 4, 5, 6) + b = torch.rand(3, 4, 6, 7) + shard_count = 2 + shard_dim = 1 + expected_result = torch.matmul(a, b) + sharded_a = ops.reshard_split(a, dim=shard_dim, count=shard_count) + sharded_b = ops.reshard_split(b, dim=shard_dim, count=shard_count) + sharded_result = ops.matmul(sharded_a, sharded_b) + assert isinstance(sharded_result, SplitPrimitiveTensor) + assert sharded_result.shard_count == shard_count + assert sharded_result.shard_dim == shard_dim + actual_result = unbox_tensor(ops.unshard(sharded_result)) + torch.testing.assert_close(actual_result, expected_result) + class ReplicateTest(unittest.TestCase): def testReplicateReplicated(self): @@ -685,6 +840,102 @@ def testReplicateUnsharded(self): expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count) assert expected_result.is_deep_equal(actual_result) + # Test that is a copy. + tensor[...] = torch.rand_like(tensor) + assert all(not ops.equal(tensor, shard) for shard in actual_result.shards) + + +class ReshapeTest(unittest.TestCase): + def testSplitTensorFlattenNonSplitDim(self): + tensor = torch.rand(2, 3, 4, 5) + new_shape = [2, 12, 5] + unsharded_expected_result = torch.reshape(tensor, new_shape) + expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=3, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorSplitDimIsLeadingFlattenDim(self): + tensor = torch.rand(3, 4, 5, 6) + new_shape = [3, 20, 6] + unsharded_expected_result = torch.reshape(tensor, new_shape) + expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=1, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorInsertSize1DimBeforeSplitDim(self): + tensor = torch.rand(4, 5, 6, 7) + new_shape = [4, 1, 5, 6, 7] + unsharded_expected_result = torch.reshape(tensor, new_shape) + shard_dim = 2 + expected_result = ops.reshard_split( + unsharded_expected_result, dim=shard_dim + 1, count=2 + ) + sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorInsertMultipleSize1DimsBeforeSplitDim(self): + tensor = torch.rand(4, 5, 6, 7) + new_shape = [4, 1, 1, 5, 6, 7] + unsharded_expected_result = torch.reshape(tensor, new_shape) + shard_dim = 2 + expected_result = ops.reshard_split( + unsharded_expected_result, dim=shard_dim + 2, count=2 + ) + sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorInsertMultipleSize1TrailingDimsNotRightAfterSplitDim(self): + tensor = torch.rand(4, 5, 6, 7) + new_shape = [4, 5, 6, 7, 1, 1] + unsharded_expected_result = torch.reshape(tensor, new_shape) + shard_dim = 2 + expected_result = ops.reshard_split( + unsharded_expected_result, dim=shard_dim, count=2 + ) + sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorUnflattenNonSplitDim(self): + tensor = torch.rand(3, 20, 6) + new_shape = [3, 4, 5, 6] + unsharded_expected_result = torch.reshape(tensor, new_shape) + expected_result = ops.reshard_split(unsharded_expected_result, dim=3, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=2, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorUnflattenTrailingNonSplitDim(self): + tensor = torch.rand(3, 4, 30) + new_shape = [3, 4, 5, 6] + unsharded_expected_result = torch.reshape(tensor, new_shape) + expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=1, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorUnflattenSplitDim(self): + tensor = torch.rand(3, 20, 6) + new_shape = [3, 4, 5, 6] + unsharded_expected_result = torch.reshape(tensor, new_shape) + expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=1, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + + def testSplitTensorUnflattenTrailingSplitDim(self): + tensor = torch.rand(2, 3, 20) + new_shape = [2, 3, 4, 5] + unsharded_expected_result = torch.reshape(tensor, new_shape) + expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2) + sharded_tensor = ops.reshard_split(tensor, dim=2, count=2) + actual_result = ops.reshape(sharded_tensor, new_shape) + assert expected_result.is_deep_equal(actual_result) + class ReshardSplitTest(unittest.TestCase): def testReshardReplicated(self): @@ -708,6 +959,11 @@ def testReshardUnsharded(self): ) assert expected_result.is_deep_equal(actual_result) + # Test that is a copy. + tensor[...] = torch.rand_like(tensor) + result_split2 = ops.reshard_split(tensor, dim=shard_dim, count=shard_count) + assert not ops.equal(actual_result, result_split2) + def testReshardSharded(self): tensor = torch.rand(4, 5, 6, dtype=torch.float32) shard_dim = 2 diff --git a/shortfin/tests/framework/device_session_test.py b/sharktank/tests/serving_poc/framework/device_session_test.py similarity index 96% rename from shortfin/tests/framework/device_session_test.py rename to sharktank/tests/serving_poc/framework/device_session_test.py index 7b4916eb4..5dfdd5f46 100644 --- a/shortfin/tests/framework/device_session_test.py +++ b/sharktank/tests/serving_poc/framework/device_session_test.py @@ -6,7 +6,7 @@ import pytest -from shortfin.framework.session import ( +from sharktank.serving_poc.framework.session import ( DeviceSession, ) diff --git a/shortfin/tests/llm/api_server_test.py b/sharktank/tests/serving_poc/llm/api_server_test.py similarity index 93% rename from shortfin/tests/llm/api_server_test.py rename to sharktank/tests/serving_poc/llm/api_server_test.py index fe4153dae..c2d2cc36a 100644 --- a/shortfin/tests/llm/api_server_test.py +++ b/sharktank/tests/serving_poc/llm/api_server_test.py @@ -39,7 +39,7 @@ def __init__(self, args): [ sys.executable, "-m", - "shortfin.llm.api.rest_server", + "sharktank.serving_poc.llm.api.rest_server", "--testing-mock-service", "--port=" + port, ] @@ -77,6 +77,11 @@ def __del__(self): @pytest.fixture(scope="session") def server(): + try: + import fastapi + import uvicorn + except ModuleNotFoundError as e: + pytest.skip(f"Skipping server test because deps are missing: {e}") runner = ServerRunner([]) yield runner diff --git a/shortfin/tests/llm/service_v1_test.py b/sharktank/tests/serving_poc/llm/service_v1_test.py similarity index 90% rename from shortfin/tests/llm/service_v1_test.py rename to sharktank/tests/serving_poc/llm/service_v1_test.py index 56472ecc2..c010e2034 100644 --- a/shortfin/tests/llm/service_v1_test.py +++ b/sharktank/tests/serving_poc/llm/service_v1_test.py @@ -10,28 +10,28 @@ HalElementType, ) -from shortfin.framework.session import DeviceSession -from shortfin.llm.config import ( +from sharktank.serving_poc.framework.session import DeviceSession +from sharktank.serving_poc.llm.config import ( CacheParams, ModelParams, ServiceParams, ) -from shortfin.llm.service import ( +from sharktank.serving_poc.llm.service import ( GenerateRequest, GenerateResponsePart, ) -from shortfin.llm.attn_block_cache import ( +from sharktank.serving_poc.llm.attn_block_cache import ( create_attn_block_cache_module, AttnBlockCache, ) -from shortfin.llm.impl.service_v1 import ( +from sharktank.serving_poc.llm.impl.service_v1 import ( GenerateServiceV1, ) -from shortfin.llm.testing.fake_v1_module import ( +from sharktank.serving_poc.llm.testing.fake_v1_module import ( create_fake_module, ) diff --git a/sharktank/tests/transforms/dataset_transforms_test.py b/sharktank/tests/transforms/dataset_transforms_test.py index 4a57b6229..928d1e4ac 100644 --- a/sharktank/tests/transforms/dataset_transforms_test.py +++ b/sharktank/tests/transforms/dataset_transforms_test.py @@ -19,8 +19,8 @@ from sharktank.utils.testing import MainRunnerTestBase -class MmtRHSShardingTransformTest(MainRunnerTestBase): - def testPrimitive(self): +class DatasetShardingTransformTest(MainRunnerTestBase): + def testShardLlmDataset(self): orig_pts = [ DefaultPrimitiveTensor( name="blk.1.attn_k.weight", data=torch.randn([32, 128]) @@ -28,9 +28,19 @@ def testPrimitive(self): DefaultPrimitiveTensor( name="blk.2.attn_q.weight", data=torch.randn([48, 64]) ), - DefaultPrimitiveTensor(name="other", data=torch.randn([2, 2])), ] - ds_orig = Dataset({}, Theta(orig_pts)) + ds_orig = Dataset( + { + "general.architecture": "llm", + "llm.attention.head_count": 1, + "llm.context_length": 2, + "llm.embedding_length": 3, + "llm.block_count": 4, + "llm.feed_forward_length": 5, + "llm.attention.layer_norm_rms_epsilon": 0.1, + }, + Theta(orig_pts), + ) input_path = self.save_dataset(ds_orig, "input") output_path = self.get_irpa_path("output") from sharktank.examples.sharding import shard_llm_dataset @@ -41,38 +51,38 @@ def testPrimitive(self): input_path, "--output-irpa-file", output_path, - "--num-shards", + "--tensor-parallelism-size", 8, ) ds_tran = Dataset.load(output_path, mmap=False) + ds_tran.properties["tensor_parallelism_size"] = 8 + # Verify. flat_sts = ds_tran.root_theta.flatten() - self.assertEqual(3, len(flat_sts)) + self.assertEqual(2, len(flat_sts)) st_1 = flat_sts["blk.1.attn_k.weight"] st_2 = flat_sts["blk.2.attn_q.weight"] - pt_3 = flat_sts["other"] self.assertIsInstance(st_1, SplitPrimitiveTensor) self.assertIsInstance(st_2, SplitPrimitiveTensor) - self.assertIsInstance(pt_3, DefaultPrimitiveTensor) self.assertListEqual(st_1.shape, [32, 128]) self.assertListEqual(st_2.shape, [48, 64]) # Verify component shapes for st_1. self.assertEqual(8, len(st_1.shards)) - self.assertTrue(all(pt.shape == [32, 16] for pt in st_1.shards)) + self.assertTrue(all(pt.shape == [4, 128] for pt in st_1.shards)) self.assertTrue( - all(list(pt.as_torch().shape) == [32, 16] for pt in st_1.shards) + all(list(pt.as_torch().shape) == [4, 128] for pt in st_1.shards) ) # Verify component shapes for st_2. self.assertEqual(8, len(st_2.shards)) - self.assertTrue(all(pt.shape == [48, 8] for pt in st_2.shards)) - self.assertTrue(all(list(pt.as_torch().shape) == [48, 8] for pt in st_2.shards)) + self.assertTrue(all(pt.shape == [6, 64] for pt in st_2.shards)) + self.assertTrue(all(list(pt.as_torch().shape) == [6, 64] for pt in st_2.shards)) # Verify contents for one shard for sanity. new_t = st_1.shards[0].as_torch() - torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(16, dim=1)[0]) + torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(4, dim=0)[0]) if __name__ == "__main__": diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 0d79785f6..4494eab2f 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -11,12 +11,12 @@ import torch -from shark_turbine.aot import ExternalTensorTrait +from iree.turbine.aot import ExternalTensorTrait from sharktank.types import * def _t(name: str, *dims: int): - return DefaultPrimitiveTensor(name=name, data=torch.empty(*dims)) + return DefaultPrimitiveTensor(name=name, data=torch.ones(*dims)) def _flat_t_dict(*ts): @@ -77,6 +77,22 @@ def testTransform(self): self.assertIsNot(pt1, pt2) torch.testing.assert_close(pt1, pt2) + def testPop(self): + t1 = Theta( + _flat_t_dict( + _t("a.b.c", 1, 2), + _t("a.c.d", 10, 11), + _t("a.b.3", 3, 4), + ) + ) + popped = t1.pop("a.b").flatten() + t1 = t1.flatten() + + self.assertIsNotNone("a.c.d", t1.keys()) + self.assertNotIn("a.b.c", t1.keys()) + self.assertNotIn("a.b.3", t1.keys()) + self.assertIn("a.b.3", popped.keys()) + class DatasetTest(unittest.TestCase): def setUp(self): diff --git a/sharktank/tests/types/quantizers_test.py b/sharktank/tests/types/quantizers_test.py index 787725e88..b712da06a 100644 --- a/sharktank/tests/types/quantizers_test.py +++ b/sharktank/tests/types/quantizers_test.py @@ -9,6 +9,7 @@ import torch from sharktank.types import * +from sharktank.types.layout_utils import saturate_cast from sharktank.utils.testing import TempDirTestBase @@ -164,6 +165,80 @@ def testQuantDequantf8fnuz(self): dequant_value = layout.dequant() torch.testing.assert_close(orig_value, dequant_value, atol=1e-1, rtol=1e-1) + def testQuarkF8Hell(self): + # we use hardcoded values here because they're representative of actual values from a quark model + scale = torch.tensor(0.0118, dtype=torch.float64) + orig = torch.tensor( + [ + -58, + -48, + -70, + 53, + -53, + 76, + -71, + -90, + 50, + 77, + 62, + -98, + 66, + -54, + 55, + -80, + -66, + -62, + -61, + -56, + 56, + -67, + 79, + -60, + -71, + 42, + 72, + -73, + 91, + 63, + 124, + -128, + ], + dtype=torch.int8, + ) + # mirrors dequant logic in quark and our importer + orig = orig.view(torch.float8_e4m3fn) + orig = (orig.to(torch.float64) * scale).to(torch.float16) + # Note that for fnuz we have to do scale*2 to account for the difference between types + # We specify the reciprocal scale explicitly to avoid adding more floating point error noise + fnuz = StaticScaledQuantizer( + name="dopoo", + scale=1.0 / (scale * 2), + reciprocal_scale=scale * 2, + offset=None, + dtype=torch.float8_e4m3fnuz, + ) + fn = StaticScaledQuantizer( + name="poodoo", + scale=1.0 / scale, + reciprocal_scale=scale, + offset=None, + dtype=torch.float8_e4m3fn, + ) + fnuz_quant = fnuz.quantize(orig) + fn_quant = fn.quantize(orig) + + dequant_fnuz = fnuz_quant.unpack().dequant() + dequant_fn = fn_quant.unpack().dequant() + + # redundant asserts for sanity + torch.testing.assert_close( + orig.to(torch.float16), dequant_fnuz, atol=1e-3, rtol=1e-3 + ) + torch.testing.assert_close( + orig.to(torch.float16), dequant_fn, atol=1e-3, rtol=1e-3 + ) + torch.testing.assert_close(dequant_fnuz, dequant_fn, atol=1e-3, rtol=1e-3) + if __name__ == "__main__": unittest.main() diff --git a/sharktank/tests/types/tensors_test.py b/sharktank/tests/types/tensors_test.py index 8ee839b6e..4af4a513f 100644 --- a/sharktank/tests/types/tensors_test.py +++ b/sharktank/tests/types/tensors_test.py @@ -62,10 +62,8 @@ def transform2(d): class ShardedTensorTest(unittest.TestCase): def testReplicatedTensorSaveLoad(self): - tensor = torch.rand([2, 3, 4], dtype=torch.float32) - replicated_tensor = ReplicatedTensor( - ts=tensor, shard_count=3, name="the_tensor" - ) + tensor = [torch.rand([2, 3, 4], dtype=torch.float32)] * 3 + replicated_tensor = ReplicatedTensor(ts=tensor, name="the_tensor") theta = Theta([replicated_tensor]) dataset = Dataset({}, theta) with tempfile.TemporaryDirectory() as tmp_dir: @@ -140,6 +138,34 @@ def testSplitTensorExtractSliceOfNonSplitDim(self): actual_result = ops.reshard_like(sharded_slice, expected_result) assert ops.equal(expected_result, actual_result) + def testSplitTensorExtractSliceWithEllipsis(self): + tensor = torch.rand([2, 3, 4, 5]) + sharded_tensor = ops.reshard_split(tensor, dim=2, count=2) + expected_result = tensor[0, ..., 1:3] + expected_sharded_result = ops.reshard_split(expected_result, dim=1, count=2) + actual_sharded_result = sharded_tensor[0, ..., 1:3] + assert ops.equal(actual_sharded_result, expected_sharded_result) + + def testSplitTensorInsertSliceOfAllDimsWithEllipsis(self): + dst = torch.rand([2, 3, 4]) + src = torch.rand([2, 3, 4]) + sharded_dst = ops.reshard_split(dst.clone(), dim=1, count=3) + sharded_src = ops.reshard_like(src, like=sharded_dst) + dst[...] = src + sharded_dst[...] = sharded_src + actual_result = ops.unshard(sharded_dst) + assert ops.equal(actual_result, dst) + + def testSplitTensorInsertSliceWithEllipsis(self): + dst = torch.rand([2, 3, 4, 5]) + src = torch.rand([3, 4, 2]) + sharded_dst = ops.reshard_split(dst.clone(), dim=2, count=2) + sharded_src = ops.reshard_split(src, dim=1, count=2) + dst[0, ..., 1:3] = src + sharded_dst[0, ..., 1:3] = sharded_src + actual_result = ops.unshard(sharded_dst) + assert ops.equal(actual_result, dst) + if __name__ == "__main__": unittest.main() diff --git a/sharktank/version.json b/sharktank/version.json new file mode 100644 index 000000000..9519501ae --- /dev/null +++ b/sharktank/version.json @@ -0,0 +1,3 @@ +{ + "package-version": "3.1.0.dev" +} diff --git a/libshortfin/.clang-format b/shortfin/.clang-format similarity index 100% rename from libshortfin/.clang-format rename to shortfin/.clang-format diff --git a/shortfin/.gitignore b/shortfin/.gitignore new file mode 100644 index 000000000..000e575d5 --- /dev/null +++ b/shortfin/.gitignore @@ -0,0 +1,2 @@ +# Local-only config options +version_info_rc.json diff --git a/shortfin/.readthedocs.yaml b/shortfin/.readthedocs.yaml new file mode 100644 index 000000000..582088587 --- /dev/null +++ b/shortfin/.readthedocs.yaml @@ -0,0 +1,16 @@ +version: "2" + +build: + os: "ubuntu-24.04" + tools: + python: "3.12" + jobs: + pre_build: + - python -m pip install -v shortfin/ + +python: + install: + - requirements: shortfin/docs/requirements.txt + +sphinx: + configuration: shortfin/docs/conf.py diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt new file mode 100644 index 000000000..f025eccfe --- /dev/null +++ b/shortfin/CMakeLists.txt @@ -0,0 +1,268 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +cmake_minimum_required(VERSION 3.29) + +if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR) + message( + FATAL_ERROR + "Do not build in-source. Please remove CMakeCache.txt and the CMakeFiles/ directory. Then build out-of-source." + ) +endif() + +# Get version number from file +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/version.json VERSION_JSON_STRING) +string(JSON PACKAGE_VERSION GET ${VERSION_JSON_STRING} package-version) +string(REGEX MATCH "(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*" BASE_VERSION ${PACKAGE_VERSION}) + +project( + "libshortfin" + VERSION ${BASE_VERSION} + LANGUAGES C CXX) + +include(CMakeDependentOption) + +set(SOVERSION 1) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 20) +# https://discourse.cmake.org/t/cmake-3-28-cmake-cxx-compiler-clang-scan-deps-notfound-not-found/9244/3 +set(CMAKE_CXX_SCAN_FOR_MODULES 0) +set(CMAKE_EXPORT_COMPILE_COMMANDS 1) + +# Problems with linking libfmt without PIC. +# Turn on PIC on non windows targets. +if(NOT WIN32) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) +endif() + +# Pins +set(SHORTFIN_IREE_GIT_TAG "iree-3.0.0rc20241118") + +# build options +option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF) +option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON) +option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" ON) +option(SHORTFIN_ENABLE_TRACING "Enable runtime tracing for iree and shortfin" OFF) +option(SHORTFIN_ENABLE_LTO "Enables LTO if supported" ON) + +set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source") + +# Options for building static or dynamic libraries. +# Default to dynamic linking, unless on Windows. +# TODO(#211): Unify the defaults once Windows dynamic linking issues are fixed. +set(SHORTFIN_BUILD_STATIC_DEFAULT OFF) +set(SHORTFIN_BUILD_DYNAMIC_DEFAULT ON) +if(WIN32) + set(SHORTFIN_BUILD_STATIC_DEFAULT ON) + set(SHORTFIN_BUILD_DYNAMIC_DEFAULT OFF) +endif() +option(SHORTFIN_BUILD_STATIC "Builds static libraries" ${SHORTFIN_BUILD_STATIC_DEFAULT}) +option(SHORTFIN_BUILD_DYNAMIC "Builds dynamic libraries" ${SHORTFIN_BUILD_DYNAMIC_DEFAULT}) +cmake_dependent_option(SHORTFIN_LINK_DYNAMIC "Links internal binaries against static libshortfin.a" ON "SHORTFIN_BUILD_DYNAMIC" OFF) +if(NOT SHORTFIN_BUILD_STATIC AND NOT SHORTFIN_BUILD_DYNAMIC) + message(FATAL_ERROR "One of SHORTFIN_BUILD_STATIC or SHORTFIN_BUILD_DYNAMIC must be ON") +endif() +message(STATUS "Shortfin build static = ${SHORTFIN_BUILD_STATIC}, dynamic = ${SHORTFIN_BUILD_DYNAMIC}") +if(SHORTFIN_LINK_DYNAMIC) + message(STATUS "Dynamic linking to shortfin") + set(SHORTFIN_LINK_LIBRARY_NAME "shortfin") +else() + message(STATUS "Static linking to shortfin-static") + set(SHORTFIN_LINK_LIBRARY_NAME "shortfin-static") +endif() + +# Includes. +list(APPEND CMAKE_MODULE_PATH + ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/ +) +include(shortfin_library) +include(CheckCXXCompilerFlag) +include(FetchContent) + +################################################################################ +# Toolchain features +################################################################################ + +if(SHORTFIN_ENABLE_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT SHORTFIN_LTO_SUPPORTED OUTPUT SHORTFIN_LTO_ERROR) + if(SHORTFIN_LTO_SUPPORTED) + message(STATUS "Shortfin LTO Enabled") + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "Could not enable LTO (not supported): ${SHORTFIN_LTO_ERROR}") + endif() +endif() + +# Enabling ASAN. Note that this will work best if building in a completely +# bundled fashion and with an ASAN rigged CPython. Otherwise, various LD_PRELOAD +# hacks are needed. This is merely a develope convenience: people are more +# than welcome to set flags themselves. +option(SHORTFIN_ENABLE_ASAN "Enable ASAN" OFF) +if(SHORTFIN_ENABLE_ASAN) + add_compile_options(-fsanitize=address) + add_link_options(-fsanitize=address) + + # Enable more ASAN checks. + add_compile_definitions(IREE_SANITIZER_ADDRESS) +endif() + +# Thread safety annotations: Enabled if the compiler supports it. +check_cxx_compiler_flag("-Wthread-safety" SHORTFIN_HAS_THREAD_SAFETY_ANNOTATIONS) +if(SHORTFIN_HAS_THREAD_SAFETY) + add_compile_options(-Wthread-safety) + add_compile_definitions(SHORTFIN_HAS_THREAD_SAFETY_ANNOTATIONS) +endif() + +option(SHORTFIN_SYSTEMS_AMDGPU "Builds for AMD GPU systems" ON) +message(STATUS "shortfin supported systems:") +if(SHORTFIN_SYSTEMS_AMDGPU) + message(STATUS " - AMD GPU") +endif() +message(STATUS " - Host") + +################################################################################ +# Dependencies +################################################################################ + +if(SHORTFIN_BUNDLE_DEPS) + ## fmt + FetchContent_Declare( + fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt.git + GIT_TAG e69e5f977d458f2650bb346dadf2ad30c5320281 # 10.2.1 (sync with spdlog) + ) + + ## spdlog + # We build fmt from source instead, because we also use fmt. + set(SPDLOG_FMT_EXTERNAL ON) + FetchContent_Declare( + spdlog + GIT_REPOSITORY https://github.com/gabime/spdlog.git + GIT_TAG 2d4acf8cc321d7783d8f2e22e17a794c6d0e9450 # v1.14.1 + ) + + ## xtl: required for xtensor + FetchContent_Declare( + xtl + GIT_REPOSITORY https://github.com/xtensor-stack/xtl.git + GIT_TAG a7c1c5444dfc57f76620391af4c94785ff82c8d6 # v0.7.7 + ) + + ## xtensor + FetchContent_Declare( + xtensor + GIT_REPOSITORY https://github.com/xtensor-stack/xtensor.git + GIT_TAG 3634f2ded19e0cf38208c8b86cea9e1d7c8e397d # v0.25.0 + ) + + # In order to bundle libraries without conflict, we have to tweak settings. + shortfin_push_bundled_lib_options() + # Enable spdlog shared library options so we can export it. + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPDLOG_SHARED_LIB -Dspdlog_EXPORTS") + FetchContent_MakeAvailable(fmt spdlog xtl xtensor) + shortfin_pop_bundled_lib_options() +else() + find_package(spdlog) + find_package(xtensor) +endif() + +################################################################################ +# IREE +################################################################################ + +# Set IREE build flags. +# We currently rely on IREE to have visible symbols in order to re-export +# its API for further use. +# TODO: Turn this back on and use explicit visibility control in the IREE +# runtime and linker scripts. +set(IREE_VISIBILITY_HIDDEN OFF) +set(IREE_BUILD_COMPILER OFF) +set(IREE_BUILD_TESTS OFF) +set(IREE_BUILD_SAMPLES OFF) +# Disable missing submodules error because we are only building the runtime. +set(IREE_ERROR_ON_MISSING_SUBMODULES OFF) +# Only enable local_sync/local_task/hip drivers for now. +set(IREE_HAL_DRIVER_DEFAULTS OFF) +set(IREE_HAL_DRIVER_LOCAL_SYNC ON) +set(IREE_HAL_DRIVER_LOCAL_TASK ON) +if(SHORTFIN_SYSTEMS_AMDGPU) + set(IREE_HAL_DRIVER_HIP ON) +endif() +if (SHORTFIN_ENABLE_TRACING) + set(IREE_ENABLE_RUNTIME_TRACING ON) + # When using shared libraries there are some issues that need to be + # explored more on static initialization order. Something is getting + # initialized and is emitting tracy events before tracy objects are + # initialized. This can point to some shared library overloading allocation + # functions and making them emit tracy events, which are further used in + # some static allocation. See https://github.com/wolfpld/tracy/issues/196 + # for a similar issue and discussion. Using the workaround suggested in + # that issue for now. Note that this does not happen when using static + # libraries. + set(TRACY_DELAYED_INIT ON CACHE BOOL "Enable delayed init for tracy") +endif() + +# In order to bundle libraries without conflict, we have to tweak settings. +shortfin_push_bundled_lib_options() +if(SHORTFIN_IREE_SOURCE_DIR) + message(STATUS "Using existing IREE sources: ${SHORTFIN_IREE_SOURCE_DIR}") + add_subdirectory(${SHORTFIN_IREE_SOURCE_DIR} shortfin_iree SYSTEM EXCLUDE_FROM_ALL) +else() + message(STATUS "Fetching IREE sources from tag ${SHORTFIN_IREE_GIT_TAG}") + + # TODO: We shouldn't have to pull googletest when we are not building tests. + # This needs to be fixed with IREE. + set(IREE_SUBMODULES "third_party/benchmark third_party/cpuinfo third_party/flatcc third_party/hip-build-deps third_party/googletest") + if (SHORTFIN_ENABLE_TRACING) + set(IREE_SUBMODULES "${IREE_SUBMODULES} third_party/tracy") + endif() + FetchContent_Declare( + shortfin_iree + GIT_REPOSITORY https://github.com/iree-org/iree.git + GIT_TAG "${SHORTFIN_IREE_GIT_TAG}" + GIT_SUBMODULES ${IREE_SUBMODULES} + GIT_SHALLOW TRUE + SYSTEM + EXCLUDE_FROM_ALL + ) + FetchContent_GetProperties(shortfin_iree) + if(NOT shortfin_iree_POPULATED) + FetchContent_MakeAvailable(shortfin_iree) + endif() +endif() +shortfin_pop_bundled_lib_options() + +################################################################################ +# Tests +################################################################################ + +if(SHORTFIN_BUILD_TESTS) + if (NOT SHORTFIN_BUNDLE_DEPS AND NOT SHORTFIN_IREE_SOURCE_DIR) + # For now we use gtest shipped alongside with IREE. + FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip + ) + # For Windows: Prevent overriding the parent project's compiler/linker settings + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(googletest) + endif() + include(GoogleTest) + enable_testing() +endif() + + +add_subdirectory(src) + +if(SHORTFIN_BUILD_PYTHON_BINDINGS) + find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + add_subdirectory(python) + set(SHORTFIN_PYTHON_CPP_PREBUILT "TRUE") # See setup.py. + configure_file(setup.py setup.py @ONLY) + configure_file(pyproject.toml pyproject.toml COPYONLY) +endif() diff --git a/shortfin/README.md b/shortfin/README.md index a76fad8c7..6269ca702 100644 --- a/shortfin/README.md +++ b/shortfin/README.md @@ -1,15 +1,215 @@ -# SHARK Shortfin Serving Infrastructure +# shortfin - SHARK inference library and serving engine -**WARNING: This is an early preview that is in progress. It is not ready for -general use.** +The shortfin project is SHARK's open source, high performance inference library +and serving engine. Shortfin consists of these major components: -This sub-project contains components and infrastructure for serving various -forms of sharktank compiled models. Instead of coming with models, it defines -ABIs that compiled models should adhere to in order to be served. It then -allows them to be delivered as web endpoints via popular APIs. +* The "libshortfin" inference library written in C/C++ and built on + [IREE](https://github.com/iree-org/iree) +* Python bindings for the underlying inference library +* Example applications in + ['shortfin_apps'](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps) + built using the python bindings -As emulation can be the sincerest form of flattery, this project derives -substantial inspiration from vllm and the OpenAI APIs, emulating and -interopping with them where possible. It is intended to be the lightest -weight possible reference implementation for serving models with an -opinionated compiled form, built elsewhere in the project. +## Prerequisites + +* Python 3.11+ + +## Simple user installation + +Install the latest stable version: + +```bash +pip install shortfin +``` + +## Developer guides + +### Quick start: install local packages and run tests + +After cloning this repository, from the `shortfin/` directory: + +```bash +pip install -e . +``` + +Install test requirements: + +```bash +pip install -r requirements-tests.txt +``` + +Run tests: + +```bash +pytest -s tests/ +``` + +### Simple dev setup + +We recommend this development setup for core contributors: + +1. Check out this repository as a sibling to [IREE](https://github.com/iree-org/iree) + if you already have an IREE source checkout. Otherwise, a pinned version will + be downloaded for you +2. Ensure that `python --version` reads 3.11 or higher (3.12 preferred). +3. Run `./dev_me.py` to build and install the `shortfin` Python package with both + a tracing-enabled and default build. Run it again to do an incremental build + and delete the `build/` directory to start over +4. Run tests with `python -m pytest -s tests/` +5. Test optional features: + * `pip install iree-base-compiler` to run a small suite of model tests intended + to exercise the runtime (or use a [source build of IREE](https://iree.dev/building-from-source/getting-started/#using-the-python-bindings)). + * `pip install onnx` to run some more model tests that depend on downloading + ONNX models + * Run tests on devices other than the CPU with flags like: + `--system amdgpu --compile-flags="--iree-hal-target-backends=rocm --iree-hip-target=gfx1100"` + * Use the tracy instrumented runtime to collect execution traces: + `export SHORTFIN_PY_RUNTIME=tracy` + +Refer to the advanced build options below for other scenarios. + +### Advanced build options + +1. Native C++ build +2. Local Python release build +3. Package Python release build +4. Python dev build + +Prerequisites + +* A modern C/C++ compiler, such as clang 18 or gcc 12 +* A modern Python, such as Python 3.12 + +#### Native C++ builds + +```bash +cmake -GNinja -S. -Bbuild \ + -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_LINKER_TYPE=LLD +cmake --build build --target all +``` + +If Python bindings are enabled in this mode (`-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON`), +then `pip install -e build/` will install from the build dir (and support +build/continue). + +#### Package Python release builds + +* To build wheels for Linux using a manylinux Docker container: + + ```bash + sudo ./build_tools/build_linux_package.sh + ``` + +* To build a wheel for your host OS/arch manually: + + ```bash + # Build shortfin.*.whl into the dist/ directory + # e.g. `shortfin-0.9-cp312-cp312-linux_x86_64.whl` + python3 -m pip wheel -v -w dist . + + # Install the built wheel. + python3 -m pip install dist/*.whl + ``` + +#### Python dev builds + +```bash +# Install build system pre-reqs (since we are building in dev mode, this +# is not done for us). See source of truth in pyproject.toml: +pip install setuptools wheel + +# Optionally install cmake and ninja if you don't have them or need a newer +# version. If doing heavy development in Python, it is strongly recommended +# to install these natively on your system as it will make it easier to +# switch Python interpreters and build options (and the launcher in debug/asan +# builds of Python is much slower). Note CMakeLists.txt for minimum CMake +# version, which is usually quite recent. +pip install cmake ninja + +SHORTFIN_DEV_MODE=ON pip install --no-build-isolation -v -e . +``` + +Note that the `--no-build-isolation` flag is useful in development setups +because it does not create an intermediate venv that will keep later +invocations of cmake/ninja from working at the command line. If just doing +a one-shot build, it can be ommitted. + +Once built the first time, `cmake`, `ninja`, and `ctest` commands can be run +directly from `build/cmake` and changes will apply directly to the next +process launch. + +Several optional environment variables can be used with setup.py: + +* `SHORTFIN_CMAKE_BUILD_TYPE=Debug` : Sets the CMAKE_BUILD_TYPE. Defaults to + `Debug` for dev mode and `Release` otherwise. +* `SHORTFIN_ENABLE_ASAN=ON` : Enables an ASAN build. Requires a Python runtime + setup that is ASAN clean (either by env vars to preload libraries or set + suppressions or a dev build of Python with ASAN enabled). +* `SHORTFIN_IREE_SOURCE_DIR=$(pwd)/../../iree` +* `SHORTFIN_RUN_CTESTS=ON` : Runs `ctest` as part of the build. Useful for CI + as it uses the version of ctest installed in the pip venv. + +### Running tests + +The project uses a combination of ctest for native C++ tests and pytest. Much +of the functionality is only tested via the Python tests, using the +`_shortfin.lib` internal implementation directly. In order to run these tests, +you must have installed the Python package as per the above steps. + +Which style of test is used is pragmatic and geared at achieving good test +coverage with a minimum of duplication. Since it is often much more expensive +to build native tests of complicated flows, many things are only tested via +Python. This does not preclude having other language bindings later, but it +does mean that the C++ core of the library must always be built with the +Python bindings to test the most behavior. Given the target of the project, +this is not considered to be a significant issue. + +#### Python tests + +Run platform independent tests only: + +```bash +pytest tests/ +``` + +Run tests including for a specific platform (in this example, a gfx1100 AMDGPU): + +(note that not all tests are system aware yet and some may only run on the CPU) + +```bash +pytest tests/ --system amdgpu \ + --compile-flags="--iree-hal-target-backends=rocm --iree-hip-target=gfx1100" +``` + +## Production library building + +In order to build a production library, additional build steps are typically +recommended: + +* Compile all deps with the same compiler/linker for LTO compatibility +* Provide library dependencies manually and compile them with LTO +* Compile dependencies with `-fvisibility=hidden` +* Enable LTO builds of libshortfin +* Set flags to enable symbol versioning + +## Miscellaneous build topics + +### Free-threaded Python + +Support for free-threaded Python builds (aka. "nogil") is in progress. It +is currently being tested via CPython 3.13 with the `--disable-gil` option set. +There are multiple ways to acquire such an environment: + +* Generally, see the documentation at + +* If using `pyenv`: + + ```bash + # Install a free-threaded 3.13 version. + pyenv install 3.13t + + # Test (should print "False"). + pyenv shell 3.13t + python -c 'import sys; print(sys._is_gil_enabled())' + ``` diff --git a/shortfin/build_tools/build_linux_package.sh b/shortfin/build_tools/build_linux_package.sh new file mode 100755 index 000000000..9f7388119 --- /dev/null +++ b/shortfin/build_tools/build_linux_package.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# Copyright 2024 Advanced Micro Devices, Inc. +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# build_linux_package.sh +# One stop build of shortfin Python packages for Linux. The Linux build is +# complicated because it has to be done via a manylinux docker container. +# +# Usage: +# Build everything (all python versions): +# sudo ./build_tools/build_linux_package.sh +# +# Build specific Python versions to custom directory: +# OVERRIDE_PYTHON_VERSIONS="cp312-cp312 cp313-cp313" \ +# OUTPUT_DIR="/tmp/wheelhouse" \ +# sudo -E ./build_tools/build_linux_package.sh +# +# Valid Python versions match a subdirectory under /opt/python in the docker +# image. Typically: +# cp312-cp312 cp313-cp313 +# +# Note that this script is meant to be run on CI and it will pollute both the +# output directory and in-tree build/ directories with docker created, root +# owned builds. Sorry - there is no good way around it. +# +# It can be run on a workstation but recommend using a git worktree dedicated +# to packaging to avoid stomping on development artifacts. +set -xeu -o errtrace + +THIS_DIR="$(cd $(dirname $0) && pwd)" +REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)" +SCRIPT_NAME="$(basename $0)" +ARCH="$(uname -m)" + +# TODO(#130): Update to manylinux_2_28, upstream or a fork +# * upstream uses a version of gcc that has build warnings/errors +# * https://github.com/nod-ai/base-docker-images is a bit out of date but can include a recent clang +# MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux_2_28_${ARCH}:latest}" +MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux2014_${ARCH}:latest}" +PYTHON_VERSIONS="${OVERRIDE_PYTHON_VERSIONS:-cp311-cp311 cp312-cp312 cp313-cp313}" +OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}" + +function run_on_host() { + echo "Running on host" + echo "Launching docker image ${MANYLINUX_DOCKER_IMAGE}" + + # Canonicalize paths. + mkdir -p "${OUTPUT_DIR}" + OUTPUT_DIR="$(cd "${OUTPUT_DIR}" && pwd)" + echo "Outputting to ${OUTPUT_DIR}" + mkdir -p "${OUTPUT_DIR}" + docker run --rm \ + -v "${REPO_ROOT}:${REPO_ROOT}" \ + -v "${OUTPUT_DIR}:${OUTPUT_DIR}" \ + -e __MANYLINUX_BUILD_WHEELS_IN_DOCKER=1 \ + -e "OVERRIDE_PYTHON_VERSIONS=${PYTHON_VERSIONS}" \ + -e "OUTPUT_DIR=${OUTPUT_DIR}" \ + "${MANYLINUX_DOCKER_IMAGE}" \ + -- ${THIS_DIR}/${SCRIPT_NAME} + + echo "******************** BUILD COMPLETE ********************" + echo "Generated binaries:" + ls -l "${OUTPUT_DIR}" +} + +function run_in_docker() { + echo "Running in docker" + echo "Marking git safe.directory" + git config --global --add safe.directory '*' + + echo "Using python versions: ${PYTHON_VERSIONS}" + local orig_path="${PATH}" + + # Build phase. + echo "******************** BUILDING PACKAGE ********************" + for python_version in ${PYTHON_VERSIONS}; do + python_dir="/opt/python/${python_version}" + if ! [ -x "${python_dir}/bin/python" ]; then + echo "ERROR: Could not find python: ${python_dir} (skipping)" + continue + fi + export PATH="${python_dir}/bin:${orig_path}" + echo ":::: Python version $(python --version)" + clean_wheels "shortfin" "${python_version}" + build_shortfin + run_audit_wheel "shortfin" "${python_version}" + done +} + +function build_shortfin() { + export SHORTFIN_ENABLE_TRACING=ON + python -m pip wheel --disable-pip-version-check -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/shortfin" +} + +function run_audit_wheel() { + local wheel_basename="$1" + local python_version="$2" + # Force wildcard expansion here + generic_wheel="$(echo "${OUTPUT_DIR}/${wheel_basename}-"*"-${python_version}-linux_${ARCH}.whl")" + ls "${generic_wheel}" + echo ":::: Auditwheel ${generic_wheel}" + auditwheel repair -w "${OUTPUT_DIR}" "${generic_wheel}" + rm -v "${generic_wheel}" +} + +function clean_wheels() { + local wheel_basename="$1" + local python_version="$2" + echo ":::: Clean wheels ${wheel_basename} ${python_version}" + rm -f -v "${OUTPUT_DIR}/${wheel_basename}-"*"-${python_version}-"*".whl" +} + +# Trampoline to the docker container if running on the host. +if [ -z "${__MANYLINUX_BUILD_WHEELS_IN_DOCKER-}" ]; then + run_on_host "$@" +else + run_in_docker "$@" +fi diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake new file mode 100644 index 000000000..aaa97a6c1 --- /dev/null +++ b/shortfin/build_tools/cmake/shortfin_library.cmake @@ -0,0 +1,212 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(SHORTFIN_DEFAULT_COPTS + # General clang and GCC options application to C and C++. + $<$: + -Wall + -Werror + > + + # General MSVC options applicable to C and C++. + $<$: + > +) + +function(shortfin_public_library) + cmake_parse_arguments( + _RULE + "" + "NAME;LINUX_LD_SCRIPT" + "SRCS;COMPONENTS;USAGE_DEPS" + ${ARGN} + ) + + # Get usage requirements from requested USAGE_DEPS and forward them from the + # public library. This is because the contents of the libraries stop at the + # public library vs propagating to callers. So we must manually control + # the combined usage requirements of the aggregate library. + set(_usage_include_directories) + set(_usage_compile_definitions) + foreach(_usage_dep _shortfin_defs ${_RULE_USAGE_DEPS}) + get_target_property(_value ${_usage_dep} INTERFACE_INCLUDE_DIRECTORIES) + if(_value) + list(APPEND _usage_include_directories ${_value}) + endif() + get_target_property(_value ${_usage_dep} INTERFACE_COMPILE_DEFINITIONS) + if(_value) + list(APPEND _usage_compile_definitions ${_value}) + endif() + endforeach() + + # Useful for debugging include/definition issues. + # message(STATUS "Public library ${_RULE_NAME}: Includes = ${_usage_include_directories}") + # message(STATUS "Public library ${_RULE_NAME}: Definitions = ${_usage_compile_definitions}") + + if(SHORTFIN_BUILD_STATIC) + # Static library. + shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS}) + add_library("${_RULE_NAME}-static" STATIC ${_RULE_SRCS}) + target_compile_definitions("${_RULE_NAME}-static" INTERFACE + _SHORTFIN_USING_DYLIB + ${_usage_compile_definitions} + ) + target_include_directories("${_RULE_NAME}-static" INTERFACE ${_usage_include_directories}) + target_link_libraries( + "${_RULE_NAME}-static" + PRIVATE ${_STATIC_COMPONENTS} + ) + endif() + + if(SHORTFIN_BUILD_DYNAMIC) + # Dylib library. + shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS}) + add_library("${_RULE_NAME}" SHARED ${_RULE_SRCS}) + target_compile_definitions("${_RULE_NAME}" INTERFACE + _SHORTFIN_USING_DYLIB + ${_usage_compile_definitions} + ) + target_include_directories("${_RULE_NAME}" INTERFACE ${_usage_include_directories}) + if(_RULE_LINUX_LD_SCRIPT) + target_link_options("${_RULE_NAME}" PRIVATE + "$<$:-Wl,--version-script=${_RULE_LINUX_LD_SCRIPT}>" + ) + endif() + target_link_libraries( + "${_RULE_NAME}" + PRIVATE ${_DYLIB_COMPONENTS} + ) + set_target_properties("${_RULE_NAME}" PROPERTIES + VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH} + SOVERSION ${SOVERSION} + ) + endif() +endfunction() + +function(shortfin_cc_component) + cmake_parse_arguments( + _RULE + "" + "NAME" + "HDRS;SRCS;DEFINES;DEPS;COMPONENTS" + ${ARGN} + ) + if(SHORTFIN_BUILD_STATIC) + # Static object library. + set(_STATIC_OBJECTS_NAME "${_RULE_NAME}.objects") + shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS}) + add_library(${_STATIC_OBJECTS_NAME} OBJECT) + target_sources(${_STATIC_OBJECTS_NAME} + PRIVATE + ${_RULE_SRCS} + ${_RULE_HDRS} + ) + target_compile_options(${_STATIC_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS}) + target_link_libraries(${_STATIC_OBJECTS_NAME} + PUBLIC + _shortfin_defs + ${_STATIC_COMPONENTS} + ${_RULE_DEPS} + ) + target_compile_definitions(${_STATIC_OBJECTS_NAME} PUBLIC ${_RULE_DEFINES}) + endif() + + if(SHORTFIN_BUILD_DYNAMIC) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + set(_DYLIB_OBJECTS_NAME "${_RULE_NAME}.dylib.objects") + shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS}) + # Dylib object library. + add_library(${_DYLIB_OBJECTS_NAME} OBJECT) + target_sources(${_DYLIB_OBJECTS_NAME} + PRIVATE + ${_RULE_SRCS} + ${_RULE_HDRS} + ) + target_compile_options(${_DYLIB_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS}) + target_link_libraries(${_DYLIB_OBJECTS_NAME} + PUBLIC + _shortfin_defs + ${_DYLIB_COMPONENTS} + ${_RULE_DEPS} + ) + set_target_properties( + ${_DYLIB_OBJECTS_NAME} PROPERTIES + CXX_VISIBILITY_PRESET hidden + C_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN ON + ) + target_compile_definitions(${_DYLIB_OBJECTS_NAME} + PRIVATE + _SHORTFIN_BUILDING_DYLIB + # Mate up with spdlog export settings since this is part of the + # library that is exporting these symbols. + SPDLOG_SHARED_LIB + spdlog_EXPORTS + ) + target_compile_definitions(${_DYLIB_OBJECTS_NAME} PUBLIC ${_RULE_DEFINES}) + endif() +endfunction() + +function(shortfin_components_to_static_libs out_static_libs) + set(_LIBS ${ARGN}) + list(TRANSFORM _LIBS APPEND ".objects") + set(${out_static_libs} ${_LIBS} PARENT_SCOPE) +endfunction() + +function(shortfin_components_to_dynamic_libs out_dynamic_libs) + set(_LIBS ${ARGN}) + list(TRANSFORM _LIBS APPEND ".dylib.objects") + set(${out_dynamic_libs} "${_LIBS}" PARENT_SCOPE) +endfunction() + +function(shortfin_gtest_test) + cmake_parse_arguments( + _RULE + "" + "NAME" + "SRCS;DEPS" + ${ARGN} + ) + + if(NOT SHORTFIN_BUILD_TESTS) + return() + endif() + + add_executable(${_RULE_NAME} ${_RULE_SRCS}) + target_link_libraries(${_RULE_NAME} PRIVATE + ${_RULE_DEPS} + ${SHORTFIN_LINK_LIBRARY_NAME} + GTest::gmock + GTest::gtest_main + ) + gtest_discover_tests(${_RULE_NAME}) +endfunction() + + +# Make changes to the global compile flags and properties before including +# bundled deps. This configures various options aimed at making the bundled +# dependencies private. +# The effect can be undone with shortfin_pop_bundled_lib_options(). +# After this call, additional changes can be made to CMAKE_CXX_FLAGS as desired. +macro(shortfin_push_bundled_lib_options) + set(SHORTFIN_ORIG_CXX_VISIBILITY_PRESET "${CMAKE_CXX_VISIBILITY_PRESET}") + set(SHORTFIN_ORIG_C_VISIBILITY_PRESET "${CMAKE_C_VISIBILITY_PRESET}") + set(SHORTFIN_ORIG_VISIBILITY_INLINES_HIDDEN "${CMAKE_VISIBILITY_INLINES_HIDDEN}") + set(SHORTFIN_ORIG_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + + # Callers get a known state for visibility controls and can make changes from + # there. + set(CMAKE_C_VISIBILITY_PRESET "default") + set(CMAKE_CXX_VISIBILITY_PRESET "default") + set(CMAKE_VISIBILITY_INLINES_HIDDEN ON) +endmacro() + +macro(shortfin_pop_bundled_lib_options) + set(CMAKE_CXX_VISIBILITY_PRESET ${SHORTFIN_ORIG_CXX_VISIBILITY_PRESET}) + set(CMAKE_C_VISIBILITY_PRESET ${SHORTFIN_ORIG_C_VISIBILITY_PRESET}) + set(CMAKE_VISIBILITY_INLINES_HIDDEN ${SHORTFIN_ORIG_VISIBILITY_INLINES_HIDDEN}) + set(CMAKE_CXX_FLAGS "${SHORTFIN_ORIG_CXX_FLAGS}") +endmacro() diff --git a/shortfin/build_tools/python_lsan_suppressions.txt b/shortfin/build_tools/python_lsan_suppressions.txt new file mode 100644 index 000000000..f3ac58064 --- /dev/null +++ b/shortfin/build_tools/python_lsan_suppressions.txt @@ -0,0 +1,11 @@ +leak:PyUnicode_New +leak:_PyUnicodeWriter_PrepareInternal +leak:_PyUnicodeWriter_Finish +leak:numpy +leak:_mlir_libs +leak:google/_upb +leak:import_find_and_load +leak:pyo3::pyclass::create_type_object +leak:ufunc +leak:pydantic_core +leak:sentencepiece diff --git a/shortfin/dev_me.py b/shortfin/dev_me.py new file mode 100755 index 000000000..8eacca274 --- /dev/null +++ b/shortfin/dev_me.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# dev_me.py +# +# This is an opinionated development environment setup procedure aimed at +# making core contributors on the same golden path. It is not the only way +# to develop this project. +# +# First time build usage: +# rm -Rf build # Start with a fresh build dir +# python dev_me.py [--cmake=/path/to/cmake] [--clang=/path/to/clang] \ +# [--iree=/path/to/iree] [--asan] [--build-type=Debug] \ +# [--no-tracing] +# +# Subsequent build: +# ./dev_me.py +# +# This will perform an editable install into the used python with both +# default and tracing packages installed. After the initial build, ninja +# can be invoked directly under build/cmake/default or build/cmake/tracy. +# This can be done automatically just by running dev_me.py in a tree with +# an existing build directory. +# +# By default, if there is an iree source dir adjacent to this parent repository, +# that will be used (so you can just directly edit IREE runtime code and build. +# Otherwise, the shortfin build will download a pinned IREE source tree. + +import argparse +import importlib +import os +from pathlib import Path +import re +import subprocess +import shutil +import sys +import sysconfig + +try: + from packaging.version import Version +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"'packaging' package not installed and required: Install with:\n" + f" {sys.executable} -m pip install packaging" + ) + + +CMAKE_REQUIRED_VERSION = Version("3.29") +PYTHON_REQUIRED_VERSION = Version("3.12") +CLANG_REQUIRED_VERSION = Version("16") +SETUPTOOLS_REQUIRED_VERSION = Version("61.0") + + +class EnvInfo: + def __init__(self, args): + self.this_dir = Path(__file__).resolve().parent + self.python_exe = sys.executable + self.python_version = Version(".".join(str(v) for v in sys.version_info[1:2])) + self.debug = bool(sysconfig.get_config_var("Py_DEBUG")) + self.asan = "-fsanitize=address" in sysconfig.get_config_var("PY_LDFLAGS") + self.gil_disabled = bool(sysconfig.get_config_var("Py_GIL_DISABLED")) + self.cmake_exe, self.cmake_version = self.find_cmake(args) + self.ninja_exe = shutil.which("ninja") + self.clang_exe, self.clang_version = self.find_clang(args) + self.iree_dir = self.find_iree(args) + self.setuptools_version = self.find_package_version("setuptools") + self.wheel_version = self.find_package_version("wheel") + + self.configured_dirs = [] + self.add_configured(self.this_dir / "build" / "cmake" / "default") + self.add_configured(self.this_dir / "build" / "cmake" / "tracy") + + def add_configured(self, path: Path): + probe = path / "CMakeCache.txt" + if probe.resolve().exists(): + self.configured_dirs.append(path) + + def find_cmake(self, args): + paths = [] + if args.cmake: + paths.append(str(args.cmake)) + else: + default_cmake = shutil.which("cmake") + if default_cmake: + paths.append(default_cmake) + for cmake_path in paths: + try: + cmake_output = subprocess.check_output( + [cmake_path, "--version"] + ).decode() + except: + continue + if m := re.search("cmake version (.+)", cmake_output): + return cmake_path, Version(m.group(1)) + return None, None + + def find_clang(self, args): + if args.clang: + clang_exe = args.clang + else: + clang_exe = shutil.which("clang") + if not clang_exe: + return None, None + try: + clang_output = subprocess.check_output([clang_exe, "--version"]).decode() + except: + return None, None + if m := re.search(r"clang version ([0-9\.]+)", clang_output): + return clang_exe, Version(m.group(1)) + return None, None + + def find_iree(self, args): + iree_dir = args.iree + if not iree_dir: + # See if a sibling iree directory exists. + iree_dir = self.this_dir.parent.parent / "iree" + if (iree_dir / "CMakeLists.txt").exists(): + return str(iree_dir) + if not iree_dir.exists(): + print(f"ERROR: --iree={iree_dir} directory does not exist") + sys.exit(1) + return str(iree_dir) + + def find_package_version(self, package_name: str) -> Version | None: + try: + m = importlib.import_module(package_name) + except ModuleNotFoundError: + return None + return Version(m.__version__) + + def check_prereqs(self, args): + if self.cmake_version is None or self.cmake_version < CMAKE_REQUIRED_VERSION: + print( + f"ERROR: cmake not found or of an insufficient version: {self.cmake_exe}" + ) + print(f" Required: {CMAKE_REQUIRED_VERSION}, Found: {self.cmake_version}") + print(f" Configure explicitly with --cmake=") + sys.exit(1) + if self.python_version < PYTHON_REQUIRED_VERSION: + print(f"ERROR: python version too old: {self.python_exe}") + print( + f" Required: {PYTHON_REQUIRED_VERSION}, Found: {self.python_version}" + ) + sys.exit(1) + if self.clang_exe and self.clang_version < CLANG_REQUIRED_VERSION: + print(f"WARNING: clang version too old: {self.clang_exe}") + print(f" REQUIRED: {CLANG_REQUIRED_VERSION}, Found {self.clang_version}") + elif not self.clang_exe: + print(f"WARNING: Building the project with clang is highly recommended") + print(f" (pass --clang= to select clang)") + + if args.asan and not self.asan: + print( + f"ERROR: An ASAN build was requested but python was not built with ASAN support" + ) + sys.exit(1) + + if ( + self.setuptools_version is None + or self.setuptools_version < SETUPTOOLS_REQUIRED_VERSION + ): + print( + f"ERROR: 'setuptools' packaging is not installed or too old. " + f"Found {self.setuptools_version}, Need {SETUPTOOLS_REQUIRED_VERSION}" + ) + sys.exit(1) + if self.wheel_version is None: + print(f"'wheel' package is not installed") + sys.exit(1) + + def __repr__(self): + report = [ + f"python: {self.python_exe}", + f"debug: {self.debug}", + f"asan: {self.asan}", + f"gil_disabled: {self.gil_disabled}", + f"cmake: {self.cmake_exe} ({self.cmake_version})", + f"ninja: {self.ninja_exe}", + f"clang: {self.clang_exe} ({self.clang_version})", + f"iree: {self.iree_dir}", + f"setuptools: {self.setuptools_version}", + f"wheel: {self.wheel_version}", + ] + return "\n".join(report) + + +def main(argv: list[str]): + parser = argparse.ArgumentParser( + prog="shortfin dev", description="Shortfin dev setup helper" + ) + parser.add_argument("--cmake", type=Path, help="CMake path") + parser.add_argument("--clang", type=Path, help="Clang path") + parser.add_argument("--iree", type=Path, help="Path to IREE source checkout") + parser.add_argument("--asan", action="store_true", help="Build with ASAN support") + parser.add_argument( + "--no-tracing", action="store_true", help="Disable IREE tracing build" + ) + parser.add_argument( + "--build-type", default="Debug", help="CMake build type (default Debug)" + ) + args = parser.parse_args(argv) + env_info = EnvInfo(args) + + if env_info.configured_dirs: + print("First time configure...") + build_mode(env_info) + else: + configure_mode(env_info, args) + + +def configure_mode(env_info: EnvInfo, args): + print("Environment info:") + print(env_info) + env_info.check_prereqs(args) + + env_vars = { + "SHORTFIN_DEV_MODE": "ON", + "SHORTFIN_CMAKE_BUILD_TYPE": args.build_type, + "SHORTFIN_ENABLE_ASAN": "ON" if args.asan else "OFF", + "SHORTFIN_CMAKE": env_info.cmake_exe, + } + if env_info.iree_dir: + env_vars["SHORTFIN_IREE_SOURCE_DIR"] = env_info.iree_dir + if env_info.clang_exe: + env_vars["CC"] = env_info.clang_exe + env_vars["CXX"] = f"{env_info.clang_exe}++" + env_vars["CMAKE_LINKER_TYPE"] = "LLD" + env_vars["SHORTFIN_ENABLE_TRACING"] = "OFF" if args.no_tracing else "ON" + + print("Executing setup:") + setup_args = [ + env_info.python_exe, + "-m", + "pip", + "install", + "--no-build-isolation", + "-v", + "-e", + str(env_info.this_dir), + ] + print(f"{' '.join('='.join(str(kv)) for kv in env_vars.items())} \\") + print(f" {' '.join(setup_args)}") + actual_env_vars = dict(os.environ) + actual_env_vars.update(env_vars) + subprocess.check_call(setup_args, cwd=env_info.this_dir, env=actual_env_vars) + print("You are now DEV'd!") + + +def build_mode(env_info: EnvInfo): + print("Building") + for build_dir in env_info.configured_dirs: + subprocess.check_call([env_info.cmake_exe, "--build", str(build_dir)]) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/shortfin/docs/README.md b/shortfin/docs/README.md new file mode 100644 index 000000000..3d993628a --- /dev/null +++ b/shortfin/docs/README.md @@ -0,0 +1,20 @@ +# Python API Docs + +Documentation for the Python API is build with Sphinx under this directory. + +## Building docs + +The Python modules will be automatically imported if installed or if the build +is located at `../build`, relative to this file. + +### Install dependencies + +```shell +python3 -m pip install -r requirements.txt +``` + +### Build the docs + +```shell +sphinx-build -b html . _build +``` diff --git a/shortfin/docs/conf.py b/shortfin/docs/conf.py new file mode 100644 index 000000000..62a49b026 --- /dev/null +++ b/shortfin/docs/conf.py @@ -0,0 +1,43 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +import os +import sys + +try: + import _shortfin_default +except ImportError: + sys.path.insert(0, os.path.abspath("../build/python/")) + import _shortfin_default + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "shortfin" +copyright = "2024, Advanced Micro Devices, Inc" +author = "shortfin Authors" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", +] + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] diff --git a/shortfin/docs/index.rst b/shortfin/docs/index.rst new file mode 100644 index 000000000..3176543ea --- /dev/null +++ b/shortfin/docs/index.rst @@ -0,0 +1,27 @@ +.. Copyright 2024 Advanced Micro Devices, Inc. + +.. shortfin documentation master file, created by + sphinx-quickstart on Fri Sep 6 16:31:45 2024. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to shortfin's documentation! +======================================= + +.. toctree:: + :maxdepth: 2 + :caption: Contents + +.. toctree:: + :maxdepth: 2 + :caption: Reference + + reference + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/shortfin/docs/reference.rst b/shortfin/docs/reference.rst new file mode 100644 index 000000000..8fc96be22 --- /dev/null +++ b/shortfin/docs/reference.rst @@ -0,0 +1,64 @@ +.. Copyright 2024 Advanced Micro Devices, Inc. + +.. py:module:: _shortfin_default.lib + +.. _reference: + +API Reference +============= + +Array +-------------- +.. automodule:: _shortfin_default.lib.array +.. autoclass:: DType +.. autoclass:: storage + :members: +.. autoclass:: base_array +.. autoclass:: device_array + :members: +.. autoclass:: RandomGenerator + +.. autofunction:: _shortfin_default.lib.array.fill_randn +.. autofunction:: _shortfin_default.lib.array.argmax + +Local +-------------- + +.. automodule:: _shortfin_default.lib.local + +.. autoclass:: SystemBuilder +.. autoclass:: System +.. autoclass:: Node +.. autoclass:: Device +.. autoclass:: DeviceAffinity +.. autoclass:: Program +.. autoclass:: ProgramFunction + :members: +.. autoclass:: ProgramModule +.. autoclass:: ProgramInvocation +.. autoclass:: Fiber +.. autoclass:: ScopedDevice +.. autoclass:: Worker +.. autoclass:: Process +.. autoclass:: CompletionEvent +.. autoclass:: Message +.. autoclass:: Queue +.. autoclass:: QueueWriter +.. autoclass:: QueueReader +.. autoclass:: Future +.. autoclass:: VoidFuture +.. autoclass:: MessageFuture + + +AMD GPU +^^^^^^^ +.. automodule:: _shortfin_default.lib.local.amdgpu +.. autoclass:: SystemBuilder + :members: +.. autoclass:: AMDGPUDevice + +Host +^^^^^^^ +.. automodule:: _shortfin_default.lib.local.host +.. autoclass:: CPUSystemBuilder +.. autoclass:: HostCPUDevice diff --git a/shortfin/docs/requirements.txt b/shortfin/docs/requirements.txt new file mode 100644 index 000000000..1aef75db6 --- /dev/null +++ b/shortfin/docs/requirements.txt @@ -0,0 +1,2 @@ +sphinx==7.4.7 +sphinx_rtd_theme==2.0.0 diff --git a/libshortfin/examples/python/async/basic_asyncio.py b/shortfin/examples/python/async/basic_asyncio.py similarity index 100% rename from libshortfin/examples/python/async/basic_asyncio.py rename to shortfin/examples/python/async/basic_asyncio.py diff --git a/libshortfin/examples/python/async/device_sync.py b/shortfin/examples/python/async/device_sync.py similarity index 86% rename from libshortfin/examples/python/async/device_sync.py rename to shortfin/examples/python/async/device_sync.py index b88049ae0..2c1a8248c 100644 --- a/libshortfin/examples/python/async/device_sync.py +++ b/shortfin/examples/python/async/device_sync.py @@ -14,7 +14,7 @@ class MyProcess(sf.Process): async def run(self): - device = self.scope.device(0) + device = self.fiber.device(0) ary1 = snp.device_array(device, [32, 1, 4], snp.int32) ary1.storage.fill(array.array("i", [0])) print(f"[pid:{self.pid}] ARY1:", ary1) @@ -27,11 +27,11 @@ async def run(self): async def main(): worker = lsys.create_worker("main") - scope = lsys.create_scope(worker) + fiber = lsys.create_fiber(worker) print("+++ Launching process") await asyncio.gather( - MyProcess(scope=scope).launch(), - MyProcess(scope=scope).launch(), + MyProcess(fiber=fiber).launch(), + MyProcess(fiber=fiber).launch(), ) print("--- Process terminated") diff --git a/libshortfin/examples/python/async/process.py b/shortfin/examples/python/async/process.py similarity index 82% rename from libshortfin/examples/python/async/process.py rename to shortfin/examples/python/async/process.py index dcf539516..96db63052 100644 --- a/libshortfin/examples/python/async/process.py +++ b/shortfin/examples/python/async/process.py @@ -32,7 +32,7 @@ async def run(self): processes = [] if self.arg < 10: await asyncio.sleep(0.1) - processes.append(MyProcess(self.arg + 1, scope=self.scope).launch()) + processes.append(MyProcess(self.arg + 1, fiber=self.fiber).launch()) await asyncio.gather(*processes) print(f"[pid:{self.pid}] Goodbye async:", self.arg, self) tick_total() @@ -41,14 +41,14 @@ async def run(self): async def main(): def create_worker(i): worker = lsys.create_worker(f"main-{i}") - return lsys.create_scope(worker) + return lsys.create_fiber(worker) workers = [create_worker(i) for i in range(3)] processes = [] for i in range(10): - processes.append(MyProcess(i, scope=workers[i % len(workers)]).launch()) - processes.append(MyProcess(i * 100, scope=workers[i % len(workers)]).launch()) - processes.append(MyProcess(i * 1000, scope=workers[i % len(workers)]).launch()) + processes.append(MyProcess(i, fiber=workers[i % len(workers)]).launch()) + processes.append(MyProcess(i * 100, fiber=workers[i % len(workers)]).launch()) + processes.append(MyProcess(i * 1000, fiber=workers[i % len(workers)]).launch()) await asyncio.sleep(0.1) print("<
>") diff --git a/libshortfin/examples/python/async/queue.py b/shortfin/examples/python/async/queue.py similarity index 85% rename from libshortfin/examples/python/async/queue.py rename to shortfin/examples/python/async/queue.py index dabc2c7f0..90f9dc592 100644 --- a/libshortfin/examples/python/async/queue.py +++ b/shortfin/examples/python/async/queue.py @@ -55,17 +55,19 @@ async def run(self): async def main(): - queue = lsys.create_queue("infeed") - main_scope = lsys.create_scope() + queue = lsys.create_queue() + main_fiber = lsys.create_fiber() + # TODO: Also test named queues. + # queue = lsys.create_queue("infeed") w1 = lsys.create_worker("w1") - w1_scope = lsys.create_scope(w1) + w1_fiber = lsys.create_fiber(w1) await asyncio.gather( - WriterProcess(queue, scope=main_scope).launch(), + WriterProcess(queue, fiber=main_fiber).launch(), # By having a reader on the main worker and a separate worker, # we test both intra and inter worker future resolution, which # take different paths internally. - ReaderProcess(queue, scope=main_scope).launch(), - ReaderProcess(queue, scope=w1_scope).launch(), + ReaderProcess(queue, fiber=main_fiber).launch(), + ReaderProcess(queue, fiber=w1_fiber).launch(), ) diff --git a/shortfin/examples/python/enumerate_devices.py b/shortfin/examples/python/enumerate_devices.py new file mode 100644 index 000000000..2c1cc8203 --- /dev/null +++ b/shortfin/examples/python/enumerate_devices.py @@ -0,0 +1,32 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +r"""Simple test program that enumerates devices available. + +Run with SystemBuilder keyword args on the command line like:: + + python examples/python/enumerate_devices.py \ + system_type=amdgpu amdgpu_logical_devices_per_physical_device=4 + +""" + +import sys + +import shortfin as sf + + +def main(): + args = [arg.split("=", maxsplit=1) for arg in sys.argv[1:]] + arg_dict = {k: v for k, v in args} + print(f"Creating system with args: {arg_dict}") + builder = sf.SystemBuilder(**arg_dict) + with builder.create_system() as ls: + for device in ls.devices: + print(device) + + +if __name__ == "__main__": + main() diff --git a/libshortfin/examples/python/fastapi/server.py b/shortfin/examples/python/fastapi/server.py similarity index 97% rename from libshortfin/examples/python/fastapi/server.py rename to shortfin/examples/python/fastapi/server.py index 639cf2d00..8807cf4c6 100644 --- a/libshortfin/examples/python/fastapi/server.py +++ b/shortfin/examples/python/fastapi/server.py @@ -67,9 +67,10 @@ async def run(self): ) await asyncio.sleep(0.01) responder.stream_part(None) - except Exception as e: - responder.close_with_error() + except: traceback.print_exc() + finally: + responder.ensure_response() @asynccontextmanager diff --git a/libshortfin/examples/python/mobilenet_server/.gitignore b/shortfin/examples/python/mobilenet_server/.gitignore similarity index 100% rename from libshortfin/examples/python/mobilenet_server/.gitignore rename to shortfin/examples/python/mobilenet_server/.gitignore diff --git a/libshortfin/examples/python/mobilenet_server/build_model.sh b/shortfin/examples/python/mobilenet_server/build_model.sh similarity index 87% rename from libshortfin/examples/python/mobilenet_server/build_model.sh rename to shortfin/examples/python/mobilenet_server/build_model.sh index 6011072c9..26ae2fad2 100755 --- a/libshortfin/examples/python/mobilenet_server/build_model.sh +++ b/shortfin/examples/python/mobilenet_server/build_model.sh @@ -39,5 +39,11 @@ echo "Import onnx model" python -m iree.compiler.tools.import_onnx $onnx_upgrade_path -o $mlir_path echo "Compile onnx model" -python -m iree.compiler.tools.scripts.ireec \ - $mlir_path -o "$vmfb_path" --iree-input-type=onnx --iree-hal-target-backends=llvm-cpu +if [ -z "$@" ]; then + compile_flags="--iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host" +else + compile_flags="$@" +fi +echo "Using compile flags: $compile_flags" +python -m iree.compiler.tools.scripts.iree_compile \ + $mlir_path -o "$vmfb_path" --iree-input-type=onnx $compile_flags diff --git a/libshortfin/examples/python/mobilenet_server/inference_system.py b/shortfin/examples/python/mobilenet_server/inference_system.py similarity index 81% rename from libshortfin/examples/python/mobilenet_server/inference_system.py rename to shortfin/examples/python/mobilenet_server/inference_system.py index 068204f49..cc8b6dd1b 100644 --- a/libshortfin/examples/python/mobilenet_server/inference_system.py +++ b/shortfin/examples/python/mobilenet_server/inference_system.py @@ -26,7 +26,7 @@ def __init__(self, program, request_queue, **kwargs): super().__init__(**kwargs) self.main_function = program["module.torch-jit-export"] self.request_reader = request_queue.reader() - self.device = self.scope.device(0) + self.device = self.fiber.device(0) self.device_input = sfnp.device_array( self.device, [MAX_BATCH, 3, 224, 224], sfnp.float32 ) @@ -40,15 +40,17 @@ async def run(self): # just writing to the backing storage is the best we have API # support for. Generally, APIs on storage should be mirrored onto # the array. - self.host_staging.storage.data = request.raw_image_data + # TODO: Easier to use API for writing into the storage + with self.host_staging.storage.map(write=True, discard=True) as m: + m.fill(request.raw_image_data) print("host_staging =", self.host_staging) self.device_input.copy_from(self.host_staging) # Simple call. Note that the await here is merely awaiting the # result being *available* (i.e. that the VM coroutine has # completed) but does not indicate that the result is ready. - (result1,) = await self.main_function(self.device_input) - (result2,) = await self.main_function(self.device_input) + (result1,) = await self.main_function(self.device_input, fiber=self.fiber) + (result2,) = await self.main_function(self.device_input, fiber=self.fiber) # TODO: Implement await on individual results. The accounting is # there but currently we can only await on the device itself. @@ -57,20 +59,20 @@ async def run(self): print("Result 2:", result2) # Explicit invocation object. - # inv = self.main_function.invocation(scope=self.scope) + # inv = self.main_function.invocation(fiber=self.fiber) # inv.add_arg(self.device_input) # results = await inv.invoke() # print("results:", results) # Multiple invocations in parallel. # all_results = await asyncio.gather( - # self.main_function(self.device_input, scope=self.scope), - # self.main_function(self.device_input, scope=self.scope), - # self.main_function(self.device_input, scope=self.scope), + # self.main_function(self.device_input, fiber=self.fiber), + # self.main_function(self.device_input, fiber=self.fiber), + # self.main_function(self.device_input, fiber=self.fiber), # ) # print("All results:", all_results) - # output = await self.scope.invoke(self.main_function, self.device_input) + # output = await self.fiber.invoke(self.main_function, self.device_input) # print("OUTPUT:", output) # read_back = self.device_input.for_transfer() # read_back.copy_from(self.device_input) @@ -88,14 +90,14 @@ def __init__(self, lsys: sf.System, home_dir: Path): print(f"Loaded: {self.program_module}") self.processes = [] - async def start_scope(self, scope): + async def start_fiber(self, fiber): # Note that currently, program load is synchronous. But we do it # in a task so we can await it in the future and let program loads # overlap. for _ in range(self.processes_per_worker): - program = sf.Program([self.program_module], scope=scope) + program = sf.Program([self.program_module], devices=fiber.raw_devices) self.processes.append( - InferenceProcess(program, self.request_queue, scope=scope).launch() + InferenceProcess(program, self.request_queue, fiber=fiber).launch() ) async def main(self): @@ -104,14 +106,14 @@ async def main(self): f"System created with {len(devices)} devices:\n " f"{' '.join(repr(d) for d in devices)}" ) - # We create a physical worker and initial scope for each device. + # We create a physical worker and initial fiber for each device. # This isn't a hard requirement and there are advantages to other # topologies. initializers = [] for device in devices: worker = self.lsys.create_worker(f"device-{device.name}") - scope = self.lsys.create_scope(worker, devices=[device]) - initializers.append(self.start_scope(scope)) + fiber = self.lsys.create_fiber(worker, devices=[device]) + initializers.append(self.start_fiber(fiber)) # Run all initializers in parallel. These launch inference processes. print("Waiting for initializers") @@ -142,7 +144,8 @@ def client(): # Done. writer.close() - lsys = sf.host.CPUSystemBuilder().create_system() + sf.SystemBuilder.default_system_type = "hostcpu" + lsys = sf.SystemBuilder().create_system() main = Main(lsys, home_dir) lsys.init_worker.call_threadsafe(client) lsys.run(main.main()) diff --git a/libshortfin/examples/python/mobilenet_server/upgrade_onnx.py b/shortfin/examples/python/mobilenet_server/upgrade_onnx.py similarity index 100% rename from libshortfin/examples/python/mobilenet_server/upgrade_onnx.py rename to shortfin/examples/python/mobilenet_server/upgrade_onnx.py diff --git a/shortfin/mypy.ini b/shortfin/mypy.ini deleted file mode 100644 index d10567407..000000000 --- a/shortfin/mypy.ini +++ /dev/null @@ -1,5 +0,0 @@ -[mypy] - -explicit_package_bases = True -mypy_path = $MYPY_CONFIG_FILE_DIR -packages = shortfin.llm diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml index d347ed631..1abb49ef6 100644 --- a/shortfin/pyproject.toml +++ b/shortfin/pyproject.toml @@ -1,12 +1,54 @@ [build-system] -requires = ["setuptools", "wheel"] +requires = [ + "cmake>=3.29", + "setuptools>=61.0", + "wheel", + "ninja", + 'typing-extensions ; python_version == "3.10" ', +] build-backend = "setuptools.build_meta" +[project] +name = "shortfin" +authors = [ + {name = "SHARK Authors"}, +] +description = "SHARK inference library and serving engine" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">= 3.10" + +# Version is set via the `setup.py`. +dynamic = ["version"] + +[project.urls] +Repository = "https://github.com/nod-ai/shark-ai" +Documentation = "https://shortfin.readthedocs.io/en/latest/" + +[project.optional-dependencies] +apps = [ + "transformers", + "dataclasses-json", + "pillow", + "fastapi", + "uvicorn", + "aiohttp", +] + [tool.pytest.ini_options] -addopts = "-ra" +addopts = [ + "-ra", + "--import-mode=importlib", +] testpaths = [ "tests", ] -pythonpath = [ - ".", -] diff --git a/shortfin/python/CMakeLists.txt b/shortfin/python/CMakeLists.txt new file mode 100644 index 000000000..d125416af --- /dev/null +++ b/shortfin/python/CMakeLists.txt @@ -0,0 +1,79 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# shortfin publishes multiple python packages: - _shortfin: Trampoline +# __init__.py which looks at environment variables to load an appropriate native +# library. - _shortfin_default.lib: Native library as a default, uninstrumented +# build. - _shortfin_tracing.lib: Native library with tracing enabled. - +# Others. + +# nanobind +FetchContent_Declare( + nanobind + GIT_REPOSITORY https://github.com/wjakob/nanobind.git + GIT_TAG 784efa2a0358a4dc5432c74f5685ee026e20f2b6 # 2.2.0 +) +FetchContent_MakeAvailable(nanobind) + +nanobind_add_module(shortfin_python_extension + NB_STATIC LTO FREE_THREADED + array_binding.cc + array_host_ops.cc + lib_ext.cc +) + +if (SHORTFIN_ENABLE_TRACING) + set_target_properties(shortfin_python_extension + PROPERTIES OUTPUT_NAME "_shortfin_tracy/lib") +else() + set_target_properties(shortfin_python_extension + PROPERTIES OUTPUT_NAME "_shortfin_default/lib") +endif() + +target_link_libraries(shortfin_python_extension + PRIVATE ${SHORTFIN_LINK_LIBRARY_NAME} +) + +function(shortfin_python_stubs build_type) + nanobind_add_stub( + shortfin_python_extension_stub + MODULE _shortfin_${build_type}.lib + OUTPUT _shortfin_${build_type}/lib.pyi + DEPENDS shortfin_python_extension + ) + +endfunction() + +function(shortfin_python_stubs build_variant) + set(output_root "${CMAKE_CURRENT_BINARY_DIR}/_shortfin_${build_variant}") + file(MAKE_DIRECTORY ${output_root}) + nanobind_add_stub( + shortfin_python_extension_stub_lib_${build_variant} + MODULE _shortfin_${build_variant}.lib + OUTPUT ${output_root}/lib/__init__.pyi + DEPENDS shortfin_python_extension + ) + + nanobind_add_stub( + shortfin_python_extension_stub_array_${build_variant} + MODULE _shortfin_${build_variant}.lib.array + OUTPUT ${output_root}/lib/array.pyi + DEPENDS shortfin_python_extension + ) + + nanobind_add_stub( + shortfin_python_extension_stub_local_${build_variant} + MODULE _shortfin_${build_variant}.lib.local + OUTPUT ${output_root}/lib/local.pyi + DEPENDS shortfin_python_extension + ) +endfunction() + +if (SHORTFIN_ENABLE_TRACING) + shortfin_python_stubs(tracy) +else() + shortfin_python_stubs(default) +endif() diff --git a/shortfin/python/_shortfin/__init__.py b/shortfin/python/_shortfin/__init__.py new file mode 100644 index 000000000..9bfa3a497 --- /dev/null +++ b/shortfin/python/_shortfin/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# The proper way to import this package is via: +# from _shortfin import lib as sfl + +from typing import TYPE_CHECKING + +import os +import sys +import warnings + +if TYPE_CHECKING: + from _shortfin_default import lib +else: + variant = os.getenv("SHORTFIN_PY_RUNTIME", "default") + + if variant == "tracy": + try: + from _shortfin_tracy import lib + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Shortfin Tracy runtime requested via SHORTFIN_PY_RUNTIME but it is not enabled in this build" + ) + print("-- Using Tracy runtime (SHORTFIN_PY_RUNTIME=tracy)", file=sys.stderr) + else: + if variant != "default": + warnings.warn( + f"Unknown value for SHORTFIN_PY_RUNTIME env var ({variant}): Using default" + ) + from _shortfin_default import lib diff --git a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py b/shortfin/python/_shortfin/asyncio_bridge.py similarity index 74% rename from libshortfin/bindings/python/_shortfin/asyncio_bridge.py rename to shortfin/python/_shortfin/asyncio_bridge.py index 78aff49a6..28264e9e3 100644 --- a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py +++ b/shortfin/python/_shortfin/asyncio_bridge.py @@ -5,10 +5,19 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio +import inspect from . import lib as sfl +# Feature detect some versions where signatures changes. +if "context" in inspect.signature(asyncio.Task).parameters: + # Python > 3.10 + _ASYNCIO_TASK_HAS_CONTEXT = True +else: + _ASYNCIO_TASK_HAS_CONTEXT = False + + class PyWorkerEventLoop(asyncio.AbstractEventLoop): def __init__(self, worker: sfl.local.Worker): self._worker = worker @@ -17,8 +26,15 @@ def get_debug(self): # Requirement of asyncio. return False - def create_task(self, coro): - return asyncio.Task(coro, loop=self) + if _ASYNCIO_TASK_HAS_CONTEXT: + + def create_task(self, coro, *, name=None, context=None): + return asyncio.Task(coro, loop=self, name=name, context=context) + + else: + + def create_task(self, coro, *, name=None): + return asyncio.Task(coro, loop=self, name=name) def create_future(self): return asyncio.Future(loop=self) @@ -48,6 +64,13 @@ def call_later( w.delay_call(deadline, handle._sf_maybe_run) return handle + def call_at(self, when, callback, *args, context=None) -> asyncio.TimerHandle: + w = self._worker + deadline = int(when * 1e9) + handle = _TimerHandle(when, callback, args, self, context) + w.delay_call(deadline, handle._sf_maybe_run) + return handle + def call_exception_handler(self, context) -> None: # TODO: Should route this to the central exception handler. Should # also play with ergonomics of how the errors get reported in @@ -62,7 +85,7 @@ def call_exception_handler(self, context) -> None: def _timer_handle_cancelled(self, handle): # We don't do anything special: just skip it if it comes up. - pass + ... class _Handle(asyncio.Handle): diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc new file mode 100644 index 000000000..08a4071a8 --- /dev/null +++ b/shortfin/python/array_binding.cc @@ -0,0 +1,679 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "./lib_ext.h" +#include "./utils.h" +#include "shortfin/array/api.h" +#include "shortfin/support/logging.h" + +using namespace shortfin::array; + +namespace shortfin::python { + +namespace { +static const char DOCSTRING_ARRAY_COPY_FROM[] = + R"(Copy contents from a source array to this array. + +Equivalent to `dest_array.storage.copy_from(source_array.storage)`. +)"; + +static const char DOCSTRING_ARRAY_COPY_TO[] = + R"(Copy contents this array to a destination array. + +Equivalent to `dest_array.storage.copy_from(source_array.storage)`. +)"; + +static const char DOCSTRING_ARRAY_FILL[] = R"(Fill an array with a value. + +Note that `fill` is asynchronous and may not be visible immediately. For immediate +manipulation of host visible arrays, assign to the `items` property or use the +`map(discard=True)` to get a mapping object which can be used to directly +update the contents. + +Equivalent to `array.storage.fill(pattern)`. +)"; + +static const char DOCSTRING_ARRAY_ITEMS[] = + R"(Convenience shorthand for map(...).items)"; + +static const char DOCSTRING_ARRAY_MAP[] = + R"(Create a typed mapping of the buffer contents in host memory. + +Support kwargs of: + +| read: Enables read access to the mapped memory. +| write: Enables write access to the mapped memory and will flush upon close + (for non-unified memory systems). +| discard: Indicates that the entire memory map should be treated as if it will + be overwritten. Initial contents will be undefined. Implies `write=True`. + +Mapping memory for access from the host requires a compatible buffer that has +been created with host visibility (which includes host buffers). + +The returned mapping object is a context manager that will close/flush on +exit. Alternatively, the `close()` method can be invoked explicitly. + +See also `storage.map()` which functions similarly but does not allow access +to dtype specific functionality. +)"; + +static const char DOCSTRING_ARRAY_VIEW[] = + R"(Create a view of an array. + +Either integer indices or slices can be passed to the view() method to create +an aliased device_array that shares a subset of the storage. Only view() +organizations that result in a row-major, dense array are currently supported. +)"; + +static const char DOCSTRING_MAPPING_FILL[] = + R"(Fill the host mapping with a pattern. + +The pattern can either be an object implementing the buffer protocol or a Python +int/float if the mapping has a dtype. In this case, the numeric value will be +converted to the appropriate typed pattern. Only dtypes supported by the +array.array class are supported in this fashion. + +The pattern must evenly divide the mapping. + +Note that like all methods on a mapping, any changes are immediately visible +(whereas the `fill` method on the array and storage are async operations). +)"; + +static const char DOCSTRING_MAPPING_ITEMS[] = + R"(Access contents as a Python array. + +When reading this attribute, an array.array will be constructed with the +contents of the mapping. This supports a subset of element types (byte aligned +integers, floats and doubles) corresponding to Python types. + +On write, the mapping will be written with arbitrary Python types marshaled +via array.array into its contents. +)"; + +static const char DOCSTRING_STORAGE_COPY_FROM[] = + R"(Copy contents from a source storage to this array. + +This operation executes asynchronously and the effect will only be visible +once the execution fiber has been synced to the point of mutation. +)"; + +static const char DOCSTRING_STORAGE_FILL[] = R"(Fill a storage with a value. + +Takes as argument any value that can be interpreted as a buffer with the Python +buffer protocol of size 1, 2, or 4 bytes. The storage will be filled uniformly +with the pattern. + +This operation executes asynchronously and the effect will only be visible +once the execution fiber has been synced to the point of mutation. +)"; + +static const char DOCSTRING_STORAGE_MAP[] = + R"(Create a mapping of the buffer contents in host memory. + +Support kwargs of: + +| read: Enables read access to the mapped memory. +| write: Enables write access to the mapped memory and will flush upon close + (for non-unified memory systems). +| discard: Indicates that the entire memory map should be treated as if it will + be overwritten. Initial contents will be undefined. Implies `write=True`. + +Mapping memory for access from the host requires a compatible buffer that has +been created with host visibility (which includes host buffers). + +The returned mapping object is a context manager that will close/flush on +exit. Alternatively, the `close()` method can be invoked explicitly. + +See also `device_array.map()` which functions similarly but allows some +additional dtype specific accessors. +)"; + +device_array PyDeviceArrayView(device_array &array, py::args keys) { + size_t rank = array.shape().size(); + Dims c_offsets(rank, 0); + Dims c_sizes(array.shape_container()); + + if (keys.size() > rank) { + throw std::invalid_argument( + "Cannot create view into device_array greater than its rank"); + } + + for (size_t idx = 0; py::handle key : keys) { + if (py::isinstance(key)) { + // Slice key. + auto slice = py::cast(key); + auto [start, stop, step, length] = slice.compute(c_sizes[idx]); + if (step != 1) { + throw std::logic_error("view does not support strided slices"); + } + c_offsets[idx] = start; + c_sizes[idx] = length; + } else if (py::isinstance(key)) { + // Integer key. + c_offsets[idx] = py::cast(key); + c_sizes[idx] = 1; + } else { + throw std::invalid_argument( + "Args to view must either be integer indices or slices"); + } + idx += 1; + } + + return array.view(c_offsets, c_sizes); +} + +class Refs { + public: + std::unordered_map + element_type_array_type_code_table = + CreateElementTypeArrayTypeCodeTable(); + py::object array_array_ctor = py::module_::import_("array").attr("array"); + + private: + static std::unordered_map + CreateElementTypeArrayTypeCodeTable() { + std::unordered_map table; + // This is really gross. Python's array type codes are pegged to C types, + // which do not have portable sizes. We pick portablish things here and + // carp on mismatch. + auto add_type = [&](DType dt, const char *code, size_t size) { + if (dt.dense_byte_count() != size) { + throw std::invalid_argument( + fmt::format("Illegal native type size for dtype {}, type code {}. " + "Native size mismatch: {} vs {}", + dt.name(), code, dt.dense_byte_count(), size)); + } + table[dt] = py::str(code); + }; + + // See table at https://docs.python.org/3/library/array.html + add_type(DType::int8(), "b", sizeof(char)); + add_type(DType::sint8(), "b", sizeof(char)); + add_type(DType::uint8(), "B", sizeof(unsigned char)); + add_type(DType::int16(), "h", sizeof(signed short)); + add_type(DType::sint16(), "h", sizeof(signed short)); + add_type(DType::uint16(), "H", sizeof(unsigned short)); + add_type(DType::int32(), "i", sizeof(signed int)); + add_type(DType::sint32(), "i", sizeof(signed int)); + add_type(DType::uint32(), "I", sizeof(unsigned int)); + add_type(DType::int64(), "q", sizeof(signed long long)); + add_type(DType::sint64(), "q", sizeof(signed long long)); + add_type(DType::uint64(), "Q", sizeof(unsigned long long)); + add_type(DType::float16(), "H", sizeof(unsigned short)); + add_type(DType::float32(), "f", sizeof(float)); + add_type(DType::float64(), "d", sizeof(double)); + return table; + } +}; + +// Holder for a `mapping` class. Also holds additional metadata. +class PyMapping { + public: + PyMapping() = default; + class mapping &mapping() { return mapping_; } + std::optional dtype() { return dtype_; } + void set_dtype(DType dtype) { dtype_ = dtype; } + + void AssertValid() { + if (!mapping_) { + throw std::logic_error("Mapping has been closed"); + } + } + + void FillFromScalar(Refs *refs, py::handle value) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromScalar"); + if (!dtype()) { + throw std::invalid_argument( + "The `fill` method is only valid for typed mappings but " + "this does not have a dtype"); + } + auto &table = refs->element_type_array_type_code_table; + auto it = table.find(*dtype()); + if (it == table.end()) { + throw std::invalid_argument( + fmt::format("Python array.array type code not known for dtype " + "{}: Cannot access items", + dtype()->name())); + } + py::object pattern = + refs->array_array_ctor(it->second, py::make_tuple(value)); + FillFromBuffer(pattern); + } + + void FillFromBuffer(py::handle buffer) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromBuffer"); + Py_buffer py_view; + int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND. + if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { + throw py::python_error(); + } + PyBufferReleaser py_view_releaser(py_view); + if (mapping().size() % py_view.len != 0) { + throw std::invalid_argument( + fmt::format("Cannot fill mapping of size {} with pattern of " + "size {} (it does not evenly divide)", + mapping().size(), py_view.len)); + } + + // Specialize by fundamental sizes for fast-path implementations. + if (py_view.len == 1) { + std::memset(mapping().data(), *static_cast(py_view.buf), + mapping().size()); + } else if (py_view.len == 2) { + uint16_t v = *static_cast(py_view.buf); + uint16_t *begin = + static_cast(static_cast((mapping().data()))); + std::fill(begin, begin + mapping().size() / sizeof(v), v); + } else if (py_view.len == 4) { + uint32_t v = *static_cast(py_view.buf); + uint32_t *begin = + static_cast(static_cast((mapping().data()))); + std::fill(begin, begin + mapping().size() / sizeof(v), v); + } else if (py_view.len == 8) { + uint64_t v = *static_cast(py_view.buf); + uint64_t *begin = + static_cast(static_cast((mapping().data()))); + std::fill(begin, begin + mapping().size() / sizeof(v), v); + } else { + // Slow path. + uint8_t *begin = mapping().data(); + uint8_t *end = begin + mapping().size(); + while (begin < end) { + std::memcpy(begin, py_view.buf, py_view.len); + begin += py_view.len; + } + } + } + + py::object GetItems(py::handle self_obj, Refs *refs) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::GetItems"); + if (!dtype()) { + throw std::invalid_argument( + "The `items` property is only valid for typed mappings but " + "this does not have a dtype"); + } + AssertValid(); + auto &table = refs->element_type_array_type_code_table; + auto it = table.find(*dtype()); + if (it == table.end()) { + throw std::invalid_argument( + fmt::format("Python array.array type code not known for dtype " + "{}: Cannot access items", + dtype()->name())); + } + py::object py_bytes = py::steal(PyBytes_FromObject(self_obj.ptr())); + py::object items = refs->array_array_ctor(it->second, py_bytes); + return items; + } + + void SetItems(Refs *refs, py::handle initializer) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::SetItems"); + if (!dtype()) { + throw std::invalid_argument( + "The `items` property is only valid for typed mappings but " + "this does not have a dtype"); + } + AssertValid(); + auto &table = refs->element_type_array_type_code_table; + auto it = table.find(*dtype()); + if (it == table.end()) { + throw std::invalid_argument( + fmt::format("Python array.array type code not known for dtype " + "{}: Cannot access items", + dtype()->name())); + } + + py::object items = refs->array_array_ctor(it->second, initializer); + PyBufferRequest src_info(items, PyBUF_SIMPLE); + if (src_info.view().len > mapping().size()) { + throw std::invalid_argument( + fmt::format("Cannot write {} bytes into buffer of {} bytes", + src_info.view().len, mapping().size())); + } + std::memcpy(mapping().data(), src_info.view().buf, src_info.view().len); + } + + private: + class mapping mapping_; + std::optional dtype_; +}; + +// Does in-place creation of a mapping object and stores a pointer to the +// contained array::mapping C++ object. +py::object CreateMappingObject(PyMapping **out_cpp_mapping) { + py::object py_mapping = py::inst_alloc(py::type()); + PyMapping *cpp_mapping = py::inst_ptr(py_mapping); + new (cpp_mapping) mapping(); + py::inst_mark_ready(py_mapping); + *out_cpp_mapping = cpp_mapping; + return py_mapping; +} + +} // namespace + +void BindArray(py::module_ &m) { + auto refs = std::make_shared(); + + py::class_(m, "DType") + .def_prop_ro("name", &DType::name) + .def_prop_ro("is_boolean", &DType::is_boolean) + .def_prop_ro("is_integer", &DType::is_integer) + .def_prop_ro("is_float", &DType::is_float) + .def_prop_ro("is_complex", &DType::is_complex) + .def_prop_ro("bit_count", &DType::bit_count) + .def_prop_ro("is_byte_aligned", &DType::is_byte_aligned) + .def_prop_ro("dense_byte_count", &DType::dense_byte_count) + .def("is_integer_bitwidth", &DType::is_integer_bitwidth) + .def("compute_dense_nd_size", &DType::compute_dense_nd_size) + .def(py::self == py::self) + .def("__repr__", &DType::name); + +#define SHORTFIN_DTYPE_HANDLE(et, ident) m.attr(#ident) = DType::ident(); +#include "shortfin/array/dtypes.inl" +#undef SHORTFIN_DTYPE_HANDLE + + // storage + py::class_(m, "storage") + .def("__sfinv_marshal__", + [](device_array *self, py::capsule inv_capsule, int barrier) { + auto *inv = + static_cast(inv_capsule.data()); + static_cast(self) + ->AddAsInvocationArgument( + inv, static_cast(barrier)); + }) + .def_static( + "allocate_host", + [](local::ScopedDevice &device, iree_device_size_t allocation_size) { + return storage::allocate_host(device, allocation_size); + }, + py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) + .def_static( + "allocate_device", + [](local::ScopedDevice &device, iree_device_size_t allocation_size) { + return storage::allocate_device(device, allocation_size); + }, + py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>()) + .def( + "fill", + [](storage &self, py::handle buffer) { + Py_buffer py_view; + int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND. + if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { + throw py::python_error(); + } + PyBufferReleaser py_view_releaser(py_view); + self.fill(py_view.buf, py_view.len); + }, + py::arg("pattern"), DOCSTRING_STORAGE_FILL) + .def( + "copy_from", [](storage &self, storage &src) { self.copy_from(src); }, + py::arg("source_storage"), DOCSTRING_STORAGE_COPY_FROM) + .def( + "map", + [](storage &self, bool read, bool write, bool discard) { + SHORTFIN_TRACE_SCOPE_NAMED("PyStorage::map"); + int access = 0; + if (read) access |= IREE_HAL_MEMORY_ACCESS_READ; + if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE; + if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD; + if (!access) { + throw std::invalid_argument( + "One of the access flags must be set"); + } + PyMapping *cpp_mapping = nullptr; + py::object py_mapping = CreateMappingObject(&cpp_mapping); + self.map_explicit( + cpp_mapping->mapping(), + static_cast(access)); + return py_mapping; + }, + py::kw_only(), py::arg("read") = false, py::arg("write") = false, + py::arg("discard") = false, DOCSTRING_STORAGE_MAP) + .def(py::self == py::self) + .def("__len__", &storage::byte_length) + .def("__repr__", &storage::to_s); + + // mapping + auto mapping_class = py::class_(m, "mapping"); + mapping_class.def("close", [](PyMapping &self) { self.mapping().reset(); }) + .def_prop_ro("valid", + [](PyMapping &self) -> bool { return self.mapping(); }) + .def("__enter__", [](py::object self_obj) { return self_obj; }) + .def( + "__exit__", + [](PyMapping &self, py::handle exc_type, py::handle exc_value, + py::handle exc_tb) { self.mapping().reset(); }, + py::arg("exc_type").none(), py::arg("exc_value").none(), + py::arg("exc_tb").none()) + .def( + "fill", + [refs](PyMapping &self, py::int_ value) { + self.AssertValid(); + self.FillFromScalar(refs.get(), value); + }, + py::arg("value"), DOCSTRING_MAPPING_FILL) + .def( + "fill", + [refs](PyMapping &self, py::float_ value) { + self.AssertValid(); + self.FillFromScalar(refs.get(), value); + }, + py::arg("value"), DOCSTRING_MAPPING_FILL) + .def( + "fill", + [](PyMapping &self, py::handle buffer) { + self.AssertValid(); + self.FillFromBuffer(buffer); + }, + py::arg("buffer"), DOCSTRING_MAPPING_FILL) + .def_prop_rw( + "items", + [refs](py::handle self_obj) { + PyMapping &self = py::cast(self_obj); + return self.GetItems(self_obj, refs.get()); + }, + [refs](PyMapping &self, py::handle initializer) { + self.SetItems(refs.get(), initializer); + }, + DOCSTRING_MAPPING_ITEMS); + + struct MappingBufferHandler { + int operator()(PyMapping &self, Py_buffer *view, int flags) { + view->buf = self.mapping().data(); + view->len = self.mapping().size(); + view->readonly = !self.mapping().writable(); + view->itemsize = 1; + view->format = (char *)"B"; // Byte + view->ndim = 1; + view->shape = nullptr; + view->strides = nullptr; + view->suboffsets = nullptr; + view->internal = nullptr; + return 0; + } + }; + BindBufferProtocol(mapping_class); + + // base_array and subclasses + py::class_(m, "base_array") + .def_prop_ro("dtype", &base_array::dtype) + .def_prop_ro("shape", &base_array::shape); + + py::class_(m, "device_array") + .def("__init__", [](py::args, py::kwargs) {}) + .def_static( + "__new__", + [](py::handle py_type, class storage storage, + std::span shape, DType dtype) { + return custom_new_keep_alive( + py_type, /*keep_alive=*/storage.fiber(), storage, shape, dtype); + }, + py::arg("cls"), py::arg("storage"), py::arg("shape"), + py::arg("dtype")) + .def_static( + "__new__", + [](py::handle py_type, local::ScopedDevice &device, + std::span shape, DType dtype) { + return custom_new_keep_alive( + py_type, /*keep_alive=*/device.fiber(), + device_array::for_device(device, shape, dtype)); + }, + py::arg("cls"), py::arg("device"), py::arg("shape"), py::arg("dtype")) + .def("__sfinv_marshal__", + [](device_array *self, py::capsule inv_capsule, int barrier) { + auto *inv = + static_cast(inv_capsule.data()); + static_cast(self) + ->AddAsInvocationArgument( + inv, static_cast(barrier)); + }) + .def_static( + "for_device", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/device.fiber(), + device_array::for_device(device, shape, dtype)); + }, + py::arg("device"), py::arg("shape"), py::arg("dtype")) + .def_static( + "for_host", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/device.fiber(), + device_array::for_host(device, shape, dtype)); + }, + py::arg("device"), py::arg("shape"), py::arg("dtype")) + .def("for_transfer", + [](device_array &self) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/self.device().fiber(), self.for_transfer()); + }) + .def_prop_ro("device", &device_array::device, + py::rv_policy::reference_internal) + .def_prop_ro("storage", &device_array::storage, + py::rv_policy::reference_internal) + .def( + "fill", + [](py::handle_t self, py::handle buffer) { + self.attr("storage").attr("fill")(buffer); + }, + py::arg("pattern"), DOCSTRING_ARRAY_FILL) + .def("copy_from", &device_array::copy_from, py::arg("source_array"), + DOCSTRING_ARRAY_COPY_FROM) + .def("copy_to", &device_array::copy_to, py::arg("dest_array"), + DOCSTRING_ARRAY_COPY_TO) + .def("view", PyDeviceArrayView, DOCSTRING_ARRAY_VIEW) + .def( + "map", + [](device_array &self, bool read, bool write, bool discard) { + SHORTFIN_TRACE_SCOPE_NAMED("PyArray::map"); + int access = 0; + if (read) access |= IREE_HAL_MEMORY_ACCESS_READ; + if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE; + if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD; + if (!access) { + throw std::invalid_argument( + "One of the access flags must be set"); + } + PyMapping *cpp_mapping = nullptr; + py::object py_mapping = CreateMappingObject(&cpp_mapping); + cpp_mapping->set_dtype(self.dtype()); + self.storage().map_explicit( + cpp_mapping->mapping(), + static_cast(access)); + return py_mapping; + }, + py::kw_only(), py::arg("read") = false, py::arg("write") = false, + py::arg("discard") = false, DOCSTRING_ARRAY_MAP) + .def_prop_rw( + "items", + [refs](device_array &self) { + SHORTFIN_TRACE_SCOPE_NAMED("PyArray::items"); + PyMapping *mapping; + py::object mapping_obj = CreateMappingObject(&mapping); + mapping->set_dtype(self.dtype()); + self.storage().map_explicit( + mapping->mapping(), static_cast( + IREE_HAL_MEMORY_ACCESS_READ)); + return mapping->GetItems(mapping_obj, refs.get()); + }, + [refs](device_array &self, py::handle initializer) { + PyMapping mapping; + mapping.set_dtype(self.dtype()); + self.storage().map_explicit( + mapping.mapping(), static_cast( + IREE_HAL_MEMORY_ACCESS_READ)); + return mapping.SetItems(refs.get(), initializer); + }, + DOCSTRING_ARRAY_ITEMS) + .def_prop_ro( + "__array_interface__", + [refs](device_array &self) { + SHORTFIN_TRACE_SCOPE_NAMED("PyArray::__array_interface__"); + py::dict interface; + interface["version"] = 3; + interface["strides"] = py::none(); + + auto shape = self.shape(); + py::list shapeList; + for (size_t i = 0; i < shape.size(); ++i) { + shapeList.append(shape[i]); + } + py::tuple shape_tuple(py::cast(shapeList)); + interface["shape"] = shape_tuple; + + auto &table = refs->element_type_array_type_code_table; + auto it = table.find(self.dtype()); + if (it == table.end()) { + throw std::invalid_argument(fmt::format( + "Python array.array type code not known for dtype " + "{}: Cannot access items", + self.dtype().name())); + } + + auto typeString = py::str(); + if (it->first == DType::float16()) { + typeString = py::str("float16"); + } else { + typeString = py::str(it->second); + } + interface["typestr"] = typeString; + + PyMapping *mapping; + py::object mapping_obj = CreateMappingObject(&mapping); + mapping->set_dtype(self.dtype()); + self.storage().map_explicit( + mapping->mapping(), static_cast( + IREE_HAL_MEMORY_ACCESS_READ)); + auto items = mapping->GetItems(mapping_obj, refs.get()); + + // Obtain pointer to first element in items + Py_buffer buffer; + if (PyObject_GetBuffer(items.ptr(), &buffer, PyBUF_SIMPLE) != 0) { + throw std::runtime_error("Failed to get buffer from items"); + } + void *itemsPtr = buffer.buf; + interface["data"] = + py::make_tuple(reinterpret_cast(itemsPtr), false); + return interface; + }) + .def("__repr__", &device_array::to_s) + .def("__str__", [](device_array &self) -> std::string { + auto contents = self.contents_to_s(); + if (!contents) return "<>"; + return *contents; + }); + + BindArrayHostOps(m); +} + +} // namespace shortfin::python diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc new file mode 100644 index 000000000..3e2a8ebe3 --- /dev/null +++ b/shortfin/python/array_host_ops.cc @@ -0,0 +1,747 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "./lib_ext.h" +#include "./utils.h" +#include "shortfin/array/api.h" +#include "shortfin/support/logging.h" +#include "xtensor/xrandom.hpp" +#include "xtensor/xsort.hpp" +#include "xtl/xhalf_float.hpp" + +using namespace shortfin::array; + +namespace shortfin::python { + +namespace { + +static const char DOCSTRING_ARGMAX[] = + R"(Returns the indices of the maximum values along an axis. + +Implemented for dtypes: float16, float32. + +Args: + input: An input array. + axis: Axis along which to sort. Defaults to the last axis (note that the + numpy default is into the flattened array, which we do not support). + keepdims: Whether to preserve the sort axis. If true, this will become a unit + dim. If false, it will be removed. + out: Array to write into. If specified, it must have an expected shape and + int64 dtype. + device_visible: Whether to make the result array visible to devices. Defaults to + False. + +Returns: + A device_array of dtype=int64, allocated on the host and not visible to the device. +)"; + +static const char DOCSTRING_CONVERT[] = + R"(Does an elementwise conversion from one dtype to another. + +The same behavior exists for several conversion ops: + +* `convert` : element-wise conversion like a static cast. +* `round` : element-wise nearest integer to the input, rounding halfway cases + away from zero. +* `ceil` : element-wise smallest integer value not less than the input. +* `floor` : element-wise smallest integer value not greater than the input. +* `trunc` : element-wise nearest integer not greater in magnitude than the input. + +For nearest-integer conversions (round, ceil, floor, trunc), the input dtype +must be a floating point array, and the output must be a byte-aligned integer +type between 8 and 32 bits. + +Args: + input: An input array of a floating point dtype. + dtype: If given, then this is the explicit output dtype. + out: If given, then the results are written to this array. This implies the + output dtype. + device_visible: Whether to make the result array visible to devices. Defaults to + False. + +Returns: + A device_array of the requested dtype, or the input dtype if not specified. +)"; + +static const char DOCSTRING_FILL_RANDN[] = + R"(Fills an array with numbers sampled from the standard ormal distribution. + +Values are samples with a mean of 0 and standard deviation of 1. + +This operates like torch.randn but only supports in place fills to an existing +array, deriving shape and dtype from the output array. + +Args: + out: Output array to fill. + generator: Uses an explicit generator. If not specified, uses a global + default. +)"; + +static const char DOCSTRING_RANDOM_GENERATOR[] = + R"(Returns an object for generating random numbers. + + Every instance is self contained and does not share state with others. + + Args: + seed: Optional seed for the generator. Not setting a seed will cause an + implementation defined value to be used, which may in fact be a completely + fixed number. + )"; + +static const char DOCSTRING_TRANSPOSE[] = + R"(Transposes axes of an array according to a permutation vector. + +Args: + input: Array to transpose. + permutation: New sequence of axes. Must have same number of elements as the + rank of input. + out: If given, then the results are written to this array. + device_visible: Whether to make the result array visible to devices. Defaults + to False. +)"; + +#define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): \ + return compute.template operator()() + +#define SF_UNARY_THUNK_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): \ + compute.template operator()(); \ + break + +#define SF_MOVEMENT_OP_SWITCH(dtype) \ + if (!dtype.is_byte_aligned()) \ + throw std::invalid_argument( \ + "data movement ops are only defined for byte aligned dtypes"); \ + switch (dtype.dense_byte_count()) { \ + case 1: \ + return compute.template operator()(); \ + case 2: \ + return compute.template operator()(); \ + case 4: \ + return compute.template operator()(); \ + case 8: \ + return compute.template operator()(); \ + default: \ + throw std::invalid_argument( \ + "data movement ops are only defined for dtypes of size 1, 2, " \ + "4, 8"); \ + } + +struct PyRandomGenerator { + public: + using SeedType = xt::random::default_engine_type::result_type; + PyRandomGenerator(std::optional seed) { + if (seed) SetSeed(*seed); + } + + static PyRandomGenerator &get_default() { + static PyRandomGenerator default_generator(std::nullopt); + return default_generator; + } + + void SetSeed(SeedType seed) { engine().seed(seed); } + + xt::random::default_engine_type &engine() { return engine_; } + + private: + xt::random::default_engine_type engine_; +}; + +// Generic conversion templates, split into a bindable template and functors +// that operate on pre-allocated outputs. +template +device_array GenericElementwiseConvert(device_array &input, + std::optional dtype, + std::optional out, + bool device_visible) { + // Argument check and output allocation. + if (!dtype) { + dtype = out ? out->dtype() : input.dtype(); + } else { + if (out && out->dtype() != dtype) { + throw std::invalid_argument( + "if both dtype and out are specified, they must match"); + } + } + if (!out) { + out.emplace(device_array::for_host(input.device(), input.shape(), *dtype, + device_visible)); + } + + ConvertFunc::Invoke(input, *dtype, *out); + return *out; +} + +// Generic elementwise conversion functor +struct ConvertFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::convert"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(*input_t); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(float16, half_float::half); + SF_STORE_CASE(float32, float); + SF_STORE_CASE(float64, double); + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + SF_STORE_CASE(uint64, uint64_t); + SF_STORE_CASE(int64, int64_t); + default: + throw std::invalid_argument("Invalid output dtype for convert op"); + } + +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + SF_UNARY_THUNK_CASE(float64, double); + SF_UNARY_THUNK_CASE(uint8, uint8_t); + SF_UNARY_THUNK_CASE(int8, int8_t); + SF_UNARY_THUNK_CASE(uint16, uint16_t); + SF_UNARY_THUNK_CASE(int16, int16_t); + SF_UNARY_THUNK_CASE(uint32, uint32_t); + SF_UNARY_THUNK_CASE(int32, uint32_t); + SF_UNARY_THUNK_CASE(uint64, uint64_t); + SF_UNARY_THUNK_CASE(int64, int64_t); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +// Converting round functor. +struct ConvertRoundFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::round"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::round(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +struct ConvertCeilFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::ceil"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::ceil(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +struct ConvertFloorFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::floor"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::floor(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +struct ConvertTruncFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::trunc"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::trunc(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +void OptionalArrayCast(py::handle handle, + std::optional &maybe_array) { + if (py::isinstance(handle)) { + maybe_array.emplace(py::cast(handle)); + } +} + +int DTypePromotionRank(DType dtype) { + int rank = 1; + if (dtype.is_boolean()) + rank *= 1000; + else if (dtype.is_integer()) + rank *= 2000; + else if (dtype.is_float()) + rank *= 4000; + else if (dtype.is_complex()) + rank *= 8000; + return rank + dtype.bit_count(); +} + +DType PromoteArithmeticTypes(std::optional lhs_dtype, + std::optional rhs_dtype) { + if (!lhs_dtype && !rhs_dtype) { + throw std::invalid_argument( + "Elementwise operators require at least one argument to be a " + "device_array"); + } + + // One not an array: promote to the array type. + if (!lhs_dtype) + return *rhs_dtype; + else if (!rhs_dtype) + return *lhs_dtype; + + int lhs_rank = DTypePromotionRank(*lhs_dtype); + int rhs_rank = DTypePromotionRank(*rhs_dtype); + DType promoted_dtype = lhs_rank < rhs_rank ? *rhs_dtype : *lhs_dtype; + + // If mismatched signed/unsigned, then need to promote to the next signed + // dtype. + if (promoted_dtype.is_integer()) { + bool lhs_unsigned = iree_all_bits_set( + lhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED); + bool rhs_unsigned = iree_all_bits_set( + rhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED); + if ((lhs_unsigned || rhs_unsigned) && !(lhs_unsigned && rhs_unsigned)) { + // Signed/unsigned mismatch. Promote to next. + switch (promoted_dtype) { + case DType::uint8(): + case DType::int8(): + return DType::int16(); + case DType::uint16(): + case DType::int16(): + return DType::int32(); + case DType::uint32(): + case DType::int32(): + return DType::int64(); + default: + // Jax's type promotion chart says this goes to a weak FP type, but + // we don't implement such a construct and I don't really see how + // that makes sense in a system setting like this, so we just saturate + // to 64bit. + return DType::int64(); + } + } + } + + return promoted_dtype; +} + +// ---------------------------------------------------------------------------// +// Elementwise support +// ---------------------------------------------------------------------------// + +// Python element type scalar conversion functions. +uint8_t ConvertPyToEltTy(py::handle py_value, uint8_t zero) { + return py::cast(py_value); +} + +int8_t ConvertPyToEltTy(py::handle py_value, int8_t zero) { + return py::cast(py_value); +} + +uint16_t ConvertPyToEltTy(py::handle py_value, uint16_t zero) { + return py::cast(py_value); +} + +int16_t ConvertPyToEltTy(py::handle py_value, int16_t zero) { + return py::cast(py_value); +} + +uint32_t ConvertPyToEltTy(py::handle py_value, uint32_t zero) { + return py::cast(py_value); +} + +int32_t ConvertPyToEltTy(py::handle py_value, int32_t zero) { + return py::cast(py_value); +} + +uint64_t ConvertPyToEltTy(py::handle py_value, uint64_t zero) { + return py::cast(py_value); +} + +int64_t ConvertPyToEltTy(py::handle py_value, int64_t zero) { + return py::cast(py_value); +} + +float ConvertPyToEltTy(py::handle py_value, float zero) { + return py::cast(py_value); +} + +double ConvertPyToEltTy(py::handle py_value, double zero) { + return py::cast(py_value); +} + +half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) { + // Python can't cast directly to half so first go to double. + return static_cast(py::cast(py_value)); +} + +struct AddFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs + rhs; + } +}; + +struct DivideFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs / rhs; + } +}; + +struct MultiplyFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs * rhs; + } +}; + +struct SubtractFunctor { + template + static auto Invoke(Lhs &&lhs, Rhs &&rhs) { + return lhs - rhs; + } +}; + +template +device_array ElementwiseOperation(py::handle lhs, py::handle rhs, + std::optional out, + bool device_visible) { + std::optional lhs_array; + OptionalArrayCast(lhs, lhs_array); + std::optional rhs_array; + OptionalArrayCast(rhs, rhs_array); + auto dtype = PromoteArithmeticTypes( + lhs_array ? std::optional(lhs_array->dtype()) : std::nullopt, + rhs_array ? std::optional(rhs_array->dtype()) : std::nullopt); + if (lhs_array && lhs_array->dtype() != dtype) { + auto converted = GenericElementwiseConvert( + *lhs_array, dtype, /*out=*/std::nullopt, + /*device_visible=*/false); + lhs_array.reset(); + lhs_array.emplace(std::move(converted)); + } + if (rhs_array && rhs_array->dtype() != dtype) { + auto converted = GenericElementwiseConvert( + *rhs_array, dtype, /*out=*/std::nullopt, + /*device_visible=*/false); + rhs_array.reset(); + rhs_array.emplace(std::move(converted)); + } + + auto compute = [&]() -> device_array { + auto handle_result = [&]( + D &&device, A &&result) -> device_array { + if (!out) { + out.emplace(device_array::for_host(device, result.shape(), dtype, + device_visible)); + } + auto out_t = out->map_xtensor_w(); + *out_t = result; + return *out; + }; + if (!rhs_array) { + auto lhs_t = lhs_array->map_xtensor(); + xt::xarray rhs_scalar = ConvertPyToEltTy(rhs, EltTy()); + return handle_result(lhs_array->device(), + ElementwiseFunctor::Invoke(*lhs_t, rhs_scalar)); + } else if (!lhs_array) { + xt::xarray lhs_scalar = ConvertPyToEltTy(lhs, EltTy()); + auto rhs_t = rhs_array->map_xtensor(); + return handle_result(rhs_array->device(), + ElementwiseFunctor::Invoke(lhs_scalar, *rhs_t)); + } else { + auto lhs_t = lhs_array->map_xtensor(); + auto rhs_t = rhs_array->map_xtensor(); + return handle_result(lhs_array->device(), + ElementwiseFunctor::Invoke(*lhs_t, *rhs_t)); + } + }; + + switch (dtype) { + SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(float32, float); + SF_UNARY_FUNCTION_CASE(float64, double); + SF_UNARY_FUNCTION_CASE(uint8, uint8_t); + SF_UNARY_FUNCTION_CASE(int8, int8_t); + SF_UNARY_FUNCTION_CASE(uint16, uint16_t); + SF_UNARY_FUNCTION_CASE(int16, int16_t); + SF_UNARY_FUNCTION_CASE(uint32, uint32_t); + SF_UNARY_FUNCTION_CASE(int32, uint32_t); + SF_UNARY_FUNCTION_CASE(uint64, uint64_t); + SF_UNARY_FUNCTION_CASE(int64, int64_t); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for in elementwise op", dtype.name())); + } +} + +} // namespace + +void BindArrayHostOps(py::module_ &m) { + // Simple op definitions. + m.def( + "argmax", + [](device_array &input, int axis, std::optional out, + bool keepdims, bool device_visible) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::argmax"); + if (axis < 0) axis += input.shape().size(); + if (axis < 0 || axis >= input.shape().size()) { + throw std::invalid_argument( + fmt::format("Axis out of range: Must be [0, {}) but got {}", + input.shape().size(), axis)); + } + if (out && (out->dtype() != DType::int64())) { + throw std::invalid_argument("out array must have dtype=int64"); + } + auto compute = [&]() { + auto input_t = input.map_xtensor(); + auto result = xt::argmax(*input_t, axis); + if (!out) { + out.emplace(device_array::for_host(input.device(), result.shape(), + DType::int64(), device_visible)); + } + auto out_t = out->map_xtensor_w(); + *out_t = result; + if (keepdims) { + out->expand_dims(axis); + } + return *out; + }; + + switch (input.dtype()) { + SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(float32, float); + default: + throw std::invalid_argument( + fmt::format("Unsupported dtype({}) for operator argmax", + input.dtype().name())); + } + }, + py::arg("input"), py::arg("axis") = -1, py::arg("out") = py::none(), + py::kw_only(), py::arg("keepdims") = false, + py::arg("device_visible") = false, DOCSTRING_ARGMAX); + + // Random number generation. + py::class_(m, "RandomGenerator") + .def(py::init>(), + py::arg("seed") = py::none(), DOCSTRING_RANDOM_GENERATOR); + m.def( + "fill_randn", + [](device_array out, std::optional gen) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::fill_randn"); + if (!gen) gen = &PyRandomGenerator::get_default(); + auto compute = [&]() { + auto result = xt::random::randn(out.shape_container(), /*mean=*/0.0, + /*std_dev=*/1.0, (*gen)->engine()); + auto out_t = out.map_xtensor_w(); + *out_t = result; + }; + + switch (out.dtype()) { + SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(float32, float); + default: + throw std::invalid_argument( + fmt::format("Unsupported dtype({}) for operator randn", + out.dtype().name())); + } + }, + py::arg("out"), py::arg("generator") = py::none(), DOCSTRING_FILL_RANDN); + +// Data-type conversion and rounding. +#define SF_DEF_CONVERT(py_name, target) \ + m.def(py_name, target, py::arg("input"), py::kw_only(), \ + py::arg("dtype") = py::none(), py::arg("out") = py::none(), \ + py::arg("device_visible") = false, DOCSTRING_CONVERT) + SF_DEF_CONVERT("convert", GenericElementwiseConvert); + SF_DEF_CONVERT("ceil", GenericElementwiseConvert); + SF_DEF_CONVERT("floor", GenericElementwiseConvert); + SF_DEF_CONVERT("round", GenericElementwiseConvert); + SF_DEF_CONVERT("trunc", GenericElementwiseConvert); + + // Transpose. + m.def( + "transpose", + [](device_array input, std::vector permutation, + std::optional out, bool device_visible) { + auto compute = [&]() -> device_array { + auto input_t = input.map_xtensor(); + auto permuted_t = + xt::transpose(*input_t, permutation, xt::check_policy::full()); + if (!out) { + out.emplace(device_array::for_host(input.device(), + permuted_t.shape(), + input.dtype(), device_visible)); + } + auto out_t = out->map_xtensor_w(); + *out_t = permuted_t; + return *out; + }; + SF_MOVEMENT_OP_SWITCH(input.dtype()); + }, + py::arg("input"), py::arg("permutation"), py::arg("out") = py::none(), + py::arg("device_visible") = false, DOCSTRING_TRANSPOSE); + +// Elementwise. +#define SF_DEF_ELEMENTWISE(py_name, target) \ + m.def(py_name, target, py::arg("lhs"), py::arg("rhs"), py::kw_only(), \ + py::arg("out") = py::none(), py::arg("device_visible") = false) + SF_DEF_ELEMENTWISE("add", ElementwiseOperation); + SF_DEF_ELEMENTWISE("divide", ElementwiseOperation); + SF_DEF_ELEMENTWISE("multiply", ElementwiseOperation); + SF_DEF_ELEMENTWISE("subtract", ElementwiseOperation); + +} // namespace shortfin::python + +} // namespace shortfin::python diff --git a/libshortfin/bindings/python/lib_ext.cc b/shortfin/python/lib_ext.cc similarity index 52% rename from libshortfin/bindings/python/lib_ext.cc rename to shortfin/python/lib_ext.cc index cc7c21190..c668e6a8b 100644 --- a/libshortfin/bindings/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -10,10 +10,10 @@ #include "shortfin/array/array.h" #include "shortfin/array/storage.h" #include "shortfin/local/async.h" +#include "shortfin/local/fiber.h" #include "shortfin/local/messaging.h" #include "shortfin/local/process.h" #include "shortfin/local/program.h" -#include "shortfin/local/scope.h" #include "shortfin/local/system.h" #if defined(SHORTFIN_HAVE_AMDGPU) #include "shortfin/local/systems/amdgpu.h" @@ -26,6 +26,82 @@ namespace shortfin::python { namespace { +static const char DOCSTRING_SYSTEM_CTOR[] = + R"(Constructs a System based on system type and kwargs. + +System types depend on how the library was compiled and correspond to +SystemBuilder classes. This API is a shorthand for creating a SystemBuilder +and calling create_system() on it. +)"; + +static const char DOCSTRING_SYSTEM_BUILDER_CTOR[] = + R"(Constructs a system with environment dependent configuration. + +Most configuration is done by way of key/value arguments. Arguments are meant +to be derived from flags or config files and are expected to be simple strings +or integer values: + + * "system_type": A supported system type, corresponding to a subclass of + SystemBuilder. Typically available values include "hostcpu", "amdgpu", etc. + * Other keywords are passed to the concrete SystemBuilder subclass. + +Resolution for the `system_type` is special and is checked in three steps, with +the first defined winning: + +1. `kwargs["system_type"]` +2. Environment `SHORTFIN_SYSTEM_TYPE` variable (or as controlled by `env_prefix`) +3. `SystemBuilder.default_system_type` as a process-wide default. + +Args: + env_prefix: Controls how options are looked up in the environment. By default, + the prefix is "SHORTFIN_" and upper-cased options are appended to it. Any + option not explicitly specified but in the environment will be used. Pass + None to disable environment lookup. + **kwargs: Key/value arguments for controlling setup of the system. +)"; + +static const char DOCSTRING_HOSTCPU_SYSTEM_BUILDER_CTOR[] = + R"(Constructs a system with CPU based devices. + +Most configuration is done by way of key/value arguments. Arguments are meant +to be derived from flags or config files and are expected to be simple strings +or integer values: + + * "hostcpu_topology_nodes": Takes one of the special values "current" (default) + or "all". If not one of those, this should be a comma-delimited list of + NUMA node ids. Each NUMA node will be modeled as one device queue and will + show up on the system as a device. + * "hostcpu_topology_max_group_count": Maximum number of groups to create per + node. The actual number of groups is derived by a heuristic (which can be + influenced by other options) such that there will not be more groups than + eligible physical cores on the node. + +Args: + env_prefix: Controls how options are looked up in the environment. By default, + the prefix is "SHORTFIN_" and upper-cased options are appended to it. Any + option not explicitly specified but in the environment will be used. Pass + None to disable environment lookup. + **kwargs: Key/value arguments for controlling setup of the system. +)"; + +static const char DOCSTRING_HOSTCPU_SYSTEM_BUILDER_HOSTCPU_ALLOCATOR_SPECS[] = + R"(Allocator specs to apply to HOSTCPU devices configured by this builder. + +This uses syntax like:: + + some_allocator + some_allocator:key=value + some_allocator:key=value,key=value + some_allocator:key=value,key=value;other_allocator:key=value + +Typical values for `some_allocator` include `caching` and `debug`. + +This can be set via a keyword of `amdgpu_allocators`, which will only apply to +HOSTCPU devices or `allocators` which will apply to all contained devices. +Similarly, it is available on a `SHORTFIN_` prefixed env variable if environment +lookup is not disabled. +)"; + static const char DOCSTRING_PROGRAM_FUNCTION_INVOCATION[] = R"(Creates an invocation object targeting the function. @@ -58,8 +134,19 @@ class Refs { return lazy_PyWorkerEventLoop_; } + std::optional default_system_type() { + iree::slim_mutex_lock_guard g(mu_); + return default_system_type_; + } + void set_default_system_type(std::optional st) { + iree::slim_mutex_lock_guard g(mu_); + default_system_type_ = st; + } + private: + iree::slim_mutex mu_; py::object lazy_PyWorkerEventLoop_; + std::optional default_system_type_; }; // We need a fair bit of accounting additions in order to make Workers usable @@ -86,6 +173,7 @@ class PyWorkerExtension : public local::Worker::Extension { py::handle loop() { return loop_; } void OnThreadStart() noexcept override { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStart"); // Python threading initialization. // If our own thread, teach Python about it. Not done for donated. if (worker().options().owned_thread) { @@ -100,6 +188,7 @@ class PyWorkerExtension : public local::Worker::Extension { } void OnThreadStop() noexcept override { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStop"); { // Do Python level thread cleanup. py::gil_scoped_acquire g; @@ -142,10 +231,18 @@ class PyWorkerExtension : public local::Worker::Extension { class PyProcess : public local::detail::BaseProcess { public: - PyProcess(std::shared_ptr scope, std::shared_ptr refs) - : BaseProcess(std::move(scope)), refs_(std::move(refs)) {} + PyProcess(std::shared_ptr refs) + : BaseProcess(), refs_(std::move(refs)) {} + using BaseProcess::Initialize; + using BaseProcess::is_initialized; using BaseProcess::Launch; + void AssertInitialized() { + if (!is_initialized()) { + throw std::logic_error("Process.__init__ not called in constructor"); + } + } + void ScheduleOnWorker() override { // This is tricky: We need to retain the object reference across the // thread transition, but on the receiving side, the GIL will not be @@ -154,14 +251,15 @@ class PyProcess : public local::detail::BaseProcess { // the callback. py::handle self_object = py::cast(this, py::rv_policy::none); self_object.inc_ref(); - scope()->worker().CallThreadsafe( + fiber()->worker().CallThreadsafe( std::bind(&PyProcess::RunOnWorker, self_object)); } static void RunOnWorker(py::handle self_handle) { + SHORTFIN_TRACE_SCOPE_NAMED("PyProcess:RunOnWorker"); py::gil_scoped_acquire g; // Steal the reference back from ScheduleOnWorker. Important: this is // very likely the last reference to the process. So self must not be - // touched after self_object goes out of scope. + // touched after self_object goes out of fiber. py::object self_object = py::steal(self_handle); PyProcess *self = py::cast(self_handle); // We assume that the run method either returns None (def) or a coroutine @@ -207,9 +305,10 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) { py::cast(py::repr(arg.type())))); } -local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self, - py::args args) { - auto inv = self.CreateInvocation(); +local::ProgramInvocation::Future PyFunctionCall( + local::ProgramFunction &self, py::args args, local::Fiber &fiber, + std::optional isolation) { + auto inv = self.CreateInvocation(fiber.shared_from_this(), isolation); py::capsule inv_capsule(inv.get()); for (py::handle arg : args) { PyAddProgramInvocationArg(inv_capsule, arg); @@ -246,6 +345,7 @@ py::object PyRehydrateRef(local::ProgramInvocation *inv, py::object RunInForeground(std::shared_ptr refs, local::System &self, py::object coro) { + SHORTFIN_TRACE_SCOPE_NAMED("CoroRunInForeground"); bool is_main_thread = refs->threading_current_thread().is(refs->threading_main_thread()); @@ -311,9 +411,35 @@ py::object RunInForeground(std::shared_ptr refs, local::System &self, return result; } +ConfigOptions CreateConfigOptions(std::optional &env_prefix, + py::kwargs &kwargs, bool validate_undef) { + ConfigOptions options(std::move(env_prefix), + validate_undef + ? ConfigOptions::ValidationLevel::UNDEF_ERROR + : ConfigOptions::ValidationLevel::UNDEF_WARN); + for (auto it = kwargs.begin(); it != kwargs.end(); ++it) { + std::string key = py::cast((*it).first); + std::string value = py::cast(py::str((*it).second)); + options.SetOption(std::move(key), std::move(value)); + } + return options; +} + } // namespace NB_MODULE(lib, m) { + // Tragically, debug builds of Python do the right thing and don't immortalize + // many identifiers and such. This makes the last chance leak checking that + // nanobind does somewhat unreliable since the reports it prints may be + // to identifiers that are no longer live (at a time in process shutdown + // where it is expected that everything left just gets dropped on the floor). + // This causes segfaults or ASAN violations in the leak checker on exit in + // certain scenarios where we have spurious "leaks" of global objects. + + py::set_leak_warnings(false); + + logging::InitializeFromEnv(); + py::register_exception_translator( [](const std::exception_ptr &p, void * /*unused*/) { try { @@ -339,6 +465,13 @@ NB_MODULE(lib, m) { }); py::class_(m, "_OpaqueVmRef"); + + // Logging entrypoints. + m.def("log_debug", [](std::string_view sv) { logging::debug("{}", sv); }); + m.def("log_info", [](std::string_view sv) { logging::info("{}", sv); }); + m.def("log_warn", [](std::string_view sv) { logging::warn("{}", sv); }); + m.def("log_error", [](std::string_view sv) { logging::error("{}", sv); }); + auto local_m = m.def_submodule("local"); BindLocal(local_m); BindHostSystem(local_m); @@ -371,16 +504,94 @@ void BindLocal(py::module_ &m) { std::make_unique(worker, interp_state, refs)); }; + py::enum_(m, "ProgramIsolation") + .value("NONE", local::ProgramIsolation::NONE) + .value("PER_FIBER", local::ProgramIsolation::PER_FIBER) + .value("PER_CALL", local::ProgramIsolation::PER_CALL) + .export_values(); + py::class_(m, "SystemBuilder") + .def("__init__", [](py::args, py::kwargs) {}) + .def_static( + "__new__", + [refs](py::handle cls, std::optional env_prefix, + bool validate_undef, py::kwargs kwargs) { + auto options = + CreateConfigOptions(env_prefix, kwargs, validate_undef); + std::optional system_type = + options.GetOption("system_type"); + std::optional default_system_type; + if (!system_type) { + default_system_type = refs->default_system_type(); + if (!default_system_type) { + throw std::invalid_argument( + "In order to construct a generic SystemBuilder, a " + "`system_type=` keyword (or appropriate environment " + "variable) must be specified. Alternatively, a default can " + "be set process wide via " + "`SystemBuilder.default_system_type =`"); + } + system_type = *default_system_type; + } + return local::SystemBuilder::ForSystem( + iree_allocator_system(), *system_type, std::move(options)); + }, + // Note that for some reason, no-arg construction passes no arguments + // to __new__. We allow the single positional argument to be none, + // which satisfies this case in practice. + py::arg("cls") = py::none(), py::kw_only(), + py::arg("env_prefix").none() = "SHORTFIN_", + py::arg("validate_undef") = true, py::arg("kwargs"), + DOCSTRING_SYSTEM_BUILDER_CTOR) + .def_prop_rw_static( + "default_system_type", + [refs](py::handle /*unused*/) { return refs->default_system_type(); }, + [refs](py::handle /*unused*/, std::optional st) { + refs->set_default_system_type(std::move(st)); + }) .def("create_system", [live_system_refs, worker_initializer](local::SystemBuilder &self) { auto system_ptr = self.CreateSystem(); system_ptr->AddWorkerInitializer(worker_initializer); auto system_obj = py::cast(system_ptr, py::rv_policy::take_ownership); live_system_refs.attr("add")(system_obj); + try { + self.config_options().ValidateUndef(); + } catch (...) { + system_obj.attr("shutdown")(); + throw; + } return system_obj; }); py::class_(m, "System", py::is_weak_referenceable()) + .def( + "__init__", + [live_system_refs](py::object self_obj, py::args, py::kwargs) { + live_system_refs.attr("add")(self_obj); + }, + DOCSTRING_SYSTEM_CTOR) + .def_static( + "__new__", + [worker_initializer](py::handle py_type, std::string_view system_type, + std::optional env_prefix, + bool validate_undef, py::kwargs kwargs) { + auto options = + CreateConfigOptions(env_prefix, kwargs, validate_undef); + auto system = local::System::Create( + iree_allocator_system(), system_type, std::move(options)); + system->AddWorkerInitializer(worker_initializer); + return system; + }, + py::arg("type"), py::arg("system_type"), py::kw_only(), + py::arg("env_prefix") = "SHORTFIN_", py::arg("validate_undef") = true, + py::arg("kwargs")) + .def("__enter__", [](py::object self_obj) { return self_obj; }) + .def( + "__exit__", + [](local::System &self, py::handle exc_type, py::handle exc_value, + py::handle exc_tb) { self.Shutdown(); }, + py::arg("exc_type").none(), py::arg("exc_value").none(), + py::arg("exc_tb").none()) .def("shutdown", &local::System::Shutdown) // Access devices by list, name, or lookup. .def_prop_ro("device_names", @@ -396,30 +607,33 @@ void BindLocal(py::module_ &m) { .def( "device", [](local::System &self, std::string_view key) { - auto it = self.named_devices().find(key); - if (it == self.named_devices().end()) { + local::Device *device = self.FindDeviceByName(key); + if (!device) { throw std::invalid_argument(fmt::format("No device '{}'", key)); } - return it->second; + return device; }, py::rv_policy::reference_internal) .def( "create_queue", - [](local::System &self, std::string name) -> local::Queue & { + [](local::System &self, + std::optional name) -> std::shared_ptr { local::Queue::Options options; - options.name = std::move(name); + if (name) { + options.name = std::move(*name); + } return self.CreateQueue(std::move(options)); }, - py::arg("name"), py::rv_policy::reference_internal) + py::arg("name") = py::none(), py::rv_policy::reference_internal) .def("named_queue", &local::System::named_queue, py::arg("name"), py::rv_policy::reference_internal) .def( - "create_scope", + "create_fiber", [](local::System &self, local::Worker *worker, py::handle raw_devices) { // TODO: I couldn't really figure out how to directly accept an // optional kw-only arg without it just being a raw object/handle. - // If the passed devices is none, then we create the scope with + // If the passed devices is none, then we create the fiber with // all devices in the system. Otherwise, with those explicitly // given. std::vector devices; @@ -434,7 +648,7 @@ void BindLocal(py::module_ &m) { worker = dynamic_cast(&self.init_worker()); } - return self.CreateScope(*worker, devices); + return self.CreateFiber(*worker, devices); }, py::rv_policy::reference_internal, py::arg("worker").none() = py::none(), py::kw_only(), @@ -469,7 +683,6 @@ void BindLocal(py::module_ &m) { py::class_(m, "Device") .def_prop_ro("name", &local::Device::name) .def_prop_ro("node_affinity", &local::Device::node_affinity) - .def_prop_ro("node_locked", &local::Device::node_locked) .def(py::self == py::self) .def("__repr__", &local::Device::to_s); py::class_(m, "DeviceAffinity") @@ -481,31 +694,55 @@ void BindLocal(py::module_ &m) { .def("__repr__", &local::DeviceAffinity::to_s); py::class_(m, "Program") - .def(py::new_([](std::span modules, - local::Scope &scope, bool trace_execution) { - local::Program::Options options; - options.trace_execution = trace_execution; - return local::Program::Load(scope.shared_from_this(), modules, - std::move(options)); - }), - py::arg("modules"), py::arg("scope"), py::kw_only(), - py::arg("trace_execution") = false) + .def( + py::new_([](std::span modules, + std::vector devices, + bool trace_execution, local::ProgramIsolation isolation) { + local::Program::Options options; + options.devices = devices; + options.trace_execution = trace_execution; + options.isolation = isolation; + return local::Program::Load(modules, std::move(options)); + }), + py::arg("modules"), py::kw_only(), py::arg("devices"), + py::arg("trace_execution") = false, + py::arg("isolation") = local::ProgramIsolation::PER_FIBER) .def_prop_ro("exports", &local::Program::exports) + .def_prop_ro("isolation", &local::Program::isolation) .def("lookup_function", &local::Program::LookupRequiredFunction) .def("__getitem__", &local::Program::LookupRequiredFunction); py::class_(m, "ProgramFunction") .def_prop_ro("name", &local::ProgramFunction::name) .def_prop_ro("calling_convention", &local::ProgramFunction::calling_convention) - .def("invocation", &local::ProgramFunction::CreateInvocation, - DOCSTRING_PROGRAM_FUNCTION_INVOCATION) - .def("__call__", PyFunctionCall, py::arg("args")) + .def( + "invocation", + [](local::ProgramFunction &self, local::Fiber &fiber, + std::optional isolation) { + return self.CreateInvocation(fiber.shared_from_this(), isolation); + }, + py::arg("fiber"), py::arg("isolation") = py::none(), + DOCSTRING_PROGRAM_FUNCTION_INVOCATION) + .def_prop_ro("isolation", &local::ProgramFunction::isolation) + .def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(), + py::arg("fiber"), py::arg("isolation") = py::none()) .def("__repr__", &local::ProgramFunction::to_s); py::class_(m, "ProgramModule") .def_prop_ro("exports", &local::ProgramModule::exports) .def("__repr__", &local::ProgramModule::to_s) .def_static("load", &local::ProgramModule::Load, py::arg("system"), - py::arg("path"), py::arg("mmap") = true); + py::arg("path"), py::arg("mmap") = true) + .def_static( + "parameter_provider", + [](local::System &system, py::args params) { + std::vector c_params; + c_params.reserve(params.size()); + for (py::handle h : params) { + c_params.push_back(py::cast(h)); + } + return local::ProgramModule::ParameterProvider(system, c_params); + }, + py::arg("system"), py::arg("params")); py::class_(m, "ProgramInvocation") .def("invoke", [](local::ProgramInvocation::Ptr &self) { @@ -552,41 +789,97 @@ void BindLocal(py::module_ &m) { } return PyRehydrateRef(self.get(), std::move(ref)); }, - "Gets the i'th result"); + "Gets the i'th result") + .def("__repr__", [](local::ProgramInvocation::Ptr &self) { + if (!self) return std::string("ProgramInvocation(INVALID)"); + return self->to_s(); + }); + + py::class_(m, "BaseProgramParameters"); + py::class_( + m, "StaticProgramParameters") + .def( + py::init(), + py::arg("system"), py::arg("parameter_scope"), + py::arg("max_concurrent_operations") = + IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS) + .def( + "load", + [](local::StaticProgramParameters &self, + std::filesystem::path file_path, std::string_view format, + bool readable, bool writable, bool mmap) { + local::StaticProgramParameters::LoadOptions options; + options.format = format; + options.readable = readable; + options.writable = writable; + options.mmap = mmap; + self.Load(file_path, options); + }, + py::arg("file_path"), py::arg("format") = std::string_view(), + py::arg("readable") = true, py::arg("writable") = false, + py::arg("mmap") = true); struct DevicesSet { - DevicesSet(local::Scope &scope) : scope(scope) {} - local::Scope &scope; + DevicesSet(py::object fiber_obj, std::optional index = {}) + : fiber_obj(std::move(fiber_obj)), index(index) {} + py::object KeepAlive(local::ScopedDevice device) { + py::object device_obj = py::cast(device); + py::detail::keep_alive(/*nurse=*/device_obj.ptr(), + /*patient=*/fiber_obj.ptr()); + return device_obj; + } + local::Fiber &fiber() { return py::cast(fiber_obj); } + py::object fiber_obj; + std::optional index; }; - py::class_(m, "Scope") - .def("__repr__", &local::Scope::to_s) - .def_prop_ro("raw_devices", &local::Scope::raw_devices, - py::rv_policy::reference_internal) + py::class_(m, "Fiber") + .def("__repr__", &local::Fiber::to_s) + .def_prop_ro( + "raw_devices", + [](local::Fiber &self) { + std::vector devices; + devices.reserve(self.raw_devices().size()); + for (auto it : self.raw_devices()) { + devices.push_back(it.second); + } + return devices; + }, + py::rv_policy::reference_internal) .def( "raw_device", - [](local::Scope &self, int index) { return self.raw_device(index); }, + [](local::Fiber &self, int index) { return self.raw_device(index); }, py::rv_policy::reference_internal) .def( "raw_device", - [](local::Scope &self, std::string_view name) { + [](local::Fiber &self, std::string_view name) { return self.raw_device(name); }, py::rv_policy::reference_internal) - .def_prop_ro( - "devices", [](local::Scope &self) { return DevicesSet(self); }, - py::rv_policy::reference_internal) - .def_prop_ro("device_names", &local::Scope::device_names) - .def_prop_ro("named_devices", &local::Scope::named_devices, - py::rv_policy::reference_internal) + .def_prop_ro("devices", + [](py::object self) { return DevicesSet(std::move(self)); }) + .def_prop_ro("devices_dict", + [](py::handle self_obj) { + local::Fiber &self = py::cast(self_obj); + py::dict d; + for (auto &it : self.raw_devices()) { + py::object scoped_device = + py::cast(self.device(it.second)); + py::detail::keep_alive(/*nurse=*/scoped_device.ptr(), + /*patient=*/self_obj.ptr()); + d[py::cast(it.first)] = scoped_device; + } + return d; + }) + .def_prop_ro("device_names", &local::Fiber::device_names) .def( "device", - [](local::Scope &self, py::args args) { + [](local::Fiber &self, py::args args) { return CastDeviceAffinity(self, args); }, py::rv_policy::reference_internal); py::class_(m, "ScopedDevice") - .def_prop_ro("scope", &local::ScopedDevice::scope, + .def_prop_ro("fiber", &local::ScopedDevice::fiber, py::rv_policy::reference) .def_prop_ro("affinity", &local::ScopedDevice::affinity, py::rv_policy::reference_internal) @@ -601,25 +894,35 @@ void BindLocal(py::module_ &m) { .def("__repr__", &local::ScopedDevice::to_s); py::class_(m, "_ScopeDevicesSet") + .def("__iter__", + [](DevicesSet &self) { return DevicesSet(self.fiber_obj, 0); }) + .def("__next__", + [](DevicesSet &self) { + auto &fiber = self.fiber(); + if (!self.index || *self.index >= fiber.raw_devices().size()) { + // Blurgh: Exception as flow control is not cheap. There is a + // very obnoxious way to make this not be exception based but + // this is a minority path. + throw py::stop_iteration(); + } + return self.KeepAlive(fiber.device((*self.index)++)); + }) .def("__len__", - [](DevicesSet &self) { return self.scope.raw_devices().size(); }) - .def( - "__getitem__", - [](DevicesSet &self, int index) { return self.scope.device(index); }, - py::rv_policy::reference_internal) - .def( - "__getitem__", - [](DevicesSet &self, std::string_view name) { - return self.scope.device(name); - }, - py::rv_policy::reference_internal) - .def( - "__getattr__", - [](DevicesSet &self, std::string_view name) { - return self.scope.device(name); - }, - py::rv_policy::reference_internal); + [](DevicesSet &self) { return self.fiber().raw_devices().size(); }) + .def("__getitem__", + [](DevicesSet &self, size_t index) { + return self.KeepAlive(self.fiber().device(index)); + }) + .def("__getitem__", + [](DevicesSet &self, std::string_view name) { + return self.KeepAlive(self.fiber().device(name)); + }) + .def("__getattr__", + [](DevicesSet &self, std::string_view name) -> py::object { + return self.KeepAlive(self.fiber().device(name)); + }); + ; py::class_(m, "Worker", py::is_weak_referenceable()) .def_prop_ro("loop", [](local::Worker &self) { @@ -636,6 +939,7 @@ void BindLocal(py::module_ &m) { callable.inc_ref(); // Stolen within the callback. auto thunk = +[](void *user_data, iree_loop_t loop, iree_status_t status) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::Callback"); py::gil_scoped_acquire g; py::object user_callable = py::steal(static_cast(user_data)); @@ -655,6 +959,7 @@ void BindLocal(py::module_ &m) { callable.inc_ref(); // Stolen within the callback. auto thunk = +[](void *user_data, iree_loop_t loop, iree_status_t status) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::DelayCallback"); py::gil_scoped_acquire g; py::object user_callable = py::steal(static_cast(user_data)); @@ -681,34 +986,56 @@ void BindLocal(py::module_ &m) { .def("__repr__", &local::Worker::to_s); py::class_(m, "Process") - .def("__init__", [](py::args, py::kwargs) {}) + .def( + "__init__", + [](py::handle self_obj, std::shared_ptr fiber) { + PyProcess &self = py::cast(self_obj); + self.Initialize(std::move(fiber)); + }, + py::kw_only(), py::arg("fiber")) .def_static( "__new__", - [refs](py::handle py_type, py::args, - std::shared_ptr scope, py::kwargs) { - return custom_new(py_type, std::move(scope), refs); + [refs](py::handle py_type, py::args, py::kwargs) { + return custom_new(py_type, refs); }, - py::arg("type"), py::arg("args"), py::arg("scope"), py::arg("kwargs")) + py::arg("type"), py::arg("args"), py::arg("kwargs")) .def_prop_ro("pid", &PyProcess::pid) - .def_prop_ro("scope", &PyProcess::scope) + .def_prop_ro("fiber", + [](PyProcess &self) -> std::shared_ptr { + self.AssertInitialized(); + return self.fiber(); + }) + .def_prop_ro("system", + [](PyProcess &self) { + self.AssertInitialized(); + return self.fiber()->system().shared_ptr(); + }) .def("launch", [](py::object self_obj) { PyProcess &self = py::cast(self_obj); + self.AssertInitialized(); self.Launch(); return self_obj; }) .def("__await__", [](PyProcess &self) { + self.AssertInitialized(); py::object future = py::cast(local::CompletionEvent(self.OnTermination()), py::rv_policy::move); return future.attr("__await__")(); }) - .def("__repr__", &PyProcess::to_s); + .def("__repr__", [](PyProcess &self) { + if (!self.is_initialized()) { + return std::string("Process(UNINITIALIZED)"); + } + return self.to_s(); + }); py::class_(m, "CompletionEvent") .def(py::init<>()) .def("__await__", [](py::handle self_obj) { + SHORTFIN_TRACE_SCOPE_NAMED("PyCompletionEvent::__await__"); auto &worker_ext = PyWorkerExtension::GetCurrent(); auto &self = py::cast(self_obj); py::object future = worker_ext.loop().attr("create_future")(); @@ -730,6 +1057,7 @@ void BindLocal(py::module_ &m) { self, iree_infinite_timeout(), +[](void *future_vp, iree_loop_t loop, iree_status_t status) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("PyCompletionEvent::OnComplete"); py::gil_scoped_acquire g; py::object future = py::steal(static_cast(future_vp)); try { @@ -758,28 +1086,30 @@ void BindLocal(py::module_ &m) { // happens, the owner struct is replaced and any C++ side reference counts // are turned into Python reference counts. Once transferred, only Python // reference counting is used, even if referenced from the C++ side. - py::intrusive_ptr( - [](local::Message *self, PyObject *self_py) noexcept { - local::detail::MessageRefOwner owner( - +[](local::detail::MessageRefOwner::Request req, - const local::Message &msg) { - py::gil_scoped_acquire g; - PyObject *msg_object = reinterpret_cast( - local::detail::MessageRefOwner::access_ref_data(msg)); - if (req == local::detail::MessageRefOwner::Request::RETAIN) { - py::handle(msg_object).inc_ref(); - } else { - py::handle(msg_object).dec_ref(); - } - }); - intptr_t orig_ref_data = - owner.set_owner(*self, reinterpret_cast(self_py)); - // Transfer any prior C++ references to the Python side (less 1 - // since we start with a live reference). - for (int i = 0; i < orig_ref_data - 1; ++i) { - py::handle(self_py).inc_ref(); - } - })) + py::intrusive_ptr([](local::Message *self, + PyObject *self_py) noexcept { + local::detail::MessageLifetimeController owner( + +[](local::detail::MessageLifetimeController::Request req, + const local::Message &msg) { + py::gil_scoped_acquire g; + PyObject *msg_object = reinterpret_cast( + local::detail::MessageLifetimeController::AccessOwnedRefData( + msg)); + if (req == + local::detail::MessageLifetimeController::Request::RETAIN) { + py::handle(msg_object).inc_ref(); + } else { + py::handle(msg_object).dec_ref(); + } + }); + intptr_t orig_ref_data = + owner.TakeOwnership(*self, reinterpret_cast(self_py)); + // Transfer any prior C++ references to the Python side (less 1 + // since we start with a live reference). + for (int i = 0; i < orig_ref_data - 1; ++i) { + py::handle(self_py).inc_ref(); + } + })) .def(py::init<>()); py::class_(m, "Queue") @@ -791,10 +1121,15 @@ void BindLocal(py::module_ &m) { py::type(), /*keep_alive=*/self, /*queue=*/self); }) - .def("reader", [](local::Queue &self) { - return custom_new_keep_alive( - py::type(), - /*keep_alive=*/self, /*queue=*/self); + .def("reader", + [](local::Queue &self) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/self, /*queue=*/self); + }) + .def_prop_ro("closed", &local::Queue::is_closed) + .def("write_nodelay", [](local::Queue &self, local::Message &message) { + self.WriteNoDelay(local::Message::Ref(message)); }); py::class_(m, "QueueWriter") .def("__call__", @@ -817,6 +1152,7 @@ void BindLocal(py::module_ &m) { return py::none(); }) .def("__await__", [](py::handle self_obj) { + SHORTFIN_TRACE_SCOPE_NAMED("PyFuture::__await__"); // TODO: We should make our C++ future able to be used directly // vs needing to bridge it like this. auto &worker_ext = PyWorkerExtension::GetCurrent(); @@ -838,6 +1174,7 @@ void BindLocal(py::module_ &m) { self.AddCallback( [py_future_vp = static_cast(future.release().ptr())]( local::Future &sf_future) { + SHORTFIN_TRACE_SCOPE_NAMED("PyFuture::OnComplete"); py::gil_scoped_acquire g; py::object py_future = py::steal(static_cast(py_future_vp)); @@ -855,7 +1192,9 @@ void BindLocal(py::module_ &m) { }); return iter_ret; }); - py::class_(m, "VoidFuture"); + py::class_(m, "VoidFuture") + .def(py::init<>()) + .def("set_success", [](local::VoidFuture &self) { self.set_success(); }); py::class_( m, "ProgramInvocationFuture") .def("result", [](local::ProgramInvocation::Future &self) { @@ -886,18 +1225,262 @@ void BindHostSystem(py::module_ &global_m) { m, "SystemBuilder"); py::class_(m, "CPUSystemBuilder") - .def(py::init<>()); + .def("__init__", [](py::args, py::kwargs) {}) + .def_static( + "__new__", + [](py::handle cls, std::optional env_prefix, + bool validate_undef, py::kwargs kwargs) { + auto options = + CreateConfigOptions(env_prefix, kwargs, validate_undef); + return std::make_unique( + iree_allocator_system(), std::move(options)); + }, + // Note that for some reason, no-arg construction passes no arguments + // to __new__. We allow the single positional argument to be none, + // which satisfies this case in practice. + py::arg("cls") = py::none(), py::kw_only(), + py::arg("env_prefix").none() = "SHORTFIN_", + py::arg("validate_undef") = true, py::arg("kwargs"), + DOCSTRING_HOSTCPU_SYSTEM_BUILDER_CTOR) + .def_prop_rw( + "hostcpu_allocator_specs", + [](local::systems::HostCPUSystemBuilder &self) { + return self.hostcpu_allocator_specs(); + }, + [](local::systems::HostCPUSystemBuilder &self, + std::vector specs) { + self.hostcpu_allocator_specs() = std::move(specs); + }, + DOCSTRING_HOSTCPU_SYSTEM_BUILDER_HOSTCPU_ALLOCATOR_SPECS); py::class_(m, "HostCPUDevice"); } #if defined(SHORTFIN_HAVE_AMDGPU) + +namespace { +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_CTOR[] = + R"(Constructs a system with AMDGPU based devices. + +Most configuration is done by way of key/value arguments. See the properties +of this class, which document the option keywords that can be passed to this +constructor. + +Args: + env_prefix: Controls how options are looked up in the environment. By default, + the prefix is "SHORTFIN_" and upper-cased options are appended to it. Any + option not explicitly specified but in the environment will be used. Pass + None to disable environment lookup. + **kwargs: Key/value arguments for controlling setup of the system. +)"; + +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_AMDGPU_ALLOCATOR_SPECS[] = + R"(Allocator specs to apply to AMDGPU devices configured by this builder. + +This uses syntax like:: + + some_allocator + some_allocator:key=value + some_allocator:key=value,key=value + some_allocator:key=value,key=value;other_allocator:key=value + +Typical values for `some_allocator` include `caching` and `debug`. + +This can be set via a keyword of `amdgpu_allocators`, which will only apply to +AMDGPU devices or `allocators` which will apply to all contained devices. +Similarly, it is available on a `SHORTFIN_` prefixed env variable if environment +lookup is not disabled. +)"; + +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_AMDGPU_ASYNC_ALLOCATIONS[] = + R"(Whether to use async allocations if supported (default true).)"; + +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_CPU_DEVICES_ENABLED[] = + R"(Whether to create a heterogenous system with hostcpu and amdgpu devices. + +Defaults to false. If enabled, the resulting system will contain both device +types and it is up to application code to differentiate between them. All +options for the hostcpu system builder are applicable in this case. + +This option can be set as an option keyword with the name +"amdgpu_cpu_devices_enabled" or the environment variable +"SHORTFIN_AMDGPU_CPU_DEVICES_ENABLED=true" (if `env_prefix` was not changed +at construction). +)"; + +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_HIP_LIB_SEARCH_PATHS[] = + R"(List of directories to search for libamdhip64.so (or amdhip64.dll). + +If empty, then `dlopen` will be used without a path, meaning that the library +must be on the default search path or already loaded in the process (i.e. +if running within an overall framework). + +Each entry should be a directory, but a full path to a file can be given by +prefixing with "file:". + +This option can be set as an option keyword with the name +"amdgpu_hip_lib_search_path" or the environment variable +"SHORTFIN_AMDGPU_HIP_LIB_SEARCH_PATH" (if `env_prefix` was not changed at +construction). For compatibility with IREE tools, the "IREE_HIP_DYLIB_PATH" +environment variable is searched as a fallback in all cases. Multiple paths +can be separated by semicolons on all platforms. +)"; + +static const char + DOCSTRING_AMDGPU_SYSTEM_BUILDER_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE[] = + R"(Number of logical devices to open per physical, visible device. + +This option can be set as an option keyword with the name +"amgdpu_logical_devices_per_physical_device" or the environment variable +"SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE" (if `env_prefix` was not +changed at construction). +)"; + +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_TRACING_LEVEL[] = + R"(Tracing level for AMDGPU device behavior. + +Controls the verbosity of tracing when Tracy instrumentation is enabled. +The impact to benchmark timing becomes more severe as the verbosity +increases, and thus should be only enabled when needed. + +This is the equivalent of the `--hip_tracing` IREE tools flag. +Permissible values are: + * 0 : stream tracing disabled. + * 1 : coarse command buffer level tracing enabled. + * 2 : (default) fine-grained kernel level tracing enabled. + +The setting only has an effect if using a tracing enabled runtime (i.e. +by running with `SHORTFIN_PY_RUNTIME=tracy` or equiv). + +The default value for this setting is available as a +`amdgpu.SystemBuilder(amdgpu_tracing_level=2)` or (by default) from an +environment variable `SHORTFIN_AMDGPU_TRACING_LEVEL`. +)"; + +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_AVAILABLE_DEVICES[] = + R"(List of available device ids on the system. + +Accessing this property triggers enumeration, so configuration needed to load +libraries and perform basic system setup must be set first. +)"; + +static const char DOCSTRING_AMDGPU_SYSTEM_BUILDER_VISIBLE_DEVICES[] = + R"(Get or set the list of visible device ids. + +If not set or None, then all available devices will be opened and added to +the system. See the property `available_devices` to access this list of ids. + +If set, then each device with the given device id will be opened and added to +the system in the order listed. Note that in certain partitioned cases, multiple +devices may be available with the same device id. In this case, duplicates +in the visible devices list will cause instantiate a partition of the device +in enumeration order (so there can be as many duplicates as physical +partitions). This is an uncommon scenario and most users should not specify +duplicate device ids. Since there are several ways that partitioned devices +can be consumed, additional options will be available in the future for +controlling this behavior. + +This property can be set as an option keyword with the name +"amdgpu_visible_devices" or the environment variable +"SHORTFIN_AMDGPU_VISIBLE_DEVICES" (if `env_prefix` was not changed at +construction). Multiples can be separated by a semicolon. +)"; + +} // namespace + void BindAMDGPUSystem(py::module_ &global_m) { auto m = global_m.def_submodule("amdgpu", "AMDGPU system config"); py::class_(m, "SystemBuilder") - .def(py::init<>()) - .def_rw("cpu_devices_enabled", - &local::systems::AMDGPUSystemBuilder::cpu_devices_enabled); + .def("__init__", [](py::args, py::kwargs) {}) + .def_static( + "__new__", + [](py::handle cls, std::optional env_prefix, + bool validate_undef, py::kwargs kwargs) { + auto options = + CreateConfigOptions(env_prefix, kwargs, validate_undef); + return std::make_unique( + iree_allocator_system(), std::move(options)); + }, + // Note that for some reason, no-arg construction passes no arguments + // to __new__. We allow the single positional argument to be none, + // which satisfies this case in practice. + py::arg("cls") = py::none(), py::kw_only(), + py::arg("env_prefix").none() = "SHORTFIN_", + py::arg("validate_undef") = true, py::arg("kwargs"), + DOCSTRING_AMDGPU_SYSTEM_BUILDER_CTOR) + .def_prop_rw( + "amdgpu_allocator_specs", + [](local::systems::AMDGPUSystemBuilder &self) { + return self.amdgpu_allocator_specs(); + }, + [](local::systems::AMDGPUSystemBuilder &self, + std::vector specs) { + self.amdgpu_allocator_specs() = std::move(specs); + }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_AMDGPU_ALLOCATOR_SPECS) + .def_prop_ro( + "available_devices", + [](local::systems::AMDGPUSystemBuilder &self) { + return self.GetAvailableDeviceIds(); + }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_AVAILABLE_DEVICES) + .def_prop_rw( + "async_allocations", + [](local::systems::AMDGPUSystemBuilder &self) { + return self.async_allocations(); + }, + [](local::systems::AMDGPUSystemBuilder &self, bool value) { + self.async_allocations() = value; + }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_AMDGPU_ASYNC_ALLOCATIONS) + .def_prop_rw( + "cpu_devices_enabled", + [](local::systems::AMDGPUSystemBuilder &self) -> bool { + return self.cpu_devices_enabled(); + }, + [](local::systems::AMDGPUSystemBuilder &self, bool en) { + self.cpu_devices_enabled() = en; + }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_CPU_DEVICES_ENABLED) + .def_prop_rw( + "hip_lib_search_paths", + [](local::systems::AMDGPUSystemBuilder &self) + -> std::vector { + return self.hip_lib_search_paths(); + }, + [](local::systems::AMDGPUSystemBuilder &self, + std::vector vs) { self.hip_lib_search_paths() = vs; }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_HIP_LIB_SEARCH_PATHS) + .def_prop_rw( + "tracing_level", + [](local::systems::AMDGPUSystemBuilder &self) -> int { + return self.tracing_level(); + }, + [](local::systems::AMDGPUSystemBuilder &self, int tracing_level) { + self.tracing_level() = tracing_level; + }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_TRACING_LEVEL) + .def_prop_rw( + "logical_devices_per_physical_device", + [](local::systems::AMDGPUSystemBuilder &self) -> size_t { + return self.logical_devices_per_physical_device(); + }, + [](local::systems::AMDGPUSystemBuilder &self, size_t value) { + self.logical_devices_per_physical_device() = value; + }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE) + .def_prop_rw( + "visible_devices", + [](local::systems::AMDGPUSystemBuilder &self) + -> std::optional> { + return self.visible_devices(); + }, + [](local::systems::AMDGPUSystemBuilder &self, + std::optional> vs) { + self.visible_devices() = std::move(vs); + }, + DOCSTRING_AMDGPU_SYSTEM_BUILDER_VISIBLE_DEVICES); + py::class_(m, "AMDGPUDevice"); } #endif // SHORTFIN_HAVE_AMDGPU diff --git a/libshortfin/bindings/python/lib_ext.h b/shortfin/python/lib_ext.h similarity index 98% rename from libshortfin/bindings/python/lib_ext.h rename to shortfin/python/lib_ext.h index f5f3b2733..e0a2e6cbb 100644 --- a/libshortfin/bindings/python/lib_ext.h +++ b/shortfin/python/lib_ext.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -24,6 +25,7 @@ namespace shortfin::python { namespace py = nanobind; void BindArray(py::module_ &module); +void BindArrayHostOps(py::module_ &module); void BindLocal(py::module_ &module); void BindHostSystem(py::module_ &module); void BindAMDGPUSystem(py::module_ &module); diff --git a/libshortfin/bindings/python/shortfin/__init__.py b/shortfin/python/shortfin/__init__.py similarity index 78% rename from libshortfin/bindings/python/shortfin/__init__.py rename to shortfin/python/shortfin/__init__.py index 050c2409c..c91058d62 100644 --- a/libshortfin/bindings/python/shortfin/__init__.py +++ b/shortfin/python/shortfin/__init__.py @@ -6,25 +6,32 @@ from _shortfin import lib as _sfl +# Set up logging. +import shortfin.support.logging_setup as _logging_setup + # Most classes from the native "local" namespace are aliased to the top # level of the public API. +BaseProgramParameters = _sfl.local.BaseProgramParameters CompletionEvent = _sfl.local.CompletionEvent Device = _sfl.local.Device +Fiber = _sfl.local.Fiber Message = _sfl.local.Message Node = _sfl.local.Node Process = _sfl.local.Process Program = _sfl.local.Program ProgramFunction = _sfl.local.ProgramFunction +ProgramIsolation = _sfl.local.ProgramIsolation ProgramInvocation = _sfl.local.ProgramInvocation ProgramInvocationFuture = _sfl.local.ProgramInvocationFuture ProgramModule = _sfl.local.ProgramModule Queue = _sfl.local.Queue QueueReader = _sfl.local.QueueReader QueueWriter = _sfl.local.QueueWriter -Scope = _sfl.local.Scope ScopedDevice = _sfl.local.ScopedDevice +StaticProgramParameters = _sfl.local.StaticProgramParameters System = _sfl.local.System SystemBuilder = _sfl.local.SystemBuilder +VoidFuture = _sfl.local.VoidFuture Worker = _sfl.local.Worker # Array is auto-imported. @@ -35,8 +42,10 @@ from . import host __all__ = [ + "BaseProgramParameters", "CompletionEvent", "Device", + "Fiber", "Message", "Node", "Program", @@ -47,10 +56,11 @@ "Queue", "QueueReader", "QueueWriter", - "Scope", "ScopedDevice", + "StaticProgramParameters", "System", "SystemBuilder", + "VoidFuture", "Worker", # System namespaces. "amdgpu", diff --git a/libshortfin/bindings/python/shortfin/amdgpu.py b/shortfin/python/shortfin/amdgpu.py similarity index 100% rename from libshortfin/bindings/python/shortfin/amdgpu.py rename to shortfin/python/shortfin/amdgpu.py diff --git a/libshortfin/bindings/python/shortfin/array.py b/shortfin/python/shortfin/array/__init__.py similarity index 67% rename from libshortfin/bindings/python/shortfin/array.py rename to shortfin/python/shortfin/array/__init__.py index d665d280b..670102dfe 100644 --- a/libshortfin/bindings/python/shortfin/array.py +++ b/shortfin/python/shortfin/array/__init__.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import importlib.util + from _shortfin import lib as _sfl # All dtype aliases. @@ -40,6 +42,20 @@ storage = _sfl.array.storage DType = _sfl.array.DType +# Ops. +argmax = _sfl.array.argmax +add = _sfl.array.add +ceil = _sfl.array.ceil +convert = _sfl.array.convert +divide = _sfl.array.divide +fill_randn = _sfl.array.fill_randn +floor = _sfl.array.floor +multiply = _sfl.array.multiply +round = _sfl.array.round +subtract = _sfl.array.subtract +transpose = _sfl.array.transpose +trunc = _sfl.array.trunc +RandomGenerator = _sfl.array.RandomGenerator __all__ = [ # DType aliases. @@ -74,4 +90,25 @@ "device_array", "storage", "DType", + # Ops. + "add", + "argmax", + "ceil", + "convert", + "divide", + "fill_randn", + "floor", + "multiply", + "round", + "subtract", + "transpose", + "trunc", + "RandomGenerator", ] + +# Import nputils if numpy is present. +np_present = importlib.util.find_spec("numpy") is not None +if np_present: + from . import _nputils as nputils + + __all__.append("nputils") diff --git a/shortfin/python/shortfin/array/_nputils.py b/shortfin/python/shortfin/array/_nputils.py new file mode 100644 index 000000000..c5e37bfc6 --- /dev/null +++ b/shortfin/python/shortfin/array/_nputils.py @@ -0,0 +1,104 @@ +import logging + +import numpy as np + +from shortfin import array as sfnp + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def debug_dump_array(tensor: sfnp.device_array) -> None: + """Dump the contents of a device array to the debug log. + + Args: + tensor (sfnp.device_array): The device array to dump. + """ + np_array = np.array(tensor) + logger.debug(np_array) + + +def debug_fill_array(tensor: sfnp.device_array, fill_value: int | float) -> np.ndarray: + """Fill a device array with a given value and return the resulting numpy array. + + Args: + tensor (sfnp.device_array): The device array to fill. + fill_value (int | float): The value to fill the array with. + + Returns: + np.ndarray: The filled numpy array. + """ + np_array = np.array(tensor) + np_array.fill(fill_value) + return np_array + + +def _find_mode( + arr: np.ndarray, axis=0, keepdims=False +) -> tuple[np.ndarray, np.ndarray]: + """ + Find the mode of an array along a given axis. + + Args: + arr: The input array. + axis: The axis along which to find the mode. + keepdims: If True, the output shape is the same as arr except along the specified axis. + + Returns: + tuple: A tuple containing the mode values and the count of the mode values. + """ + + def _mode(arr): + if arr.size == 0: + return np.nan, 0 + + unique, counts = np.unique(arr, return_counts=True) + max_counts = counts.max() + + mode = unique[counts == max_counts][0] + return mode, max_counts + + result = np.apply_along_axis(_mode, axis, arr) + mode_values, mode_count = result[..., 0], result[..., 1] + + if keepdims: + mode_values = np.expand_dims(mode_values, axis) + mode_count = np.expand_dims(mode_count, axis) + + return mode_values, mode_count + + +def debug_log_tensor_stats(tensor: sfnp.device_array) -> None: + """Log statistics about a device array to the debug log. + + The following statistics are logged: + - NaN count + - Shape, dtype + - Min, max, mean, mode (excluding NaN values) + - First 10 elements + - Last 10 elements + + Args: + tensor (sfnp.device_array): The device array to log statistics for. + """ + + np_array = np.array(tensor) + + nan_count = np.isnan(np_array).sum() + + # Remove NaN values + np_array_no_nan = np_array[~np.isnan(np_array)] + + logger.debug(f"NaN count: {nan_count} / {np_array.size}") + logger.debug(f"Shape: {np_array.shape}, dtype: {np_array.dtype}") + + if len(np_array_no_nan) > 0: + mode = _find_mode(np_array_no_nan)[0] + logger.debug(f"Min (excluding NaN): {np_array_no_nan.min()}") + logger.debug(f"Max (excluding NaN): {np_array_no_nan.max()}") + logger.debug(f"Mean (excluding NaN): {np_array_no_nan.mean()}") + logger.debug(f"Mode (excluding NaN): {mode}") + logger.debug(f"First 10 elements: {np_array_no_nan.flatten()[:10]}") + logger.debug(f"Last 10 elements: {np_array_no_nan.flatten()[-10:]}") + else: + logger.warning(f"All values are NaN") diff --git a/libshortfin/bindings/python/shortfin/host.py b/shortfin/python/shortfin/host.py similarity index 100% rename from libshortfin/bindings/python/shortfin/host.py rename to shortfin/python/shortfin/host.py diff --git a/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py b/shortfin/python/shortfin/interop/fastapi/__init__.py similarity index 80% rename from libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py rename to shortfin/python/shortfin/interop/fastapi/__init__.py index 98f6a41b4..d55bc2180 100644 --- a/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py +++ b/shortfin/python/shortfin/interop/fastapi/__init__.py @@ -5,14 +5,22 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio +import logging + +from shortfin.support.deps import ShortfinDepNotFoundError try: from fastapi import Request, Response from fastapi.responses import StreamingResponse except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Shortfin fastapi interop requires fastapi to be installed" - ) from e + raise ShortfinDepNotFoundError(__name__, "fastapi") from e + + +__all__ = [ + "FastAPIResponder", +] + +logger = logging.getLogger(__name__) class FastAPIResponder: @@ -42,30 +50,32 @@ def __init__(self, request: Request): # Capture the running loop so that we can send responses back. self._loop = asyncio.get_running_loop() self.response = asyncio.Future(loop=self._loop) - self._responded = False + self.responded = False self._streaming_queue: asyncio.Queue | None = None self.is_disconnected = False - def close_with_error(self): - # Called in a failsafe fashion as part of exception handlers seeking to - # shutdown the response. If not yet responded, this will response with - # a status code of 500. If streaming, then None will be streamed. - if self._responded: + def ensure_response(self): + """Called as part of some finally type block to ensure responses are made.""" + if self.responded: if self._streaming_queue: + logging.error("Streaming response not finished. Force finishing.") self.stream_part(None) else: + logging.error("One-shot response not finished. Responding with error.") self.send_response(Response(status_code=500)) - def send_response(self, response: Response): + def send_response(self, response: Response | bytes): """Sends a response back for this transaction. This is intended for sending single part responses back. See start_response() for sending back a streaming, multi-part response. """ - assert not self._responded, "Response already sent" + assert not self.responded, "Response already sent" if self._loop.is_closed(): raise IOError("Web server is shut down") - self._responded = True + self.responded = True + if not isinstance(response, Response): + response = Response(response) self._loop.call_soon_threadsafe(self.response.set_result, response) def start_response(self, **kwargs): @@ -77,10 +87,10 @@ def start_response(self, **kwargs): be used for bulk transfer (i.e. by scheduling on the webserver loop directly). """ - assert not self._responded, "Response already sent" + assert not self.responded, "Response already sent" if self._loop.is_closed(): raise IOError("Web server is shut down") - self._responded = True + self.responded = True self._streaming_queue = asyncio.Queue() async def gen(request, streaming_queue): diff --git a/shortfin/python/shortfin/interop/support/device_setup.py b/shortfin/python/shortfin/interop/support/device_setup.py new file mode 100644 index 000000000..afe6ca695 --- /dev/null +++ b/shortfin/python/shortfin/interop/support/device_setup.py @@ -0,0 +1,26 @@ +import shortfin as sf + + +def get_selected_devices(sb: sf.SystemBuilder, device_ids=None): + available = sb.available_devices + selected = [] + if device_ids is not None: + if len(device_ids) > len(available): + raise ValueError( + f"Requested more device ids ({device_ids}) than available ({available})." + ) + for did in device_ids: + if isinstance(did, str): + try: + did = int(did) + except ValueError: + did = did + if did in available: + selected.append(did) + elif isinstance(did, int): + selected.append(available[did]) + else: + raise ValueError(f"Device id {did} could not be parsed.") + else: + selected = available + return selected diff --git a/shortfin/python/shortfin/interop/support/logging_setup.py b/shortfin/python/shortfin/interop/support/logging_setup.py new file mode 100644 index 000000000..5edd0695f --- /dev/null +++ b/shortfin/python/shortfin/interop/support/logging_setup.py @@ -0,0 +1,52 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import sys + +from _shortfin import lib as _sfl + +_LOG_FUNCTIONS = { + logging.DEBUG: _sfl.log_debug, + logging.INFO: _sfl.log_info, + logging.WARNING: _sfl.log_warn, + logging.ERROR: _sfl.log_error, + logging.CRITICAL: _sfl.log_error, +} + +logger = logging.getLogger("shortfin") +logger.propagate = False + + +class NativeHandler(logging.Handler): + def emit(self, record): + formatted = self.format(record) + f = _LOG_FUNCTIONS.get(record.levelno) + if f is not None: + f(formatted) + + +class NativeFormatter(logging.Formatter): + def __init__(self): + super().__init__("[%(filename)s:%(lineno)d] %(message)s") + + +native_handler = NativeHandler() +native_handler.setFormatter(NativeFormatter()) + +# TODO: Source from env vars. +logger.setLevel(logging.DEBUG) +logger.addHandler(native_handler) + + +def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger: + """Configures logging from a main entrypoint. + Returns a logger that can be used for the main module itself. + """ + logging.root.addHandler(native_handler) + logging.root.setLevel(logging.DEBUG) # TODO: source from env vars + main_module = sys.modules["__main__"] + return logging.getLogger(f"{main_module.__package__}.{module_suffix}") diff --git a/shortfin/python/shortfin/support/deps.py b/shortfin/python/shortfin/support/deps.py new file mode 100644 index 000000000..bcb9175d2 --- /dev/null +++ b/shortfin/python/shortfin/support/deps.py @@ -0,0 +1,38 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Utilities for managing dependencies. + +The overall shortfin namespace contains components with optional dependencies. +This module provides support for reacting to dependcy issues. +""" + + +class ShortfinDepNotFoundError(Exception): + """Raised from a ModuleNotFoundError for a missing or incorrect dep.""" + + def __init__( + self, caller_name: str, package_name: str, extras_name: str | None = None + ): + super().__init__() + self.caller_name = caller_name.removesuffix("._deps") + self.package_name = package_name + self.extras_name = extras_name + + def __str__(self): + msg = ( + f"Shortfin is missing a dependency to use {self.caller_name}. " + f"This is typically available via `pip install {self.package_name}`" + ) + if self.extras_name: + msg += ( + f" (or by installing with an extra like " + f"`pip install shortfin[{self.extras_name}])" + ) + return msg + + +ShortfinDepNotFoundError.__name__ = "ShortfinDepNotFoundError" diff --git a/shortfin/python/shortfin/support/logging_setup.py b/shortfin/python/shortfin/support/logging_setup.py new file mode 100644 index 000000000..849d65bf3 --- /dev/null +++ b/shortfin/python/shortfin/support/logging_setup.py @@ -0,0 +1,52 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import sys + +from _shortfin import lib as _sfl + +_LOG_FUNCTIONS = { + logging.DEBUG: _sfl.log_debug, + logging.INFO: _sfl.log_info, + logging.WARNING: _sfl.log_warn, + logging.ERROR: _sfl.log_error, + logging.CRITICAL: _sfl.log_error, +} + +logger = logging.getLogger("shortfin") +logger.propagate = False + + +class NativeHandler(logging.Handler): + def emit(self, record): + formatted = self.format(record) + f = _LOG_FUNCTIONS.get(record.levelno) + if f is not None: + f(formatted) + + +class NativeFormatter(logging.Formatter): + def __init__(self): + super().__init__("[%(filename)s:%(lineno)d] %(message)s") + + +native_handler = NativeHandler() +native_handler.setFormatter(NativeFormatter()) + +# TODO: Source from env vars. +logger.setLevel(logging.WARNING) +logger.addHandler(native_handler) + + +def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger: + """Configures logging from a main entrypoint. + Returns a logger that can be used for the main module itself. + """ + logging.root.addHandler(native_handler) + logging.root.setLevel(logging.WARNING) # TODO: source from env vars + main_module = sys.modules["__main__"] + return logging.getLogger(f"{main_module.__package__}.{module_suffix}") diff --git a/shortfin/python/shortfin_apps/llm/README.md b/shortfin/python/shortfin_apps/llm/README.md new file mode 100644 index 000000000..b62cca20b --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/README.md @@ -0,0 +1,10 @@ +# LLM Server and CLI + +This directory contains an LLM inference server, CLI and support components. + + +## Quick start + +``` +python -m shortfin_apps.llm.server --help +``` diff --git a/shortfin/python/shortfin_apps/llm/__init__.py b/shortfin/python/shortfin_apps/llm/__init__.py new file mode 100644 index 000000000..4a168079c --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import _deps diff --git a/shortfin/python/shortfin_apps/llm/_deps.py b/shortfin/python/shortfin_apps/llm/_deps.py new file mode 100644 index 000000000..7123d011e --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/_deps.py @@ -0,0 +1,17 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from shortfin.support.deps import ShortfinDepNotFoundError + +try: + import tokenizers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "tokenizers") from e + +try: + import dataclasses_json +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e diff --git a/shortfin/python/shortfin_apps/llm/client.py b/shortfin/python/shortfin_apps/llm/client.py new file mode 100644 index 000000000..e3ff3ec39 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/client.py @@ -0,0 +1,88 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import requests +import json +import uuid +import argparse +import time +from typing import Dict, Any + + +def main() -> None: + parser = argparse.ArgumentParser(description="Test LLM server") + parser.add_argument("--text", default="1 2 3 4 5 ", help="Input text prompt") + parser.add_argument( + "--max_completion_tokens", type=int, default=50, help="Max tokens to generate" + ) + parser.add_argument( + "--temperature", type=float, default=0.7, help="Sampling temperature" + ) + parser.add_argument( + "--stream", action="store_true", help="Enable response streaming" + ) + parser.add_argument( + "--port", + type=str, + default="8000", + help="Port that shortfin server is running on", + ) + args = parser.parse_args() + + base_url = f"http://localhost:{args.port}" + + data = { + "text": args.text, + "sampling_params": { + "max_completion_tokens": args.max_completion_tokens, + "temperature": args.temperature, + }, + "rid": uuid.uuid4().hex, + "return_logprob": False, + "logprob_start_len": -1, + "top_logprobs_num": 0, + "return_text_in_logprobs": False, + "stream": args.stream, + } + + print(f"Testing LLM server at {base_url}") + + # Health check with exponential backoff + backoff = 1 + while True: + try: + requests.get(f"{base_url}/health").raise_for_status() + break + except requests.exceptions.RequestException as e: + if backoff > 16: + print("Health check failed, max retries exceeded") + return + print(f"Health check failed ({str(e)}), retrying in {backoff}s...") + time.sleep(backoff) + backoff *= 2 + + # Generate request + try: + print("Prompt text:", data["text"]) + headers = {"Content-Type": "application/json"} + response = requests.post(f"{base_url}/generate", headers=headers, json=data) + response.raise_for_status() + + if response.text.startswith("data: "): + text = response.text[6:].rstrip("\n") + print("Generated text:", text) + print("\nTest passed") + else: + print("\nTest failed: unexpected response format") + + except requests.exceptions.RequestException as e: + print(f"\nTest failed: request error: {str(e)}") + except KeyboardInterrupt: + print("\nTest interrupted") + + +if __name__ == "__main__": + main() diff --git a/shortfin/python/shortfin_apps/llm/components/cache.py b/shortfin/python/shortfin_apps/llm/components/cache.py new file mode 100644 index 000000000..12794498f --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/cache.py @@ -0,0 +1,111 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Sequence + +import logging +import math +import threading + +import shortfin as sf + +from .config_struct import ModelParams, human_size + +logger = logging.getLogger(__name__) + + +class AttnPageEntry: + __slots__ = [ + "cache", + "index", + "in_use", + ] + + def __init__(self, cache: "AttnPageCache", index: int): + self.cache = cache + self.index = index + self.in_use = False + + def __repr__(self): + return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" + + +class AttnPageCache: + """Page table based attention cache. + + While internal to a model, the cache is organized with additional structure + per page, outside of the model, it is just a list of pages of a certain + element type and number of elements (all inner dims are flattened). + + One page table is allocated per device in a fiber. Currently, this is a + dense allocation with committed memory but in the future, we may just + allocate the address space and lazily populate it with committed memory. + + The cache is unique because usage of it can span fibers and concurrency + is implicitly managed at the block level (i.e. freshly acquired blocks + are assumed to be uninitialized and available immediately for use). + + It is initialized with a discrete list of fiberd devices from a fiber but + cache usage can be done from any fiber which includes those devices. + """ + + def __init__( + self, *, devices: Sequence[sf.ScopedDevice], model_params: ModelParams + ): + self._lock = threading.Lock() + self.devices = list(devices) + self.model_params = model_params + self.page_tables: list[sf.array.device_array] = [] + cache_params = model_params.paged_kv_cache + alloc_page_count = cache_params.device_block_count + + # Setup accounting structs. + self.attn_page_entries = [ + AttnPageEntry(self, i) for i in range(alloc_page_count) + ] + self.attn_page_free = list(self.attn_page_entries) + + # Initialize a page table on each device. + assert cache_params is not None, "Model does not have a paged kv cache" + page_table_shape = [ + alloc_page_count, + model_params.paged_kv_block_size_elements, + ] + for device in devices: + logging.info( + "Allocating page table (shape=%r, dtype=%r, size=%s) on %r", + page_table_shape, + model_params.attn_dtype, + human_size( + math.prod(page_table_shape) + * model_params.attn_dtype.dense_byte_count + ), + device, + ) + page_table = sf.array.device_array.for_device( + device, page_table_shape, model_params.attn_dtype + ) + self.page_tables.append(page_table) + + def acquire_free_pages(self, count: int) -> list[AttnPageEntry] | None: + with self._lock: + available = len(self.attn_page_free) + if count > available: + return None + return [self.attn_page_free.pop() for _ in range(count)] + + def release_pages(self, pages: list[AttnPageEntry]): + with self._lock: + self.attn_page_free.extend(pages) + + def __repr__(self): + # No need to lock for repr (list is internally synchronized). + free_pages = len(self.attn_page_free) + total_pages = len(self.attn_page_entries) + return ( + f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " + f"{100.0 * free_pages / total_pages}% free)" + ) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py new file mode 100644 index 000000000..141c7a7eb --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -0,0 +1,185 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Configuration objects. + +Parameters that are intrinsic to a specific model. + +In a typical transformer model, the KV cache is organized similar to (mapped to +our parameter names below): + k = tensor.empty(transformer_block_count, batch_size, seq, + attn_head_count, attn_head_dim) + v = ... + +For context, a popular model has parameters of: + attn_dtype_size = 2 # (fp16) + max_seq_len = 2048 + transformer_block_count = 32 + attn_head_count = 32 + attn_head_dim = 128 # (dim / head_count) + +If paging, then we primarily care about the organization of a single block, where +a block represents a single position in the sequence for a single item in the batch. +Therefore, it will be organized like: + block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) + +In this scenario, we declare that one block holds the KV cache for all transformer +block layers because it reduces the accounting. As such, for the above example, +a single position in the sequence will be 524,288 bytes, assuming a 2-byte element +type. If we choose to block by block_stride=16 positions, each block will be 8MiB. +Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536 +blocks for a total number of sequence positions of 24,576. + +These are well-known numbers but are derived above to give a sense of scale. + +In order to indirect through to the block cache, we have to provide the index map +to specific invocations: + +* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will + need write indices of [batch_size, prompt_len // block_stride + 1]. +* Decode step: Decode is auto-regressive, and needs to first compute the new kv + row and then attend over all rows in the cache up to this point in the sequence. + +If wanting to avoid dynamic allocation of transients, we can also pool the index +tables based on the maximum batch size and maximum sequence length. Since all +block cache sizes are well within the range of an i16, we will use that for storage. +Therefore, each batch invocation would need a block lookup table of: + + byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t) + +For a max_batch_size of 16, this is 4KiB of block index table lookups per +invocation. We don't have to statically allocate this, but the system is more +predictable if we just reserve what we need. Again, numbers are given to give a +sense of scale only: real workloads will vary. +""" + +from dataclasses import dataclass +from pathlib import Path + +import dataclasses_json +from dataclasses_json import dataclass_json, Undefined + +import shortfin.array as sfnp + + +def _decode_dtype(name: str) -> sfnp.DType: + obj = getattr(sfnp, name, None) + if not isinstance(obj, sfnp.DType): + raise ValueError(f"{name} is not a recognized dtype") + + +dataclasses_json.cfg.global_config.encoders[sfnp.DType] = lambda dt: dt.name +dataclasses_json.cfg.global_config.decoders[sfnp.DType] = _decode_dtype + + +@dataclass_json(undefined=Undefined.RAISE) +@dataclass +class PagedKVCacheParams: + """Parameters for the paged KV cache.""" + + # Position stride per attention block + block_seq_stride: int + + # Size of the cache on each device. + device_block_count: int + + +@dataclass_json(undefined=Undefined.RAISE) +@dataclass +class ModelParams: + """Parameters for a specific compiled model, sufficient to do cache planning and + invocations.""" + + # Maximum length of a sequence including prompt and output. + max_seq_len: int + + # Number of transformer blocks. + transformer_block_count: int + + # Number of attention heads per block. + attn_head_count: int + + # Dimensionality of each attention head + attn_head_dim: int + + # Batch sizes that the prefill stage is compiled for. These are expected to be + # functions exported from the model with suffixes of "_bs{batch_size}". Must + # be in ascending order. + prefill_batch_sizes: list[int] + + # Similarly, batch sizes that the decode stage is compiled for. + decode_batch_sizes: list[int] + + # Name of the IREE module implementing the model. + module_name: str = "module" + + # ABI of the module. + module_abi_version: int = 1 + + # The element type of the attention caches. + attn_dtype: sfnp.DType = sfnp.float16 + + # Cache parameters. + paged_kv_cache: PagedKVCacheParams | None = None + + # Size in bytes of the KV cache dtype. + @property + def attn_dtype_size(self) -> int: + assert sfnp.DType.is_byte_aligned() + return sfnp.DType.dense_byte_count() + + @property + def max_prefill_batch_size(self) -> int: + return self.prefill_batch_sizes[-1] + + @property + def max_decode_batch_size(self) -> int: + return self.decode_batch_sizes[-1] + + @property + def max_batch_size(self): + return max(self.max_prefill_batch_size, self.max_decode_batch_size) + + @property + def has_paged_kv_cache(self): + return self.paged_kv_cache is not None + + @property + def paged_kv_unit_size_elements(self) -> int: + """Size in elements of each cache line in the attention cache. + + Each cache line can store a unit position stride. + """ + assert self.has_paged_kv_cache + size = 1 + size *= self.transformer_block_count + size *= 2 # K and V cache line + size *= self.attn_head_count + size *= self.attn_head_dim + return size + + @property + def paged_kv_block_size_elements(self) -> int: + """Size in elements of each attention block of {block_position_stride} + positions. + """ + assert self.paged_kv_cache is not None + return self.paged_kv_unit_size_elements * self.paged_kv_cache.block_seq_stride + + @staticmethod + def load_json(path: Path | str): + with open(path, "rt") as f: + json_text = f.read() + return ModelParams.from_json(json_text) + + +# From: https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size +def human_size(num, suffix="B"): + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < 1024.0: + return f"{num:3.1f}{unit}{suffix}" + num /= 1024.0 + return f"{num:.1f}Yi{suffix}" diff --git a/shortfin/python/shortfin_apps/llm/components/generate.py b/shortfin/python/shortfin_apps/llm/components/generate.py new file mode 100644 index 000000000..698f779fb --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/generate.py @@ -0,0 +1,183 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import io +import logging + +import shortfin as sf +import shortfin.array as sfnp + +# TODO: Have a generic "Responder" interface vs just the concrete impl. +from shortfin.interop.fastapi import FastAPIResponder + +from .io_struct import GenerateReqInput +from .messages import InferenceExecRequest, InferencePhase +from .service import GenerateService +from .tokenizer import Encoding + +logger = logging.getLogger(__name__) + + +class GenerateItemProcess(sf.Process): + """Process instantiated for each generation sequence. + + This process breaks the sequence into individual inference and sampling + steps, submitting them to the batcher and marshaling incremental/final + results. + """ + + def __init__( + self, + client: "ClientGenerateBatchProcess", + gen_req: GenerateReqInput, + index: int, + input_token_ids: list[int], + max_completion_tokens: int, + eos_token_id: int, + ): + super().__init__(fiber=client.fiber) + self.client = client + self.gen_req = gen_req + self.index = index + self.input_token_ids = input_token_ids + self.result_token_ids: list[int] = [] + self.max_completion_tokens = max_completion_tokens + self.eos_token_id = eos_token_id + + async def run(self): + exec = InferenceExecRequest(InferencePhase.PREFILL, self.input_token_ids) + try: + self.client.batcher.submit(exec) + await exec.done + + # Prefill result. + token = sfnp.argmax(exec.result_logits) + token_int = token.items[0] + + self.append_token(token_int) + # Decode loop. + exec.start_position = len(self.input_token_ids) - 1 + for i in range(self.max_completion_tokens): + exec.reset(InferencePhase.DECODE) + exec.input_token_ids.append(token_int) + exec.start_position += 1 + self.client.batcher.submit(exec) + await exec.done + token = sfnp.argmax(exec.result_logits) + token_int = token.items[0] + self.append_token(token_int) + if token_int == self.eos_token_id: + break + finally: + exec.free_cache_pages() + + def append_token(self, token: int): + self.result_token_ids.append(token) + self.client.stream_results(self) + + +class ClientGenerateBatchProcess(sf.Process): + """Process instantiated for handling a batch from a client. + + This takes care of several responsibilities: + + * Tokenization / Detokenization + * Splitting the batch into GenerateItemProcesses + * Streaming responses + * Final responses + """ + + __slots__ = [ + "batcher", + "complete_infeed", + "gen_req", + "responder", + "tokenizer", + ] + + def __init__( + self, + service: GenerateService, + gen_req: GenerateReqInput, + responder: FastAPIResponder, + ): + super().__init__(fiber=service.main_fiber) + self.gen_req = gen_req + self.responder = responder + self.tokenizer = service.tokenizer + self.batcher = service.batcher + self.complete_infeed = self.system.create_queue() + + async def run(self): + logger.debug("Started ClientBatchGenerateProcess: %r", self) + streaming = self.gen_req.stream + if streaming: + self.responder.start_response() + + try: + # Launch all individual generate processes and wait for them to finish. + gen_processes = [] + # TODO: We should send this to an executor and await the results. + input_batch = self.tokenize() + for index, input_tokens in enumerate(input_batch): + gen_process = GenerateItemProcess( + self, + self.gen_req, + index, + input_tokens.ids, + max_completion_tokens=self.gen_req.sampling_params[ + "max_completion_tokens" + ], + eos_token_id=self.tokenizer.eos_token_id, + ) + gen_processes.append(gen_process) + gen_process.launch() + + await asyncio.gather(*gen_processes) + + if streaming: + logger.debug("Responding to streaming batch") + self.responder.stream_part(b"data: [DONE]\n\n") + self.responder.stream_part(None) + else: + logging.debug("Responding to one shot batch") + out = io.BytesIO() + result_texts = self.tokenizer.decode( + [p.result_token_ids for p in gen_processes] + ) + for result_text in result_texts: + out.write(b"data: ") + out.write(result_text.encode()) + out.write(b"\n\n") + self.responder.send_response(out.getvalue()) + finally: + self.responder.ensure_response() + + def stream_results(self, gen_process: GenerateItemProcess): + if not self.gen_req.stream: + return + (result_text,) = self.tokenizer.decode([gen_process.result_token_ids]) + out = io.BytesIO() + out.write(b"data: ") + out.write(result_text.encode()) + out.write(b"\n\n") + self.responder.stream_part(out.getvalue()) + + def tokenize(self) -> list[Encoding]: + gen_req = self.gen_req + if gen_req.text is not None: + if self.gen_req.is_single: + texts = [self.gen_req.text] + logger.debug("Encoding single request") + else: + texts = self.gen_req.text + logger.debug("Encoding batch of %d", len(texts)) + encodings = self.tokenizer.encode(texts) + logger.debug("Generated encodings: %r", encodings) + return encodings + else: + raise NotImplementedError("GenerateReqInput.input_ids handling NYI") diff --git a/shortfin/python/shortfin_apps/llm/components/io_struct.py b/shortfin/python/shortfin_apps/llm/components/io_struct.py new file mode 100644 index 000000000..a739b731f --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/io_struct.py @@ -0,0 +1,147 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Objects transferred between components. + +Portions adapted from API definitions originating in: + +sglang: Copyright 2023-2024 SGLang Team, Licensed under the Apache License, Version 2.0 +""" + +from typing import Dict, List, Optional, Union +from dataclasses import dataclass +import uuid + + +# Adapted from: +# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/io_struct.py +@dataclass +class GenerateReqInput: + # The input prompt. It can be a single prompt or a batch of prompts. + text: Optional[Union[List[str], str]] = None + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None + # The image input. It can be a file name, a url, or base64 encoded string. + # See also python/sglang/srt/utils.py:load_image. + image_data: Optional[Union[List[str], str]] = None + # The sampling_params. See descriptions below. + sampling_params: Union[List[Dict], Dict] = None + # The request id. + rid: Optional[Union[List[str], str]] = None + # Whether to return logprobs. + return_logprob: Optional[Union[List[bool], bool]] = None + # If return logprobs, the start location in the prompt for returning logprobs. + logprob_start_len: Optional[Union[List[int], int]] = None + # If return logprobs, the number of top logprobs to return at each position. + top_logprobs_num: Optional[Union[List[int], int]] = None + # Whether to detokenize tokens in text in the returned logprobs. + return_text_in_logprobs: bool = False + # Whether to stream output. + stream: bool = False + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None + + is_single: bool = True + + def post_init(self): + if (self.text is None and self.input_ids is None) or ( + self.text is not None and self.input_ids is not None + ): + raise ValueError("Either text or input_ids should be provided.") + if ( + isinstance(self.sampling_params, dict) + and self.sampling_params.get("n", 1) != 1 + ): + is_single = False + else: + if self.text is not None: + is_single = isinstance(self.text, str) + else: + is_single = isinstance(self.input_ids[0], int) + self.is_single = is_single + + if is_single: + if self.sampling_params is None: + self.sampling_params = {} + if self.rid is None: + self.rid = uuid.uuid4().hex + if self.return_logprob is None: + self.return_logprob = False + if self.logprob_start_len is None: + self.logprob_start_len = -1 + if self.top_logprobs_num is None: + self.top_logprobs_num = 0 + else: + parallel_sample_num_list = [] + if isinstance(self.sampling_params, dict): + parallel_sample_num = self.sampling_params.get("n", 1) + elif isinstance(self.sampling_params, list): + for sp in self.sampling_params: + parallel_sample_num = sp.get("n", 1) + parallel_sample_num_list.append(parallel_sample_num) + parallel_sample_num = max(parallel_sample_num_list) + all_equal = all( + element == parallel_sample_num + for element in parallel_sample_num_list + ) + if parallel_sample_num > 1 and (not all_equal): + # TODO cope with the case that the parallel_sample_num is different for different samples + raise ValueError( + "The parallel_sample_num should be the same for all samples in sample params." + ) + else: + parallel_sample_num = 1 + self.parallel_sample_num = parallel_sample_num + + if parallel_sample_num != 1: + # parallel sampling +1 represents the original prefill stage + num = parallel_sample_num + 1 + if isinstance(self.text, list): + # support batch operation + self.batch_size = len(self.text) + num = num * len(self.text) + elif isinstance(self.input_ids, list) and isinstance( + self.input_ids[0], list + ): + self.batch_size = len(self.input_ids) + num = num * len(self.input_ids) + else: + self.batch_size = 1 + else: + # support select operation + num = len(self.text) if self.text is not None else len(self.input_ids) + self.batch_size = num + + if self.image_data is None: + self.image_data = [None] * num + elif not isinstance(self.image_data, list): + self.image_data = [self.image_data] * num + + if self.sampling_params is None: + self.sampling_params = [{}] * num + elif not isinstance(self.sampling_params, list): + self.sampling_params = [self.sampling_params] * num + + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(num)] + else: + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") + + if self.return_logprob is None: + self.return_logprob = [False] * num + elif not isinstance(self.return_logprob, list): + self.return_logprob = [self.return_logprob] * num + + if self.logprob_start_len is None: + self.logprob_start_len = [-1] * num + elif not isinstance(self.logprob_start_len, list): + self.logprob_start_len = [self.logprob_start_len] * num + + if self.top_logprobs_num is None: + self.top_logprobs_num = [0] * num + elif not isinstance(self.top_logprobs_num, list): + self.top_logprobs_num = [self.top_logprobs_num] * num diff --git a/shortfin/python/shortfin_apps/llm/components/manager.py b/shortfin/python/shortfin_apps/llm/components/manager.py new file mode 100644 index 000000000..b44116b39 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/manager.py @@ -0,0 +1,47 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import threading + +import shortfin as sf +from shortfin.interop.support.device_setup import get_selected_devices + +logger = logging.getLogger(__name__) + + +class SystemManager: + def __init__(self, device="local-task", device_ids=None, async_allocs=True): + if any(x in device for x in ["local-task", "cpu"]): + self.ls = sf.host.CPUSystemBuilder().create_system() + elif any(x in device for x in ["hip", "amdgpu"]): + sb = sf.SystemBuilder( + system_type="amdgpu", amdgpu_async_allocations=async_allocs + ) + if device_ids: + sb.visible_devices = sb.available_devices + sb.visible_devices = get_selected_devices(sb, device_ids) + self.ls = sb.create_system() + logger.info(f"Created local system with {self.ls.device_names} devices") + # TODO: Come up with an easier bootstrap thing than manually + # running a thread. + self.t = threading.Thread(target=lambda: self.ls.run(self.run())) + self.command_queue = self.ls.create_queue("command") + self.command_writer = self.command_queue.writer() + + def start(self): + logger.info("Starting system manager") + self.t.start() + + def shutdown(self): + logger.info("Shutting down system manager") + self.command_queue.close() + + async def run(self): + reader = self.command_queue.reader() + while command := await reader(): + ... + logging.info("System manager command processor stopped") diff --git a/shortfin/python/shortfin_apps/llm/components/messages.py b/shortfin/python/shortfin_apps/llm/components/messages.py new file mode 100644 index 000000000..fdcbeefc1 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/messages.py @@ -0,0 +1,86 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from enum import Enum + +import shortfin as sf +import shortfin.array as sfnp + +from .cache import AttnPageCache, AttnPageEntry + + +class InferencePhase(Enum): + PREFILL = 1 + DECODE = 2 + + +class InferenceExecRequest(sf.Message): + """Performs a prefill operation.""" + + def __init__(self, phase: InferencePhase, input_token_ids: list[int]): + super().__init__() + self.phase = phase + self.start_position: int = 0 + self.input_token_ids = input_token_ids + self.done = sf.VoidFuture() + + # Response control. + # If True, return all sequence position logits. If False, return only + # the last. + self.return_all_logits: bool = False + + # Move the result array to the host and sync to ensure data is + # available. + self.return_host_array: bool = True + + # Result logits as [1, sl, d] where 1 is the preserved batch dim, + # sl is either 1 (not return_all_logits) or >=1 (return_all_logits). + self.result_logits: sfnp.device_array | None = None + + # Cache pages that have been locked for this request. + self._cache: AttnPageCache | None = None + self.locked_pages: list[AttnPageEntry] | None = None + + def reset(self, phase: InferencePhase): + """Resets all per request state in preparation for an subsequent execution.""" + self.phase = phase + self.done = sf.VoidFuture() + self.return_all_logits = False + self.return_host_array = True + self.result_logits = None + + def cache_page_indices(self, max_len: int) -> list[int]: + if not self.locked_pages: + return [] + indices = [p.index for p in self.locked_pages] + if len(indices) > max_len: + return indices[0:max_len] + return indices + + def free_cache_pages(self): + cache = self._cache + if cache: + pages = self.locked_pages + self._cache = None + self.locked_pages = None + cache.release_pages(pages) + + def lock_initial_cache_pages( + self, cache: AttnPageCache, pages: list[AttnPageEntry] + ): + assert not self._cache + self._cache = cache + self.locked_pages = pages + + def lock_new_cache_pages(self, cache: AttnPageCache, pages: list[AttnPageEntry]): + assert self._cache is cache + self.locked_pages.extend(pages) + + +class StrobeMessage(sf.Message): + """Sent to strobe a queue with fake activity (generate a wakeup).""" + + ... diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py new file mode 100644 index 000000000..bcd08b756 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -0,0 +1,442 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +from pathlib import Path + +import shortfin as sf +import shortfin.array as sfnp + +from .cache import AttnPageCache +from .config_struct import ModelParams +from .manager import SystemManager +from .messages import InferenceExecRequest, InferencePhase, StrobeMessage +from .tokenizer import Tokenizer + +logger = logging.getLogger(__name__) + +PROG_ISOLATIONS = { + isolation.name.lower(): isolation for isolation in sf.ProgramIsolation +} + + +class GenerateService: + """Top level service interface for generating text against a model.""" + + inference_program: sf.Program + prefill_functions: dict[int, sf.ProgramFunction] + decode_functions: dict[int, sf.ProgramFunction] + + def __init__( + self, + *, + name: str, + sysman: SystemManager, + tokenizer: Tokenizer, + model_params: ModelParams, + program_isolation: str = "per_call", + ): + self.name = name + + # Application objects. + self.sysman = sysman + self.tokenizer = tokenizer + self.model_params = model_params + self.inference_parameters: list[sf.BaseProgramParameters] = [] + self.inference_modules: list[sf.ProgramModule] = [] + + self.main_worker = sysman.ls.create_worker(f"{name}-inference") + self.main_fiber = sysman.ls.create_fiber(self.main_worker) + + # Scope dependent objects. + self.batcher = BatcherProcess(self) + self.page_cache = AttnPageCache( + devices=self.main_fiber.devices_dict.values(), model_params=model_params + ) + + self.program_isolation = PROG_ISOLATIONS[program_isolation] + + def load_inference_module(self, vmfb_path: Path): + self.inference_modules.append(sf.ProgramModule.load(self.sysman.ls, vmfb_path)) + + def load_inference_parameters( + self, *paths: Path, parameter_scope: str, format: str = "" + ): + p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope) + for path in paths: + logging.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) + p.load(path, format=format) + self.inference_parameters.append(p) + + def start(self): + self.inference_program = sf.Program( + modules=[ + sf.ProgramModule.parameter_provider( + self.sysman.ls, *self.inference_parameters + ) + ] + + self.inference_modules, + devices=self.sysman.ls.devices, + trace_execution=False, + isolation=self.program_isolation, + ) + # Resolve prefill entrypoints. + self.prefill_functions = {} + for bs in self.model_params.prefill_batch_sizes: + self.prefill_functions[bs] = self.inference_program[ + f"{self.model_params.module_name}.prefill_bs{bs}" + ] + # Resolve decode entrypoints. + self.decode_functions = {} + for bs in self.model_params.decode_batch_sizes: + self.decode_functions[bs] = self.inference_program[ + f"{self.model_params.module_name}.decode_bs{bs}" + ] + + # Start persistent processes. + self.batcher.launch() + + def shutdown(self): + self.batcher.shutdown() + + def __repr__(self): + return ( + f"ServiceManager(\n" + f" model_params={self.model_params}\n" + f" inference_modules={self.inference_modules}\n" + f" page_cache={self.page_cache}\n" + f")" + ) + + +######################################################################################## +# Batcher +######################################################################################## + +import math + + +class BatcherProcess(sf.Process): + """The batcher is a persistent process responsible for flighting incoming work + into batches and handling the requisite cache allocations (since every batch needs + committed cache state). + """ + + STROBE_SHORT_DELAY = 0.1 + STROBE_LONG_DELAY = 0.25 + + def __init__(self, service: GenerateService): + super().__init__(fiber=service.main_fiber) + self.service = service + self.batcher_infeed = self.system.create_queue() + self.pending_prefills: set[InferenceExecRequest] = set() + self.pending_decodes: set[InferenceExecRequest] = set() + self.strobe_enabled = True + self.strobes: int = 0 + # TODO: There is no "ideal" batch size. Use prefill/decode dynamic + # batching in the scheduling algo. + self.ideal_batch_size: int = max(service.model_params.prefill_batch_sizes) + self.page_seq_stride = service.model_params.paged_kv_cache.block_seq_stride + + def shutdown(self): + self.batcher_infeed.close() + + def submit(self, request: StrobeMessage | InferenceExecRequest): + self.batcher_infeed.write_nodelay(request) + + async def _background_strober(self): + while not self.batcher_infeed.closed: + await asyncio.sleep( + BatcherProcess.STROBE_SHORT_DELAY + if len(self.pending_prefills) > 0 + else BatcherProcess.STROBE_LONG_DELAY + ) + if self.strobe_enabled: + self.submit(StrobeMessage()) + + async def run(self): + strober_task = asyncio.create_task(self._background_strober()) + reader = self.batcher_infeed.reader() + while item := await reader(): + self.strobe_enabled = False + if isinstance(item, InferenceExecRequest): + phase = item.phase + if phase == InferencePhase.PREFILL: + self.pending_prefills.add(item) + elif phase == InferencePhase.DECODE: + self.pending_decodes.add(item) + else: + logger.error("Illegal InferenceExecRequest phase: %r", phase) + elif isinstance(item, StrobeMessage): + self.strobes += 1 + else: + logger.error("Illegal message received by batcher: %r", item) + self.board_flights() + self.strobe_enabled = True + await strober_task + + def board_flights(self): + waiting_count = len(self.pending_prefills) + len(self.pending_decodes) + if waiting_count == 0: + return + if waiting_count < self.ideal_batch_size and self.strobes < 2: + logger.info("Waiting a bit longer to fill flight") + return + self.strobes = 0 + cache = self.service.page_cache + + # TODO: This is a very naive cache management algorithm. Burn with fire + # and implement a real one. + self.board_prefills(cache) + self.board_decodes(cache) + + # For now, kill anything that is left. + for prefill_request in self.pending_prefills: + prefill_request.done.set_success() + self.pending_prefills.clear() + logger.debug("Post boarding cache state: %r", cache) + + def board_prefills(self, cache: AttnPageCache): + # Fill prefill flights. + pending_prefills = self.pending_prefills + if len(pending_prefills) == 0: + return + exec_process = InferenceExecutorProcess( + self.service, + InferencePhase.PREFILL, + self.page_seq_stride, + cache.page_tables, + ) + for prefill_request in pending_prefills: + assert prefill_request.phase == InferencePhase.PREFILL + if len(exec_process.exec_requests) >= self.ideal_batch_size: + break + needed_pages = math.ceil( + len(prefill_request.input_token_ids) / self.page_seq_stride + ) + pages = cache.acquire_free_pages(needed_pages) + if pages is None: + logger.debug("Cannot fulfill request for %d pages", needed_pages) + continue + else: + logger.debug("Allocated %d cache pages to request", len(pages)) + prefill_request.lock_initial_cache_pages(cache, pages) + + # Can flight this request. + exec_process.exec_requests.append(prefill_request) + + # We've filled our flight. Remove from the boarding area. + if exec_process.exec_requests: + for flighted_request in exec_process.exec_requests: + self.pending_prefills.remove(flighted_request) + # And takeoff. + exec_process.launch() + + def board_decodes(self, cache: AttnPageCache): + # Fill decode flights. + pending_decodes = self.pending_decodes + if len(pending_decodes) == 0: + return + exec_process = InferenceExecutorProcess( + self.service, InferencePhase.DECODE, self.page_seq_stride, cache.page_tables + ) + for decode_request in pending_decodes: + assert decode_request.phase == InferencePhase.DECODE + if len(exec_process.exec_requests) >= self.ideal_batch_size: + break + incoming_token_count = len(decode_request.input_token_ids) + needed_pages = math.ceil( + (decode_request.start_position + incoming_token_count) + / self.page_seq_stride + ) + if needed_pages > len(decode_request.locked_pages): + pages = cache.acquire_free_pages(needed_pages) + if pages is None: + logger.debug( + "Cannot fulfill decode request for %d pages", needed_pages + ) + continue + else: + logger.debug( + "Allocated %d cache pages to decode request", len(pages) + ) + decode_request.lock_new_cache_pages(cache, pages) + + # Can flight this request. + exec_process.exec_requests.append(decode_request) + + # We've filled our flight. Remove from the boarding area. + if exec_process.exec_requests: + for flighted_request in exec_process.exec_requests: + self.pending_decodes.remove(flighted_request) + # And takeoff. + exec_process.launch() + + +######################################################################################## +# Inference Executor +######################################################################################## + + +class InferenceExecutorProcess(sf.Process): + """Executes a prefill or decode batch.""" + + def __init__( + self, + service: GenerateService, + phase: InferencePhase, + seq_stride: int, + page_tables, + ): + super().__init__(fiber=service.main_fiber) + self.service = service + self.phase = phase + self.seq_stride = seq_stride + self.exec_requests: list[InferenceExecRequest] = [] + self.page_tables = page_tables + + async def run(self): + try: + is_decode = self.phase == InferencePhase.DECODE + req_bs = len(self.exec_requests) + seq_stride = self.seq_stride + # Select an entrypoint for the batch. + if is_decode: + entrypoints = self.service.decode_functions + else: + entrypoints = self.service.prefill_functions + for bs, fn in entrypoints.items(): + if bs >= req_bs: + break + else: + raise RuntimeError(f"No available entry point for bs {req_bs}") + + # Compute block sequence length as maximum sequence length, rounded + # up to the seq_stride. + if self.phase == InferencePhase.PREFILL: + for r in self.exec_requests: + assert r.start_position == 0 + + bsl = max( + (r.start_position + len(r.input_token_ids)) for r in self.exec_requests + ) + bsl = int(math.ceil(bsl / seq_stride) * seq_stride) + block_count = bsl // seq_stride + req_count = len(self.exec_requests) + logger.debug("Prefill bs=%d, bsl=%d", bs, bsl) + + # Prepare inputs. + # TODO: Better support in shortfin for h2d. The best way to do it is + # device dependent. + device0 = self.fiber.device(0) + int_dtype = sfnp.int64 + if is_decode: + tokens = sfnp.device_array.for_device(device0, [bs, 1], int_dtype) + start_positions = sfnp.device_array.for_device(device0, [bs], int_dtype) + else: + tokens = sfnp.device_array.for_device(device0, [bs, bsl], int_dtype) + seq_lens = sfnp.device_array.for_device(device0, [bs], int_dtype) + seq_block_ids = sfnp.device_array.for_device( + device0, [bs, block_count], int_dtype + ) + + # Populate tokens. + tokens_host = tokens.for_transfer() + for i in range(bs): + with tokens_host.view(i).map(discard=True) as m: + m.fill(0) + if i < req_count: + if self.phase == InferencePhase.PREFILL: + m.items = self.exec_requests[i].input_token_ids + elif self.phase == InferencePhase.DECODE: + m.items = self.exec_requests[i].input_token_ids[-1:] + tokens_host.copy_to(tokens) + + # For prefill, populate seq_lens + if self.phase == InferencePhase.PREFILL: + seq_lens_host = seq_lens.for_transfer() + with seq_lens_host.map(discard=True) as m: + m.fill(0) + m.items = [len(req.input_token_ids) for req in self.exec_requests] + seq_lens_host.copy_to(seq_lens) + + # For decode, populate start_positions and seq_lens. + # paged_llm_v1 and export_paged_llm_v1 do some funky things with start_positions and seq_lens + # TODO: make them not so funky + if self.phase == InferencePhase.DECODE: + start_positions_host = start_positions.for_transfer() + with start_positions_host.map(discard=True) as m: + m.fill(0) + m.items = [req.start_position for req in self.exec_requests] + start_positions_host.copy_to(start_positions) + + seq_lens_host = seq_lens.for_transfer() + with seq_lens_host.map(discard=True) as m: + m.fill(0) + m.items = [ + req.start_position + len(req.input_token_ids) + for req in self.exec_requests + ] + seq_lens_host.copy_to(seq_lens) + + # Populate cache pages. + seq_block_ids_host = seq_block_ids.for_transfer() + for i in range(bs): + with seq_block_ids_host.view(i).map(discard=True) as m: + m.fill(0) + if i < req_count: + m.items = self.exec_requests[i].cache_page_indices(block_count) + seq_block_ids_host.copy_to(seq_block_ids) + + # V1 args: + # prefill: + # tokens: [bs, bsl] + # seq_lens: [bs] + # seq_block_ids: [bs, blocks] + # cache_slabs: ... + # decode: + # tokens: [bs, 1] + # seq_lens: [bs] + # start_positions: [bs] + # seq_block_ids: [bs, blocks] + # cache_slabs: ... + if is_decode: + args = [tokens, seq_lens, start_positions, seq_block_ids] + else: + args = [tokens, seq_lens, seq_block_ids] + args.extend(self.page_tables) + logger.info( + "INVOKE %r: %s", + fn, + "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(args)]), + ) + # Invoke. Logits are of shape [bs, bsl, d]. + (logits,) = await fn(*args, fiber=self.fiber) + + # Return results. + for i in range(req_count): + req = self.exec_requests[i] + sl = 1 if is_decode else len(req.input_token_ids) + if req.return_all_logits: + logits_item = logits.view(i, slice(0, sl)) + else: + logits_item = logits.view(i, sl - 1) + if req.return_host_array: + req.result_logits = logits_item.for_transfer() + req.result_logits.copy_from(logits_item) + await device0 + else: + req.result_logits = logits_item + req.done.set_success() + + except Exception: + logger.exception("Fatal error in prefetch invocation") + # TODO: Cancel and set error correctly + for req in self.exec_requests: + req.result_logits = None + req.free_cache_pages() + req.done.set_success() diff --git a/shortfin/python/shortfin_apps/llm/components/tokenizer.py b/shortfin/python/shortfin_apps/llm/components/tokenizer.py new file mode 100644 index 000000000..59112e4dd --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/tokenizer.py @@ -0,0 +1,90 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path + +import tokenizers + +import shortfin as sf +import shortfin.array as sfnp + +# Type alias from the backing library. +Encoding = tokenizers.Encoding + + +class Tokenizer: + def __init__( + self, raw_tk: tokenizers.Tokenizer, pad_id: int = 0, eos_token: str = None + ): + self.pad_id = pad_id + self.eos_token = eos_token + self.eos_token_id = ( + raw_tk.token_to_id(eos_token) if eos_token is not None else None + ) + self._raw = raw_tk + self._raw.enable_padding(pad_id=pad_id) + + @staticmethod + def from_pretrained(name: str) -> "Tokenizer": + raw_tk = tokenizers.Tokenizer.from_pretrained(name) + return Tokenizer(raw_tk) + + @staticmethod + def from_tokenizer_json_file(json_path: Path | str, eos_token: str): + return Tokenizer( + tokenizers.Tokenizer.from_file(str(json_path)), eos_token=eos_token + ) + + def encode(self, texts: list[str]) -> list[tokenizers.Encoding]: + """Encodes a batch of texts, applying no padding.""" + return self._raw.encode_batch(texts) + + def decode(self, sequences) -> list[str]: + """Decodes a batch of sequences to text.""" + return self._raw.decode_batch(sequences) + + def encoding_length(self, enc: tokenizers.Encoding) -> int: + """Gets the length of an encoding.""" + return len(enc.ids) + + def post_process_encodings( + self, encs: list[tokenizers.Encoding], batch_seq_len: int + ): + """Truncates and pads to a requested size.""" + for enc in encs: + enc.truncate(batch_seq_len) + enc.pad(batch_seq_len) + + def encodings_to_array( + self, + device: sf.ScopedDevice, + encs: list[tokenizers.Encoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + """Creates a device_array with the contents of a batch of encodings. + + It is expected that the user has called post_process_encodings with + the same batch_seq_len in order to properly truncate/pad. + """ + ary = sfnp.device_array.for_host(device, [len(encs), batch_seq_len], dtype) + for i, enc in enumerate(encs): + ary.view(i).items = enc.ids + return ary + + def attention_masks_to_array( + self, + device: sf.ScopedDevice, + encs: list[tokenizers.Encoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + ary = sfnp.device_array.for_host(device, [len(encs), batch_seq_len], dtype) + for i, enc in enumerate(encs): + ary.view(i).items = enc.attention_mask + return ary diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py new file mode 100644 index 000000000..2ab7a1b96 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -0,0 +1,221 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any + +import argparse +import logging +from pathlib import Path +import sys + +import uvicorn.logging + +# Import first as it does dep checking and reporting. +from shortfin import ProgramIsolation +from shortfin.interop.fastapi import FastAPIResponder + +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request, Response +import uvicorn + + +from .components.generate import ClientGenerateBatchProcess +from .components.config_struct import ModelParams +from .components.io_struct import GenerateReqInput +from .components.manager import SystemManager +from .components.service import GenerateService +from .components.tokenizer import Tokenizer + + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + sysman.start() + try: + for service_name, service in services.items(): + logging.info("Initializing service '%s': %r", service_name, service) + service.start() + except: + sysman.shutdown() + raise + yield + try: + for service_name, service in services.items(): + logging.info("Shutting down service '%s'", service_name) + service.shutdown() + finally: + sysman.shutdown() + + +sysman: SystemManager +services: dict[str, Any] = {} +app = FastAPI(lifespan=lifespan) + + +@app.get("/health") +async def health() -> Response: + return Response(status_code=200) + + +async def generate_request(gen_req: GenerateReqInput, request: Request): + service = services["default"] + gen_req.post_init() + responder = FastAPIResponder(request) + ClientGenerateBatchProcess(service, gen_req, responder).launch() + return await responder.response + + +app.post("/generate")(generate_request) +app.put("/generate")(generate_request) + + +def get_eos_from_tokenizer_config(json_path): + import json + + with open(json_path, "rt") as f: + json_text = f.read() + config = json.loads(json_text) + return config["eos_token"] + + +def configure(args) -> SystemManager: + # Setup system (configure devices, etc). + sysman = SystemManager( + device=args.device, + device_ids=args.device_ids, + async_allocs=args.amdgpu_async_allocations, + ) + + # Setup each service we are hosting. + eos_token = get_eos_from_tokenizer_config(args.tokenizer_config_json) + tokenizer = Tokenizer.from_tokenizer_json_file( + args.tokenizer_json, eos_token=eos_token + ) + model_params = ModelParams.load_json(args.model_config) + sm = GenerateService( + name="default", + sysman=sysman, + tokenizer=tokenizer, + model_params=model_params, + program_isolation=args.isolation, + ) + sm.load_inference_module(args.vmfb) + sm.load_inference_parameters(*args.parameters, parameter_scope="model") + services[sm.name] = sm + return sysman + + +def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="Root path to use for installing behind path based proxy.", + ) + parser.add_argument( + "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" + ) + parser.add_argument( + "--tokenizer_json", + type=Path, + required=True, + help="Path to a tokenizer.json file", + ) + parser.add_argument( + "--tokenizer_config_json", + type=Path, + required=False, + help="Path to a tokenizer_config json file", + ) + parser.add_argument( + "--model_config", + type=Path, + required=True, + help="Path to the model config file", + ) + parser.add_argument( + "--vmfb", + type=Path, + required=True, + help="Model VMFB to load", + ) + # parameters are loaded with `iree_io_parameters_module_create` + parser.add_argument( + "--parameters", + type=Path, + nargs="*", + help="Parameter archives to load (supports: gguf, irpa, safetensors).", + metavar="FILE", + ) + parser.add_argument( + "--device", + type=str, + required=True, + choices=["local-task", "hip", "amdgpu"], + help="Device to serve on; e.g. local-task, hip. Same options as `iree-run-module --device` ", + ) + parser.add_argument( + "--device_ids", + type=str, + nargs="*", + default=None, + help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_call", + choices=[isolation.name.lower() for isolation in ProgramIsolation], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--amdgpu_async_allocations", + action="store_true", + help="Enable asynchronous allocations for amdgpu device contexts.", + ) + args = parser.parse_args(argv) + + if args.tokenizer_config_json is None: + # this is only used for the EOS token + logging.info("Argument `--tokenizer_config_json` is not provided") + logging.info("Inferring tokenizer config path from tokenizer path") + inferred_tokenizer_config_path = args.tokenizer_json.with_name( + args.tokenizer_json.stem + "_config.json" + ) + args.tokenizer_config_json = inferred_tokenizer_config_path + global sysman + sysman = configure(args) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_config=log_config, + timeout_keep_alive=args.timeout_keep_alive, + ) + + +if __name__ == "__main__": + from shortfin.support.logging_setup import configure_main_logger + + logger = configure_main_logger("server") + main( + sys.argv[1:], + # Make logging defer to the default shortfin logging config. + log_config={ + "version": 1, + "disable_existing_loggers": False, + "formatters": {}, + "handlers": {}, + "loggers": {}, + }, + ) diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md new file mode 100644 index 000000000..3397be6cf --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -0,0 +1,30 @@ +# SDXL Server and CLI + +This directory contains a [SDXL](https://stablediffusionxl.com/) inference server, CLI and support components. More information about SDXL on [huggingface](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). + +## Install + +For [nightly releases](../../../../docs/nightly_releases.md) +For our [stable release](../../../../docs/user_guide.md) + +## Start SDXL Server +The server will prepare runtime artifacts for you. + +By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands. + +You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`. + +``` +python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" +``` + - Wait until your server outputs: +``` +INFO - Application startup complete. +INFO - Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` +## Run the SDXL Client + + - Run a CLI client in a separate shell: +``` +python -m shortfin_apps.sd.simple_client --interactive +``` diff --git a/shortfin/python/shortfin_apps/sd/__init__.py b/shortfin/python/shortfin_apps/sd/__init__.py new file mode 100644 index 000000000..4a168079c --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import _deps diff --git a/shortfin/python/shortfin_apps/sd/_deps.py b/shortfin/python/shortfin_apps/sd/_deps.py new file mode 100644 index 000000000..92bd089ec --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/_deps.py @@ -0,0 +1,22 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from shortfin.support.deps import ShortfinDepNotFoundError + +try: + import transformers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "transformers") from e + +try: + import tokenizers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "tokenizers") from e + +try: + import dataclasses_json +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py new file mode 100644 index 000000000..98678c46d --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -0,0 +1,320 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace, BuildAction, BuildContext, BuildFile +import itertools +import os +import urllib +import shortfin.array as sfnp +import copy + +from shortfin_apps.sd.components.config_struct import ModelParams + +this_dir = os.path.dirname(os.path.abspath(__file__)) +parent = os.path.dirname(this_dir) +default_config_json = os.path.join(parent, "examples", "sdxl_config_i8.json") + +dtype_to_filetag = { + sfnp.float16: "fp16", + sfnp.float32: "fp32", + sfnp.int8: "i8", + sfnp.bfloat16: "bf16", +} + +ARTIFACT_VERSION = "11182024" +SDXL_BUCKET = ( + f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/" +) +SDXL_WEIGHTS_BUCKET = ( + "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/" +) + + +def filter_by_model(filenames, model): + if not model: + return filenames + filtered = [] + for i in filenames: + if model.lower() in i.lower(): + filtered.extend([i]) + return filtered + + +def get_mlir_filenames(model_params: ModelParams, model=None): + mlir_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + mlir_filenames.extend([stem + ".mlir"]) + return filter_by_model(mlir_filenames, model) + + +def get_vmfb_filenames( + model_params: ModelParams, model=None, target: str = "amdgpu-gfx942" +): + vmfb_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + vmfb_filenames.extend([stem + "_" + target + ".vmfb"]) + return filter_by_model(vmfb_filenames, model) + + +def get_params_filenames(model_params: ModelParams, model=None, splat: bool = False): + params_filenames = [] + base = ( + "stable_diffusion_xl_base_1_0" + if model_params.base_model_name.lower() == "sdxl" + else model_params.base_model_name + ) + modnames = ["clip", "vae"] + mod_precs = [ + dtype_to_filetag[model_params.clip_dtype], + dtype_to_filetag[model_params.unet_dtype], + ] + if model_params.use_i8_punet: + modnames.append("punet") + mod_precs.append("i8") + else: + modnames.append("unet") + mod_precs.append(dtype_to_filetag[model_params.unet_dtype]) + if splat == "True": + for idx, mod in enumerate(modnames): + params_filenames.extend( + ["_".join([mod, "splat", f"{mod_precs[idx]}.irpa"])] + ) + else: + for idx, mod in enumerate(modnames): + params_filenames.extend( + [base + "_" + mod + "_dataset_" + mod_precs[idx] + ".irpa"] + ) + return filter_by_model(params_filenames, model) + + +def get_file_stems(model_params: ModelParams): + file_stems = [] + base = ( + ["stable_diffusion_xl_base_1_0"] + if model_params.base_model_name.lower() == "sdxl" + else [model_params.base_model_name] + ) + mod_names = { + "clip": "clip", + "unet": "punet" if model_params.use_i8_punet else "unet", + "scheduler": model_params.scheduler_id + "Scheduler", + "vae": "vae", + } + for mod, modname in mod_names.items(): + ord_params = [ + base, + [modname], + ] + bsizes = [] + for bs in getattr(model_params, f"{mod}_batch_sizes", [1]): + bsizes.extend([f"bs{bs}"]) + ord_params.extend([bsizes]) + if mod in ["unet", "clip"]: + ord_params.extend([[str(model_params.max_seq_len)]]) + if mod in ["unet", "vae", "scheduler"]: + dims = [] + for dim_pair in model_params.dims: + dim_pair_str = [str(d) for d in dim_pair] + dims.extend(["x".join(dim_pair_str)]) + ord_params.extend([dims]) + if mod == "scheduler": + dtype_str = dtype_to_filetag[model_params.unet_dtype] + elif mod != "unet": + dtype_str = dtype_to_filetag[ + getattr(model_params, f"{mod}_dtype", sfnp.float16) + ] + else: + dtype_str = ( + "i8" + if model_params.use_i8_punet + else dtype_to_filetag[model_params.unet_dtype] + ) + ord_params.extend([[dtype_str]]) + for x in list(itertools.product(*ord_params)): + file_stems.extend(["_".join(x)]) + return file_stems + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + needed = True + if os.path.exists(out_file): + if url: + needed = not is_valid_size(out_file, url) + if not needed: + return False + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + return True + + +def needs_compile(filename, target, ctx): + vmfb_name = f"{filename}_{target}.vmfb" + namespace = FileNamespace.BIN + return needs_file(vmfb_name, ctx, namespace=namespace) + + +def get_cached_vmfb(filename, target, ctx): + vmfb_name = f"{filename}_{target}.vmfb" + return ctx.file(vmfb_name) + + +def is_valid_size(file_path, url): + if not url: + return True + with urllib.request.urlopen(url) as response: + content_length = response.getheader("Content-Length") + local_size = get_file_size(str(file_path)) + if content_length: + content_length = int(content_length) + if content_length != local_size: + return False + return True + + +def get_file_size(file_path): + """Gets the size of a local file in bytes as an integer.""" + + file_stats = os.stat(file_path) + return file_stats.st_size + + +def fetch_http_check_size(*, name: str, url: str) -> BuildFile: + context = BuildContext.current() + output_file = context.allocate_file(name) + action = FetchHttpWithCheckAction( + url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor + ) + output_file.deps.add(action) + return output_file + + +class FetchHttpWithCheckAction(BuildAction): + def __init__(self, url: str, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.url = url + self.output_file = output_file + + def _invoke(self, retries=4): + path = self.output_file.get_fs_path() + self.executor.write_status(f"Fetching URL: {self.url} -> {path}") + try: + urllib.request.urlretrieve(self.url, str(path)) + except urllib.error.HTTPError as e: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + else: + raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None + local_size = get_file_size(str(path)) + try: + with urllib.request.urlopen(self.url) as response: + content_length = response.getheader("Content-Length") + if content_length: + content_length = int(content_length) + if content_length != local_size: + raise IOError( + f"Size of downloaded artifact does not match content-length header! {content_length} != {local_size}" + ) + except IOError: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + + +@entrypoint(description="Retreives a set of SDXL submodels.") +def sdxl( + model_json=cl_arg( + "model-json", + default=default_config_json, + help="Local config filepath", + ), + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + splat=cl_arg( + "splat", default=False, type=str, help="Download empty weights (for testing)" + ), + build_preference=cl_arg( + "build-preference", + default="precompiled", + help="Sets preference for artifact generation method: [compile, precompiled]", + ), + model=cl_arg("model", type=str, help="Submodel to fetch/compile for."), +): + model_params = ModelParams.load_json(model_json) + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + mlir_bucket = SDXL_BUCKET + "mlir/" + vmfb_bucket = SDXL_BUCKET + "vmfbs/" + if "gfx" in target: + target = "amdgpu-" + target + + mlir_filenames = get_mlir_filenames(model_params, model) + mlir_urls = get_url_map(mlir_filenames, mlir_bucket) + for f, url in mlir_urls.items(): + if update or needs_file(f, ctx, url): + fetch_http(name=f, url=url) + + vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target) + vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket) + if build_preference == "compile": + for idx, f in enumerate(copy.deepcopy(vmfb_filenames)): + # We return .vmfb file stems for the compile builder. + file_stem = "_".join(f.split("_")[:-1]) + if needs_compile(file_stem, target, ctx): + for mlirname in mlir_filenames: + if file_stem in mlirname: + mlir_source = mlirname + break + obj = compile(name=file_stem, source=mlir_source) + vmfb_filenames[idx] = obj[0] + else: + vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx) + else: + for f, url in vmfb_urls.items(): + if update or needs_file(f, ctx, url): + fetch_http(name=f, url=url) + + params_filenames = get_params_filenames(model_params, model=model, splat=splat) + params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) + for f, url in params_urls.items(): + if needs_file(f, ctx, url): + fetch_http_check_size(name=f, url=url) + filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py new file mode 100644 index 000000000..432f08b4e --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py @@ -0,0 +1,104 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace +import os + +ARTIFACT_VERSION = "11182024" +SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + if os.path.exists(out_file): + needed = False + else: + # name_path = "bin" if namespace == FileNamespace.BIN else "" + # if name_path: + # filename = os.path.join(name_path, filename) + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + needed = True + return needed + + +@entrypoint(description="Retreives a set of SDXL configuration files.") +def sdxlconfig( + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + model=cl_arg("model", type=str, default="sdxl", help="Model architecture"), + topology=cl_arg( + "topology", + type=str, + default="spx_single", + help="System topology configfile keyword", + ), +): + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + model_config_filenames = [f"{model}_config_i8.json"] + model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) + for f, url in model_config_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + topology_config_filenames = [f"topology_config_{topology}.txt"] + topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET) + for f, url in topology_config_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + flagfile_filenames = [f"{model}_flagfile_{target}.txt"] + flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) + for f, url in flagfile_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + tuning_filenames = ( + [f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else [] + ) + tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) + for f, url in tuning_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + filenames = [ + *model_config_filenames, + *topology_config_filenames, + *flagfile_filenames, + *tuning_filenames, + ] + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py new file mode 100644 index 000000000..2b954c18b --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -0,0 +1,118 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Configuration objects. + +Parameters that are intrinsic to a specific model. + +Typically represented in something like a Huggingface config.json, +we extend the configuration to enumerate inference boundaries of some given set of compiled modules. +""" + +from dataclasses import dataclass +from pathlib import Path + +from dataclasses_json import dataclass_json, Undefined + +import shortfin.array as sfnp + +str_to_dtype = { + "int8": sfnp.int8, + "float16": sfnp.float16, +} + + +@dataclass_json(undefined=Undefined.RAISE) +@dataclass +class ModelParams: + """Parameters for a specific set of compiled SD submodels, sufficient to do batching / + invocations.""" + + # Maximum length of prompt sequence. + max_seq_len: int + + # Channel dim of latents. + num_latents_channels: int + + # Batch sizes that each stage is compiled for. These are expected to be + # functions exported from the model with suffixes of "_bs{batch_size}". Must + # be in ascending order. + clip_batch_sizes: list[int] + + # Similarly, batch sizes that the decode stage is compiled for. + unet_batch_sizes: list[int] + + # Same for VAE. + vae_batch_sizes: list[int] + + # Same for scheduler. + scheduler_batch_sizes: list[int] + + # Height and Width, respectively, for which Unet and VAE are compiled. e.g. [[512, 512], [1024, 1024]] + dims: list[list[int]] + + # Scheduler id. + scheduler_id: str = "EulerDiscrete" + + base_model_name: str = "SDXL" + # Name of the IREE module for each submodel. + clip_module_name: str = "compiled_clip" + unet_module_name: str = "compiled_unet" + vae_module_name: str = "compiled_vae" + scheduler_module_name: str = "compiled_scheduler" + + # some unet vmfbs have "main" as entrypoint. + unet_fn_name: str = "run_forward" + + # Classifer free guidance mode. If set to false, only positive prompts will matter. + cfg_mode = True + + # DTypes (not necessarily weights precision): + clip_dtype: sfnp.DType = sfnp.float16 + unet_dtype: sfnp.DType = sfnp.float16 + vae_dtype: sfnp.DType = sfnp.float16 + + use_i8_punet: bool = False + + # ABI of the module. + module_abi_version: int = 1 + + @property + def max_clip_batch_size(self) -> int: + return self.clip_batch_sizes[-1] + + @property + def max_unet_batch_size(self) -> int: + return self.unet_batch_sizes[-1] + + @property + def max_vae_batch_size(self) -> int: + return self.vae_batch_sizes[-1] + + @property + def all_batch_sizes(self) -> list: + return [self.clip_batch_sizes, self.unet_batch_sizes, self.vae_batch_sizes] + + @property + def max_batch_size(self): + return max(self.all_batch_sizes) + + @staticmethod + def load_json(path: Path | str): + with open(path, "rt") as f: + json_text = f.read() + raw_params = ModelParams.from_json(json_text) + if isinstance(raw_params.unet_dtype, str): + raw_params.unet_dtype = str_to_dtype[raw_params.unet_dtype] + return raw_params + + def __repr__(self): + return ( + f" base model : {self.base_model_name} \n" + f" output size (H,W) : {self.dims} \n" + f" max token sequence length : {self.max_seq_len} \n" + f" classifier free guidance : {self.cfg_mode} \n" + ) diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py new file mode 100644 index 000000000..62ac5e855 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -0,0 +1,102 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +import json + +import shortfin as sf + +# TODO: Have a generic "Responder" interface vs just the concrete impl. +from shortfin.interop.fastapi import FastAPIResponder + +from .io_struct import GenerateReqInput +from .messages import InferenceExecRequest +from .service import GenerateService +from .metrics import measure + +logger = logging.getLogger("shortfin-sd.generate") + + +class GenerateImageProcess(sf.Process): + """Process instantiated for every image generation. + + This process breaks the sequence into individual inference and sampling + steps, submitting them to the batcher and marshaling final + results. + + Responsible for a single image. + """ + + def __init__( + self, + client: "ClientGenerateBatchProcess", + gen_req: GenerateReqInput, + index: int, + ): + super().__init__(fiber=client.fiber) + self.client = client + self.gen_req = gen_req + self.index = index + self.result_image = None + + async def run(self): + exec = InferenceExecRequest.from_batch(self.gen_req, self.index) + self.client.batcher.submit(exec) + await exec.done + self.result_image = exec.result_image + + +class ClientGenerateBatchProcess(sf.Process): + """Process instantiated for handling a batch from a client. + + This takes care of several responsibilities: + + * Tokenization + * Random Latents Generation + * Splitting the batch into GenerateImageProcesses + * Streaming responses + * Final responses + """ + + __slots__ = [ + "batcher", + "complete_infeed", + "gen_req", + "responder", + ] + + def __init__( + self, + service: GenerateService, + gen_req: GenerateReqInput, + responder: FastAPIResponder, + ): + super().__init__(fiber=service.fibers[0]) + self.gen_req = gen_req + self.responder = responder + self.batcher = service.batcher + self.complete_infeed = self.system.create_queue() + + async def run(self): + logger.debug("Started ClientBatchGenerateProcess: %r", self) + try: + # Launch all individual generate processes and wait for them to finish. + gen_processes = [] + for index in range(self.gen_req.num_output_images): + gen_process = GenerateImageProcess(self, self.gen_req, index) + gen_processes.append(gen_process) + gen_process.launch() + + await asyncio.gather(*gen_processes) + + # TODO: stream image outputs + logging.debug("Responding to one shot batch") + response_data = {"images": [p.result_image for p in gen_processes]} + json_str = json.dumps(response_data) + self.responder.send_response(json_str) + finally: + self.responder.ensure_response() diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py new file mode 100644 index 000000000..73e77316f --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py @@ -0,0 +1,79 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Optional, Union +from dataclasses import dataclass +import uuid + + +@dataclass +class GenerateReqInput: + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None + # The input negative prompt. It can be a single prompt or a batch of prompts. + neg_prompt: Optional[Union[List[str], str]] = None + # Output image dimensions per prompt. + height: Optional[Union[List[int], int]] = None + width: Optional[Union[List[int], int]] = None + # The number of inference steps; one int per prompt. + steps: Optional[Union[List[int], int]] = None + # The classifier-free-guidance scale for denoising; one float per prompt. + guidance_scale: Optional[Union[List[float], float]] = None + # The seed for random latents generation; one int per prompt. + seed: Optional[Union[List[int], int]] = None + # Token ids: only used in place of prompt. + input_ids: Optional[Union[List[List[int]], List[int]]] = None + # Negative token ids: only used in place of negative prompt. + neg_input_ids: Optional[Union[List[List[int]], List[int]]] = None + # Output image format. Defaults to base64. One string ("PIL", "base64") + output_type: Optional[List[str]] = None + # The request id. + rid: Optional[Union[List[str], str]] = None + + def post_init(self): + if (self.prompt is None and self.input_ids is None) or ( + self.prompt is not None and self.input_ids is not None + ): + raise ValueError("Either text or input_ids should be provided.") + + if isinstance(self.prompt, str): + self.prompt = [str] + + self.num_output_images = ( + len(self.prompt) if self.prompt is not None else len(self.input_ids) + ) + + batchable_args = [ + self.prompt, + self.neg_prompt, + self.height, + self.width, + self.steps, + self.guidance_scale, + self.seed, + self.input_ids, + self.neg_input_ids, + ] + for arg in batchable_args: + if isinstance(arg, list): + if len(arg) != self.num_output_images and len(arg) != 1: + raise ValueError( + f"Batchable arguments should either be singular or as many as the full batch ({self.num_output_images})." + ) + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)] + else: + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") + if self.output_type is None: + self.output_type = ["base64"] * self.num_output_images + # Temporary restrictions + heights = [self.height] if not isinstance(self.height, list) else self.height + widths = [self.width] if not isinstance(self.width, list) else self.width + if any(dim != 1024 for dim in [*heights, *widths]): + raise ValueError( + "Currently, only 1024x1024 output image size is supported." + ) diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py new file mode 100644 index 000000000..e416592d0 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/manager.py @@ -0,0 +1,48 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import threading + +import shortfin as sf +from shortfin.interop.support.device_setup import get_selected_devices + +logger = logging.getLogger("shortfin-sd.manager") + + +class SystemManager: + def __init__(self, device="local-task", device_ids=None, async_allocs=True): + if any(x in device for x in ["local-task", "cpu"]): + self.ls = sf.host.CPUSystemBuilder().create_system() + elif any(x in device for x in ["hip", "amdgpu"]): + sb = sf.SystemBuilder( + system_type="amdgpu", amdgpu_async_allocations=async_allocs + ) + if device_ids: + sb.visible_devices = sb.available_devices + sb.visible_devices = get_selected_devices(sb, device_ids) + self.ls = sb.create_system() + logger.info(f"Created local system with {self.ls.device_names} devices") + # TODO: Come up with an easier bootstrap thing than manually + # running a thread. + self.t = threading.Thread(target=lambda: self.ls.run(self.run())) + self.command_queue = self.ls.create_queue("command") + self.command_writer = self.command_queue.writer() + + def start(self): + logger.info("Starting system manager") + self.t.start() + + def shutdown(self): + logger.info("Shutting down system manager") + self.command_queue.close() + self.ls.shutdown() + + async def run(self): + reader = self.command_queue.reader() + while command := await reader(): + ... + logger.info("System manager command processor stopped") diff --git a/shortfin/python/shortfin_apps/sd/components/messages.py b/shortfin/python/shortfin_apps/sd/components/messages.py new file mode 100644 index 000000000..6ae716bad --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/messages.py @@ -0,0 +1,185 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from enum import Enum + +import logging + +import shortfin as sf +import shortfin.array as sfnp + +from .io_struct import GenerateReqInput + +logger = logging.getLogger("shortfin-sd.messages") + + +class InferencePhase(Enum): + # Tokenize prompt, negative prompt and get latents, timesteps, time ids, guidance scale as device arrays + PREPARE = 1 + # Run CLIP to encode tokenized prompts into text embeddings + ENCODE = 2 + # Run UNet to denoise the random sample + DENOISE = 3 + # Run VAE to decode the denoised latents into an image. + DECODE = 4 + # Postprocess VAE outputs. + POSTPROCESS = 5 + + +class InferenceExecRequest(sf.Message): + """ + Generalized request passed for an individual phase of image generation. + + Used for individual image requests. Bundled as lists by the batcher for inference processes, + and inputs joined for programs with bs>1. + + Inference execution processes are responsible for writing their outputs directly to the appropriate attributes here. + """ + + def __init__( + self, + prompt: str | None = None, + neg_prompt: str | None = None, + height: int | None = None, + width: int | None = None, + steps: int | None = None, + guidance_scale: float | sfnp.device_array | None = None, + seed: int | None = None, + input_ids: list[list[int]] | None = None, + sample: sfnp.device_array | None = None, + prompt_embeds: sfnp.device_array | None = None, + text_embeds: sfnp.device_array | None = None, + timesteps: sfnp.device_array | None = None, + time_ids: sfnp.device_array | None = None, + denoised_latents: sfnp.device_array | None = None, + image_array: sfnp.device_array | None = None, + ): + super().__init__() + self.print_debug = True + + self.phases = {} + self.phase = None + self.height = height + self.width = width + + # Phase inputs: + # Prep phase. + self.prompt = prompt + self.neg_prompt = neg_prompt + self.height = height + self.width = width + self.seed = seed + + # Encode phase. + # This is a list of sequenced positive and negative token ids and pooler token ids. + self.input_ids = input_ids + + # Denoise phase. + self.prompt_embeds = prompt_embeds + self.text_embeds = text_embeds + self.sample = sample + self.steps = steps + self.timesteps = timesteps + self.time_ids = time_ids + self.guidance_scale = guidance_scale + + # Decode phase. + self.denoised_latents = denoised_latents + + # Postprocess. + self.image_array = image_array + + self.result_image = None + self.img_metadata = None + + self.done = sf.VoidFuture() + + # Response control. + # Move the result array to the host and sync to ensure data is + # available. + self.return_host_array: bool = True + + self.post_init() + + @staticmethod + def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": + gen_inputs = [ + "prompt", + "neg_prompt", + "height", + "width", + "steps", + "guidance_scale", + "seed", + ] + rec_inputs = {} + for item in gen_inputs: + received = getattr(gen_req, item, None) + if isinstance(received, list): + if index >= (len(received)): + if len(received) == 1: + rec_input = received[0] + else: + logging.error( + "Inputs in request must be singular or as many as the list of prompts." + ) + else: + rec_input = received[index] + else: + rec_input = received + rec_inputs[item] = rec_input + return InferenceExecRequest(**rec_inputs) + + def post_init(self): + """Determines necessary inference phases and tags them with static program parameters.""" + for p in reversed(list(InferencePhase)): + required, metadata = self.check_phase(p) + p_data = {"required": required, "metadata": metadata} + self.phases[p] = p_data + if not required: + if p not in [ + InferencePhase.ENCODE, + InferencePhase.PREPARE, + ]: + break + self.phase = p + + def check_phase(self, phase: InferencePhase): + match phase: + case InferencePhase.POSTPROCESS: + return True, None + case InferencePhase.DECODE: + required = not self.image_array + meta = [self.width, self.height] + return required, meta + case InferencePhase.DENOISE: + required = not self.denoised_latents + meta = [self.width, self.height, self.steps] + return required, meta + case InferencePhase.ENCODE: + p_results = [ + self.prompt_embeds, + self.text_embeds, + ] + required = any([inp is None for inp in p_results]) + return required, None + case InferencePhase.PREPARE: + p_results = [self.sample, self.input_ids] + required = any([inp is None for inp in p_results]) + return required, None + + def reset(self, phase: InferencePhase): + """Resets all per request state in preparation for an subsequent execution.""" + self.phase = None + self.phases = None + self.done = sf.VoidFuture() + self.return_host_array = True + + +class StrobeMessage(sf.Message): + """Sent to strobe a queue with fake activity (generate a wakeup).""" + + ... diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py new file mode 100644 index 000000000..62e855698 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/metrics.py @@ -0,0 +1,51 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import time +from typing import Any +import functools + +logger = logging.getLogger("shortfin-sd.metrics") + + +def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"): + assert callable(fn) or fn is None + + def _decorator(func): + @functools.wraps(func) + async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: + start = time.time() + ret = await func(*args, **kwargs) + duration = time.time() - start + if type == "exec": + batch_size = len(getattr(args[0], "exec_requests", [])) + log_duration_str(duration, task=task, batch_size=batch_size) + if type == "throughput": + if isinstance(num_items, str): + items = getattr(args[0].gen_req, num_items) + else: + items = str(num_items) + log_throughput(duration, items, freq, label) + return ret + + return wrapped_fn_async + + return _decorator(fn) if callable(fn) else _decorator + + +def log_throughput(duration, num_items, freq, label) -> str: + sps = str(float(num_items) / duration) * freq + freq_str = "second" if freq == 1 else f"{freq} seconds" + logger.info(f"THROUGHPUT: {sps} {label} per {freq_str}") + + +def log_duration_str(duration: float, task, batch_size=0) -> str: + """Get human readable duration string from start time""" + if batch_size > 0: + task = f"{task} (batch size {batch_size})" + duration_str = f"{round(duration * 1e3)}ms" + logger.info(f"Completed {task} in {duration_str}") diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py new file mode 100644 index 000000000..9b09632a6 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -0,0 +1,695 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +import numpy as np +from tqdm.auto import tqdm +from pathlib import Path +from PIL import Image +import base64 + +import shortfin as sf +import shortfin.array as sfnp + +from .config_struct import ModelParams +from .manager import SystemManager +from .messages import InferenceExecRequest, InferencePhase, StrobeMessage +from .tokenizer import Tokenizer +from .metrics import measure + +logger = logging.getLogger("shortfin-sd.service") + +prog_isolations = { + "none": sf.ProgramIsolation.NONE, + "per_fiber": sf.ProgramIsolation.PER_FIBER, + "per_call": sf.ProgramIsolation.PER_CALL, +} + + +class GenerateService: + """Top level service interface for image generation.""" + + inference_programs: dict[str, sf.Program] + + inference_functions: dict[str, dict[str, sf.ProgramFunction]] + + def __init__( + self, + *, + name: str, + sysman: SystemManager, + tokenizers: list[Tokenizer], + model_params: ModelParams, + fibers_per_device: int, + workers_per_device: int = 1, + prog_isolation: str = "per_fiber", + show_progress: bool = False, + trace_execution: bool = False, + ): + self.name = name + + # Application objects. + self.sysman = sysman + self.tokenizers = tokenizers + self.model_params = model_params + self.inference_parameters: dict[str, list[sf.BaseProgramParameters]] = {} + self.inference_modules: dict[str, sf.ProgramModule] = {} + self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} + self.inference_programs: dict[int, dict[str, sf.Program]] = {} + self.trace_execution = trace_execution + self.show_progress = show_progress + + self.prog_isolation = prog_isolations[prog_isolation] + + self.workers_per_device = workers_per_device + self.fibers_per_device = fibers_per_device + if fibers_per_device % workers_per_device != 0: + raise ValueError( + "Currently, fibers_per_device must be divisible by workers_per_device" + ) + self.fibers_per_worker = int(fibers_per_device / workers_per_device) + + self.workers = [] + self.fibers = [] + self.idle_fibers = set() + for idx, device in enumerate(self.sysman.ls.devices): + for i in range(self.workers_per_device): + worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") + self.workers.append(worker) + for idx, device in enumerate(self.sysman.ls.devices): + for i in range(self.fibers_per_device): + tgt_worker = self.workers[i % len(self.workers)] + fiber = sysman.ls.create_fiber(tgt_worker, devices=[device]) + self.fibers.append(fiber) + self.idle_fibers.add(fiber) + for idx in range(len(self.workers)): + self.inference_programs[idx] = {} + self.inference_functions[idx] = {} + # Scope dependent objects. + self.batcher = BatcherProcess(self) + + def get_worker_index(self, fiber): + if fiber not in self.fibers: + raise ValueError("A worker was requested from a rogue fiber.") + fiber_idx = self.fibers.index(fiber) + worker_idx = int( + (fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker + ) + return worker_idx + + def load_inference_module(self, vmfb_path: Path, component: str = None): + if not self.inference_modules.get(component): + self.inference_modules[component] = [] + self.inference_modules[component].append( + sf.ProgramModule.load(self.sysman.ls, vmfb_path) + ) + + def load_inference_parameters( + self, + *paths: Path, + parameter_scope: str, + format: str = "", + component: str = None, + ): + p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope) + for path in paths: + logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) + p.load(path, format=format) + if not self.inference_parameters.get(component): + self.inference_parameters[component] = [] + self.inference_parameters[component].append(p) + + def start(self): + # Initialize programs. + for component in self.inference_modules: + logger.info(f"Loading component: {component}") + component_modules = [ + sf.ProgramModule.parameter_provider( + self.sysman.ls, *self.inference_parameters.get(component, []) + ), + *self.inference_modules[component], + ] + + for worker_idx, worker in enumerate(self.workers): + worker_devices = self.fibers[ + worker_idx * (self.fibers_per_worker) + ].raw_devices + logger.info( + f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}" + ) + self.inference_programs[worker_idx][component] = sf.Program( + modules=component_modules, + devices=worker_devices, + isolation=self.prog_isolation, + trace_execution=self.trace_execution, + ) + + for worker_idx, worker in enumerate(self.workers): + self.inference_functions[worker_idx]["encode"] = {} + for bs in self.model_params.clip_batch_sizes: + self.inference_functions[worker_idx]["encode"][ + bs + ] = self.inference_programs[worker_idx]["clip"][ + f"{self.model_params.clip_module_name}.encode_prompts" + ] + self.inference_functions[worker_idx]["denoise"] = {} + for bs in self.model_params.unet_batch_sizes: + self.inference_functions[worker_idx]["denoise"][bs] = { + "unet": self.inference_programs[worker_idx]["unet"][ + f"{self.model_params.unet_module_name}.{self.model_params.unet_fn_name}" + ], + "init": self.inference_programs[worker_idx]["scheduler"][ + f"{self.model_params.scheduler_module_name}.run_initialize" + ], + "scale": self.inference_programs[worker_idx]["scheduler"][ + f"{self.model_params.scheduler_module_name}.run_scale" + ], + "step": self.inference_programs[worker_idx]["scheduler"][ + f"{self.model_params.scheduler_module_name}.run_step" + ], + } + self.inference_functions[worker_idx]["decode"] = {} + for bs in self.model_params.vae_batch_sizes: + self.inference_functions[worker_idx]["decode"][ + bs + ] = self.inference_programs[worker_idx]["vae"][ + f"{self.model_params.vae_module_name}.decode" + ] + self.batcher.launch() + + def shutdown(self): + self.batcher.shutdown() + + def __repr__(self): + modules = [ + f" {key} : {value}" for key, value in self.inference_modules.items() + ] + params = [ + f" {key} : {value}" for key, value in self.inference_parameters.items() + ] + # For python 3.11 since we can't have \ in the f"" expression. + new_line = "\n" + return ( + f"ServiceManager(" + f"\n INFERENCE DEVICES : \n" + f" {self.sysman.ls.devices}\n" + f"\n MODEL PARAMS : \n" + f"{self.model_params}" + f"\n SERVICE PARAMS : \n" + f" fibers per device : {self.fibers_per_device}\n" + f" program isolation mode : {self.prog_isolation}\n" + f"\n INFERENCE MODULES : \n" + f"{new_line.join(modules)}\n" + f"\n INFERENCE PARAMETERS : \n" + f"{new_line.join(params)}\n" + f")" + ) + + +######################################################################################## +# Batcher +######################################################################################## + + +class BatcherProcess(sf.Process): + """The batcher is a persistent process responsible for flighting incoming work + into batches. + """ + + STROBE_SHORT_DELAY = 0.5 + STROBE_LONG_DELAY = 1 + + def __init__(self, service: GenerateService): + super().__init__(fiber=service.fibers[0]) + self.service = service + self.batcher_infeed = self.system.create_queue() + self.pending_requests: set[InferenceExecRequest] = set() + self.strobe_enabled = True + self.strobes: int = 0 + self.ideal_batch_size: int = max(service.model_params.max_batch_size) + self.num_fibers = len(service.fibers) + + def shutdown(self): + self.batcher_infeed.close() + + def submit(self, request: StrobeMessage | InferenceExecRequest): + self.batcher_infeed.write_nodelay(request) + + async def _background_strober(self): + while not self.batcher_infeed.closed: + await asyncio.sleep( + BatcherProcess.STROBE_SHORT_DELAY + if len(self.pending_requests) > 0 + else BatcherProcess.STROBE_LONG_DELAY + ) + if self.strobe_enabled: + self.submit(StrobeMessage()) + + async def run(self): + strober_task = asyncio.create_task(self._background_strober()) + reader = self.batcher_infeed.reader() + while item := await reader(): + self.strobe_enabled = False + if isinstance(item, InferenceExecRequest): + self.pending_requests.add(item) + elif isinstance(item, StrobeMessage): + self.strobes += 1 + else: + logger.error("Illegal message received by batcher: %r", item) + + self.board_flights() + + self.strobe_enabled = True + await strober_task + + def board_flights(self): + waiting_count = len(self.pending_requests) + if waiting_count == 0: + return + if waiting_count < self.ideal_batch_size and self.strobes < 2: + logger.info("Waiting a bit longer to fill flight") + return + self.strobes = 0 + batches = self.sort_batches() + for batch in batches.values(): + # Assign the batch to the next idle fiber. + if len(self.service.idle_fibers) == 0: + return + fiber = self.service.idle_fibers.pop() + fiber_idx = self.service.fibers.index(fiber) + worker_idx = self.service.get_worker_index(fiber) + logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})") + self.board(batch["reqs"], fiber=fiber) + if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(fiber) + + def sort_batches(self): + """Files pending requests into sorted batches suitable for program invocations.""" + reqs = self.pending_requests + next_key = 0 + batches = {} + for req in reqs: + is_sorted = False + req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()] + + for idx_key, data in batches.items(): + if not isinstance(data, dict): + logger.error( + "Expected to find a dictionary containing a list of requests and their shared metadatas." + ) + if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size: + # Batch is full + next_key = idx_key + 1 + continue + elif data["meta"] == req_metas: + batches[idx_key]["reqs"].extend([req]) + is_sorted = True + break + else: + next_key = idx_key + 1 + if not is_sorted: + batches[next_key] = { + "reqs": [req], + "meta": req_metas, + } + return batches + + def board(self, request_bundle, fiber): + pending = request_bundle + if len(pending) == 0: + return + exec_process = InferenceExecutorProcess(self.service, fiber) + for req in pending: + if len(exec_process.exec_requests) >= self.ideal_batch_size: + break + exec_process.exec_requests.append(req) + if exec_process.exec_requests: + for flighted_request in exec_process.exec_requests: + self.pending_requests.remove(flighted_request) + exec_process.launch() + + +######################################################################################## +# Inference Executors +######################################################################################## + + +class InferenceExecutorProcess(sf.Process): + """Executes a stable diffusion inference batch""" + + def __init__( + self, + service: GenerateService, + fiber, + ): + super().__init__(fiber=fiber) + self.service = service + self.worker_index = self.service.get_worker_index(fiber) + self.exec_requests: list[InferenceExecRequest] = [] + + @measure(type="exec", task="inference process") + async def run(self): + try: + phase = None + for req in self.exec_requests: + if phase: + if phase != req.phase: + logger.error("Executor process recieved disjoint batch.") + phase = req.phase + phases = self.exec_requests[0].phases + req_count = len(self.exec_requests) + device0 = self.fiber.device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._encode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) + await device0 + for i in range(req_count): + req = self.exec_requests[i] + req.done.set_success() + if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(self.fiber) + + except Exception: + logger.exception("Fatal error in image generation") + # TODO: Cancel and set error correctly + for req in self.exec_requests: + req.done.set_success() + + async def _prepare(self, device, requests): + for request in requests: + # Tokenize prompts and negative prompts. We tokenize in bs1 for now and join later. + input_ids_list = [] + neg_ids_list = [] + for tokenizer in self.service.tokenizers: + input_ids = tokenizer.encode(request.prompt) + input_ids_list.append(input_ids) + neg_ids = tokenizer.encode(request.neg_prompt) + neg_ids_list.append(neg_ids) + ids_list = [*input_ids_list, *neg_ids_list] + + request.input_ids = ids_list + + # Generate random sample latents. + seed = request.seed + channels = self.service.model_params.num_latents_channels + unet_dtype = self.service.model_params.unet_dtype + latents_shape = ( + 1, + channels, + request.height // 8, + request.width // 8, + ) + + # Create and populate sample device array. + generator = sfnp.RandomGenerator(seed) + request.sample = sfnp.device_array.for_device( + device, latents_shape, unet_dtype + ) + + sample_host = request.sample.for_transfer() + with sample_host.map(discard=True) as m: + m.fill(bytes(1)) + + sfnp.fill_randn(sample_host, generator=generator) + + request.sample.copy_from(sample_host) + return + + async def _encode(self, device, requests): + req_bs = len(requests) + entrypoints = self.service.inference_functions[self.worker_index]["encode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._encode(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + + # Prepare tokenized input ids for CLIP inference + + clip_inputs = [ + sfnp.device_array.for_device( + device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 + ), + sfnp.device_array.for_device( + device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 + ), + sfnp.device_array.for_device( + device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 + ), + sfnp.device_array.for_device( + device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 + ), + ] + host_arrs = [None] * 4 + for idx, arr in enumerate(clip_inputs): + host_arrs[idx] = arr.for_transfer() + for i in range(req_bs): + with host_arrs[idx].view(i).map(write=True, discard=True) as m: + + # TODO: fix this attr redundancy + np_arr = requests[i].input_ids[idx].input_ids + + m.fill(np_arr) + clip_inputs[idx].copy_from(host_arrs[idx]) + + # Encode tokenized inputs. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), + ) + await device + pe, te = await fn(*clip_inputs, fiber=self.fiber) + + for i in range(req_bs): + cfg_mult = 2 + requests[i].prompt_embeds = pe.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) + requests[i].text_embeds = te.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) + + return + + async def _denoise(self, device, requests): + req_bs = len(requests) + step_count = requests[0].steps + cfg_mult = 2 if self.service.model_params.cfg_mode else 1 + # Produce denoised latents + entrypoints = self.service.inference_functions[self.worker_index]["denoise"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._denoise(device, [request]) + return + for bs, fns in entrypoints.items(): + if bs == req_bs: + break + + # Get shape of batched latents. + # This assumes all requests are dense at this point. + latents_shape = [ + req_bs, + self.service.model_params.num_latents_channels, + requests[0].height // 8, + requests[0].width // 8, + ] + # Assume we are doing classifier-free guidance + hidden_states_shape = [ + req_bs * cfg_mult, + self.service.model_params.max_seq_len, + 2048, + ] + text_embeds_shape = [ + req_bs * cfg_mult, + 1280, + ] + denoise_inputs = { + "sample": sfnp.device_array.for_device( + device, latents_shape, self.service.model_params.unet_dtype + ), + "encoder_hidden_states": sfnp.device_array.for_device( + device, hidden_states_shape, self.service.model_params.unet_dtype + ), + "text_embeds": sfnp.device_array.for_device( + device, text_embeds_shape, self.service.model_params.unet_dtype + ), + "guidance_scale": sfnp.device_array.for_device( + device, [req_bs], self.service.model_params.unet_dtype + ), + } + + # Send guidance scale to device. + gs_host = denoise_inputs["guidance_scale"].for_transfer() + for i in range(req_bs): + cfg_dim = i * cfg_mult + with gs_host.view(i).map(write=True, discard=True) as m: + # TODO: do this without numpy + np_arr = np.asarray(requests[i].guidance_scale, dtype="float16") + + m.fill(np_arr) + # Batch sample latent inputs on device. + req_samp = requests[i].sample + denoise_inputs["sample"].view(i).copy_from(req_samp) + + # Batch CLIP hidden states. + enc = requests[i].prompt_embeds + denoise_inputs["encoder_hidden_states"].view( + slice(cfg_dim, cfg_dim + cfg_mult) + ).copy_from(enc) + + # Batch CLIP text embeds. + temb = requests[i].text_embeds + denoise_inputs["text_embeds"].view( + slice(cfg_dim, cfg_dim + cfg_mult) + ).copy_from(temb) + + denoise_inputs["guidance_scale"].copy_from(gs_host) + + num_steps = sfnp.device_array.for_device(device, [1], sfnp.sint64) + ns_host = num_steps.for_transfer() + with ns_host.map(write=True) as m: + ns_host.items = [step_count] + num_steps.copy_from(ns_host) + + init_inputs = [ + denoise_inputs["sample"], + num_steps, + ] + + # Initialize scheduler. + logger.debug( + "INVOKE %r", + fns["init"], + ) + (latents, time_ids, timesteps, sigmas) = await fns["init"]( + *init_inputs, fiber=self.fiber + ) + for i, t in tqdm( + enumerate(range(step_count)), + disable=(not self.service.show_progress), + desc=f"DENOISE (bs{req_bs})", + ): + step = sfnp.device_array.for_device(device, [1], sfnp.sint64) + s_host = step.for_transfer() + with s_host.map(write=True) as m: + s_host.items = [i] + step.copy_from(s_host) + scale_inputs = [latents, step, timesteps, sigmas] + logger.debug( + "INVOKE %r", + fns["scale"], + ) + latent_model_input, t, sigma, next_sigma = await fns["scale"]( + *scale_inputs, fiber=self.fiber + ) + await device + + unet_inputs = [ + latent_model_input, + t, + denoise_inputs["encoder_hidden_states"], + denoise_inputs["text_embeds"], + time_ids, + denoise_inputs["guidance_scale"], + ] + logger.debug( + "INVOKE %r", + fns["unet"], + ) + (noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber) + + step_inputs = [noise_pred, latents, sigma, next_sigma] + logger.debug( + "INVOKE %r", + fns["step"], + ) + (latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber) + latents.copy_from(latent_model_output) + + for idx, req in enumerate(requests): + req.denoised_latents = sfnp.device_array.for_device( + device, latents_shape, self.service.model_params.vae_dtype + ) + req.denoised_latents.copy_from(latents.view(idx)) + return + + async def _decode(self, device, requests): + req_bs = len(requests) + # Decode latents to images + entrypoints = self.service.inference_functions[self.worker_index]["decode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._decode(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + + latents_shape = [ + req_bs, + self.service.model_params.num_latents_channels, + requests[0].height // 8, + requests[0].width // 8, + ] + latents = sfnp.device_array.for_device( + device, latents_shape, self.service.model_params.vae_dtype + ) + for i in range(req_bs): + latents.view(i).copy_from(requests[i].denoised_latents) + + await device + # Decode the denoised latents. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n 0: {latents.shape}"]), + ) + (image,) = await fn(latents, fiber=self.fiber) + + await device + images_shape = [ + req_bs, + 3, + requests[0].height, + requests[0].width, + ] + image_shape = [ + req_bs, + 3, + requests[0].height, + requests[0].width, + ] + images_host = sfnp.device_array.for_host(device, images_shape, sfnp.float16) + images_host.copy_from(image) + await device + for idx, req in enumerate(requests): + image_array = images_host.view(idx).items + dtype = image_array.typecode + if images_host.dtype == sfnp.float16: + dtype = np.float16 + req.image_array = np.frombuffer(image_array, dtype=dtype).reshape( + *image_shape + ) + return + + async def _postprocess(self, device, requests): + # Process output images + for req in requests: + # TODO: reimpl with sfnp + permuted = np.transpose(req.image_array, (0, 2, 3, 1))[0] + cast_image = (permuted * 255).round().astype("uint8") + image_bytes = Image.fromarray(cast_image).tobytes() + + image = base64.b64encode(image_bytes).decode("utf-8") + req.result_image = image + return diff --git a/shortfin/python/shortfin_apps/sd/components/tokenizer.py b/shortfin/python/shortfin_apps/sd/components/tokenizer.py new file mode 100644 index 000000000..5903d89a5 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/tokenizer.py @@ -0,0 +1,79 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from transformers import CLIPTokenizer, BatchEncoding + +import shortfin as sf +import shortfin.array as sfnp + + +class Tokenizer: + def __init__( + self, + raw_tk: CLIPTokenizer, + max_length: int = 64, + pad_id: int = 0, + attn_mask=False, + ): + self.pad_id = pad_id + self._raw = raw_tk + self.max_length = 64 + self.return_attention_mask = attn_mask + + @staticmethod + def from_pretrained(name: str, subfolder: str) -> "Tokenizer": + raw_tk = CLIPTokenizer.from_pretrained(name, subfolder=subfolder) + return Tokenizer(raw_tk) + + def encode(self, texts: list[str]): + """Encodes a batch of texts, applying no padding.""" + return self._raw( + texts, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_tensors="np", + return_attention_mask=self.return_attention_mask, + ) + + def encoding_length(self, enc: BatchEncoding) -> int: + """Gets the length of an encoding.""" + return len(enc.input_ids) + + def encodings_to_array( + self, + device: sf.ScopedDevice, + encs: dict[str, BatchEncoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + """Creates a device_array with the contents of a batch of encodings. + + It is expected that the user has called post_process_encodings with + the same batch_seq_len in order to properly truncate/pad. + """ + ary = sfnp.device_array.for_host( + device, [len(encs.input_ids), batch_seq_len], dtype + ) + for i, ids in enumerate(encs.input_ids): + ary.view(i).items = ids + return ary + + def attention_masks_to_array( + self, + device: sf.ScopedDevice, + encs: list[BatchEncoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + ary = sfnp.device_array.for_host( + device, [len(encs.attention_mask), batch_seq_len], dtype + ) + for i, enc in enumerate(encs.attention_mask): + ary.view(i).items = enc + return ary diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_config_fp16.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_fp16.json new file mode 100644 index 000000000..2e03e4603 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_fp16.json @@ -0,0 +1,24 @@ +{ + "max_seq_len": 64, + "num_latents_channels": 4, + "clip_batch_sizes": [ + 1 + ], + "unet_batch_sizes": [ + 1 + ], + "vae_batch_sizes": [ + 1 + ], + "scheduler_batch_sizes": [ + 1 + ], + "unet_module_name": "compiled_unet", + "unet_fn_name": "run_forward", + "dims": [ + [ + 1024, + 1024 + ] + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json new file mode 100644 index 000000000..804947d8f --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json @@ -0,0 +1,26 @@ +{ + "max_seq_len": 64, + "num_latents_channels": 4, + "clip_batch_sizes": [ + 1 + ], + "unet_batch_sizes": [ + 1 + ], + "vae_batch_sizes": [ + 1 + ], + "scheduler_batch_sizes": [ + 1 + ], + "unet_dtype": "float16", + "unet_module_name": "compiled_punet", + "unet_fn_name": "main", + "use_i8_punet": true, + "dims": [ + [ + 1024, + 1024 + ] + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt b/shortfin/python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt new file mode 100644 index 000000000..731bc6da0 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt @@ -0,0 +1,23 @@ +all +--iree-hal-target-backends=rocm +--iree-hip-target=gfx942 +--iree-execution-model=async-external +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +--iree-global-opt-propagate-transposes=1 +--iree-opt-const-eval=0 +--iree-opt-outer-dim-concat=1 +--iree-opt-aggressively-propagate-transposes=1 +--iree-dispatch-creation-enable-aggressive-fusion +--iree-hal-force-indirect-command-buffers +--iree-codegen-llvmgpu-use-vector-distribution=1 +--iree-llvmgpu-enable-prefetch=1 +--iree-codegen-gpu-native-math-precision=1 +--iree-hip-legacy-sync=0 +--iree-opt-data-tiling=0 +--iree-vm-target-truncate-unsupported-floats +clip +unet +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 +vae +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 +scheduler diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request.json new file mode 100644 index 000000000..bf3d3ae26 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request.json @@ -0,0 +1,29 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": [ + "Watermark, blurry, oversaturated, low resolution, pollution" + ], + "height": [ + 1024 + ], + "width": [ + 1024 + ], + "steps": [ + 20 + ], + "guidance_scale": [ + 7.5 + ], + "seed": [ + 0 + ], + "output_type": [ + "base64" + ], + "rid": [ + "string" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json new file mode 100644 index 000000000..0ded22888 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs2.json @@ -0,0 +1,18 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 7.5, + 7.9 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json new file mode 100644 index 000000000..002f43f0e --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json @@ -0,0 +1,57 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo" + ], + "neg_prompt": [ + "Watermark, blurry, oversaturated, low resolution, pollution" + ], + "height": [ + 1024 + ], + "width": [ + 1024 + ], + "steps": [ + 20 + ], + "guidance_scale": [ + 7.5 + ], + "seed": [ + 0 + ], + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json new file mode 100644 index 000000000..b59887b8f --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs4.json @@ -0,0 +1,22 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a dog under the snow with brown eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 10, + 10, + 10, + 10 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json new file mode 100644 index 000000000..394e3568e --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs8.json @@ -0,0 +1,33 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo" + ], + "neg_prompt": [ + "Watermark, blurry, oversaturated, low resolution, pollution" + ], + "height": [ + 1024 + ], + "width": [ + 1024 + ], + "steps": [ + 20 + ], + "guidance_scale": [ + 7.5 + ], + "seed": [ + 0 + ], + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py new file mode 100644 index 000000000..4e3835690 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -0,0 +1,421 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any +import argparse +import logging +from pathlib import Path +import sys +import os +import copy +import subprocess +from contextlib import asynccontextmanager +import uvicorn + +# Import first as it does dep checking and reporting. +from shortfin.interop.fastapi import FastAPIResponder +from shortfin.support.logging_setup import native_handler + +from fastapi import FastAPI, Request, Response + +from .components.generate import ClientGenerateBatchProcess +from .components.config_struct import ModelParams +from .components.io_struct import GenerateReqInput +from .components.manager import SystemManager +from .components.service import GenerateService +from .components.tokenizer import Tokenizer + + +logger = logging.getLogger("shortfin-sd") +logger.addHandler(native_handler) +logger.propagate = False + +THIS_DIR = Path(__file__).resolve().parent + +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "format": "[{asctime}] {message}", + "datefmt": "%Y-%m-%d %H:%M:%S", + "style": "{", + "use_colors": True, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + sysman.start() + try: + for service_name, service in services.items(): + logger.info("Initializing service '%s':", service_name) + logger.info(str(service)) + service.start() + except: + sysman.shutdown() + raise + yield + try: + for service_name, service in services.items(): + logger.info("Shutting down service '%s'", service_name) + service.shutdown() + finally: + sysman.shutdown() + + +sysman: SystemManager +services: dict[str, Any] = {} +app = FastAPI(lifespan=lifespan) + + +@app.get("/health") +async def health() -> Response: + return Response(status_code=200) + + +async def generate_request(gen_req: GenerateReqInput, request: Request): + service = services["sd"] + gen_req.post_init() + responder = FastAPIResponder(request) + ClientGenerateBatchProcess(service, gen_req, responder).launch() + return await responder.response + + +app.post("/generate")(generate_request) +app.put("/generate")(generate_request) + + +def configure_sys(args) -> SystemManager: + # Setup system (configure devices, etc). + model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) + sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) + return sysman, model_config, flagfile, tuning_spec + + +def configure_service(args, sysman, model_config, flagfile, tuning_spec): + # Setup each service we are hosting. + tokenizers = [] + for idx, tok_name in enumerate(args.tokenizers): + subfolder = f"tokenizer_{idx + 1}" if idx > 0 else "tokenizer" + tokenizers.append(Tokenizer.from_pretrained(tok_name, subfolder)) + + model_params = ModelParams.load_json(model_config) + vmfbs, params = get_modules(args, model_config, flagfile, tuning_spec) + + sm = GenerateService( + name="sd", + sysman=sysman, + tokenizers=tokenizers, + model_params=model_params, + fibers_per_device=args.fibers_per_device, + workers_per_device=args.workers_per_device, + prog_isolation=args.isolation, + show_progress=args.show_progress, + trace_execution=args.trace_execution, + ) + for key, vmfblist in vmfbs.items(): + for vmfb in vmfblist: + sm.load_inference_module(vmfb, component=key) + for key, datasets in params.items(): + sm.load_inference_parameters(*datasets, parameter_scope="model", component=key) + services[sm.name] = sm + return sysman + + +def get_configs(args): + # Returns one set of config artifacts. + modelname = "sdxl" + model_config = args.model_config if args.model_config else None + topology_config = None + tuning_spec = None + flagfile = args.flagfile if args.flagfile else None + topology_inp = args.topology if args.topology else "spx_single" + cfg_builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "config_artifacts.py"), + f"--target={args.target}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + f"--topology={topology_inp}", + ] + outs = subprocess.check_output(cfg_builder_args).decode() + outs_paths = outs.splitlines() + for i in outs_paths: + if "sdxl_config" in i and not args.model_config: + model_config = i + elif "topology" in i and args.topology: + topology_config = i + elif "flagfile" in i and not args.flagfile: + flagfile = i + elif "attention_and_matmul_spec" in i and args.use_tuned: + tuning_spec = i + + if args.use_tuned and args.tuning_spec: + tuning_spec = os.path.abspath(args.tuning_spec) + + if topology_config: + with open(topology_config, "r") as f: + contents = [line.rstrip() for line in f] + for spec in contents: + if "--" in spec: + arglist = spec.strip("--").split("=") + arg = arglist[0] + if len(arglist) > 2: + value = arglist[1:] + for val in value: + try: + val = int(val) + except ValueError: + val = val + elif len(arglist) == 2: + value = arglist[-1] + try: + value = int(value) + except ValueError: + value = value + else: + # It's a boolean arg. + value = True + setattr(args, arg, value) + else: + # It's an env var. + arglist = spec.split("=") + os.environ[arglist[0]] = arglist[1] + return model_config, topology_config, flagfile, tuning_spec, args + + +def get_modules(args, model_config, flagfile, td_spec): + # TODO: Move this out of server entrypoint + vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []} + params = {"clip": [], "unet": [], "vae": []} + model_flags = copy.deepcopy(vmfbs) + model_flags["all"] = args.compile_flags + + if flagfile: + with open(flagfile, "r") as f: + contents = [line.rstrip() for line in f] + flagged_model = "all" + for elem in contents: + match = [keyw in elem for keyw in model_flags.keys()] + if any(match): + flagged_model = elem + else: + model_flags[flagged_model].extend([elem]) + if td_spec: + model_flags["unet"].extend( + [f"--iree-codegen-transform-dialect-library={td_spec}"] + ) + + filenames = [] + for modelname in vmfbs.keys(): + ireec_args = model_flags["all"] + model_flags[modelname] + ireec_extra_args = " ".join(ireec_args) + builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "builders.py"), + f"--model-json={model_config}", + f"--target={args.target}", + f"--splat={args.splat}", + f"--build-preference={args.build_preference}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + f"--iree-hal-target-device={args.device}", + f"--iree-hip-target={args.target}", + f"--iree-compile-extra-args={ireec_extra_args}", + ] + logger.info(f"Preparing runtime artifacts for {modelname}...") + logger.debug( + "COMMAND LINE EQUIVALENT: " + " ".join([str(argn) for argn in builder_args]) + ) + output = subprocess.check_output(builder_args).decode() + + output_paths = output.splitlines() + filenames.extend(output_paths) + for name in filenames: + for key in vmfbs.keys(): + if key in name.lower(): + if any(x in name for x in [".irpa", ".safetensors", ".gguf"]): + params[key].extend([name]) + elif "vmfb" in name: + vmfbs[key].extend([name]) + return vmfbs, params + + +def main(argv, log_config=UVICORN_LOG_CONFIG): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" + ) + parser.add_argument( + "--device", + type=str, + required=True, + choices=["local-task", "hip", "amdgpu"], + help="Primary inferencing device", + ) + parser.add_argument( + "--target", + type=str, + required=False, + default="gfx942", + choices=["gfx942", "gfx1100", "gfx90a"], + help="Primary inferencing device LLVM target arch.", + ) + parser.add_argument( + "--device_ids", + type=str, + nargs="*", + default=None, + help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0", + ) + parser.add_argument( + "--tokenizers", + type=Path, + nargs="*", + default=[ + "stabilityai/stable-diffusion-xl-base-1.0", + "stabilityai/stable-diffusion-xl-base-1.0", + ], + help="HF repo from which to load tokenizer(s).", + ) + parser.add_argument( + "--model_config", + type=Path, + help="Path to the model config file. If None, defaults to i8 punet, batch size 1", + ) + parser.add_argument( + "--workers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--fibers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_call", + choices=["per_fiber", "per_call", "none"], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--show_progress", + action="store_true", + help="enable tqdm progress for unet iterations.", + ) + parser.add_argument( + "--trace_execution", + action="store_true", + help="Enable tracing of program modules.", + ) + parser.add_argument( + "--amdgpu_async_allocations", + action="store_true", + help="Enable asynchronous allocations for amdgpu device contexts.", + ) + parser.add_argument( + "--splat", + action="store_true", + help="Use splat (empty) parameter files, usually for testing.", + ) + parser.add_argument( + "--build_preference", + type=str, + choices=["compile", "precompiled"], + default="precompiled", + help="Specify preference for builder artifact generation.", + ) + parser.add_argument( + "--compile_flags", + type=str, + nargs="*", + default=[], + help="extra compile flags for all compile actions. For fine-grained control, use flagfiles.", + ) + parser.add_argument( + "--flagfile", + type=Path, + help="Path to a flagfile to use for SDXL. If not specified, will use latest flagfile from azure.", + ) + parser.add_argument( + "--artifacts_dir", + type=Path, + default=None, + help="Path to local artifacts cache.", + ) + parser.add_argument( + "--tuning_spec", + type=str, + default=None, + help="Path to transform dialect spec if compiling an executable with tunings.", + ) + parser.add_argument( + "--topology", + type=str, + default=None, + choices=["spx_single", "cpx_single", "spx_multi", "cpx_multi"], + help="Use one of four known performant preconfigured device/fiber topologies.", + ) + parser.add_argument( + "--use_tuned", + type=int, + default=1, + help="Use tunings for attention and matmul ops. 0 to disable.", + ) + args = parser.parse_args(argv) + if not args.artifacts_dir: + home = Path.home() + artdir = home / ".cache" / "shark" + args.artifacts_dir = str(artdir) + else: + args.artifacts_dir = Path(args.artifacts_dir).resolve() + + global sysman + sysman, model_config, flagfile, tuning_spec = configure_sys(args) + configure_service(args, sysman, model_config, flagfile, tuning_spec) + uvicorn.run( + app, + host=args.host, + port=args.port, + log_config=log_config, + timeout_keep_alive=args.timeout_keep_alive, + ) + + +if __name__ == "__main__": + logging.root.setLevel(logging.INFO) + main( + sys.argv[1:], + # Make logging defer to the default shortfin logging config. + log_config=UVICORN_LOG_CONFIG, + ) diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py new file mode 100644 index 000000000..0d88a59c7 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -0,0 +1,246 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from datetime import datetime as dt +import os +import sys +import time +import json +import argparse +import base64 +import asyncio +import aiohttp +import requests + +from PIL import Image + +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [20], + "guidance_scale": [7.5], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def bytes_to_img(in_bytes, outputdir, idx=0, width=1024, height=1024): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(in_bytes) + ) + if not os.path.isdir(outputdir): + os.mkdir(outputdir) + im_path = os.path.join(outputdir, f"shortfin_sd_output_{timestamp}_{idx}.png") + image.save(im_path) + print(f"Saved to {im_path}") + + +def get_batched(request, arg, idx): + if isinstance(request[arg], list): + # some args are broadcasted to each prompt, hence overriding idx for single-item entries + if len(request[arg]) == 1: + indexed = request[arg][0] + else: + indexed = request[arg][idx] + else: + indexed = request[arg] + return indexed + + +async def send_request(session, rep, args, data): + print("Sending request batch #", rep) + url = f"{args.host}:{args.port}/generate" + start = time.time() + async with session.post(url, json=data) as response: + end = time.time() + # Check if the response was successful + if response.status == 200: + response.raise_for_status() # Raise an error for bad responses + res_json = await response.json(content_type=None) + if args.save: + for idx, item in enumerate(res_json["images"]): + width = get_batched(data, "width", idx) + height = get_batched(data, "height", idx) + print("Saving response as image...") + bytes_to_img( + item.encode("utf-8"), args.outputdir, idx, width, height + ) + latency = end - start + print("Responses processed.") + return latency, len(data["prompt"]) + print(f"Error: Received {response.status} from server") + raise Exception + + +async def static(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if not args.file: + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + start = time.time() + + async for i in async_range(args.reps): + pending.append(asyncio.create_task(send_request(session, i, args, data))) + await asyncio.sleep(1) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + latency, num_samples = await task + latencies.append(latency) + sample_counts.append(num_samples) + end = time.time() + if not any(i is None for i in [latencies, sample_counts]): + total_num_samples = sum(sample_counts) + sps = str(total_num_samples / (end - start)) + # Until we have better measurements, don't report the throughput that includes saving images. + if not args.save: + print(f"Average throughput: {sps} samples per second") + else: + raise ValueError("Received error response from server.") + + +async def interactive(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if not args.file: + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + while True: + prompt = await ainput("Enter a prompt: ") + data["prompt"] = [prompt] + data["steps"] = [args.steps] + print("Sending request with prompt: ", data["prompt"]) + + async for i in async_range(args.reps): + pending.append( + asyncio.create_task(send_request(session, i, args, data)) + ) + await asyncio.sleep( + 1 + ) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + _, _ = await task + pending = [] + if any(i is None for i in [latencies, sample_counts]): + raise ValueError("Received error response from server.") + + +async def ainput(prompt: str) -> str: + return await asyncio.to_thread(input, f"{prompt} ") + + +async def async_range(count): + for i in range(count): + yield i + await asyncio.sleep(0.0) + + +def check_health(url): + ready = False + print("Waiting for server.", end=None) + while not ready: + try: + if requests.get(f"{url}/health", timeout=20).status_code == 200: + print("Successfully connected to server.") + ready = True + return + time.sleep(2) + print(".", end=None) + except: + time.sleep(2) + print(".", end=None) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument( + "--file", + type=str, + default=None, + help="A non-default request to send to the server.", + ) + p.add_argument( + "--reps", + type=int, + default=1, + help="Number of times to duplicate each request in one second intervals.", + ) + p.add_argument( + "--save", + action=argparse.BooleanOptionalAction, + default=True, + help="Save images. To disable, use --no-save", + ) + p.add_argument( + "--outputdir", + type=str, + default="gen_imgs", + help="Directory to which images get saved.", + ) + p.add_argument( + "--host", type=str, default="http://0.0.0.0", help="Server host address." + ) + p.add_argument("--port", type=str, default="8000", help="Server port") + p.add_argument( + "--steps", + type=int, + default="20", + help="Number of inference steps. More steps usually means a better image. Interactive only.", + ) + p.add_argument( + "--interactive", + action="store_true", + help="Start as an example CLI client instead of sending static requests.", + ) + args = p.parse_args() + check_health(f"{args.host}:{args.port}") + if args.interactive: + asyncio.run(interactive(args)) + else: + asyncio.run(static(args)) + + +if __name__ == "__main__": + main() diff --git a/libshortfin/bindings/python/utils.h b/shortfin/python/utils.h similarity index 87% rename from libshortfin/bindings/python/utils.h rename to shortfin/python/utils.h index 86a85441f..90488d3f4 100644 --- a/libshortfin/bindings/python/utils.h +++ b/shortfin/python/utils.h @@ -8,31 +8,31 @@ #include "./lib_ext.h" #include "shortfin/local/device.h" -#include "shortfin/local/scope.h" +#include "shortfin/local/fiber.h" namespace shortfin::python { // Casts any of int, str, local::Device, DeviceAffinity to a DeviceAffinity. // If the object is a sequence, then the affinity is constructed from the union. -inline local::ScopedDevice CastDeviceAffinity(local::Scope& scope, +inline local::ScopedDevice CastDeviceAffinity(local::Fiber& fiber, py::handle object) { if (py::isinstance(object)) { - return scope.device(py::cast(object)); + return fiber.device(py::cast(object)); } else if (py::isinstance(object)) { - return local::ScopedDevice(scope, py::cast(object)); + return local::ScopedDevice(fiber, py::cast(object)); } else if (py::isinstance(object)) { - return scope.device(py::cast(object)); + return fiber.device(py::cast(object)); } else if (py::isinstance(object)) { - return scope.device(py::cast(object)); + return fiber.device(py::cast(object)); } else if (py::isinstance(object)) { // Important: sequence must come after string, since string is a sequence // and this will infinitely recurse (since the first element of the string // is a sequence, etc). local::DeviceAffinity affinity; for (auto item : py::cast(object)) { - affinity |= CastDeviceAffinity(scope, item).affinity(); + affinity |= CastDeviceAffinity(fiber, item).affinity(); } - return local::ScopedDevice(scope, affinity); + return local::ScopedDevice(fiber, affinity); } throw std::invalid_argument(fmt::format("Cannot cast {} to DeviceAffinity", diff --git a/shortfin/requirements-iree-compiler.txt b/shortfin/requirements-iree-compiler.txt new file mode 100644 index 000000000..ada82f2eb --- /dev/null +++ b/shortfin/requirements-iree-compiler.txt @@ -0,0 +1,4 @@ +# Keep in sync with "ref: iree-" in .github/workflows/* and GIT_TAG in CMakeLists.txt +-f https://iree.dev/pip-release-links.html +iree-base-compiler==3.0.0rc20241118 +iree-base-runtime==3.0.0rc20241118 diff --git a/libshortfin/requirements-tests.txt b/shortfin/requirements-tests-nogil.txt similarity index 75% rename from libshortfin/requirements-tests.txt rename to shortfin/requirements-tests-nogil.txt index 1049b0412..1769467ab 100644 --- a/libshortfin/requirements-tests.txt +++ b/shortfin/requirements-tests-nogil.txt @@ -1,4 +1,3 @@ pytest requests -fastapi uvicorn diff --git a/shortfin/requirements-tests.txt b/shortfin/requirements-tests.txt new file mode 100644 index 000000000..c04c97af2 --- /dev/null +++ b/shortfin/requirements-tests.txt @@ -0,0 +1,19 @@ +pytest +requests +fastapi +onnx +uvicorn + +# Libraries needed to build in dev setups. +setuptools +wheel + +# Deps needed for shortfin_apps.llm +dataclasses-json +tokenizers +huggingface_hub[cli] +sentencepiece + +# Deps needed for shortfin_apps.sd +pillow +transformers diff --git a/shortfin/requirements.txt b/shortfin/requirements.txt deleted file mode 100644 index 99c132cea..000000000 --- a/shortfin/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -fastapi==0.109.2 -uvicorn==0.27.0 -requests==2.32.0 diff --git a/shortfin/setup.cfg b/shortfin/setup.cfg deleted file mode 100644 index 358360671..000000000 --- a/shortfin/setup.cfg +++ /dev/null @@ -1,6 +0,0 @@ -[tool:pytest] -testpaths = - ./tests -filterwarnings = - # TODO: Remove once flatbuffer 'imp' usage resolved. - ignore::DeprecationWarning diff --git a/shortfin/setup.py b/shortfin/setup.py index 9ce1aa822..cf3762950 100644 --- a/shortfin/setup.py +++ b/shortfin/setup.py @@ -6,103 +6,382 @@ import json import os -import distutils.command.build +import shutil +import subprocess +import sys +import traceback +from distutils.command.build import build as _build +from distutils.core import setup, Extension from pathlib import Path +from setuptools import find_namespace_packages +from setuptools.command.build_py import build_py as _build_py +from setuptools.command.build_ext import build_ext as _build_ext -from setuptools import find_namespace_packages, setup # type: ignore -THIS_DIR = Path(__file__).resolve().parent -REPO_DIR = THIS_DIR.parent -VERSION_INFO_FILE = REPO_DIR / "version_info.json" +def get_env_boolean(name: str, default_value: bool = False) -> bool: + svalue = os.getenv(name) + if svalue is None: + return default_value + svalue = svalue.upper() + if svalue in ["1", "ON", "TRUE"]: + return True + elif svalue in ["0", "OFF", "FALSE"]: + return False + else: + print(f"WARNING: {name} env var cannot be interpreted as a boolean value") + return default_value -with open( - os.path.join( - THIS_DIR, - "README.md", - ), - "rt", -) as f: - README = f.read() +def get_env_cmake_option(name: str, default_value: bool = False) -> str: + svalue = os.getenv(name) + if not svalue: + svalue = "ON" if default_value else "OFF" + return f"-D{name}={svalue}" -def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: - return json.load(f) +def add_env_cmake_setting( + args, env_name: str, cmake_name=None, default_value=None +) -> str: + svalue = os.getenv(env_name) + if svalue is None and default_value is not None: + svalue = default_value + if svalue is not None: + if not cmake_name: + cmake_name = env_name + args.append(f"-D{cmake_name}={svalue}") -version_info = load_version_info() -PACKAGE_VERSION = version_info["package-version"] +def combine_dicts(*ds): + result = {} + for d in ds: + result.update(d) + return result -packages = find_namespace_packages( - include=[ - "shortfin", - "shortfin.*", - ], + +# This file can be generated into the build directory to allow an arbitrary +# CMake built version of the project to be installed into a venv for development. +# This can be detected if the CPP_PREBUILT global contains the string +# "TRUE", which will be the case if generated. +CPP_PREBUILT = "@SHORTFIN_PYTHON_CPP_PREBUILT@" +CPP_PREBUILT_SOURCE_DIR = "@libshortfin_SOURCE_DIR@" +CPP_PREBUILT_BINARY_DIR = "@libshortfin_BINARY_DIR@" + +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) +CMAKE_EXE = os.getenv("SHORTFIN_CMAKE", "cmake") + + +def is_cpp_prebuilt(): + return CPP_PREBUILT == "TRUE" + + +DEV_MODE = False +ENABLE_TRACY = get_env_boolean("SHORTFIN_ENABLE_TRACING", False) + +if ENABLE_TRACY: + print( + "*** Enabling Tracy instrumentation (disable with SHORTFIN_ENABLE_TRACING=OFF)", + ) +else: + print( + "*** Tracy instrumentation not enabled (enable with SHORTFIN_ENABLE_TRACING=ON)", + ) + + +if is_cpp_prebuilt(): + print("setup.py running in pre-built mode:") + SOURCE_DIR = Path(CPP_PREBUILT_SOURCE_DIR) + BINARY_DIR = Path(CPP_PREBUILT_BINARY_DIR) + CMAKE_DEFAULT_BUILD_DIR = BINARY_DIR + CMAKE_TRACY_BUILD_DIR = BINARY_DIR +else: + print("setup.py running in cmake build mode:") + # setup.py is in the source directory. + SOURCE_DIR = Path(SETUPPY_DIR) + BINARY_DIR = Path(os.path.join(SETUPPY_DIR, "build")) + # TODO: Should build default and tracing version to different dirs. + CMAKE_DEFAULT_BUILD_DIR = BINARY_DIR / "cmake" / "default" + CMAKE_TRACY_BUILD_DIR = BINARY_DIR / "cmake" / "tracy" + DEV_MODE = get_env_boolean("SHORTFIN_DEV_MODE") + +print(f" SOURCE_DIR = {SOURCE_DIR}") +print(f" BINARY_DIR = {BINARY_DIR}") + +if DEV_MODE: + print(f" DEV MODE ENABLED: Building debug with clang/lld and other dev settings") + +# Due to a quirk of setuptools, that package_dir map must only contain +# paths relative to the directory containing setup.py. Why? No one knows. +REL_SOURCE_DIR = Path(os.path.relpath(SOURCE_DIR, SETUPPY_DIR)) +REL_BINARY_DIR = Path(os.path.relpath(BINARY_DIR, SETUPPY_DIR)) +REL_CMAKE_DEFAULT_BUILD_DIR = Path( + os.path.relpath(CMAKE_DEFAULT_BUILD_DIR, SETUPPY_DIR) ) +REL_CMAKE_TRACY_BUILD_DIR = Path(os.path.relpath(CMAKE_TRACY_BUILD_DIR, SETUPPY_DIR)) + + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CustomBuild(_build): + def run(self): + self.run_command("build_py") + self.run_command("build_ext") + self.run_command("build_scripts") + + +class NoopBuildExtension(_build_ext): + def build_extension(self, ext): + ... + + def copy_extensions_to_source(self, *args, **kwargs): + ... + + +# Setup and get version information. +VERSION_FILE = os.path.join(REL_SOURCE_DIR, "version.json") +VERSION_FILE_LOCAL = os.path.join(REL_SOURCE_DIR, "version_local.json") + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + -print("Found packages:", packages) +try: + version_info = load_version_info(VERSION_FILE_LOCAL) +except FileNotFoundError: + print("version_local.json not found. Default to dev build") + version_info = load_version_info(VERSION_FILE) -# Lookup version pins from requirements files. -requirement_pins = {} +PACKAGE_VERSION = version_info.get("package-version") +print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") -def load_requirement_pins(requirements_file: Path): - with open(requirements_file, "rt") as f: - lines = f.readlines() - pin_pairs = [line.strip().split("==") for line in lines if "==" in line] - requirement_pins.update(dict(pin_pairs)) +def maybe_nuke_cmake_cache(cmake_build_dir): + # From run to run under pip, we can end up with different paths to ninja, + # which isn't great and will confuse cmake. Detect if the location of + # ninja changes and force a cache flush. + ninja_path = "" + try: + import ninja + except ModuleNotFoundError: + pass + else: + ninja_path = ninja.__file__ + expected_stamp_contents = f"{sys.executable}\n{ninja_path}" + + # In order to speed things up on CI and not rebuild everything, we nuke + # the CMakeCache.txt file if the path to the Python interpreter changed. + # Ideally, CMake would let us reconfigure this dynamically... but it does + # not (and gets very confused). + PYTHON_STAMP_FILE = os.path.join(cmake_build_dir, "python_stamp.txt") + if os.path.exists(PYTHON_STAMP_FILE): + with open(PYTHON_STAMP_FILE, "rt") as f: + actual_stamp_contents = f.read() + if actual_stamp_contents == expected_stamp_contents: + # All good. + return + + # Mismatch or not found. Clean it. + cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt") + if os.path.exists(cmake_cache_file): + print("Removing CMakeCache.txt because Python version changed") + os.remove(cmake_cache_file) + + # And write. + with open(PYTHON_STAMP_FILE, "wt") as f: + f.write(expected_stamp_contents) -load_requirement_pins(REPO_DIR / "requirements.txt") +def build_cmake_configuration(CMAKE_BUILD_DIR: Path, extra_cmake_args=[]): + # Build extension using cmake. + cfg = os.getenv("SHORTFIN_CMAKE_BUILD_TYPE", "Debug" if DEV_MODE else "Release") + # Configure CMake. + os.makedirs(CMAKE_BUILD_DIR, exist_ok=True) + if not DEV_MODE: + maybe_nuke_cmake_cache(CMAKE_BUILD_DIR) + print(f"CMake build dir: {CMAKE_BUILD_DIR}") + cmake_args = [ + "-GNinja", + "-Wno-dev", + "--log-level=VERBOSE", + "-DSHORTFIN_BUNDLE_DEPS=ON", + f"-DCMAKE_BUILD_TYPE={cfg}", + "-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON", + f"-DPython3_EXECUTABLE={sys.executable}", + ] + extra_cmake_args -def get_version_spec(dep: str): - if dep in requirement_pins: - return f">={requirement_pins[dep]}" + if DEV_MODE: + if not os.getenv("CC"): + cmake_args.append("-DCMAKE_C_COMPILER=clang") + if not os.getenv("CXX"): + cmake_args.append("-DCMAKE_CXX_COMPILER=clang++") + add_env_cmake_setting(cmake_args, "CMAKE_LINKER_TYPE", default_value="LLD") + + add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_LTO", default_value="ON") + add_env_cmake_setting(cmake_args, "SHORTFIN_IREE_SOURCE_DIR") + add_env_cmake_setting(cmake_args, "SHORTFIN_ENABLE_ASAN") + + # Only do a from-scratch configure if not already configured. + cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt") + if not os.path.exists(cmake_cache_file): + print(f"Configuring with: {cmake_args}") + subprocess.check_call([CMAKE_EXE, SOURCE_DIR] + cmake_args, cwd=CMAKE_BUILD_DIR) + print(f"CMake configure complete.") else: - return "" + print(f"Not re-configing (already configured)") + + # Build. + subprocess.check_call([CMAKE_EXE, "--build", "."], cwd=CMAKE_BUILD_DIR) + print("Build complete.") + + # Optionally run CTests. + if get_env_boolean("SHORTFIN_RUN_CTESTS", False): + print("Running ctests...") + subprocess.check_call( + ["ctest", "--timeout", "30", "--output-on-failure"], + cwd=CMAKE_BUILD_DIR, + ) + + +class CMakeBuildPy(_build_py): + def run(self): + # The super-class handles the pure python build. + super().run() + + # Only build using cmake if not in prebuild mode. + if is_cpp_prebuilt(): + return + + try: + self.build_default_configuration() + if ENABLE_TRACY: + self.build_tracy_configuration() + except subprocess.CalledProcessError as e: + print("Native build failed:") + traceback.print_exc() + # This is not great, but setuptools *swallows* exceptions from here + # and mis-reports them as deprecation warnings! This is fairly + # fatal, so just kill it. + sys.exit(1) + + def build_default_configuration(self): + print(" *********************************") + print(" * Building base shortfin *") + print(" *********************************") + + build_cmake_configuration(CMAKE_DEFAULT_BUILD_DIR) + + # Copy non-python binaries generated during the build. + target_dir = os.path.join(os.path.abspath(self.build_lib), "_shortfin_default") + + print(f"Building in target: {target_dir}") + os.makedirs(target_dir, exist_ok=True) + print("Copying build to target.") + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + shutil.copytree( + os.path.join( + CMAKE_DEFAULT_BUILD_DIR, + "python", + "_shortfin_default", + ), + target_dir, + symlinks=False, + ) + def build_tracy_configuration(self): + print(" *********************************") + print(" * Building tracy shortfin *") + print(" *********************************") -# Override build command so that we can build into _python_build -# instead of the default "build". This avoids collisions with -# typical CMake incantations, which can produce all kinds of -# hilarity (like including the contents of the build/lib directory). -class BuildCommand(distutils.command.build.build): - def initialize_options(self): - distutils.command.build.build.initialize_options(self) - self.build_base = "_python_build" + build_cmake_configuration( + CMAKE_TRACY_BUILD_DIR, ["-DSHORTFIN_ENABLE_TRACING=ON"] + ) + # Copy non-python binaries generated during the build. + target_dir = os.path.join(os.path.abspath(self.build_lib), "_shortfin_tracy") + + print(f"Building in target: {target_dir}") + os.makedirs(target_dir, exist_ok=True) + print("Copying build to target.") + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + shutil.copytree( + os.path.join( + CMAKE_TRACY_BUILD_DIR, + "python", + "_shortfin_tracy", + ), + target_dir, + symlinks=False, + ) + + +PYTHON_SOURCE_DIR = REL_SOURCE_DIR / "python" +PYTHON_DEFAULT_BINARY_DIR = REL_CMAKE_DEFAULT_BUILD_DIR / "python" +PYTHON_TRACY_BINARY_DIR = REL_CMAKE_TRACY_BUILD_DIR / "python" + + +# We need some directories to exist before setup. +def populate_built_package(abs_dir): + """Makes sure that a directory and __init__.py exist. + + This needs to unfortunately happen before any of the build process + takes place so that setuptools can plan what needs to be built. + We do this for any built packages (vs pure source packages). + """ + os.makedirs(abs_dir, exist_ok=True) + with open(os.path.join(abs_dir, "__init__.py"), "wt"): + pass + + +populate_built_package(os.path.join(PYTHON_DEFAULT_BINARY_DIR / "_shortfin_default")) +if ENABLE_TRACY: + populate_built_package(os.path.join(PYTHON_TRACY_BINARY_DIR / "_shortfin_tracy")) + +packages = find_namespace_packages( + where=os.path.join(SOURCE_DIR, "python"), + include=[ + "_shortfin", + "_shortfin_default", + "shortfin", + "shortfin.*", + "shortfin_apps", + "shortfin_apps.*", + ] + + (["_shortfin_tracy"] if ENABLE_TRACY else []), +) +print(f"Found shortfin packages: {packages}") setup( - name=f"shortfin", version=f"{PACKAGE_VERSION}", - author="SHARK Authors", - author_email="stella@nod.ai", - description="SHARK Shortfin Machine Learning Deployment Tools", - long_description=README, - long_description_content_type="text/markdown", - url="https://github.com/nod-ai/sharktank", - license="Apache-2.0", - classifiers=[ - "Development Status :: 3 - Alpha", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - ], packages=packages, - package_data={"shortfin": ["py.typed"]}, - install_requires=[ - f"sharktank=={PACKAGE_VERSION}", - f"fastapi{get_version_spec('fastapi')}", - f"iree-runtime{get_version_spec('iree-runtime')}", - f"uvicorn{get_version_spec('uvicorn')}", - f"requests{get_version_spec('requests')}", - ], - extras_require={ - "testing": [ - f"pytest{get_version_spec('pytest')}", - f"pytest-xdist{get_version_spec('pytest-xdist')}", - ], + zip_safe=False, + package_dir=combine_dicts( + { + "_shortfin": str(PYTHON_SOURCE_DIR / "_shortfin"), + "_shortfin_default": str(PYTHON_DEFAULT_BINARY_DIR / "_shortfin_default"), + "shortfin": str(PYTHON_SOURCE_DIR / "shortfin"), + "shortfin_apps": str(PYTHON_SOURCE_DIR / "shortfin_apps"), + }, + ( + ({"_shortfin_tracy": str(PYTHON_TRACY_BINARY_DIR / "_shortfin_tracy")}) + if ENABLE_TRACY + else {} + ), + ), + ext_modules=( + [CMakeExtension("_shortfin_default.lib")] + + ([CMakeExtension("_shortfin_tracy.lib")] if ENABLE_TRACY else []) + ), + cmdclass={ + "build": CustomBuild, + "build_ext": NoopBuildExtension, + "build_py": CMakeBuildPy, }, - cmdclass={"build": BuildCommand}, ) diff --git a/shortfin/src/CMakeLists.txt b/shortfin/src/CMakeLists.txt new file mode 100644 index 000000000..53d801f36 --- /dev/null +++ b/shortfin/src/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Any definitions which must be reflected on the public library must be added +# to this library. +add_library(shortfin_public_defs INTERFACE) + +add_subdirectory(shortfin) + +# Common definitions exported from both static and dynamic libraries. +add_library(_shortfin_defs INTERFACE) +target_include_directories( + _shortfin_defs INTERFACE $ + $) + + +get_property( + _SHORTFIN_LIB_OPTIONAL_COMPONENTS GLOBAL PROPERTY SHORTFIN_LIB_OPTIONAL_COMPONENTS) + +message(STATUS "Linking optional components '${_SHORTFIN_LIB_OPTIONAL_COMPONENTS}'") +shortfin_public_library( + NAME + shortfin + LINUX_LD_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/shortfin.ld + COMPONENTS + shortfin_array + shortfin_local + shortfin_support + shortfin_systems_factory + ${_SHORTFIN_LIB_OPTIONAL_COMPONENTS} + USAGE_DEPS + shortfin_public_defs + spdlog::spdlog + fmt::fmt + xtensor + xtl + iree_defs +) diff --git a/shortfin/src/shortfin.ld b/shortfin/src/shortfin.ld new file mode 100644 index 000000000..5e62ca700 --- /dev/null +++ b/shortfin/src/shortfin.ld @@ -0,0 +1,4 @@ +SHORTFIN_API_3 { + /* Generally source level annotations are used. Exceptions only here. */ + global: *; +}; diff --git a/shortfin/src/shortfin/CMakeLists.txt b/shortfin/src/shortfin/CMakeLists.txt new file mode 100644 index 000000000..058e0e336 --- /dev/null +++ b/shortfin/src/shortfin/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_subdirectory(array) +add_subdirectory(local) +add_subdirectory(support) diff --git a/libshortfin/src/shortfin/array/CMakeLists.txt b/shortfin/src/shortfin/array/CMakeLists.txt similarity index 70% rename from libshortfin/src/shortfin/array/CMakeLists.txt rename to shortfin/src/shortfin/array/CMakeLists.txt index d40eed23f..48ab33590 100644 --- a/libshortfin/src/shortfin/array/CMakeLists.txt +++ b/shortfin/src/shortfin/array/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception shortfin_cc_component( NAME diff --git a/libshortfin/src/shortfin/array/api.h b/shortfin/src/shortfin/array/api.h similarity index 100% rename from libshortfin/src/shortfin/array/api.h rename to shortfin/src/shortfin/array/api.h diff --git a/libshortfin/src/shortfin/array/array.cc b/shortfin/src/shortfin/array/array.cc similarity index 56% rename from libshortfin/src/shortfin/array/array.cc rename to shortfin/src/shortfin/array/array.cc index 9a0b22bf2..882e4ef39 100644 --- a/libshortfin/src/shortfin/array/array.cc +++ b/shortfin/src/shortfin/array/array.cc @@ -11,11 +11,35 @@ #include "fmt/core.h" #include "fmt/ranges.h" #include "shortfin/array/xtensor_bridge.h" +#include "shortfin/support/logging.h" namespace shortfin::array { template class InlinedDims; +// -------------------------------------------------------------------------- // +// base_array +// -------------------------------------------------------------------------- // + +void base_array::expand_dims(Dims::value_type axis) { + auto shape = this->shape(); + if (axis > shape.size()) { + throw std::invalid_argument( + fmt::format("expand_dims axis must be <= rank ({}) but was {}", + shape.size(), axis)); + } + Dims new_dims(shape.size() + 1); + size_t j = 0; + for (size_t i = 0; i < axis; ++i) { + new_dims[j++] = shape[i]; + } + new_dims[j++] = 1; + for (size_t i = axis; i < shape.size(); ++i) { + new_dims[j++] = shape[i]; + } + set_shape(new_dims.span()); +} + // -------------------------------------------------------------------------- // // device_array // -------------------------------------------------------------------------- // @@ -40,6 +64,7 @@ mapping device_array::data_rw() { return storage_.map_read_write(); } mapping device_array::data_w() { return storage_.map_write_discard(); } std::optional device_array::map_memory_for_xtensor() { + SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::map_memory_for_xtensor"); if (storage_.is_mappable_for_read_write()) { return storage_.map_read_write(); } else if (storage_.is_mappable_for_read()) { @@ -73,6 +98,7 @@ std::string device_array::to_s() const { void device_array::AddAsInvocationArgument( local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) { + SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::AddAsInvocationArgument"); auto dims_span = shape(); iree_hal_buffer_view_t *buffer_view; SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_view_create( @@ -93,6 +119,7 @@ iree_vm_ref_type_t device_array::invocation_marshalable_type() { device_array device_array::CreateFromInvocationResultRef( local::ProgramInvocation *inv, iree::vm_opaque_ref ref) { + SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::CreateFromInvocationResultRef"); // We don't retain the buffer view in the device array, so just deref it // vs stealing the ref. iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get()); @@ -108,4 +135,60 @@ device_array device_array::CreateFromInvocationResultRef( DType::import_element_type(iree_hal_buffer_view_element_type(bv))); } +device_array device_array::view(Dims &offsets, Dims &sizes) { + auto rank = shape().size(); + if (offsets.size() != sizes.size() || offsets.empty() || + offsets.size() > rank) { + throw std::invalid_argument( + "view offsets and sizes must be of equal size and be of a rank " + "<= the array rank"); + } + if (rank == 0) { + throw std::invalid_argument("view cannot operate on rank 0 arrays"); + } + // Compute row strides. + Dims row_stride_bytes(shape().size()); + iree_device_size_t accum = dtype().dense_byte_count(); + for (int i = rank - 1; i >= 0; --i) { + row_stride_bytes[i] = accum; + accum *= shape()[i]; + } + + Dims new_dims(shape_container()); + bool has_stride = false; + iree_device_size_t start_offset = 0; + iree_device_size_t span_size = storage().byte_length(); + for (size_t i = 0; i < offsets.size(); ++i) { + auto row_stride = row_stride_bytes[i]; + auto dim_size = shape()[i]; + auto slice_offset = offsets[i]; + auto slice_size = sizes[i]; + if (slice_offset >= dim_size || (slice_offset + slice_size) > dim_size) { + throw std::invalid_argument( + fmt::format("Cannot index ({}:{}) into dim size {} at position {}", + slice_offset, slice_size, dim_size, i)); + } + if (has_stride && (slice_offset > 0 || slice_size != dim_size)) { + throw std::invalid_argument( + fmt::format("Cannot create a view with dimensions following a " + "spanning dim (at position {})", + i)); + } + if (slice_size > 1) { + has_stride = true; + } + + // Since we are only narrowing a dense, row major slice, as we traverse + // the dims, we are narrowing the memory view at each step by advancing + // the beginning based on the requested offset and pulling the end in + // by the difference in size. + start_offset += row_stride * slice_offset; + span_size -= row_stride * (new_dims[i] - slice_size); + new_dims[i] = slice_size; + } + + return device_array(storage().subspan(start_offset, span_size), + new_dims.span(), dtype()); +} + } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/array.h b/shortfin/src/shortfin/array/array.h similarity index 67% rename from libshortfin/src/shortfin/array/array.h rename to shortfin/src/shortfin/array/array.h index 875155757..8af584849 100644 --- a/libshortfin/src/shortfin/array/array.h +++ b/shortfin/src/shortfin/array/array.h @@ -21,6 +21,26 @@ namespace shortfin::array { +// Wraps an owned mapping and an adapted xtensor together, ensuring that the +// mapping remains live for the duration of the tensor. This presents as a +// smart pointer or iterator in that you dereference the tensor via `*` or +// `->`. +template +struct mapped_xtensor_holder { + public: + using xtensor_type = + decltype(xt::adapt(static_cast(nullptr), Dims())); + mapped_xtensor_holder(class mapping mapping, xtensor_type t) + : mapping(std::move(mapping)), t(std::move(t)) {} + + xtensor_type &operator*() { return t; } + xtensor_type &operator->() { return t; } + + private: + class mapping mapping; + xtensor_type t; +}; + // Either a host or device nd-array view. class SHORTFIN_API base_array { public: @@ -49,6 +69,9 @@ class SHORTFIN_API base_array { Dims &shape_container() { return shape_; } const Dims &shape_container() const { return shape_; } + // Inserts a unit dim at axis, which must be <= rank. + void expand_dims(Dims::value_type axis); + private: DType dtype_; Dims shape_; @@ -78,9 +101,10 @@ class SHORTFIN_API device_array // arrays that are visible from different combinations of host/device. static device_array for_host(local::ScopedDevice &device, std::span shape, - DType dtype) { + DType dtype, bool device_visible = true) { return device_array( - storage::allocate_host(device, dtype.compute_dense_nd_size(shape)), + storage::allocate_host(device, dtype.compute_dense_nd_size(shape), + device_visible), shape, dtype); } @@ -125,21 +149,63 @@ class SHORTFIN_API device_array // Typed access to the backing data. template typed_mapping typed_data() { + dtype().AssertCompatibleSize(); return typed_mapping(data()); } template typed_mapping typed_data() const { + dtype().AssertCompatibleSize(); return typed_mapping(data()); } template typed_mapping typed_data_rw() { + dtype().AssertCompatibleSize(); return typed_mapping(data_rw()); } template typed_mapping typed_data_w() { + dtype().AssertCompatibleSize(); return typed_mapping(data_w()); } + // Maps a read-only xtensor for the given EltTy (which must be compatible with + // the array dtype). The returned holder maintains the mapping and the + // xtensor, allowing the xtensor to be dereferenced via `*` or `->` like a + // pointer. + template + auto map_xtensor() { + dtype().AssertCompatibleSize(); + auto m = data(); + auto *data = static_cast(static_cast((m.data()))); + return mapped_xtensor_holder( + std::move(m), xt::adapt(static_cast(data), shape_container())); + } + + // Same as `map_xtensor()` but maps read-write. + template + auto map_xtensor_rw() { + dtype().AssertCompatibleSize(); + auto m = data_rw(); + auto *data = static_cast(static_cast((m.data()))); + return mapped_xtensor_holder( + std::move(m), xt::adapt(static_cast(data), shape_container())); + } + + // Same as `map_xtensor()` but maps write-only. + template + auto map_xtensor_w() { + dtype().AssertCompatibleSize(); + auto m = data_w(); + auto *data = static_cast(static_cast((m.data()))); + return mapped_xtensor_holder( + std::move(m), xt::adapt(static_cast(data), shape_container())); + } + + // Creates a device array which aliases the backing storage by slicing. Only + // slice shapes that produce a dense view without strides are supported by + // this mechanism. + device_array view(Dims &indices, Dims &sizes); + std::string to_s() const override; protected: diff --git a/libshortfin/src/shortfin/array/array_test.cc b/shortfin/src/shortfin/array/array_test.cc similarity index 89% rename from libshortfin/src/shortfin/array/array_test.cc rename to shortfin/src/shortfin/array/array_test.cc index 50d9c2b00..b1ed1ba3b 100644 --- a/libshortfin/src/shortfin/array/array_test.cc +++ b/shortfin/src/shortfin/array/array_test.cc @@ -25,8 +25,8 @@ class DeviceArrayTest : public testing::Test { void SetUp() override { system = systems::HostCPUSystemBuilder().CreateSystem(); - scope = system->CreateScope(system->init_worker(), system->devices()); - device = scope->device(0); + fiber = system->CreateFiber(system->init_worker(), system->devices()); + device = fiber->device(0); } void TearDown() override { system->Shutdown(); @@ -34,7 +34,7 @@ class DeviceArrayTest : public testing::Test { } SystemPtr system; - std::shared_ptr scope; + std::shared_ptr fiber; ScopedDevice device; }; @@ -43,7 +43,7 @@ TEST_F(DeviceArrayTest, contents_to_s_valid) { device, std::to_array({2, 3}), DType::float32()); { auto map = ary1.typed_data_w(); - std::fill(map.begin(), map.end(), 42.0); + std::fill(map.begin(), map.end(), 42.0f); } std::optional contents = ary1.contents_to_s(); diff --git a/libshortfin/src/shortfin/array/dims.h b/shortfin/src/shortfin/array/dims.h similarity index 77% rename from libshortfin/src/shortfin/array/dims.h rename to shortfin/src/shortfin/array/dims.h index a0cbacd00..b76aeef99 100644 --- a/libshortfin/src/shortfin/array/dims.h +++ b/shortfin/src/shortfin/array/dims.h @@ -40,17 +40,36 @@ class SHORTFIN_API InlinedDims { using reference = T &; using iterator_category = std::random_access_iterator_tag; iterator(pointer p) : p(p) {} - iterator &operator++() { + constexpr iterator &operator++() { p++; return *this; } - iterator &operator++(int) { + constexpr iterator operator++(int) { + auto tmp = *this; p++; + return tmp; + } + constexpr iterator &operator--() { + p--; return *this; } - bool operator==(iterator other) const { return p == other.p; } - bool operator!=(iterator other) const { return p != other.p; } - reference operator*() { return *p; } + constexpr iterator operator--(int) { + auto tmp = *this; + p--; + return tmp; + } + constexpr bool operator==(iterator other) const { return p == other.p; } + constexpr bool operator!=(iterator other) const { return p != other.p; } + constexpr reference operator*() { return *p; } + constexpr iterator operator+(difference_type d) const { + return iterator(p + d); + } + constexpr iterator operator-(difference_type d) const { + return iterator(p - d); + } + constexpr difference_type operator-(iterator rhs) const { + return static_cast(p - rhs.p); + } private: pointer p; @@ -64,17 +83,40 @@ class SHORTFIN_API InlinedDims { using iterator_category = std::random_access_iterator_tag; const_iterator(pointer p) : p(p) {} - const_iterator &operator++() { + constexpr const_iterator &operator++() { p++; return *this; } - const_iterator &operator++(int) { + constexpr const_iterator operator++(int) { + auto tmp = *this; p++; + return tmp; + } + constexpr const_iterator &operator--() { + p--; return *this; } - bool operator==(const_iterator other) const { return p == other.p; } - bool operator!=(const_iterator other) const { return p != other.p; } - reference operator*() { return *p; } + constexpr const_iterator operator--(int) { + auto tmp = *this; + p--; + return tmp; + } + constexpr bool operator==(const_iterator other) const { + return p == other.p; + } + constexpr bool operator!=(const_iterator other) const { + return p != other.p; + } + constexpr reference operator*() { return *p; } + constexpr const_iterator operator+(difference_type d) const { + return const_iterator(p + d); + } + constexpr const_iterator operator-(difference_type d) const { + return const_iterator(p - d); + } + constexpr difference_type operator-(const_iterator rhs) const { + return static_cast(p - rhs.p); + } private: pointer p; @@ -95,6 +137,11 @@ class SHORTFIN_API InlinedDims { std::fill(dims_.inline_dims.begin(), dims_.inline_dims.end(), value); } } + template + InlinedDims(BeginTy begin, EndTy end) { + assert(end > begin); + set(std::span(&(*begin), end - begin)); + } InlinedDims(const InlinedDims &other) { new (&dims_.inline_dims) InlineTy(); set(other.span()); @@ -168,6 +215,12 @@ class SHORTFIN_API InlinedDims { const_iterator end() const { return const_iterator(data() + size()); } const_iterator cbegin() const { return const_iterator(data()); } const_iterator cend() const { return const_iterator(data() + size()); } + reverse_iterator rbegin() { return reverse_iterator(begin()); } + reverse_iterator rend() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(begin()); + } + const_reverse_iterator rend() const { return const_reverse_iterator(end()); } void resize(size_type count) { resize_impl(count, value_type()); } void resize(size_type count, value_type value) { resize_impl(count, value); } @@ -249,7 +302,6 @@ class SHORTFIN_API InlinedDims { _D dims_; }; -extern template class InlinedDims; using Dims = InlinedDims; } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/dims_test.cc b/shortfin/src/shortfin/array/dims_test.cc similarity index 100% rename from libshortfin/src/shortfin/array/dims_test.cc rename to shortfin/src/shortfin/array/dims_test.cc diff --git a/libshortfin/src/shortfin/array/dtype.cc b/shortfin/src/shortfin/array/dtype.cc similarity index 100% rename from libshortfin/src/shortfin/array/dtype.cc rename to shortfin/src/shortfin/array/dtype.cc diff --git a/libshortfin/src/shortfin/array/dtype.h b/shortfin/src/shortfin/array/dtype.h similarity index 75% rename from libshortfin/src/shortfin/array/dtype.h rename to shortfin/src/shortfin/array/dtype.h index 061122ae7..de1763698 100644 --- a/libshortfin/src/shortfin/array/dtype.h +++ b/shortfin/src/shortfin/array/dtype.h @@ -9,6 +9,7 @@ #include #include +#include #include #include "iree/hal/buffer_view.h" @@ -20,11 +21,11 @@ namespace shortfin::array { class SHORTFIN_API DType { public: #define SHORTFIN_DTYPE_HANDLE(et, ident) \ - static DType ident() { return DType(et, #ident); } + static constexpr DType ident() { return DType(et, #ident); } #include "shortfin/array/dtypes.inl" #undef SHORTFIN_DTYPE_HANDLE - operator iree_hal_element_type_t() const { return et_; } + constexpr operator iree_hal_element_type_t() const { return et_; } std::string_view name() const { return name_; } @@ -48,6 +49,9 @@ class SHORTFIN_API DType { bool is_integer_bitwidth(size_t bitwidth) const { return iree_hal_element_type_is_integer(et_, bitwidth); } + uint32_t numerical_type() const { + return iree_hal_element_numerical_type(et_); + } // Computes the size in bytes required to store densely packed nd-dims. // This presently only supports byte aligned dtypes. In the future, when @@ -56,13 +60,23 @@ class SHORTFIN_API DType { // pre-condition iree_device_size_t compute_dense_nd_size(std::span dims); - bool operator==(const DType &other) const { return et_ == other.et_; } + constexpr bool operator==(const DType &other) const { + return et_ == other.et_; + } // Imports a raw iree_hal_element_type_t from the ether. static DType import_element_type(iree_hal_element_type_t et); + // Asserts that the sizeof EltTy is equal to the size of this dtype. + template + void AssertCompatibleSize() { + if (!is_byte_aligned() || sizeof(EltTy) != dense_byte_count()) { + throw std::invalid_argument("Incompatible element size"); + } + } + private: - DType(iree_hal_element_type_t et, std::string_view name) + constexpr DType(iree_hal_element_type_t et, std::string_view name) : et_(et), name_(name) {} iree_hal_element_type_t et_; std::string_view name_; diff --git a/libshortfin/src/shortfin/array/dtype_test.cc b/shortfin/src/shortfin/array/dtype_test.cc similarity index 98% rename from libshortfin/src/shortfin/array/dtype_test.cc rename to shortfin/src/shortfin/array/dtype_test.cc index d5f38c974..783fe9d52 100644 --- a/libshortfin/src/shortfin/array/dtype_test.cc +++ b/shortfin/src/shortfin/array/dtype_test.cc @@ -10,6 +10,7 @@ #include #include +#include namespace shortfin::array { diff --git a/libshortfin/src/shortfin/array/dtypes.inl b/shortfin/src/shortfin/array/dtypes.inl similarity index 100% rename from libshortfin/src/shortfin/array/dtypes.inl rename to shortfin/src/shortfin/array/dtypes.inl diff --git a/libshortfin/src/shortfin/array/storage.cc b/shortfin/src/shortfin/array/storage.cc similarity index 90% rename from libshortfin/src/shortfin/array/storage.cc rename to shortfin/src/shortfin/array/storage.cc index 542f69025..ffbbd9ba2 100644 --- a/libshortfin/src/shortfin/array/storage.cc +++ b/shortfin/src/shortfin/array/storage.cc @@ -38,11 +38,12 @@ storage::~storage() { logging::destruct("array::storage", this); } storage storage::import_buffer(local::ScopedDevice &device, iree::hal_buffer_ptr buffer) { return storage(device, std::move(buffer), - device.scope().NewTimelineResource()); + device.fiber().NewTimelineResource()); } storage storage::allocate_device(ScopedDevice &device, iree_device_size_t allocation_size) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::allocate_device"); if (!device.raw_device()) { throw std::invalid_argument("Cannot allocate with a null device affinity"); } @@ -57,11 +58,13 @@ storage storage::allocate_device(ScopedDevice &device, SHORTFIN_THROW_IF_ERROR(iree_hal_allocator_allocate_buffer( allocator, params, allocation_size, buffer.for_output())); return storage(device, std::move(buffer), - device.scope().NewTimelineResource()); + device.fiber().NewTimelineResource()); } storage storage::allocate_host(ScopedDevice &device, - iree_device_size_t allocation_size) { + iree_device_size_t allocation_size, + bool device_visible) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::allocate_host"); if (!device.raw_device()) { throw std::invalid_argument("Cannot allocate with a null device affinity"); } @@ -70,17 +73,19 @@ storage storage::allocate_host(ScopedDevice &device, iree_hal_buffer_params_t params = { .usage = IREE_HAL_BUFFER_USAGE_MAPPING, .access = IREE_HAL_MEMORY_ACCESS_ALL, - .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST | - IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_HOST, .queue_affinity = device.affinity().queue_affinity(), }; - if (device.affinity().queue_affinity() != 0) { - params.usage |= IREE_HAL_BUFFER_USAGE_TRANSFER; + if (device_visible) { + params.type |= IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE; + if (device.affinity().queue_affinity() != 0) { + params.usage |= IREE_HAL_BUFFER_USAGE_TRANSFER; + } } SHORTFIN_THROW_IF_ERROR(iree_hal_allocator_allocate_buffer( allocator, params, allocation_size, buffer.for_output())); return storage(device, std::move(buffer), - device.scope().NewTimelineResource()); + device.fiber().NewTimelineResource()); } storage storage::subspan(iree_device_size_t byte_offset, @@ -92,7 +97,7 @@ storage storage::subspan(iree_device_size_t byte_offset, } void storage::fill(const void *pattern, iree_host_size_t pattern_length) { - device_.scope().scheduler().AppendCommandBuffer( + device_.fiber().scheduler().AppendCommandBuffer( device_, TransactionType::TRANSFER, [&](Account &account) { // Must depend on all of this buffer's use dependencies to avoid // write-after-read hazard. @@ -108,7 +113,7 @@ void storage::fill(const void *pattern, iree_host_size_t pattern_length) { iree_hal_make_buffer_ref( buffer_, /*offset=*/0, /*length=*/iree_hal_buffer_byte_length(buffer_)), - pattern, pattern_length)); + pattern, pattern_length, IREE_HAL_FILL_FLAG_NONE)); // And move our own mutation barrier to the current pending timeline // value. @@ -118,7 +123,7 @@ void storage::fill(const void *pattern, iree_host_size_t pattern_length) { } void storage::copy_from(storage &source_storage) { - device_.scope().scheduler().AppendCommandBuffer( + device_.fiber().scheduler().AppendCommandBuffer( device_, TransactionType::TRANSFER, [&](Account &account) { // Must depend on the source's mutation dependencies to avoid // read-before-write hazard. @@ -136,7 +141,8 @@ void storage::copy_from(storage &source_storage) { /*source_ref=*/ iree_hal_make_buffer_ref(source_storage.buffer_, 0, byte_length()), /*target_ref=*/ - iree_hal_make_buffer_ref(buffer_, 0, byte_length()))); + iree_hal_make_buffer_ref(buffer_, 0, byte_length()), + IREE_HAL_COPY_FLAG_NONE)); // Move our own mutation barrier to the current pending timeline // value. @@ -203,6 +209,7 @@ std::string storage::formatted_buffer_usage() const { void storage::AddAsInvocationArgument(local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::AddAsInvocationArgument"); iree::vm_opaque_ref ref; *(&ref) = iree_hal_buffer_retain_ref(buffer_); inv->AddArg(std::move(ref)); @@ -216,6 +223,7 @@ iree_vm_ref_type_t storage::invocation_marshalable_type() { storage storage::CreateFromInvocationResultRef(local::ProgramInvocation *inv, iree::vm_opaque_ref ref) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::CreateFromInvocationResultRef"); // Steal the ref to one of our smart pointers. // TODO: Should have an opaque_ref::release(). iree::hal_buffer_ptr buffer = @@ -226,8 +234,9 @@ storage storage::CreateFromInvocationResultRef(local::ProgramInvocation *inv, storage storage::ImportInvocationResultStorage(local::ProgramInvocation *inv, iree::hal_buffer_ptr buffer) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::ImportInvocationResultStorage"); local::ScopedDevice device = - local::ScopedDevice(*inv->scope(), inv->device_selection()); + local::ScopedDevice(*inv->fiber(), inv->device_selection()); auto imported_storage = storage::import_buffer(device, std::move(buffer)); auto coarse_signal = inv->coarse_signal(); @@ -247,6 +256,7 @@ storage storage::ImportInvocationResultStorage(local::ProgramInvocation *inv, void storage::AddInvocationArgBarrier(local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::AddInvocationArgBarrier"); switch (barrier) { case ProgramResourceBarrier::DEFAULT: case ProgramResourceBarrier::READ: diff --git a/libshortfin/src/shortfin/array/storage.h b/shortfin/src/shortfin/array/storage.h similarity index 91% rename from libshortfin/src/shortfin/array/storage.h rename to shortfin/src/shortfin/array/storage.h index 5f3a568b2..2ea8f5aef 100644 --- a/libshortfin/src/shortfin/array/storage.h +++ b/shortfin/src/shortfin/array/storage.h @@ -9,8 +9,8 @@ #include +#include "shortfin/local/fiber.h" #include "shortfin/local/program_interfaces.h" -#include "shortfin/local/scope.h" #include "shortfin/support/api.h" namespace shortfin::array { @@ -75,9 +75,9 @@ class SHORTFIN_API storage : public local::ProgramInvocationMarshalable { public: ~storage(); local::ScopedDevice &device() { return device_; } - local::Scope &scope() { return device_.scope(); } + local::Fiber &fiber() { return device_.fiber(); } const local::ScopedDevice &device() const { return device_; } - local::Scope &scope() const { return device_.scope(); } + local::Fiber &fiber() const { return device_.fiber(); } static storage import_buffer(local::ScopedDevice &device, iree::hal_buffer_ptr buffer); @@ -91,9 +91,11 @@ class SHORTFIN_API storage : public local::ProgramInvocationMarshalable { // By default, if there are any affinity bits set in the device, then // the storage will be device visible and have permitted usage for // transfers. This default policy can be overriden based on device defaults - // or explicit options. + // or explicit options. Pass `device_visible=false` to create a pure host + // heap buffer. static storage allocate_host(local::ScopedDevice &device, - iree_device_size_t allocation_size); + iree_device_size_t allocation_size, + bool device_visible = true); // Creates a subspan view of the current storage given a byte offset and // length. The returned storage shares the underlying allocation and @@ -193,7 +195,7 @@ class SHORTFIN_API storage : public local::ProgramInvocationMarshalable { static storage ImportInvocationResultStorage(local::ProgramInvocation *inv, iree::hal_buffer_ptr buffer); - // The timeline resource holds the back reference to the owning scope, + // The timeline resource holds the back reference to the owning fiber, // which keeps all devices alive. Buffers must be destroyed before devices, // so this must be declared first. local::detail::TimelineResource::Ref timeline_resource_; @@ -230,14 +232,14 @@ class typed_mapping { span_type span() { return span_type(data(), size()); } const_span_type span() const { return const_span_type(data(), size()); } - span_type::iterator begin() { return span().begin(); } - span_type::iterator end() { return span().end(); } + typename span_type::iterator begin() { return span().begin(); } + typename span_type::iterator end() { return span().end(); } - const_span_type::iterator begin() const { return span().begin(); } - const_span_type::iterator end() const { return span().end(); } + typename const_span_type::iterator begin() const { return span().begin(); } + typename const_span_type::iterator end() const { return span().end(); } - const_span_type::iterator cbegin() const { return span().begin(); } - const_span_type::iterator cend() const { return span().end(); } + typename const_span_type::iterator cbegin() const { return span().begin(); } + typename const_span_type::iterator cend() const { return span().end(); } private: mapping untyped_mapping_; diff --git a/libshortfin/src/shortfin/array/xtensor_bridge.cc b/shortfin/src/shortfin/array/xtensor_bridge.cc similarity index 94% rename from libshortfin/src/shortfin/array/xtensor_bridge.cc rename to shortfin/src/shortfin/array/xtensor_bridge.cc index fa36c4ca1..da350b71a 100644 --- a/libshortfin/src/shortfin/array/xtensor_bridge.cc +++ b/shortfin/src/shortfin/array/xtensor_bridge.cc @@ -8,6 +8,9 @@ #include +#include "shortfin/support/logging.h" +#include "xtl/xhalf_float.hpp" + namespace shortfin::array { namespace { @@ -54,6 +57,7 @@ class typed_xt_methods final : public poly_xt_methods { bool poly_xt_methods::inplace_new(uint8_t *inst_storage, DType dtype, void *array_memory, size_t array_memory_size, Dims &dims) { + SHORTFIN_TRACE_SCOPE_NAMED("array_xtensor_cast"); #define POLY_XT_CASE(et, cpp_type) \ case et: \ typed_xt_methods::concrete_inplace_new( \ @@ -65,6 +69,7 @@ bool poly_xt_methods::inplace_new(uint8_t *inst_storage, DType dtype, POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_32, float); POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_32, int32_t); POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_32, int32_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_16, half_float::half); POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_32, uint32_t); POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_64, int64_t); POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_64, int64_t); @@ -77,8 +82,6 @@ bool poly_xt_methods::inplace_new(uint8_t *inst_storage, DType dtype, POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_16, uint16_t); POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_64, double); POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_BOOL_8, bool); - // TODO: float16 - // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_16, TODO); // TODO: bfloat16 // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_BFLOAT_16, TODO); // TODO: complex64 diff --git a/libshortfin/src/shortfin/array/xtensor_bridge.h b/shortfin/src/shortfin/array/xtensor_bridge.h similarity index 100% rename from libshortfin/src/shortfin/array/xtensor_bridge.h rename to shortfin/src/shortfin/array/xtensor_bridge.h diff --git a/libshortfin/src/shortfin/local/CMakeLists.txt b/shortfin/src/shortfin/local/CMakeLists.txt similarity index 63% rename from libshortfin/src/shortfin/local/CMakeLists.txt rename to shortfin/src/shortfin/local/CMakeLists.txt index 37d93bdcb..1c83f51da 100644 --- a/libshortfin/src/shortfin/local/CMakeLists.txt +++ b/shortfin/src/shortfin/local/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception add_subdirectory(systems) @@ -12,22 +12,22 @@ shortfin_cc_component( HDRS async.h device.h + fiber.h messaging.h process.h program.h worker.h scheduler.h - scope.h system.h SRCS async.cc device.cc + fiber.cc messaging.cc process.cc program.cc worker.cc scheduler.cc - scope.cc system.cc COMPONENTS shortfin_support @@ -35,6 +35,9 @@ shortfin_cc_component( iree_base_base iree_base_loop_sync iree_hal_hal + iree_hal_utils_allocators + iree_io_formats_parser_registry + iree_modules_io_parameters_parameters iree_modules_hal_hal iree_vm_vm iree_vm_bytecode_module diff --git a/libshortfin/src/shortfin/local/async.cc b/shortfin/src/shortfin/local/async.cc similarity index 100% rename from libshortfin/src/shortfin/local/async.cc rename to shortfin/src/shortfin/local/async.cc diff --git a/libshortfin/src/shortfin/local/async.h b/shortfin/src/shortfin/local/async.h similarity index 90% rename from libshortfin/src/shortfin/local/async.h rename to shortfin/src/shortfin/local/async.h index 75f5545d0..f4c40f706 100644 --- a/libshortfin/src/shortfin/local/async.h +++ b/shortfin/src/shortfin/local/async.h @@ -16,7 +16,7 @@ namespace shortfin::local { -class SHORTFIN_API Worker; +class Worker; // CompletionEvents are the most basic form of awaitable object. They // encapsulate a native iree_wait_source_t (which multiplexes any supported @@ -112,23 +112,25 @@ class SHORTFIN_API Future { virtual ~BaseState(); iree::slim_mutex lock_; Worker *worker_; - int ref_count_ = 1; - iree::ignorable_status failure_status_; - bool done_ = false; - std::vector callbacks_; + int ref_count_ SHORTFIN_GUARDED_BY(lock_) = 1; + iree::ignorable_status failure_status_ SHORTFIN_GUARDED_BY(lock_); + bool done_ SHORTFIN_GUARDED_BY(lock_) = false; + std::vector callbacks_ SHORTFIN_GUARDED_BY(lock_); }; Future(BaseState *state) : state_(state) {} void Retain() const; void Release() const; static Worker *GetRequiredWorker(); - void set_success() { state_->done_ = true; } + void SetSuccessWithLockHeld() SHORTFIN_REQUIRES_LOCK(state_->lock_) { + state_->done_ = true; + } // Posts a message to the worker to issue callbacks. Lock must be held. - void IssueCallbacksWithLockHeld(); + void IssueCallbacksWithLockHeld() SHORTFIN_REQUIRES_LOCK(state_->lock_); static iree_status_t RawHandleWorkerCallback(void *state_vp, iree_loop_t loop, iree_status_t status) noexcept; void HandleWorkerCallback(); - void ThrowFailureWithLockHeld(); + void ThrowFailureWithLockHeld() SHORTFIN_REQUIRES_LOCK(state_->lock_); mutable BaseState *state_; }; @@ -148,7 +150,11 @@ class SHORTFIN_API VoidFuture : public Future { return *this; } - using Future::set_success; + void set_success() { + iree::slim_mutex_lock_guard g(state_->lock_); + SetSuccessWithLockHeld(); + IssueCallbacksWithLockHeld(); + } }; // Value containing Future. @@ -183,7 +189,7 @@ class SHORTFIN_API TypedFuture : public Future { "Cannot 'set_failure' on a Future that is already done"); } static_cast(state_)->result_ = std::move(result); - set_success(); + SetSuccessWithLockHeld(); IssueCallbacksWithLockHeld(); } diff --git a/libshortfin/src/shortfin/local/device.cc b/shortfin/src/shortfin/local/device.cc similarity index 90% rename from libshortfin/src/shortfin/local/device.cc rename to shortfin/src/shortfin/local/device.cc index 3cb1bd40d..3afd2b8ad 100644 --- a/libshortfin/src/shortfin/local/device.cc +++ b/shortfin/src/shortfin/local/device.cc @@ -55,19 +55,21 @@ std::string DeviceAffinity::to_s() const { // -------------------------------------------------------------------------- // Device::Device(DeviceAddress address, iree::hal_device_ptr hal_device, - int node_affinity, bool node_locked) + int node_affinity, uint32_t capabilities) : address_(std::move(address)), hal_device_(std::move(hal_device)), node_affinity_(node_affinity), - node_locked_(node_locked) {} + capabilities_(capabilities) { + assert(hal_device_ && "nullptr iree_hal_device_t"); +} Device::~Device() = default; std::string Device::to_s() const { return fmt::format( - "Device(name='{}', ordinal={}:{}, node_affinity={}, node_locked={})", + "Device(name='{}', ordinal={}:{}, node_affinity={}, capabilities=0x{:x})", name(), address().instance_ordinal, address().queue_ordinal, - node_affinity(), node_locked()); + node_affinity(), capabilities_); } } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/device.h b/shortfin/src/shortfin/local/device.h similarity index 93% rename from libshortfin/src/shortfin/local/device.h rename to shortfin/src/shortfin/local/device.h index eb24d6a9f..a70fb0dae 100644 --- a/libshortfin/src/shortfin/local/device.h +++ b/shortfin/src/shortfin/local/device.h @@ -123,14 +123,22 @@ struct SHORTFIN_API DeviceAddress { // A device attached to the LocalSystem. class SHORTFIN_API Device { public: + enum class Capabilities : uint32_t { + NONE = 0, + // Indicates that the device has unified memory with the host that should + // be preferred when performing device-visible buffer manipulation. Note + // that many devices technically support host unified memory, but this bit + // indicates that it should be used for loading/storing/accessing device + // buffers without managing a separate staging DMA buffer. + PREFER_HOST_UNIFIED_MEMORY = 1, + }; Device(DeviceAddress address, iree::hal_device_ptr hal_device, - int node_affinity, bool node_locked); + int node_affinity, uint32_t capabilities); virtual ~Device(); const DeviceAddress &address() const { return address_; } std::string_view name() const { return address_.device_name; } int node_affinity() const { return node_affinity_; } - bool node_locked() const { return node_locked_; } iree_hal_device_t *hal_device() const { return hal_device_.get(); } std::string to_s() const; @@ -144,7 +152,7 @@ class SHORTFIN_API Device { DeviceAddress address_; iree::hal_device_ptr hal_device_; int node_affinity_; - bool node_locked_; + uint32_t capabilities_ = 0; }; // Holds a reference to a Device* and a bitmask of queues that are being diff --git a/libshortfin/src/shortfin/local/scope.cc b/shortfin/src/shortfin/local/fiber.cc similarity index 59% rename from libshortfin/src/shortfin/local/scope.cc rename to shortfin/src/shortfin/local/fiber.cc index e92368a52..8ad9f2960 100644 --- a/libshortfin/src/shortfin/local/scope.cc +++ b/shortfin/src/shortfin/local/fiber.cc @@ -4,7 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "shortfin/local/scope.h" +#include "shortfin/local/fiber.h" #include #include @@ -15,75 +15,73 @@ namespace shortfin::local { // -------------------------------------------------------------------------- // -// Scope +// Fiber // -------------------------------------------------------------------------- // -Scope::Scope(std::shared_ptr system, Worker &worker, +Fiber::Fiber(std::shared_ptr system, Worker &worker, std::span> devices) : system_(std::move(system)), host_allocator_(system_->host_allocator()), scheduler_(*system_), worker_(worker) { - logging::construct("Scope", this); + logging::construct("Fiber", this); for (auto &it : devices) { AddDevice(it.first, it.second); } Initialize(); } -Scope::Scope(std::shared_ptr system, Worker &worker, +Fiber::Fiber(std::shared_ptr system, Worker &worker, std::span devices) : system_(std::move(system)), host_allocator_(system_->host_allocator()), scheduler_(*system_), worker_(worker) { - logging::construct("Scope", this); + logging::construct("Fiber", this); for (auto *device : devices) { AddDevice(device->address().logical_device_class, device); } Initialize(); } -Scope::~Scope() { logging::destruct("Scope", this); } +Fiber::~Fiber() { logging::destruct("Fiber", this); } -std::string Scope::to_s() const { - return fmt::format("Scope(worker='{}', devices=[{}])", worker_.name(), +std::string Fiber::to_s() const { + return fmt::format("Fiber(worker='{}', devices=[{}])", worker_.name(), fmt::join(device_names(), ", ")); } -void Scope::Initialize() { scheduler_.Initialize(devices_); } +void Fiber::Initialize() { scheduler_.Initialize(devices_); } -void Scope::AddDevice(std::string_view device_class, Device *device) { +void Fiber::AddDevice(std::string_view device_class, Device *device) { device_class = interner_.intern(device_class); auto &count = device_class_count_[device_class]; std::string_view device_name = interner_.intern(fmt::format("{}{}", device_class, count++)); - named_devices_[device_name] = device; - devices_.push_back(device); + devices_.push_back(std::make_pair(device_name, device)); } -Device *Scope::raw_device(std::string_view name) const { - auto it = named_devices_.find(name); - if (it == named_devices_.end()) [[unlikely]] { - throw std::invalid_argument( - fmt::format("Device '{}' not found (available: {})", name, - fmt::join(device_names(), ", "))); +Device *Fiber::raw_device(std::string_view name) const { + for (auto &it : devices_) { + if (it.first == name) return it.second; } - return it->second; + throw std::invalid_argument( + fmt::format("Device '{}' not found (available: {})", name, + fmt::join(device_names(), ", "))); } -Device *Scope::raw_device(int ordinal) const { - if (ordinal < 0 || ordinal >= devices_.size()) { +Device *Fiber::raw_device(std::size_t ordinal) const { + if (ordinal >= devices_.size()) { throw std::invalid_argument( fmt::format("Device ordinal ({}) out of bounds", ordinal)); } - return devices_[ordinal]; + return devices_[ordinal].second; } -std::vector Scope::device_names() const { +std::vector Fiber::device_names() const { std::vector names; - names.reserve(named_devices_.size()); - for (auto &it : named_devices_) { + names.reserve(devices_.size()); + for (auto &it : devices_) { names.push_back(it.first); } return names; @@ -93,11 +91,11 @@ std::vector Scope::device_names() const { // ScopedDevice // -------------------------------------------------------------------------- // -CompletionEvent ScopedDevice::OnSync(bool flush) { +VoidFuture ScopedDevice::OnSync(bool flush) { if (flush) { - scope().scheduler().Flush(); + fiber().scheduler().Flush(); } - auto &default_account = scope().scheduler().GetDefaultAccount(*this); + auto &default_account = fiber().scheduler().GetDefaultAccount(*this); return default_account.OnSync(); } diff --git a/libshortfin/src/shortfin/local/scope.h b/shortfin/src/shortfin/local/fiber.h similarity index 60% rename from libshortfin/src/shortfin/local/scope.h rename to shortfin/src/shortfin/local/fiber.h index 14c4d3749..afd65b346 100644 --- a/libshortfin/src/shortfin/local/scope.h +++ b/shortfin/src/shortfin/local/fiber.h @@ -4,8 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#ifndef SHORTFIN_LOCAL_SCOPE_H -#define SHORTFIN_LOCAL_SCOPE_H +#ifndef SHORTFIN_LOCAL_FIBER_H +#define SHORTFIN_LOCAL_FIBER_H #include #include @@ -19,26 +19,26 @@ namespace shortfin::local { -class SHORTFIN_API Scope; -class SHORTFIN_API System; -class SHORTFIN_API Worker; +class Fiber; +class System; +class Worker; -// Wraps a Scope and a DeviceAffinity together. This is used in all -// Scope based APIs as a short-hand for "device" as it contains everything +// Wraps a Fiber and a DeviceAffinity together. This is used in all +// Fiber based APIs as a short-hand for "device" as it contains everything // needed to do thing with some slice of device queues. class SHORTFIN_API ScopedDevice { public: ScopedDevice() = default; - ScopedDevice(Scope &scope, DeviceAffinity affinity) - : scope_(&scope), affinity_(affinity) {} - ScopedDevice(Scope &scope, Device *device) - : scope_(&scope), affinity_(device) {} + ScopedDevice(Fiber &fiber, DeviceAffinity affinity) + : fiber_(&fiber), affinity_(affinity) {} + ScopedDevice(Fiber &fiber, Device *device) + : fiber_(&fiber), affinity_(device) {} ScopedDevice(const ScopedDevice &other) - : scope_(other.scope_), affinity_(other.affinity_) {} + : fiber_(other.fiber_), affinity_(other.affinity_) {} - Scope &scope() const { - assert(scope_ && "scope must not be null"); - return *scope_; + Fiber &fiber() const { + assert(fiber_ && "fiber must not be null"); + return *fiber_; } DeviceAffinity affinity() const { return affinity_; } Device *raw_device() const { return affinity_.device(); } @@ -46,75 +46,74 @@ class SHORTFIN_API ScopedDevice { std::string to_s() const { return affinity().to_s(); } bool operator==(const ScopedDevice &other) const { - return (scope_ == other.scope_) && affinity_ == other.affinity_; + return (fiber_ == other.fiber_) && affinity_ == other.affinity_; } // Returns a future which will be satisfied when the primary device timeline // of this affinity set progresses to "now". This will be true when all // currently queued work on the device has been completed. - CompletionEvent OnSync(bool flush = true); + VoidFuture OnSync(bool flush = true); private: - Scope *scope_ = nullptr; + Fiber *fiber_ = nullptr; DeviceAffinity affinity_; }; -// A logical scope of execution, consisting of participating devices, +// A logical fiber of execution, consisting of participating devices, // resources, and timelines. Most interaction with the compute resources // is done on these instances. // -// The scope is generally instantiated with a slice of system resources, +// The fiber is generally instantiated with a slice of system resources, // and produces an arrangement that is easy to use vs maximally diverse. // // Devices // ------- -// The scope is initialized with a list of participating devices, which is +// The fiber is initialized with a list of participating devices, which is // a subset of all devices managed by the LocalSystem. Each device is given // a logical name of the form ``, by default using the // DeviceAddress::logical_device_class as the ``. In exotic // situations, this can be customized. By default, devices are added in the // order defined by the system and will have an `` corresponding to // their order. It is up to the constructor to produce a sensible arrangement. -class SHORTFIN_API Scope : public std::enable_shared_from_this { +class SHORTFIN_API Fiber : public std::enable_shared_from_this { public: // Initialize with devices using logical_device_class as the device class. - Scope(std::shared_ptr system, Worker &worker, + Fiber(std::shared_ptr system, Worker &worker, std::span devices); // Initialize with devices with custom device class names. - Scope(std::shared_ptr system, Worker &worker, + Fiber(std::shared_ptr system, Worker &worker, std::span> devices); - Scope(const Scope &) = delete; + Fiber(const Fiber &) = delete; // Ensure polymorphic. - virtual ~Scope(); + virtual ~Fiber(); std::string to_s() const; // All scopes are created as shared pointers. - std::shared_ptr shared_ptr() { return shared_from_this(); } + std::shared_ptr shared_ptr() { return shared_from_this(); } // The host allocator. iree_allocator_t host_allocator() { return host_allocator_; } - // The worker that this scope is bound to. + // The worker that this fiber is bound to. Worker &worker() { return worker_; } - // System that this scope is bound to. + // System that this fiber is bound to. System &system() { return *system_; } // Device access. // Throws std::invalid_argument on lookup failure. Device *raw_device(std::string_view name) const; - const std::unordered_map named_devices() const { - return named_devices_; - } - Device *raw_device(int index) const; + Device *raw_device(std::size_t index) const; Device *raw_device(Device *device) const { return device; } - const std::vector &raw_devices() const { return devices_; } + std::span> raw_devices() const { + return devices_; + } std::vector device_names() const; // Variadic helper for making a DeviceAffinity from any of: // * Explicit Device* - // * Device name (from a Scope) - // * Device index (from a Scope) + // * Device name (from a Fiber) + // * Device index (from a Fiber) // If at any point during accumulation, the DeviceAffinity would be invalid, // then a std::invalid_argument exception is thrown. Any failure to resolve // a name or index will also throw a std::invalid_argument. @@ -145,12 +144,27 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { // Map of `` to the count of that class contained. std::unordered_map device_class_count_; - // Ordered devices. - std::vector devices_; - // Map of `` to Device. - std::unordered_map named_devices_; + // Ordered devices named as ``. + std::vector> devices_; + + // Program isolation control. + // This data structure is manipulated by APIs on the Program class hierarchy. + // It maps a parent context pointer to an isolate accounting struct. This + // struct contains a strong reference to the parent_context and a vector + // of fork contexts. For PER_FIBER invocations, there will only ever be either + // zero or one fork_contexts: when no calls have been issued there will be one + // and if a call is outstanding, there will be zero. This is used to guard + // concurrent access. For PER_CALL invocations, there will be as many + // fork_contexts as are needed to satisfy the peak number of calls in flight + // at any time. + // The program_isolate_mu_ must be held to manipulate the accounting structs. + iree::slim_mutex program_isolate_mu_; + std::unordered_map> + program_isolates_; + friend struct detail::ProgramIsolate; }; } // namespace shortfin::local -#endif // SHORTFIN_LOCAL_SCOPE_H +#endif // SHORTFIN_LOCAL_FIBER_H diff --git a/libshortfin/src/shortfin/local/messaging.cc b/shortfin/src/shortfin/local/messaging.cc similarity index 73% rename from libshortfin/src/shortfin/local/messaging.cc rename to shortfin/src/shortfin/local/messaging.cc index 7e3e166ab..a7df25700 100644 --- a/libshortfin/src/shortfin/local/messaging.cc +++ b/shortfin/src/shortfin/local/messaging.cc @@ -14,20 +14,58 @@ namespace shortfin::local { // Message // -------------------------------------------------------------------------- // -template class TypedFuture; - -Message::~Message() = default; - // -------------------------------------------------------------------------- // // Queue // -------------------------------------------------------------------------- // +namespace { + +struct QueueCreator : public Queue { + QueueCreator(Options options) : Queue(std::move(options)) {} +}; + +} // namespace + Queue::Queue(Options options) : options_(std::move(options)) {} +std::shared_ptr Queue::Create(Options options) { + return std::make_shared(std::move(options)); +} + std::string Queue::to_s() const { return fmt::format("Queue(name={})", options().name); } +bool Queue::is_closed() { + iree::slim_mutex_lock_guard g(lock_); + return closed_; +} + +void Queue::WriteNoDelay(Message::Ref mr) { + std::optional future; + { + iree::slim_mutex_lock_guard g(lock_); + if (pending_readers_.empty()) { + // No readers. Just add to the backlog. + backlog_.push_back(std::move(mr)); + return; + } else { + // Signal a reader. We must do this within the queue lock to avoid + // a QueueReader lifetime hazard. But we defer actually setting the + // future until out of the lock. + QueueReader *reader = pending_readers_.front(); + pending_readers_.pop_front(); + future = *reader->read_result_future_; + // Reset the worker for a new read. + reader->worker_ = nullptr; + reader->read_result_future_.reset(); + } + } + + // Signal the future outside of our lock. + future->set_result(std::move(mr)); +} + void Queue::Close() { std::vector async_close_readers; { @@ -58,48 +96,27 @@ void Queue::Close() { } QueueWriter::QueueWriter(Queue &queue, Options options) - : queue_(queue), options_(std::move(options)) {} + : queue_(queue.shared_from_this()), options_(std::move(options)) {} QueueWriter::~QueueWriter() = default; CompletionEvent QueueWriter::Write(Message::Ref mr) { - std::optional future; - { - iree::slim_mutex_lock_guard g(queue_.lock_); - if (queue_.pending_readers_.empty()) { - // No readers. Just add to the backlog. - queue_.backlog_.push_back(std::move(mr)); - return CompletionEvent(); - } else { - // Signal a reader. We must do this within the queue lock to avoid - // a QueueReader lifetime hazard. But we defer actually setting the - // future until out of the lock. - QueueReader *reader = queue_.pending_readers_.front(); - queue_.pending_readers_.pop_front(); - future = *reader->read_result_future_; - // Reset the worker for a new read. - reader->worker_ = nullptr; - reader->read_result_future_.reset(); - } - } - - // Signal the future outside of our lock. - future->set_result(std::move(mr)); + queue().WriteNoDelay(std::move(mr)); return CompletionEvent(); } QueueReader::QueueReader(Queue &queue, Options options) - : queue_(queue), options_(std::move(options)) {} + : queue_(queue.shared_from_this()), options_(std::move(options)) {} QueueReader::~QueueReader() { - iree::slim_mutex_lock_guard g(queue_.lock_); + iree::slim_mutex_lock_guard g(queue().lock_); if (read_result_future_) { logging::warn("QueueReader destroyed while pending"); // Reader is in progress: Cancel it from the queue. - auto it = std::find(queue_.pending_readers_.begin(), - queue_.pending_readers_.end(), this); - if (it != queue_.pending_readers_.end()) { - queue_.pending_readers_.erase(it); + auto it = std::find(queue().pending_readers_.begin(), + queue().pending_readers_.end(), this); + if (it != queue().pending_readers_.end()) { + queue().pending_readers_.erase(it); } } } @@ -107,7 +124,7 @@ QueueReader::~QueueReader() { MessageFuture QueueReader::Read() { // TODO: It should be possible to further constrain the scope of this lock, // but it is set here to be conservatively safe pending a full analysis. - iree::slim_mutex_lock_guard g(queue_.lock_); + iree::slim_mutex_lock_guard g(queue().lock_); if (worker_) { throw std::logic_error( "Cannot read concurrently from a single QueueReader"); @@ -120,17 +137,17 @@ MessageFuture QueueReader::Read() { } // See if there is a backlog that we can immediately satisfy. - if (!queue_.backlog_.empty()) { + if (!queue().backlog_.empty()) { // Service from the backlog. MessageFuture imm_future(worker_); - imm_future.set_result(std::move(queue_.backlog_.front())); - queue_.backlog_.pop_front(); + imm_future.set_result(std::move(queue().backlog_.front())); + queue().backlog_.pop_front(); worker_ = nullptr; return imm_future; } // Handle close. - if (queue_.closed_) { + if (queue().closed_) { MessageFuture imm_future(worker_); imm_future.set_result(Message::Ref()); worker_ = nullptr; @@ -138,7 +155,7 @@ MessageFuture QueueReader::Read() { } // Settle in for a wait. - queue_.pending_readers_.push_back(this); + queue().pending_readers_.push_back(this); read_result_future_ = MessageFuture(worker_); return *read_result_future_; } diff --git a/libshortfin/src/shortfin/local/messaging.h b/shortfin/src/shortfin/local/messaging.h similarity index 66% rename from libshortfin/src/shortfin/local/messaging.h rename to shortfin/src/shortfin/local/messaging.h index e60417fd7..7b33cdd18 100644 --- a/libshortfin/src/shortfin/local/messaging.h +++ b/shortfin/src/shortfin/local/messaging.h @@ -8,6 +8,7 @@ #define SHORTFIN_LOCAL_MESSAGING_H #include +#include #include #include @@ -22,18 +23,32 @@ namespace shortfin::local { // Message // -------------------------------------------------------------------------- // -class SHORTFIN_API Message; +class Message; namespace detail { -struct MessageRefOwner { - MessageRefOwner() : Control(nullptr) {} +// Message lifetime by default is managed by an internal reference count +// system. However, since Messages often need to be owned by some third +// party system with its own notion of lifetime, it is possible to provide +// a custom lifetime controller. This can only be done once, typically on +// construction by a proxy system. +struct MessageLifetimeController { + MessageLifetimeController() : Control(nullptr) {} enum class Request { RETAIN, RELEASE }; - MessageRefOwner(void (*Control)(Request req, const Message &msg)) + MessageLifetimeController(void (*Control)(Request req, const Message &msg)) : Control(Control) {} void (*Control)(Request req, const Message &msg); operator bool() { return Control != nullptr; } - static intptr_t &access_ref_data(const Message &msg); - intptr_t set_owner(const Message &msg, intptr_t ref_data); + // Takes ownership of the Message using this ownership controller, providing + // new ref_data that will be stored in the message and accessed from then + // on without internal locking. Returns the existing reference count at the + // time of transfer. + intptr_t TakeOwnership(const Message &msg, intptr_t ref_data); + // Accessed the ref_data memory within the Message. This is only valid + // if ownership has been transferred to a lifetime controller, and it is + // accessed without locking. This method purely exists to add some static + // thread/access safety. + static intptr_t &AccessOwnedRefData(const Message &msg) + SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(no_thread_safety_analysis); }; } // namespace detail @@ -58,7 +73,7 @@ class SHORTFIN_API Message { Message(const Message &) = delete; Message(Message &&) = delete; Message &operator=(const Message &) = delete; - virtual ~Message(); + virtual ~Message() = default; // RAII class for holding a reference to a Message. class Ref { @@ -104,10 +119,7 @@ class SHORTFIN_API Message { }; protected: - // Guard a scope with the fine grained lock. - iree::slim_mutex_lock_guard lock_guard() const { - return iree::slim_mutex_lock_guard(lock_); - } + mutable iree::slim_mutex lock_; // Manual retain and release. Callers must assume that the Message is no // longer valid after any call to Release() where they do not hold a known // reference. @@ -121,46 +133,66 @@ class SHORTFIN_API Message { // sized field that the allocator can use at it sees fit. Both fields // are managed within a lock_ scope and are optimized for single threaded // access and cross-thread transfers with coarse references. - mutable iree::slim_mutex lock_; - mutable intptr_t ref_data_ = 1; - mutable detail::MessageRefOwner owner_; - friend struct detail::MessageRefOwner; + mutable intptr_t ref_data_ SHORTFIN_GUARDED_BY(lock_) = 1; + mutable detail::MessageLifetimeController lifetime_controller_ + SHORTFIN_GUARDED_BY(lock_); + friend struct detail::MessageLifetimeController; }; // Future specialization for Message::Ref. -extern template class TypedFuture; +template class TypedFuture; using MessageFuture = TypedFuture; // -------------------------------------------------------------------------- // // Queue // -------------------------------------------------------------------------- // -class SHORTFIN_API QueueReader; -class SHORTFIN_API QueueWriter; +class Queue; +class QueueReader; +class QueueWriter; + +namespace { +struct QueueCreator; +} + +using QueuePtr = std::shared_ptr; // Queues are the primary form of communication in shortfin for exchanging // messages. They are inherently thread safe and coupled with the async/worker // system for enqueue/dequeue operations. -class SHORTFIN_API Queue { +class SHORTFIN_API Queue : public std::enable_shared_from_this { public: struct Options { - // Queues are generally managed by the system with a global name. + // Queues are generally managed by the system with a global name. The + // the name is empty, then this is an anonymous queue. std::string name; }; - Queue(Options options); Queue(const Queue &) = delete; Queue &operator=(const Queue &) = delete; Queue(Queue &&) = delete; ~Queue() = default; + operator QueuePtr() { return shared_from_this(); } + const Options &options() const { return options_; } std::string to_s() const; + // Returns whether the queue is still open. + bool is_closed(); + + // Writes a message to the queue without any possible delay, possibly + // overriding capacity and throttling policy. + void WriteNoDelay(Message::Ref mr); + // Closes the queue. All readers will return with a null message from here // on. Writers that attempt to write to the queue will throw an exception. void Close(); + protected: private: + // Queues can only be created as shared by the System. + static QueuePtr Create(Options options); + Queue(Options options); mutable iree::slim_mutex lock_; Options options_; // Backlog of messages not yet sent to a reader. Messages are pushed on the @@ -175,6 +207,8 @@ class SHORTFIN_API Queue { friend class QueueReader; friend class QueueWriter; + friend QueueCreator; + friend class System; }; // Writes messages to a queue. @@ -188,16 +222,18 @@ class SHORTFIN_API QueueWriter { QueueWriter(Queue &queue, Options options = {}); ~QueueWriter(); + Queue &queue() { return *queue_; } + // Writes a message to the queue. // The write must be awaited as it can produce backpressure and failures. // TODO: This should be a Future so that exceptions can propagate. CompletionEvent Write(Message::Ref mr); // Calls Close() on the backing queue. - void Close() { queue_.Close(); } + void Close() { queue_->Close(); } private: - Queue &queue_; + std::shared_ptr queue_; Options options_; }; @@ -207,11 +243,13 @@ class SHORTFIN_API QueueReader { QueueReader(Queue &queue, Options options = {}); ~QueueReader(); + Queue &queue() { return *queue_; } + // Reads a message from the queue. MessageFuture Read(); private: - Queue &queue_; + std::shared_ptr queue_; Options options_; // Reader state machine. If worker_ is non null, then there must be a @@ -227,37 +265,50 @@ class SHORTFIN_API QueueReader { // Message allocation detail // -------------------------------------------------------------------------- // -inline intptr_t &detail::MessageRefOwner::access_ref_data(const Message &msg) { +inline intptr_t &detail::MessageLifetimeController::AccessOwnedRefData( + const Message &msg) { return msg.ref_data_; } -inline intptr_t detail::MessageRefOwner::set_owner(const Message &msg, - intptr_t ref_data) { - auto g = msg.lock_guard(); - assert(!msg.owner_ && "Message ref owner transfer more than once"); - msg.owner_ = *this; +inline intptr_t detail::MessageLifetimeController::TakeOwnership( + const Message &msg, intptr_t ref_data) { + iree::slim_mutex_lock_guard g(msg.lock_); + assert(!msg.lifetime_controller_ && + "Message ref owner transfer more than once"); + msg.lifetime_controller_ = *this; intptr_t orig_ref_data = msg.ref_data_; msg.ref_data_ = ref_data; return orig_ref_data; } inline void Message::Retain() const { - auto g = lock_guard(); - if (owner_) { - owner_.Control(detail::MessageRefOwner::Request::RETAIN, *this); + iree::slim_mutex_lock_guard g(lock_); + if (lifetime_controller_) { + lifetime_controller_.Control( + detail::MessageLifetimeController::Request::RETAIN, *this); } else { ref_data_ += 1; } } inline void Message::Release() const { - auto g = lock_guard(); - if (owner_) { - owner_.Control(detail::MessageRefOwner::Request::RELEASE, *this); + // Since the destructor of the lock asserts that it is not held, we must + // manually release the lock prior to an action that may result in + // destruction. As such, just manage lock manually/carefully vs using RAII. + lock_.Lock(); + auto *local_controller = &lifetime_controller_; + if (*local_controller) { + lock_.Unlock(); + local_controller->Control( + detail::MessageLifetimeController::Request::RELEASE, *this); + return; + } else if (--ref_data_ == 0) { + lock_.Unlock(); + delete this; + return; } else { - if (--ref_data_ == 0) { - delete this; - } + lock_.Unlock(); + return; } } diff --git a/libshortfin/src/shortfin/local/process.cc b/shortfin/src/shortfin/local/process.cc similarity index 73% rename from libshortfin/src/shortfin/local/process.cc rename to shortfin/src/shortfin/local/process.cc index 3d3f6ff54..bafe4cf26 100644 --- a/libshortfin/src/shortfin/local/process.cc +++ b/shortfin/src/shortfin/local/process.cc @@ -12,10 +12,13 @@ namespace shortfin::local { -detail::BaseProcess::BaseProcess(std::shared_ptr scope) - : scope_(std::move(scope)) {} +detail::BaseProcess::BaseProcess() = default; +detail::BaseProcess::~BaseProcess() = default; -detail::BaseProcess::~BaseProcess() {} +void detail::BaseProcess::Initialize(std::shared_ptr fiber) { + assert(!fiber_ && "BaseProcess::Initialize already called"); + fiber_ = std::move(fiber); +} int64_t detail::BaseProcess::pid() const { iree::slim_mutex_lock_guard g(lock_); @@ -31,24 +34,24 @@ std::string detail::BaseProcess::to_s() const { if (pid == 0) { return fmt::format("Process(NOT_STARTED, worker='{}')", - scope_->worker().name()); + fiber_->worker().name()); } else if (pid < 0) { return fmt::format("Process(TERMINATED, worker='{}')", - scope_->worker().name()); + fiber_->worker().name()); } else { return fmt::format("Process(pid={}, worker='{}')", pid, - scope_->worker().name()); + fiber_->worker().name()); } } void detail::BaseProcess::Launch() { - Scope* scope = scope_.get(); + Fiber* fiber = fiber_.get(); { iree::slim_mutex_lock_guard g(lock_); if (pid_ != 0) { throw std::logic_error("Process can only be launched a single time"); } - pid_ = scope->system().AllocateProcess(this); + pid_ = fiber->system().AllocateProcess(this); } ScheduleOnWorker(); @@ -67,7 +70,7 @@ void detail::BaseProcess::Terminate() { } } if (deallocate_pid > 0) { - scope_->system().DeallocateProcess(deallocate_pid); + fiber_->system().DeallocateProcess(deallocate_pid); } else { logging::warn("Process signalled termination multiple times (ignored)"); } @@ -81,4 +84,6 @@ CompletionEvent detail::BaseProcess::OnTermination() { return CompletionEvent(terminated_event_); } +Process::Process(std::shared_ptr fiber) { Initialize(std::move(fiber)); } + } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/process.h b/shortfin/src/shortfin/local/process.h similarity index 81% rename from libshortfin/src/shortfin/local/process.h rename to shortfin/src/shortfin/local/process.h index d10c88a3e..17dc67d58 100644 --- a/libshortfin/src/shortfin/local/process.h +++ b/shortfin/src/shortfin/local/process.h @@ -11,7 +11,7 @@ #include #include "shortfin/local/async.h" -#include "shortfin/local/scope.h" +#include "shortfin/local/fiber.h" #include "shortfin/local/worker.h" #include "shortfin/support/api.h" #include "shortfin/support/iree_concurrency.h" @@ -26,19 +26,27 @@ namespace detail { // structure and external lifetime management. class SHORTFIN_API BaseProcess { public: - BaseProcess(std::shared_ptr scope); + BaseProcess(); BaseProcess(const BaseProcess &) = delete; virtual ~BaseProcess(); // The unique pid of this process (or zero if not launched). int64_t pid() const; std::string to_s() const; - std::shared_ptr &scope() { return scope_; } + std::shared_ptr &fiber() { return fiber_; } // Returns a future that can be waited on for termination. CompletionEvent OnTermination(); protected: + // Derived classes must arrange to call Initialize() before any operation + // is taken on the instance. In C++ subclasses, this will typically be done + // in the constructor, but for bindings, this can be separated. + void Initialize(std::shared_ptr fiber); + + // Whether subclass initialization has been done. + bool is_initialized() const { return fiber_.get(); } + // Launches the process. void Launch(); @@ -51,7 +59,7 @@ class SHORTFIN_API BaseProcess { void Terminate(); private: - std::shared_ptr scope_; + std::shared_ptr fiber_; // Process control state. Since this can be accessed by multiple threads, // it is protected by a lock. Most process state can only be accessed on @@ -73,7 +81,7 @@ class SHORTFIN_API BaseProcess { // driven fashion (i.e. cps, async/await, co-routines, etc). class SHORTFIN_API Process : public detail::BaseProcess { public: - using BaseProcess::BaseProcess; + Process(std::shared_ptr fiber); }; } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc similarity index 62% rename from libshortfin/src/shortfin/local/program.cc rename to shortfin/src/shortfin/local/program.cc index f3154bd32..6ab1f47ae 100644 --- a/libshortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -8,9 +8,11 @@ #include "fmt/core.h" #include "fmt/std.h" +#include "iree/io/formats/parser_registry.h" #include "iree/modules/hal/module.h" +#include "iree/modules/io/parameters/module.h" #include "iree/vm/bytecode/module.h" -#include "shortfin/local/scope.h" +#include "shortfin/local/fiber.h" #include "shortfin/local/system.h" #include "shortfin/support/logging.h" @@ -34,12 +36,12 @@ void GetVmModuleExports(iree_vm_module_t *vm_module, // -------------------------------------------------------------------------- // ProgramFunction::ProgramFunction( - std::shared_ptr scope, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + iree::vm_context_ptr vm_context, iree_vm_function_t vm_function, + ProgramIsolation isolation, std::optional invocation_model) - : scope_(std::move(scope)), - vm_context_(std::move(vm_context)), + : vm_context_(std::move(vm_context)), vm_function_(vm_function), + isolation_(isolation), invocation_model_(invocation_model ? *invocation_model : GetInvocationModelFromFunction(vm_function)) {} @@ -71,9 +73,21 @@ std::string_view ProgramFunction::calling_convention() const { iree_vm_function_signature(&vm_function_).calling_convention); } -ProgramInvocation::Ptr ProgramFunction::CreateInvocation() { - return ProgramInvocation::New(scope_, vm_context_, vm_function_, - invocation_model_); +ProgramInvocation::Ptr ProgramFunction::CreateInvocation( + std::shared_ptr fiber, std::optional isolation) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramFunction::CreateInvocation"); + ProgramIsolation actual_isolation = isolation ? *isolation : isolation_; + // Low-overhead NONE isolation handling (saves some ref-count twiddling). + if (actual_isolation == ProgramIsolation::NONE) { + return ProgramInvocation::New(std::move(fiber), vm_context_, vm_function_, + invocation_model_, /*isolate=*/nullptr); + } + + // Create an isolated invocation. + auto [isolated_context, isolate] = detail::ProgramIsolate::AcquireIsolate( + *fiber, vm_context_, actual_isolation); + return ProgramInvocation::New(std::move(fiber), std::move(isolated_context), + vm_function_, invocation_model_, isolate); } std::string ProgramFunction::to_s() const { @@ -88,11 +102,13 @@ std::string ProgramFunction::to_s() const { ProgramModule ProgramModule::Load(System &system, const std::filesystem::path &path, bool mmap) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramModule::Load"); iree::file_contents_ptr contents; iree_file_read_flags_t flags = mmap ? IREE_FILE_READ_FLAG_MMAP : IREE_FILE_READ_FLAG_PRELOAD; - SHORTFIN_THROW_IF_ERROR(iree_file_read_contents( - path.c_str(), flags, system.host_allocator(), contents.for_output())); + SHORTFIN_THROW_IF_ERROR(iree_file_read_contents(path.string().c_str(), flags, + system.host_allocator(), + contents.for_output())); // Ownership hazard: iree_vm_bytecode_module_create only assumes ownership // of the contents when it returns *sucessfully*. In the exceptional case, @@ -103,7 +119,27 @@ ProgramModule ProgramModule::Load(System &system, system.vm_instance(), contents.const_buffer(), contents.deallocator(), system.host_allocator(), module.for_output())); contents.release(); // Must be invoked on success path only. - return ProgramModule(std::move(module)); + return ProgramModule(system.shared_from_this(), std::move(module)); +} + +ProgramModule ProgramModule::ParameterProvider( + System &system, std::span params) { + std::vector providers; + providers.reserve(params.size()); + for (auto *param : params) { + iree_io_parameter_provider_t *provider = *param; + if (!provider) { + throw std::logic_error( + "Cannot pass uninitialized parameters to ParameterProvider"); + } + providers.push_back(provider); + } + + iree::vm_module_ptr module; + SHORTFIN_THROW_IF_ERROR(iree_io_parameters_module_create( + system.vm_instance(), providers.size(), providers.data(), + system.host_allocator(), module.for_output())); + return ProgramModule(system.shared_from_this(), std::move(module)); } std::string_view ProgramModule::name() const { @@ -135,14 +171,28 @@ std::vector ProgramModule::exports() const { // Program // -------------------------------------------------------------------------- // -Program Program::Load(std::shared_ptr scope, - std::span modules, Options options) { +Program Program::Load(std::span modules, + Options &&options) { + SHORTFIN_TRACE_SCOPE_NAMED("Program::Load"); std::vector all_modules; std::vector raw_devices; - // By default, bind all devices in the scope in order to the program. - for (Device *d : scope->raw_devices()) { - raw_devices.push_back(d->hal_device()); + System *system = nullptr; + // By default, bind all devices in the fiber in order to the program. + for (auto &it : options.devices) { + raw_devices.push_back(it->hal_device()); + } + + for (auto &mod : modules) { + if (system && &mod.system() != system) { + throw std::invalid_argument( + "Cannot create Program from modules loaded from multiple system " + "instances"); + } + system = &mod.system(); + } + if (!system) { + throw std::invalid_argument("Cannot create Program with no modules"); } // Add a HAL module. @@ -154,12 +204,11 @@ Program Program::Load(std::shared_ptr scope, // functionality (or module versions; iree_vm_module_dependency_t has the // minimum version required so you can switch between them, and whether they // are optional/required). - auto &system = scope->system(); iree::vm_module_ptr hal_module; - SHORTFIN_THROW_IF_ERROR( - iree_hal_module_create(system.vm_instance(), raw_devices.size(), - raw_devices.data(), IREE_HAL_MODULE_FLAG_NONE, - system.host_allocator(), hal_module.for_output())); + SHORTFIN_THROW_IF_ERROR(iree_hal_module_create( + system->vm_instance(), raw_devices.size(), raw_devices.data(), + IREE_HAL_MODULE_FLAG_NONE, iree_hal_module_debug_sink_stdio(stderr), + system->host_allocator(), hal_module.for_output())); all_modules.push_back(hal_module); // Add explicit modules. @@ -172,10 +221,10 @@ Program Program::Load(std::shared_ptr scope, iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT; if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION; SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules( - system.vm_instance(), flags, all_modules.size(), all_modules.data(), - system.host_allocator(), context.for_output())); + system->vm_instance(), flags, all_modules.size(), all_modules.data(), + system->host_allocator(), context.for_output())); - return Program(std::move(scope), std::move(context)); + return Program(std::move(context), options.isolation); } std::optional Program::LookupFunction(std::string_view name) { @@ -194,7 +243,7 @@ std::optional Program::LookupFunction(std::string_view name) { // TODO: Torch import is not setting the coarse-fences abi.model on // its functions. Get it from there instead of just assuming based on // name. - return ProgramFunction(scope_, vm_context_, f, + return ProgramFunction(vm_context_, f, isolation_, ProgramInvocationModel::COARSE_FENCES); } else if (!iree_status_is_not_found(status)) { SHORTFIN_THROW_IF_ERROR(status); @@ -206,7 +255,7 @@ std::optional Program::LookupFunction(std::string_view name) { vm_context_, to_iree_string_view(name), &f); if (iree_status_is_not_found(status)) return {}; SHORTFIN_THROW_IF_ERROR(status); - return ProgramFunction(scope_, vm_context_, f); + return ProgramFunction(vm_context_, f, isolation_); } ProgramFunction Program::LookupRequiredFunction(std::string_view name) { @@ -237,6 +286,15 @@ std::vector Program::exports() const { return results; } +void Program::PrepareIsolate(Fiber &fiber) { + if (isolation_ == ProgramIsolation::NONE) return; + auto [context, isolate] = + detail::ProgramIsolate::AcquireIsolate(fiber, vm_context_, isolation_); + if (isolate) { + detail::ProgramIsolate::ReleaseIsolate(fiber, std::move(context), isolate); + } +} + // -------------------------------------------------------------------------- // // ProgramInvocation // -------------------------------------------------------------------------- // @@ -264,18 +322,23 @@ void ProgramInvocation::Deleter::operator()(ProgramInvocation *inst) { } ProgramInvocation::ProgramInvocation() = default; -ProgramInvocation::~ProgramInvocation() { - if (!scheduled()) { - // This instance was dropped on the floor before scheduling. - // Clean up the initialization parameters. - iree::vm_context_ptr drop = - iree::vm_context_ptr::steal_reference(state.params.context); +ProgramInvocation::~ProgramInvocation() { ReleaseContext(); } + +void ProgramInvocation::ReleaseContext() { + if (vm_context_) { + if (isolate_) { + detail::ProgramIsolate::ReleaseIsolate(*fiber_, std::move(vm_context_), + isolate_); + } else { + vm_context_.reset(); + } } } ProgramInvocation::Ptr ProgramInvocation::New( - std::shared_ptr scope, iree::vm_context_ptr vm_context, - iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model) { + std::shared_ptr fiber, iree::vm_context_ptr vm_context, + iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate) { auto sig = iree_vm_function_signature(&vm_function); iree_host_size_t arg_count; iree_host_size_t result_count; @@ -313,9 +376,9 @@ ProgramInvocation::Ptr ProgramInvocation::New( Ptr inst(static_cast( static_cast(inst_storage.release())), Deleter()); - inst->scope_ = std::move(scope); - inst->state.params.context = - vm_context.release(); // Ref transfer to ProgramInvocation. + inst->fiber_ = std::move(fiber); + inst->vm_context_ = std::move(vm_context); + inst->isolate_ = isolate; inst->state.params.function = vm_function; inst->state.params.invocation_model = invocation_model; inst->result_list_ = result_list; @@ -344,26 +407,27 @@ iree_status_t ProgramInvocation::FinalizeCallingConvention( // Handle post-processing invocation model setup. if (invocation_model == ProgramInvocationModel::COARSE_FENCES) { // If we have a device_selection, set up to signal the leader account. + iree_hal_fence_t *maybe_wait_fence = nullptr; if (device_selection_) { - ScopedDevice scoped_device(*scope(), device_selection_); + ScopedDevice scoped_device(*fiber(), device_selection_); auto &sched_account = - scope()->scheduler().GetDefaultAccount(scoped_device); - iree_hal_fence_t *wait_fence = this->wait_fence(); + fiber()->scheduler().GetDefaultAccount(scoped_device); + maybe_wait_fence = this->wait_fence(); iree_hal_semaphore_t *timeline_sem = sched_account.timeline_sem(); uint64_t timeline_now = sched_account.timeline_idle_timepoint(); SHORTFIN_SCHED_LOG("Invocation {}: Wait on account timeline {}@{}", static_cast(this), static_cast(timeline_sem), timeline_now); IREE_RETURN_IF_ERROR( - iree_hal_fence_insert(wait_fence, timeline_sem, timeline_now)); + iree_hal_fence_insert(maybe_wait_fence, timeline_sem, timeline_now)); signal_sem_ = sched_account.timeline_sem(); signal_timepoint_ = sched_account.timeline_acquire_timepoint(); } // Push wait fence (or null if no wait needed). ::iree::vm::ref wait_ref; - if (wait_fence_) { - ::iree::vm::retain_ref(wait_fence()); + if (maybe_wait_fence) { + wait_ref = ::iree::vm::retain_ref(maybe_wait_fence); } IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(arg_list, wait_ref)); @@ -375,7 +439,7 @@ iree_status_t ProgramInvocation::FinalizeCallingConvention( static_cast(signal_sem_), signal_timepoint_); IREE_RETURN_IF_ERROR( iree_hal_fence_create_at(signal_sem_, signal_timepoint_, - scope()->host_allocator(), &signal_ref)); + fiber()->host_allocator(), &signal_ref)); } IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(arg_list, signal_ref)); } else { @@ -390,21 +454,23 @@ iree_status_t ProgramInvocation::FinalizeCallingConvention( ProgramInvocation::Future ProgramInvocation::Invoke( ProgramInvocation::Ptr invocation) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::Invoke"); invocation->CheckNotScheduled(); - Worker &worker = invocation->scope_->worker(); + Worker &worker = invocation->fiber_->worker(); // We're about to overwrite the instance level storage for params, so move // it to the stack and access there. Params params = invocation->state.params; auto schedule = [](ProgramInvocation *raw_invocation, Worker *worker, - iree_vm_context_t *owned_context, iree_vm_function_t function, ProgramInvocationModel invocation_model, std::optional failure_future) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::InvokeAsync"); auto complete_callback = [](void *user_data, iree_loop_t loop, iree_status_t status, iree_vm_list_t *outputs) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::Complete"); // Async invocation helpfully gives us a retained reference to the // outputs, but we already have one statically on the // ProgramInvocation. So release this one, which makes it safe to @@ -417,6 +483,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( ProgramInvocation::Ptr invocation( static_cast(user_data)); ProgramInvocation *raw_invocation = invocation.get(); + raw_invocation->ReleaseContext(); if (iree_status_is_ok(status)) { raw_invocation->future_->set_result(std::move(invocation)); } else { @@ -437,7 +504,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( // Multiple steps needed to schedule need to all exit via the same // path. if (iree_status_is_ok(status)) { - status = invocation->scope()->scheduler().FlushWithStatus(); + status = invocation->fiber()->scheduler().FlushWithStatus(); } if (iree_status_is_ok(status)) { status = invocation->FinalizeCallingConvention( @@ -446,7 +513,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (iree_status_is_ok(status)) { status = iree_vm_async_invoke(worker->loop(), &invocation->state.async_invoke_state, - owned_context, function, + invocation->vm_context_.get(), function, /*flags=*/IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/nullptr, /*inputs=*/invocation->arg_list(), @@ -455,10 +522,6 @@ ProgramInvocation::Future ProgramInvocation::Invoke( /*user_data=*/invocation.get()); } - // Regardless of status, the context reference we were holding is no - // longer needed. Drop it on the floor. - iree::vm_context_ptr::steal_reference(owned_context); - // On success, then the complete callback takes ownership of the // invocation, so we release it here and return. We have to treat // the invocation as possibly deallocated at this point, since the @@ -467,9 +530,11 @@ ProgramInvocation::Future ProgramInvocation::Invoke( invocation.release(); } else if (failure_future) { // Requested to set any failure on the future. + invocation->ReleaseContext(); failure_future->set_failure(status); } else { // Synchronous: just throw. + invocation->ReleaseContext(); SHORTFIN_THROW_IF_ERROR(status); } }; @@ -481,14 +546,13 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (&worker == Worker::GetCurrent()) { // On the same worker: fast-path directly to the loop. - schedule(invocation.release(), &worker, params.context, params.function, + schedule(invocation.release(), &worker, params.function, params.invocation_model, /*failure_future=*/{}); } else { // Cross worker coordination: submit an external task to bootstrap. - auto bound_schedule = - std::bind(schedule, invocation.release(), &worker, params.context, - params.function, params.invocation_model, - /*failure_future=*/fork_future); + auto bound_schedule = std::bind(schedule, invocation.release(), &worker, + params.function, params.invocation_model, + /*failure_future=*/fork_future); worker.CallThreadsafe(bound_schedule); } @@ -509,7 +573,7 @@ iree::vm_opaque_ref ProgramInvocation::result_ref(iree_host_size_t i) { iree_hal_fence_t *ProgramInvocation::wait_fence() { if (!wait_fence_) { - wait_fence_ = scope_->scheduler().NewFence(); + wait_fence_ = fiber_->scheduler().NewFence(); } return wait_fence_.get(); } @@ -533,4 +597,137 @@ void ProgramInvocation::DeviceSelect(DeviceAffinity device_affinity) { device_selection_ |= device_affinity; } +std::string ProgramInvocation::to_s() { + return fmt::format("ProgramInvocation({}: result_size={})", + (scheduled_ ? "SCHEDULED" : "NOT_SCHEDULED"), + results_size()); +} + +// -------------------------------------------------------------------------- // +// BaseProgramParameters +// -------------------------------------------------------------------------- // + +BaseProgramParameters::~BaseProgramParameters() = default; + +// -------------------------------------------------------------------------- // +// StaticProgramParameters +// -------------------------------------------------------------------------- // + +StaticProgramParameters::StaticProgramParameters( + System &system, std::string_view parameter_scope, + iree_host_size_t max_concurrent_operations) + : host_allocator_(system.host_allocator()) { + SHORTFIN_THROW_IF_ERROR( + iree_io_parameter_index_create(host_allocator_, index_.for_output())); + SHORTFIN_THROW_IF_ERROR(iree_io_parameter_index_provider_create( + to_iree_string_view(parameter_scope), index_, max_concurrent_operations, + host_allocator_, provider_.for_output())); +} + +void StaticProgramParameters::Load(std::filesystem::path file_path, + LoadOptions options) { + SHORTFIN_TRACE_SCOPE_NAMED("StaticProgramParameters::Load"); + // Default format from extension. + if (options.format.empty()) { + options.format = file_path.extension().string(); + } + + // Open file. + iree_file_read_flags_t read_flags = IREE_FILE_READ_FLAG_DEFAULT; + if (options.mmap) { + read_flags = IREE_FILE_READ_FLAG_MMAP; + } else { + read_flags = IREE_FILE_READ_FLAG_PRELOAD; + } + iree_file_contents_t *file_contents = nullptr; + SHORTFIN_THROW_IF_ERROR(iree_file_read_contents( + file_path.string().c_str(), read_flags, host_allocator_, &file_contents)); + iree_io_file_handle_release_callback_t release_callback = { + +[](void *user_data, iree_io_file_handle_primitive_t handle_primitive) { + iree_file_contents_t *file_contents = (iree_file_contents_t *)user_data; + iree_file_contents_free(file_contents); + }, + file_contents, + }; + + // Wrap contents. + iree::io_file_handle_ptr file_handle; + iree_status_t status = iree_io_file_handle_wrap_host_allocation( + IREE_IO_FILE_ACCESS_READ, file_contents->buffer, release_callback, + host_allocator_, file_handle.for_output()); + if (!iree_status_is_ok(status)) { + iree_file_contents_free(file_contents); + SHORTFIN_THROW_IF_ERROR(status); + } + + // Parse. + SHORTFIN_THROW_IF_ERROR(iree_io_parse_file_index( + to_iree_string_view(options.format), file_handle.get(), index_.get())); +} + +// -------------------------------------------------------------------------- // +// ProgramIsolate +// -------------------------------------------------------------------------- // + +std::pair +detail::ProgramIsolate::AcquireIsolate(Fiber &fiber, + iree::vm_context_ptr root_context, + ProgramIsolation isolation) { + assert(isolation != ProgramIsolation::NONE && + "cannot AcquireIsolate when isolation == NONE"); + // Some isolation required. + detail::ProgramIsolate *isolate = nullptr; + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + auto found_it = fiber.program_isolates_.find(root_context.get()); + if (found_it != fiber.program_isolates_.end()) { + isolate = found_it->second.get(); + } + if (isolate && !isolate->fork_contexts.empty()) { + // Fast path: there is an existing isolate and a context avaialable. + auto isolated_context = std::move(isolate->fork_contexts.back()); + isolate->fork_contexts.pop_back(); + return std::make_pair(std::move(isolated_context), isolate); + } else if (!isolate) { + // Initialize a new isolate accounting struct while in the lock. + // Note that this can cause a fault for PER_FIBER mode if the call + // to fork fails below as it will leave the isolate with no available + // context and every future call will raise an exception indicating that + // the context is busy (vs trying to create a new one). This is deemed + // an acceptable situation for a system fault (which is the only reason + // a fork will fail). + auto [inserted_it, inserted] = + fiber.program_isolates_.insert(std::make_pair( + root_context.get(), + std::make_unique(root_context))); + isolate = inserted_it->second.get(); + } else if (isolation == ProgramIsolation::PER_FIBER) { + throw std::logic_error( + "Cannot make concurrent invocations of a PER_FIBER program from " + "the same Fiber. This typically means that two invocations were " + "attempted on the same program on the same fiber without an " + "await. Consider fixing adding appropriate sequencing or switching " + "to either PER_CALL or NONE isolation if appropriate for the use " + "case. This exception can also occur if the first invocation to this " + "Program failed, leaving no initialized Program for this fiber."); + } + } + + // Slow-path: fork needed (and possibly new isolate registration needed). + iree::vm_context_ptr new_context; + SHORTFIN_THROW_IF_ERROR(iree_vm_context_fork( + root_context.get(), fiber.host_allocator(), new_context.for_output())); + return std::make_pair(std::move(new_context), isolate); +} + +void detail::ProgramIsolate::ReleaseIsolate(Fiber &fiber, + iree::vm_context_ptr context, + detail::ProgramIsolate *isolate) { + assert(isolate && "attempt to release null isolate"); + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + isolate->fork_contexts.push_back(std::move(context)); + } +} + } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/program.h b/shortfin/src/shortfin/local/program.h similarity index 60% rename from libshortfin/src/shortfin/local/program.h rename to shortfin/src/shortfin/local/program.h index d3c2984e7..450b29736 100644 --- a/libshortfin/src/shortfin/local/program.h +++ b/shortfin/src/shortfin/local/program.h @@ -22,8 +22,13 @@ namespace shortfin::local { -class SHORTFIN_API Scope; -class SHORTFIN_API System; +class BaseProgramParameters; +class Fiber; +class System; + +namespace detail { +struct ProgramIsolate; +} // namespace detail enum class ProgramInvocationModel { // Uses the coarse-fences invocation model. In this model, the last two @@ -36,6 +41,24 @@ enum class ProgramInvocationModel { UNKNOWN, }; +// The level of isolation that a program has with respect to concurrent use. +enum class ProgramIsolation { + // There is no isolation: Callers are completely on their own to only issue + // concurrent invocations if supported. + NONE = 0, + + // Each fiber in the system that makes calls into the program will have its + // own shallow fork of the module. This is done on-demand and the root + // program is retained for the lifetime of any referencing fibers. + // Concurrent calls on the same fiber are considered programming errors and + // will be flagged as such at an appropriate debug level. + PER_FIBER = 1, + + // Each call triggers a shallow fork of the module. This is the most expensive + // but safest way to ensure complete isolation of stateless invocations. + PER_CALL = 2, +}; + // State related to making an invocation of a function on a program. // // Since ownership of this object is transferred to the loop/callback and @@ -64,9 +87,10 @@ class SHORTFIN_API ProgramInvocation { static_assert(sizeof(Ptr) == sizeof(void *)); using Future = TypedFuture; - static Ptr New(std::shared_ptr scope, iree::vm_context_ptr vm_context, + static Ptr New(std::shared_ptr fiber, iree::vm_context_ptr vm_context, iree_vm_function_t &vm_function, - ProgramInvocationModel invocation_model); + ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate); ProgramInvocation(const ProgramInvocation &) = delete; ProgramInvocation &operator=(const ProgramInvocation &) = delete; ProgramInvocation &operator=(ProgramInvocation &&) = delete; @@ -78,8 +102,8 @@ class SHORTFIN_API ProgramInvocation { // accessed. bool scheduled() const { return scheduled_; } - // The scope this invocation was scheduled against. - Scope *scope() const { return scope_.get(); } + // The fiber this invocation was scheduled against. + Fiber *fiber() const { return fiber_.get(); } // Adds wait barriers to the invocation. For coarse fences invocations, these // will cause execution of the function to wait until all sempahores added @@ -127,9 +151,16 @@ class SHORTFIN_API ProgramInvocation { return std::make_pair(signal_sem_, signal_timepoint_); } + std::string to_s(); + private: ProgramInvocation(); void CheckNotScheduled(); + // Eagerly releases context when it is known that no further use of it can + // be made (allowing it to be returned to a pool prior to the invocation + // actually being recycled). Object destruction also does this, but possibly + // extending the context lifetime. + void ReleaseContext(); // Returns a pointer to the trailing arg list. iree_vm_list_t *arg_list(); @@ -153,8 +184,6 @@ class SHORTFIN_API ProgramInvocation { // This must not contain entities that require destruction or cannot be // trivially copied. struct Params { - // Context is retained upon construction and released when scheduled. - iree_vm_context_t *context; iree_vm_function_t function; ProgramInvocationModel invocation_model; }; @@ -165,7 +194,9 @@ class SHORTFIN_API ProgramInvocation { iree_vm_async_invoke_state_t async_invoke_state; } state; - std::shared_ptr scope_; + std::shared_ptr fiber_; + iree::vm_context_ptr vm_context_; + detail::ProgramIsolate *isolate_; iree_vm_list_t *result_list_ = nullptr; std::optional future_; iree::hal_fence_ptr wait_fence_; @@ -183,8 +214,11 @@ class SHORTFIN_API ProgramFunction { std::string_view name() const; std::string_view calling_convention() const; ProgramInvocationModel invocation_model() const { return invocation_model_; } - - ProgramInvocation::Ptr CreateInvocation(); + // Gets the default isolation level for this function. + ProgramIsolation isolation() const { return isolation_; } + ProgramInvocation::Ptr CreateInvocation( + std::shared_ptr fiber, + std::optional isolation = std::nullopt); std::string to_s() const; @@ -192,17 +226,16 @@ class SHORTFIN_API ProgramFunction { operator iree_vm_function_t &() { return vm_function_; } private: - ProgramFunction(std::shared_ptr scope, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + ProgramFunction(iree::vm_context_ptr vm_context, + iree_vm_function_t vm_function, ProgramIsolation isolation, std::optional invocation_model = {}); static ProgramInvocationModel GetInvocationModelFromFunction( iree_vm_function_t &f); - // The context that this function was resolved against. - std::shared_ptr scope_; iree::vm_context_ptr vm_context_; iree_vm_function_t vm_function_; + ProgramIsolation isolation_; ProgramInvocationModel invocation_model_; friend class Program; }; @@ -228,19 +261,28 @@ class SHORTFIN_API ProgramModule { std::string to_s() const; iree_vm_module_t *vm_module() const { return vm_module_; } std::string_view name() const; + System &system() const { return *system_; } // Loads a dynamic bytecode module (VMFB) from a path on the file system. static ProgramModule Load(System &system, const std::filesystem::path &path, bool mmap = true); + // Creates a ProgramModule that will provide the given list of parameters + // to modules loaded after it. In IREE parlance, this produces an + // 'io_parameters' VM module. + static ProgramModule ParameterProvider( + System &system, std::span params); + // Gets the name of all exported functions. std::vector exports() const; protected: - explicit ProgramModule(iree::vm_module_ptr vm_module) - : vm_module_(std::move(vm_module)) {} + explicit ProgramModule(std::shared_ptr system, + iree::vm_module_ptr vm_module) + : system_(std::move(system)), vm_module_(std::move(vm_module)) {} private: + std::shared_ptr system_; iree::vm_module_ptr vm_module_; }; @@ -248,7 +290,7 @@ class SHORTFIN_API ProgramModule { // having functions invoked on them. While the underlying programming model // is a bit broader and can be exploited in various advanced way, generally, // a program should be thought of as a fiber, and it is therefore bound to -// a Scope, which provides a logical thread of execution. By default, all +// a Fiber, which provides a logical thread of execution. By default, all // invocations will take place in logical order (there are certain ways to // violate this constraint safely that are provided for separately). // @@ -260,15 +302,19 @@ class SHORTFIN_API Program { struct Options { Options() {} + // Ordered list of devices to bind this program to. + std::span devices; + + // The isolation level to apply to program invocation. + ProgramIsolation isolation = ProgramIsolation::PER_FIBER; + // Enables program-wide execution tracing (to stderr). bool trace_execution = false; }; - // Loads a program attached to a scope with a list of user provided modules - // and options. - static Program Load(std::shared_ptr scope, - std::span modules, - Options options = {}); + // Load a program from a list of modules and options. + static Program Load(std::span modules, + Options &&options); // Looks up a public function by fully qualified name (i.e. module.function). // Returns nothing if not found. @@ -281,14 +327,97 @@ class SHORTFIN_API Program { // Gets the name of all exported functions. std::vector exports() const; + // Gets the default isolation level for all functions in this program. + ProgramIsolation isolation() const { return isolation_; } + + // Eagerly does any per-fiber isolation preparation for the program at a + // convenient point (usually init time) to avoid first-invocation overhead. + void PrepareIsolate(Fiber &fiber); + private: - explicit Program(std::shared_ptr scope, - iree::vm_context_ptr vm_context) - : scope_(std::move(scope)), vm_context_(std::move(vm_context)) {} - std::shared_ptr scope_; + explicit Program(iree::vm_context_ptr vm_context, ProgramIsolation isolation) + : vm_context_(std::move(vm_context)), isolation_(isolation) {} + iree::vm_context_ptr vm_context_; - friend class Scope; + ProgramIsolation isolation_; + friend class Fiber; +}; + +// Base class for classes that can be interpreted as a provider of program +// parameters. +class SHORTFIN_API BaseProgramParameters { + public: + BaseProgramParameters() = default; + BaseProgramParameters(const BaseProgramParameters &) = delete; + BaseProgramParameters &operator=(const BaseProgramParameters &) = delete; + virtual ~BaseProgramParameters(); + + operator iree_io_parameter_provider_t *() { return provider_.get(); } + + protected: + iree::io_parameter_provider_ptr provider_; +}; + +// Pool of parameters that can be made available to ProgramModules. Each +// instance represents a unique "parameter scope" name which corresponds to +// some set of parameters that one or more ProgramModules were compiled to +// depend on. +// +// This class wraps the lower level iree_io_parameter_provider_t and a single +// iree_io_parameter_index_t. While the underlying APIs have many ways that +// they can be composed, populated and manipulated, this facility presumes +// that has been done elsewhere and primarily targets referencing them from +// somewhere statically known. More advanced use cases will be served by +// additional APIs. +class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters { + public: + StaticProgramParameters( + System &system, std::string_view parameter_scope, + iree_host_size_t max_concurrent_operations = + IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS); + + struct LoadOptions { + // File format. If empty, then it is inferred from the file name or + // contents. Can be one of "irpa", "gguf", "safetensors", etc. + std::string format; + + // Whether the backing file can be read. + bool readable = true; + // Whether the backing file can be written. + bool writable = false; + // Whether to mmap the file. + bool mmap = true; + }; + // Load parameters from a supported file format, applying no name + // transformation. + void Load(std::filesystem::path file_path, LoadOptions options); + void Load(std::filesystem::path file_path) { Load(file_path, LoadOptions()); } + + private: + iree_allocator_t host_allocator_; + iree::io_parameter_index_ptr index_; +}; + +namespace detail { +// See Fiber::program_isolates_. +struct ProgramIsolate { + ProgramIsolate(iree::vm_context_ptr parent_context) + : parent_context(std::move(parent_context)) {} + iree::vm_context_ptr parent_context; + std::vector fork_contexts; + + // Acquires an isolate for the given fiber. This will return a context which + // may be the original program context or may be a forked child that is + // available for use. It is only valid to call this when isolation != NONE. + static std::pair + AcquireIsolate(Fiber &fiber, iree::vm_context_ptr root_context, + ProgramIsolation isolation); + + // Releases an isolate obtained from a fiber in AcquireIsolate. + static void ReleaseIsolate(Fiber &fiber, iree::vm_context_ptr context, + ProgramIsolate *isolate); }; +}; // namespace detail } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/program_interfaces.h b/shortfin/src/shortfin/local/program_interfaces.h similarity index 98% rename from libshortfin/src/shortfin/local/program_interfaces.h rename to shortfin/src/shortfin/local/program_interfaces.h index b376c75d3..8ab46ab24 100644 --- a/libshortfin/src/shortfin/local/program_interfaces.h +++ b/shortfin/src/shortfin/local/program_interfaces.h @@ -16,7 +16,7 @@ namespace shortfin::local { -class SHORTFIN_API ProgramInvocation; +class ProgramInvocation; // The type of barrier that should be managed for a program resource. enum class ProgramResourceBarrier { diff --git a/libshortfin/src/shortfin/local/scheduler.cc b/shortfin/src/shortfin/local/scheduler.cc similarity index 75% rename from libshortfin/src/shortfin/local/scheduler.cc rename to shortfin/src/shortfin/local/scheduler.cc index 6f5581270..883951a20 100644 --- a/libshortfin/src/shortfin/local/scheduler.cc +++ b/shortfin/src/shortfin/local/scheduler.cc @@ -6,7 +6,7 @@ #include "shortfin/local/scheduler.h" -#include "shortfin/local/scope.h" +#include "shortfin/local/fiber.h" #include "shortfin/local/system.h" #include "shortfin/support/logging.h" @@ -60,43 +60,44 @@ void Account::active_deps_extend(iree_hal_semaphore_list_t sem_list) { } } -CompletionEvent Account::OnSync() { - bool hack_has_working_wait_source = false; - if (hack_has_working_wait_source) { - return CompletionEvent(sem_, idle_timepoint_); - } else { - // TODO: Burn this path with fire! No attempt has been made to make this - // particularly good: the backend is being implemented now to export - // HAL semaphores via iree_hal_semaphore_await, and that should be used - // when supported. This is merely here so as to unblock local progress. - iree::shared_event::ref satisfied(false); - iree::hal_semaphore_ptr sem = sem_; - auto idle_timepoint = idle_timepoint_; - SHORTFIN_SCHED_LOG("OnSync::Wait({}@{})", static_cast(sem.get()), - idle_timepoint); - scheduler_.system().blocking_executor().Schedule( - [sem = std::move(sem), idle_timepoint, satisfied]() { - iree_status_t status = iree_hal_semaphore_wait( - sem, idle_timepoint, iree_infinite_timeout()); - IREE_CHECK_OK(status); - SHORTFIN_SCHED_LOG("OnSync::Complete({}@{})", - static_cast(sem.get()), idle_timepoint); - satisfied->set(); - }); - return CompletionEvent(satisfied); - } +VoidFuture Account::OnSync() { + SHORTFIN_TRACE_SCOPE_NAMED("Account::OnSync"); + // TODO: Burn this path with fire! No attempt has been made to make this + // particularly good: the backend is being implemented now to export + // HAL semaphores via iree_hal_semaphore_await, and that should be used + // when supported. This is merely here so as to unblock local progress. + // This should be something like: + // return CompletionEvent(sem_, idle_timepoint_); + iree::hal_semaphore_ptr sem = sem_; + auto idle_timepoint = idle_timepoint_; + SHORTFIN_SCHED_LOG("OnSync::Wait({}@{})", static_cast(sem.get()), + idle_timepoint); + VoidFuture future; + scheduler_.system().blocking_executor().Schedule([sem = std::move(sem), + idle_timepoint, future]() { + iree_status_t status = + iree_hal_semaphore_wait(sem, idle_timepoint, iree_infinite_timeout()); + if (!iree_status_is_ok(status)) { + const_cast(future).set_failure(status); + } else { + SHORTFIN_SCHED_LOG("OnSync::Complete({}@{})", + static_cast(sem.get()), idle_timepoint); + const_cast(future).set_success(); + } + }); + return future; } // -------------------------------------------------------------------------- // // TimelineResource // -------------------------------------------------------------------------- // -TimelineResource::TimelineResource(std::shared_ptr scope, +TimelineResource::TimelineResource(std::shared_ptr fiber, size_t semaphore_capacity) - : scope_(std::move(scope)) { + : fiber_(std::move(fiber)) { logging::construct("TimelineResource", this); SHORTFIN_THROW_IF_ERROR( - iree_hal_fence_create(semaphore_capacity, scope_->host_allocator(), + iree_hal_fence_create(semaphore_capacity, fiber_->host_allocator(), use_barrier_fence_.for_output())); } @@ -111,7 +112,7 @@ void TimelineResource::use_barrier_insert(iree_hal_semaphore_t *sem, } iree_allocator_t TimelineResource::host_allocator() { - return scope_->host_allocator(); + return fiber_->host_allocator(); } // -------------------------------------------------------------------------- // @@ -131,9 +132,11 @@ Scheduler::~Scheduler() { } } -void Scheduler::Initialize(std::span devices) { - for (Device *device : devices) { - accounts_.emplace_back(*this, device); +void Scheduler::Initialize( + std::span> devices) { + SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::Initialize"); + for (auto &it : devices) { + accounts_.emplace_back(*this, it.second); } for (Account &account : accounts_) { @@ -164,6 +167,7 @@ Account &Scheduler::GetDefaultAccount(ScopedDevice &device) { void Scheduler::AppendCommandBuffer(ScopedDevice &device, TransactionType tx_type, std::function callback) { + SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::AppendCommandBuffer"); Account &account = GetDefaultAccount(device); auto needed_affinity_bits = device.affinity().queue_affinity(); SHORTFIN_SCHED_LOG( @@ -207,6 +211,22 @@ void Scheduler::AppendCommandBuffer(ScopedDevice &device, account.active_queue_affinity_bits_ = needed_affinity_bits; account.active_deps_ = std::move(new_active_deps); account.active_command_buffer_ = std::move(new_cb); + + // Sence the command buffer will be submitted to signal the next + // timepoint on the main timeline, we must depend on its current value + // to be value (semaphores must strictly advance). This has the effect of + // serializing all submissions, which while correct, is not a particularly + // enlightened scheduling policy. + // TODO: Revisit this when scheduling is generalized and consider that such + // serialization be retained only as a debug feature. + iree_hal_semaphore_t *main_timeline_sem = account.sem_.get(); + account.active_deps_extend(iree_hal_semaphore_list_t{ + .count = 1, + .semaphores = &main_timeline_sem, + .payload_values = &account.idle_timepoint_, + }); + + // Signal an advance of the main timeline. account.idle_timepoint_ += 1; SHORTFIN_SCHED_LOG( " : New command buffer (category={}, idle_timepoint={})", category, @@ -225,6 +245,7 @@ void Scheduler::AppendCommandBuffer(ScopedDevice &device, } iree_status_t Scheduler::FlushWithStatus() noexcept { + SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::FlushWithStatus"); // This loop is optimized for a small number of accounts, where it is // fine to just linearly probe. If this ever becomes cumbersome, we can // maintain a dirty list which is appended to when an account transitions @@ -258,9 +279,8 @@ iree_status_t Scheduler::FlushWithStatus() noexcept { .semaphores = &signal_sem, .payload_values = &signal_timepoint, }, - /*command_buffer_count=*/1, - /*command_buffers=*/&active_command_buffer, - /*binding_tables=*/&binding_tables)); + /*command_buffers=*/active_command_buffer, + /*binding_tables=*/binding_tables)); account.Reset(); } return iree_ok_status(); diff --git a/libshortfin/src/shortfin/local/scheduler.h b/shortfin/src/shortfin/local/scheduler.h similarity index 94% rename from libshortfin/src/shortfin/local/scheduler.h rename to shortfin/src/shortfin/local/scheduler.h index 5c514b1bb..cd493a41b 100644 --- a/libshortfin/src/shortfin/local/scheduler.h +++ b/shortfin/src/shortfin/local/scheduler.h @@ -16,14 +16,14 @@ namespace shortfin::local { -class SHORTFIN_API Scope; -class SHORTFIN_API ScopedDevice; -class SHORTFIN_API System; +class Fiber; +class ScopedDevice; +class System; namespace detail { -class SHORTFIN_API Account; -class SHORTFIN_API Scheduler; +class Account; +class Scheduler; // Transactions are accumulated into a command buffer by type and in // auto-flush mode, the command buffer is submitted upon a change of type. @@ -68,7 +68,7 @@ enum class TransactionMode { // Since TimelineResources are shared (i.e. across subspan storage, etc), // they are modeled as reference counted (using non atomics, since this is // "scoped" same thread access). They must only be held in a context that -// is keeping the containing Scope alive. +// is keeping the containing Fiber alive. // // Note to the future: in discussing the above, many cases were noted where // a more advanced programming model would be desirable in order to exercise @@ -145,7 +145,7 @@ class SHORTFIN_API TimelineResource { iree_allocator_t host_allocator(); private: - TimelineResource(std::shared_ptr scope, size_t semaphore_capacity); + TimelineResource(std::shared_ptr fiber, size_t semaphore_capacity); ~TimelineResource(); void Retain() { refcnt_++; } void Release() { @@ -154,8 +154,8 @@ class SHORTFIN_API TimelineResource { int refcnt_ = 0; - // Back reference to the owning scope. - std::shared_ptr scope_; + // Back reference to the owning fiber. + std::shared_ptr fiber_; // Non-owning mutation barrier semaphore and timepoint. The fact that this // is a single semaphore is an implementation detail that may be generalized @@ -199,7 +199,7 @@ class SHORTFIN_API Account { // Returns a future that is satisfied when the timeline of this account // reaches its current idle timepoint (i.e. all currently pending work // is complete). - CompletionEvent OnSync(); + VoidFuture OnSync(); private: void Initialize(); @@ -232,7 +232,7 @@ class SHORTFIN_API Account { friend class Scheduler; }; -// Handles scheduling state for a scope. +// Handles scheduling state for a fiber. class SHORTFIN_API Scheduler { public: Scheduler(System &system); @@ -261,9 +261,9 @@ class SHORTFIN_API Scheduler { // Gets a fresh TimelineResource which can be used for tracking resource // read/write and setting barriers. Note that these are all allocated fresh // on each call today but may be pooled in the future. - TimelineResource::Ref NewTimelineResource(std::shared_ptr scope) { + TimelineResource::Ref NewTimelineResource(std::shared_ptr fiber) { return TimelineResource::Ref( - new TimelineResource(std::move(scope), semaphore_count_)); + new TimelineResource(std::move(fiber), semaphore_count_)); } // Creates a new fence with capacity for all semaphores that are extant at @@ -273,7 +273,8 @@ class SHORTFIN_API Scheduler { System &system() { return system_; } private: - void Initialize(std::span devices); + void Initialize( + std::span> devices); System &system_; // Each distinct hal device gets an account. @@ -286,7 +287,7 @@ class SHORTFIN_API Scheduler { TransactionMode tx_mode_ = TransactionMode::EAGER; TransactionType current_tx_type_ = TransactionType::NONE; - friend class local::Scope; + friend class local::Fiber; }; } // namespace detail diff --git a/libshortfin/src/shortfin/local/system.cc b/shortfin/src/shortfin/local/system.cc similarity index 67% rename from libshortfin/src/shortfin/local/system.cc rename to shortfin/src/shortfin/local/system.cc index 23ecbc088..ef31bb001 100644 --- a/libshortfin/src/shortfin/local/system.cc +++ b/shortfin/src/shortfin/local/system.cc @@ -8,7 +8,8 @@ #include -#include "shortfin/local/scope.h" +#include "iree/hal/utils/allocators.h" +#include "shortfin/local/fiber.h" #include "shortfin/support/logging.h" namespace shortfin::local { @@ -19,6 +20,7 @@ namespace shortfin::local { System::System(iree_allocator_t host_allocator) : host_allocator_(host_allocator) { + SHORTFIN_TRACE_SCOPE_NAMED("System::System"); logging::construct("System", this); SHORTFIN_THROW_IF_ERROR(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, host_allocator_, @@ -28,6 +30,7 @@ System::System(iree_allocator_t host_allocator) } System::~System() { + SHORTFIN_TRACE_SCOPE_NAMED("System::~System"); logging::destruct("System", this); bool needs_shutdown = false; { @@ -60,14 +63,16 @@ System::~System() { } void System::Shutdown() { + SHORTFIN_TRACE_SCOPE_NAMED("System::Shutdown"); // Stop workers. - std::vector> local_workers; + std::vector local_workers; { iree::slim_mutex_lock_guard guard(lock_); if (!initialized_ || shutdown_) return; shutdown_ = true; - workers_by_name_.clear(); - local_workers.swap(workers_); + for (auto &w : workers_) { + local_workers.push_back(w.get()); + } } // Worker drain and shutdown. @@ -80,16 +85,17 @@ void System::Shutdown() { } } blocking_executor_.Kill(); - local_workers.clear(); } -std::shared_ptr System::CreateScope(Worker &worker, +std::shared_ptr System::CreateFiber(Worker &worker, std::span devices) { iree::slim_mutex_lock_guard guard(lock_); - return std::make_shared(shared_ptr(), worker, devices); + AssertRunning(); + return std::make_shared(shared_ptr(), worker, devices); } void System::InitializeNodes(int node_count) { + iree::slim_mutex_lock_guard guard(lock_); AssertNotInitialized(); if (!nodes_.empty()) { throw std::logic_error("System::InitializeNodes called more than once"); @@ -100,19 +106,26 @@ void System::InitializeNodes(int node_count) { } } -Queue &System::CreateQueue(Queue::Options options) { - iree::slim_mutex_lock_guard guard(lock_); - if (queues_by_name_.count(options.name) != 0) { - throw std::invalid_argument(fmt::format( - "Cannot create queue with duplicate name '{}'", options.name)); +QueuePtr System::CreateQueue(Queue::Options options) { + if (options.name.empty()) { + // Fast, lock-free path for anonymous queue creation. + return Queue::Create(std::move(options)); + } else { + // Lock and allocate a named queue. + iree::slim_mutex_lock_guard guard(lock_); + AssertRunning(); + if (queues_by_name_.count(options.name) != 0) { + throw std::invalid_argument(fmt::format( + "Cannot create queue with duplicate name '{}'", options.name)); + } + queues_.push_back(Queue::Create(std::move(options))); + Queue *unowned_queue = queues_.back().get(); + queues_by_name_[unowned_queue->options().name] = unowned_queue; + return *unowned_queue; } - queues_.push_back(std::make_unique(std::move(options))); - Queue *unowned_queue = queues_.back().get(); - queues_by_name_[unowned_queue->options().name] = unowned_queue; - return *unowned_queue; } -Queue &System::named_queue(std::string_view name) { +QueuePtr System::named_queue(std::string_view name) { iree::slim_mutex_lock_guard guard(lock_); auto it = queues_by_name_.find(name); if (it == queues_by_name_.end()) { @@ -140,6 +153,7 @@ Worker &System::CreateWorker(Worker::Options options) { Worker *unowned_worker; { iree::slim_mutex_lock_guard guard(lock_); + AssertRunning(); if (options.name == std::string_view("__init__")) { throw std::invalid_argument( "Cannot create worker '__init__' (reserved name)"); @@ -161,6 +175,7 @@ Worker &System::CreateWorker(Worker::Options options) { Worker &System::init_worker() { iree::slim_mutex_lock_guard guard(lock_); + AssertRunning(); auto found_it = workers_by_name_.find("__init__"); if (found_it != workers_by_name_.end()) { return *found_it->second; @@ -178,6 +193,7 @@ Worker &System::init_worker() { void System::InitializeHalDriver(std::string_view moniker, iree::hal_driver_ptr driver) { + iree::slim_mutex_lock_guard guard(lock_); AssertNotInitialized(); auto &slot = hal_drivers_[moniker]; if (slot) { @@ -188,6 +204,7 @@ void System::InitializeHalDriver(std::string_view moniker, } void System::InitializeHalDevice(std::unique_ptr device) { + iree::slim_mutex_lock_guard guard(lock_); AssertNotInitialized(); auto device_name = device->name(); auto [it, success] = named_devices_.try_emplace(device_name, device.get()); @@ -199,6 +216,14 @@ void System::InitializeHalDevice(std::unique_ptr device) { retained_devices_.push_back(std::move(device)); } +Device *System::FindDeviceByName(std::string_view name) { + auto it = named_devices_.find(name); + if (it == named_devices_.end()) { + return nullptr; + } + return it->second; +} + void System::FinishInitialization() { iree::slim_mutex_lock_guard guard(lock_); AssertNotInitialized(); @@ -207,6 +232,7 @@ void System::FinishInitialization() { int64_t System::AllocateProcess(detail::BaseProcess *p) { iree::slim_mutex_lock_guard guard(lock_); + AssertRunning(); int pid = next_pid_++; processes_by_pid_[pid] = p; return pid; @@ -217,4 +243,42 @@ void System::DeallocateProcess(int64_t pid) { processes_by_pid_.erase(pid); } +// -------------------------------------------------------------------------- // +// SystemBuilder +// -------------------------------------------------------------------------- // + +void SystemBuilder::ConfigureAllocators(const std::vector &specs, + iree_hal_device_t *device, + std::string_view device_debug_desc) { + if (specs.empty()) return; + std::vector spec_views; + spec_views.reserve(specs.size()); + for (auto &spec : specs) { + spec_views.push_back(to_iree_string_view(spec)); + } + + logging::info("Configure allocator {} = [{}]", device_debug_desc, + fmt::join(specs, " ; ")); + + SHORTFIN_THROW_IF_ERROR(iree_hal_configure_allocator_from_specs( + spec_views.size(), spec_views.data(), device)); +} + +std::vector SystemBuilder::GetConfigAllocatorSpecs( + std::optional specific_config_key) { + std::optional value; + if (specific_config_key) { + value = config_options().GetOption(*specific_config_key); + } + if (!value) { + value = config_options().GetOption("allocators"); + } + if (!value) { + return {}; + } + + auto split_views = ConfigOptions::Split(*value, ';'); + return std::vector(split_views.begin(), split_views.end()); +} + } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/system.h b/shortfin/src/shortfin/local/system.h similarity index 67% rename from libshortfin/src/shortfin/local/system.h rename to shortfin/src/shortfin/local/system.h index 82fd4b489..df0c45c4d 100644 --- a/libshortfin/src/shortfin/local/system.h +++ b/shortfin/src/shortfin/local/system.h @@ -19,6 +19,7 @@ #include "shortfin/local/worker.h" #include "shortfin/support/api.h" #include "shortfin/support/blocking_executor.h" +#include "shortfin/support/config.h" #include "shortfin/support/iree_concurrency.h" #include "shortfin/support/iree_helpers.h" #include "shortfin/support/stl_extras.h" @@ -29,7 +30,7 @@ namespace detail { class BaseProcess; } // namespace detail -class Scope; +class Fiber; class System; class SystemBuilder; @@ -50,19 +51,19 @@ class SystemBuilder; // tenant), and all owning references to the System are via // `std::shared_ptr`. Every object in the system must either be // a managed child of the system or own a system reference. -// 2. Scope: Binds any number of devices to a coherent schedule, rooted on +// 2. Fiber: Binds any number of devices to a coherent schedule, rooted on // a Worker. Scopes are independent of the system and there are generally -// as many as needed logical concurrency in the application. Each scope +// as many as needed logical concurrency in the application. Each fiber // holds a system reference by way of a `std::shared_ptr`. These // are still heavy-weight objects mostly created at initialization time -// and are therefore managed held as a `std::shared_ptr` by anything +// and are therefore managed held as a `std::shared_ptr` by anything // that depends on them. // 3. TimelineResource: Any resource in the system (i.e. buffer, // synchronization, object, etc) will hold a unique TimelineResource. These // are light-weight objects managed via intrusive reference counting by // their contained `TimelineResource::Ref` class. Each `TimelineResource` -// maintains a `std::shared_ptr` back reference to its owning -// scope. +// maintains a `std::shared_ptr` back reference to its owning +// fiber. // // Leaf objects can have any lifetime that they wish, so long as they maintain // an appropriate ownership reference into the System hierarchy above. This @@ -82,6 +83,15 @@ class SHORTFIN_API System : public std::enable_shared_from_this { System(const System &) = delete; ~System(); + // One shot creation factory that is the equivalent of: + // SystemBuilder::ForSystem( + // host_allocator, system_type, + // std::move(config_options))->CreateSystem() + // Undef validation will be done on the config options prior to returning. + static std::shared_ptr Create(iree_allocator_t host_allocator, + std::string_view system_type, + ConfigOptions config_options = {}); + // Explicit shutdown (vs in destructor) is encouraged. void Shutdown(); @@ -95,16 +105,15 @@ class SHORTFIN_API System : public std::enable_shared_from_this { // Topology access. std::span nodes() { return {nodes_}; } std::span devices() { return {devices_}; } - const std::unordered_map &named_devices() { + const std::unordered_map named_devices() { return named_devices_; } + Device *FindDeviceByName(std::string_view name); // Queue access. - Queue &CreateQueue(Queue::Options options); - Queue &named_queue(std::string_view name); - const std::unordered_map named_queues() { - return queues_by_name_; - } + QueuePtr CreateQueue(Queue::Options options); + QueuePtr CreateQueue() { return CreateQueue(Queue::Options()); } + QueuePtr named_queue(std::string_view name); // Access the system wide blocking executor thread pool. This can be used // to execute thunks that can block on a dedicated thread and is needed @@ -112,10 +121,10 @@ class SHORTFIN_API System : public std::enable_shared_from_this { BlockingExecutor &blocking_executor() { return blocking_executor_; } // Scopes. - // Creates a new Scope bound to this System (it will internally + // Creates a new Fiber bound to this System (it will internally // hold a reference to this instance). All devices in system order will be - // added to the scope. - std::shared_ptr CreateScope(Worker &worker, + // added to the fiber. + std::shared_ptr CreateFiber(Worker &worker, std::span devices); // Creates and starts a worker (if it is configured to run in a thread). @@ -141,13 +150,20 @@ class SHORTFIN_API System : public std::enable_shared_from_this { void FinishInitialization(); private: - void AssertNotInitialized() { + void AssertNotInitialized() SHORTFIN_REQUIRES_LOCK(lock_) { if (initialized_) { throw std::logic_error( "System::Initialize* methods can only be called during " "initialization"); } } + void AssertRunning() SHORTFIN_REQUIRES_LOCK(lock_) { + if (!initialized_ || shutdown_) { + throw std::logic_error( + "System manipulation methods can only be called when initialized and " + "not shutdown"); + } + } // Allocates a process in the process table and returns its new pid. // This is done on process construction. Note that it acquires the @@ -173,7 +189,9 @@ class SHORTFIN_API System : public std::enable_shared_from_this { // after initialization, but mainly this is for keeping them alive. std::unordered_map hal_drivers_; - // Map of device name to a SystemDevice. + // Map of device name to a SystemDevice. Note that devices are immortal and + // enumerated at initialization time. As such, they are accessed without + // locking. std::vector> retained_devices_; std::unordered_map named_devices_; std::vector devices_; @@ -185,22 +203,25 @@ class SHORTFIN_API System : public std::enable_shared_from_this { BlockingExecutor blocking_executor_; // Queues. - std::vector> queues_; - std::unordered_map queues_by_name_; + std::vector> queues_ SHORTFIN_GUARDED_BY(lock_); + std::unordered_map queues_by_name_ + SHORTFIN_GUARDED_BY(lock_); // Workers. - std::vector> workers_; + std::vector> workers_ SHORTFIN_GUARDED_BY(lock_); std::vector> worker_initializers_; - std::unordered_map workers_by_name_; + std::unordered_map workers_by_name_ + SHORTFIN_GUARDED_BY(lock_); // Process management. - int next_pid_ = 1; - std::unordered_map processes_by_pid_; + int next_pid_ SHORTFIN_GUARDED_BY(lock_) = 1; + std::unordered_map processes_by_pid_ + SHORTFIN_GUARDED_BY(lock_); // Whether initialization is complete. If true, various low level // mutations are disallowed. - bool initialized_ = false; - bool shutdown_ = false; + bool initialized_ SHORTFIN_GUARDED_BY(lock_) = false; + bool shutdown_ SHORTFIN_GUARDED_BY(lock_) = false; friend class detail::BaseProcess; }; @@ -209,18 +230,49 @@ using SystemPtr = std::shared_ptr; // Base class for configuration objects for setting up a System. class SHORTFIN_API SystemBuilder { public: - SystemBuilder(iree_allocator_t host_allocator) - : host_allocator_(host_allocator) {} + SystemBuilder(iree_allocator_t host_allocator, + ConfigOptions config_options = {}) + : host_allocator_(host_allocator), + config_options_(std::move(config_options)) {} SystemBuilder() : SystemBuilder(iree_allocator_system()) {} virtual ~SystemBuilder() = default; + // Creates a SystemBuilder subclass for a given named system (i.e. + // "hostcpu", "amdgpu", etc). + static std::unique_ptr ForSystem( + iree_allocator_t host_allocator, std::string_view system_type, + ConfigOptions config_options = {}); + iree_allocator_t host_allocator() { return host_allocator_; } + const ConfigOptions &config_options() const { return config_options_; } // Construct a System virtual SystemPtr CreateSystem() = 0; + protected: + // Uses the iree_hal_configure_allocator_from_specs() API to configure + // allocators for a device. The specs are parsed from the given config_key + // if it exists and take the form: + // some_allocator + // some_allocator:key=value + // some_allocator:key=value,key=value + // some_allocator:key=value,key=value;other_allocator:key=value + void ConfigureAllocators(const std::vector &specs, + iree_hal_device_t *device, + std::string_view device_debug_desc); + + // Gets a list of allocator specs from the config. If `specific_config_key` + // is given, this will be consulted first and used if available. Otherwise, + // "allocators" will be used. For SystemBuilders that handle multiple + // device types, the specific key will be something like "amdgpu_allocators" + // or "hostcpu_allocators" and will be used to allow independently scoped + // allocator specs. + std::vector GetConfigAllocatorSpecs( + std::optional specific_config_key); + private: const iree_allocator_t host_allocator_; + ConfigOptions config_options_; }; } // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/systems/CMakeLists.txt b/shortfin/src/shortfin/local/systems/CMakeLists.txt new file mode 100644 index 000000000..b1c9d8b44 --- /dev/null +++ b/shortfin/src/shortfin/local/systems/CMakeLists.txt @@ -0,0 +1,64 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +set(_SYSTEM_COMPONENTS) + +shortfin_cc_component( + NAME + shortfin_systems_host + HDRS + host.h + SRCS + host.cc + COMPONENTS + shortfin_local + shortfin_support + DEFINES + SHORTFIN_HAVE_HOSTCPU + DEPS + iree_hal_drivers_local_task_task_driver + iree_hal_local_executable_loader + iree_hal_local_executable_plugin + iree_hal_local_executable_plugin_manager + iree_hal_local_loaders_registration_registration + iree_hal_local_local + iree_task_api + iree_task_task +) +list(APPEND _SYSTEM_COMPONENTS shortfin_systems_host) +target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_HOSTCPU) + +if(SHORTFIN_SYSTEMS_AMDGPU) + shortfin_cc_component( + NAME + shortfin_systems_amdgpu + HDRS + amdgpu.h + SRCS + amdgpu.cc + DEFINES + SHORTFIN_HAVE_AMDGPU + COMPONENTS + shortfin_local + shortfin_support + DEPS + iree_hal_drivers_hip_hip + ) + list(APPEND _SYSTEM_COMPONENTS shortfin_systems_amdgpu) + target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_AMDGPU) +endif() + +shortfin_cc_component( + NAME + shortfin_systems_factory + SRCS + factory.cc + COMPONENTS + ${_SYSTEM_COMPONENTS} +) + +set_property(GLOBAL APPEND + PROPERTY SHORTFIN_LIB_OPTIONAL_COMPONENTS ${_SYSTEM_COMPONENTS}) diff --git a/shortfin/src/shortfin/local/systems/amdgpu.cc b/shortfin/src/shortfin/local/systems/amdgpu.cc new file mode 100644 index 000000000..cecedd1a0 --- /dev/null +++ b/shortfin/src/shortfin/local/systems/amdgpu.cc @@ -0,0 +1,253 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/local/systems/amdgpu.h" + +#include "shortfin/support/logging.h" +#include "shortfin/support/sysconfig.h" + +namespace shortfin::local::systems { + +namespace { +const std::string_view SYSTEM_DEVICE_CLASS = "amdgpu"; +const std::string_view LOGICAL_DEVICE_CLASS = "gpu"; +const std::string_view HAL_DRIVER_PREFIX = "hip"; +} // namespace + +AMDGPUSystemBuilder::AMDGPUSystemBuilder(iree_allocator_t host_allocator, + ConfigOptions options) + : HostCPUSystemBuilder(host_allocator, std::move(options)), + available_devices_(host_allocator) { + iree_hal_hip_device_params_initialize(&default_device_params_); + InitializeDefaultSettings(); + config_options().ValidateUndef(); +} + +AMDGPUSystemBuilder::~AMDGPUSystemBuilder() = default; + +void AMDGPUSystemBuilder::InitializeDefaultSettings() { + // Library search path. + std::optional search_path = + config_options().GetOption("amdgpu_hip_dylib_path"); + if (!search_path) { + // Fall back to the raw "IREE_HIP_DYLIB_PATH" for compatibility with IREE + // tools. + search_path = config_options().GetRawEnv("IREE_HIP_DYLIB_PATH"); + } + if (search_path) { + for (auto entry : config_options().Split(*search_path, ';')) { + hip_lib_search_paths_.push_back(std::string(entry)); + } + } + + // Gets allocator specs from either "amdgpu_allocators" or the fallback + // "allocators". + amdgpu_allocator_specs_ = GetConfigAllocatorSpecs("amdgpu_allocators"); + + // Whether to use async allocations if the device supports them (default + // true). There are various reasons to disable this in different usage + // scenarios. + default_device_params_.async_allocations = + config_options().GetBool("amdgpu_async_allocations", true); + + // HIP options. + // "amdgpu_tracing_level": Matches IREE flag --hip_tracing: + // Permissible values are: + // 0 : stream tracing disabled. + // 1 : coarse command buffer level tracing enabled. + // 2 : fine-grained kernel level tracing enabled. + auto tracing_level = + config_options().GetInt("amdgpu_tracing_level", /*non_negative=*/true); + default_device_params_.stream_tracing = tracing_level ? *tracing_level : 2; + + // Override logical_devices_per_physical_device if present. + auto logical_devices_per_physical_device = config_options().GetInt( + "amdgpu_logical_devices_per_physical_device", /*non_negative=*/true); + if (logical_devices_per_physical_device) { + logical_devices_per_physical_device_ = *logical_devices_per_physical_device; + } + + // CPU devices. + cpu_devices_enabled_ = config_options().GetBool("amdgpu_cpu_devices_enabled"); + + // Visible devices. + std::optional visible_devices_option = + config_options().GetOption("amdgpu_visible_devices"); + if (visible_devices_option) { + auto splits = config_options().Split(*visible_devices_option, ';'); + visible_devices_.emplace(); + for (auto split_sv : splits) { + visible_devices_->emplace_back(split_sv); + } + } +} + +void AMDGPUSystemBuilder::Enumerate() { + if (hip_hal_driver_) return; + SHORTFIN_TRACE_SCOPE_NAMED("AMDGPUSystemBuilder::Enumerate"); + + iree_hal_hip_driver_options_t driver_options; + iree_hal_hip_driver_options_initialize(&driver_options); + + // Search path. + std::vector hip_lib_search_path_sv; + hip_lib_search_path_sv.resize(hip_lib_search_paths_.size()); + for (size_t i = 0; i < hip_lib_search_paths_.size(); ++i) { + hip_lib_search_path_sv[i].data = hip_lib_search_paths_[i].data(); + hip_lib_search_path_sv[i].size = hip_lib_search_paths_[i].size(); + } + driver_options.hip_lib_search_paths = hip_lib_search_path_sv.data(); + driver_options.hip_lib_search_path_count = hip_lib_search_path_sv.size(); + + SHORTFIN_THROW_IF_ERROR(iree_hal_hip_driver_create( + IREE_SV("hip"), &driver_options, &default_device_params_, + host_allocator(), hip_hal_driver_.for_output())); + + // Get available devices and filter into visible_devices_. + SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices( + hip_hal_driver_, host_allocator(), &available_devices_count_, + available_devices_.for_output())); + for (iree_host_size_t i = 0; i < available_devices_count_; ++i) { + iree_hal_device_info_t *info = &available_devices_.get()[i]; + logging::debug("Enumerated available AMDGPU device: {} ({})", + to_string_view(info->path), to_string_view(info->name)); + } +} + +std::vector AMDGPUSystemBuilder::GetAvailableDeviceIds() { + Enumerate(); + std::vector results; + for (iree_host_size_t i = 0; i < available_devices_count_; ++i) { + iree_hal_device_info_t *info = &available_devices_.get()[i]; + results.emplace_back(to_string_view(info->path)); + } + return results; +} + +SystemPtr AMDGPUSystemBuilder::CreateSystem() { + SHORTFIN_TRACE_SCOPE_NAMED("AMDGPUSystemBuilder::CreateSystem"); + auto lsys = std::make_shared(host_allocator()); + Enumerate(); + + // TODO: Real NUMA awareness. + lsys->InitializeNodes(1); + lsys->InitializeHalDriver(SYSTEM_DEVICE_CLASS, hip_hal_driver_); + + // Must have some device visible. + if (available_devices_count_ == 0 && + (!visible_devices_ || visible_devices_->empty())) { + throw std::invalid_argument("No AMDGPU devices found/visible"); + } + + // If a visibility list, process that. + std::vector used_device_ids; + if (visible_devices_) { + used_device_ids.reserve(visible_devices_->size()); + // In large scale partitioned cases, there could be 64+ devices, so we want + // to avoid a linear scan. Also, in some cases with partitioned physical + // devices, there can be multiple devices with the same id. In this case, + // the visibility list also connotes order/repetition, so we store with + // vectors. + std::unordered_map>> + visible_device_hal_ids; + for (size_t i = 0; i < available_devices_count_; ++i) { + iree_hal_device_info_t *info = &available_devices_.get()[i]; + visible_device_hal_ids[to_string_view(info->path)].push_back( + info->device_id); + } + + for (auto &visible_device_id : *visible_devices_) { + auto found_it = visible_device_hal_ids.find(visible_device_id); + if (found_it == visible_device_hal_ids.end()) { + throw std::invalid_argument(fmt::format( + "Requested visible device '{}' was not found on the system " + "(available: '{}')", + visible_device_id, fmt::join(GetAvailableDeviceIds(), ";"))); + } + + bool found = false; + auto &bucket = found_it->second; + for (auto &hal_id : bucket) { + if (hal_id) { + found = true; + used_device_ids.push_back(*hal_id); + hal_id.reset(); + } + } + + if (!found) { + throw std::invalid_argument( + fmt::format("Requested visible device '{}' was requested more " + "times than present on the system ({})", + visible_device_id, bucket.size())); + } + } + } else { + for (iree_host_size_t i = 0; i < available_devices_count_; ++i) { + iree_hal_device_info_t *info = &available_devices_.get()[i]; + used_device_ids.push_back(info->device_id); + } + } + + // Estimate the resource requirements for the requested number of devices. + // As of 2024-11-08, the number of file handles required to open 64 device + // partitions was 31 times the number to open one device. Because it is not + // good to run near the limit, we conservatively round that up to 64 above + // an arbitrary baseline of 768. This means that on a small, four device + // system, we will not request to raise limits for the Linux default of + // 1024 file handles, but we will raise for everything larger (which tends + // to be where the problems are). + size_t expected_device_count = + used_device_ids.size() * logical_devices_per_physical_device_; + if (!sysconfig::EnsureFileLimit(expected_device_count * 64 + 768)) { + logging::error( + "Could not ensure sufficient file handles for minimum operations: " + "Suggest setting explicit limits with `ulimit -n` and system settings"); + } + + // Initialize all used GPU devices. + for (size_t instance_ordinal = 0; instance_ordinal < used_device_ids.size(); + ++instance_ordinal) { + iree_hal_device_id_t device_id = used_device_ids[instance_ordinal]; + for (size_t logical_index = 0; + logical_index < logical_devices_per_physical_device_; + ++logical_index) { + iree::hal_device_ptr device; + SHORTFIN_THROW_IF_ERROR(iree_hal_driver_create_device_by_id( + hip_hal_driver_, device_id, 0, nullptr, host_allocator(), + device.for_output())); + DeviceAddress address( + /*system_device_class=*/SYSTEM_DEVICE_CLASS, + /*logical_device_class=*/LOGICAL_DEVICE_CLASS, + /*hal_driver_prefix=*/HAL_DRIVER_PREFIX, + /*instance_ordinal=*/instance_ordinal, + /*queue_ordinal=*/0, + /*instance_topology_address=*/{logical_index}); + ConfigureAllocators(amdgpu_allocator_specs_, device, address.device_name); + lsys->InitializeHalDevice(std::make_unique( + address, + /*hal_device=*/device, + /*node_affinity=*/0, + /*capabilities=*/static_cast(Device::Capabilities::NONE))); + } + } + + // Initialize CPU devices if requested. + if (cpu_devices_enabled_) { + // Delegate to the HostCPUSystemConfig to configure CPU devices. + // This will need to become more complicated and should happen after + // GPU configuration when mating NUMA nodes, etc. + InitializeHostCPUDefaults(); + auto *driver = InitializeHostCPUDriver(*lsys); + InitializeHostCPUDevices(*lsys, driver); + } + + lsys->FinishInitialization(); + return lsys; +} + +} // namespace shortfin::local::systems diff --git a/shortfin/src/shortfin/local/systems/amdgpu.h b/shortfin/src/shortfin/local/systems/amdgpu.h new file mode 100644 index 000000000..7e3d86abe --- /dev/null +++ b/shortfin/src/shortfin/local/systems/amdgpu.h @@ -0,0 +1,114 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H +#define SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H + +#include + +#include "iree/hal/drivers/hip/api.h" +#include "shortfin/local/system.h" +#include "shortfin/local/systems/host.h" +#include "shortfin/support/api.h" +#include "shortfin/support/iree_helpers.h" + +namespace shortfin::local::systems { + +// AMD GPU device subclass. +class SHORTFIN_API AMDGPUDevice : public Device { + public: + using Device::Device; +}; + +// System configuration for some subset of AMD GPUs connected to the local +// system. Note that this inherits from HostCPUSystemBuilder, allowing joint +// configuration of a heterogenous CPU/GPU system. Depending on the specific +// system, this can involve more than simple starting CPU drivers: datacenter +// GPU systems have specific NUMA configurations that need to be mated. +class SHORTFIN_API AMDGPUSystemBuilder : public HostCPUSystemBuilder { + public: + AMDGPUSystemBuilder(iree_allocator_t host_allocator, + ConfigOptions options = {}); + AMDGPUSystemBuilder() : AMDGPUSystemBuilder(iree_allocator_system()) {} + ~AMDGPUSystemBuilder(); + + SystemPtr CreateSystem() override; + + // Settings. + bool &cpu_devices_enabled() { return cpu_devices_enabled_; } + + // See iree_hal_hip_driver_options_t::hip_lib_search_paths. Each is either + // a directory or "file:" prefixed path to a specific HIP dynamic library. + // This is typically libamdhip64.so or amdhip64.dll. + // If the environment variable IREE_HIP_DYLIB_PATH is present, then it is + // split on ';' and each entry added here (for compatibility with IREE + // tools). + // Changing these paths after enumeration has no effect. + std::vector &hip_lib_search_paths() { + return hip_lib_search_paths_; + } + + // If set, then the system will be created to only include devices with + // the corresponding id (in the order listed). + std::optional> &visible_devices() { + return visible_devices_; + }; + + // Allocator specs to apply to amdgpu devices in this builder. + std::vector &amdgpu_allocator_specs() { + return amdgpu_allocator_specs_; + } + + // Whether to use async allocations if the device supports them (default + // true). There are various reasons to disable this in different usage + // scenarios. + bool &async_allocations() { return default_device_params_.async_allocations; } + + // "amdgpu_tracing_level": Matches IREE flag --hip_tracing: + // Permissible values are: + // 0 : stream tracing disabled. + // 1 : coarse command buffer level tracing enabled. + // 2 : fine-grained kernel level tracing enabled. + int32_t &tracing_level() { return default_device_params_.stream_tracing; } + + // The number of logical HAL devices to create per physical, visible device. + // This form of topology can be useful in certain cases where we aim to have + // oversubscription emulating what would usually be achieved with process + // level isolation. Defaults to 1. + size_t &logical_devices_per_physical_device() { + return logical_devices_per_physical_device_; + } + + // Gets all enumerated available device ids. This triggers enumeration, so + // any settings required for that must already be set. This does no filtering + // and will return all device ids. + std::vector GetAvailableDeviceIds(); + + private: + void InitializeDefaultSettings(); + // Triggers driver setup and initial device enumeration. No-op if already + // done. + void Enumerate(); + + // Valid at construction time. + iree_hal_hip_device_params_t default_device_params_; + + // Configuration. + bool cpu_devices_enabled_ = false; + std::vector hip_lib_search_paths_; + std::optional> visible_devices_; + size_t logical_devices_per_physical_device_ = 1; + std::vector amdgpu_allocator_specs_; + + // Valid post enumeration. + iree::hal_driver_ptr hip_hal_driver_; + iree_host_size_t available_devices_count_ = 0; + iree::allocated_ptr available_devices_; +}; + +} // namespace shortfin::local::systems + +#endif // SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H diff --git a/shortfin/src/shortfin/local/systems/factory.cc b/shortfin/src/shortfin/local/systems/factory.cc new file mode 100644 index 000000000..bf5b788dc --- /dev/null +++ b/shortfin/src/shortfin/local/systems/factory.cc @@ -0,0 +1,79 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/local/system.h" +#include "shortfin/support/logging.h" + +#ifdef SHORTFIN_HAVE_HOSTCPU +#include "shortfin/local/systems/host.h" +#endif + +#ifdef SHORTFIN_HAVE_AMDGPU +#include "shortfin/local/systems/amdgpu.h" +#endif + +namespace shortfin::local { + +SystemPtr System::Create(iree_allocator_t host_allocator, + std::string_view system_type, + ConfigOptions config_options) { + auto builder = SystemBuilder::ForSystem(host_allocator, system_type, + std::move(config_options)); + auto system = builder->CreateSystem(); + try { + builder->config_options().ValidateUndef(); + } catch (...) { + system->Shutdown(); + throw; + } + return system; +} + +std::unique_ptr SystemBuilder::ForSystem( + iree_allocator_t host_allocator, std::string_view system_type, + ConfigOptions config_options) { + using Factory = std::unique_ptr (*)( + iree_allocator_t host_allocator, ConfigOptions); + static const std::vector> factories{ +#ifdef SHORTFIN_HAVE_HOSTCPU + std::make_pair( + "hostcpu", + +[](iree_allocator_t host_allocator, + ConfigOptions options) -> std::unique_ptr { + return std::make_unique( + host_allocator, std::move(options)); + }), +#endif +#ifdef SHORTFIN_HAVE_AMDGPU + std::make_pair( + "amdgpu", + +[](iree_allocator_t host_allocator, + ConfigOptions options) -> std::unique_ptr { + return std::make_unique( + host_allocator, std::move(options)); + }), +#endif + }; + + for (auto &it : factories) { + if (system_type == it.first) { + return it.second(host_allocator, std::move(config_options)); + } + } + + // Not found. + std::vector available; + available.reserve(factories.size()); + for (auto &it : factories) { + available.push_back(it.first); + } + + throw std::invalid_argument( + fmt::format("System type '{}' not known (available: {})", system_type, + fmt::join(available, ", "))); +} + +} // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/systems/host.cc b/shortfin/src/shortfin/local/systems/host.cc new file mode 100644 index 000000000..1da4b2af1 --- /dev/null +++ b/shortfin/src/shortfin/local/systems/host.cc @@ -0,0 +1,247 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/local/systems/host.h" + +#include + +#include "iree/hal/local/loaders/registration/init.h" +#include "shortfin/support/iree_helpers.h" +#include "shortfin/support/logging.h" +#include "shortfin/support/sysconfig.h" + +namespace shortfin::local::systems { + +namespace { +const std::string_view SYSTEM_DEVICE_CLASS = "hostcpu"; +const std::string_view LOGICAL_DEVICE_CLASS = "cpu"; +const std::string_view HAL_DRIVER_PREFIX = "local"; + +struct TopologyHolder { + TopologyHolder() { iree_task_topology_initialize(&topology); } + ~TopologyHolder() { iree_task_topology_deinitialize(&topology); } + + iree_task_topology_t topology; +}; + +} // namespace + +// -------------------------------------------------------------------------- // +// HostCPUSystemBuilder +// -------------------------------------------------------------------------- // + +HostCPUSystemBuilder::Deps::Deps(iree_allocator_t host_allocator) { + iree_task_executor_options_initialize(&task_executor_options); + iree_hal_task_device_params_initialize(&task_params); + +#ifndef NDEBUG + // TODO: In normal IREE programs, this is exposed as --task_abort_on_failure. + // It is a critical debug feature as it forces an eager program crash at + // the point encountered vs as a later, rolled up async status. Since it + // guards things that are API usage bugs in how we are using the runtime, + // from our perspective, it is assert like, and we treat it as such. + // However, it would be best to be independently controllable. + task_params.queue_scope_flags |= IREE_TASK_SCOPE_FLAG_ABORT_ON_FAILURE; +#endif +} + +HostCPUSystemBuilder::Deps::~Deps() { + for (iree_host_size_t i = 0; i < loader_count; ++i) { + iree_hal_executable_loader_release(loaders[i]); + } + if (device_allocator) { + iree_hal_allocator_release(device_allocator); + } + if (plugin_manager) { + iree_hal_executable_plugin_manager_release(plugin_manager); + } +} + +HostCPUSystemBuilder::HostCPUSystemBuilder(iree_allocator_t host_allocator, + ConfigOptions config_options) + : HostSystemBuilder(host_allocator, std::move(config_options)), + host_cpu_deps_(host_allocator) { + hostcpu_allocator_specs_ = GetConfigAllocatorSpecs("hostcpu_allocators"); +} + +HostCPUSystemBuilder::~HostCPUSystemBuilder() = default; + +void HostCPUSystemBuilder::InitializeHostCPUDefaults() { + // Give it a default device allocator. + if (!host_cpu_deps_.device_allocator) { + SHORTFIN_THROW_IF_ERROR(iree_hal_allocator_create_heap( + iree_make_cstring_view("local"), host_allocator(), host_allocator(), + &host_cpu_deps_.device_allocator)); + } + + // And loaders. + if (host_cpu_deps_.loader_count == 0) { + SHORTFIN_THROW_IF_ERROR(iree_hal_create_all_available_executable_loaders( + /*plugin_manager=*/nullptr, IREE_ARRAYSIZE(host_cpu_deps_.loaders), + &host_cpu_deps_.loader_count, host_cpu_deps_.loaders, + host_allocator())); + } + + // Queue executors. +} + +std::vector +HostCPUSystemBuilder::SelectHostCPUNodesFromOptions() { + const unsigned MAX_NODE_COUNT = 64u; + const iree_host_size_t available_node_count = std::max( + 1u, std::min(MAX_NODE_COUNT, static_cast( + iree_task_topology_query_node_count()))); + auto topology_nodes = config_options().GetOption("hostcpu_topology_nodes"); + + std::vector nodes; + if (!topology_nodes || topology_nodes->empty() || + *topology_nodes == "current") { + // If topology_nodes not specified or "current", use a single default node. + nodes.push_back(iree_task_topology_query_current_node()); + } else if (*topology_nodes == "all") { + // If topology_nodes == "all", create a mask of all available nodes. + nodes.reserve(available_node_count); + for (iree_host_size_t i = 0; i < available_node_count; ++i) { + nodes.push_back(i); + } + } else { + // Otherwise, parse it as an integer list. + auto topology_node_ids = + config_options().GetIntList("hostcpu_topology_nodes"); + assert(topology_node_ids); + for (int64_t node_id : *topology_node_ids) { + if (node_id < 0 || (iree_host_size_t)node_id >= available_node_count) { + throw std::invalid_argument(fmt::format( + "Illegal value {} in hostcpu_topology_nodes: Expected [0..{}]", + node_id, available_node_count - 1)); + } + nodes.push_back(node_id); + } + } + return nodes; +} + +SystemPtr HostCPUSystemBuilder::CreateSystem() { + SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::CreateSystem"); + auto lsys = std::make_shared(host_allocator()); + // TODO: Real NUMA awareness. + lsys->InitializeNodes(1); + InitializeHostCPUDefaults(); + auto *driver = InitializeHostCPUDriver(*lsys); + InitializeHostCPUDevices(*lsys, driver); + lsys->FinishInitialization(); + return lsys; +} + +iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) { + SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::InitializeHostCPUDriver"); + // TODO: Kill these flag variants in favor of settings on the config + // object. + SHORTFIN_THROW_IF_ERROR(iree_task_executor_options_initialize_from_flags( + &host_cpu_deps_.task_executor_options)); + + // Determine NUMA nodes to use. + auto selected_nodes = SelectHostCPUNodesFromOptions(); + auto max_group_count = config_options().GetInt( + "hostcpu_topology_max_group_count", /*non_negative=*/true); + if (!max_group_count) { + max_group_count = 64; + } + + // Create one queue executor per node. + unsigned total_needed_file_handles = 512; + bool has_issued_limit_error = false; + std::vector queue_executors; + queue_executors.reserve(selected_nodes.size()); + queue_node_ids_.reserve(selected_nodes.size()); + for (auto node_id : selected_nodes) { + TopologyHolder topology; + iree_task_topology_performance_level_t performance_level = + IREE_TASK_TOPOLOGY_PERFORMANCE_LEVEL_ANY; + SHORTFIN_THROW_IF_ERROR(iree_task_topology_initialize_from_physical_cores( + node_id, performance_level, *max_group_count, &topology.topology)); + logging::debug("Creating hostcpu queue for NUMA node {} with {} groups", + node_id, iree_task_topology_group_count(&topology.topology)); + queue_executors.push_back({}); + auto &executor = queue_executors.back(); + // As of 2024-11-8, it took approximately 32 file handles per node-group. + // To be conservative because file handle limits are basically free, we + // round up to 64 and assume a floor of 512. This allows small, default + // 8 group, single node configs to require no limit increase for Linux + // 1024 default cases. + total_needed_file_handles += 64 * topology.topology.group_count; + if (!sysconfig::EnsureFileLimit(total_needed_file_handles) && + !has_issued_limit_error) { + logging::error( + "Could not ensure sufficient file handles for minimum operations: " + "Suggest setting explicit limits with `ulimit -n` and system " + "settings"); + has_issued_limit_error = true; + } + + SHORTFIN_THROW_IF_ERROR(iree_task_executor_create( + host_cpu_deps_.task_executor_options, &topology.topology, + host_allocator(), executor.for_output())); + queue_node_ids_.push_back(node_id); + } + + // Create the driver and save it in the System. + iree::hal_driver_ptr driver; + iree_hal_driver_t *unowned_driver; + SHORTFIN_THROW_IF_ERROR(iree_hal_task_driver_create( + /*identifier=*/ + { + .data = HAL_DRIVER_PREFIX.data(), + .size = HAL_DRIVER_PREFIX.size(), + }, + &host_cpu_deps_.task_params, /*queue_count=*/queue_executors.size(), + reinterpret_cast(queue_executors.data()), + host_cpu_deps_.loader_count, host_cpu_deps_.loaders, + host_cpu_deps_.device_allocator, host_allocator(), driver.for_output())); + unowned_driver = driver.get(); + lsys.InitializeHalDriver(SYSTEM_DEVICE_CLASS, std::move(driver)); + return unowned_driver; +} + +void HostCPUSystemBuilder::InitializeHostCPUDevices(System &lsys, + iree_hal_driver_t *driver) { + SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::InitializeHostCPUDevices"); + iree_host_size_t device_info_count = 0; + iree::allocated_ptr device_infos(host_allocator()); + SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices( + driver, host_allocator(), &device_info_count, &device_infos)); + if (device_info_count != 1) { + throw std::logic_error("Expected a single CPU device"); + } + + iree::hal_device_ptr device; + iree_hal_device_info_t *it = &device_infos.get()[0]; + SHORTFIN_THROW_IF_ERROR(iree_hal_driver_create_device_by_id( + driver, it->device_id, 0, nullptr, host_allocator(), + device.for_output())); + ConfigureAllocators(hostcpu_allocator_specs_, device, "hostcpu"); + + iree_host_size_t queue_index = 0; + for (auto node_id : queue_node_ids_) { + DeviceAddress address( + /*system_device_class=*/SYSTEM_DEVICE_CLASS, + /*logical_device_class=*/LOGICAL_DEVICE_CLASS, + /*hal_driver_prefix=*/HAL_DRIVER_PREFIX, + /*instance_ordinal=*/0, + /*queue_ordinal=*/queue_index, + /*instance_topology_address=*/{queue_index}); + lsys.InitializeHalDevice(std::make_unique( + address, + /*hal_device=*/device, + /*node_affinity=*/node_id, + /*capabilities=*/ + static_cast( + Device::Capabilities::PREFER_HOST_UNIFIED_MEMORY))); + queue_index += 1; + } +} + +} // namespace shortfin::local::systems diff --git a/libshortfin/src/shortfin/local/systems/host.h b/shortfin/src/shortfin/local/systems/host.h similarity index 83% rename from libshortfin/src/shortfin/local/systems/host.h rename to shortfin/src/shortfin/local/systems/host.h index 0655748ab..b80f997ab 100644 --- a/libshortfin/src/shortfin/local/systems/host.h +++ b/shortfin/src/shortfin/local/systems/host.h @@ -12,6 +12,7 @@ #include "iree/task/api.h" #include "shortfin/local/system.h" #include "shortfin/support/api.h" +#include "shortfin/support/config.h" namespace shortfin::local::systems { @@ -32,7 +33,8 @@ class SHORTFIN_API HostSystemBuilder : public SystemBuilder { // can extend this class (or provide features themselves). class SHORTFIN_API HostCPUSystemBuilder : public HostSystemBuilder { public: - HostCPUSystemBuilder(iree_allocator_t host_allocator); + HostCPUSystemBuilder(iree_allocator_t host_allocator, + ConfigOptions options = {}); HostCPUSystemBuilder() : HostCPUSystemBuilder(iree_allocator_system()) {} ~HostCPUSystemBuilder() override; @@ -40,6 +42,11 @@ class SHORTFIN_API HostCPUSystemBuilder : public HostSystemBuilder { // must wholly replace this method, using protected piece-wise components. SystemPtr CreateSystem() override; + // Allocator specs to apply to hostcpu devices in this builder. + std::vector& hostcpu_allocator_specs() { + return hostcpu_allocator_specs_; + } + protected: // Initializes any host-cpu defaults that have not been configured yet. void InitializeHostCPUDefaults(); @@ -54,15 +61,19 @@ class SHORTFIN_API HostCPUSystemBuilder : public HostSystemBuilder { struct Deps { Deps(iree_allocator_t host_allocator); ~Deps(); - iree_task_topology_t task_topology_options; iree_task_executor_options_t task_executor_options; iree_hal_task_device_params_t task_params; iree_hal_executable_plugin_manager_t* plugin_manager = nullptr; iree_hal_executable_loader_t* loaders[8] = {nullptr}; iree_host_size_t loader_count = 0; - iree_task_executor_t* executor = nullptr; iree_hal_allocator_t* device_allocator = nullptr; } host_cpu_deps_; + + std::vector hostcpu_allocator_specs_; + + private: + std::vector queue_node_ids_; + std::vector SelectHostCPUNodesFromOptions(); }; } // namespace shortfin::local::systems diff --git a/libshortfin/src/shortfin/local/worker.cc b/shortfin/src/shortfin/local/worker.cc similarity index 97% rename from libshortfin/src/shortfin/local/worker.cc rename to shortfin/src/shortfin/local/worker.cc index 09207e5e4..eed500891 100644 --- a/libshortfin/src/shortfin/local/worker.cc +++ b/shortfin/src/shortfin/local/worker.cc @@ -46,8 +46,8 @@ Worker::Worker(const Options options) iree_status_ignore(status); }; // TODO: We need a way to dynamically resize this vs having a hard limit. - iree_loop_sync_options_t loop_options = {.max_queue_depth = 256, - .max_wait_count = 256}; + iree_loop_sync_options_t loop_options = {.max_queue_depth = 2048, + .max_wait_count = 2048}; SHORTFIN_THROW_IF_ERROR( iree_loop_sync_allocate(loop_options, options_.allocator, &loop_sync_)); iree_loop_sync_scope_initialize(loop_sync_, OnError, this, &loop_scope_); @@ -109,6 +109,7 @@ iree_status_t Worker::TransactLoop(iree_status_t signal_status) { for (auto& next_thunk : next_thunks_) { // TODO: Make thunks have to return a status, propagate, and handle // exceptions. + SHORTFIN_TRACE_SCOPE_NAMED("Worker::ThreadsafeCallback"); next_thunk(); } next_thunks_.clear(); diff --git a/libshortfin/src/shortfin/local/worker.h b/shortfin/src/shortfin/local/worker.h similarity index 100% rename from libshortfin/src/shortfin/local/worker.h rename to shortfin/src/shortfin/local/worker.h diff --git a/libshortfin/src/shortfin/support/CMakeLists.txt b/shortfin/src/shortfin/support/CMakeLists.txt similarity index 67% rename from libshortfin/src/shortfin/support/CMakeLists.txt rename to shortfin/src/shortfin/support/CMakeLists.txt index 895c62052..ea8572466 100644 --- a/libshortfin/src/shortfin/support/CMakeLists.txt +++ b/shortfin/src/shortfin/support/CMakeLists.txt @@ -1,8 +1,8 @@ # Copyright 2024 Advanced Micro Devices, Inc. # -# Licensed under the Apache License v2.0 with LLVM Exceptions. See -# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier: -# Apache-2.0 WITH LLVM-exception +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception shortfin_cc_component( NAME @@ -10,23 +10,32 @@ shortfin_cc_component( HDRS api.h blocking_executor.h + config.h globals.h iree_helpers.h iree_concurrency.h logging.h stl_extras.h + sysconfig.h SRCS blocking_executor.cc + config.cc globals.cc iree_helpers.cc logging.cc + sysconfig.cc DEPS iree_base_base # TODO: Maybe reclassify some of these low level, shared support entities # as externally usable. iree_base_internal_threading + iree_io_file_handle iree_hal_hal + iree_io_parameter_index + iree_io_parameter_index_provider + iree_io_parameter_provider iree_modules_hal_types + iree_task_api spdlog::spdlog ) diff --git a/libshortfin/src/shortfin/support/api.h b/shortfin/src/shortfin/support/api.h similarity index 100% rename from libshortfin/src/shortfin/support/api.h rename to shortfin/src/shortfin/support/api.h diff --git a/libshortfin/src/shortfin/support/blocking_executor.cc b/shortfin/src/shortfin/support/blocking_executor.cc similarity index 100% rename from libshortfin/src/shortfin/support/blocking_executor.cc rename to shortfin/src/shortfin/support/blocking_executor.cc diff --git a/libshortfin/src/shortfin/support/blocking_executor.h b/shortfin/src/shortfin/support/blocking_executor.h similarity index 100% rename from libshortfin/src/shortfin/support/blocking_executor.h rename to shortfin/src/shortfin/support/blocking_executor.h diff --git a/libshortfin/src/shortfin/support/blocking_executor_test.cc b/shortfin/src/shortfin/support/blocking_executor_test.cc similarity index 100% rename from libshortfin/src/shortfin/support/blocking_executor_test.cc rename to shortfin/src/shortfin/support/blocking_executor_test.cc diff --git a/shortfin/src/shortfin/support/config.cc b/shortfin/src/shortfin/support/config.cc new file mode 100644 index 000000000..7de820d1c --- /dev/null +++ b/shortfin/src/shortfin/support/config.cc @@ -0,0 +1,184 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/support/config.h" + +#include +#include +#include +#include + +#include "fmt/format.h" +#include "shortfin/support/logging.h" + +namespace shortfin { + +void ConfigOptions::SetOption(std::string_view key, std::string value) { + options_[intern_.intern(key)] = std::move(value); +} + +const std::optional ConfigOptions::GetOption( + std::string_view key) const { + // Get explicit option. + auto found_it = options_.find(key); + consumed_keys_.insert(key); + if (found_it != options_.end()) { + return found_it->second; + } + + // Consult environment. + if (env_lookup_prefix_) { + std::string env_key; + env_key.reserve(env_lookup_prefix_->size() + key.size()); + env_key.append(*env_lookup_prefix_); + for (char c : key) { + env_key.push_back(std::toupper(c)); + } + auto env_value = GetRawEnv(env_key.c_str()); + if (env_value) { + return env_value; + } + } + + return {}; +} + +std::optional ConfigOptions::GetInt(std::string_view key, + bool non_negative) const { + auto value = GetOption(key); + if (!value) return {}; + int64_t result; + auto last = value->data() + value->size(); + auto err = std::from_chars(value->data(), last, result); + if (err.ec != std::errc{} || err.ptr != last) { + throw std::invalid_argument( + fmt::format("Could not parse '{}' as an integer from config option " + "{}", + *value, key)); + } + if (non_negative && result < 0) { + throw std::invalid_argument(fmt::format( + "Could not parse '{}' as a non-negative integer from config option " + "{}", + *value, key)); + } + return result; +} + +bool ConfigOptions::GetBool(std::string_view key, bool default_value) const { + auto svalue = GetOption(key); + if (!svalue) return default_value; + auto iequal = [](std::string_view a, std::string_view b) -> bool { + return std::ranges::equal(a, b, [](char c1, char c2) { + return std::toupper(c1) == std::toupper(c2); + }); + }; + if (iequal(*svalue, "1") || iequal(*svalue, "TRUE") || + iequal(*svalue, "ON")) { + return true; + } else if (iequal(*svalue, "0") || iequal(*svalue, "FALSE") || + iequal(*svalue, "OFF")) { + return false; + } else { + throw std::invalid_argument( + fmt::format("Cannot interpret {} = '{}' as bool: must be one of '1', " + "'TRUE', 'ON', '0', 'FALSE', 'OFF'", + key, *svalue)); + } +} + +std::optional> ConfigOptions::GetIntList( + std::string_view key, bool non_negative) const { + auto value = GetOption(key); + if (!value) return {}; + + std::vector results; + auto Consume = [&](std::string_view atom) { + int64_t result; + auto last = atom.data() + atom.size(); + auto err = std::from_chars(atom.data(), last, result); + if (err.ec != std::errc{} || err.ptr != last) { + throw std::invalid_argument( + fmt::format("Could not parse '{}' as an integer from config option " + "{} (full value: {})", + atom, key, *value)); + } + if (non_negative && result < 0) { + throw std::invalid_argument(fmt::format( + "Could not parse '{}' as a non-negative integer from config option " + "{} (full value: {})", + atom, key, *value)); + } + results.push_back(result); + }; + std::string_view sv_value = *value; + for (;;) { + auto found_it = sv_value.find(','); + if (found_it == std::string_view::npos) { + Consume(sv_value); + break; + } + + Consume(sv_value.substr(0, found_it)); + sv_value.remove_prefix(found_it + 1); + } + + return results; +} + +void ConfigOptions::ValidateUndef() const { + std::vector unused_options; + for (auto it : options_) { + const auto &key = it.first; + if (!consumed_keys_.contains(key)) { + unused_options.push_back(key); + } + } + if (!unused_options.empty()) { + std::string message = fmt::format( + "Specified options were not used: {} (available: {})", + fmt::join(unused_options, ", "), fmt::join(consumed_keys_, ", ")); + switch (validation_) { + case ValidationLevel::UNDEF_DEBUG: + logging::debug("{}", message); + break; + case ValidationLevel::UNDEF_WARN: + logging::warn("{}", message); + break; + case ValidationLevel::UNDEF_ERROR: + throw std::invalid_argument(std::move(message)); + } + } +} + +std::optional ConfigOptions::GetRawEnv( + const char *key) const { + char *env_value = std::getenv(key); + if (env_value && std::strlen(env_value) > 0) { + return intern_.intern(env_value); + } + return {}; +} + +// Helper to split on a delimitter. +std::vector ConfigOptions::Split(std::string_view value, + char delim) { + std::vector results; + std::string_view rest(value); + for (;;) { + auto pos = rest.find(delim); + if (pos == std::string_view::npos) { + results.push_back(rest); + break; + } + std::string_view first = rest.substr(0, pos); + rest = rest.substr(pos + 1); + results.push_back(first); + } + return results; +} + +} // namespace shortfin diff --git a/shortfin/src/shortfin/support/config.h b/shortfin/src/shortfin/support/config.h new file mode 100644 index 000000000..c91ae6d84 --- /dev/null +++ b/shortfin/src/shortfin/support/config.h @@ -0,0 +1,97 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_SUPPORT_CONFIG_H +#define SHORTFIN_SUPPORT_CONFIG_H + +#include +#include +#include +#include +#include +#include + +#include "shortfin/support/api.h" +#include "shortfin/support/stl_extras.h" + +namespace shortfin { + +// Utility class for querying free-form config options from a string list, +// environment variables, etc. +// +// Config options are looked up in a dict-like map. If not found explicitly, +// they can be optionally found in the environment if |env_lookup_prefix| is +// provided. In this case, the requested upper-cased key is concatted to the +// prefix and lookup up via std::getenv(). +class SHORTFIN_API ConfigOptions { + public: + // Level of validation to do on configuration options when calling `Create`. + enum class ValidationLevel { + UNDEF_WARN, + UNDEF_DEBUG, + UNDEF_ERROR, + }; + + ConfigOptions(std::optional env_lookup_prefix = {}, + ValidationLevel validation = ValidationLevel::UNDEF_WARN) + : env_lookup_prefix_(std::move(env_lookup_prefix)), + validation_(validation) {} + ConfigOptions(const ConfigOptions &) = delete; + ConfigOptions(ConfigOptions &&) = default; + + void SetOption(std::string_view key, std::string value); + + const std::optional GetOption(std::string_view key) const; + + // Gets an option as an integer, optionall enforcing that each is + // non-negative. + std::optional GetInt(std::string_view key, + bool non_negative = false) const; + + // Gets an option as a bool, returning |default_value| if not found. + // Bools are interpreted strictly from the string value. The following + // evaluate to true (case insensitive): + // 1, TRUE, ON + // The following evaluate to false: + // 0, FALSE, OFF + bool GetBool(std::string_view key, bool default_value = false) const; + + // Gets an option as a list of integers, optionally enforcing that each + // is non-negative. + std::optional> GetIntList( + std::string_view key, bool non_negative = false) const; + + // Gets a raw environment variable without looking up in the options or + // translating the name. + std::optional GetRawEnv(const char *key) const; + + // Helper to split on a delimitter. + static std::vector Split(std::string_view value, + char delim); + + // After all configuration options have been consumed, perform validation + // that all options were recognized. + void ValidateUndef() const; + + private: + mutable string_interner intern_; + // Optional environment variable lookup prefix for resolving options not + // explicitly set. + std::optional env_lookup_prefix_; + + // Level of validation to perform. + ValidationLevel validation_; + + // Explicit keyword options. + std::unordered_map options_; + + // Keep track of which keys were consumed. Used for error checking. + mutable std::unordered_set consumed_keys_; +}; + +} // namespace shortfin + +#endif // SHORTFIN_SUPPORT_CONFIG_H diff --git a/libshortfin/src/shortfin/support/globals.cc b/shortfin/src/shortfin/support/globals.cc similarity index 100% rename from libshortfin/src/shortfin/support/globals.cc rename to shortfin/src/shortfin/support/globals.cc diff --git a/libshortfin/src/shortfin/support/globals.h b/shortfin/src/shortfin/support/globals.h similarity index 100% rename from libshortfin/src/shortfin/support/globals.h rename to shortfin/src/shortfin/support/globals.h diff --git a/libshortfin/src/shortfin/support/iree_concurrency.h b/shortfin/src/shortfin/support/iree_concurrency.h similarity index 74% rename from libshortfin/src/shortfin/support/iree_concurrency.h rename to shortfin/src/shortfin/support/iree_concurrency.h index 2f355ded6..e6b9ffbdd 100644 --- a/libshortfin/src/shortfin/support/iree_concurrency.h +++ b/shortfin/src/shortfin/support/iree_concurrency.h @@ -14,12 +14,24 @@ #include "iree/base/internal/wait_handle.h" #include "shortfin/support/iree_helpers.h" +// Set up threading annotations. +#if defined(SHORTFIN_HAS_THREAD_SAFETY_ANNOTATIONS) +#define SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(x) __attribute__((x)) +#else +#define SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(x) +#endif + +#define SHORTFIN_GUARDED_BY(x) \ + SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(guarded_by(x)) +#define SHORTFIN_REQUIRES_LOCK(...) \ + SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(requires_capability(__VA_ARGS__)) + namespace shortfin::iree { SHORTFIN_IREE_DEF_PTR(thread); // Wraps an iree::slim_mutex as an RAII object. -class slim_mutex { +class SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(capability("mutex")) slim_mutex { public: slim_mutex() { iree_slim_mutex_initialize(&mu_); } slim_mutex(const slim_mutex &) = delete; @@ -28,15 +40,31 @@ class slim_mutex { operator iree_slim_mutex_t *() { return &mu_; } + void Lock() SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(acquire_capability()) { + iree_slim_mutex_lock(&mu_); + } + + void Unlock() SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(release_capability()) { + iree_slim_mutex_unlock(&mu_); + } + private: iree_slim_mutex_t mu_; }; // RAII slim mutex lock guard. -class slim_mutex_lock_guard { +class SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(scoped_lockable) + slim_mutex_lock_guard { public: - slim_mutex_lock_guard(slim_mutex &mu) : mu_(mu) { iree_slim_mutex_lock(mu_); } - ~slim_mutex_lock_guard() { iree_slim_mutex_unlock(mu_); } + slim_mutex_lock_guard(slim_mutex &mu) + SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(acquire_capability(mu)) + : mu_(mu) { + mu_.Lock(); + } + ~slim_mutex_lock_guard() + SHORTFIN_THREAD_ANNOTATION_ATTRIBUTE(release_capability()) { + mu_.Unlock(); + } private: slim_mutex &mu_; diff --git a/libshortfin/src/shortfin/support/iree_concurrency_test.cc b/shortfin/src/shortfin/support/iree_concurrency_test.cc similarity index 100% rename from libshortfin/src/shortfin/support/iree_concurrency_test.cc rename to shortfin/src/shortfin/support/iree_concurrency_test.cc diff --git a/libshortfin/src/shortfin/support/iree_helpers.cc b/shortfin/src/shortfin/support/iree_helpers.cc similarity index 95% rename from libshortfin/src/shortfin/support/iree_helpers.cc rename to shortfin/src/shortfin/support/iree_helpers.cc index 417a9f443..17430bb71 100644 --- a/libshortfin/src/shortfin/support/iree_helpers.cc +++ b/shortfin/src/shortfin/support/iree_helpers.cc @@ -86,13 +86,14 @@ error::error(std::string message, iree_status_t failing_status) message_(std::move(message)), failing_status_(failing_status) { message_.append(": "); + AppendStatusMessage(); } -error::error(iree_status_t failing_status) : failing_status_(failing_status) {} -void error::AppendStatus() const noexcept { - if (status_appended_) return; - status_appended_ = false; +error::error(iree_status_t failing_status) : failing_status_(failing_status) { + AppendStatusMessage(); +} +void error::AppendStatusMessage() { iree_allocator_t allocator = iree_allocator_system(); char *status_buffer = nullptr; iree_host_size_t length = 0; diff --git a/libshortfin/src/shortfin/support/iree_helpers.h b/shortfin/src/shortfin/support/iree_helpers.h similarity index 84% rename from libshortfin/src/shortfin/support/iree_helpers.h rename to shortfin/src/shortfin/support/iree_helpers.h index 7f2e28cb2..f8d3f1398 100644 --- a/libshortfin/src/shortfin/support/iree_helpers.h +++ b/shortfin/src/shortfin/support/iree_helpers.h @@ -13,7 +13,9 @@ #include "iree/base/api.h" #include "iree/base/internal/file_io.h" #include "iree/hal/api.h" +#include "iree/io/parameter_index_provider.h" #include "iree/modules/hal/types.h" +#include "iree/task/api.h" #include "iree/vm/api.h" #include "iree/vm/ref_cc.h" #include "shortfin/support/api.h" @@ -143,24 +145,25 @@ class object_ptr { // Defines a reference counting helper struct named like // iree_hal_buffer_ptr_helper (for type_stem == hal_buffer). // These must be defined in the shortfin::iree::detail namespace. -#define SHORTFIN_IREE_DEF_PTR(type_stem) \ - namespace detail { \ - struct type_stem##_ptr_helper { \ - static void steal(iree_##type_stem##_t *obj) { \ - LogIREESteal(#type_stem "_t", obj); \ - } \ - static void retain(iree_##type_stem##_t *obj) { \ - LogIREERetain(#type_stem "_t", obj); \ - iree_##type_stem##_retain(obj); \ - } \ - static void release(iree_##type_stem##_t *obj) { \ - LogIREERelease(#type_stem "_t", obj); \ - iree_##type_stem##_release(obj); \ - } \ - }; \ - } \ - using type_stem##_ptr = \ - object_ptr +#define SHORTFIN_IREE_DEF_PTR(type_stem) \ + namespace detail { \ + struct type_stem##_ptr_helper { \ + static void steal(iree_##type_stem##_t *obj) { \ + LogIREESteal(#type_stem "_t", obj); \ + } \ + static void retain(iree_##type_stem##_t *obj) { \ + LogIREERetain(#type_stem "_t", obj); \ + iree_##type_stem##_retain(obj); \ + } \ + static void release(iree_##type_stem##_t *obj) { \ + LogIREERelease(#type_stem "_t", obj); \ + iree_##type_stem##_release(obj); \ + } \ + }; \ + } \ + using type_stem##_ptr = \ + object_ptr; \ + static_assert(sizeof(type_stem##_ptr) == sizeof(iree_##type_stem##_t *)) SHORTFIN_IREE_DEF_PTR(hal_command_buffer); SHORTFIN_IREE_DEF_PTR(hal_buffer); @@ -169,6 +172,10 @@ SHORTFIN_IREE_DEF_PTR(hal_device); SHORTFIN_IREE_DEF_PTR(hal_driver); SHORTFIN_IREE_DEF_PTR(hal_fence); SHORTFIN_IREE_DEF_PTR(hal_semaphore); +SHORTFIN_IREE_DEF_PTR(io_file_handle); +SHORTFIN_IREE_DEF_PTR(io_parameter_index); +SHORTFIN_IREE_DEF_PTR(io_parameter_provider); +SHORTFIN_IREE_DEF_PTR(task_executor); SHORTFIN_IREE_DEF_PTR(vm_context); SHORTFIN_IREE_DEF_PTR(vm_instance); SHORTFIN_IREE_DEF_PTR(vm_list); @@ -270,24 +277,21 @@ class SHORTFIN_API error : public std::exception { public: error(std::string message, iree_status_t failing_status); error(iree_status_t failing_status); - error(const error &) = delete; + error(const error &other) + : code_(other.code_), + message_(other.message_), + failing_status_(iree_status_clone(other.failing_status_)) {} error &operator=(const error &) = delete; ~error() { iree_status_ignore(failing_status_); } - const char *what() const noexcept override { - if (!status_appended_) { - AppendStatus(); - } - return message_.c_str(); - }; + const char *what() const noexcept override { return message_.c_str(); }; iree_status_code_t code() const { return code_; } private: - void AppendStatus() const noexcept; + void AppendStatusMessage(); iree_status_code_t code_; - mutable std::string message_; + std::string message_; mutable iree_status_t failing_status_; - mutable bool status_appended_ = false; }; #define SHORTFIN_IMPL_HANDLE_IF_API_ERROR(var, ...) \ diff --git a/libshortfin/src/shortfin/support/iree_helpers_test.cc b/shortfin/src/shortfin/support/iree_helpers_test.cc similarity index 87% rename from libshortfin/src/shortfin/support/iree_helpers_test.cc rename to shortfin/src/shortfin/support/iree_helpers_test.cc index 5f4adcdaa..53875d617 100644 --- a/libshortfin/src/shortfin/support/iree_helpers_test.cc +++ b/shortfin/src/shortfin/support/iree_helpers_test.cc @@ -97,10 +97,9 @@ TEST(iree_error, user_message) { "Something went wrong", iree_make_status(IREE_STATUS_CANCELLED, "because I said so")); } catch (iree::error &e) { - EXPECT_THAT( - std::string(e.what()), - testing::MatchesRegex( - "^Something went wrong: .*: CANCELLED; because I said so$")); + EXPECT_THAT(std::string(e.what()), + testing::ContainsRegex( + "^Something went wrong: .*: CANCELLED; because I said so")); } } @@ -110,7 +109,7 @@ TEST(iree_error, no_user_message) { iree_make_status(IREE_STATUS_CANCELLED, "because I said so")); } catch (iree::error &e) { EXPECT_THAT(std::string(e.what()), - testing::MatchesRegex("^.*: CANCELLED; because I said so$")); + testing::ContainsRegex("^.*: CANCELLED; because I said so")); } } @@ -129,7 +128,7 @@ TEST(iree_error, throw_if_error) { FAIL(); } catch (iree::error &e) { EXPECT_THAT(std::string(e.what()), - testing::MatchesRegex("^.*: CANCELLED; because I said so$")); + testing::ContainsRegex("^.*: CANCELLED; because I said so")); } } @@ -140,9 +139,9 @@ TEST(iree_error, throw_if_error_addl_message) { "oops: %d", 1); FAIL(); } catch (iree::error &e) { - EXPECT_THAT( - std::string(e.what()), - testing::MatchesRegex("^.*: CANCELLED; because I said so; oops: 1$")); + EXPECT_THAT(std::string(e.what()), + testing::ContainsRegex("^.*: CANCELLED; because I said so;")); + EXPECT_THAT(std::string(e.what()), testing::ContainsRegex("oops: 1")); } } diff --git a/libshortfin/src/shortfin/support/logging.cc b/shortfin/src/shortfin/support/logging.cc similarity index 57% rename from libshortfin/src/shortfin/support/logging.cc rename to shortfin/src/shortfin/support/logging.cc index 1b2ff56b5..668ba7812 100644 --- a/libshortfin/src/shortfin/support/logging.cc +++ b/shortfin/src/shortfin/support/logging.cc @@ -5,3 +5,14 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "shortfin/support/logging.h" + +#include "spdlog/cfg/env.h" + +namespace shortfin::logging { + +void InitializeFromEnv() { + // TODO: Also support our own env vars. + spdlog::cfg::load_env_levels(); +} + +} // namespace shortfin::logging diff --git a/libshortfin/src/shortfin/support/logging.h b/shortfin/src/shortfin/support/logging.h similarity index 74% rename from libshortfin/src/shortfin/support/logging.h rename to shortfin/src/shortfin/support/logging.h index 99d9c64e8..e70c54e99 100644 --- a/libshortfin/src/shortfin/support/logging.h +++ b/shortfin/src/shortfin/support/logging.h @@ -7,6 +7,8 @@ #ifndef SHORTFIN_SUPPORT_LOGGING_H #define SHORTFIN_SUPPORT_LOGGING_H +#include "iree/base/tracing.h" +#include "shortfin/support/api.h" #include "spdlog/spdlog.h" #if !defined(SHORTFIN_LOG_LIFETIMES) @@ -21,8 +23,18 @@ #define SHORTFIN_SCHED_LOG(...) #endif +// Tracing macros. These are currently just aliases of the underlying IREE +// macros, but we maintain the ability to redirect them in the future (i.e. +// for certain kinds of library builds, etc). +#define SHORTFIN_TRACE_SCOPE IREE_TRACE_SCOPE +#define SHORTFIN_TRACE_SCOPE_NAMED(name_literal) \ + IREE_TRACE_SCOPE_NAMED(name_literal) +#define SHORTFIN_TRACE_SCOPE_ID IREE_TRACE_SCOPE_ID + namespace shortfin::logging { +SHORTFIN_API void InitializeFromEnv(); + // TODO: Re-export doesn't really work like this. Need to define API // exported trampolines for cross library use. using spdlog::debug; diff --git a/libshortfin/src/shortfin/support/stl_extras.h b/shortfin/src/shortfin/support/stl_extras.h similarity index 100% rename from libshortfin/src/shortfin/support/stl_extras.h rename to shortfin/src/shortfin/support/stl_extras.h diff --git a/libshortfin/src/shortfin/support/stl_extras_test.cc b/shortfin/src/shortfin/support/stl_extras_test.cc similarity index 100% rename from libshortfin/src/shortfin/support/stl_extras_test.cc rename to shortfin/src/shortfin/support/stl_extras_test.cc diff --git a/shortfin/src/shortfin/support/sysconfig.cc b/shortfin/src/shortfin/support/sysconfig.cc new file mode 100644 index 000000000..486f5ffc4 --- /dev/null +++ b/shortfin/src/shortfin/support/sysconfig.cc @@ -0,0 +1,63 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/support/sysconfig.h" + +#include "shortfin/support/logging.h" + +#ifdef __linux__ +#include +#endif + +namespace shortfin::sysconfig { + +// ----------------------------------------------------------------------------- +// File handle limits +// ----------------------------------------------------------------------------- + +#ifdef __linux__ + +bool EnsureFileLimit(unsigned needed_limit) { + struct rlimit limit; + if (getrlimit(RLIMIT_NOFILE, &limit) != 0) { + return {}; + } + + if (limit.rlim_cur >= needed_limit) return true; + unsigned requested_limit = needed_limit; + if (limit.rlim_max >= needed_limit) { + logging::debug( + "Estimated number of open file handles ({}) < current limit ({}) but " + "within max limit ({}): Increasing limit", + needed_limit, limit.rlim_cur, limit.rlim_max); + } else if (limit.rlim_max > limit.rlim_cur) { + logging::warn( + "Esimated number of open file handles ({}) < current ({}) and max ({}) " + "limit: Increasing to max", + needed_limit, limit.rlim_cur, limit.rlim_max); + requested_limit = limit.rlim_max; + } else { + logging::warn("Esimated number of open file handles ({}) < max ({})", + needed_limit, limit.rlim_max); + return false; + } + + limit.rlim_cur = requested_limit; + if (setrlimit(RLIMIT_NOFILE, &limit) != 0) { + logging::error("Could not set open file handle limit to {}", + requested_limit); + return false; + } + + return limit.rlim_cur >= needed_limit; +} + +#else +// Fallback implementation. +bool EnsureFileLimit(unsigned needed_limit) { return true; } +#endif + +} // namespace shortfin::sysconfig diff --git a/shortfin/src/shortfin/support/sysconfig.h b/shortfin/src/shortfin/support/sysconfig.h new file mode 100644 index 000000000..864405efc --- /dev/null +++ b/shortfin/src/shortfin/support/sysconfig.h @@ -0,0 +1,25 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_SUPPORT_SYSCONFIG_H +#define SHORTFIN_SUPPORT_SYSCONFIG_H + +#include +#include + +namespace shortfin::sysconfig { + +// Attempts to ensure that the given number of file descriptors can be created. +// If the system does not support such a thing (i.e. GetOpenFileLimit() returns +// nothing), then nothing is done and true is returned. If the system does +// support it and heuristics say this should be allowed, then true will return. +// Otherwise, a warning will be logged and false returned. +// This is a best effort attempt. +bool EnsureFileLimit(unsigned needed_limit); + +} // namespace shortfin::sysconfig + +#endif // SHORTFIN_SUPPORT_SYSCONFIG_H diff --git a/shortfin/tests/__init__.py b/shortfin/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/shortfin/tests/amdgpu_system_test.py b/shortfin/tests/amdgpu_system_test.py new file mode 100644 index 000000000..2ab7cbcfa --- /dev/null +++ b/shortfin/tests/amdgpu_system_test.py @@ -0,0 +1,118 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + +import shortfin as sf + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_system_defaults(): + sc = sf.amdgpu.SystemBuilder() + with sc.create_system() as ls: + print(f"DEFAULTS:", ls) + for device_name in ls.device_names: + print(f" DEVICE: {device_name} = {ls.device(device_name)}") + assert "amdgpu:0:0@0" in ls.device_names + assert "hostcpu:0:0@0" not in ls.device_names + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_tracing_level(): + sc = sf.amdgpu.SystemBuilder() + assert sc.tracing_level == 2 # Default + sc = sf.amdgpu.SystemBuilder(amdgpu_tracing_level=1) + assert sc.tracing_level == 1 + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_allocator(): + sc = sf.amdgpu.SystemBuilder(allocators="caching;debug") + assert sc.amdgpu_allocator_specs == ["caching", "debug"] + with sc.create_system() as ls: + # Nothing to verify + pass + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_async_allocations(): + sc = sf.amdgpu.SystemBuilder() + assert sc.async_allocations == True + sc = sf.amdgpu.SystemBuilder(amdgpu_async_allocations=False) + assert sc.async_allocations == False + with sc.create_system() as ls: + # Nothing to verify + pass + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_logical_devices_per_physical_device(): + # Default. + sc = sf.amdgpu.SystemBuilder() + assert sc.logical_devices_per_physical_device == 1 + + # Override. + sc = sf.amdgpu.SystemBuilder(amdgpu_logical_devices_per_physical_device=2) + assert sc.logical_devices_per_physical_device == 2 + sc.visible_devices = sc.available_devices[0:1] + with sc.create_system() as ls: + assert "amdgpu:0:0@0" in ls.device_names + assert "amdgpu:0:0@1" in ls.device_names + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_system_defaults(): + sc = sf.amdgpu.SystemBuilder(amdgpu_cpu_devices_enabled=True) + with sc.create_system() as ls: + print(f"WITH CPU:", ls) + for device_name in ls.device_names: + print(f" DEVICE: {device_name} = {ls.device(device_name)}") + assert "amdgpu:0:0@0" in ls.device_names + assert "hostcpu:0:0@0" in ls.device_names + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_system_visible(): + sc_query = sf.amdgpu.SystemBuilder() + available = sc_query.available_devices + print("AVAILABLE:", available) + + # Create a system with the explicitly listed available device. + sc_query.visible_devices = [available[0]] + with sc_query.create_system() as ls: + assert "amdgpu:0:0@0" in ls.device_names + assert len(ls.devices) == 1 + + # Create via option. + sc = sf.amdgpu.SystemBuilder(amdgpu_visible_devices=available[0]) + with sc.create_system() as ls: + assert "amdgpu:0:0@0" in ls.device_names + assert len(ls.devices) == 1 + + # Duplicates not available. + sc = sf.amdgpu.SystemBuilder( + amdgpu_visible_devices=";".join(available[0] for i in range(100)) + ) + with pytest.raises( + ValueError, match="was requested more times than present on the system" + ): + sc.create_system() + + +@pytest.mark.system("amdgpu") +def test_create_amd_gpu_system_visible_unknown(): + sc = sf.amdgpu.SystemBuilder(amdgpu_visible_devices="foobar") + with pytest.raises( + ValueError, + match="Requested visible device 'foobar' was not found on the system", + ): + sc.create_system() + + +@pytest.mark.system("amdgpu") +def test_system_ctor(): + with sf.System("amdgpu") as ls: + assert "amdgpu:0:0@0" in ls.device_names diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py new file mode 100644 index 000000000..164dfb479 --- /dev/null +++ b/shortfin/tests/api/array_ops_test.py @@ -0,0 +1,438 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import array +import math +import pytest + +import shortfin as sf +import shortfin.array as sfnp + + +@pytest.fixture +def lsys(): + # TODO: Port this test to use memory type independent access. It currently + # presumes unified memory. + # sc = sf.SystemBuilder() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +def test_argmax(device): + src = sfnp.device_array(device, [4, 16, 128], dtype=sfnp.float32) + data = [float(i) for i in range(math.prod([1, 16, 128]))] + for i in range(4): + src.view(i).items = data + data.reverse() + + # default variant + result = sfnp.argmax(src) + assert result.shape == [4, 16] + assert result.view(0).items.tolist() == [127] * 16 + assert result.view(1).items.tolist() == [0] * 16 + assert result.view(2).items.tolist() == [127] * 16 + assert result.view(3).items.tolist() == [0] * 16 + + # keepdims variant + result = sfnp.argmax(src, keepdims=True) + assert result.shape == [4, 16, 1] + + # out= variant + out = sfnp.device_array(device, [4, 16], dtype=sfnp.int64) + sfnp.argmax(src, out=out) + assert out.shape == [4, 16] + assert out.view(0).items.tolist() == [127] * 16 + assert out.view(1).items.tolist() == [0] * 16 + assert out.view(2).items.tolist() == [127] * 16 + assert out.view(3).items.tolist() == [0] * 16 + + # out= keepdims variant (left aligned rank broadcast is allowed) + out = sfnp.device_array(device, [4, 16, 1], dtype=sfnp.int64) + sfnp.argmax(src, keepdims=True, out=out) + assert out.shape == [4, 16, 1] + assert out.view(0).items.tolist() == [127] * 16 + assert out.view(1).items.tolist() == [0] * 16 + assert out.view(2).items.tolist() == [127] * 16 + assert out.view(3).items.tolist() == [0] * 16 + + +def test_argmax_axis0(device): + src = sfnp.device_array(device, [4, 16], dtype=sfnp.float32) + for j in range(4): + src.view(j).items = [ + float((j + 1) * (i + 1) - j * 4) for i in range(math.prod([1, 16])) + ] + print(repr(src)) + + # default variant + result = sfnp.argmax(src, axis=0) + assert result.shape == [16] + assert result.items.tolist() == [0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] + + # keepdims variant + result = sfnp.argmax(src, axis=0, keepdims=True) + assert result.shape == [1, 16] + + # out= variant + out = sfnp.device_array(device, [16], dtype=sfnp.int64) + sfnp.argmax(src, axis=0, out=out) + + # out= keepdims variant + out = sfnp.device_array(device, [1, 16], dtype=sfnp.int64) + sfnp.argmax(src, axis=0, keepdims=True, out=out) + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.float16, + sfnp.float32, + ], +) +def test_argmax_dtypes(device, dtype): + # Just verifies that the dtype functions. We don't have IO support for + # some of these. + src = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + sfnp.argmax(src) + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.float16, + sfnp.float32, + ], +) +def test_fill_randn_default_generator(device, dtype): + out1 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out1.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out1) + out2 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out2.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out2) + + with out1.map(read=True) as m1, out2.map(read=True) as m2: + # The default generator should populate two different arrays. + contents1 = bytes(m1) + contents2 = bytes(m2) + assert contents1 != contents2 + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.float16, + sfnp.float32, + ], +) +def test_fill_randn_explicit_generator(device, dtype): + gen1 = sfnp.RandomGenerator(42) + gen2 = sfnp.RandomGenerator(42) + out1 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out1.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out1, generator=gen1) + out2 = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with out2.map(write=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(out2, generator=gen2) + zero = sfnp.device_array(device, [4, 16, 128], dtype=dtype) + with zero.map(write=True) as m: + m.fill(bytes(1)) + + with out1.map(read=True) as m1, out2.map(read=True) as m2, zero.map( + read=True + ) as mz: + # Using explicit generators with the same seed should produce the + # same distributions. + contents1 = bytes(m1) + contents2 = bytes(m2) + assert contents1 == contents2 + # And not be zero. + assert contents1 != bytes(mz) + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.uint8, + sfnp.uint16, + sfnp.uint32, + sfnp.uint64, + sfnp.int8, + sfnp.int16, + sfnp.int32, + sfnp.int64, + sfnp.float16, + sfnp.float32, + sfnp.float64, + ], +) +def test_convert(device, dtype): + input_array = sfnp.device_array(device, [2, 3], dtype=sfnp.int32) + with input_array.map(write=True) as m: + m.fill(16) + intermediate = sfnp.convert(input_array, dtype=dtype) + with input_array.map(write=True) as m: + m.fill(0) + sfnp.convert(intermediate, out=input_array) + assert list(input_array.items) == 6 * [16] + + +def round_half_up(n): + return math.floor(n + 0.5) + + +def round_half_away_from_zero(n): + rounded_abs = round_half_up(abs(n)) + return math.copysign(rounded_abs, n) + + +@pytest.mark.parametrize( + "dtype,sfnp_func,ref_round_func", + [ + (sfnp.float16, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.round, round_half_away_from_zero), + (sfnp.float16, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.ceil, math.ceil), + (sfnp.float16, sfnp.floor, math.floor), + (sfnp.float32, sfnp.floor, math.floor), + (sfnp.float16, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.trunc, math.trunc), + ], +) +def test_nearest_int_no_conversion(device, dtype, sfnp_func, ref_round_func): + input = sfnp.device_array(device, [2, 3], dtype=dtype) + sfnp.fill_randn(input) + ref_rounded = [ + ref_round_func(n) for n in sfnp.convert(input, dtype=sfnp.float32).items + ] + output = sfnp_func(input) + assert output.dtype == dtype + output_items = sfnp.convert(output, dtype=sfnp.float32).items + print(output_items) + for ref, actual in zip(ref_rounded, output_items): + assert ref == pytest.approx(actual) + + +@pytest.mark.parametrize( + "dtype,out_dtype,sfnp_func,ref_round_func", + [ + # Round + (sfnp.float16, sfnp.int8, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.int8, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.int16, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.int32, sfnp.round, round_half_away_from_zero), + # Note that we do not test unsigned conversion with random data. + # Ceil + (sfnp.float16, sfnp.int8, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.int8, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.int16, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.int32, sfnp.ceil, math.ceil), + # Floor + (sfnp.float16, sfnp.int8, sfnp.floor, math.floor), + (sfnp.float32, sfnp.int8, sfnp.floor, math.floor), + (sfnp.float32, sfnp.int16, sfnp.floor, math.floor), + (sfnp.float32, sfnp.int32, sfnp.floor, math.floor), + # Trunc + (sfnp.float16, sfnp.int8, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.int8, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.int16, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.int32, sfnp.trunc, math.trunc), + ], +) +def test_nearest_int_conversion(device, dtype, out_dtype, sfnp_func, ref_round_func): + input = sfnp.device_array(device, [2, 3], dtype=dtype) + sfnp.fill_randn(input) + ref_rounded = [ + int(ref_round_func(n)) for n in sfnp.convert(input, dtype=sfnp.float32).items + ] + output = sfnp_func(input, dtype=out_dtype) + assert output.dtype == out_dtype + for ref, actual in zip(ref_rounded, output.items): + assert ref == int(actual) + + +def test_elementwise_forms(device): + # All elementwise ops use the same template expansion which enforces + # certain common invariants. Here we test these on the multiply op, + # relying on a parametric test for actual behavior. + with pytest.raises( + ValueError, + match="Elementwise operators require at least one argument to be a device_array", + ): + sfnp.multiply(2, 2) + + ary = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32) + with ary.map(discard=True) as m: + m.fill(42.0) + + # Rhs scalar int accepted. + result = sfnp.multiply(ary, 2) + assert list(result.items) == [84.0] * 6 + + # Rhs scalar float accepted. + result = sfnp.multiply(ary, 2.0) + assert list(result.items) == [84.0] * 6 + + # Lhs scalar int accepted. + result = sfnp.multiply(2, ary) + assert list(result.items) == [84.0] * 6 + + # Lhs scalar float accepted. + result = sfnp.multiply(2.0, ary) + assert list(result.items) == [84.0] * 6 + + # Out. + out = sfnp.device_array.for_host(device, [2, 3], dtype=sfnp.float32) + sfnp.multiply(2.0, ary, out=out) + assert list(out.items) == [84.0] * 6 + + +@pytest.mark.parametrize( + "lhs_dtype,rhs_dtype,promoted_dtype", + [ + (sfnp.float32, sfnp.float16, sfnp.float32), + (sfnp.float16, sfnp.float32, sfnp.float32), + (sfnp.float32, sfnp.float64, sfnp.float64), + (sfnp.float64, sfnp.float32, sfnp.float64), + # Integer promotion. + (sfnp.uint8, sfnp.uint16, sfnp.uint16), + (sfnp.uint16, sfnp.uint32, sfnp.uint32), + (sfnp.uint32, sfnp.uint64, sfnp.uint64), + (sfnp.int8, sfnp.int16, sfnp.int16), + (sfnp.int16, sfnp.int32, sfnp.int32), + (sfnp.int32, sfnp.int64, sfnp.int64), + # Signed/unsigned promotion. + (sfnp.int8, sfnp.uint8, sfnp.int16), + (sfnp.int16, sfnp.uint16, sfnp.int32), + (sfnp.int32, sfnp.uint32, sfnp.int64), + (sfnp.int8, sfnp.uint32, sfnp.int64), + ], +) +def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): + # Tests that promotion infers an appropriate result type. + lhs = sfnp.device_array.for_host(device, [2, 3], lhs_dtype) + rhs = sfnp.device_array.for_host(device, [2, 3], rhs_dtype) + result = sfnp.multiply(lhs, rhs) + assert result.dtype == promoted_dtype + + +@pytest.mark.parametrize( + "dtype,op,check_value", + [ + # Add. + (sfnp.int8, sfnp.add, 44.0), + (sfnp.int16, sfnp.add, 44.0), + (sfnp.int32, sfnp.add, 44.0), + (sfnp.int64, sfnp.add, 44.0), + (sfnp.uint8, sfnp.add, 44.0), + (sfnp.uint16, sfnp.add, 44.0), + (sfnp.uint32, sfnp.add, 44.0), + (sfnp.uint64, sfnp.add, 44.0), + (sfnp.float16, sfnp.add, 44.0), + (sfnp.float32, sfnp.add, 44.0), + (sfnp.float64, sfnp.add, 44.0), + # Divide. + (sfnp.int8, sfnp.divide, 21.0), + (sfnp.int16, sfnp.divide, 21.0), + (sfnp.int32, sfnp.divide, 21.0), + (sfnp.int64, sfnp.divide, 21.0), + (sfnp.uint8, sfnp.divide, 21.0), + (sfnp.uint16, sfnp.divide, 21.0), + (sfnp.uint32, sfnp.divide, 21.0), + (sfnp.uint64, sfnp.divide, 21.0), + (sfnp.float16, sfnp.divide, 21.0), + (sfnp.float32, sfnp.divide, 21.0), + (sfnp.float64, sfnp.divide, 21.0), + # Multiply. + (sfnp.int8, sfnp.multiply, 84.0), + (sfnp.int16, sfnp.multiply, 84.0), + (sfnp.int32, sfnp.multiply, 84.0), + (sfnp.int64, sfnp.multiply, 84.0), + (sfnp.uint8, sfnp.multiply, 84.0), + (sfnp.uint16, sfnp.multiply, 84.0), + (sfnp.uint32, sfnp.multiply, 84.0), + (sfnp.uint64, sfnp.multiply, 84.0), + (sfnp.float16, sfnp.multiply, 84.0), + (sfnp.float32, sfnp.multiply, 84.0), + (sfnp.float64, sfnp.multiply, 84.0), + # Subtract. + (sfnp.int8, sfnp.subtract, 40.0), + (sfnp.int16, sfnp.subtract, 40.0), + (sfnp.int32, sfnp.subtract, 40.0), + (sfnp.int64, sfnp.subtract, 40.0), + (sfnp.uint8, sfnp.subtract, 40.0), + (sfnp.uint16, sfnp.subtract, 40.0), + (sfnp.uint32, sfnp.subtract, 40.0), + (sfnp.uint64, sfnp.subtract, 40.0), + (sfnp.float16, sfnp.subtract, 40.0), + (sfnp.float32, sfnp.subtract, 40.0), + (sfnp.float64, sfnp.subtract, 40.0), + ], +) +def test_elementwise_array_correctness(device, dtype, op, check_value): + lhs = sfnp.device_array.for_host(device, [2, 2], sfnp.int32) + with lhs.map(discard=True) as m: + m.fill(42) + + rhs = sfnp.device_array.for_host(device, [2], sfnp.int32) + with rhs.map(discard=True) as m: + m.fill(2) + + lhs = sfnp.convert(lhs, dtype=dtype) + rhs = sfnp.convert(rhs, dtype=dtype) + result = op(lhs, rhs) + assert result.shape == [2, 2] + result = sfnp.convert(result, dtype=sfnp.float32) + items = list(result.items) + assert items == [check_value] * 4 + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.int8, + sfnp.int16, + sfnp.int32, + sfnp.int64, + sfnp.uint8, + sfnp.uint16, + sfnp.uint32, + sfnp.uint64, + sfnp.float32, + sfnp.float16, + sfnp.float32, + sfnp.float64, + ], +) +def test_transpose(device, dtype): + input = sfnp.device_array.for_host(device, [3, 2], sfnp.int32) + input.items = [0, 1, 2, 3, 4, 5] + input = sfnp.convert(input, dtype=dtype) + permuted = sfnp.transpose(input, [1, 0]) + assert permuted.shape == [2, 3] + items = list(sfnp.convert(permuted, dtype=sfnp.int32).items) + assert items == [0, 2, 4, 1, 3, 5] + + out = sfnp.device_array.for_host(device, [2, 3], dtype) + sfnp.transpose(input, [1, 0], out=out) + items = list(sfnp.convert(permuted, dtype=sfnp.int32).items) + assert items == [0, 2, 4, 1, 3, 5] diff --git a/libshortfin/tests/api/array_storage_test.py b/shortfin/tests/api/array_storage_test.py similarity index 73% rename from libshortfin/tests/api/array_storage_test.py rename to shortfin/tests/api/array_storage_test.py index 9af208f5a..5ecf9bdc9 100644 --- a/libshortfin/tests/api/array_storage_test.py +++ b/shortfin/tests/api/array_storage_test.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import pytest +import sys import shortfin as sf import shortfin.array as sfnp @@ -12,30 +13,30 @@ @pytest.fixture def lsys(): - sc = sf.host.CPUSystemBuilder() + sc = sf.SystemBuilder() lsys = sc.create_system() yield lsys lsys.shutdown() @pytest.fixture -def scope(lsys): - return lsys.create_scope() +def fiber(lsys): + return lsys.create_fiber() @pytest.fixture -def device(scope): - return scope.device(0) +def device(fiber): + return fiber.device(0) def test_allocate_host(device): - s = sfnp.storage.allocate_host(device, 32) - assert len(bytes(s.data)) == 32 + h = sfnp.storage.allocate_host(device, 32) + assert len(h) == 32 def test_allocate_device(device): - s = sfnp.storage.allocate_device(device, 64) - assert len(bytes(s.data)) == 64 + d = sfnp.storage.allocate_device(device, 64) + assert len(d) == 64 def test_fill1(lsys, device): @@ -43,7 +44,7 @@ async def main(): s = sfnp.storage.allocate_host(device, 8) s.fill(b"0") await device - assert bytes(s.data) == b"00000000" + assert bytes(s.map(read=True)) == b"00000000" lsys.run(main()) @@ -53,7 +54,7 @@ async def main(): s = sfnp.storage.allocate_host(device, 8) s.fill(b"01") await device - assert bytes(s.data) == b"01010101" + assert bytes(s.map(read=True)) == b"01010101" lsys.run(main()) @@ -63,11 +64,14 @@ async def main(): s = sfnp.storage.allocate_host(device, 8) s.fill(b"0123") await device - assert bytes(s.data) == b"01230123" + assert bytes(s.map(read=True)) == b"01230123" lsys.run(main()) +@pytest.mark.skipif( + sys.platform == "win32", reason="Windows fatal exception: access violation" +) def test_fill_error(device): s = sfnp.storage.allocate_host(device, 8) with pytest.raises(RuntimeError): @@ -80,6 +84,9 @@ def test_fill_error(device): s.fill(b"01234567") +@pytest.mark.skipif( + sys.platform == "win32", reason="Windows fatal exception: access violation" +) @pytest.mark.parametrize( "pattern,size", [ @@ -135,7 +142,7 @@ async def main(): mv = memoryview(m) assert not mv.readonly mv[0] = ord(b"9") - assert bytes(src.data) == b"91230123" + assert bytes(src.map(read=True)) == b"91230123" lsys.run(main()) @@ -150,16 +157,27 @@ async def main(): assert not mv.readonly for i in range(8): mv[i] = ord(b"9") - i - assert bytes(src.data) == b"98765432" + assert bytes(src.map(read=True)) == b"98765432" lsys.run(main()) -def test_data_write(lsys, device): +@pytest.mark.parametrize( + "alloc_bytes,fill_value,expected_value", + [ + (8, b"9", b"99999999"), + (8, b"98", b"98989898"), + (8, b"9876", b"98769876"), + (8, b"98765432", b"98765432"), + (20, b"9876543210", b"98765432109876543210"), + ], +) +def test_mapping_fill1(lsys, device, alloc_bytes, fill_value, expected_value): async def main(): - src = sfnp.storage.allocate_host(device, 8) - src.data = b"98765432" - assert bytes(src.data) == b"98765432" + src = sfnp.storage.allocate_host(device, alloc_bytes) + with src.map(discard=True) as m: + m.fill(fill_value) + assert bytes(src.map(read=True)) == expected_value lsys.run(main()) diff --git a/shortfin/tests/api/array_test.py b/shortfin/tests/api/array_test.py new file mode 100644 index 000000000..be10c05b7 --- /dev/null +++ b/shortfin/tests/api/array_test.py @@ -0,0 +1,244 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import array +import math +import pytest + +import shortfin as sf +import shortfin.array as sfnp + + +@pytest.fixture +def lsys(): + # TODO: Port this test to use memory type independent access. It currently + # presumes unified memory. + # sc = sf.SystemBuilder() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +def test_storage_constructor(lsys, device): + async def main(): + s = sfnp.storage.allocate_host(device, 8) + s.fill(b"\0\1\2\3") + await device + ary = sfnp.device_array(s, [2, 4], sfnp.uint8) + assert ary.dtype == sfnp.uint8 + assert ary.shape == [2, 4] + assert list(ary.items) == [0, 1, 2, 3, 0, 1, 2, 3] + assert ary.device == device + assert ary.storage == s + + lsys.run(main()) + + +def test_device_constructor(lsys, device): + async def main(): + ary = sfnp.device_array(device, [2, 4], sfnp.uint8) + ary.storage.fill(b"\0\1\2\3") + await device + assert ary.dtype == sfnp.uint8 + assert ary.shape == [2, 4] + assert list(ary.items) == [0, 1, 2, 3, 0, 1, 2, 3] + assert ary.device == device + + lsys.run(main()) + + +def test_fill_copy_from_for_transfer(lsys, device): + async def main(): + src = sfnp.device_array(device, [2, 4], sfnp.uint8) + src.fill(b"\0\1\2\3") + dst = src.for_transfer() + dst.copy_from(src) + await device + assert list(dst.items) == [0, 1, 2, 3, 0, 1, 2, 3] + + lsys.run(main()) + + +def test_fill_copy_to_for_transfer(lsys, device): + async def main(): + src = sfnp.device_array(device, [2, 4], sfnp.uint8) + src.fill(b"\0\1\2\3") + dst = src.for_transfer() + src.copy_to(dst) + await device + assert list(dst.items) == [0, 1, 2, 3, 0, 1, 2, 3] + + lsys.run(main()) + + +def test_shape_overflow(lsys, device): + async def main(): + s = sfnp.storage.allocate_host(device, 4) + _ = sfnp.device_array(s, [2, 4], sfnp.uint8) + + with pytest.raises( + ValueError, match="Array storage requires at least 8 bytes but has only 4" + ): + lsys.run(main()) + + +@pytest.mark.parametrize( + "dtype,code,py_value,expected_str", + [ + (sfnp.int8, "b", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + (sfnp.int16, "h", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + (sfnp.int32, "i", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + ( + sfnp.float32, + "f", + 42.0, + "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", + ), + ( + sfnp.float64, + "d", + 42.0, + "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", + ), + ], +) +def test_xtensor_types(fiber, dtype, code, py_value, expected_str): + ary = sfnp.device_array.for_host(fiber.device(0), [2, 4], dtype) + with ary.map(discard=True) as m: + m.fill(py_value) + s = str(ary) + print("__str__ =", s) + assert expected_str == s, f"Expected '{expected_str}' == '{s}'" + r = repr(ary) + print("__repr__ =", r) + assert expected_str in r, f"Expected '{expected_str}' in '{r}'" + + +@pytest.mark.parametrize( + "dtype,value,", + [ + (sfnp.int8, 42), + (sfnp.int16, 42), + (sfnp.int32, 42), + (sfnp.int64, 42), + (sfnp.float32, 42.0), + (sfnp.float64, 42.0), + ], +) +def test_items(fiber, dtype, value): + ary = sfnp.device_array.for_host(fiber.device(0), [2, 4], dtype) + ary.items = [value] * 8 + readback = ary.items.tolist() + assert readback == [value] * 8 + + +@pytest.mark.parametrize( + "dtype,value,", + [ + (sfnp.int8, 42), + (sfnp.int16, 42), + (sfnp.int32, 42), + (sfnp.int64, 42), + (sfnp.float32, 42.0), + (sfnp.float64, 42.0), + ], +) +def test_typed_mapping(fiber, dtype, value): + ary = sfnp.device_array.for_host(fiber.device(0), [2, 4], dtype) + with ary.map(discard=True) as m: + m.fill(value) + readback = ary.items.tolist() + assert readback == [value] * 8 + + # Map as read/write and validate. + with ary.map(read=True, write=True) as m: + new_values = m.items.tolist() + for i in range(len(new_values)): + new_values[i] += 1 + m.items = new_values + + readback = ary.items.tolist() + assert readback == [value + 1] * 8 + + +@pytest.mark.parametrize( + "keys,expected", + [ + # Simple indexing + ([0, 0], [0]), + # Row indexing + ([0], [0, 1, 2, 3]), + # Sliced indexing + ([1, slice(2, 4)], [2, 3]), + ([slice(1, 2), slice(2, 4)], [2, 3]), + ], +) +def test_view(device, keys, expected): + src = sfnp.device_array(device, [4, 4], sfnp.uint8) + with src.map(discard=True) as m: + m.fill(b"\0\1\2\3") + view = src.view(*keys) + assert list(view.items) == expected + + +def test_view_nd(device): + shape = [4, 16, 128] + data = [i for i in range(math.prod(shape))] + src = sfnp.device_array(device, [4, 16, 128], dtype=sfnp.uint32) + src.items = data + + # Validate left justified indexing into the first dimension. + for i in range(4): + v = src.view(i) + v_items = v.items.tolist() + assert len(v_items) == 2048 + assert v_items[0] == i * 2048 + assert v_items[-1] == (i + 1) * 2048 - 1 + for i in range(16): + v = src.view(1, i) + v_items = v.items.tolist() + assert len(v_items) == 128 + assert v_items[0] == 2048 + i * 128 + assert v_items[-1] == 2048 + (i + 1) * 128 - 1 + for i in range(128): + v = src.view(1, 1, 1) + v_items = v.items.tolist() + assert len(v_items) == 1 + assert v_items[0] == 2177 + + # Validate span. + for i in range(16): + v = src.view(slice(2, 4)) + v_items = v.items.tolist() + assert len(v_items) == 4096 + assert v_items[0] == 4096 + assert v_items[-1] == 8191 + + +def test_view_unsupported(lsys, device): + async def main(): + src = sfnp.device_array(device, [4, 4], sfnp.uint8) + src.fill(b"\0\1\2\3") + + with pytest.raises( + ValueError, + match="Cannot create a view with dimensions following a spanning dim", + ): + view = src.view(slice(0, 2), 1) + await device + + lsys.run(main()) diff --git a/shortfin/tests/api/array_use_case_test.py b/shortfin/tests/api/array_use_case_test.py new file mode 100644 index 000000000..d4a030d45 --- /dev/null +++ b/shortfin/tests/api/array_use_case_test.py @@ -0,0 +1,64 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import array +import math +import pytest + +import shortfin as sf +import shortfin.array as sfnp + + +@pytest.fixture +def lsys(): + # TODO: Port this test to use memory type independent access. It currently + # presumes unified memory. + # sc = sf.SystemBuilder() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +# Tests a typical image conversion from a model oriented layout to an array +# of contained images. +def test_image_to_bytes(device): + bs = 2 + height = 16 + width = 12 + images_shape = [bs, 3, height, width] + images_planar = sfnp.device_array.for_host(device, images_shape, sfnp.float32) + # Band the data so that each channel increases by 0.1 across images. + for i in range(bs): + for j in range(3): + data = [i * 0.3 + j * 0.1 for _ in range(height * width)] + images_planar.view(i, j).items = data + images_planar = sfnp.convert(images_planar, dtype=sfnp.float16) + + # Extract and convert each image to interleaved RGB bytes. + images = [] + for idx in range(images_planar.shape[0]): + image_planar = images_planar.view(idx) + assert image_planar.shape == [1, 3, 16, 12] + image_interleaved = sfnp.transpose(image_planar, (0, 2, 3, 1)) + assert image_interleaved.shape == [1, 16, 12, 3] + image_scaled = sfnp.multiply(image_interleaved, 255) + image = sfnp.round(image_scaled, dtype=sfnp.uint8) + image_bytes = bytes(image.map(read=True)) + images.append(image_bytes) + + assert images[0] == b"\x00\x1a3" * 192 + assert images[1] == b"Mf\x80" * 192 diff --git a/shortfin/tests/api/nputils_test.py b/shortfin/tests/api/nputils_test.py new file mode 100644 index 000000000..fdb650054 --- /dev/null +++ b/shortfin/tests/api/nputils_test.py @@ -0,0 +1,274 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import importlib +import logging +import math +import os +import pytest +import sys +from unittest.mock import patch + +import shortfin as sf +import shortfin.array as sfnp +import shortfin.host + +np = pytest.importorskip("numpy", reason="numpy is not installed") +from shortfin.array import nputils + + +@pytest.fixture +def lsys(): + # TODO: Port this test to use memory type independent access. It currently + # presumes unified memory. + # sc = sf.SystemBuilder() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +@pytest.fixture(scope="function") +def configure_caplog(caplog): + caplog.set_level(logging.INFO, logger=None) + yield caplog + + +def test_to_np_from_device_array(device): + def _verify_array(np_arr, compare_to, shape, dtype): + assert isinstance(np_arr, np.ndarray) + assert np_arr.shape == shape + assert np_arr.dtype == dtype + assert np.array_equal(np_arr, compare_to) + + shape = [4, 16, 128] + + # Test various dtypes (f32, f64, i32, i64) + src_f32 = sfnp.device_array(device, shape, dtype=sfnp.float32) + src_f64 = sfnp.device_array(device, shape, dtype=sfnp.float64) + src_i32 = sfnp.device_array(device, shape, dtype=sfnp.int32) + src_i64 = sfnp.device_array(device, shape, dtype=sfnp.int64) + compare_to = np.zeros([4, 16, 128], dtype=np.float32) + data = [i for i in range(math.prod([1, 16, 128]))] + for i in range(4): + src_f32.view(i).items = data + src_f64.view(i).items = data + src_i32.view(i).items = data + src_i64.view(i).items = data + compare_to[i] = np.array(data).reshape(16, 128) + data.reverse() + + # Convert to np array + np_array_f32 = np.array(src_f32) + np_array_f64 = np.array(src_f64) + np_array_i32 = np.array(src_i32) + np_array_i64 = np.array(src_i64) + + _verify_array(np_array_f32, compare_to, tuple(shape), np.float32) + _verify_array(np_array_f64, compare_to, tuple(shape), np.float64) + _verify_array(np_array_i32, compare_to, tuple(shape), np.int32) + _verify_array(np_array_i64, compare_to, tuple(shape), np.int64) + + +def test_to_np_from_device_f16(device, lsys): + def _int_to_f16_uint(n): + """Converts an integer to a float16 uint, for convenient testing.""" + if n == 0: + return 0 + exponent = n.bit_length() - 1 + if n == 2**exponent: + return (exponent + 15) << 10 + else: + fraction = n - 2**exponent + fraction_bits = fraction << (10 - exponent) + return ((exponent + 15) << 10) | fraction_bits + + def _verify_array(np_arr, compare_to, shape, dtype): + assert isinstance(np_arr, np.ndarray) + assert np_arr.shape == shape + assert np_arr.dtype == dtype + assert np.array_equal(np_arr, compare_to) + + shape = [4, 42, 48] + src = sfnp.device_array(device, shape, dtype=sfnp.float16) + compare_to = np.zeros([4, 42, 48], dtype=np.float16) + data_uint = [_int_to_f16_uint(i) for i in range(math.prod([1, 42, 48]))] + data = [i for i in range(math.prod([1, 42, 48]))] + for i in range(4): + src.view(i).items = data_uint + compare_to[i] = np.array(data, dtype=np.float16).reshape(42, 48) + data.reverse() + data_uint.reverse() + + # Convert to np array + np_array = np.array(src) + _verify_array(np_array, compare_to, tuple(shape), np.float16) + + +def test_dump_array(device): + shape = [4, 16, 128] + src = sfnp.device_array(device, shape, dtype=sfnp.float32) + data = [float(i) for i in range(math.prod([1, 16, 128]))] + for i in range(4): + src.view(i).items = data + data.reverse() + + # Ensure array is dumped properly to log output + log_messages = [] + with patch.object( + nputils.logger, + "debug", + side_effect=lambda message: log_messages.append(message), + ): + nputils.debug_dump_array(src) + src_np_array = np.array(src) + arr_str = str(src_np_array) + assert arr_str == str(log_messages[0]) + + +def test_fill_array(device, lsys): + shape = [4, 16, 128] + src = sfnp.device_array(device, shape, dtype=sfnp.float32) + data = [0 for _ in range(math.prod([1, 16, 128]))] + for i in range(4): + src.view(i).items = data + + # Fill array + fill_value = 3.14 + np_array = nputils.debug_fill_array(src, fill_value) + + # Check if the values are correct + compare_to = np.zeros([16, 128], dtype=np.float32) + compare_to.fill(fill_value) + for i in range(4): + assert np_array[i].tolist() == compare_to.tolist() + + +def test__find_mode_basic(): + arr = np.array([1, 2, 3, 3, 4, 5, 5, 5, 5, 5]) + mode, count = nputils._find_mode(arr) + assert mode == 5 + assert count == 5 + + +def test__find_mode_empty(): + arr = np.array([]) + mode, count = nputils._find_mode(arr) + assert math.isnan(mode) + assert count == 0 + + +def test__find_mode_multi_dim(): + arr = np.array([[1, 2, 3], [3, 4, 5], [5, 5, 5]]) + mode, count = nputils._find_mode(arr, axis=1) + assert mode.tolist() == [1, 3, 5] + assert count.tolist() == [1, 1, 3] + + +def test__find_mode_keep_dim(): + arr = np.array([[1, 2, 3], [3, 4, 5], [5, 5, 5]]) + mode, count = nputils._find_mode(arr, axis=1, keepdims=True) + assert mode.tolist() == [[1], [3], [5]] + assert count.tolist() == [[1], [1], [3]] + + +def test_log_tensor_stats_basic(device, lsys, caplog): + shape = [1, 6] + src = sfnp.device_array(device, shape, dtype=sfnp.float32) + data = [1, 2, 3, 3, 4, 5] + src.view(0).items = data + + # Ensure array stats are logged properly + log_messages = [] + with patch.object( + nputils.logger, + "debug", + side_effect=lambda message: log_messages.append(message), + ): + nputils.debug_log_tensor_stats(src) + assert log_messages[0] == "NaN count: 0 / 6" + assert log_messages[1] == "Shape: (1, 6), dtype: float32" + assert log_messages[2] == "Min (excluding NaN): 1.0" + assert log_messages[3] == "Max (excluding NaN): 5.0" + assert log_messages[4] == "Mean (excluding NaN): 3.0" + assert log_messages[5] == "Mode (excluding NaN): 3.0" + assert log_messages[6] == "First 10 elements: [1. 2. 3. 3. 4. 5.]" + assert log_messages[7] == "Last 10 elements: [1. 2. 3. 3. 4. 5.]" + + +def test_log_tensor_stats_with_nan(device, lsys, caplog): + shape = [1, 8] + src = sfnp.device_array(device, shape, dtype=sfnp.float64) + data = [3, np.nan, 4, 3, 1, np.nan, 5, 9] + src.view(0).items = data + + # Ensure array stats are logged properly + log_messages = [] + with patch.object( + nputils.logger, + "debug", + side_effect=lambda message: log_messages.append(message), + ): + nputils.debug_log_tensor_stats(src) + assert log_messages[0] == "NaN count: 2 / 8" + assert log_messages[1] == "Shape: (1, 8), dtype: float64" + assert log_messages[2] == "Min (excluding NaN): 1.0" + assert log_messages[3] == "Max (excluding NaN): 9.0" + assert log_messages[4] == "Mean (excluding NaN): 4.166666666666667" + assert log_messages[5] == "Mode (excluding NaN): 3.0" + assert log_messages[6] == "First 10 elements: [3. 4. 3. 1. 5. 9.]" + assert log_messages[7] == "Last 10 elements: [3. 4. 3. 1. 5. 9.]" + + +def test_log_tensor_stats_empty(device, lsys, caplog): + shape = [1, 0] + src = sfnp.device_array(device, shape, dtype=sfnp.float32) + + # Ensure array stats are logged properly + log_messages = [] + with patch.object( + nputils.logger, + "debug", + side_effect=lambda message: log_messages.append(message), + ): + nputils.debug_log_tensor_stats(src) + assert log_messages[0] == "NaN count: 0 / 0" + assert log_messages[1] == "Shape: (1, 0), dtype: float32" + + +def test_log_tensor_stats_multi_dim(device, lsys, caplog): + shape = [3, 3] + src = sfnp.device_array(device, shape, dtype=sfnp.float32) + data = [[1, np.nan, 3], [np.nan, 4, 5], [5, np.nan, 7]] + for i in range(3): + src.view(i).items = data[i] + + # Ensure array stats are logged properly + log_messages = [] + with patch.object( + nputils.logger, + "debug", + side_effect=lambda message: log_messages.append(message), + ): + nputils.debug_log_tensor_stats(src) + assert log_messages[0] == "NaN count: 3 / 9" + assert log_messages[1] == "Shape: (3, 3), dtype: float32" + assert log_messages[2] == "Min (excluding NaN): 1.0" + assert log_messages[3] == "Max (excluding NaN): 7.0" + assert log_messages[4] == "Mean (excluding NaN): 4.166666507720947" + assert log_messages[5] == "Mode (excluding NaN): 5.0" + assert log_messages[6] == "First 10 elements: [1. 3. 4. 5. 5. 7.]" + assert log_messages[7] == "Last 10 elements: [1. 3. 4. 5. 5. 7.]" diff --git a/shortfin/tests/apps/llm/components/cache_test.py b/shortfin/tests/apps/llm/components/cache_test.py new file mode 100644 index 000000000..169d082b1 --- /dev/null +++ b/shortfin/tests/apps/llm/components/cache_test.py @@ -0,0 +1,94 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Tests for llm kvcache component. +""" + +import pytest +import time +import tempfile +import shortfin as sf +from _shortfin import lib as sfl +from shortfin_apps.llm.components import cache +from shortfin_apps.llm.components import config_struct +import json +from pathlib import Path + + +@pytest.fixture +def lsys(): + sc = sfl.local.host.CPUSystemBuilder() + ls = sc.create_system() + yield ls + ls.shutdown() + + +@pytest.fixture +def fiber(lsys): + # TODO: Should adopt the main thread. + worker = lsys.create_worker("main") + return lsys.create_fiber(worker) + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +@pytest.fixture +def model_params(): + model_params = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 2048, + "attn_head_count": 32, + "attn_head_dim": 100, + "prefill_batch_sizes": [4], + "decode_batch_sizes": [4], + "transformer_block_count": 26, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + + # Create a temporary file to store the JSON + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as tmp_file: + json.dump(model_params, tmp_file, indent=4) + tmp_path = Path(tmp_file.name) + + try: + # Load the JSON using config_struct + model_params = config_struct.ModelParams.load_json(tmp_path) + yield model_params + finally: + tmp_path.unlink + + +@pytest.fixture +def cache_fixture(fiber, model_params) -> cache.AttnPageCache: + # Create and return the cache object + return cache.AttnPageCache( + devices=fiber.devices_dict.values(), model_params=model_params + ) + + +@pytest.mark.parametrize("n_allocated", [1, 16, 255]) +def test_alloc( + cache_fixture: cache.AttnPageCache, + n_allocated, + model_params: config_struct.ModelParams, +): + alloc_page_count = cache_fixture.page_tables[0].shape[0] + + assert alloc_page_count == model_params.paged_kv_cache.device_block_count + + pages = cache_fixture.acquire_free_pages(n_allocated) + last_page = alloc_page_count - 1 + expected_indices = range(last_page, last_page - n_allocated, -1) + for p, expected_ix in zip(pages, expected_indices): + assert p.index == expected_ix + assert p.index > 0 diff --git a/shortfin/tests/apps/llm/components/tokenizer_test.py b/shortfin/tests/apps/llm/components/tokenizer_test.py new file mode 100644 index 000000000..b7e4ee8b9 --- /dev/null +++ b/shortfin/tests/apps/llm/components/tokenizer_test.py @@ -0,0 +1,70 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + + +@pytest.fixture +def bert_tokenizer(): + import shortfin_apps.llm.components.tokenizer as tokenizer + + return tokenizer.Tokenizer.from_pretrained("bert-base-cased") + + +def test_tokenizers_lib(bert_tokenizer): + enc0, enc1 = bert_tokenizer.encode(["This is sequence 1", "Sequence 2"]) + assert enc0.ids == [101, 1188, 1110, 4954, 122, 102] + assert enc1.ids == [101, 22087, 25113, 123, 102, 0] + texts = bert_tokenizer.decode([enc0.ids, enc1.ids]) + assert texts == ["This is sequence 1", "Sequence 2"] + + # Test manual padding. + enc0.pad(12) + assert enc0.ids == [101, 1188, 1110, 4954, 122, 102, 0, 0, 0, 0, 0, 0] + assert bert_tokenizer.encoding_length(enc0) == 12 + + +def test_tokenizer_to_array(cpu_fiber, bert_tokenizer): + batch_seq_len = 12 + encs = bert_tokenizer.encode(["This is sequence 1", "Sequence 2"]) + bert_tokenizer.post_process_encodings(encs, batch_seq_len) + ary = bert_tokenizer.encodings_to_array(cpu_fiber.device(0), encs, batch_seq_len) + print(ary) + assert ary.view(0).items.tolist() == [ + 101, + 1188, + 1110, + 4954, + 122, + 102, + 0, + 0, + 0, + 0, + 0, + 0, + ] + assert ary.view(1).items.tolist() == [ + 101, + 22087, + 25113, + 123, + 102, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ] + + masks = bert_tokenizer.attention_masks_to_array( + cpu_fiber.device(0), encs, batch_seq_len + ) + print(masks) + assert masks.view(0).items.tolist() == [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] + assert masks.view(1).items.tolist() == [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] diff --git a/shortfin/tests/apps/llm/conftest.py b/shortfin/tests/apps/llm/conftest.py new file mode 100644 index 000000000..6cd06f385 --- /dev/null +++ b/shortfin/tests/apps/llm/conftest.py @@ -0,0 +1,17 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + +from shortfin.support.deps import ShortfinDepNotFoundError + + +@pytest.fixture(autouse=True) +def require_deps(): + try: + import shortfin_apps.llm + except ShortfinDepNotFoundError as e: + pytest.skip(f"Dep not available: {e}") diff --git a/shortfin/tests/apps/sd/components/tokenizer_test.py b/shortfin/tests/apps/sd/components/tokenizer_test.py new file mode 100644 index 000000000..05515ec30 --- /dev/null +++ b/shortfin/tests/apps/sd/components/tokenizer_test.py @@ -0,0 +1,55 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + + +@pytest.fixture +def clip_tokenizer(): + from shortfin_apps.sd.components.tokenizer import Tokenizer + + return Tokenizer.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", "tokenizer" + ) + + +def test_transformers_tokenizer(clip_tokenizer): + enc0 = clip_tokenizer.encode(["This is sequence 1", "Sequence 2"]) + e0 = enc0.input_ids[0, :10] + e1 = enc0.input_ids[1, :10] + assert e0.tolist() == [ + 49406, + 589, + 533, + 18833, + 272, + 49407, + 49407, + 49407, + 49407, + 49407, + ] + assert e1.tolist() == [ + 49406, + 18833, + 273, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + ] + + +def test_tokenizer_to_array(cpu_fiber, clip_tokenizer): + batch_seq_len = 64 + encs = clip_tokenizer.encode(["This is sequence 1", "Sequence 2"]) + ary = clip_tokenizer.encodings_to_array(cpu_fiber.device(0), encs, batch_seq_len) + print(ary) + assert ary.view(0).items.tolist()[:5] == [49406, 589, 533, 18833, 272] + assert ary.view(1).items.tolist()[:5] == [49406, 18833, 273, 49407, 49407] diff --git a/shortfin/tests/apps/sd/conftest.py b/shortfin/tests/apps/sd/conftest.py new file mode 100644 index 000000000..1a08d9b4b --- /dev/null +++ b/shortfin/tests/apps/sd/conftest.py @@ -0,0 +1,17 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + +from shortfin.support.deps import ShortfinDepNotFoundError + + +@pytest.fixture(autouse=True) +def require_deps(): + try: + import shortfin_apps.sd + except ShortfinDepNotFoundError as e: + pytest.skip(f"Dep not available: {e}") diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py new file mode 100644 index 000000000..26c2e30f6 --- /dev/null +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -0,0 +1,265 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import requests +import time +import asyncio +import base64 +import pytest +import subprocess +import os +import socket +import sys +import copy +import math +import tempfile +from contextlib import closing + +from datetime import datetime as dt +from PIL import Image + +BATCH_SIZES = [1] + +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [5], + "guidance_scale": [7.5], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def start_server(fibers_per_device=1, isolation="per_fiber"): + # Start the server + srv_args = [ + "python", + "-m", + "shortfin_apps.sd.server", + ] + with open("sdxl_config_i8.json", "wb") as f: + r = requests.get( + "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/11022024/configs/sdxl_config_i8.json", + allow_redirects=True, + ) + f.write(r.content) + srv_args.extend( + [ + f"--model_config=sdxl_config_i8.json", + f"--fibers_per_device={fibers_per_device}", + f"--isolation={isolation}", + f"--splat", + ] + ) + runner = ServerRunner(srv_args) + # Wait for server to start + time.sleep(3) + return runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1(): + runner = start_server(fibers_per_device=1) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1_per_call(): + runner = start_server(fibers_per_device=1, isolation="per_call") + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd2(): + runner = start_server(fibers_per_device=2) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd8(): + runner = start_server(fibers_per_device=8) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.mark.system("amdgpu") +def test_sd_server(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url) + assert len(imgs) == 1 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs4_dense(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_percall(sd_server_fpd1_per_call): + imgs, status_code = send_json_file(sd_server_fpd1_per_call.url, num_copies=8) + assert len(imgs) == 8 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs4_dense_fpd2(sd_server_fpd2): + imgs, status_code = send_json_file(sd_server_fpd2.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=8) + assert len(imgs) == 8 + assert status_code == 200 + + +@pytest.mark.skip +@pytest.mark.system("amdgpu") +def test_sd_server_bs64_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=64) + assert len(imgs) == 64 + assert status_code == 200 + + +@pytest.mark.skip +@pytest.mark.xfail(reason="Unexpectedly large client batch.") +@pytest.mark.system("amdgpu") +def test_sd_server_bs512_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=512) + assert len(imgs) == 512 + assert status_code == 200 + + +class ServerRunner: + def __init__(self, args): + port = str(find_free_port()) + self.url = "http://0.0.0.0:" + port + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.process = subprocess.Popen( + [ + *args, + "--port=" + port, + "--device=amdgpu", + ], + env=env, + # TODO: Have a more robust way of forking a subprocess. + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_ready() + + def _wait_for_ready(self): + start = time.time() + while True: + time.sleep(2) + try: + if requests.get(f"{self.url}/health").status_code == 200: + return + except Exception as e: + if self.process.errors is not None: + raise RuntimeError("API server process terminated") from e + time.sleep(1.0) + if (time.time() - start) > 30: + raise RuntimeError("Timeout waiting for server start") + + def __del__(self): + try: + process = self.process + except AttributeError: + pass + else: + process.terminate() + process.wait() + + +def bytes_to_img(bytes, idx=0, width=1024, height=1024): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(bytes) + ) + return image + + +def send_json_file(url="http://0.0.0.0:8000", num_copies=1): + # Read the JSON file + data = copy.deepcopy(sample_request) + imgs = [] + # Send the data to the /generate endpoint + data["prompt"] = ( + [data["prompt"]] + if isinstance(data["prompt"], str) + else data["prompt"] * num_copies + ) + try: + response = requests.post(url + "/generate", json=data) + response.raise_for_status() # Raise an error for bad responses + request = json.loads(response.request.body.decode("utf-8")) + + for idx, item in enumerate(response.json()["images"]): + width = getbatched(request, idx, "width") + height = getbatched(request, idx, "height") + img = bytes_to_img(item.encode("utf-8"), idx, width, height) + imgs.append(img) + + except requests.exceptions.RequestException as e: + print(f"Error sending the request: {e}") + + return imgs, response.status_code + + +def getbatched(req, idx, key): + if isinstance(req[key], list): + if len(req[key]) == 1: + return req[key][0] + elif len(req[key]) > idx: + return req[key][idx] + else: + return req[key] + + +def find_free_port(): + """This tries to find a free port to run a server on for the test. + + Race conditions are possible - the port can be acquired between when this + runs and when the server starts. + + https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("localhost", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def test_placeholder(): + # Here in case this pytest is invoked via CPU CI and no tests are run. + pass diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py new file mode 100644 index 000000000..083698968 --- /dev/null +++ b/shortfin/tests/conftest.py @@ -0,0 +1,114 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import pytest +import shlex + +import shortfin as sf + + +def pytest_addoption(parser): + parser.addoption( + "--system", + action="store", + metavar="NAME", + default="hostcpu", + help="Enable tests for system name ('hostcpu', 'amdgpu', ...)", + ) + parser.addoption( + "--compile-flags", + action="store", + metavar="FLAGS", + help="Compile flags to run test on the --system (required if it cannot be inferred)", + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "system(name): mark test to run only on a named system" + ) + config.addinivalue_line( + "markers", "slow: mark test to run in a separate, slow suite." + ) + + +def pytest_runtest_setup(item): + system_type = item.config.getoption("--system") + # Filter tests based on system mark. + required_system_names = [mark.args[0] for mark in item.iter_markers("system")] + if required_system_names: + if not all(name == system_type for name in required_system_names): + pytest.skip( + f"test requires system in {required_system_names!r} but has " + f"{system_type!r} (set with --system arg)" + ) + # Set the default. + sf.SystemBuilder.default_system_type = system_type + + +# Keys that will be cleaned project wide prior to and after each test run. +# Test code can freely modify these. +CLEAN_ENV_KEYS = [ + "SHORTFIN_ALLOCATORS", + "SHORTFIN_AMDGPU_ALLOCATORS", + "SHORTFIN_AMDGPU_ASYNC_ALLOCATIONS", + "SHORTFIN_AMDGPU_LOGICAL_DEVICES_PER_PHYSICAL_DEVICE", + "SHORTFIN_AMDGPU_TRACING_LEVEL", + "SHORTFIN_HOSTCPU_ALLOCATORS", + "SHORTFIN_HOSTCPU_TOPOLOGY_NODES", + "SHORTFIN_HOSTCPU_TOPOLOGY_MAX_GROUP_COUNT", + "SHORTFIN_SYSTEM_TYPE", +] + + +@pytest.fixture(scope="session") +def compile_flags(pytestconfig) -> list[str]: + compile_flags = pytestconfig.getoption("--compile-flags") + if compile_flags is not None: + return shlex.split(compile_flags) + # Try to figure it out from the system. + system_type = pytestconfig.getoption("--system") + if system_type == "hostcpu": + return [ + "--iree-hal-target-device=llvm-cpu", + "--iree-llvmcpu-target-cpu=host", + ] + pytest.skip( + reason="Test needs to compile a binary and no --compile-flags set (or " + "could not be inferred)" + ) + + +@pytest.fixture(autouse=True) +def clean_env(): + def kill(): + for key in CLEAN_ENV_KEYS: + if key in os.environ: + del os.environ[key] + os.unsetenv(key) + + kill() + yield + kill() + + +@pytest.fixture +def cpu_lsys(): + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def cpu_fiber(cpu_lsys): + return cpu_lsys.create_fiber() + + +@pytest.fixture +def cpu_device(cpu_fiber): + return cpu_fiber.device(0) diff --git a/libshortfin/tests/examples/async_test.py b/shortfin/tests/examples/async_test.py similarity index 91% rename from libshortfin/tests/examples/async_test.py rename to shortfin/tests/examples/async_test.py index ea3dc20ab..06bc47587 100644 --- a/libshortfin/tests/examples/async_test.py +++ b/shortfin/tests/examples/async_test.py @@ -8,6 +8,7 @@ # those as examples and launch them here. from pathlib import Path +import pytest import subprocess import sys @@ -16,7 +17,7 @@ def run_example(path: Path): - subprocess.check_call([sys.executable, str(path)]) + subprocess.check_call([sys.executable, str(path)], timeout=60) def test_async_basic_asyncio(): diff --git a/libshortfin/tests/examples/fastapi_test.py b/shortfin/tests/examples/fastapi_test.py similarity index 96% rename from libshortfin/tests/examples/fastapi_test.py rename to shortfin/tests/examples/fastapi_test.py index 5640f0d4b..bd9c350ee 100644 --- a/libshortfin/tests/examples/fastapi_test.py +++ b/shortfin/tests/examples/fastapi_test.py @@ -20,6 +20,10 @@ @pytest.fixture(scope="session") def server(): + try: + import fastapi + except ModuleNotFoundError as e: + pytest.skip(f"Required dep not available: {e}") runner = ServerRunner([]) yield runner print("Sending kill signal") diff --git a/shortfin/tests/host_cpu_system_test.py b/shortfin/tests/host_cpu_system_test.py new file mode 100644 index 000000000..bd45f6e61 --- /dev/null +++ b/shortfin/tests/host_cpu_system_test.py @@ -0,0 +1,100 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import pytest +import re +import sys + +import shortfin as sf + + +def test_create_host_cpu_system_defaults(): + sc = sf.host.CPUSystemBuilder() + with sc.create_system() as ls: + print(f"DEFAULT LOCAL SYSTEM:", ls) + print("\n".join(repr(d) for d in ls.devices)) + assert len(ls.devices) > 0 + + +@pytest.mark.skipif( + sys.platform == "win32", reason="Windows fatal exception: access violation" +) +def test_create_host_cpu_system_topology_nodes_all(): + sc = sf.host.CPUSystemBuilder( + hostcpu_topology_nodes="all", hostcpu_topology_max_group_count=2 + ) + with sc.create_system() as ls: + print(f"NODES ALL LOCAL SYSTEM:", ls) + print("\n".join(repr(d) for d in ls.devices)) + assert len(ls.devices) > 0 + + +def test_create_host_cpu_system_topology_nodes_explicit(): + sc = sf.host.CPUSystemBuilder( + hostcpu_topology_nodes="0,0", hostcpu_topology_max_group_count=2 + ) + with sc.create_system() as ls: + print(f"NODES EXPLICIT LOCAL SYSTEM:", ls) + print("\n".join(repr(d) for d in ls.devices)) + assert len(ls.devices) == 2 + + +@pytest.mark.skipif( + sys.platform == "win32", + reason="Only detecting 1 device, check config setup from env vars?", +) +def test_create_host_cpu_system_env_vars(): + os.environ["SHORTFIN_HOSTCPU_TOPOLOGY_NODES"] = "0,0" + os.environ["SHORTFIN_HOSTCPU_TOPOLOGY_MAX_GROUP_COUNT"] = "2" + sc = sf.host.CPUSystemBuilder() + with sc.create_system() as ls: + print(f"ENV VARS LOCAL SYSTEM:", ls) + print("\n".join(repr(d) for d in ls.devices)) + assert len(ls.devices) == 2 + + +def test_create_host_cpu_system_allocators(): + pytest.skip("Setting allocators triggers LSAN leak. See #443") + sc = sf.host.CPUSystemBuilder(hostcpu_allocators="caching;debug") + assert sc.hostcpu_allocator_specs == ["caching", "debug"] + with sc.create_system() as ls: + pass + + +def test_create_host_cpu_system_unsupported_option(): + sc = sf.host.CPUSystemBuilder(unsupported="foobar") + with pytest.raises( + ValueError, match="Specified options were not used: unsupported" + ): + sc.create_system() + + +def test_system_ctor(): + with sf.System( + "hostcpu", hostcpu_topology_nodes="0,0", hostcpu_topology_max_group_count=2 + ) as ls: + print(f"NODES EXPLICIT LOCAL SYSTEM:", ls) + print("\n".join(repr(d) for d in ls.devices)) + assert len(ls.devices) == 2 + + +def test_system_ctor_unknown_type(): + with pytest.raises( + ValueError, + match=re.escape("System type 'NOTDEFINED' not known (available: hostcpu"), + ): + sf.System("NOTDEFINED") + + +def test_system_ctor_undef_error(): + with pytest.raises(ValueError, match="Specified options were not used: undef"): + sf.System("hostcpu", undef=1) + + +def test_system_ctor_undef_warn(): + with sf.System("hostcpu", validate_undef=False, undef=1) as ls: + ... diff --git a/shortfin/tests/invocation/conftest.py b/shortfin/tests/invocation/conftest.py new file mode 100644 index 000000000..e62373eb5 --- /dev/null +++ b/shortfin/tests/invocation/conftest.py @@ -0,0 +1,59 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import urllib.request + + +def upgrade_onnx(original_path, converted_path): + import onnx + + original_model = onnx.load_model(original_path) + converted_model = onnx.version_converter.convert_version(original_model, 17) + onnx.save(converted_model, converted_path) + + +@pytest.fixture(scope="session") +def mobilenet_onnx_path(tmp_path_factory): + try: + import onnx + except ModuleNotFoundError: + raise pytest.skip("onnx python package not available") + parent_dir = tmp_path_factory.mktemp("mobilenet_onnx") + orig_onnx_path = parent_dir / "mobilenet_orig.onnx" + upgraded_onnx_path = parent_dir / "mobilenet.onnx" + if not upgraded_onnx_path.exists(): + print("Downloading mobilenet.onnx") + urllib.request.urlretrieve( + "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", + orig_onnx_path, + ) + upgrade_onnx(orig_onnx_path, upgraded_onnx_path) + return upgraded_onnx_path + + +@pytest.fixture(scope="session") +def mobilenet_compiled_path(mobilenet_onnx_path, compile_flags): + try: + import iree.compiler.tools as tools + import iree.compiler.tools.import_onnx.__main__ as import_onnx + except ModuleNotFoundError: + raise pytest.skip("iree.compiler packages not available") + mlir_path = mobilenet_onnx_path.parent / "mobilenet.mlir" + vmfb_path = mobilenet_onnx_path.parent / "mobilenet_cpu.vmfb" + if not vmfb_path.exists(): + print("Compiling mobilenet") + args = import_onnx.parse_arguments( + ["-o", str(mlir_path), str(mobilenet_onnx_path)] + ) + import_onnx.main(args) + tools.compile_file( + str(mlir_path), + output_file=str(vmfb_path), + input_type="onnx", + extra_args=compile_flags, + ) + return vmfb_path diff --git a/shortfin/tests/invocation/mobilenet_program_test.py b/shortfin/tests/invocation/mobilenet_program_test.py new file mode 100644 index 000000000..ff7b9bbf2 --- /dev/null +++ b/shortfin/tests/invocation/mobilenet_program_test.py @@ -0,0 +1,236 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import array +import asyncio +import time +import functools +import pytest + +import shortfin as sf +import shortfin.array as sfnp + + +@pytest.fixture +def lsys(): + sc = sf.SystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber0(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber0): + return fiber0.device(0) + + +@pytest.fixture +def mobilenet_program_function( + lsys, mobilenet_compiled_path +) -> tuple[sf.ProgramFunction]: + program_module = lsys.load_module(mobilenet_compiled_path) + program = sf.Program([program_module], devices=lsys.devices) + main_function = program["module.torch-jit-export"] + return main_function + + +@pytest.fixture +def mobilenet_program_function_per_call( + lsys, mobilenet_compiled_path +) -> tuple[sf.ProgramFunction]: + program_module = lsys.load_module(mobilenet_compiled_path) + program = sf.Program( + [program_module], devices=lsys.devices, isolation=sf.ProgramIsolation.PER_CALL + ) + main_function = program["module.torch-jit-export"] + return main_function + + +def get_mobilenet_ref_input(device) -> sfnp.device_array: + dummy_data = array.array( + "f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) + ) + device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) + staging_input = device_input.for_transfer() + with staging_input.map(discard=True) as m: + m.fill(dummy_data) + device_input.copy_from(staging_input) + return device_input + + +async def assert_mobilenet_ref_output(device, device_output): + host_output = device_output.for_transfer() + host_output.copy_from(device_output) + await device + flat_output = host_output.items + absmean = functools.reduce( + lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 + ) + assert absmean == pytest.approx(5.01964943873882) + + +# Tests that a single invocation on a single fiber works. +def test_invoke_mobilenet_single_per_fiber(lsys, fiber0, mobilenet_program_function): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function(device_input, fiber=fiber0) + await assert_mobilenet_ref_output(device, device_output) + + lsys.run(main()) + + +# Tests that a single invocation on a single fiber in per_call mode works. +def test_invoke_mobilenet_single_per_call( + lsys, fiber0, mobilenet_program_function_per_call +): + assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function_per_call( + device_input, fiber=fiber0 + ) + await assert_mobilenet_ref_output(device, device_output) + + lsys.run(main()) + + +# Tests that chained back to back invocations on the same fiber work correctly. +# Does an async gather/assert with all results at the end. +def test_invoke_mobilenet_chained_per_fiber(lsys, fiber0, mobilenet_program_function): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + results = [ + await mobilenet_program_function(device_input, fiber=fiber0) + for _ in range(5) + ] + + await asyncio.gather( + *[ + assert_mobilenet_ref_output(device, device_output) + for (device_output,) in results + ] + ) + + lsys.run(main()) + + +# Tests that parallel invocations on a single fiber with a program in PER_CALL +# isolation functions properly. Note that in this variant, the await is done +# on all invocations vs serially per invocation (as in +# test_invoke_mobilenet_chained_per_fiber). This would be illegal if done on the +# same fiber without PER_CALL isolation managing forks. +# +# Note that since these are all operating on the same fiber, they are added to +# the device-side work graph with a one-after-the-other dependency, but the +# host side schedules concurrently. +def test_invoke_mobilenet_parallel_per_call( + lsys, fiber0, mobilenet_program_function_per_call +): + assert mobilenet_program_function_per_call.isolation == sf.ProgramIsolation.PER_CALL + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + results = await asyncio.gather( + *[ + mobilenet_program_function_per_call(device_input, fiber=fiber0) + for _ in range(5) + ] + ) + + await asyncio.gather( + *[ + assert_mobilenet_ref_output(device, device_output) + for (device_output,) in results + ] + ) + + lsys.run(main()) + + +# Same as above but uses explicit isolation controls on the function vs as the +# program level. If this constraint were violated, shortfin makes a best effort +# attempt to detect the situation and raise an exception, but there are a subset +# of programs which are purely async and would make detection of this exception +# lossy in the synchronous completion case. +def test_invoke_mobilenet_parallel_per_call_explicit( + lsys, fiber0, mobilenet_program_function +): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER + device = fiber0.device(0) + + async def main(): + device_input = get_mobilenet_ref_input(device) + results = await asyncio.gather( + *[ + mobilenet_program_function( + device_input, fiber=fiber0, isolation=sf.ProgramIsolation.PER_CALL + ) + for _ in range(50) + ] + ) + + await asyncio.gather( + *[ + assert_mobilenet_ref_output(device, device_output) + for (device_output,) in results + ] + ) + + lsys.run(main()) + + +# Tests that independent executions on multiple fibers all run concurrently. +# All fibers share the same host thread but schedule concurrently. Since +# each fiber has its own timeline, device side graphs have no dependency on +# each other and also schedule concurrently. +def test_invoke_mobilenet_multi_fiber_per_fiber(lsys, mobilenet_program_function): + assert mobilenet_program_function.isolation == sf.ProgramIsolation.PER_FIBER + + class InferProcess(sf.Process): + async def run(self): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + print(f"{self}: Start") + device = self.fiber.device(0) + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function( + device_input, fiber=self.fiber + ) + print(f"{self}: Program complete (+{duration()}ms)") + await assert_mobilenet_ref_output(device, device_output) + print(f"{self} End (+{duration()}ms)") + + async def main(): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + fibers = [lsys.create_fiber() for _ in range(5)] + print("Fibers:", fibers) + processes = [InferProcess(fiber=f).launch() for f in fibers] + print("Waiting for processes:", processes) + await asyncio.gather(*processes) + print(f"All processes complete: (+{duration()}ms)") + + lsys.run(main()) diff --git a/shortfin/tests/invocation/vmfb_buffer_access_test.py b/shortfin/tests/invocation/vmfb_buffer_access_test.py new file mode 100644 index 000000000..d86a58822 --- /dev/null +++ b/shortfin/tests/invocation/vmfb_buffer_access_test.py @@ -0,0 +1,218 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import shortfin as sf +import shortfin.array as sfnp +import array +import random + + +@pytest.fixture +def kvcache_compiled_cpu_path(tmp_path_factory, compile_flags): + try: + import iree.compiler.tools as tools + except ModuleNotFoundError: + raise pytest.skip("iree.compiler packages not available") + + print("Compiling kvcache module") + + KVCACHE_MODULE_CONTENTS = """ + module @kvcache { + func.func @write_kvcache(%kvcache: !torch.tensor<[?,2662400],f16>, %new_data: !torch.vtensor<[16,32,100],f16>, %page_index: !torch.vtensor<[1],si64>, %layer_index: !torch.vtensor<[1],si64>) { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int16 = torch.constant.int 16 + %int26 = torch.constant.int 26 + %int32 = torch.constant.int 32 + %int100 = torch.constant.int 100 + %int2662400 = torch.constant.int 2662400 + %false = torch.constant.bool false + %none = torch.constant.none + + %0 = torch.copy.to_vtensor %kvcache : !torch.vtensor<[?,2662400],f16> + + // Get the number of pages + %num_pages = torch.aten.size.int %0, %int0 : !torch.vtensor<[?,2662400],f16>, !torch.int -> !torch.int + + // Reshape kvcache to [?,26,2,16,32,100] + %1 = torch.prim.ListConstruct %num_pages, %int26, %int2, %int16, %int32, %int100 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[?,2662400],f16>, !torch.list -> !torch.vtensor<[?,26,2,16,32,100],f16> + + // Create index list with the provided tensors + %3 = torch.prim.ListConstruct %page_index, %layer_index : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list> + + // Update the kvcache + %4 = torch.aten.index_put %2, %3, %new_data, %false : !torch.vtensor<[?,26,2,16,32,100],f16>, !torch.list>, !torch.vtensor<[16,32,100],f16>, !torch.bool -> !torch.vtensor<[?,26,2,16,32,100],f16> + + // Reshape back to original shape + %5 = torch.prim.ListConstruct %num_pages, %int2662400 : (!torch.int, !torch.int) -> !torch.list + %6 = torch.aten.view %4, %5 : !torch.vtensor<[?,26,2,16,32,100],f16>, !torch.list -> !torch.vtensor<[?,2662400],f16> + + // Overwrite the original tensor + torch.overwrite.tensor.contents %6 overwrites %kvcache : !torch.vtensor<[?,2662400],f16>, !torch.tensor<[?,2662400],f16> + + return + } + + func.func @read_kvcache(%kvcache: !torch.tensor<[?,2662400],f16>, %page_index: !torch.vtensor<[1],si64>, %layer_index: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,16,32,100],f16> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int16 = torch.constant.int 16 + %int26 = torch.constant.int 26 + %int32 = torch.constant.int 32 + %int100 = torch.constant.int 100 + %int2662400 = torch.constant.int 2662400 + %none = torch.constant.none + + %0 = torch.copy.to_vtensor %kvcache : !torch.vtensor<[?,2662400],f16> + + // Get the number of pages + %num_pages = torch.aten.size.int %0, %int0 : !torch.vtensor<[?,2662400],f16>, !torch.int -> !torch.int + + // Reshape kvcache to [?,26,2,16,32,100] + %1 = torch.prim.ListConstruct %num_pages, %int26, %int2, %int16, %int32, %int100 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[?,2662400],f16>, !torch.list -> !torch.vtensor<[?,26,2,16,32,100],f16> + + // Create index list with the provided tensors + %3 = torch.prim.ListConstruct %page_index, %layer_index : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list> + + // Read from the kvcache and squeeze the result + %4 = torch.aten.index.Tensor %2, %3 : !torch.vtensor<[?,26,2,16,32,100],f16>, !torch.list> -> !torch.vtensor<[1,2,16,32,100],f16> + %5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1,2,16,32,100],f16>, !torch.int -> !torch.vtensor<[2,16,32,100],f16> + + return %5 : !torch.vtensor<[2,16,32,100],f16> + } + } + """ + + # Get a temporary directory using tmp_path_factory + tmp_dir = tmp_path_factory.mktemp("vmfb_buffer_access_test") + mlir_path = tmp_dir / "kvcache.mlir" + mlir_path.write_text(KVCACHE_MODULE_CONTENTS) + vmfb_path = tmp_dir / "kvcache_cpu.vmfb" + + tools.compile_file( + str(mlir_path), + output_file=str(vmfb_path), + input_type="AUTO", + extra_args=compile_flags, + ) + + return vmfb_path + + +@pytest.fixture +def lsys(): + sc = sf.SystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() + + +@pytest.fixture +def fiber(lsys): + return lsys.create_fiber() + + +@pytest.fixture +def device(fiber): + return fiber.device(0) + + +def create_random_float16_array(size): + """Create an array of random uint16 values.""" + return array.array("H", [random.randint(0, 65535) for _ in range(size)]) + + +def create_scalar_device_array(device, value, dtype=sfnp.int64): + """Helper function to create a scalar device array.""" + arr = sfnp.device_array.for_device(device, [1], dtype) + staging = arr.for_transfer() + with staging.map(discard=True) as m: + m.fill(value) + arr.copy_from(staging) + return arr + + +@pytest.mark.parametrize("await_before_invoke", [True, False]) +def test_kvcache_noreturn(lsys, fiber, kvcache_compiled_cpu_path, await_before_invoke): + device = fiber.device(0) + program_module = lsys.load_module(kvcache_compiled_cpu_path) + program = sf.Program([program_module], devices=fiber.raw_devices) + + write_function = program["kvcache.write_kvcache"] + read_function = program["kvcache.read_kvcache"] + + # Test parameters + num_pages = 4 + num_layers = 26 + num_kv = 2 + batch_size = 16 + num_heads = 32 + head_dim = 100 + + test_data_size = batch_size * num_heads * head_dim + test_data = create_random_float16_array(test_data_size) + + total_dim = num_layers * num_kv * batch_size * num_heads * head_dim + assert total_dim == 2662400 + kvcache_shape = [num_pages, total_dim] + kvcache_data = array.array("H", [0] * (kvcache_shape[0] * kvcache_shape[1])) + + async def main(): + device_kvcache = sfnp.device_array(device, kvcache_shape, sfnp.float16) + device_new_data = sfnp.device_array( + device, [batch_size, num_heads, head_dim], sfnp.float16 + ) + + staging_kvcache = device_kvcache.for_transfer() + with staging_kvcache.map(discard=True) as m: + m.fill(kvcache_data) + device_kvcache.copy_from(staging_kvcache) + + staging_new_data = device_new_data.for_transfer() + with staging_new_data.map(discard=True) as m: + m.fill(test_data) + device_new_data.copy_from(staging_new_data) + + for layer_idx in range(2): + for kv_idx in range(num_kv): + page_index = create_scalar_device_array(device, 1) + layer_index = create_scalar_device_array(device, layer_idx) + + if await_before_invoke: + await device + ret = await write_function( + device_kvcache, + device_new_data, + page_index, + layer_index, + fiber=fiber, + ) + + if await_before_invoke: + await device + (read_result,) = await read_function( + device_kvcache, page_index, layer_index, fiber=fiber + ) + + host_result = read_result.for_transfer() + host_result.copy_from(read_result) + await device + + # Simple byte comparison of the arrays + result_array = array.array("H", host_result.items) + offset = kv_idx * test_data_size + result_slice = result_array[offset : offset + test_data_size] + assert result_slice.tobytes() == test_data.tobytes(), ( + f"KV cache read/write mismatch for layer {layer_idx}, " + f"{'key' if kv_idx == 0 else 'value'} state" + ) + + lsys.run(main()) diff --git a/shortfin/tests/local_scope_test.py b/shortfin/tests/local_scope_test.py new file mode 100644 index 000000000..028e560e8 --- /dev/null +++ b/shortfin/tests/local_scope_test.py @@ -0,0 +1,68 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import time + +from _shortfin import lib as sfl + + +@pytest.fixture +def lsys(): + sc = sfl.local.host.CPUSystemBuilder() + ls = sc.create_system() + yield ls + ls.shutdown() + + +@pytest.fixture +def fiber(lsys): + # TODO: Should adopt the main thread. + worker = lsys.create_worker("main") + return lsys.create_fiber(worker) + + +def test_raw_device_access(fiber): + first_name = fiber.device_names[0] + assert first_name == "cpu0" + first_device = fiber.raw_device(0) # By index + assert isinstance(first_device, sfl.local.host.HostCPUDevice) + assert first_device is fiber.raw_device(first_name) # By name + print(first_device) + named_devices = fiber.devices_dict + assert first_name in named_devices + with pytest.raises(ValueError): + fiber.raw_device("cpu1") + with pytest.raises(ValueError): + fiber.raw_device(1) + + +def test_devices_collection_access(fiber): + # # Access via devices pseudo collection. + first_device = fiber.raw_device(0) + assert fiber.devices.cpu0.raw_device is first_device + assert fiber.devices[0].raw_device is first_device + assert fiber.devices["cpu0"].raw_device is first_device + assert len(fiber.devices) == 1 + with pytest.raises(ValueError): + fiber.devices.cpu1 + with pytest.raises(ValueError): + fiber.devices[1] + iter_list = list(fiber.devices) + assert iter_list == [fiber.device(0)] + + +def test_device_affinity_repr(fiber): + assert ( + repr(sfl.local.DeviceAffinity(fiber.raw_device(0))) + == "DeviceAffinity(hostcpu:0:0@0[0x1])" + ) + assert repr(sfl.local.DeviceAffinity()) == "DeviceAffinity(ANY)" + + +def test_device_affinity_resolve(fiber): + # TODO: Need a fiber with multiple devices to test errors. + print(fiber.device(0, "cpu0", fiber.raw_device(0))) diff --git a/shortfin/version.json b/shortfin/version.json new file mode 100644 index 000000000..9519501ae --- /dev/null +++ b/shortfin/version.json @@ -0,0 +1,3 @@ +{ + "package-version": "3.1.0.dev" +} diff --git a/tuner/.gitignore b/tuner/.gitignore new file mode 100644 index 000000000..94d12c89f --- /dev/null +++ b/tuner/.gitignore @@ -0,0 +1,4 @@ +.venv/ + +# Tuning artifacts +tuning_*/ diff --git a/tuner/README.md b/tuner/README.md index 69821496e..3737f6bdf 100644 --- a/tuner/README.md +++ b/tuner/README.md @@ -1,16 +1,20 @@ # IREE dispatch auto-tuning scripts -`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`. +`libtuner.py` is the core Python script that provides the fundamental functions +for the tuning loop. It imports `candidate_gen.py` for candidate generation. To +implement the full tuning loop, `libtuner.py` requires a separate Python script +that uses the provided `TuningClient` API from `libtuner.py`. ## Prerequisites [Optional] Using virtual environments: ```shell -cd tuning +cd tuner python -m venv .venv source .venv/bin/activate ``` Install python dependencies: ```shell -pip install -r ./requirements-tuner.txt +pip install -r requirements-tuner.txt +pip install -r requirements-dev.txt ``` Using the IREE's Python bindings: - Building with CMake @@ -21,47 +25,13 @@ Using the IREE's Python bindings: - Set environment ```shell source ../iree-build/.env && export PYTHONPATH + export PATH="$(realpath ../iree-build/tools):$PATH" ``` -For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings) +For more information, refer to the [IREE +documentation](https://iree.dev/building-from-source/getting-started/#python-bindings). -### Overall flow +## Examples -1. Symlink all scripts and mlir/irpa files in your build dir. - - Symlink `iree-build-dir/tools` inside `tuning`. - - Symlink ML model MLIR and weights based on `unet.sh`. - -2. Copy the attention/matmul spec as `config.mlir` in the tuning dir. - -3. Temporarily comment out all the existing configs in `config.mlir`. - - Example: - ```mlir - // , @match_mmt_2048x10240x1280 -> @apply_op_config - // , @match_mmt_2048x1280x5120 -> @apply_op_config - // , @match_mmt_2048x1280x1280 -> @apply_op_config - ``` - -4. Compile a baseline unet -```shell -./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd -``` - -5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir. -```shell -cp dump-winograd/*_141_*benchmark.mlir ./141.mlir -``` - -6. Run the tuning script. - - Example: - ```shell - python punet_autotune.py 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024 - ``` - -7. Check the winner candidate in `result_summary.log`, find and copy the transform spec. - -8. Paste the transform spec into the `config.mlir` and uncomment them. - -9. Add the match function to the entry point in `config.mlir` - - Example: - ```mlir - @match_something -> @apply_op_config - ``` +Check the `examples` directory for sample tuners implemented with `libtuner`. +The [`dispatch` example](https://github.com/nod-ai/shark-ai/tree/main/tuner/examples/dispatch) +should be a good starting point for most users. diff --git a/tuner/candidate_gen.py b/tuner/candidate_gen.py deleted file mode 100755 index 5a878e072..000000000 --- a/tuner/candidate_gen.py +++ /dev/null @@ -1,1408 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Given an input dispatch, this code modifies the hyperparameters -# in the code and runs it. - -""" -Generate candidates by tweaking op configuration for tuning. - -It can be invoked in two ways: - 1. From another python script, import and call `tune()` - 2. Run this script directly from the command - -Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk - -""" - -import argparse -import logging -import math -import pickle -import re -import z3 -from dataclasses import asdict, dataclass -from enum import Enum -from os import mkdir, path, makedirs -from typing import Callable, Optional -from textwrap import indent -from abc import ABC, abstractmethod - -import iree.compiler as ireec -from iree.compiler import ir -from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen - - -tune_logger = logging.getLogger("tune") - - -class DispatchKind(Enum): - conv = 1 - mmt = 2 - contraction = 3 - batch_mmt = 4 - batch_matmul = 5 - broadcast_rhs_mmt = 6 - - -class ElementType(Enum): - i8 = 1 - i32 = 2 - f8 = 3 - f16 = 4 - f32 = 5 - - @property - def bitwidth(self) -> int: - match self: - case ElementType.i8 | ElementType.f8: - return 8 - case ElementType.f16: - return 16 - case ElementType.i32 | ElementType.f32: - return 32 - case _: - assert False, "unhandled case" - - def __str__(self) -> str: - return self.name - - -@dataclass -class ShapedType: - shape: list[int] - element_type: ElementType - - def rank(self) -> int: - return len(self.shape) - - @property - def bitwidth(self) -> int: - return self.element_type.bitwidth - - def __str__(self) -> str: - dim_to_str = lambda dim: str(dim) if dim != -1 else "?" - return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) - - -@dataclass -class MatmulSize: - M: int - N: int - K: int - B: int = 1 - - -@dataclass -class ProblemSize: - matmul_size: MatmulSize - lhs_type: ShapedType - rhs_type: ShapedType - res_type: ShapedType - dispatch_kind: DispatchKind - - @property - def MNK(self) -> tuple[int, int, int]: - return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) - - -@dataclass -class MfmaIntrinsic: - input_type: ElementType - m: int - n: int - k: int - output_type: ElementType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{input}_{self.m}x{self.n}x{self.k}_{output}" - - @staticmethod - def mfma_f16_16x16x16_f32(): - return MfmaIntrinsic(ElementType.f16, 16, 16, 16, ElementType.f32) - - @staticmethod - def mfma_f16_32x32x8_f32(): - return MfmaIntrinsic(ElementType.f16, 32, 32, 8, ElementType.f32) - - @staticmethod - def mfma_i8_16x16x32_i32(): - return MfmaIntrinsic(ElementType.i8, 16, 16, 32, ElementType.i32) - - @staticmethod - def mfma_i8_32x32x16_i32(): - return MfmaIntrinsic(ElementType.i8, 32, 32, 16, ElementType.i32) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f16_16x16x16_f32(), - MfmaIntrinsic.mfma_f16_32x32x8_f32(), - MfmaIntrinsic.mfma_i8_16x16x32_i32(), - MfmaIntrinsic.mfma_i8_32x32x16_i32(), - ] - - -@dataclass -class Configuration: - subgroup_size: int - workgroup_size: list[int] - intrinsic: MfmaIntrinsic - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int - waves_per_eu: int - - -class MlirRegex(str, Enum): - ssa_value = r"%[a-zA-Z0-9-_]+" - tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" - - @staticmethod - def dps_ins_two_args() -> str: - return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" - - @staticmethod - def dps_outs_one_arg() -> str: - return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" - - -def read_input_mlir(filename: str) -> list[str]: - with open(filename, "r") as f: - return f.readlines() - - -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes - - -@dataclass -class ConvDimInfo: - n: int - oh: int - ow: int - oc: int - fh: int - fw: int - ic: int - - @staticmethod - def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): - n, oh, ow, oc = res_shaped_type.shape - fh, fw, ic, _ = rhs_shaped_type.shape - return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) - - @staticmethod - def from_problem_size(problem_size: ProblemSize): - return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) - - -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - tile_size[idx] = m - if dim == "n": - tile_size[idx] = n - if dim == "k": - tile_size[idx] = k - return tile_size - - -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes - - -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = ", prefetch_shared_memory" - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config - - -def apply_configuration( - template: list[str], configuration: Configuration, tile_sizes: list[int] -) -> str: - tune_logger.info(f"Applying: {configuration}") - expr0 = re.compile( - r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" - ) - expr1 = re.compile( - r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," - ) - expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") - expr3 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" - repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' - repl3 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' - - new_mlir = "" - for line in template: - if "intrinsic =" in line: - line = re.sub(expr0, repl0, line) - if "LLVMGPUVectorDistribute " in line: - line = re.sub(expr1, repl1, line) - if "tile_sizes" in line: - line = re.sub(expr2, repl2, line) - if "amdgpu-waves-per-eu" in line: - line = re.sub(expr3, repl3, line) - new_mlir += line - - return new_mlir - - -def parse_tensor_type(tensor_type: str) -> ShapedType: - shape_match = re.search(MlirRegex.tensor_type, tensor_type) - assert shape_match - - shape_str = shape_match.group(1) - dims_and_elem = shape_str.split("x") - dims = [int(x) for x in dims_and_elem[:-1]] - elem = dims_and_elem[-1] - str_to_elem_ty = {x.name: x for x in ElementType} - return ShapedType(dims, str_to_elem_ty[elem]) - - -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: - return False - if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: - return False - return True - - return list(filter(is_compatible, MfmaIntrinsic.all())) - - -def get_mfma_intrinsic_constraints( - problem_size: ProblemSize, - intrinsic_m: z3.ArithRef, - intrinsic_n: z3.ArithRef, - intrinsic_k: z3.ArithRef, -) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) - assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" - return z3.Or( - *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics - ) - ) - - -def get_dispatch_constraints( - problem_size: ProblemSize, - tile_m: z3.ArithRef, - tile_n: z3.ArithRef, - tile_k: z3.ArithRef, -) -> list[z3.BoolRef]: - if problem_size.dispatch_kind != DispatchKind.conv: - return [] - - dim_info = ConvDimInfo.from_problem_size(problem_size) - conv_constraints = [] - # WARNING: This sometimes makes the constraints UNSAT for some reason. - conv_constraints += [tile_m <= dim_info.ow] - conv_constraints += [tile_n <= dim_info.oc] - conv_constraints += [tile_k <= dim_info.ic] - return conv_constraints - - -def calculate_shared_memory_usage_in_bytes( - problem_size: ProblemSize, - m: int | z3.ArithRef, - n: int | z3.ArithRef, - k: int | z3.ArithRef, -) -> int | z3.ArithRef: - lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) - rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) - return lhs_memory + rhs_memory - - -def generate_constraints( - problem_size: ProblemSize, - tile_sizes, - num_subgroups, - subgroup_size, - intrinsic_size, - workgroup_size, - subgroup_m_count, - subgroup_n_count, - waves_per_eu, -): - M, N, K = ( - problem_size.matmul_size.M, - problem_size.matmul_size.N, - problem_size.matmul_size.K, - ) - m, n, k = tile_sizes - intrinsic_mn, intrinsic_k = intrinsic_size - wg_x, wg_y, wg_z = workgroup_size - wg_threads = z3.Int("wg_threads") - constraints = [wg_threads == wg_x * wg_y * wg_z] - constraints += [subgroup_size == 64, wg_threads <= 1024] - constraints += [ - get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k - ) - ] - subgroup_k_count = 1 - constraints += [ - m >= intrinsic_mn, - m <= 512, - m <= M, - ] - constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] - constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] - for x in (subgroup_m_count, subgroup_n_count): - constraints += [x >= 1, x <= 32] - - subgroup_m_tile_count = z3.Int("sg_m_tcnt") - subgroup_n_tile_count = z3.Int("sg_n_tcnt") - subgroup_k_tile_count = z3.Int("sg_k_tcnt") - for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): - constraints += [x >= 1, x <= 32] - - constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] - constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] - constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] - constraints += [wg_x == subgroup_size * subgroup_n_count] - constraints += [wg_y == subgroup_m_count] - constraints += [wg_z == subgroup_k_count] - constraints += [z3.Or(wg_x <= n, wg_x <= m)] - constraints += [k % intrinsic_mn == 0] - constraints += [(k * n) % wg_threads == 0] - constraints += [(k * m) % wg_threads == 0] - subgroups = subgroup_m_count * subgroup_n_count - if num_subgroups > 0: - constraints += [subgroups == num_subgroups] - else: - constraints += [subgroups >= 1, subgroups <= 10] - - constraints += [waves_per_eu == 2] - # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] - - shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) - constraints += [shared_memory <= 65536] - - constraints += get_dispatch_constraints(problem_size, m, n, k) - - return constraints - - -def generate_solutions(problem_size: ProblemSize, num_subgrups: int): - M, N, K = problem_size.MNK - tune_logger.info(f"{M},{N},{K}") - m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") - subgroup_size = z3.Int("subgroup_size") - intrinsic_mn = z3.Int("intrinsic_mn") - intrinsic_k = z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") - sg_m_cnt = z3.Int("sg_m_cnt") - sg_n_cnt = z3.Int("sg_n_cnt") - waves_per_eu = z3.Int("waves_per_eu") - all_vars = [ - m, - n, - k, - subgroup_size, - intrinsic_mn, - intrinsic_k, - wg_x, - wg_y, - wg_z, - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ] - - solver = z3.Solver() - constraints = generate_constraints( - problem_size, - [m, n, k], - num_subgrups, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - solver.add(z3.simplify(z3.And(constraints))) - tune_logger.debug(f"Initial constraints: {solver}") - i = 0 - while solver.check() == z3.sat: - model = solver.model() - lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( - problem_size.lhs_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.res_type.element_type, - ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), - lookup(waves_per_eu), - ) - solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) - i += 1 - yield config - - -def get_default_output_dir() -> str: - from datetime import datetime - - return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") - - -def parse_mlir(mlir_text: str) -> ir.Module: - mlir_module = None - with ireec.ir.Context() as context: - try: - mlir_module = ireec.ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ireec.ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") - - return mlir_module - - -@dataclass -class MLIRTransformation: - """Transformation of MLIR context""" - - template: str - modified: str - embeddable: str - - -class DispatchTuner(ABC): - @abstractmethod - def supports(self, op_name: str) -> bool: - """Check if the tuner can handle the type of operation represented by the input string.""" - pass - - @abstractmethod - def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of thge operation.""" - pass - - @abstractmethod - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - """Apply parameter transformations to the operation.""" - pass - - -@dataclass -class OpWalkResult: - was_interrupted: bool = False - dispatch_tuner: Optional[DispatchTuner] = None - - -class DispatchTunerRegistry: - def __init__(self): - self.registry = set() - - def register(self, dispatch_tuners: list[DispatchTuner]) -> None: - for dispatch_tuner in dispatch_tuners: - self.registry.add(dispatch_tuner) - - def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: - for attr in attrs: - if (attr.name == "translation_info") and ( - "LLVMGPUVectorDistribute" in str(attr.attr) - ): - return True - assert False, "Translation info not supported" - - def find_handler(self, op_name: str) -> DispatchTuner: - for dispatch_tuner in self.registry: - if dispatch_tuner.supports(op_name): - return dispatch_tuner - assert False, "Dispatch kind not supported" - - -class MmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - mmt_re = None - dps = None - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], - rhs_shaped_type.shape[0], - lhs_shaped_type.shape[1], - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - assert mmt_re - assert dps, f"'{mmt_re}' not found in given context" - - def get_transform_function_mmt( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration - ) -> str: - tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ - %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op - %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value - %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value - transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value - transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_mmt( - problem_size, f"match_mmt_{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_mmt_tile_sizes(configuration) - ) - embeddable = indent( - self.get_transform_function_mmt(problem_size, f"match_op", configuration), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class ConvTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "conv_2d_nhwc_hwcf" in op_name - - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - def get_transform_function_conv( - self, problem_size: ProblemSize, functionName: str, configuration: Configuration - ) -> str: - dynamic_batch_input_ty = problem_size.lhs_type - dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() - dynamic_batch_input_ty.shape[0] = -1 - - dynamic_batch_output_ty = problem_size.res_type - dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() - dynamic_batch_output_ty.shape[0] - 1 - - input = f"tensor<{dynamic_batch_input_ty}>" - filter = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{dynamic_batch_output_ty}>" - - tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ - ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): - %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} - ins(%lhs, %rhs : {input}, {filter}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %conv, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - conv_dims = ConvDimInfo.from_problem_size(problem_size) - modified = indent( - self.get_transform_function_conv( - problem_size, - f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, configuration, self.get_conv_tile_sizes(configuration) - ) - embeddable = indent( - self.get_transform_function_conv(problem_size, f"match_op", configuration), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class ContractionTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "matmul_like" in op_name - - def is_broadcast_rhs_mmt_op(self, line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(self.is_broadcast_rhs_mmt_op(line) for line in template) - - def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: - for line in template: - if not self.is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_shapes(self, template: list[str]) -> ProblemSize: - if self.is_broadcast_rhs_mmt(template): - return self.get_shapes_broadcast_rhs_mmt(template) - - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - - def get_transform_function_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, - ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - lhs_dynamic_batch = problem_size.lhs_type - lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() - lhs_dynamic_batch.shape[0] = -1 - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params_broadcast_rhs_mmt( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - self.get_transform_function_broadcast_rhs_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - if self.is_broadcast_rhs_mmt(template): - return self.apply_params_broadcast_rhs_mmt( - problem_size, template, configuration - ) - - # TODO: Generate transform function. - return MLIRTransformation( - template, - apply_configuration( - template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), - ), - "", - ) - - -class BatchMmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "batch_matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - - def get_transform_function_batch_mmt( - self, - problem_size: ProblemSize, - functionName: str, - configuration: Configuration, - ) -> str: - tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" -transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ -%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op -%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value -%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value -transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value -transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value -%config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param -transform.yield %generic, %config : !transform.any_op, !transform.any_param -}} -""" - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - B = problem_size.matmul_size.B - modified = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration - ), - "// ", - ) - modified += apply_configuration( - template, configuration, get_batch_mmt_tile_sizes(configuration) - ) - - embeddable = indent( - self.get_transform_function_batch_mmt( - problem_size, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "batch_matmul" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - B0 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) - ) - B1 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) - ) - M = math.prod( - val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - N = math.prod( - val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - K0 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - K1 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - - def get_transform_function_batch_matmul( - self, - problem_size: ProblemSize, - tile_dims: str, - functionName: str, - configuration: Configuration, - ) -> str: - input0 = f"tensor<{problem_size.lhs_type}>" - input1 = f"tensor<{problem_size.rhs_type}>" - output = f"tensor<{problem_size.res_type}>" - - tile_sizes = ", ".join( - map(str, get_contract_tile_sizes(configuration, tile_dims)) - ) - - wg_x, wg_y, wg_z = configuration.workgroup_size - extra_config = get_pipeline_config(configuration) - - return f""" - transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) - -> (!transform.any_op, !transform.any_param) {{ - %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ - ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): - %13 = linalg.batch_matmul - ins(%lhs, %rhs : {input0}, {input1}) - outs(%out : {output}) -> {output} - }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) - %config = transform.param.constant #iree_codegen.compilation_info< - lowering_config = #iree_codegen.lowering_config, - translation_info = #iree_codegen.translation_info, - subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> - {extra_config}}}> - > -> !transform.any_param - transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param - }} - """ - - def apply_params( - self, - problem_size: ProblemSize, - template: list[str], - configuration: Configuration, - ) -> MLIRTransformation: - M, N, K = problem_size.MNK - modified = indent( - self.get_transform_function_batch_matmul( - problem_size, - self.tile_dims, - f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", - configuration, - ), - "// ", - ) - modified += apply_configuration( - template, - configuration, - get_contract_tile_sizes(configuration, self.tile_dims), - ) - - embeddable = indent( - self.get_transform_function_batch_matmul( - problem_size, self.tile_dims, f"match_op", configuration - ), - " ", - ) - return MLIRTransformation(template, modified, embeddable) - - -def walk_callback_get_fn( - op: ir.Operation, - walk_result: OpWalkResult, - dispatch_tuner_registry: DispatchTunerRegistry, -) -> ir.WalkResult: - if op.name == "func.func": - dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) - if op.name == "util.func": - func_name = str(op.opview.sym_name) - walk_result.was_interrupted = True - walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) - return ir.WalkResult.INTERRUPT - return ir.WalkResult.ADVANCE - - -def walk_mlir_op( - mlir_module: ir.Module, - dispatch_tuner_registry: DispatchTunerRegistry, -) -> OpWalkResult: - walk_result = OpWalkResult() - for op in mlir_module.body.operations: - op.walk( - lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), - ir.WalkOrder.POST_ORDER, - ) - if walk_result.was_interrupted: - break - return walk_result - - -def tune( - input: str, # Path to the mlir file to be tuned - output: str = "", # Path to the output directory, auto creates one if not given - limit: int = 4096, # Max candidates to be generated - num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints - lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations - rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations - tile_dims: str = "mnk", # Dimensions for the tile size -): - input_file = str(input) - - if not output: - output = get_default_output_dir() - - # Create the directory if it does not exist - makedirs(str(output), exist_ok=True) - - tune_logger.debug(f"Output directory {output}") - tune_logger.debug(f"Processing {input_file}") - mlir_template = read_input_mlir(input_file) - mlir_text = "".join(mlir_template) - - mlir_module = parse_mlir(mlir_text) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) - - walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - dispatch_tuner = walk_result.dispatch_tuner - problem_size = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input", help="Input mlir file", type=str) - parser.add_argument( - "-o", "--output", help="Output dir", type=str, default=get_default_output_dir() - ) - parser.add_argument( - "-l", - "--limit", - help="Max number of candidates generated", - type=int, - default=4096, - ) - parser.add_argument( - "--num-subgroups", - help="Number of subgroups per workgroup to use. (-1 == unconstrained)", - type=int, - default=-1, - ) - parser.add_argument( - "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" - ) - parser.add_argument( - "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" - ) - parser.add_argument( - "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" - ) - parser.add_argument( - "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" - ) - - args = parser.parse_args() - tune_logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) - - # Create printing formatter for logging info - formatter = logging.Formatter("%(message)s") - - # Create a handler to print to console - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - tune_logger.addHandler(console_handler) - - # # Optionally, add a file handler to log to a file - # file_handler = logging.FileHandler("tune.log") - # file_handler.setFormatter(formatter) - # tune_logger.addHandler(file_handler) - - tune( - args.input, - args.output, - args.limit, - args.num_subgroups, - args.lhs_dims, - args.rhs_dims, - args.tile_dims, - ) - - -if __name__ == "__main__": - args = main() diff --git a/tuner/candidate_gen_test.py b/tuner/candidate_gen_test.py deleted file mode 100644 index ee0a32c66..000000000 --- a/tuner/candidate_gen_test.py +++ /dev/null @@ -1,814 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Usage: python -m pytest candidate_gen_test.py -""" - -import pytest -import candidate_gen - - -def test_get_shaped_type_element_bitwidth(): - assert ( - candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 - ) - assert ( - candidate_gen.ShapedType( - [2048, 512, 384], candidate_gen.ElementType.f8 - ).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 - ) - - -def test_get_shaped_type_to_str(): - assert ( - str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) - == "1024x2048xi8" - ) - assert ( - str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) - == "1024xf32" - ) - assert ( - str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) - == "1x2x3xf16" - ) - assert ( - str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) - == "?x2x3xf16" - ) - - -def test_parse_tensor_type(): - assert candidate_gen.parse_tensor_type( - "tensor<1x2x3xf32>" - ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) - assert candidate_gen.parse_tensor_type( - "tensor<123xi8>" - ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) - - -def test_get_mmt_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic="", - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - waves_per_eu=0, - ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] - - -def test_get_conv_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic="#iree_gpu.mma_layout", - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=1, - ) - assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 16, - ] - - -def test_get_contract_tile_sizes(): - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ - 16, - 16, - 16, - ] - - -def test_get_pipeline_config(): - config1 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=2, - ) - config2 = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic="", - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - waves_per_eu=4, - ) - assert candidate_gen.get_pipeline_config(config1) == ", prefetch_shared_memory" - assert ( - candidate_gen.get_pipeline_config(config2) - == ', prefetch_shared_memory, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - ) - - -def test_get_shapes_mmt(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - - -def test_get_shapes_conv(): - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 256, 11520), - candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - - -def test_get_shapes_contract(): - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul(): - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 32, 1024, 1), - candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt(): - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMmtTuner().get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - -def test_mfma_intrinsic_to_str(): - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) - == "MFMA_F16_16x16x16_F32" - ) - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) - == "MFMA_I8_32x32x16_I32" - ) - - -def test_get_compatible_mfma_intrinsics(): - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_i8_16x16x32_i32(), - candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - ] - - -def test_generate_solutions(): - matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) - lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - configs = candidate_gen.generate_solutions(problem_size, 4) - assert configs is not None - - -def test_calculate_shared_memory_usage_in_bytes(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 147456 - ) - - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 81920 - ) - - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) - == 12288 - ) - - -def test_generate_constraints_valid_input(): - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - # Define input parameters as z3 Ints - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are satisfiable - assert solver.check() == candidate_gen.z3.sat - - -def test_generate_constraints_invalid_input(): - # Define input parameters that should lead to unsatisfiable constraints - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt - ) - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - constraints.append(m > 1000) # Adding an additional unsatisfiable constraint - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are unsatisfiable - assert solver.check() == candidate_gen.z3.unsat - - -def test_apply_params_mmt(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - M, N, K = 2048, 1280, 1280 - - config = candidate_gen.Configuration( - subgroup_size=16, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[8, 8, 8], - subgroup_m_count=16, - subgroup_n_count=16, - waves_per_eu=8, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(M, N, K), - candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" - in modified - ) - assert "tile_sizes = [[8, 8, 8]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified - - -def test_apply_params_conv(): - mlir_template = [ - ", subgroup_m_count = 16, subgroup_n_count = 16>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', - ] - - n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType( - [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 - ), - candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - tf_mlir = candidate_gen.ConvTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_contract(): - mlir_template = [ - ", subgroup_m_count = 2, subgroup_n_count = 2>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "*mnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 3840, 1280), - candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[480, 384, 32], - subgroup_m_count=1, - subgroup_n_count=4, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config - ) - - new_mlir = tf_mlir.modified - - assert new_mlir - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" - in new_mlir - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" - in new_mlir - ) - assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir - - -def test_apply_params_batch_matmul(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - tile_dims = "bmnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_32x32x8_f32(), - tile_sizes=[416, 320, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert embeddable - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 416, 320, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_float(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=2, - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert embeddable - assert modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified - - -def test_apply_params_batch_mmt_int(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - tf_mlir = candidate_gen.BatchMmtTuner().apply_params( - problem_size, mlir_template, config - ) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_apply_params_broadcast_rhs_mmt(): - mlir_template = [ - ", subgroup_m_count = 4, subgroup_n_count = 1>}>", - "", - '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', - ] - - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.broadcast_rhs_mmt, - ) - - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32(), - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, - waves_per_eu=4, - ) - - tf_mlir = candidate_gen.ContractionTuner( - "mk", "nk", "mnk" - ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) - - modified = tf_mlir.modified - embeddable = tf_mlir.embeddable - - assert modified - assert ( - "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" - in modified - ) - assert ( - "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" - in modified - ) - assert ( - "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" - in modified - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in modified - assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified - - assert embeddable - assert "transform.named_sequence @match_op(" in embeddable - assert ( - "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" - in embeddable - ) - assert ( - "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" - in embeddable - ) - assert ( - "%config = transform.param.constant #iree_codegen.compilation_info<" - in embeddable - ) - assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable - assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable - assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable - - -def test_detect_broadcast_rhs_mmt(): - mlir_lines = [ - r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( - mlir_lines - ) - - -def test_parse_mlir(): - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str) - assert mlir_module != None - assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) - assert isinstance( - mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp - ) diff --git a/tuner/examples/__init__.py b/tuner/examples/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/examples/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/examples/dispatch/.gitignore b/tuner/examples/dispatch/.gitignore new file mode 100644 index 000000000..9fb2fe16a --- /dev/null +++ b/tuner/examples/dispatch/.gitignore @@ -0,0 +1,3 @@ +# Test files/dirs recommended by README.md. +dump/ +benchmark.mlir diff --git a/tuner/examples/dispatch/README.md b/tuner/examples/dispatch/README.md new file mode 100644 index 000000000..70c46e08a --- /dev/null +++ b/tuner/examples/dispatch/README.md @@ -0,0 +1,35 @@ +# Dispatch Tuner + +Allows to tune a single dispatch in isolation. + +## Environments +Follow instructions in [`/tuner/README.md`](../README.md) + +## Running the Dispatch Tuner + +### Generate a benchmark file +Use the usual `iree-compile` command for your dispatch and add +`--iree-hal-dump-executable-files-to=dump`. For example: +```shell +iree-compile mmt.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-hal-dump-executable-files-to=dump -o /dev/null +``` + +Next, copy the `*_benchmark.mlir` file to some temporary directory of choice. +This will be the input to the dispatch tuner. + +### Recommended Trial Run +For an initial trial to test the tuning loop, use: +```shell +python -m examples.dispatch benchmark.mlir --num-candidates=20 +``` + +### Dry Run Test +To perform a dry run (no GPU required), use: +```shell +python -m examples.dispatch benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run +``` + +### Basic Usage +```shell +python -m examples.dispatch benchmark.mlir +``` diff --git a/tuner/examples/dispatch/__init__.py b/tuner/examples/dispatch/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/examples/dispatch/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/libshortfin/bindings/python/_shortfin/__init__.py b/tuner/examples/dispatch/__main__.py similarity index 53% rename from libshortfin/bindings/python/_shortfin/__init__.py rename to tuner/examples/dispatch/__main__.py index ac43edaea..9fb86fd9f 100644 --- a/libshortfin/bindings/python/_shortfin/__init__.py +++ b/tuner/examples/dispatch/__main__.py @@ -4,9 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# The proper way to import this package is via: -# from _shortfin import lib as sfl +from . import dispatch_tuner -# TODO: Use environment variables to determine which built variant to -# import. -from _shortfin_default import lib +dispatch_tuner.main() diff --git a/tuner/examples/dispatch/compile_dispatch.sh b/tuner/examples/dispatch/compile_dispatch.sh new file mode 100755 index 000000000..0b01ac991 --- /dev/null +++ b/tuner/examples/dispatch/compile_dispatch.sh @@ -0,0 +1,18 @@ +#! /usr/bin/env bash + +set -eou pipefail + +readonly INPUT="$1" +readonly DIR="$(dirname "$INPUT")" +readonly BASENAME="$(basename "$INPUT" .mlir)" +readonly OUT="${DIR}/compiled/${BASENAME}.vmfb" + +iree-compile "$INPUT" -o "$OUT" \ + --compile-from=executable-sources 2>/dev/null || (mv "$INPUT" "$DIR/failed" && exit 1) + +iree-dump-module "$OUT" | grep -q 'rocm-hsaco-fb' || (mv "$INPUT" "$DIR/failed" && rm -f "$OUT" && exit 1) +if [ -f "${DIR}/${BASENAME}_config.mlir" ]; then + cat "${DIR}/../config_prolog.mlir" "${DIR}/${BASENAME}_config.mlir" "${DIR}/../config_epilog.mlir" > "${DIR}/specs/${BASENAME}_spec.mlir" +fi + +echo "Compiling ${INPUT}: success" diff --git a/tuner/examples/dispatch/config_epilog.mlir b/tuner/examples/dispatch/config_epilog.mlir new file mode 100644 index 000000000..c15a30502 --- /dev/null +++ b/tuner/examples/dispatch/config_epilog.mlir @@ -0,0 +1,12 @@ + +//===----------------------------------------------------------------------===// +// Entry point +//===----------------------------------------------------------------------===// + + transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.consumed}) { + transform.foreach_match in %variant_op + , @match_op -> @apply_op_config + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} //// module diff --git a/tuner/examples/dispatch/config_prolog.mlir b/tuner/examples/dispatch/config_prolog.mlir new file mode 100644 index 000000000..377ac3f8f --- /dev/null +++ b/tuner/examples/dispatch/config_prolog.mlir @@ -0,0 +1,32 @@ +// Transform dialect specification for attention on MI300 with MFMA. +module attributes { transform.with_named_sequence } { +//===----------------------------------------------------------------------===// +// Matmul tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.extf %in_0 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %acc, %10 : f32 + linalg.yield %11 : f32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op + } + + transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}, %config: !transform.any_param {transform.readonly}) { + transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param + // transform.print %op {name = "Applied"} : !transform.any_op + transform.yield + } diff --git a/tuner/examples/dispatch/dispatch_tuner.py b/tuner/examples/dispatch/dispatch_tuner.py new file mode 100644 index 000000000..3c2d77f64 --- /dev/null +++ b/tuner/examples/dispatch/dispatch_tuner.py @@ -0,0 +1,138 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Sample Usage: + +python -m examples.dispatch benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64 + + +Recommended Trial Run: + +python -m examples.dispatch benchmark.mlir --num-candidates=10 + + +Dry Run Test (no gpu required): + +python -m examples.dispatch benchmark.mlir --num-candidates=64 --dry-run + +""" + +from tuner import libtuner +from pathlib import Path, PurePath +import os + + +class DispatchTuner(libtuner.TuningClient): + def get_dispatch_compile_timeout_s(self) -> int: + return 10 + + def get_dispatch_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + assert candidate_tracker.dispatch_mlir_path is not None + mlir_path: Path = candidate_tracker.dispatch_mlir_path + script_dir = Path(__file__).resolve().parent + command = [ + (script_dir / "compile_dispatch.sh").as_posix(), + mlir_path.as_posix(), + ] + return command + + def get_dispatch_benchmark_timeout_s(self) -> int: + return 15 + + def get_dispatch_benchmark_command( + self, + candidate_tracker: libtuner.CandidateTracker, + ) -> list[str]: + compiled_vmfb_path = candidate_tracker.compiled_dispatch_path + assert compiled_vmfb_path is not None + + command = [ + "iree-benchmark-module", + f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", + f"--module={compiled_vmfb_path.resolve()}", + "--batch_size=1000", + "--benchmark_repetitions=3", + "--benchmark_format=json", + ] + + return command + + def get_model_compile_timeout_s(self) -> int: + return 0 + + def get_model_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + return [] + + def get_model_benchmark_timeout_s(self) -> int: + return 0 + + def get_model_benchmark_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + return [] + + +def main(): + args = libtuner.parse_arguments() + path_config = libtuner.PathConfig() + # These will not be used, so always default to the empty config in the script dir. + script_dir = Path(__file__).resolve().parent + path_config.global_config_prolog_mlir = ( + script_dir / path_config.global_config_prolog_mlir + ) + path_config.global_config_epilog_mlir = ( + script_dir / path_config.global_config_epilog_mlir + ) + path_config.base_dir.mkdir(parents=True, exist_ok=True) + path_config.output_unilog.touch() + candidate_trackers: list[libtuner.CandidateTracker] = [] + dispatch_tuner = DispatchTuner() + stop_after_phase: str = args.stop_after + + print("Setup logging") + libtuner.setup_logging(args, path_config) + print(path_config.run_log, end="\n\n") + + if not args.dry_run: + print("Validating devices") + libtuner.validate_devices(args.devices) + print("Validation successful!\n") + + print("Generating candidates...") + candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) + print(f"Stored candidates in {path_config.candidates_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling candidates...") + compiled_candidates = libtuner.compile_dispatches( + args, path_config, candidates, candidate_trackers, dispatch_tuner + ) + print(f"Compiled files are stored in {path_config.compiled_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: + return + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark_dispatches( + args, path_config, compiled_candidates, candidate_trackers, dispatch_tuner + ) + print(f"\nStored results in {path_config.output_unilog.resolve()}\n") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + return + + libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) + print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") + + print("Check the detailed execution logs in:") + print(path_config.run_log.resolve()) + + for candidate in candidate_trackers: + libtuner.logging.debug(candidate) diff --git a/tuner/examples/dispatch/mmt.mlir b/tuner/examples/dispatch/mmt.mlir new file mode 100644 index 000000000..b9d6c5f4c --- /dev/null +++ b/tuner/examples/dispatch/mmt.mlir @@ -0,0 +1,11 @@ +!matA_0 = tensor<2048x1280xf16> +!matB_0 = tensor<10240x1280xf16> +!matC_0 = tensor<2048x10240xf32> + +func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 { + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : !matC_0 + %6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0 + %8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0 + return %8 : !matC_0 +} diff --git a/tuner/examples/punet/.gitignore b/tuner/examples/punet/.gitignore new file mode 100644 index 000000000..fae904ffb --- /dev/null +++ b/tuner/examples/punet/.gitignore @@ -0,0 +1,3 @@ +# Test files/dirs recommended by README.md. +dump-mmt +test-benchmark.mlir diff --git a/tuner/examples/punet/README.md b/tuner/examples/punet/README.md new file mode 100644 index 000000000..777d1c194 --- /dev/null +++ b/tuner/examples/punet/README.md @@ -0,0 +1,46 @@ +# Punet Tuner + +## Environments +Follow instructions in [`/tuner/README.md`](../README.md) + +## Shell Scripts + +The required shell scripts can be downloaded from: +[sdxl-scripts](https://github.com/nod-ai/sdxl-scripts). + +These scripts include: +1. `compile-punet-base.sh` - Used for compiling model candidates. +2. `compile_candidate.sh` - Used for compiling dispatch candidates. +3. `punet.sh` - Invoked by `compile_candidate.sh`. + +Add the parent directories of these scripts to your `PATH` environment variable, +so that they can be picked up by `punet_autotune.py`. + +## Running the Tuner + +### [Optional] Generate a tunable mlir +Use +[`punet.sh`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/punet.sh) +to compile the sample matmul `mmt.mlir` (can also find here: +[`mmt_unet.mlir`](https://github.com/nod-ai/sdxl-scripts/blob/main/tuning/mmt_unet.mlir)): +```shell +punet.sh mmt.mlir -o mmt.vmfb --iree-hal-dump-executable-files-to=dump-mmt +cp ./dump-mmt/module_main_0_dispatch_0_rocm_hsaco_fb_benchmark.mlir test-benchmark.mlir +``` + +### Recommended Trial Run +For an initial trial to test the tuning loop, use: +```shell +python -m examples.punet test-benchmark.mlir --num-candidates=10 +``` + +### Dry Run Test +To perform a dry run (no GPU required), use: +```shell +python -m examples.punet test-benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run +``` + +### Basic Usage +```shell +python -m examples.punet test-benchmark.mlir +``` diff --git a/tuner/examples/punet/__init__.py b/tuner/examples/punet/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/examples/punet/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/libshortfin/tests/host_cpu_system_test.py b/tuner/examples/punet/__main__.py similarity index 50% rename from libshortfin/tests/host_cpu_system_test.py rename to tuner/examples/punet/__main__.py index 1a37f02c1..ca092d502 100644 --- a/libshortfin/tests/host_cpu_system_test.py +++ b/tuner/examples/punet/__main__.py @@ -4,13 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from . import punet_autotune -def test_create_host_cpu_system(): - from _shortfin import lib as sfl - - sc = sfl.local.host.CPUSystemBuilder() - ls = sc.create_system() - print(f"LOCAL SYSTEM:", ls) - - print("Sleeping in Python") - ls.shutdown() +punet_autotune.main() diff --git a/tuner/examples/punet/mmt.mlir b/tuner/examples/punet/mmt.mlir new file mode 100644 index 000000000..b9d6c5f4c --- /dev/null +++ b/tuner/examples/punet/mmt.mlir @@ -0,0 +1,11 @@ +!matA_0 = tensor<2048x1280xf16> +!matB_0 = tensor<10240x1280xf16> +!matC_0 = tensor<2048x10240xf32> + +func.func @main_0(%arg0: !matA_0, %arg1: !matB_0) -> !matC_0 { + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : !matC_0 + %6 = linalg.fill ins(%cst : f16) outs(%5 : !matC_0) -> !matC_0 + %8 = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA_0, !matB_0) outs(%6 : !matC_0) -> !matC_0 + return %8 : !matC_0 +} diff --git a/tuner/examples/punet/punet_autotune.py b/tuner/examples/punet/punet_autotune.py new file mode 100644 index 000000000..3503c86df --- /dev/null +++ b/tuner/examples/punet/punet_autotune.py @@ -0,0 +1,185 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Sample Usage: + +python -m examples.punet benchmark.mlir --lhs-dims=bmk --rhs-dims=bkn --tile-dims=*mnk --devices=hip://0,hip://1 --num-candidates=64 + + +Recommended Trial Run: + +python -m examples.punet benchmark.mlir --num-candidates=1 + + +Dry Run Test (no gpu requried): + +python -m examples.punet benchmark.mlir --num-candidates=64 --num-model-candidates=10 --dry-run + +""" + +from tuner import libtuner +from pathlib import Path + + +class PunetClient(libtuner.TuningClient): + def get_dispatch_compile_timeout_s(self) -> int: + return 4 + + def get_dispatch_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + mlir_path = candidate_tracker.dispatch_mlir_path + assert mlir_path is not None + command = [ + "compile_candidate.sh", + mlir_path.as_posix(), + ] + return command + + def get_dispatch_benchmark_timeout_s(self) -> int: + return 15 + + def get_dispatch_benchmark_command( + self, + candidate_tracker: libtuner.CandidateTracker, + ) -> list[str]: + compiled_vmfb_path = candidate_tracker.compiled_dispatch_path + assert compiled_vmfb_path is not None + + command = [ + "iree-benchmark-module", + f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", + f"--module={compiled_vmfb_path.resolve()}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--batch_size=1000", + "--benchmark_repetitions=3", + "--benchmark_format=json", + ] + + return command + + def get_model_compile_timeout_s(self) -> int: + return 300 + + def get_model_compile_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + mlir_spec_path = candidate_tracker.spec_path + assert mlir_spec_path is not None + target_dir = mlir_spec_path.resolve().parent.parent.parent + output_name = f"unet_candidate_{candidate_tracker.candidate_id}.vmfb" + command = [ + "compile-punet-base.sh", + "iree-compile", + "gfx942", + f"{mlir_spec_path.resolve()}", + "./punet.mlir", + "-o", + (target_dir / output_name).as_posix(), + ] + return command + + def get_model_benchmark_timeout_s(self) -> int: + return 180 + + def get_model_benchmark_command( + self, candidate_tracker: libtuner.CandidateTracker + ) -> list[str]: + unet_candidate_path = candidate_tracker.compiled_model_path + assert unet_candidate_path is not None + + command = [ + "iree-benchmark-module", + f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--device_allocator=caching", + f"--module={unet_candidate_path.resolve()}", + "--parameters=model=punet.irpa", + "--function=main", + "--input=1x4x128x128xf16", + "--input=1xsi32", + "--input=2x64x2048xf16", + "--input=2x1280xf16", + "--input=2x6xf16", + "--input=1xf16", + "--benchmark_repetitions=5", + "--benchmark_format=json", + ] + return command + + +def main(): + args = libtuner.parse_arguments() + path_config = libtuner.PathConfig() + path_config.base_dir.mkdir(parents=True, exist_ok=True) + path_config.output_unilog.touch() + candidate_trackers: list[libtuner.CandidateTracker] = [] + punet_client = PunetClient() + stop_after_phase: str = args.stop_after + + print("Setup logging") + libtuner.setup_logging(args, path_config) + print(path_config.run_log, end="\n\n") + + if not args.dry_run: + print("Validating devices") + libtuner.validate_devices(args.devices) + print("Validation successful!\n") + + print("Generating candidates...") + candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) + print(f"Stored candidates in {path_config.candidates_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: + return + + print("Compiling candidates...") + compiled_candidates = libtuner.compile_dispatches( + args, path_config, candidates, candidate_trackers, punet_client + ) + print(f"Compiled files are stored in {path_config.compiled_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_dispatches: + return + + print("Benchmarking compiled candidates...") + top_candidates = libtuner.benchmark_dispatches( + args, path_config, compiled_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}\n") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_dispatches: + return + + print(f"Compiling top model candidates...") + punet_candidates = libtuner.compile_models( + args, path_config, top_candidates, candidate_trackers, punet_client + ) + print(f"Model candidates compiled in {path_config.base_dir}\n") + if stop_after_phase == libtuner.ExecutionPhases.compile_models: + return + + print("Benchmarking model candidates...") + libtuner.benchmark_models( + args, path_config, punet_candidates, candidate_trackers, punet_client + ) + print(f"Stored results in {path_config.output_unilog}") + if stop_after_phase == libtuner.ExecutionPhases.benchmark_models: + return + + libtuner.summerize_top_candidates(path_config, candidate_trackers) + print(f"Stored top candidates info in {path_config.result_summary_log}\n") + + libtuner.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers) + print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n") + + print("Check the detailed execution logs in:") + print(path_config.run_log) + + for candidate in candidate_trackers: + libtuner.logging.debug(candidate) + if args.verbose: + print(candidate) diff --git a/tuner/pyproject.toml b/tuner/pyproject.toml new file mode 100644 index 000000000..c36326bf7 --- /dev/null +++ b/tuner/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "SHARK Tuner" +authors = [ + {name = "SHARK Authors"}, +] +description = "IREE Dispatch Tuner" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">= 3.10" + +# Version is set via the `setup.py`. +dynamic = ["version"] + +[project.urls] +Repository = "https://github.com/nod-ai/shark-ai" diff --git a/tuner/requirements-dev.txt b/tuner/requirements-dev.txt index 51d5b9ba0..747b28508 100644 --- a/tuner/requirements-dev.txt +++ b/tuner/requirements-dev.txt @@ -1,2 +1,3 @@ +mypy==1.8.0 pre-commit==3.8.0 virtualenv==20.13.0 diff --git a/tuner/setup.py b/tuner/setup.py new file mode 100644 index 000000000..aa450eaee --- /dev/null +++ b/tuner/setup.py @@ -0,0 +1,35 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os + +from setuptools import setup + +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) + +# Setup and get version information. +VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json") +VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json") + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +try: + version_info = load_version_info(VERSION_FILE_LOCAL) +except FileNotFoundError: + print("version_local.json not found. Default to dev build") + version_info = load_version_info(VERSION_FILE) + +PACKAGE_VERSION = version_info.get("package-version") +print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") + +setup( + version=f"{PACKAGE_VERSION}", +) diff --git a/tuner/tuner/__init__.py b/tuner/tuner/__init__.py new file mode 100644 index 000000000..a85ba359d --- /dev/null +++ b/tuner/tuner/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py new file mode 100644 index 000000000..38696e6db --- /dev/null +++ b/tuner/tuner/candidate_gen.py @@ -0,0 +1,633 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +""" +Generate candidates by tweaking op configuration for tuning. + +It can be invoked in two ways: + 1. From another python script, import and call `tune()` + 2. Run this script directly from the command + +Usage: ./candidate_gen.py 121.mlir -o "tuning/candidates" -l 1024 --lhs-dims=mk --rhs-dims=nk --tile-dims=mnk + +""" + +import argparse +import logging +import pickle +import re +from dataclasses import dataclass +from os import path, makedirs +from typing import Optional +from textwrap import indent +from abc import abstractmethod + +from iree.compiler import ir # type: ignore + +from iree.compiler.dialects import iree_codegen # type: ignore + +from .common import * +from .dispatch_constraints import * +from .dispatch_parser import * + +tune_logger = logging.getLogger("tune") + + +def apply_configuration( + template: list[str], configuration: Configuration, tile_sizes: list[int] +) -> str: + tune_logger.info(f"Applying: {configuration}") + expr0 = re.compile( + r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" + ) + expr1 = re.compile( + r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+)," + ) + expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") + expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") + expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") + repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' + repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' + repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" + repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' + + new_mlir = "" + for line in template: + if "intrinsic =" in line: + line = re.sub(expr0, repl0, line) + if "LLVMGPUVectorDistribute " in line: + line = re.sub(expr1, repl1, line) + if "tile_sizes" in line: + line = re.sub(expr2, repl2, line) + if "gpu_pipeline_options =" in line: + line = re.sub(expr3, repl3, line) + if "amdgpu-waves-per-eu" in line: + line = re.sub(expr4, repl4, line) + new_mlir += line + + return new_mlir + + +class DispatchTuner(DispatchParser): + # TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove this in favor of configuring using transform dialect. + @abstractmethod + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + """Apply parameter transformations to the operation.""" + pass + + +class DispatchTunerRegistry: + def __init__(self): + self.registry = set() + + def register(self, dispatch_tuners: list[DispatchTuner]) -> None: + for dispatch_tuner in dispatch_tuners: + self.registry.add(dispatch_tuner) + + def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: + for attr in attrs: + if (attr.name == "translation_info") and ( + "LLVMGPUVectorDistribute" in str(attr.attr) + ): + return True + assert False, "Translation info not supported" + + def find_handler(self, op_name: str) -> DispatchTuner: + for dispatch_tuner in self.registry: + if dispatch_tuner.supports(op_name): + return dispatch_tuner + assert False, "Dispatch kind not supported" + + +class MmtTuner(DispatchTuner, MmtParser): + def get_transform_function_mmt( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_mmt( + problem_size, f"match_mmt_{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_mmt_tile_sizes(configuration) + ) + embeddable = indent( + self.get_transform_function_mmt(problem_size, f"match_op", configuration), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class ConvTuner(DispatchTuner, ConvParser): + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + def get_transform_function_conv( + self, problem_size: ProblemSize, functionName: str, configuration: Configuration + ) -> str: + dynamic_batch_input_ty = problem_size.lhs_type + dynamic_batch_input_ty.shape = dynamic_batch_input_ty.shape.copy() + dynamic_batch_input_ty.shape[0] = -1 + + dynamic_batch_output_ty = problem_size.res_type + dynamic_batch_output_ty.shape = dynamic_batch_output_ty.shape.copy() + dynamic_batch_output_ty.shape[0] - 1 + + input = f"tensor<{dynamic_batch_input_ty}>" + filter = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{dynamic_batch_output_ty}>" + + tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv {{ + ^bb0(%lhs: {input}, %rhs: {filter}, %out: {output}): + %13 = linalg.conv_2d_nhwc_hwcf {{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}} + ins(%lhs, %rhs : {input}, {filter}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + conv_dims = ConvDimInfo.from_problem_size(problem_size) + modified = indent( + self.get_transform_function_conv( + problem_size, + f"match_conv_2d_nhwc_hwcf_Bx{conv_dims.oh}x{conv_dims.ow}x{conv_dims.oc}x{conv_dims.fh}x{conv_dims.fw}x{conv_dims.ic}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, configuration, self.get_conv_tile_sizes(configuration) + ) + embeddable = indent( + self.get_transform_function_conv(problem_size, f"match_op", configuration), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class ContractionTuner(DispatchTuner, ContractionParser): + def get_transform_function_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + lhs_dynamic_batch = problem_size.lhs_type + lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy() + lhs_dynamic_batch.shape[0] = -1 + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + def apply_params_broadcast_rhs_mmt( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_broadcast_rhs_mmt_Bx{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + self.get_transform_function_broadcast_rhs_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + if self.is_broadcast_rhs_mmt(template): + return self.apply_params_broadcast_rhs_mmt( + problem_size, template, configuration + ) + + # TODO: Generate transform function. + return MLIRTransformation( + template, + apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ), + "", + ) + + +class BatchMmtTuner(DispatchTuner, BatchMmtParser): + def get_transform_function_batch_mmt( + self, + problem_size: ProblemSize, + functionName: str, + configuration: Configuration, + ) -> str: + tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" +transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ +%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op +%lhs = transform.get_operand %generic[0] : (!transform.any_op) -> !transform.any_value +%rhs = transform.get_operand %generic[1] : (!transform.any_op) -> !transform.any_value +transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value +transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value +%config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param +transform.yield %generic, %config : !transform.any_op, !transform.any_param +}} +""" + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + B = problem_size.matmul_size.B + modified = indent( + self.get_transform_function_batch_mmt( + problem_size, f"match_batch_mmt_{B}x{M}x{N}x{K}", configuration + ), + "// ", + ) + modified += apply_configuration( + template, configuration, get_batch_mmt_tile_sizes(configuration) + ) + + embeddable = indent( + self.get_transform_function_batch_mmt( + problem_size, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +class BatchMatmulTuner(DispatchTuner, BatchMatmulParser): + def get_transform_function_batch_matmul( + self, + problem_size: ProblemSize, + tile_dims: str, + functionName: str, + configuration: Configuration, + ) -> str: + input0 = f"tensor<{problem_size.lhs_type}>" + input1 = f"tensor<{problem_size.rhs_type}>" + output = f"tensor<{problem_size.res_type}>" + + tile_sizes = ", ".join( + map(str, get_contract_tile_sizes(configuration, tile_dims)) + ) + + wg_x, wg_y, wg_z = configuration.workgroup_size + extra_config = get_pipeline_config(configuration) + + return f""" + transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}}) + -> (!transform.any_op, !transform.any_param) {{ + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %batch_matmul {{ + ^bb0(%lhs: {input0}, %rhs: {input1}, %out: {output}): + %13 = linalg.batch_matmul + ins(%lhs, %rhs : {input0}, {input1}) + outs(%out : {output}) -> {output} + }} : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> + {extra_config}}}> + > -> !transform.any_param + transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param + }} + """ + + def apply_params( + self, + problem_size: ProblemSize, + template: list[str], + configuration: Configuration, + ) -> MLIRTransformation: + M, N, K = problem_size.MNK + modified = indent( + self.get_transform_function_batch_matmul( + problem_size, + self.tile_dims, + f"match_batch_matmul_{problem_size.matmul_size.B}x{M}x{N}x{K}", + configuration, + ), + "// ", + ) + modified += apply_configuration( + template, + configuration, + get_contract_tile_sizes(configuration, self.tile_dims), + ) + + embeddable = indent( + self.get_transform_function_batch_matmul( + problem_size, self.tile_dims, f"match_op", configuration + ), + " ", + ) + return MLIRTransformation(template, modified, embeddable) + + +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + + +def walk_callback_get_fn( + op: ir.Operation, + walk_result: OpWalkResult, + dispatch_tuner_registry: DispatchTunerRegistry, +) -> ir.WalkResult: + if op.name == "func.func": + dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) + if op.name == "util.func": + func_name = str(op.opview.sym_name) + walk_result.was_interrupted = True + walk_result.dispatch_tuner = dispatch_tuner_registry.find_handler(func_name) + return ir.WalkResult.INTERRUPT + return ir.WalkResult.ADVANCE + + +def walk_mlir_op( + mlir_module: ir.Module, + dispatch_tuner_registry: DispatchTunerRegistry, +) -> OpWalkResult: + walk_result = OpWalkResult() + for op in mlir_module.body.operations: + op.walk( + lambda op: walk_callback_get_fn(op, walk_result, dispatch_tuner_registry), + ir.WalkOrder.POST_ORDER, + ) + if walk_result.was_interrupted: + break + return walk_result + + +def get_default_output_dir() -> str: + from datetime import datetime + + return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") + + +def tune( + input: str, # Path to the mlir file to be tuned + output: str = "", # Path to the output directory, auto creates one if not given + limit: int = 4096, # Max candidates to be generated + num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints + lhs_dims: str = "mk", # Dimensions for the left-hand side operand in matrix operations + rhs_dims: str = "nk", # Dimensions for the right-hand side operand in matrix operations + tile_dims: str = "mnk", # Dimensions for the tile size +): + input_file = str(input) + + if not output: + output = get_default_output_dir() + + # Create the directory if it does not exist + makedirs(str(output), exist_ok=True) + + tune_logger.debug(f"Output directory {output}") + tune_logger.debug(f"Processing {input_file}") + mlir_template = read_input_mlir(input_file) + mlir_text = "".join(mlir_template) + + with ir.Context() as ctx: + tuner_context = TunerContext(ctx, tune_logger) + mlir_module = parse_mlir(mlir_text, tuner_context) + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) + + walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + + variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) + assert len(variant_op_list) == 1, "Expect one executable variant op" + variant_op = variant_op_list[0] + # Get the MMA intrinisic intructions supported by the target. + mma_list = iree_codegen.query_mma_intrinsics(variant_op) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate( + generate_solutions(tune_logger, problem_size, num_subgroups, mma_list) + ): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + # TODO: Fix pickling for ir types. + # with open(path.join(output, "configs.pkl"), "wb") as file: + # pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input", help="Input mlir file", type=str) + parser.add_argument( + "-o", "--output", help="Output dir", type=str, default=get_default_output_dir() + ) + parser.add_argument( + "-l", + "--limit", + help="Max number of candidates generated", + type=int, + default=4096, + ) + parser.add_argument( + "--num-subgroups", + help="Number of subgroups per workgroup to use. (-1 == unconstrained)", + type=int, + default=-1, + ) + parser.add_argument( + "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" + ) + parser.add_argument( + "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" + ) + parser.add_argument( + "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" + ) + + args = parser.parse_args() + tune_logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) + + # Create printing formatter for logging info + formatter = logging.Formatter("%(message)s") + + # Create a handler to print to console + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + tune_logger.addHandler(console_handler) + + # # Optionally, add a file handler to log to a file + # file_handler = logging.FileHandler("tune.log") + # file_handler.setFormatter(formatter) + # tune_logger.addHandler(file_handler) + + tune( + args.input, + args.output, + args.limit, + args.num_subgroups, + args.lhs_dims, + args.rhs_dims, + args.tile_dims, + ) + + +if __name__ == "__main__": + args = main() diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py new file mode 100644 index 000000000..36fb87cbb --- /dev/null +++ b/tuner/tuner/candidate_gen_test.py @@ -0,0 +1,447 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest + +from typing import Generator + +from iree.compiler import ir # type: ignore + +from . import candidate_gen +from . import common + + +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +def remove_comments(mlir: str) -> str: + return "\n".join( + filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) + ) + + +def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + "gpu_pipeline_options = #iree_gpu.pipeline_options", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + M, N, K = 2048, 1280, 1280 + + config = common.Configuration( + subgroup_size=16, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[8, 8, 8], + subgroup_m_count=16, + subgroup_n_count=16, + gpu_pipeline_options=common.GpuPipelineOptions(prefetch_shared_memory=True), + waves_per_eu=8, + ) + + problem_size = common.ProblemSize( + common.MatmulSize(M, N, K), + common.ShapedType([M, K], tuner_ctx.type.f16), + common.ShapedType([N, K], tuner_ctx.type.f16), + common.ShapedType([M, N], tuner_ctx.type.f32), + common.DispatchKind.mmt, + ) + tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 16, subgroup_n_count = 16" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [16, 16, 1] subgroup_size = 16" + in modified + ) + assert "tile_sizes = [[8, 8, 8]]" in modified + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in modified + ) + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified + + +def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: + mlir_template = [ + ", subgroup_m_count = 16, subgroup_n_count = 16>", + "", + 'gpu_pipeline_options = #iree_gpu.pipeline_options, {llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}', + ] + + n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + + config = common.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + gpu_pipeline_options=common.GpuPipelineOptions( + reorder_workgroups_strategy=common.ReorderWorkgroupsStrategy.TRANSPOSE + ), + waves_per_eu=2, + ) + + problem_size = common.ProblemSize( + common.MatmulSize(oh * ow, oc, fh * fw * ic), + common.ShapedType([n, oh + 2, ow + 2, oc], tuner_ctx.type.f16), + common.ShapedType([fh, fw, ic, oc], tuner_ctx.type.f16), + common.ShapedType([n, oh, ow, oc], tuner_ctx.type.f32), + common.DispatchKind.conv, + ) + tf_mlir = candidate_gen.ConvTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 1, 464, 320, 1, 1, 16]]" in modified + assert ( + "gpu_pipeline_options = #iree_gpu.pipeline_options" + in modified + ) + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: + mlir_template = [ + ", subgroup_m_count = 2, subgroup_n_count = 2>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "*mnk" + problem_size = common.ProblemSize( + common.MatmulSize(2048, 3840, 1280), + common.ShapedType([2, 1024, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 20, 64, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 2, 20, 1024, 64], tuner_ctx.type.f32), + common.DispatchKind.contraction, + ) + + config = common.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + tile_sizes=[480, 384, 32], + subgroup_m_count=1, + subgroup_n_count=4, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + new_mlir = tf_mlir.modified + + assert new_mlir + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 1, subgroup_n_count = 4" + in new_mlir + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64" + in new_mlir + ) + assert "tile_sizes = [[1, 480, 384, 32]]" in new_mlir + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir + + +def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + tile_dims = "bmnk" + problem_size = common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f16), + common.ShapedType([64, 640, 320], tuner_ctx.type.f16), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ) + + config = common.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + tile_sizes=[416, 320, 128], + subgroup_m_count=2, + subgroup_n_count=2, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + modified = remove_comments(modified) + + assert embeddable + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 416, 320, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], tuner_ctx.type.f16), + common.ShapedType([2, 640, 640], tuner_ctx.type.f16), + common.ShapedType([2, 4096, 640], tuner_ctx.type.f32), + common.DispatchKind.batch_mmt, + ) + + config = common.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert embeddable + assert modified + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified + + +def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([2, 640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), + common.DispatchKind.batch_mmt, + ) + + config = common.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=4, + ) + + tf_mlir = candidate_gen.BatchMmtTuner().apply_params( + problem_size, mlir_template, config + ) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert "// transform.named_sequence @match_batch_mmt_2x4096x640x640(" in modified + modified = remove_comments(modified) + + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_batch_mmt_i8_i8_i32 failures(propagate)" in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor<2x4096x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<2x640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: + mlir_template = [ + ", subgroup_m_count = 4, subgroup_n_count = 1>}>", + "", + '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', + ] + + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), + common.DispatchKind.broadcast_rhs_mmt, + ) + + config = common.Configuration( + subgroup_size=64, + workgroup_size=[128, 2, 1], + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + tile_sizes=[128, 64, 128], + subgroup_m_count=2, + subgroup_n_count=2, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=4, + ) + + tf_mlir = candidate_gen.ContractionTuner( + "mk", "nk", "mnk" + ).apply_params_broadcast_rhs_mmt(problem_size, mlir_template, config) + + modified = tf_mlir.modified + embeddable = tf_mlir.embeddable + + assert modified + assert ( + "// transform.named_sequence @match_broadcast_rhs_mmt_Bx4096x640x640(" + in modified + ) + modified = remove_comments(modified) + + assert ( + "intrinsic = #iree_gpu.mma_layout, subgroup_m_count = 2, subgroup_n_count = 2" + in modified + ) + assert ( + "LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64" + in modified + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in modified + assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in modified + + assert embeddable + assert "transform.named_sequence @match_op(" in embeddable + assert ( + "transform.include @match_broadcast_rhs_mmt_i8_i8_i32 failures(propagate)" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %lhs = tensor : !transform.any_value" + in embeddable + ) + assert ( + "transform.iree.match.cast_compatible_type %rhs = tensor<640x640xi8> : !transform.any_value" + in embeddable + ) + assert ( + "%config = transform.param.constant #iree_codegen.compilation_info<" + in embeddable + ) + assert "tile_sizes = [[1, 128, 64, 128]]" in embeddable + assert 'llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' in embeddable + assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable + + +def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: + mlir_lines = [ + r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + ] + assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( + mlir_lines + ) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py new file mode 100644 index 000000000..b6e31768e --- /dev/null +++ b/tuner/tuner/common.py @@ -0,0 +1,249 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import re +import logging +from dataclasses import astuple, dataclass +from enum import Enum +from typing import Optional + +from iree.compiler import ir # type: ignore + +from iree.compiler.dialects import iree_gpu # type: ignore + + +class CommonTypes: + def __init__(self, ctx: ir.Context): + assert ctx + self.i1 = ir.IntegerType.get_signless(1, ctx) + self.i8 = ir.IntegerType.get_signless(8, ctx) + self.i16 = ir.IntegerType.get_signless(16, ctx) + self.i32 = ir.IntegerType.get_signless(32, ctx) + + self.f8E4M3FNUZ = ir.Float8E4M3FNUZType.get(ctx) + self.f8E5M2FNUZ = ir.Float8E5M2FNUZType.get(ctx) + self.f16 = ir.F16Type.get(ctx) + self.f32 = ir.F32Type.get(ctx) + + self.bf16 = ir.BF16Type.get(ctx) + + +class TunerContext: + def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): + self.mlir_ctx: ir.Context = mlir_ctx + self.logger: logging.Logger = logger + self.type: CommonTypes = CommonTypes(mlir_ctx) + + +class DispatchKind(Enum): + conv = 1 + mmt = 2 + contraction = 3 + batch_mmt = 4 + batch_matmul = 5 + broadcast_rhs_mmt = 6 + + +@dataclass +class ShapedType: + shape: list[int] + element_type: ir.IntegerType | ir.FloatType + + def rank(self) -> int: + return len(self.shape) + + @property + def bitwidth(self) -> int: + return self.element_type.width + + def __str__(self) -> str: + dim_to_str = lambda dim: str(dim) if dim != -1 else "?" + return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) + + +@dataclass +class MatmulSize: + M: int + N: int + K: int + B: int = 1 + + +@dataclass +class ProblemSize: + matmul_size: MatmulSize + lhs_type: ShapedType + rhs_type: ShapedType + res_type: ShapedType + dispatch_kind: DispatchKind + + @property + def MNK(self) -> tuple[int, int, int]: + return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) + + +@dataclass +class MfmaIntrinsic: + output_type: ir.IntegerType | ir.FloatType + m: int + n: int + k: int + input_type: ir.IntegerType | ir.FloatType + + def __str__(self) -> str: + input = str(self.input_type).upper() + output = str(self.output_type).upper() + return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" + + @staticmethod + def mfma_f32_16x16x16_f16(): + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + return MfmaIntrinsic(f32, 16, 16, 16, f16) + + @staticmethod + def mfma_f32_32x32x8_f16(): + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + return MfmaIntrinsic(f32, 32, 32, 8, f16) + + @staticmethod + def mfma_i32_16x16x32_i8(): + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + return MfmaIntrinsic(i32, 16, 16, 32, i8) + + @staticmethod + def mfma_i32_32x32x16_i8(): + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + return MfmaIntrinsic(i32, 32, 32, 16, i8) + + @staticmethod + def all(): + return [ + MfmaIntrinsic.mfma_f32_16x16x16_f16(), + MfmaIntrinsic.mfma_f32_32x32x8_f16(), + MfmaIntrinsic.mfma_i32_16x16x32_i8(), + MfmaIntrinsic.mfma_i32_32x32x16_i8(), + ] + + +def get_compatible_mfma_intrinsics( + problem_size: ProblemSize, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +) -> list[MfmaIntrinsic]: + available_mma_intrinsics = [str(mma) for mma in mma_intrinsics] + + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: + if problem_size.res_type.element_type != intrinsic.output_type: + return False + if problem_size.dispatch_kind != DispatchKind.batch_matmul: + if problem_size.lhs_type.element_type != intrinsic.input_type: + return False + if problem_size.rhs_type.element_type != intrinsic.input_type: + return False + + if str(intrinsic) not in available_mma_intrinsics: + return False + + return True + + return list(filter(is_compatible, MfmaIntrinsic.all())) + + +class ReorderWorkgroupsStrategy(Enum): + NONE = 0 + SWIZZLE = 1 + TRANSPOSE = 2 + + def __str__(self) -> str: + return self.name.title() + + +@dataclass +class GpuPipelineOptions: + """Represents the `iree_gpu.pipeline_options` attribute""" + + prefetch_shared_memory: Optional[bool] = None + no_reduce_shared_memory_bank_conflicts: Optional[bool] = None + reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None + + def all_default(self) -> bool: + return all(x is None for x in astuple(self)) + + def __str__(self) -> str: + options: list[str] = [] + if self.prefetch_shared_memory is not None: + options.append( + f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}" + ) + if self.no_reduce_shared_memory_bank_conflicts is not None: + options.append( + f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}" + ) + if self.reorder_workgroups_strategy is not None: + options.append( + f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}" + ) + + return f"#iree_gpu.pipeline_options<{', '.join(options)}>" + + +@dataclass +class Configuration: + subgroup_size: int + workgroup_size: list[int] + intrinsic: MfmaIntrinsic + tile_sizes: list[int] + subgroup_m_count: int + subgroup_n_count: int + gpu_pipeline_options: GpuPipelineOptions + waves_per_eu: int + + +def get_pipeline_config(configuration: Configuration) -> str: + extra_config = "" + if not configuration.gpu_pipeline_options.all_default(): + extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}" + if configuration.waves_per_eu != 2: + extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' + return extra_config + + +def read_input_mlir(filename: str) -> list[str]: + with open(filename, "r") as f: + return f.readlines() + + +@dataclass +class ConvDimInfo: + n: int + oh: int + ow: int + oc: int + fh: int + fw: int + ic: int + + @staticmethod + def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): + n, oh, ow, oc = res_shaped_type.shape + fh, fw, ic, _ = rhs_shaped_type.shape + return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) + + @staticmethod + def from_problem_size(problem_size: ProblemSize): + return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) + + +@dataclass +class MLIRTransformation: + """Transformation of MLIR context""" + + template: list[str] + modified: str + embeddable: str diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py new file mode 100644 index 000000000..297ac95a2 --- /dev/null +++ b/tuner/tuner/common_test.py @@ -0,0 +1,187 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +from . import common + +from typing import Generator + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore + + +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +@pytest.fixture +def mlir_ctx() -> Generator[ir.Context, None, None]: + with ir.Context() as ctx: + yield ctx + + +def test_get_shaped_type_element_bitwidth(tuner_ctx: common.TunerContext) -> None: + assert common.ShapedType([1024, 2048], tuner_ctx.type.i8).bitwidth == 8 + assert common.ShapedType([2048], tuner_ctx.type.i32).bitwidth == 32 + assert common.ShapedType([2048, 512, 384], tuner_ctx.type.f8E4M3FNUZ).bitwidth == 8 + assert common.ShapedType([1, 1], tuner_ctx.type.f16).bitwidth == 16 + + +def test_get_shaped_type_to_str(tuner_ctx: common.TunerContext) -> None: + assert str(common.ShapedType([1024, 2048], tuner_ctx.type.i8)) == "1024x2048xi8" + assert str(common.ShapedType([1024], tuner_ctx.type.f32)) == "1024xf32" + assert str(common.ShapedType([1, 2, 3], tuner_ctx.type.f16)) == "1x2x3xf16" + assert str(common.ShapedType([-1, 2, 3], tuner_ctx.type.f16)) == "?x2x3xf16" + + +def test_gpu_pipeline_options() -> None: + options = common.GpuPipelineOptions() + assert options.all_default() + assert str(options) == "#iree_gpu.pipeline_options<>" + + options.prefetch_shared_memory = True + assert not options.all_default() + assert str(options) == "#iree_gpu.pipeline_options" + + options.no_reduce_shared_memory_bank_conflicts = False + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + options = common.GpuPipelineOptions() + options.reorder_workgroups_strategy = common.ReorderWorkgroupsStrategy.TRANSPOSE + assert not options.all_default() + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + +def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: + config = common.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + config1_str: str = common.get_pipeline_config(config) + assert config1_str == "" + + config.waves_per_eu = 4 + config2_str: str = common.get_pipeline_config(config) + assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + + config.gpu_pipeline_options.prefetch_shared_memory = True + config3_str = common.get_pipeline_config(config) + assert ( + config3_str + == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + ) + + +def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None: + assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" + assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" + + +def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), + common.DispatchKind.mmt, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], + ) == [ + common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], tuner_ctx.type.i8), + common.ShapedType([1280, 1280], tuner_ctx.type.i8), + common.ShapedType([2048, 1280], tuner_ctx.type.i32), + common.DispatchKind.mmt, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) == [ + common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), + common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], + ) == [ + common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + ], + ) == [ + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] + + assert ( + common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), + common.DispatchKind.batch_matmul, + ), + [ + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + == [] + ) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py new file mode 100644 index 000000000..85039a1e8 --- /dev/null +++ b/tuner/tuner/dispatch_constraints.py @@ -0,0 +1,206 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import z3 # type: ignore +from typing import Iterator + + +from iree.compiler.dialects import iree_gpu # type: ignore + +from .common import * + + +def get_mfma_intrinsic_constraints( + problem_size: ProblemSize, + intrinsic_m: z3.ArithRef, + intrinsic_n: z3.ArithRef, + intrinsic_k: z3.ArithRef, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +) -> z3.BoolRef: + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) + assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + return z3.Or( + *( + z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) + for mfma in compatible_intrinsics + ) + ) + + +def get_dispatch_constraints( + problem_size: ProblemSize, + tile_m: z3.ArithRef, + tile_n: z3.ArithRef, + tile_k: z3.ArithRef, +) -> list[z3.BoolRef]: + if problem_size.dispatch_kind != DispatchKind.conv: + return [] + + dim_info = ConvDimInfo.from_problem_size(problem_size) + conv_constraints = [] + # WARNING: This sometimes makes the constraints UNSAT for some reason. + conv_constraints += [tile_m <= dim_info.ow] + conv_constraints += [tile_n <= dim_info.oc] + conv_constraints += [tile_k <= dim_info.ic] + return conv_constraints + + +def calculate_shared_memory_usage_in_bytes( + problem_size: ProblemSize, + m: int | z3.ArithRef, + n: int | z3.ArithRef, + k: int | z3.ArithRef, +) -> int | z3.ArithRef: + lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) + rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) + return lhs_memory + rhs_memory + + +def generate_constraints( + problem_size: ProblemSize, + tile_sizes, + num_subgroups, + subgroup_size, + intrinsic_size, + workgroup_size, + subgroup_m_count, + subgroup_n_count, + waves_per_eu, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +): + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, + ) + m, n, k = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_threads == wg_x * wg_y * wg_z] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics + ) + ] + subgroup_k_count = 1 + constraints += [ + m >= intrinsic_mn, + m <= 512, + m <= M, + ] + constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] + constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + + constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] + constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] + constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] + constraints += [wg_x == subgroup_size * subgroup_n_count] + constraints += [wg_y == subgroup_m_count] + constraints += [wg_z == subgroup_k_count] + constraints += [z3.Or(wg_x <= n, wg_x <= m)] + constraints += [k % intrinsic_mn == 0] + constraints += [(k * n) % wg_threads == 0] + constraints += [(k * m) % wg_threads == 0] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + + constraints += [waves_per_eu == 2] + # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +def generate_solutions( + logger: logging.Logger, + problem_size: ProblemSize, + num_subgrups: int, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +) -> Iterator[Configuration]: + M, N, K = problem_size.MNK + logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + all_vars = [ + m, + n, + k, + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + + solver = z3.Solver() + constraints = generate_constraints( + problem_size, + [m, n, k], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + mma_intrinsics, + ) + solver.add(z3.simplify(z3.And(constraints))) + logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + MfmaIntrinsic( + problem_size.res_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.lhs_type.element_type, + ), + [lookup(m), lookup(n), lookup(k)], + lookup(sg_m_cnt), + lookup(sg_n_cnt), + GpuPipelineOptions(), + lookup(waves_per_eu), + ) + solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) + i += 1 + yield config diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py new file mode 100644 index 000000000..9de4beeee --- /dev/null +++ b/tuner/tuner/dispatch_constraints_test.py @@ -0,0 +1,194 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +import z3 # type: ignore + +from typing import Generator + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore + +from . import common +from . import dispatch_constraints + + +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: + matmul_size = common.MatmulSize(2048, 3840, 1280) + lhs_type = common.ShapedType([2048, 1280], tuner_ctx.type.f16) + rhs_type = common.ShapedType([3840, 1280], tuner_ctx.type.f16) + res_type = common.ShapedType([2048, 3840], tuner_ctx.type.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + configs = dispatch_constraints.generate_solutions( + tuner_ctx.logger, + problem_size, + 4, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + + assert configs is not None + + +def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) -> None: + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 512, 64, 128 + ) + == 147456 + ) + + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i8) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 512, 64, 128 + ) + == 81920 + ) + + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 128, 64, 32 + ) + == 12288 + ) + + +def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> None: + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + # Define input parameters as z3 Ints + m, n, k = ( + dispatch_constraints.z3.Int("m"), + z3.Int("n"), + z3.Int("k"), + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == z3.sat + + +def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> None: + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + m, n, k = ( + z3.Int("m"), + z3.Int("n"), + z3.Int("k"), + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == z3.unsat diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py new file mode 100644 index 000000000..c4b4b9ad5 --- /dev/null +++ b/tuner/tuner/dispatch_parser.py @@ -0,0 +1,457 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import math +import re +from abc import ABCMeta, abstractmethod + +from .common import * + + +def parse_tensor_type(tensor_type: str) -> ShapedType: + shaped_ty = ir.RankedTensorType(ir.Type.parse(tensor_type)) + assert shaped_ty + return ShapedType(shaped_ty.shape, shaped_ty.element_type) + + +def get_mmt_tile_sizes(configuration: Configuration): + return configuration.tile_sizes + + +def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: + m, n, k = configuration.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + +def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tile_sizes + + +class MlirRegex(Enum): + ssa_value = r"%[a-zA-Z0-9-_]+" + tensor_type = r"tensor<([^>]+)>" + + def __str__(self) -> str: + return self.value + + @staticmethod + def dps_ins_two_args() -> str: + return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" + + @staticmethod + def dps_outs_one_arg() -> str: + return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" + + +def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module: + mlir_module = None + try: + mlir_module = ir.Module.parse(mlir_text, ctx.mlir_ctx) + ctx.logger.info("MLIR parsing successful!") + except ir.MLIRError as e: + ctx.logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") + + return mlir_module + + +class DispatchParser(metaclass=ABCMeta): + @abstractmethod + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" + pass + + @abstractmethod + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of the operation.""" + pass + + +class MmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) + assert mmt_re + assert False, f"'{mmt_re}' not found in given context" + + +class ConvParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) + + assert False, "Shape not found" + + +class ContractionParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt_op(self, line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(self.is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not self.is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + +class BatchMmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) + + assert False, "Shape not found" + + +class BatchMatmulParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() + + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py new file mode 100644 index 000000000..d3a99806f --- /dev/null +++ b/tuner/tuner/dispatch_parser_test.py @@ -0,0 +1,191 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest + +from typing import Generator + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import func # type: ignore + +from . import common +from . import dispatch_parser + + +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: + assert dispatch_parser.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType( + [1, 2, 3], tuner_ctx.type.f32 + ) + assert dispatch_parser.parse_tensor_type("tensor<123xi8>") == common.ShapedType( + [123], tuner_ctx.type.i8 + ) + + +def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: + config = dispatch_parser.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=0, + ) + assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: + config = dispatch_parser.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=1, + ) + assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [ + 1, + 1, + 464, + 320, + 1, + 1, + 16, + ] + + +def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: + config = dispatch_parser.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4] + assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [ + 16, + 16, + 16, + ] + + +def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), + dispatch_parser.DispatchKind.mmt, + ) + + +def test_get_shapes_conv(tuner_ctx: common.TunerContext) -> None: + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(32, 256, 11520), + common.ShapedType([1, 3, 34, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 3, 1280, 256], tuner_ctx.type.f16), + common.ShapedType([1, 1, 32, 256], tuner_ctx.type.f32), + dispatch_parser.DispatchKind.conv, + ) + + +def test_get_shapes_contract(tuner_ctx: common.TunerContext) -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.ContractionParser("mk", "nk", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), + dispatch_parser.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul(tuner_ctx: common.TunerContext) -> None: + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMatmulParser("bmk", "bkn", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(32, 32, 1024, 1), + common.ShapedType([1, 32, 1024], tuner_ctx.type.f32), + common.ShapedType([1, 1024, 32], tuner_ctx.type.f32), + common.ShapedType([1, 32, 32], tuner_ctx.type.f32), + dispatch_parser.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt(tuner_ctx: common.TunerContext) -> None: + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([2, 640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), + dispatch_parser.DispatchKind.batch_mmt, + ) + + +def test_parse_mlir(tuner_ctx: common.TunerContext) -> None: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } +""" + mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_ctx) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp) diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py new file mode 100644 index 000000000..3aa932dc4 --- /dev/null +++ b/tuner/tuner/libtuner.py @@ -0,0 +1,1418 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Provides fundamental functions for tuning: + - generate_candidates() + - compile_dispatches() + - benchmark_dispatches() + - compile_models() + - benchmark_models() + +Requires a wrapper Python script to import `libtuner`, +use the `TuningClient` API, customize compilation and benchmarking commands, +and implement a complete tuning loop for a specific model. +""" + + +import sys +import shutil +import subprocess +import logging +import argparse +from datetime import datetime +from enum import Enum +from pathlib import Path +import time +import multiprocessing +import queue +from tqdm import tqdm +import re +import hashlib +from dataclasses import dataclass, field +from typing import Type, Optional, Callable, Iterable, Any +import pickle +import random +import json +from abc import ABC, abstractmethod +import iree.runtime as ireert # type: ignore +from . import candidate_gen + + +# Default values for num_candidates and devices, change it as needed +DEFAULT_NUM_CANDIDATES = 2048 +DEFAULT_DEVICE_LIST = ["hip://0"] + +# Default values for max number of workers +DEFAULT_MAX_CPU_WORKERS = ( + multiprocessing.cpu_count() // 2 +) # the actual amount of worker that will be generated = min(max_cpu_workers, len(task_list)) + +# Declare global variables at the module level for multiprocessing +worker_id = None +device_id = None + +# Declare special symbols for libtuner to search and locate +DEVICE_ID_PLACEHOLDER = "!DEVICE_ID!" + + +@dataclass +class CandidateTracker: + candidate_id: int + dispatch_mlir_path: Optional[Path] = None + dispatch_config_path: Optional[Path] = None + configuration: Optional[candidate_gen.Configuration] = None + compilation_successful: Optional[bool] = None + compiled_dispatch_path: Optional[Path] = None + compiled_dispatch_hash: Optional[str] = None + first_benchmark_time: Optional[float] = None + first_benchmark_device_id: Optional[str] = None + spec_path: Optional[Path] = None + compiled_model_path: Optional[Path] = None + compiled_model_hash: Optional[str] = None + model_benchmark_time: Optional[float] = None + model_benchmark_device_id: Optional[str] = None + baseline_benchmark_time: Optional[float] = None + calibrated_benchmark_diff: Optional[float] = None + + +@dataclass() +class PathConfig: + # Preset constants + global_config_prolog_mlir: Path = Path("config_prolog.mlir") + global_config_epilog_mlir: Path = Path("config_epilog.mlir") + model_baseline_vmfb: Path = Path("baseline.vmfb") + + # Dynamic paths + base_dir: Path = field(init=False) + local_config_prolog_mlir: Path = field(init=False) + local_config_epilog_mlir: Path = field(init=False) + template_mlir: Path = field(init=False) + candidates_dir: Path = field(init=False) + candidate_configs_pkl: Path = field(init=False) + compiled_dir: Path = field(init=False) + compile_failed_dir: Path = field(init=False) + specs_dir: Path = field(init=False) + + output_unilog: Path = field(init=False) + result_summary_log: Path = field(init=False) + candidate_trackers_pkl: Path = field(init=False) + + # To be set outside of class + run_log: Optional[Path] = field(init=False, default=None) + + def __post_init__(self): + object.__setattr__(self, "base_dir", self._name_base_dir()) + object.__setattr__( + self, "local_config_prolog_mlir", self.base_dir / "config_prolog.mlir" + ) + object.__setattr__( + self, "local_config_epilog_mlir", self.base_dir / "config_epilog.mlir" + ) + object.__setattr__(self, "template_mlir", self.base_dir / "template.mlir") + object.__setattr__(self, "candidates_dir", self.base_dir / "candidates") + object.__setattr__( + self, "candidate_configs_pkl", self.candidates_dir / "configs.pkl" + ) + object.__setattr__(self, "compiled_dir", self.candidates_dir / "compiled") + object.__setattr__(self, "compile_failed_dir", self.candidates_dir / "failed") + object.__setattr__(self, "specs_dir", self.candidates_dir / "specs") + object.__setattr__(self, "output_unilog", self.base_dir / "output.log") + object.__setattr__( + self, "result_summary_log", self.base_dir / "result_summary.log" + ) + object.__setattr__( + self, "candidate_trackers_pkl", self.base_dir / "candidate_trackers.pkl" + ) + + def _name_base_dir(self) -> Path: + timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M") + base_dir = Path(f"./tuning_{timestamp}") + return base_dir + + def _set_run_log(self, run_log: Path): + object.__setattr__(self, "run_log", run_log) + + def get_candidate_mlir_path(self, candidate_id: int) -> Path: + return self.candidates_dir / f"{candidate_id}.mlir" + + def get_candidate_spec_mlir_path(self, candidate_id: int) -> Path: + return self.candidates_dir / "specs" / f"{candidate_id}_spec.mlir" + + def get_exe_format(self, path: Path) -> str: + return f"./{path.as_posix()}" + + def get_compiled_dispatch_index(self, file_path: Path) -> int: + return int(file_path.stem) + + def get_candidate_spec_filename(self, candidate_id: int) -> str: + return f"{candidate_id}_spec.mlir" + + def get_compiled_model_index(self, file_path: Path) -> int: + return int(file_path.stem.split("_")[-1]) + + +class TuningClient(ABC): + @abstractmethod + def get_dispatch_compile_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_dispatch_benchmark_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_model_compile_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_model_benchmark_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: + pass + + @abstractmethod + def get_dispatch_compile_timeout_s(self) -> int: + pass + + @abstractmethod + def get_dispatch_benchmark_timeout_s(self) -> int: + pass + + @abstractmethod + def get_model_compile_timeout_s(self) -> int: + pass + + @abstractmethod + def get_model_benchmark_timeout_s(self) -> int: + pass + + +@dataclass +class RunPack: + command: list[str] + check: bool = True + timeout_seconds: Optional[int] = None + + +@dataclass +class RunResult: + process_res: Optional[subprocess.CompletedProcess] + is_timeout: bool + + +@dataclass +class TaskPack: + run_pack: RunPack + candidate_id: int + command_need_device_id: bool = False + cooling_time: int = 0 + + +@dataclass +class TaskResult: + run_result: RunResult + candidate_id: int + device_id: str + + +@dataclass +class ParsedDisptachBenchmarkResult: + candidate_id: int + benchmark_time_in_seconds: float + candidate_mlir: Path + candidate_spec_mlir: Path + + +@dataclass +class IREEBenchmarkResult: + # Default format follows output of iree-benchmark-module + candidate_id: int + + # A list of dictionaries, each representing a benchmark result + # Each dictionary contains fields like: aggregate_name: string, real_time: float, cpu_time: float, time_unit: str, repetitions: int, etc. + result_json: list[dict[str, Any]] + + def get_mean_time_us(self) -> Optional[float]: + """Compute the mean time (in microseconds) for all of the benchmarks""" + if not self.result_json: + return None + + mean_benchmark = self.find_mean_benchmark(self.result_json) + + if mean_benchmark: + real_time: float | None = mean_benchmark.get("real_time") + time_unit: str | None = mean_benchmark.get("time_unit") + + if real_time is not None: + assert time_unit is not None + return self.unit_to_microseconds(real_time, time_unit) + + return None + + @staticmethod + def find_mean_benchmark(result_json: list[dict[str, Any]]) -> Optional[dict]: + for benchmark in result_json: + if benchmark.get("aggregate_name") == "mean": + return benchmark + + return None + + @staticmethod + def unit_to_microseconds(real_time: float, time_unit: str) -> float: + unit_conversions = { + "s": 1e6, + "ms": 1e3, + "us": 1, + "ns": 1e-3, + } + + assert time_unit in unit_conversions, f"Unsupported time unit: {time_unit}" + + return real_time * unit_conversions[time_unit] + + +def generate_display_DBR(candidate_id: int, mean_time: float) -> str: + """Generate dispatch_benchmark_result string for displaying""" + return f"{candidate_id}\tMean Time: {mean_time:.1f}" + + +def generate_display_MBR( + candidate_vmfb_path_str: str, + device_id: str, + t1: float, + calibrated_diff: Optional[float] = None, +) -> str: + """Generate model_benchmark_result string for displaying""" + if calibrated_diff: + percentage_change = calibrated_diff * 100 + change_str = f"({percentage_change:+.3f}%)" + res_str = f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g} {change_str}" + else: + res_str = ( + f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g}" + ) + return res_str + + +def extract_driver_names(user_devices: list[str]) -> set[str]: + """Extract driver names from the user devices""" + return {device.split("://")[0] for device in user_devices} + + +def fetch_available_devices(drivers: list[str]) -> list[str]: + """ + Extract all available devices on the user's machine for the provided drivers + Only the user provided drivers will be queried + """ + all_device_ids: list[str] = [] + + for driver_name in drivers: + try: + driver = ireert.get_driver(driver_name) + devices = driver.query_available_devices() + all_device_ids.extend( + f"{driver_name}://{device['path']}" for device in devices + ) + all_device_ids.extend( + f"{driver_name}://{device['device_id'] - 1}" for device in devices + ) + except ValueError as e: + handle_error( + condition=True, + msg=f"Could not initialize driver {driver_name}: {e}", + error_type=ValueError, + exit_program=True, + ) + + return all_device_ids + + +def parse_devices(devices_str: str) -> list[str]: + """ + Parse a comma-separated list of device IDs e.g.: + --devices=hip://0,local-sync://default -> ["hip://0", "local-sync://default"]). + """ + devices = [device.strip() for device in devices_str.split(",")] + for device in devices: + if "://" not in device or not device: + handle_error( + condition=True, + msg=f"Invalid device list: {devices_str}. Error: {ValueError()}", + error_type=argparse.ArgumentTypeError, + ) + return devices + + +def validate_devices(user_devices: list[str]) -> None: + """Validates the user provided devices against the devices extracted by the IREE Runtime""" + user_drivers = extract_driver_names(user_devices) + + available_devices = fetch_available_devices(list(user_drivers)) + + for device in user_devices: + handle_error( + condition=(device not in available_devices), + msg=f"Invalid device specified: {device}\nFetched available devices: {available_devices}", + error_type=argparse.ArgumentError, + exit_program=True, + ) + + +class ExecutionPhases(str, Enum): + dont_stop = "" + generate_candidates = "generate-candidates" + compile_dispatches = "compile-dispatches" + benchmark_dispatches = "benchmark-dispatches" + compile_models = "compile-models" + benchmark_models = "benchmark-models" + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Autotune script") + + # Required arguments + required_args = parser.add_argument_group("Required Options") + required_args.add_argument( + "input_file", type=Path, help="Path to the input benchmark file (.mlir)" + ) + + # General options + general_args = parser.add_argument_group("General Options") + general_args.add_argument( + "--verbose", "-v", action="store_true", help="Enable verbose output to stdout" + ) + general_args.add_argument( + "--devices", + type=parse_devices, + default=DEFAULT_DEVICE_LIST, + help="Comma-separated list of device IDs (e.g., --devices=hip://,hip://GPU-UUID).", + ) + general_args.add_argument( + "--max-cpu-workers", + type=int, + default=DEFAULT_MAX_CPU_WORKERS, + help=f"Max number of workers for CPU-bounding tasks (default: {DEFAULT_MAX_CPU_WORKERS}, the number of CPUs in current system)", + ) + general_args.add_argument( + "--stop-after", + choices=[x.value for x in ExecutionPhases], + default=ExecutionPhases.dont_stop, + help="Stop execution after specified phase", + ) + general_args.add_argument( + "--num-model-candidates", + help="Maximum number of stage 2 candidates", + type=int, + default=50, + ) + general_args.add_argument( + "--dry-run", + action="store_true", + help="Do not attempt to run any modules or initialize the IREE runtime", + ) + + # candidate_gen.tune() options + candidate_gen_args = parser.add_argument_group("Candidate Generation Options") + candidate_gen_args.add_argument( + "--num-candidates", + type=int, + default=DEFAULT_NUM_CANDIDATES, + help=f"Number of candidates to be generated by candidate_gen.py (default: {DEFAULT_NUM_CANDIDATES})", + ) + candidate_gen_args.add_argument( + "--num-subgroups", + help="Number of subgroups per workgroup to use. (-1 == unconstrained)", + type=int, + default=-1, + ) + candidate_gen_args.add_argument( + "--lhs-dims", help="Map of LHS matmul dims", type=str, default="mk" + ) + candidate_gen_args.add_argument( + "--rhs-dims", help="Map of RHS matmul dims", type=str, default="nk" + ) + candidate_gen_args.add_argument( + "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" + ) + + return parser.parse_args() + + +def setup_logging(args: argparse.Namespace, path_config: PathConfig): + log_file_name = f"autotune_{args.input_file.stem}.log" + run_log_path = path_config.base_dir / log_file_name + path_config._set_run_log(run_log_path) + + # Create file handler for logging to a file + if path_config.run_log is None: + raise + file_handler = logging.FileHandler(path_config.run_log) + file_handler.setLevel(logging.DEBUG) + + # Create stream handler for logging to the console (only warnings and higher) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + + # Create a formatter that dynamically adds [levelname] for ERROR and WARNING + class CustomFormatter(logging.Formatter): + def format(self, record): + if record.levelno == logging.INFO: + return f"{record.message}" + else: + return f"[{record.levelname}] {record.message}" + + file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + console_formatter = CustomFormatter() + + # Set formatters to handlers + file_handler.setFormatter(file_formatter) + console_handler.setFormatter(console_formatter) + + # Configure the root logger + logging.basicConfig( + level=logging.DEBUG, # Set the root logger to the lowest level + handlers=[file_handler, console_handler], + ) + + # If verbose flag is set, add a console handler for INFO level and higher + if args.verbose: + verbose_console_handler = logging.StreamHandler() + verbose_console_handler.setLevel(logging.DEBUG) + verbose_console_handler.setFormatter(file_formatter) + logging.getLogger().addHandler(verbose_console_handler) + + # config logger in candidate_gen.py + tune_logger = logging.getLogger("tune") + tune_logger.setLevel(logging.DEBUG) + + # Log all arguments + logging.debug(f"Input Arguments:") + for arg, value in vars(args).items(): + tune_logger.info(f"{arg}: {value}") + + +def handle_error( + condition: bool, + msg: str, + level: int = logging.ERROR, + error_type: Type[BaseException] = Exception, + exit_program: bool = False, +) -> None: + """If meets the condition, handles errors with logging and optional program exit""" + if not condition: + return + + # Log the message with the specified level + if level == logging.CRITICAL: + logging.critical(msg) + raise error_type(msg) + if level == logging.ERROR: + logging.error(msg) + raise error_type(msg) + elif level == logging.WARNING: + logging.warning(msg) + elif level == logging.INFO: + logging.info(msg) + elif level == logging.DEBUG: + logging.debug(msg) + else: + raise ValueError( + "Invalid logging level specified: choose from logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG" + ) + + if exit_program: + sys.exit(1) + + +def init_worker_context(queue: multiprocessing.Queue) -> None: + """Assign a static index to current process as the worker ordinal, and specify the device indice to be used""" + global worker_id, device_id + + worker_id, device_id = queue.get() + + +def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int, int]]: + """Create queue contains Worker ID and Device ID for worker initialization""" + worker_contexts_queue = multiprocessing.Manager().Queue() + for worker_id, device_id in enumerate(device_ids): + worker_contexts_queue.put((worker_id, device_id)) + + return worker_contexts_queue + + +def run_command(run_pack: RunPack) -> RunResult: + command = run_pack.command + check = run_pack.check + timeout_seconds = run_pack.timeout_seconds + + result = None + is_timeout = False + try: + # Convert the command list to a command string for logging + command_str = " ".join(command) + logging.debug(f"Run: {command_str}") + + # Add timeout to subprocess.run call + result = subprocess.run( + command, + check=check, + capture_output=True, + text=True, + timeout=timeout_seconds, + ) + + if result.stdout: + logging.debug(f"stdout: {result.stdout}") + if result.stderr: + logging.debug(f"stderr: {result.stderr}") + except subprocess.TimeoutExpired as e: + logging.warning( + f"Command '{command_str}' timed out after {timeout_seconds} seconds." + ) + is_timeout = True + except subprocess.CalledProcessError as e: + print(e.output) + logging.error( + f"Command '{command_str}' returned non-zero exit status {e.returncode}." + ) + logging.error(f"Command '{command_str}' failed with error: {e.stderr}") + if check: + raise + except KeyboardInterrupt: + print("Ctrl+C detected, terminating child processes...") + + return RunResult(result, is_timeout) + + +def run_command_wrapper(task_pack: TaskPack) -> TaskResult: + """Help handle extra requirements and record more data for run_command()""" + if task_pack.command_need_device_id: + # Worker searches for the special symbol and substitutes it with the actual device_id + pattern = re.compile(re.escape(DEVICE_ID_PLACEHOLDER)) + task_pack.run_pack.command = [ + pattern.sub(str(device_id), s) for s in task_pack.run_pack.command + ] + + run_result = run_command(task_pack.run_pack) + + task_result = TaskResult( + run_result, task_pack.candidate_id, device_id=str(-1) + ) # Main process + if device_id: + task_result = TaskResult( + run_result, task_pack.candidate_id, device_id + ) # Subprocess + + time.sleep(task_pack.cooling_time) + + return task_result + + +def multiprocess_progress_wrapper( + num_worker: int, + task_list: list, + function: Callable, + initializer: Optional[Callable] = None, + initializer_inputs: Optional[Iterable[Any]] = None, +) -> list[Any]: + """Wrapper of multiprocessing pool and progress bar""" + results = [] + initializer_inputs = initializer_inputs or () + + # Create a multiprocessing pool + with multiprocessing.Pool( + num_worker, initializer, initializer_inputs + ) as worker_pool: + # Use tqdm to create a progress bar + with tqdm(total=len(task_list)) as pbar: + try: + # Use imap_unordered to asynchronously execute the worker function on each task + for result in worker_pool.imap_unordered(function, task_list): + pbar.update(1) # Update progress bar + results.append(result) + except KeyboardInterrupt: + # If Ctrl+C is pressed, terminate all child processes + worker_pool.terminate() + worker_pool.join() + sys.exit(1) # Exit the script + + return results + + +def extract_benchmark_from_run_result( + run_result: RunResult, +) -> Optional[list[dict[str, Any]]]: + """Extract the benchmark from the result JSON""" + if run_result.process_res and run_result.process_res.stdout: + try: + result_json = json.loads(run_result.process_res.stdout) + + return result_json.get("benchmarks", None) + except json.JSONDecodeError as e: + handle_error( + condition=True, + msg=f"Failed to parse JSON from stdout: {e}", + error_type=ValueError, + exit_program=True, + ) + + return None + + +def numerical_sort_key(path: Path) -> tuple[int | float, str]: + """ + Define a sort key function that splits the filename into a numeric and a string part. + Order: 0 | 0_a | 0_b | 1 | 1_a | 2 + """ + numeric_part: int | float + # Extract the numeric part at the start of the filename + match = re.match(r"(\d+)", path.stem) + if match: + numeric_part = int(match.group(1)) + # The rest of the filename after the numeric part + remaining_part = path.stem[len(match.group(0)) :] + else: + numeric_part = float("inf") + remaining_part = path.stem + return (numeric_part, remaining_part) + + +def calculate_md5(file_path: Path) -> str: + md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + return md5.hexdigest() + + +def find_collisions( + hash_list: list[tuple[int, str]] +) -> tuple[bool, list[tuple[str, list[int]]]]: + """ + Detect hash value collisions + Take input list of candidate index numbers and hash value strings: ex. [(1, 'abc'), (2, 'def'), (3, 'abc')] + Return collision boolean value and list of unique hash values along with their corresponding indices: ex. [('abc', [1,3]), ('def', [2])] + """ + hash_count: dict[str, list[int]] = {} + + # Count occurrences of each hash_val + for index, hash_val in hash_list: + if hash_val in hash_count: + hash_count[hash_val].append(index) + else: + hash_count[hash_val] = [index] + + # Prepare output for all hash values + hash_values = [(hash_val, indices) for hash_val, indices in hash_count.items()] + + # Determine if there are collisions + collisions_exist = any(len(indices) > 1 for hash_val, indices in hash_count.items()) + + return collisions_exist, hash_values + + +def load_pickle(file_path: Path) -> list[Any]: + handle_error( + condition=(not file_path.exists()), + msg=f"Configuration file not found: {file_path}", + error_type=FileNotFoundError, + ) + with open(file_path, "rb") as file: + loaded_array = pickle.load(file) + return loaded_array + + +def save_pickle(file_path: Path, input_list: list[Any]) -> None: + with open(file_path, "wb") as file: + pickle.dump(input_list, file) + + +def append_to_file(lines: list[str], filepath: Path, title: str = "") -> None: + """Appends new content to the end of the output.log.""" + title_str = "=" * 5 + f" {title} " + "=" * 5 + "\n" if title != "" else "" + with open(filepath, "a") as file: + file.write(title_str) + file.writelines(lines) + file.write("\n") + + +def generate_candidates( + args: argparse.Namespace, + path_config: PathConfig, + candidate_trackers: list[CandidateTracker], +) -> list[int]: + """Generate candidate files for tuning. Returns the list of candidate indexes""" + logging.debug("generate_candidates()") + + try: + shutil.copy( + path_config.global_config_epilog_mlir, path_config.local_config_epilog_mlir + ) + shutil.copy( + path_config.global_config_prolog_mlir, path_config.local_config_prolog_mlir + ) + except FileNotFoundError as e: + handle_error( + condition=True, + msg=f"Configuration file not found: {e}", + error_type=FileNotFoundError, + ) + + shutil.copy(args.input_file, path_config.template_mlir) + + mlirs = [] + try: + logging.debug("Captured messages from candidate_gen.py:") + candidate_gen.tune( + input=str(path_config.template_mlir), + output=str(path_config.candidates_dir), + limit=args.num_candidates, + num_subgroups=args.num_subgroups, + lhs_dims=args.lhs_dims, + rhs_dims=args.rhs_dims, + tile_dims=args.tile_dims, + ) + mlirs = sorted( + path_config.candidates_dir.glob("*.mlir"), key=numerical_sort_key + ) + except Exception as e: + logging.error("An error occurred during candidates generation: %s", str(e)) + # Capture and log debug messages from candidate_gen.py + tune_logger = logging.getLogger("tune") + for handler in logging.getLogger().handlers: + if isinstance(handler, logging.FileHandler): + tune_logger.handlers.append(handler) + tune_logger.exception("Error in candidate_gen.py:") + raise + logging.debug("candidate_gen.py ends") + + candidate_configs = load_pickle(path_config.candidate_configs_pkl) + candidate_configs.insert(0, None) # No Configuration class for 0.mlir + + # Create candidate trackers + assert len(mlirs) // 2 + 1 == len(candidate_configs) + candidates = [] + for mlir in mlirs: + if "_config.mlir" not in mlir.name: + candidates.append(int(mlir.stem)) + new_candidate = CandidateTracker( + candidate_id=int(mlir.stem), + dispatch_mlir_path=mlir, + configuration=candidate_configs[int(mlir.stem)], + ) + candidate_trackers.append(new_candidate) + else: + candidate_trackers[ + int(mlir.stem.split("_config")[0]) + ].dispatch_config_path = mlir + + handle_error( + condition=(len(candidates) == 0), msg="Failed to generate any candidates" + ) + + logging.info(f"Generated [{len(candidates)}] candidates") + + return candidates + + +def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, list[int]]: + """If a collision is found, generate a list of new indexes. If no collision, `unique_indexes = []`""" + # Check if candidate produces tbe same .vmfb + collision_detected, hash_list = find_collisions(index_hash_list) + unique_indexes: list[int] = [] + if not collision_detected: + return collision_detected, unique_indexes + + # If a collision is detected, select the first one from the collided list + logging.warning("Collisions detected") + for hash_val, indices in hash_list: + if len(indices) != 1: + logging.warning(f"Hash value '{hash_val}' collided at candidate {indices}.") + unique_indexes.append(indices[0]) + + return collision_detected, unique_indexes + + +def compile_dispatches( + args: argparse.Namespace, + path_config: PathConfig, + candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +) -> list[int]: + logging.debug("compile_dispatches()") + + if not candidates: + logging.warning("No candidates to compile.") + return [] + + path_config.compiled_dir.mkdir(parents=True, exist_ok=True) + path_config.compile_failed_dir.mkdir(parents=True, exist_ok=True) + path_config.specs_dir.mkdir(parents=True, exist_ok=True) + + task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_dispatch_compile_command( + candidate_trackers[i] + ), + check=False, + timeout_seconds=tuning_client.get_dispatch_compile_timeout_s(), + ), + candidate_id=i, + ) + for i in candidates + ] + num_worker = min(args.max_cpu_workers, len(task_list)) + multiprocess_progress_wrapper( + num_worker=num_worker, task_list=task_list, function=run_command_wrapper + ) + + # Note: failed/incomplete candidates can also be detected by checking if subprocess.res is None + compiled_files = sorted( + path_config.compiled_dir.glob("*.vmfb"), key=numerical_sort_key + ) + failed_files = sorted( + path_config.compile_failed_dir.glob("*.mlir"), key=numerical_sort_key + ) + + total, good, bad = len(task_list), len(compiled_files), len(failed_files) + compiling_rate = good / total * 100 + logging.info( + f"Total: {total} | Compiled: {good} | Failed: {bad} | Compiling Rate: {compiling_rate:.1f}%" + ) + + # Update candidate tracker + for failed_file in failed_files: + index = path_config.get_compiled_dispatch_index(failed_file) + candidate_trackers[index].compilation_successful = False + compiled_candidates = [] + compiled_candidates_hash_list = [] + for compiled_file in compiled_files: + index = path_config.get_compiled_dispatch_index(compiled_file) + compiled_candidates.append(index) + candidate_trackers[index].compilation_successful = True + candidate_trackers[index].compiled_dispatch_path = compiled_file + compiled_vmfb_path = candidate_trackers[index].compiled_dispatch_path + assert compiled_vmfb_path is not None + hash_val = calculate_md5(compiled_vmfb_path) + candidate_trackers[index].compiled_dispatch_hash = hash_val + compiled_candidates_hash_list.append((index, hash_val)) + + handle_error( + condition=(good == 0), + msg="All candidate dispatches .mlir files failed to compile", + ) + handle_error( + condition=(compiling_rate < 10), + msg=f"Compiling rate [{compiling_rate:.1f}%] < 10%", + level=logging.WARNING, + ) + + collision_detected, unique_indexes = collision_handler( + compiled_candidates_hash_list + ) + if collision_detected: + logging.info(f"Remains [{len(unique_indexes)}] unique candidate indexes") + + return compiled_candidates if not collision_detected else unique_indexes + + +def parse_dispatch_benchmark_results( + path_config: PathConfig, + benchmark_results: list[TaskResult], + candidate_trackers: list[CandidateTracker], +) -> tuple[list[ParsedDisptachBenchmarkResult], list[str]]: + benchmark_result_configs = [] + dump_list = [] + incomplete_list = [] + + for benchmark_result in benchmark_results: + candidate_id = benchmark_result.candidate_id + process_res = benchmark_result.run_result.process_res + + if not process_res: + if benchmark_result.run_result.is_timeout: + incomplete_list.append(candidate_id) + continue + + res_json = extract_benchmark_from_run_result(benchmark_result.run_result) + assert res_json is not None + res = IREEBenchmarkResult(candidate_id, res_json) + benchmark_time = res.get_mean_time_us() + assert benchmark_time is not None + candidate_trackers[candidate_id].first_benchmark_time = benchmark_time + candidate_trackers[ + candidate_id + ].spec_path = path_config.specs_dir / path_config.get_candidate_spec_filename( + candidate_id + ) + mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path + spec_path = candidate_trackers[candidate_id].spec_path + assert mlir_path is not None and spec_path is not None + dump_list.append(generate_display_DBR(candidate_id, benchmark_time) + "\n") + + benchmark_result_configs.append( + ( + ParsedDisptachBenchmarkResult( + candidate_id, + benchmark_time, + mlir_path, + spec_path, + ) + ) + ) + + if incomplete_list: + dump_list += [f"Candidate {i} not completed" for i in incomplete_list] + + return benchmark_result_configs, dump_list + + +def generate_sample_task_result( + stdout: str, candidate_id: int, device_id: str +) -> TaskResult: + res = subprocess.CompletedProcess( + args=[""], + stdout=stdout, + returncode=0, + ) + run_result = RunResult(res, False) + return TaskResult( + run_result=run_result, candidate_id=candidate_id, device_id=device_id + ) + + +def generate_dryrun_dispatch_benchmark_results( + compiled_candidates: list[int], +) -> list[TaskResult]: + logging.debug("generate_dryrun_dispatch_benchmark_results()") + + task_results = [ + generate_sample_task_result( + f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", + i, + str(0), + ) + for i in compiled_candidates + ] + + return task_results + + +def generate_dryrun_model_benchmark_results( + model_candidates: list[int], +) -> tuple[list[TaskResult], list[TaskResult]]: + candidate_results = [] + for i, j in enumerate(model_candidates): + stdout = f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms" + candidate_results.append(generate_sample_task_result(stdout, j, str(i % 3))) + + baseline_results = [ + generate_sample_task_result( + f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", + 0, + str(i), + ) + for i in range(3) + ] + + return candidate_results, baseline_results + + +def benchmark_dispatches( + args: argparse.Namespace, + path_config: PathConfig, + compiled_candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +): + logging.debug("benchmark_dispatches()") + + if args.dry_run: + benchmark_results = generate_dryrun_dispatch_benchmark_results( + compiled_candidates + ) + else: + # Benchmarking dispatch candidates + task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_dispatch_benchmark_command( + candidate_trackers[i] + ), + check=False, + timeout_seconds=tuning_client.get_dispatch_benchmark_timeout_s(), + ), + candidate_id=i, + command_need_device_id=True, + ) + for i in compiled_candidates + ] + worker_context_queue = create_worker_context_queue(args.devices) + benchmark_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + ( + parsed_benchmark_results, + dispatch_benchmark_dump_list, + ) = parse_dispatch_benchmark_results( + path_config, benchmark_results, candidate_trackers + ) + append_to_file( + dispatch_benchmark_dump_list, + filepath=path_config.output_unilog, + title="All Dispatch Benchmark Results", + ) + + benchmarking_rate = (len(parsed_benchmark_results) / len(benchmark_results)) * 100 + logging.info( + f"Total: {len(benchmark_results)} | Benchmarked: {len(parsed_benchmark_results)} | Failed: {len(benchmark_results) - len(parsed_benchmark_results)} | Benchmarking Rate: {benchmarking_rate:.1f}%" + ) + handle_error( + condition=(len(benchmark_results) == 0), + msg="Failed to benchmark all candidate .vmfb files", + ) + + # Select top candidates + best_results = sorted( + parsed_benchmark_results, key=lambda x: float(x.benchmark_time_in_seconds) + )[: args.num_model_candidates] + logging.info(f"Selected top[{len(best_results)}]") + + dump_list = [ + f"{result.benchmark_time_in_seconds}\t{result.candidate_mlir.as_posix()}\t{result.candidate_spec_mlir.as_posix()}\n" + for result in best_results + ] + append_to_file( + dump_list, filepath=path_config.output_unilog, title="Top Candidates Results" + ) + + top_candidates = [result.candidate_id for result in best_results] + return top_candidates + + +def compile_models( + args: argparse.Namespace, + path_config: PathConfig, + candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +) -> list[int]: + logging.debug("compile_models()") + + candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb + + if args.dry_run: + for i in candidates: + candidate_trackers[i].compiled_model_path = Path(f"model_{i}.vmfb") + return candidates + + if not candidates: + logging.warning("No model candidates to compile.") + return [] + + task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_model_compile_command(candidate_trackers[i]), + check=False, + timeout_seconds=tuning_client.get_model_compile_timeout_s(), + ), + candidate_id=i, + ) + for i in candidates + if i != 0 + ] + num_worker = min(args.max_cpu_workers, len(task_list)) + multiprocess_progress_wrapper( + num_worker=num_worker, task_list=task_list, function=run_command_wrapper + ) + + model_candidates_files = list(path_config.base_dir.glob("*.vmfb")) + + model_candidates_indexes = [] + model_candidates_hash_list = [] + + # Update candidate tracker + for model_candidate in model_candidates_files: + assert model_candidate is not None + index = path_config.get_compiled_model_index(model_candidate) + candidate_trackers[index].compiled_model_path = model_candidate + hash_val = calculate_md5(model_candidate) + candidate_trackers[index].compiled_model_hash = hash_val + model_candidates_hash_list.append((index, hash_val)) + model_candidates_indexes.append(index) + + # Check if model candidate produces tbe same .vmfb + collision_detected, unique_model_candidates_indexes = collision_handler( + model_candidates_hash_list + ) + + if collision_detected: + logging.info( + f"Remains [{len(unique_model_candidates_indexes)}] unique candidate indexes" + ) + + return ( + unique_model_candidates_indexes + if collision_detected + else model_candidates_indexes + ) + + +def group_benchmark_results_by_device_id( + benchmark_results: list[TaskResult], +) -> list[list[TaskResult]]: + """ + Groups benchmark results by device ID. + + e.g. + [TaskResult(res1, device_1), TaskResult(res2, device_2), TaskResult(res3, device_1)] + -----> + [ [TaskResult(res1, device_1), TaskResult(res3, device_1)], [TaskResult(res2, device_2)] ] + """ + grouped_results: dict[str, list[TaskResult]] = {} + for result in benchmark_results: + assert result.device_id is not None + if result.device_id not in grouped_results: + grouped_results[result.device_id] = [] + grouped_results[result.device_id].append(result) + + grouped_benchmark_results = [ + grouped_results[device_id] for device_id in sorted(grouped_results) + ] + + return grouped_benchmark_results + + +def parse_model_benchmark_results( + candidate_trackers: list[CandidateTracker], + candidate_results: list[TaskResult], + baseline_results: list[TaskResult], +): + """Update candidate_tracker and format a list of result strings to be saved later.""" + candidate_results = sorted(candidate_results, key=lambda br: br.device_id) + baseline_results = sorted(baseline_results, key=lambda tr: tr.device_id) + + # Assign candidates to the same groups by device_id + grouped_candidate_results = group_benchmark_results_by_device_id(candidate_results) + + # Insert baseline results to the head of each list + grouped_benchmark_results = [ + [x] + y for x, y in zip(baseline_results, grouped_candidate_results) + ] + + dump_list = [] + incomplete_list: list[ + tuple[int, Optional[str]] + ] = [] # format: [(candidate_id, device_id)] + + baseline_time = None + for same_device_results in grouped_benchmark_results: + dump_unsort_list: list[tuple[float, str]] = [] + for task_result in same_device_results: + candidate_id = task_result.candidate_id + device_id = task_result.device_id + process_res = task_result.run_result.process_res + + # Check if benchmarking has completed + if not process_res: + if task_result.run_result.is_timeout: + incomplete_list.append((candidate_id, device_id)) + if candidate_id == 0: + baseline_time = None + continue + + result_json = extract_benchmark_from_run_result(task_result.run_result) + assert result_json is not None + res = IREEBenchmarkResult(candidate_id, result_json) + benchmark_time = res.get_mean_time_us() + assert benchmark_time is not None + + # Record baseline benchmarking result and skip rest processes + if candidate_id == 0: + baseline_time = benchmark_time + baseline_vmfb_path = candidate_trackers[ + candidate_id + ].compiled_model_path + assert baseline_vmfb_path is not None + dump_str = ( + generate_display_MBR( + candidate_vmfb_path_str=baseline_vmfb_path.as_posix(), + device_id=device_id, + t1=benchmark_time, + ) + + "\n\n" + ) + dump_list.append(dump_str) + continue + + # Update candidate_tracker + candidate_trackers[candidate_id].model_benchmark_time = benchmark_time + candidate_trackers[candidate_id].model_benchmark_device_id = device_id + + # Calculate candidate improvement based on baseline. + if baseline_time: + candidate_trackers[candidate_id].baseline_benchmark_time = baseline_time + calibrated_benchmark_diff = ( + benchmark_time - baseline_time + ) / baseline_time + candidate_trackers[ + candidate_id + ].calibrated_benchmark_diff = calibrated_benchmark_diff + else: + calibrated_benchmark_diff = None + + # Collect candidate dump str + candidate_vmfb_path = candidate_trackers[candidate_id].compiled_model_path + assert candidate_vmfb_path is not None + dump_str = ( + generate_display_MBR( + candidate_vmfb_path_str=candidate_vmfb_path.as_posix(), + device_id=device_id, + t1=benchmark_time, + calibrated_diff=calibrated_benchmark_diff, + ) + + "\n\n" + ) + + dump_unsort_list.append((benchmark_time, dump_str)) + + # Sort model candidate benchmarking result str in ascending time order. + dump_list = dump_list + [ + dump_str for _, dump_str in sorted(dump_unsort_list, key=lambda x: x[0]) + ] + + # Store incomplete .vmfb file at the end of dump_list. + for index, device in incomplete_list: + file_path = candidate_trackers[index].compiled_model_path + assert file_path is not None + error_msg = f"Benchmarking result of {file_path.as_posix()} on device {device} is incomplete" + handle_error(condition=True, msg=error_msg, level=logging.WARNING) + dump_list.append(error_msg + "\n") + + return dump_list + + +def benchmark_models( + args: argparse.Namespace, + path_config: PathConfig, + model_candidates: list[int], + candidate_trackers: list[CandidateTracker], + tuning_client: TuningClient, +): + """Benchmark U-Net candidate files and log the results.""" + logging.debug("benchmark_models()") + + if args.dry_run: + candidate_results, baseline_results = generate_dryrun_model_benchmark_results( + model_candidates + ) + else: + # Benchmarking model candidates + worker_context_queue = create_worker_context_queue(args.devices) + benchmark_task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_model_benchmark_command( + candidate_trackers[i] + ), + check=False, + timeout_seconds=tuning_client.get_dispatch_benchmark_timeout_s(), + ), + candidate_id=i, + command_need_device_id=True, + cooling_time=10, + ) + for i in model_candidates + ] + candidate_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=benchmark_task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + # Benchmarking baselines on each involved device + candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb + worker_context_queue = create_worker_context_queue(args.devices) + baseline_task_list = [ + TaskPack( + RunPack( + command=tuning_client.get_model_benchmark_command( + candidate_trackers[0] + ), + check=False, + timeout_seconds=tuning_client.get_model_benchmark_timeout_s(), + ), + candidate_id=0, + command_need_device_id=True, + ) + ] * len(group_benchmark_results_by_device_id(candidate_results)) + baseline_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=baseline_task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) + + dump_list = parse_model_benchmark_results( + candidate_trackers, candidate_results, baseline_results + ) + + append_to_file( + dump_list, filepath=path_config.output_unilog, title="Model Benchmark Results" + ) + + +def summarize_top_candidates( + path_config: PathConfig, candidate_trackers: list[CandidateTracker] +): + dump_list = [] + top_candidates = [] + for candidate in candidate_trackers: + if candidate.candidate_id == 0 or candidate.model_benchmark_time is None: + continue + top_candidates.append( + (candidate.candidate_id, candidate.model_benchmark_time) + ) # collect (id, time) + + top_candidates = sorted( + top_candidates, key=lambda x: x[1] + ) # sort the list in ascending benchmark time order + top_candidate_ids = [item[0] for item in top_candidates] # get list of candidate id + + for candidate_id in top_candidate_ids: + candidate = candidate_trackers[candidate_id] + assert candidate.dispatch_config_path is not None + with open(candidate.dispatch_config_path, "r") as file: + config_file_contents = file.read() + final_str = f"Candidate {candidate.candidate_id}:\nModel benchmark time: {candidate.model_benchmark_time} on device {candidate.model_benchmark_device_id}\nDispatch benchmark time: {candidate.first_benchmark_time} on device {candidate.model_benchmark_device_id}\nSpec file path: {candidate.spec_path}\nSpec contents:{config_file_contents}\n\n" + dump_list.append(final_str) + + with open(path_config.result_summary_log, "w") as file: + file.writelines(dump_list) + + +def sanitize_filename(filename: str) -> str: + # Replace invalid characters by an underscore + sanitized = re.sub(r"[^\w\.-]", "_", filename) + return sanitized diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py new file mode 100644 index 000000000..11af59af4 --- /dev/null +++ b/tuner/tuner/libtuner_test.py @@ -0,0 +1,501 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import pytest +import json +from subprocess import CompletedProcess +from unittest.mock import call, patch, MagicMock +from . import libtuner + +""" +Usage: python -m pytest libtuner_test.py +""" + + +def test_group_benchmark_results_by_device_id() -> None: + # Create mock TaskResult objects with device_id attributes + task_result_1: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) + task_result_1.device_id = "device_1" + + task_result_2: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) + task_result_2.device_id = "device_2" + + task_result_3: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) + task_result_3.device_id = "device_1" + + benchmark_results = [task_result_1, task_result_2, task_result_3] + + expected_grouped_results = [ + [task_result_1, task_result_3], # Grouped by device_1 + [task_result_2], # Grouped by device_2 + ] + + grouped_results = libtuner.group_benchmark_results_by_device_id(benchmark_results) + + assert grouped_results == expected_grouped_results + assert grouped_results[0][0].device_id == "device_1" + assert grouped_results[1][0].device_id == "device_2" + + +def test_find_collisions() -> None: + input = [(1, "abc"), (2, "def"), (3, "abc")] + assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])]) + input = [(1, "abc"), (2, "def"), (3, "hig")] + assert libtuner.find_collisions(input) == ( + False, + [("abc", [1]), ("def", [2]), ("hig", [3])], + ) + + +def test_collision_handler() -> None: + input = [(1, "abc"), (2, "def"), (3, "abc"), (4, "def"), (5, "hig")] + assert libtuner.collision_handler(input) == (True, [1, 2, 5]) + input = [(1, "abc"), (2, "def"), (3, "hig")] + assert libtuner.collision_handler(input) == (False, []) + + +def test_IREEBenchmarkResult_get() -> None: + # Time is int in us + int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=1, result_json=int_json) + assert res.get_mean_time_us() == float(1) + + # Time is float in us + float_json = [{"aggregate_name": "mean", "real_time": 123.45, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=2, result_json=float_json) + assert res.get_mean_time_us() == 123.45 + + # Time is in seconds + seconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "s"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=3, result_json=seconds_json) + assert res.get_mean_time_us() == 1.0 * 1e6 + + # Time is in miliseconds + miliseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ms"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=4, result_json=miliseconds_json) + assert res.get_mean_time_us() == 1.0 * 1e3 + + # Time is in nanoseconds + nanoseconds_json = [{"aggregate_name": "mean", "real_time": 1.0, "time_unit": "ns"}] + + res = libtuner.IREEBenchmarkResult(candidate_id=5, result_json=nanoseconds_json) + assert res.get_mean_time_us() == 1.0 * 1e-3 + + small_number_json = [ + { + "aggregate_name": "mean", + "real_time": 3.4591828516259519e-02, + "time_unit": "ms", + } + ] + + res = libtuner.IREEBenchmarkResult(candidate_id=6, result_json=small_number_json) + assert res.get_mean_time_us() == 34.591828516259519 + + # Invalid json: missing real_time + invalid_real_time_json = [{"aggregate_name": "mean", "real_time": None}] + + res = libtuner.IREEBenchmarkResult( + candidate_id=7, result_json=invalid_real_time_json + ) + assert res.get_mean_time_us() == None + + # Invalid json: empty dictionary + res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json=[]) + assert res.get_mean_time_us() is None + + # Invalid json: invalid time unit + invalid_time_unit_json = [ + {"aggregate_name": "mean", "real_time": 1.0, "time_unit": "invalid_unit"} + ] + + with pytest.raises(AssertionError, match="Unsupported time unit: invalid_unit"): + res = libtuner.IREEBenchmarkResult( + candidate_id=9, result_json=invalid_time_unit_json + ) + res.get_mean_time_us() + + # Invalid json: missing aggregate_name + invalid_aggregate_name_json = [{"real_time": 1.0, "time_unit": "us"}] + + res = libtuner.IREEBenchmarkResult( + candidate_id=10, result_json=invalid_aggregate_name_json + ) + assert res.get_mean_time_us() is None + + +def test_generate_display_BR() -> None: + output = libtuner.generate_display_DBR(1, 3.14) + expected = f"1\tMean Time: 3.1" + assert output == expected, "DispatchBenchmarkResult generates invalid sample string" + + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89) + expected = "Benchmarking: baseline.vmfb on device 1: 568" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, 0.0314) + expected = "Benchmarking: baseline.vmfb on device 1: 568 (+3.140%)" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, -3.14) + expected = "Benchmarking: baseline.vmfb on device 1: 568 (-314.000%)" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + + +def make_mock_task_result() -> libtuner.TaskResult: + process: CompletedProcess = MagicMock(spec=CompletedProcess) + run_result = libtuner.RunResult(process, False) + task_result = libtuner.TaskResult(run_result, 0, "") + return task_result + + +def test_parse_dispatch_benchmark_results() -> None: + base_path = libtuner.Path("/mock/base/dir") + spec_dir = base_path / "specs" + path_config = libtuner.PathConfig() + object.__setattr__(path_config, "specs_dir", spec_dir) + + mock_result_1 = make_mock_task_result() + mock_json_1 = { + "benchmarks": [ + {"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"} + ] + } + assert mock_result_1.run_result.process_res is not None + mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1) + mock_result_1.candidate_id = 1 + mock_result_2 = make_mock_task_result() + mock_json_2 = { + "benchmarks": [ + {"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"} + ] + } + assert mock_result_2.run_result.process_res is not None + mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2) + mock_result_2.candidate_id = 2 + mock_result_3 = make_mock_task_result() + mock_json_3 = { + "benchmarks": [ + { + "aggregate_name": "mean", + "real_time": 3.4591828516259519e-02, + "time_unit": "ms", + } + ] + } + assert mock_result_3.run_result.process_res is not None + mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3) + mock_result_3.candidate_id = 3 + # Incomplete result. + mock_result_4 = libtuner.TaskResult(libtuner.RunResult(None, True), 4, "4") + benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4] + + candidate_trackers = [] + for i in range(4): + tracker = libtuner.CandidateTracker(candidate_id=i) + tracker.dispatch_mlir_path = libtuner.Path(f"/mock/mlir/path/{i}.mlir") + candidate_trackers.append(tracker) + + expected_parsed_results = [ + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=1, + benchmark_time_in_seconds=100.0, + candidate_mlir=libtuner.Path("/mock/mlir/path/1.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/1_spec.mlir"), + ), + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=2, + benchmark_time_in_seconds=200.0, + candidate_mlir=libtuner.Path("/mock/mlir/path/2.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/2_spec.mlir"), + ), + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=3, + benchmark_time_in_seconds=34.591828516259519, + candidate_mlir=libtuner.Path("/mock/mlir/path/3.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/3_spec.mlir"), + ), + ] + expected_dump_list = [ + "1\tMean Time: 100.0\n", + "2\tMean Time: 200.0\n", + "3\tMean Time: 34.6\n", + "Candidate 4 not completed", + ] + + parsed_results, dump_list = libtuner.parse_dispatch_benchmark_results( + path_config, benchmark_results, candidate_trackers + ) + + assert parsed_results == expected_parsed_results + assert dump_list == expected_dump_list + assert candidate_trackers[1].first_benchmark_time == 100.0 + assert candidate_trackers[1].spec_path == libtuner.Path( + "/mock/base/dir/specs/1_spec.mlir" + ) + assert candidate_trackers[2].first_benchmark_time == 200.0 + assert candidate_trackers[2].spec_path == libtuner.Path( + "/mock/base/dir/specs/2_spec.mlir" + ) + assert candidate_trackers[3].first_benchmark_time == 34.591828516259519 + assert candidate_trackers[3].spec_path == libtuner.Path( + "/mock/base/dir/specs/3_spec.mlir" + ) + + +def test_parse_model_benchmark_results() -> None: + # Setup mock data for candidate_trackers + tracker0 = libtuner.CandidateTracker(0) + tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb") + + tracker1 = libtuner.CandidateTracker(1) + tracker1.compiled_model_path = libtuner.Path("/path/to/model_1.vmfb") + + tracker2 = libtuner.CandidateTracker(2) + tracker2.compiled_model_path = libtuner.Path("/path/to/model_2.vmfb") + + tracker3 = libtuner.CandidateTracker(3) + tracker3.compiled_model_path = libtuner.Path("/path/to/model_3.vmfb") + + candidate_trackers = [tracker0, tracker1, tracker2, tracker3] + + # Setup mock data for task results + result1 = make_mock_task_result() + result_json_1 = {"benchmarks": [{"real_time": 1.23}]} + assert result1.run_result.process_res is not None + result1.run_result.process_res.stdout = json.dumps(result_json_1) + result1.candidate_id = 1 + result1.device_id = "device1" + + result2 = make_mock_task_result() + result_json_2 = {"benchmarks": [{"real_time": 4.56}]} + assert result2.run_result.process_res is not None + result2.run_result.process_res.stdout = json.dumps(result_json_2) + result2.candidate_id = 2 + result2.device_id = "device2" + + result3 = make_mock_task_result() + result_json_3 = {"benchmarks": [{"real_time": 0.98}]} + assert result3.run_result.process_res is not None + result3.run_result.process_res.stdout = json.dumps(result_json_3) + result3.candidate_id = 0 + result3.device_id = "device1" + + result4 = make_mock_task_result() + result_json_4 = {"benchmarks": [{"real_time": 4.13}]} + assert result4.run_result.process_res is not None + result4.run_result.process_res.stdout = json.dumps(result_json_4) + result4.candidate_id = 0 + result4.device_id = "device2" + + # Incomplete baseline on device3 + result5 = libtuner.TaskResult(libtuner.RunResult(None, True), 0, "device3") + + result6 = make_mock_task_result() + result_json_6 = {"benchmarks": [{"real_time": 3.38}]} + assert result6.run_result.process_res is not None + result6.run_result.process_res.stdout = json.dumps(result_json_6) + result6.candidate_id = 3 + result6.device_id = "device3" + + candidate_results = [result1, result2, result6] + baseline_results = [result3, result4, result5] + + # Skip real benchmark extraction, directly use given values from above + def mock_get_mean_time_us(self): + return float(self.result_json[0]["real_time"]) if self.result_json else None + + # Mock IREEBenchmarkResult to return wanted benchmark times + with patch( + f"{libtuner.__name__}.IREEBenchmarkResult.get_mean_time_us", + new=mock_get_mean_time_us, + ): + # Mock handle_error to avoid actual logging during tests + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + dump_list = libtuner.parse_model_benchmark_results( + candidate_trackers, candidate_results, baseline_results + ) + + # Verify interactions with candidate_trackers + assert tracker1.model_benchmark_time == 1.23 + assert tracker1.model_benchmark_device_id == "device1" + assert tracker1.baseline_benchmark_time == 0.98 + assert tracker1.calibrated_benchmark_diff == pytest.approx( + (1.23 - 0.98) / 0.98, rel=1e-6 + ) + + assert tracker2.model_benchmark_time == 4.56 + assert tracker2.model_benchmark_device_id == "device2" + assert tracker2.baseline_benchmark_time == 4.13 + assert tracker2.calibrated_benchmark_diff == pytest.approx( + (4.56 - 4.13) / 4.13, rel=1e-6 + ) + + assert tracker3.model_benchmark_time == 3.38 + assert tracker3.model_benchmark_device_id == "device3" + + assert dump_list == [ + "Benchmarking: /path/to/baseline.vmfb on device device1: 0.98\n" "\n", + "Benchmarking: /path/to/model_1.vmfb on device device1: 1.23 (+25.510%)\n" + "\n", + "Benchmarking: /path/to/baseline.vmfb on device device2: 4.13\n" "\n", + "Benchmarking: /path/to/model_2.vmfb on device device2: 4.56 (+10.412%)\n" + "\n", + "Benchmarking: /path/to/model_3.vmfb on device device3: 3.38\n" "\n", + "Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete\n", + ] + + # Verify handle_error was called correctly + mock_handle_error.assert_called_once_with( + condition=True, + msg="Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete", + level=libtuner.logging.WARNING, + ) + + +def test_extract_driver_names() -> None: + user_devices = ["hip://0", "local-sync://default", "cuda://default"] + expected_output = {"hip", "local-sync", "cuda"} + + assert libtuner.extract_driver_names(user_devices) == expected_output + + +def test_fetch_available_devices_success() -> None: + drivers = ["hip", "local-sync", "cuda"] + mock_devices = { + "hip": [{"path": "ABCD", "device_id": 1}], + "local-sync": [{"path": "default", "device_id": 2}], + "cuda": [{"path": "default", "device_id": 3}], + } + + with patch(f"{libtuner.__name__}.ireert.get_driver") as mock_get_driver: + mock_driver = MagicMock() + + def get_mock_driver(name): + mock_driver.query_available_devices.side_effect = lambda: mock_devices[name] + return mock_driver + + mock_get_driver.side_effect = get_mock_driver + + actual_output = libtuner.fetch_available_devices(drivers) + expected_output = [ + "hip://ABCD", + "hip://0", + "local-sync://default", + "local-sync://1", + "cuda://default", + "cuda://2", + ] + + assert actual_output == expected_output + + +def test_fetch_available_devices_failure() -> None: + drivers = ["hip", "local-sync", "cuda"] + mock_devices = { + "hip": [{"path": "ABCD", "device_id": 1}], + "local-sync": ValueError("Failed to initialize"), + "cuda": [{"path": "default", "device_id": 1}], + } + + with patch(f"{libtuner.__name__}.ireert.get_driver") as mock_get_driver: + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + mock_driver = MagicMock() + + def get_mock_driver(name): + if isinstance(mock_devices[name], list): + mock_driver.query_available_devices.side_effect = ( + lambda: mock_devices[name] + ) + else: + mock_driver.query_available_devices.side_effect = lambda: ( + _ for _ in () + ).throw(mock_devices[name]) + return mock_driver + + mock_get_driver.side_effect = get_mock_driver + + actual_output = libtuner.fetch_available_devices(drivers) + expected_output = ["hip://ABCD", "hip://0", "cuda://default", "cuda://0"] + + assert actual_output == expected_output + mock_handle_error.assert_called_once_with( + condition=True, + msg="Could not initialize driver local-sync: Failed to initialize", + error_type=ValueError, + exit_program=True, + ) + + +def test_parse_devices() -> None: + user_devices_str = "hip://0, local-sync://default, cuda://default" + expected_output = ["hip://0", "local-sync://default", "cuda://default"] + + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + actual_output = libtuner.parse_devices(user_devices_str) + assert actual_output == expected_output + + mock_handle_error.assert_not_called() + + +def test_parse_devices_with_invalid_input() -> None: + user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default" + expected_output = [ + "hip://0", + "local-sync://default", + "invalid_device", + "cuda://default", + ] + + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + actual_output = libtuner.parse_devices(user_devices_str) + assert actual_output == expected_output + + mock_handle_error.assert_called_once_with( + condition=True, + msg=f"Invalid device list: {user_devices_str}. Error: {ValueError()}", + error_type=argparse.ArgumentTypeError, + ) + + +def test_validate_devices() -> None: + user_devices = ["hip://0", "local-sync://default"] + user_drivers = {"hip", "local-sync"} + + with patch(f"{libtuner.__name__}.extract_driver_names", return_value=user_drivers): + with patch( + f"{libtuner.__name__}.fetch_available_devices", + return_value=["hip://0", "local-sync://default"], + ): + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + libtuner.validate_devices(user_devices) + assert all( + call[1]["condition"] is False + for call in mock_handle_error.call_args_list + ) + + +def test_validate_devices_with_invalid_device() -> None: + user_devices = ["hip://0", "local-sync://default", "cuda://default"] + user_drivers = {"hip", "local-sync", "cuda"} + + with patch(f"{libtuner.__name__}.extract_driver_names", return_value=user_drivers): + with patch( + f"{libtuner.__name__}.fetch_available_devices", + return_value=["hip://0", "local-sync://default"], + ): + with patch(f"{libtuner.__name__}.handle_error") as mock_handle_error: + libtuner.validate_devices(user_devices) + expected_call = call( + condition=True, + msg=f"Invalid device specified: cuda://default\nFetched available devices: ['hip://0', 'local-sync://default']", + error_type=argparse.ArgumentError, + exit_program=True, + ) + assert expected_call in mock_handle_error.call_args_list diff --git a/tuner/tuner/py.typed b/tuner/tuner/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/tuner/version.json b/tuner/version.json new file mode 100644 index 000000000..9519501ae --- /dev/null +++ b/tuner/version.json @@ -0,0 +1,3 @@ +{ + "package-version": "3.1.0.dev" +} diff --git a/turbine-requirements.txt b/turbine-requirements.txt deleted file mode 100644 index 70078ffa4..000000000 --- a/turbine-requirements.txt +++ /dev/null @@ -1 +0,0 @@ --e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine" diff --git a/version_info.json b/version_info.json deleted file mode 100644 index 2d8f19afb..000000000 --- a/version_info.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "package-version": "0.1.dev3" -}