diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml
new file mode 100644
index 000000000..3a660dec9
--- /dev/null
+++ b/.github/workflows/build_packages.yml
@@ -0,0 +1,158 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: Build packages
+
+on:
+ workflow_dispatch:
+ schedule:
+ # Runs at 11:00 AM UTC, which is 3:00 AM PST (UTC-8)
+ - cron: '0 11 * * *'
+
+permissions:
+ contents: read
+
+jobs:
+ # Note: metadata generation could happen in a separate trigger/schedule
+ # workflow. For cross platform builds, it's useful to just generate the
+ # metadata on Linux and pass that to later jobs using artifacts.
+ setup_metadata:
+ if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
+ runs-on: ubuntu-24.04
+ outputs:
+ shark_package_version: ${{ steps.version.outputs.shark_package_version }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ submodules: false
+ - name: Setup Python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: 3.12
+ cache: "pip"
+ - name: Install Python packages
+ run: |
+ pip install packaging
+ pip freeze
+ - name: Generate release candidate versions
+ id: version_rc
+ run: |
+ version_suffix="$(printf 'rc%(%Y%m%d)T')"
+ echo "version_suffix=${version_suffix}" >> $GITHUB_ENV
+ sharktank_package_version=$(python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} sharktank)
+ shortfin_package_version=$(python3 build_tools/python_deploy/compute_local_version.py --version-suffix=${version_suffix} shortfin)
+ sharkai_package_version=$(python3 build_tools/python_deploy/compute_common_version.py -rc --version-suffix=${version_suffix} --write-json)
+ - name: Upload version_local.json
+ uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
+ with:
+ name: version_local
+ path: |
+ sharktank/version_local.json
+ shortfin/version_local.json
+ shark-ai/version_local.json
+
+ build_packages:
+ name: "${{ matrix.package }} :: ${{ matrix.platform }} :: ${{ matrix.python-version }}"
+ runs-on: ${{ matrix.runs-on }}
+ permissions:
+ contents: write
+ needs: [setup_metadata]
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ # Ubuntu packages.
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: shark-ai
+ python-version: cp311-cp311 # Ignored (generic wheel), set for workflow naming
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: sharktank
+ python-version: cp311-cp311 # Ignored (generic wheel), set for workflow naming
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: shortfin
+ python-version: cp310-cp310
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: shortfin
+ python-version: cp311-cp311
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: shortfin
+ python-version: cp312-cp312
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: shortfin
+ python-version: cp313-cp313
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: shortfin
+ python-version: cp313-cp313t
+
+ # TODO(#130): macOS platform
+ # TODO(#130): Windows platform
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ path: "c" # Windows can hit path length limits, so use a short path.
+ submodules: false
+
+ - name: Download version_local.json
+ uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
+ with:
+ name: version_local
+ path: ./c/
+ merge-multiple: true
+
+ - name: Build shark-ai (Linux x86_64)
+ if: "matrix.package == 'shark-ai' && matrix.platform == 'linux-x86_64'"
+ env:
+ OUTPUT_DIR: "${{ github.workspace }}/bindist"
+ run: |
+ [ -e ./bindist/* ] && rm ./bindist/*
+ ./c/build_tools/python_deploy/write_requirements.py --version-suffix=${version_suffix}
+ ./c/shark-ai/build_tools/build_linux_package.sh
+
+ - name: Build sharktank (Linux x86_64)
+ if: "matrix.package == 'sharktank' && matrix.platform == 'linux-x86_64'"
+ env:
+ OUTPUT_DIR: "${{ github.workspace }}/bindist"
+ run: |
+ [ -e ./bindist/* ] && rm ./bindist/*
+ ./c/sharktank/build_tools/build_linux_package.sh
+
+ - name: Build shortfin (Linux x86_64, ${{ matrix.python-version }})
+ if: "matrix.package == 'shortfin' && matrix.platform == 'linux-x86_64'"
+ env:
+ OUTPUT_DIR: "${{ github.workspace }}/bindist"
+ OVERRIDE_PYTHON_VERSIONS: "${{ matrix.python-version }}"
+ run: |
+ [ -e ./bindist/* ] && rm ./bindist/*
+ ./c/shortfin/build_tools/build_linux_package.sh
+
+ - name: Upload python wheels
+ uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
+ with:
+ if-no-files-found: error
+ name: snapshot-${{ matrix.package }}-${{ matrix.platform }}-${{ matrix.python-version }}
+ path: bindist
+
+ - name: Release python wheels
+ uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0
+ with:
+ artifacts: bindist/*.whl
+ tag: "dev-wheels"
+ name: "dev-wheels"
+ body: "Automatic snapshot release of shark-ai python wheels."
+ removeArtifacts: false
+ allowUpdates: true
+ replacesArtifacts: true
+ makeLatest: false
diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml
new file mode 100644
index 000000000..644066094
--- /dev/null
+++ b/.github/workflows/ci-llama-large-tests.yaml
@@ -0,0 +1,93 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: Llama Benchmarking Tests
+
+on:
+ workflow_dispatch:
+ schedule:
+ # Weekdays at 4:00 AM UTC = 9:00 PM PST.
+ - cron: "0 4 * * 1-5"
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ test_llama_large:
+ if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
+ name: "Llama Benchmarking Tests"
+ strategy:
+ matrix:
+ version: [3.11]
+ fail-fast: false
+ runs-on: llama-mi300x-1
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ VENV_DIR: ${{ github.workspace }}/.venv
+ steps:
+ - name: Get Current Date
+ id: date
+ run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
+
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
+
+ - name: Install pip deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Install latest iree-turbine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+
+ # Test with nightly releases, not what iree-turbine uses.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler \
+ iree-base-runtime
+
+ - name: Run llama tests
+ run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --html=out/llm/llama/benchmark/index.html
+
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
+ with:
+ github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
+ publish_dir: ./out/llm/llama/benchmark
+ destination_dir: ./llm/llama/benchmark
+ keep_files: true
+
+ - name: Upload llama executable files
+ uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
+ with:
+ name: llama-files
+ path: ${{ github.workspace }}/${{ steps.date.outputs.date }}
diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml
new file mode 100644
index 000000000..a8c315ec8
--- /dev/null
+++ b/.github/workflows/ci-llama-quick-tests.yaml
@@ -0,0 +1,85 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: Llama Benchmarking 8B Tests
+
+on:
+ workflow_dispatch:
+ pull_request:
+ push:
+ branches:
+ - main
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ test_llama_quick:
+ name: "Llama Benchmarking 8B Tests"
+ strategy:
+ matrix:
+ version: [3.11]
+ fail-fast: false
+ runs-on: llama-mi300x-1
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ VENV_DIR: ${{ github.workspace }}/.venv
+ steps:
+ - name: Get Current Date
+ id: date
+ run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
+
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
+
+ - name: Install pip deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Install latest iree-turbine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+
+ # Test with nightly releases, not what iree-turbine uses.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler \
+ iree-base-runtime
+
+ - name: Run llama 8b f16 decomposed test
+ run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-quick-llama-test
+
+ - name: Upload llama executable files
+ uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
+ with:
+ name: llama-files
+ path: ${{ github.workspace }}/${{ steps.date.outputs.date }}
diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml
new file mode 100644
index 000000000..355ffcf8b
--- /dev/null
+++ b/.github/workflows/ci-sdxl.yaml
@@ -0,0 +1,110 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - shortfin - SDXL
+
+on:
+ workflow_dispatch:
+ pull_request:
+ paths:
+ - '.github/workflows/ci-sdxl.yaml'
+ - 'shortfin/**'
+ push:
+ branches:
+ - main
+ paths:
+ - '.github/workflows/ci-sdxl.yaml'
+ - 'shortfin/**'
+
+permissions:
+ contents: read
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+env:
+ IREE_REPO_DIR: ${{ github.workspace }}/iree
+ LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/
+
+jobs:
+ build-and-test:
+ name: Build and test
+ runs-on: mi300-sdxl-kernel
+
+ steps:
+ - name: Install dependencies
+ run: |
+ if dpkg -s cmake &>/dev/null; then
+ echo 'cmake is installed'
+ else
+ sudo apt install cmake -y
+ fi
+ if dpkg -s ninja-build &>/dev/null; then
+ echo 'ninja is installed'
+ else
+ sudo apt install ninja -y
+ fi
+
+ - name: Checkout repository
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ submodules: false
+
+ - name: Checkout IREE repo
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ repository: iree-org/iree
+ path: ${{ env.IREE_REPO_DIR }}
+ submodules: false
+ ref: iree-3.0.0rc20241118
+
+ - name: Initalize IREE submodules
+ working-directory: ${{ env.IREE_REPO_DIR }}
+ run : |
+ git submodule update --init --depth 1 -- third_party/benchmark
+ git submodule update --init --depth 1 -- third_party/cpuinfo/
+ git submodule update --init --depth 1 -- third_party/flatcc
+ git submodule update --init --depth 1 -- third_party/googletest
+ git submodule update --init --depth 1 -- third_party/hip-build-deps/
+
+ - name: Setup Python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: "3.12"
+ cache: "pip"
+ - name: Install Python packages
+ # TODO: Switch to `pip install -r requirements.txt -e shortfin/`.
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ run: |
+ pip install -r requirements-tests.txt
+ pip install -r requirements-iree-compiler.txt
+ pip freeze
+
+ - name: Build shortfin (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ run: |
+ mkdir build
+ cmake -GNinja \
+ -S. \
+ -Bbuild \
+ -DCMAKE_C_COMPILER=clang-18 \
+ -DCMAKE_CXX_COMPILER=clang++-18 \
+ -DSHORTFIN_BUNDLE_DEPS=ON \
+ -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
+ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
+ cmake --build build --target all
+ pip install -v -e build/
+
+ - name: Test shortfin (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ run: |
+ ctest --timeout 30 --output-on-failure --test-dir build
+ HIP_VISIBLE_DEVICES=0 pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu
diff --git a/.github/workflows/test.yaml b/.github/workflows/ci-sglang-benchmark.yml
similarity index 54%
rename from .github/workflows/test.yaml
rename to .github/workflows/ci-sglang-benchmark.yml
index 6eb519717..504e7e5e3 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/ci-sglang-benchmark.yml
@@ -1,10 +1,16 @@
-name: Integration Tests
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: SGLang Llama Benchmarking Tests
on:
workflow_dispatch:
schedule:
- # Weekdays at 13:00 UTC = 05:00 PST / 06:00 PDT.
- - cron: "5 4 * * 1-5"
+ # Weekdays at 4:00 AM UTC = 9:00 PM PST.
+ - cron: "0 4 * * 1-5"
concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
@@ -15,31 +21,35 @@ concurrency:
cancel-in-progress: true
jobs:
- test_llama:
- name: "Integration Tests - llama"
+ sglang_bench_serve:
+ if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
+ name: "SGLang Serving Benchmark Tests"
strategy:
matrix:
version: [3.11]
- os: [ubuntu-latest, windows-latest]
fail-fast: false
- runs-on: ${{matrix.os}}
+ runs-on: llama-mi300x-3
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
steps:
+ - name: Get Current Date
+ id: date
+ run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
+
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -53,42 +63,27 @@ jobs:
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
- -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine"
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
# Try with the latest nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
- iree-compiler \
- iree-runtime \
+ iree-base-compiler==3.0.0rc20241118 \
+ iree-base-runtime==3.0.0rc20241118 \
"numpy<2.0"
- - name: Run llama test
- run: ./build_tools/integration_tests/llama_export_compile_serve.sh
-
- test_punet:
- name: "Integration Tests - punet"
- runs-on: nodai-amdgpu-mi250-x86-64
- env:
- VENV_DIR: ${{ github.workspace }}/.venv
- steps:
- - name: "Checkout Code"
- uses: actions/checkout@v3
-
- - name: "Setup Python venv"
- run: python3 -m venv ${VENV_DIR}
+ - name: Install SGLang
+ run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"
- - name: Install pip deps
- run: |
- source ${VENV_DIR}/bin/activate
- python -m pip install --no-compile --upgrade pip
- pip install --no-compile -r pytorch-rocm-requirements.txt
- pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
- -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine"
- pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
+ - name: Launch Shortfin Server
+ run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
- - name: Run punet tests
- run: |
- source ${VENV_DIR}/bin/activate
- pytest -v sharktank/ -m model_punet
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
+ with:
+ github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
+ publish_dir: ./out/llm/sglang
+ destination_dir: ./llm/sglang
+ keep_files: true
diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml
new file mode 100644
index 000000000..1c382617d
--- /dev/null
+++ b/.github/workflows/ci-sglang-integration-tests.yml
@@ -0,0 +1,84 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: SGLang Llama Integration Tests
+
+on:
+ workflow_dispatch:
+ schedule:
+ # Run periodically, every 4 hours. This is ran periodically with the
+ # intent of catching regressions early, and allowing for those
+ # regressions to be easily triaged to a small subset of commits.
+ - cron: '0 */4 * * *'
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ sglang_bench_serve:
+ name: "SGLang Integration Tests"
+ strategy:
+ matrix:
+ version: [3.11]
+ fail-fast: false
+ runs-on: llama-mi300x-3
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ steps:
+ - name: Get Current Date
+ id: date
+ run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
+
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
+
+ - name: Install pip deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+ pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
+
+ # Use newest possible releases to be able to track commits that may
+ # cause errors.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade \
+ iree-base-compiler \
+ iree-base-runtime \
+ "numpy<2.0"
+
+ - name: Install SGLang
+ run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"
+
+ - name: Install sentence_transformers
+ run: pip install sentence_transformers
+
+ - name: Run Integration Tests
+ run: pytest -v app_tests/integration_tests/llm/sglang --log-cli-level=INFO
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci-shark-ai.yml
similarity index 58%
rename from .github/workflows/ci.yaml
rename to .github/workflows/ci-shark-ai.yml
index b2d302a8e..bf8007e65 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci-shark-ai.yml
@@ -1,4 +1,10 @@
-name: CI
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - shark-ai
on:
workflow_dispatch:
@@ -16,14 +22,13 @@ concurrency:
cancel-in-progress: true
jobs:
- test:
- name: "Unit Tests and Type Checking"
+ test_shortfin_llm_server:
+ name: "Integration Tests - Shortfin LLM Server"
strategy:
matrix:
version: [3.11]
- os: [ubuntu-latest, windows-latest]
fail-fast: false
- runs-on: ${{matrix.os}}
+ runs-on: nodai-amdgpu-mi250-x86-64
defaults:
run:
shell: bash
@@ -32,15 +37,15 @@ jobs:
steps:
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -53,22 +58,18 @@ jobs:
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
- pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
- -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine"
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
- - name: Run sharktank tests
- if: ${{ !cancelled() }}
- run: |
- pytest -n 4 sharktank/
+ # Install latest iree-tubrine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
- - name: Run shortfin tests
- if: ${{ !cancelled() }}
- run: |
- pytest -n 4 shortfin/
+ # Try with the latest IREE nightly releases, not what iree-turbine pins.
+ # We could also pin to a known working or stable version.
+ # This should eventually stabilize. Do the best we can for now.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler \
+ iree-base-runtime
- # TODO: Enable type checking of sharktank
- - name: MyPy Type Checking Shortfin
- if: ${{ !cancelled() }}
- run: |
- (cd shortfin && mypy)
+ - name: Run LLM Integration Tests
+ run: pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO
diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml
new file mode 100644
index 000000000..1d3960b43
--- /dev/null
+++ b/.github/workflows/ci-sharktank.yml
@@ -0,0 +1,131 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - sharktank
+
+on:
+ workflow_dispatch:
+ pull_request:
+ push:
+ branches:
+ - main
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ test:
+ name: "Unit Tests and Type Checking"
+ strategy:
+ matrix:
+ version: [3.11]
+ os: [ubuntu-latest, windows-latest]
+ fail-fast: false
+ runs-on: ${{matrix.os}}
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ steps:
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
+
+ - name: Install pip deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Update to the latest iree packages.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler iree-base-runtime --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+ - name: Run sharktank tests
+ if: ${{ !cancelled() }}
+ run: |
+ pytest -n 4 sharktank/
+
+
+ test_with_data:
+ name: "Data-dependent Tests"
+ strategy:
+ matrix:
+ version: [3.11]
+ runs-on: [llama-mi300x-3]
+ fail-fast: false
+ runs-on: ${{matrix.runs-on}}
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ HF_HOME: "/data/huggingface"
+ SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
+ steps:
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
+
+ - name: Install sharktank deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Install latest iree-tubrine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+ # Try with the latest IREE nightly releases, not what iree-turbine pins.
+ # We could also pin to a known working or stable version.
+ # This should eventually stabilize. Do the best we can for now.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler \
+ iree-base-runtime
+
+ - name: Run tests
+ run: |
+ pytest \
+ --with-t5-data \
+ sharktank/tests/models/t5/t5_test.py
diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml
index 5de7d4182..81b920e31 100644
--- a/.github/workflows/ci-tuner.yml
+++ b/.github/workflows/ci-tuner.yml
@@ -1,11 +1,23 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
name: CI - Tuner
on:
workflow_dispatch:
pull_request:
+ paths:
+ - '.github/workflows/ci-tuner.yml'
+ - 'tuner/**'
push:
branches:
- main
+ paths:
+ - '.github/workflows/ci-tuner.yml'
+ - 'tuner/**'
concurrency:
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
@@ -20,10 +32,10 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4.1.7
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
- uses: actions/setup-python@v4
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: '3.10.12'
@@ -37,8 +49,11 @@ jobs:
pip install -r tuner/requirements-tuner.txt
python -m pip install \
--find-links https://iree.dev/pip-release-links.html \
- --upgrade \
- iree-compiler iree-runtime
+ --upgrade --pre \
+ iree-base-compiler iree-base-runtime
- name: Run tuner tests
run: pytest tuner/
+
+ - name: Run mypy type checker
+ run: mypy tuner/tuner
diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml
new file mode 100644
index 000000000..0164b6cdc
--- /dev/null
+++ b/.github/workflows/ci_eval.yaml
@@ -0,0 +1,143 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - sharktank perplexity
+
+on:
+ workflow_dispatch:
+ schedule:
+ # Weekdays nightly at 07:00 UTC = 23:00 PST / 00:00 PDT.
+ - cron: "0 7 * * 1-5"
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ test_perplexity_iree:
+ if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
+ timeout-minutes: 1000
+ name: "Perplexity-IREE"
+ strategy:
+ matrix:
+ version: [3.11]
+ runs-on: [llama-mi300x-3]
+ fail-fast: false
+ runs-on: ${{matrix.runs-on}}
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
+ steps:
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
+
+ - name: Install sharktank deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Install latest iree-tubrine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+ # Try with the latest IREE nightly releases, not what iree-turbine pins.
+ # We could also pin to a known working or stable version.
+ # This should eventually stabilize. Do the best we can for now.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler \
+ iree-base-runtime
+
+ - name: Run perplexity test with IREE
+ run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html
+
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
+ with:
+ github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
+ publish_dir: ./out/llm/llama/perplexity/iree_perplexity
+ destination_dir: ./llm/llama/perplexity/iree_perplexity
+ keep_files: true
+
+ test_perplexity_torch:
+ if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
+ timeout-minutes: 1000
+ name: "Perplexity-Torch"
+ strategy:
+ matrix:
+ version: [3.11]
+ runs-on: [llama-mi300x-3]
+ fail-fast: false
+ runs-on: ${{matrix.runs-on}}
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
+ steps:
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
+
+ - name: Install sharktank deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Install latest iree-tubrine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+ - name: Run perplexity test with Torch
+ run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html
+
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
+ with:
+ github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
+ publish_dir: ./out/llm/llama/perplexity/torch_perplexity
+ destination_dir: ./llm/llama/perplexity/torch_perplexity
+ keep_files: true
diff --git a/.github/workflows/ci_eval_short.yaml b/.github/workflows/ci_eval_short.yaml
new file mode 100644
index 000000000..4622f5c57
--- /dev/null
+++ b/.github/workflows/ci_eval_short.yaml
@@ -0,0 +1,77 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - sharktank perplexity short
+
+on:
+ workflow_dispatch:
+ pull_request:
+ push:
+ branches:
+ - main
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ test_perplexity_iree:
+ name: "Llama3.1 8B FP16"
+ strategy:
+ matrix:
+ version: [3.11]
+ runs-on: [llama-mi300x-3]
+ fail-fast: false
+ runs-on: ${{matrix.runs-on}}
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
+ steps:
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}
+
+ - name: Install sharktank deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
+
+ # Install latest iree-tubrine.
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+
+ # Try with the latest IREE nightly releases, not what iree-turbine pins.
+ # We could also pin to a known working or stable version.
+ # This should eventually stabilize. Do the best we can for now.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler \
+ iree-base-runtime
+
+ - name: Run perplexity test with vmfb
+ run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --bs=5 --iree-device='hip://6' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json
diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml
index 793058570..afeca11a6 100644
--- a/.github/workflows/ci_linux_x64-libshortfin.yml
+++ b/.github/workflows/ci_linux_x64-libshortfin.yml
@@ -4,125 +4,122 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-name: CI - libshortfin
+name: CI - shortfin
on:
workflow_dispatch:
pull_request:
+ paths:
+ - '.github/workflows/ci_linux_x64-libshortfin.yml'
+ - 'shortfin/**'
push:
branches:
- main
paths:
- '.github/workflows/ci_linux_x64-libshortfin.yml'
- - 'libshortfin/**'
+ - 'shortfin/**'
permissions:
contents: read
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
env:
IREE_REPO_DIR: ${{ github.workspace }}/iree
- LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/
+ LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/
jobs:
build-and-test:
name: Build and test
runs-on: ubuntu-24.04
+ strategy:
+ matrix:
+ python-version: ["3.10", "3.11", "3.12"]
steps:
- name: Install dependencies
run: |
sudo apt update
sudo apt install clang lld cmake ninja-build
- sudo apt install libspdlog-dev libxtensor-dev
- name: Checkout repository
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: false
- name: Checkout IREE repo
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
- ref: candidate-20240904.1006
+ ref: iree-3.0.0rc20241118
- name: Initalize IREE submodules
+ working-directory: ${{ env.IREE_REPO_DIR }}
run : |
- cd ${{ env.IREE_REPO_DIR }}
git submodule update --init --depth 1 -- third_party/benchmark
git submodule update --init --depth 1 -- third_party/cpuinfo/
git submodule update --init --depth 1 -- third_party/flatcc
git submodule update --init --depth 1 -- third_party/googletest
git submodule update --init --depth 1 -- third_party/hip-build-deps/
- - name: Build IREE runtime
- run: |
- mkdir ${{ env.IREE_REPO_DIR }}/build
- cd ${{ env.IREE_REPO_DIR }}/build
- cmake -GNinja \
- -DCMAKE_C_COMPILER=clang-18 \
- -DCMAKE_CXX_COMPILER=clang++-18 \
- -DIREE_ENABLE_LLD=ON \
- -DIREE_ERROR_ON_MISSING_SUBMODULES=OFF \
- -DIREE_HAL_DRIVER_DEFAULTS=OFF \
- -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \
- -DIREE_HAL_DRIVER_LOCAL_TASK=ON \
- -DIREE_HAL_DRIVER_HIP=ON \
- -DIREE_BUILD_COMPILER=OFF \
- -DIREE_BUILD_SAMPLES=OFF \
- -DIREE_BUILD_TESTS=OFF \
- ..
- cmake --build . --target all
-
- - name: Setup Python
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
- python-version: "3.12"
+ python-version: ${{ matrix.python-version }}
cache: "pip"
- name: Install Python packages
- # TODO: Switch to `pip install -r requirements.txt -e libshortfin/`.
+ # TODO: Switch to `pip install -r requirements.txt -e shortfin/`.
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
- pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt
+ pip install -r requirements-tests.txt
+ pip install -r requirements-iree-compiler.txt
pip freeze
- - name: Build libshortfin (full)
+ - name: Build shortfin (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
- mkdir ${{ env.LIBSHORTFIN_DIR }}/build
- cd ${{ env.LIBSHORTFIN_DIR }}/build
+ mkdir build
cmake -GNinja \
+ -S. \
+ -Bbuild \
-DCMAKE_C_COMPILER=clang-18 \
-DCMAKE_CXX_COMPILER=clang++-18 \
-DCMAKE_LINKER_TYPE=LLD \
-DSHORTFIN_BUNDLE_DEPS=ON \
- -DCMAKE_PREFIX_PATH=${{ env.IREE_REPO_DIR }}/build/lib/cmake/IREE \
- -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
- ..
- cmake --build . --target all
- pip install -v -e .
+ -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
+ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
+ cmake --build build --target all
+ pip install -v -e build/
- - name: Test libshortfin (full)
+ - name: Test shortfin (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
- cd ${{ env.LIBSHORTFIN_DIR }}/build
- ctest --timeout 30 --output-on-failure
- cd ${{ env.LIBSHORTFIN_DIR }}
- pytest -s -v -m "not requires_amd_gpu"
+ ctest --timeout 30 --output-on-failure --test-dir build
+ pytest -s
- - name: Build libshortfin (host-only)
+ - name: Build shortfin (host-only)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
- mkdir ${{ env.LIBSHORTFIN_DIR }}/build-host-only
- cd ${{ env.LIBSHORTFIN_DIR }}/build-host-only
+ mkdir build-host-only
# In this configuration, also build static+dynamic in order to verify
# that path structurally works.
cmake -GNinja \
+ -S. \
+ -Bbuild-host-only \
-DCMAKE_C_COMPILER=clang-18 \
-DCMAKE_CXX_COMPILER=clang++-18 \
-DCMAKE_LINKER_TYPE=LLD \
- -DCMAKE_PREFIX_PATH=${{ env.IREE_REPO_DIR }}/build/lib/cmake/IREE \
+ -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
-DSHORTFIN_HAVE_AMDGPU=OFF \
-DSHORTFIN_BUILD_STATIC=ON \
- -DSHORTFIN_BUILD_DYNAMIC=ON \
- ..
- cmake --build . --target all
+ -DSHORTFIN_BUILD_DYNAMIC=ON
+ cmake --build build-host-only --target all
diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml
index c998fbc77..42de8f0f6 100644
--- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml
+++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml
@@ -4,29 +4,40 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-name: CI - libshortfin - ASan
+name: CI - shortfin - ASan
on:
workflow_dispatch:
pull_request:
+ paths:
+ - '.github/workflows/ci_linux_x64_asan-libshortfin.yml'
+ - 'shortfin/**'
push:
branches:
- main
paths:
- '.github/workflows/ci_linux_x64_asan-libshortfin.yml'
- - 'libshortfin/**'
+ - 'shortfin/**'
permissions:
contents: read
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
env:
PYENV_ROOT: ${{ github.workspace }}/pyenv
- PYENV_REF: 9ecd803bffaffb949fbdd8c70cb086227f6a3202 # v2.4.10
- PYTHON_VER: 3.12.3
+ PYENV_REF: 96b3fb2fc3bee85650cb22e2cb06c83c24509a6d # v2.4.17
+ PYTHON_VER: 3.12.7
CACHE_ASAN_VER: 2
CACHE_DEPS_VER: 1
IREE_SOURCE_DIR: ${{ github.workspace }}/iree
- LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/
+ LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/
jobs:
setup-python-asan:
@@ -40,7 +51,7 @@ jobs:
steps:
- name: Cache Python ASan
id: cache-python-asan
- uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
with:
path: ${{ env.PYENV_ROOT }}
key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }}
@@ -55,7 +66,7 @@ jobs:
- name: Checkout pyenv
if: steps.cache-python-asan.outputs.cache-hit != 'true'
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: pyenv/pyenv
ref: ${{ env.PYENV_REF }}
@@ -63,8 +74,8 @@ jobs:
- name: Install pyenv & Python
if: steps.cache-python-asan.outputs.cache-hit != 'true'
+ working-directory: ${{ env.PYENV_ROOT }}
run: |
- cd ${{ env.PYENV_ROOT }}
src/configure && make -C src
export PATH=${{ env.PYENV_ROOT }}/bin:$PATH && eval "$(pyenv init -)"
CC=clang-18 CXX=clang++-18 LDFLAGS="-lstdc++" PYTHON_CONFIGURE_OPTS="--with-address-sanitizer" pyenv install -v ${{ env.PYTHON_VER }}
@@ -72,13 +83,15 @@ jobs:
build-and-test:
- name: Build and test libshortfin
+ name: Build and test shortfin
needs: [setup-python-asan]
runs-on: ubuntu-24.04
env:
- # TODO(#151): Don't ignore ODR violations
- ASAN_OPTIONS: detect_odr_violation=0
- LSAN_OPTIONS: suppressions=${{ github.workspace }}/libshortfin/build_tools/python_lsan_suppressions.txt
+ # We can't count on being leak free in general (i.e. pip, etc) so disable
+ # leak checker by default. Here we suppress any ASAN features needed to
+ # pass the build. Test configuration is done specially just for that step.
+ ASAN_OPTIONS: detect_leaks=0,detect_odr_violation=0
+ LSAN_OPTIONS: suppressions=${{ github.workspace }}/shortfin/build_tools/python_lsan_suppressions.txt
steps:
- name: Install dependencies
run: |
@@ -86,21 +99,21 @@ jobs:
sudo apt install clang lld cmake ninja-build
- name: Checkout repository
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: false
- name: Checkout IREE repo
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: iree-org/iree
path: ${{ env.IREE_SOURCE_DIR }}
submodules: false
- ref: candidate-20240904.1006
+ ref: iree-3.0.0rc20241118
- name: Initalize IREE submodules
+ working-directory: ${{ env.IREE_SOURCE_DIR }}
run : |
- cd ${{ env.IREE_SOURCE_DIR }}
git submodule update --init --depth 1 -- third_party/benchmark
git submodule update --init --depth 1 -- third_party/cpuinfo/
git submodule update --init --depth 1 -- third_party/flatcc
@@ -112,7 +125,7 @@ jobs:
uses: actions/cache/restore@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
with:
path: ${{ env.PYENV_ROOT }}
- key: ${{ runner.os }}-python-deps-${{ hashFiles('libshortfin/requirements-tests.txt') }}-v${{ env.CACHE_DEPS_VER }}
+ key: ${{ runner.os }}-python-deps-${{ hashFiles('shortfin/requirements-tests.txt') }}-v${{ env.CACHE_DEPS_VER }}
- name: Restore Python ASan cache
id: cache-python-asan
@@ -128,10 +141,11 @@ jobs:
- name: Install Python dependencies
if: steps.cache-python-deps-restore.outputs.cache-hit != 'true'
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
eval "$(pyenv init -)"
- pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt
- pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-iree-compiler.txt
+ pip install -r requirements-tests.txt
+ pip install -r requirements-iree-compiler.txt
pip freeze
- name: Save Python dependencies cache
@@ -142,35 +156,22 @@ jobs:
path: ${{ env.PYENV_ROOT }}
key: ${{ steps.cache-python-deps-restore.outputs.cache-primary-key }}
- - name: Build libshortfin
+ - name: Build shortfin
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
eval "$(pyenv init -)"
- mkdir ${{ env.LIBSHORTFIN_DIR }}/build
- cd ${{ env.LIBSHORTFIN_DIR }}/build
- cmake -GNinja \
- -DCMAKE_BUILD_TYPE=Debug \
- -DCMAKE_C_COMPILER=clang-18 \
- -DCMAKE_CXX_COMPILER=clang++-18 \
- -DCMAKE_LINKER_TYPE=LLD \
- -DSHORTFIN_BUNDLE_DEPS=ON \
- -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_SOURCE_DIR }} \
- -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \
- -DSHORTFIN_ENABLE_ASAN=ON \
- ..
- cmake --build . --target all
+ SHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_SOURCE_DIR }}" \
+ SHORTFIN_ENABLE_ASAN=ON \
+ SHORTFIN_DEV_MODE=ON \
+ SHORTFIN_RUN_CTESTS=ON \
pip install -v -e .
- - name: Run ctest
- if: ${{ !cancelled() }}
- env:
- CTEST_OUTPUT_ON_FAILURE: 1
- run: |
- cd ${{ env.LIBSHORTFIN_DIR }}/build
- ctest --timeout 30 --output-on-failure
-
- name: Run pytest
if: ${{ !cancelled() }}
+ env:
+ # TODO(#151): Don't ignore ODR violations
+ ASAN_OPTIONS: detect_odr_violation=0
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
run: |
eval "$(pyenv init -)"
- cd ${{ env.LIBSHORTFIN_DIR }}
- pytest -m "not requires_amd_gpu" -s
+ pytest -s
diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
new file mode 100644
index 000000000..0e0e1db2a
--- /dev/null
+++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
@@ -0,0 +1,103 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - shortfin - Python 3.13 Free-threaded
+
+on:
+ workflow_dispatch:
+ pull_request:
+ paths:
+ - '.github/workflows/ci_linux_x64-libshortfin.yml'
+ - 'shortfin/**'
+
+ push:
+ branches:
+ - main
+ paths:
+ - '.github/workflows/ci_linux_x64-libshortfin.yml'
+ - 'shortfin/**'
+
+permissions:
+ contents: read
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+env:
+ IREE_REPO_DIR: ${{ github.workspace }}/iree
+ LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/
+
+jobs:
+ build-and-test:
+ name: Build and test
+ runs-on: ubuntu-24.04
+
+ steps:
+ - name: Install dependencies
+ run: |
+ sudo apt update
+ sudo apt install clang lld cmake ninja-build
+
+ - name: Checkout repository
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ submodules: false
+
+ - name: Checkout IREE repo
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ repository: iree-org/iree
+ path: ${{ env.IREE_REPO_DIR }}
+ submodules: false
+ ref: iree-3.0.0rc20241118
+
+ - name: Initalize IREE submodules
+ working-directory: ${{ env.IREE_REPO_DIR }}
+ run : |
+ git submodule update --init --depth 1 -- third_party/benchmark
+ git submodule update --init --depth 1 -- third_party/cpuinfo/
+ git submodule update --init --depth 1 -- third_party/flatcc
+ git submodule update --init --depth 1 -- third_party/googletest
+ git submodule update --init --depth 1 -- third_party/hip-build-deps/
+
+ - name: Setup Python
+ uses: deadsnakes/action@e640ac8743173a67cca4d7d77cd837e514bf98e8 # v3.2.0
+ with:
+ python-version: "3.13-dev"
+ nogil : true
+ - name: Install Python packages
+ run: |
+ pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests-nogil.txt
+ pip freeze
+
+ - name: Build shortfin (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ run: |
+ mkdir build
+ cmake -GNinja \
+ -S. \
+ -Bbuild \
+ -DCMAKE_BUILD_TYPE=Debug \
+ -DCMAKE_C_COMPILER=clang-18 \
+ -DCMAKE_CXX_COMPILER=clang++-18 \
+ -DCMAKE_LINKER_TYPE=LLD \
+ -DSHORTFIN_BUNDLE_DEPS=ON \
+ -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
+ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
+ cmake --build build --target all
+ pip install -v -e build/
+
+ - name: Run shortfin Python tests (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ run: |
+ pytest -s --ignore=tests/examples/fastapi_test.py --ignore=tests/apps/llm/components/cache_test.py --ignore=tests/apps/sd
+ # TODO: Enable further tests and switch to
+ # pytest -s
diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml
new file mode 100644
index 000000000..00873c432
--- /dev/null
+++ b/.github/workflows/ci_windows_x64-libshortfin.yml
@@ -0,0 +1,98 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: CI - shortfin - Windows
+
+on:
+ workflow_dispatch:
+ pull_request:
+ paths:
+ - '.github/workflows/ci_windows_x64-libshortfin.yml'
+ - 'shortfin/**'
+ push:
+ branches:
+ - main
+ paths:
+ - '.github/workflows/ci_windows_x64-libshortfin.yml'
+ - 'shortfin/**'
+
+permissions:
+ contents: read
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+env:
+ IREE_REPO_DIR: ${{ github.workspace }}/iree
+ LIBSHORTFIN_DIR: ${{ github.workspace }}/shortfin/
+
+jobs:
+ build-and-test:
+ name: Build and test
+ runs-on: windows-2022
+
+ steps:
+ - name: Configure MSVC
+ uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1.13.0
+
+ - name: Checkout repository
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ submodules: false
+
+ - name: Checkout IREE repo
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ repository: iree-org/iree
+ path: ${{ env.IREE_REPO_DIR }}
+ submodules: false
+ ref: iree-3.0.0rc20241118
+
+ - name: Initalize IREE submodules
+ working-directory: ${{ env.IREE_REPO_DIR }}
+ run : |
+ git submodule update --init --depth 1 -- third_party/benchmark
+ git submodule update --init --depth 1 -- third_party/cpuinfo/
+ git submodule update --init --depth 1 -- third_party/flatcc
+ git submodule update --init --depth 1 -- third_party/googletest
+ git submodule update --init --depth 1 -- third_party/hip-build-deps/
+
+ - name: Setup Python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: "3.12"
+ cache: "pip"
+ - name: Install Python packages
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ run: |
+ pip install -r requirements-tests.txt
+ pip install -r requirements-iree-compiler.txt
+ pip freeze
+
+ - name: Build shortfin (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ shell: bash
+ run: |
+ mkdir build
+ cmake -GNinja \
+ -S. \
+ -Bbuild \
+ -DSHORTFIN_BUNDLE_DEPS=ON \
+ -DSHORTFIN_IREE_SOURCE_DIR="${{ env.IREE_REPO_DIR }}" \
+ -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON
+ cmake --build build --target all
+ pip install -v -e build/
+
+ - name: Test shortfin (full)
+ working-directory: ${{ env.LIBSHORTFIN_DIR }}
+ run: |
+ ctest --timeout 30 --output-on-failure --test-dir build
+ pytest -s
diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
index 2b11178bf..8ec1e8d55 100644
--- a/.github/workflows/pre-commit.yaml
+++ b/.github/workflows/pre-commit.yaml
@@ -9,6 +9,6 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
- - uses: actions/setup-python@v3
- - uses: pre-commit/action@v3.0.1
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
diff --git a/.gitignore b/.gitignore
index b766af2d2..bdb0b5387 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,12 +8,17 @@
*.suo
*.user
+# Common IREE source/build paths
+iree/
+iree-build/
+
# macOS files
.DS_Store
# CMake artifacts
build/
build-*/
+_build/
# Python
__pycache__
@@ -25,8 +30,20 @@ wheelhouse
*.whl
*.venv
+# Local-only config options
+version_local.json
+
#Model artifacts
*.pt
*.safetensors
*.gguf
*.vmfb
+genfiles/
+*.zip
+tmp/
+
+# Known inference result blobs
+*output*.png
+
+# Log files.
+*.log
diff --git a/README.md b/README.md
index 7ff8d0126..44f1e6113 100644
--- a/README.md
+++ b/README.md
@@ -1,65 +1,71 @@
-# SHARK Modeling and Serving Libraries
-
-**WARNING: This is an early preview that is in progress. It is not ready for
-general use.**
+# shark-ai: SHARK Modeling and Serving Libraries
+![GitHub License](https://img.shields.io/github/license/nod-ai/shark-ai)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit)
+## SHARK Users
+
+If you're looking to use SHARK check out our [User Guide](docs/user_guide.md). For developers continue to read on.
+
+
+
+## Sub-projects
+
+### [`shortfin/`](./shortfin/)
+
+
+
+[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush)
-## Development Getting Started
+The shortfin sub-project is SHARK's high performance inference library and
+serving engine.
-Use this as a guide to get started developing the project using pinned,
-pre-release dependencies. You are welcome to deviate as you see fit, but
-these canonical directions mirror what the CI does.
+* API documentation for shortfin is available on
+ [readthedocs](https://shortfin.readthedocs.io/en/latest/).
-### Setup a venv
+### [`sharktank/`](./sharktank/)
-We recommend setting up a virtual environment (venv). The project is configured
-to ignore `.venv` directories, and editors like VSCode pick them up by default.
+[![PyPI version](https://badge.fury.io/py/sharktank.svg)](https://badge.fury.io/py/sharktank) [![CI - sharktank](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml?query=event%3Apush)
-```
-python -m venv --prompt sharktank .venv
-source .venv/bin/activate
-```
+The SHARK Tank sub-project contains a collection of model recipes and
+conversion tools to produce inference-optimized programs.
-### Install PyTorch for Your System
+> [!WARNING]
+> SHARK Tank is still under development. Experienced users may want to try it
+> out, but we currently recommend most users download pre-exported or
+> pre-compiled model files for serving with shortfin.
-If no explicit action is taken, the default PyTorch version will be installed.
-This will give you a current CUDA-based version. Install a different variant
-by doing so explicitly first:
+
-*CPU:*
+* See the [SHARK Tank Programming Guide](./docs/programming_guide.md) for
+ information about core concepts, the development model, dataset management,
+ and more.
+* See [Direct Quantization with SHARK Tank](./docs/quantization.md)
+ for information about quantization support.
-```
-pip install -r pytorch-cpu-requirements.txt
-```
+### [`tuner/`](./tuner/)
-*ROCM:*
+[![CI - Tuner](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml?query=event%3Apush)
-```
-pip install -r pytorch-rocm-requirements.txt
-```
+The Tuner sub-project assists with tuning program performance by searching for
+optimal parameter configurations to use during model compilation.
-### Install Development Packages
+> [!WARNING]
+> SHARK Tuner is still in early development. Interested users may want
+> to try it out, but the tuner is not ready for general use yet. Check out
+> [the readme](tuner/README.md) for more details.
-```
-# Clone and install editable iree-turbine dep in deps/
-pip install -f https://iree.dev/pip-release-links.html --src deps \
- -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine"
+## Support matrix
-# Install editable local projects.
-pip install -r requirements.txt -e sharktank/ shortfin/
-```
+
-### Running Tests
+### Models
-```
-pytest sharktank
-pytest shortfin
-```
+Model name | Model recipes | Serving apps
+---------- | ------------- | ------------
+SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/sd)
+llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/llm)
-### Optional: Pre-commits and developer settings
+## SHARK Developers
-This project is set up to use the `pre-commit` tooling. To install it in
-your local repo, run: `pre-commit install`. After this point, when making
-commits locally, hooks will run. See https://pre-commit.com/
+If you're looking to develop SHARK, check out our [Developer Guide](docs/developer_guide.md).
diff --git a/app_tests/__init__.py b/app_tests/__init__.py
new file mode 100644
index 000000000..a85ba359d
--- /dev/null
+++ b/app_tests/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
diff --git a/app_tests/benchmark_tests/__init__.py b/app_tests/benchmark_tests/__init__.py
new file mode 100644
index 000000000..a85ba359d
--- /dev/null
+++ b/app_tests/benchmark_tests/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/conftest.py
new file mode 100644
index 000000000..cc354b7eb
--- /dev/null
+++ b/app_tests/benchmark_tests/llm/conftest.py
@@ -0,0 +1,46 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import os
+import pytest
+import sys
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
+from integration_tests.llm.utils import compile_model, export_paged_llm_v1
+
+
+@pytest.fixture(scope="module")
+def pre_process_model(request, tmp_path_factory):
+ tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")
+
+ model_path = request.param["model_path"]
+ settings = request.param["settings"]
+ batch_sizes = request.param["batch_sizes"]
+
+ mlir_path = tmp_dir / "model.mlir"
+ config_path = tmp_dir / "config.json"
+ vmfb_path = tmp_dir / "model.vmfb"
+
+ export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)
+
+ config = {
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 131072,
+ "attn_head_count": 8,
+ "attn_head_dim": 128,
+ "prefill_batch_sizes": batch_sizes,
+ "decode_batch_sizes": batch_sizes,
+ "transformer_block_count": 32,
+ "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
+ }
+ with open(config_path, "w") as file:
+ json.dump(config, file)
+
+ compile_model(mlir_path, vmfb_path, settings)
+
+ return tmp_dir
diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py
new file mode 100644
index 000000000..0de775795
--- /dev/null
+++ b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py
@@ -0,0 +1,122 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import logging
+import multiprocessing
+import os
+from pathlib import Path
+import pytest
+import time
+from unittest.mock import patch
+
+pytest.importorskip("sglang")
+from sglang import bench_serving
+
+from utils import SGLangBenchmarkArgs
+
+from integration_tests.llm.utils import (
+ find_available_port,
+ start_llm_server,
+)
+
+logger = logging.getLogger("__name__")
+
+device_settings = {
+ "device_flags": [
+ "--iree-hal-target-backends=rocm",
+ "--iree-hip-target=gfx942",
+ ],
+ "device": "hip",
+}
+
+# TODO: Download on demand instead of assuming files exist at this path
+MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa")
+TOKENIZER_DIR = Path("/data/llama3.1/8b/")
+
+
+def log_jsonl_result(file_path):
+ with open(file_path, "r") as file:
+ json_string = file.readline().strip()
+
+ json_data = json.loads(json_string)
+ for key, val in json_data.items():
+ logger.info(f"{key.upper()}: {val}")
+
+
+@pytest.mark.parametrize(
+ "request_rate",
+ [1, 2, 4, 8, 16, 32],
+)
+@pytest.mark.parametrize(
+ "pre_process_model",
+ [
+ (
+ {
+ "model_path": MODEL_PATH,
+ "settings": device_settings,
+ "batch_sizes": [1, 4],
+ }
+ )
+ ],
+ indirect=True,
+)
+def test_sglang_benchmark_server(request_rate, pre_process_model):
+ # TODO: Remove when multi-device is fixed
+ os.environ["ROCR_VISIBLE_DEVICES"] = "1"
+
+ tmp_dir = pre_process_model
+
+ config_path = tmp_dir / "config.json"
+ vmfb_path = tmp_dir / "model.vmfb"
+ tokenizer_path = TOKENIZER_DIR / "tokenizer.json"
+
+ # Start shortfin llm server
+ port = find_available_port()
+ server_process = start_llm_server(
+ port,
+ tokenizer_path,
+ config_path,
+ vmfb_path,
+ MODEL_PATH,
+ device_settings,
+ timeout=30,
+ )
+
+ # Run and collect SGLang Serving Benchmark
+ benchmark_args = SGLangBenchmarkArgs(
+ backend="shortfin",
+ num_prompt=10,
+ base_url=f"http://localhost:{port}",
+ tokenizer=TOKENIZER_DIR,
+ request_rate=request_rate,
+ )
+ output_file = (
+ tmp_dir
+ / f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl"
+ )
+ benchmark_args.output_file = output_file
+
+ logger.info("Running SGLang Benchmark with the following args:")
+ logger.info(benchmark_args)
+ try:
+ start = time.time()
+ with patch.object(bench_serving, "print", side_effect=logger.info):
+ benchmark_process = multiprocessing.Process(
+ target=bench_serving.run_benchmark,
+ args=(benchmark_args.as_namespace(),),
+ )
+ benchmark_process.start()
+ benchmark_process.join()
+
+ logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds")
+ logger.info("======== RESULTS ========")
+ log_jsonl_result(benchmark_args.output_file)
+ except Exception as e:
+ logger.info(e)
+
+ server_process.terminate()
+ server_process.wait()
diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/utils.py
new file mode 100644
index 000000000..55b01da04
--- /dev/null
+++ b/app_tests/benchmark_tests/llm/utils.py
@@ -0,0 +1,56 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from argparse import Namespace
+from dataclasses import dataclass
+from pathlib import Path
+
+
+@dataclass
+class SGLangBenchmarkArgs:
+ base_url: str
+ num_prompt: int
+ request_rate: int
+ tokenizer: str | Path
+
+ seed: int = 1
+ extra_request_body: str | None = None
+ output_file: str | Path | None = None
+ port: int = 8000
+ backend: str = "shortfin"
+
+ def as_namespace(self) -> Namespace:
+ return Namespace(
+ num_prompts=self.num_prompt,
+ base_url=self.base_url,
+ tokenizer=str(self.tokenizer),
+ request_rate=self.request_rate,
+ backend=self.backend,
+ output_file=self.output_file,
+ seed=self.seed,
+ extra_request_body=self.extra_request_body,
+ port=8000,
+ model=None,
+ dataset_name="sharegpt",
+ random_input_len=None,
+ random_output_len=None,
+ random_range_ratio=0.0,
+ dataset_path="",
+ sharegpt_output_len=None,
+ multi=False,
+ disable_tqdm=False,
+ disable_stream=False,
+ disable_ignore_eos=False,
+ )
+
+ def __repr__(self):
+ return (
+ f"Backend: {self.backend}\n"
+ f"Base URL: {self.base_url}\n"
+ f"Num Prompt: {self.num_prompt}\n"
+ f"Tokenizer: {self.tokenizer}\n"
+ f"Request Rate: {self.request_rate}"
+ )
diff --git a/app_tests/integration_tests/__init__.py b/app_tests/integration_tests/__init__.py
new file mode 100644
index 000000000..a85ba359d
--- /dev/null
+++ b/app_tests/integration_tests/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
diff --git a/app_tests/integration_tests/llm/__init__.py b/app_tests/integration_tests/llm/__init__.py
new file mode 100644
index 000000000..a85ba359d
--- /dev/null
+++ b/app_tests/integration_tests/llm/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
diff --git a/app_tests/integration_tests/llm/sglang/__init__.py b/app_tests/integration_tests/llm/sglang/__init__.py
new file mode 100644
index 000000000..a85ba359d
--- /dev/null
+++ b/app_tests/integration_tests/llm/sglang/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
diff --git a/app_tests/integration_tests/llm/sglang/conftest.py b/app_tests/integration_tests/llm/sglang/conftest.py
new file mode 100644
index 000000000..8543708da
--- /dev/null
+++ b/app_tests/integration_tests/llm/sglang/conftest.py
@@ -0,0 +1,123 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import logging
+import os
+import pytest
+
+from ..utils import (
+ find_available_port,
+ start_llm_server,
+ download_with_hf_datasets,
+ export_paged_llm_v1,
+ compile_model,
+)
+
+pytest.importorskip("sglang")
+import sglang as sgl
+from sglang.lang.chat_template import get_chat_template
+
+pytest.importorskip("sentence_transformers")
+from sentence_transformers import SentenceTransformer
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="module")
+def register_shortfin_backend(available_port):
+ backend = sgl.Shortfin(
+ chat_template=get_chat_template("llama-3-instruct"),
+ base_url=f"http://localhost:{available_port}",
+ )
+ sgl.set_default_backend(backend)
+
+
+@pytest.fixture(scope="module")
+def pre_process_model(request, tmp_path_factory):
+ device_settings = request.param["device_settings"]
+ tmp_dir = tmp_path_factory.mktemp("sglang_integration_tests")
+
+ # Download model
+ model_params_path = tmp_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
+ download_with_hf_datasets(tmp_dir, "llama3_8B_fp16")
+
+ # Export to mlir
+ mlir_path = tmp_dir / "model.mlir"
+ config_path = tmp_dir / "config.json"
+ batch_sizes = [1, 4]
+ export_paged_llm_v1(
+ mlir_path,
+ config_path,
+ model_params_path,
+ batch_sizes,
+ )
+
+ # Compile Model
+ vmfb_path = tmp_dir / "model.vmfb"
+ compile_model(
+ mlir_path,
+ vmfb_path,
+ device_settings,
+ )
+
+ config = {
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 131072,
+ "attn_head_count": 8,
+ "attn_head_dim": 128,
+ "prefill_batch_sizes": [1, 4],
+ "decode_batch_sizes": [1, 4],
+ "transformer_block_count": 32,
+ "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
+ }
+ config_path = tmp_dir / "config.json"
+ with open(config_path, "w") as f:
+ json.dump(config, f)
+
+ return tmp_dir
+
+
+@pytest.fixture(scope="module")
+def available_port():
+ return find_available_port()
+
+
+@pytest.fixture(scope="module")
+def start_server(request, pre_process_model, available_port):
+ os.environ["ROCR_VISIBLE_DEVICES"] = "1"
+ device_settings = request.param["device_settings"]
+
+ export_dir = pre_process_model
+
+ tokenizer_path = export_dir / "tokenizer.json"
+ model_params_path = export_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
+ vmfb_path = export_dir / "model.vmfb"
+ config_path = export_dir / "config.json"
+
+ logger.info("Starting server...")
+ server_process = start_llm_server(
+ available_port,
+ tokenizer_path,
+ config_path,
+ vmfb_path,
+ model_params_path,
+ device_settings,
+ timeout=30,
+ )
+ logger.info("Server started")
+
+ yield server_process
+
+ server_process.terminate()
+ server_process.wait()
+
+
+@pytest.fixture(scope="module")
+def load_comparison_model():
+ model = SentenceTransformer("all-MiniLM-L6-v2")
+ return model
diff --git a/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py b/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py
new file mode 100644
index 000000000..efab14ea7
--- /dev/null
+++ b/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py
@@ -0,0 +1,309 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import re
+import pytest
+
+from ..utils import (
+ AccuracyValidationException,
+)
+
+pytest.importorskip("sglang")
+import sglang as sgl
+from sglang.lang.chat_template import get_chat_template
+
+pytest.importorskip("sentence_transformers")
+from sentence_transformers import SentenceTransformer, util
+
+logger = logging.getLogger(__name__)
+
+DEVICE_SETTINGS = {
+ "device_flags": [
+ "--iree-hal-target-backends=rocm",
+ "--iree-hip-target=gfx942",
+ ],
+ "device": "hip",
+}
+
+ACCEPTED_THRESHOLD = 0.8
+
+
+def compute_similarity(model: SentenceTransformer, sentence_1: str, sentence_2: str):
+ embeddings = model.encode([sentence_1, sentence_2])
+ return util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
+
+
+@sgl.function
+def multi_turn_question(s, question_1, question_2):
+ s += sgl.user(question_1)
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=50, temperature=1.0))
+ s += sgl.user(question_2)
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=50, temperature=1.0))
+
+
+@sgl.function
+def tip_suggestion(s):
+ s += (
+ "Here are two tips for staying healthy: "
+ "1. Balanced Diet. 2. Regular Exercise.\n\n"
+ )
+
+ forks = s.fork(2)
+ for i, f in enumerate(forks):
+ f += f"Now, expand tip {i+1} into a paragraph:\n"
+ f += sgl.gen(f"detailed_tip", max_tokens=50, temperature=1.0)
+
+ s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
+ s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
+ s += "In summary" + sgl.gen("summary")
+
+
+@pytest.mark.parametrize(
+ "pre_process_model,start_server",
+ [
+ (
+ {"device_settings": DEVICE_SETTINGS},
+ {"device_settings": DEVICE_SETTINGS},
+ )
+ ],
+ indirect=True,
+)
+def test_multi_turn_qa(load_comparison_model, start_server, register_shortfin_backend):
+ model = load_comparison_model
+
+ question_1 = "Name the capital city of the USA."
+ question_2 = "The Smithsonian is in this location."
+
+ answer_1 = "The capital city of the United States of America is Washington, D.C. (short for District of Columbia).assistant\n\nWould you like to know more about Washington, D.C. or is there something else I can help you with?"
+ answer_2 = "The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. It was founded in 1846 and is named after British scientist James Smithson, who left a bequest to"
+
+ logger.info("Testing multi-turn Q&A run...")
+ state = multi_turn_question.run(
+ question_1=question_1,
+ question_2=question_2,
+ )
+ messages = state.messages()
+ logger.info("Received messages from multi-turn call.")
+
+ assert messages[0] == {
+ "role": "user",
+ "content": question_1,
+ }
+ assert messages[1]["role"] == "assistant"
+
+ logger.info("Computing similarity between first question and first answer...")
+ first_q_answer = messages[1]["content"]
+ score = compute_similarity(model, answer_1, first_q_answer)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {answer_1} and {first_q_answer}:\n SCORE: {score}"
+ )
+ logger.info("Similarity passed")
+
+ assert messages[2] == {
+ "role": "user",
+ "content": question_2,
+ }
+ assert messages[3]["role"] == "assistant"
+
+ logger.info("Testing similarity between second question and second answer...")
+ second_q_answer = messages[3]["content"]
+ score = compute_similarity(model, answer_2, second_q_answer)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {answer_2} and {second_q_answer}:\n SCORE: {score}"
+ )
+ logger.info("Similarity passed.")
+
+
+@pytest.mark.parametrize(
+ "pre_process_model,start_server",
+ [
+ (
+ {"device_settings": DEVICE_SETTINGS},
+ {"device_settings": DEVICE_SETTINGS},
+ )
+ ],
+ indirect=True,
+)
+def test_stream_multi_turn_qa(
+ load_comparison_model, start_server, register_shortfin_backend
+):
+ def clean_message(message: str):
+ """Remove chat tags from message before comparison.
+
+ Args:
+ message (str): Message to clean.
+
+ Returns:
+ str: Message without tags (i.e. <|start_header_id|>)
+ """
+ pattern = r"<\|.*?\|>"
+ return re.sub(pattern, "", message)
+
+ model = load_comparison_model
+ question_1 = "Name the capital city of the USA."
+ question_2 = "The Smithsonian is in this location."
+ expected_answer_1 = "The capital city of the United States of America is Washington, D.C. (short for District of Columbia).assistant\n\nWould you like to know more about Washington, D.C. or is there something else I can help you with?"
+ expected_answer_2 = "The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. It was founded in 1846 and is named after British scientist James Smithson, who left a bequest to"
+
+ logger.info("Testing multi-turn Q&A run w/ stream...")
+ state = multi_turn_question.run(
+ question_1=question_1,
+ question_2=question_2,
+ stream=True,
+ )
+ messages = ""
+ for chunk in state.text_iter():
+ messages += chunk
+ logger.info("Received messages from multi-turn call.")
+
+ logger.info("Computing similarity between expectation and result")
+ expected_result = f"user: {question_1}\nassistant: {expected_answer_1}\nuser: {question_2}\nassistant: {expected_answer_2}"
+ cleaned_messages = clean_message(messages)
+ score = compute_similarity(model, cleaned_messages, expected_result)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {expected_result} and {messages}:\n SCORE: {score}"
+ )
+ logger.info("Similarity passed.")
+
+
+@pytest.mark.parametrize(
+ "pre_process_model,start_server",
+ [
+ (
+ {"device_settings": DEVICE_SETTINGS},
+ {"device_settings": DEVICE_SETTINGS},
+ )
+ ],
+ indirect=True,
+)
+def test_batch_multi_turn_qa(
+ load_comparison_model, start_server, register_shortfin_backend
+):
+ model = load_comparison_model
+
+ question_1_1 = "Name the capital city of the USA."
+ question_1_2 = "The Smithsonian is in this location."
+ expected_answer_1_1 = "The capital city of the United States of America is Washington, D.C. (short for District of Columbia).assistant\n\nWould you like to know more about Washington, D.C. or is there something else I can help you with?"
+ expected_answer_1_2 = "The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes. It was founded in 1846 and is named after British scientist James Smithson, who left a bequest to"
+
+ question_2_1 = "Name the largest city in the USA."
+ question_2_2 = "The Empire State Building is in this location."
+ expected_answer_2_1 = "The largest city in the USA is New York City, with a population of over 8.4 million people, according to the United States Census Bureau (2020 estimates).assistant\n\nHowever, I should note that the largest city in the"
+ expected_answer_2_2 = "That's correct, the iconic Empire State Building is located in Midtown Manhattan, New York City. It's one of the most recognizable landmarks in the world and a symbol of the city's grandeur and history.assistant\n\nAnd, by"
+
+ logger.info("Testing batch multi-turn Q&A run...")
+ states = multi_turn_question.run_batch(
+ [
+ {
+ "question_1": question_1_1,
+ "question_2": question_1_2,
+ },
+ {
+ "question_1": question_2_1,
+ "question_2": question_2_2,
+ },
+ ]
+ )
+
+ first_qa = states[0]
+ second_qa = states[1]
+
+ first_qa_messages = first_qa.messages()
+ second_qa_messages = second_qa.messages()
+
+ logger.info("Testing first batch of messages...")
+ assert first_qa_messages[0] == {
+ "role": "user",
+ "content": question_1_1,
+ }
+
+ assert first_qa_messages[1]["role"] == "assistant"
+ first_answer = first_qa_messages[1]["content"]
+ expected_answer = expected_answer_1_1
+ score = compute_similarity(model, expected_answer, first_answer)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {expected_answer} and {first_answer}:\n SCORE: {score}"
+ )
+
+ assert first_qa_messages[2] == {
+ "role": "user",
+ "content": question_1_2,
+ }
+ first_qa_messages[3]["role"] = "assistant"
+ second_answer = first_qa_messages[3]["content"]
+ expected_answer = expected_answer_1_2
+ score = compute_similarity(model, expected_answer, second_answer)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {expected_answer} and {second_answer}:\n SCORE: {score}"
+ )
+ logger.info("First batch passed.")
+
+ logger.info("Testing second batch of messages...")
+ assert second_qa_messages[0] == {
+ "role": "user",
+ "content": question_2_1,
+ }
+
+ assert second_qa_messages[1]["role"] == "assistant"
+ first_answer = second_qa_messages[1]["content"]
+ expected_answer = expected_answer_2_1
+ score = compute_similarity(model, expected_answer, first_answer)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {expected_answer} and {first_answer}:\n SCORE: {score}"
+ )
+
+ assert second_qa_messages[2] == {
+ "role": "user",
+ "content": question_2_2,
+ }
+ second_qa_messages[3]["role"] = "assistant"
+ second_answer = second_qa_messages[3]["content"]
+ expected_answer = expected_answer_2_2
+ score = compute_similarity(model, expected_answer, second_answer)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {expected_answer} and {second_answer}:\n SCORE: {score}"
+ )
+ logger.info("Second batch passed.")
+
+
+@pytest.mark.parametrize(
+ "pre_process_model,start_server",
+ [
+ (
+ {"device_settings": DEVICE_SETTINGS},
+ {"device_settings": DEVICE_SETTINGS},
+ )
+ ],
+ indirect=True,
+)
+def test_fork(load_comparison_model, start_server, register_shortfin_backend):
+ model = load_comparison_model
+
+ logger.info("Testing fork...")
+ state = tip_suggestion.run()
+ result = state.text()
+ logger.info("Fork response received.")
+
+ logger.info("Computing similarity...")
+ expected_answer = """Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise.
+ Tip 1:A balanced diet is essential for maintaining good health. It involves consuming a variety of foods from different food groups, including fruits, vegetables, whole grains, lean proteins, and healthy fats. A balanced diet provides the body with the necessary nutrients, vitamins, and
+ Tip 2:Regular exercise is essential for maintaining a healthy body. It helps to improve cardiovascular health, increase strength and flexibility, and boost the immune system. Regular physical activity can also reduce the risk of chronic diseases such as heart disease, diabetes, and certain types of cancer
+ In summary, a balanced diet and regular exercise are two of the most important tips for staying healthy. By following these tips, you can maintain a healthy body and reduce the risk of chronic diseases.
+ """
+ score = compute_similarity(model, result, expected_answer)
+ if not score > ACCEPTED_THRESHOLD:
+ raise AccuracyValidationException(
+ f"Accuracy error between {expected_answer} and {result}:\n SCORE: {score}"
+ )
+ logger.info("Similarity passed.")
diff --git a/app_tests/integration_tests/llm/shortfin/__init__.py b/app_tests/integration_tests/llm/shortfin/__init__.py
new file mode 100644
index 000000000..a85ba359d
--- /dev/null
+++ b/app_tests/integration_tests/llm/shortfin/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py
new file mode 100644
index 000000000..0d40119c7
--- /dev/null
+++ b/app_tests/integration_tests/llm/shortfin/conftest.py
@@ -0,0 +1,140 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import logging
+import os
+from pathlib import Path
+import pytest
+import shutil
+
+pytest.importorskip("transformers")
+from ..utils import (
+ download_huggingface_model,
+ download_tokenizer,
+ export_paged_llm_v1,
+ compile_model,
+ find_available_port,
+ start_llm_server,
+ start_log_group,
+ end_log_group,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="module")
+def model_test_dir(request, tmp_path_factory):
+ """Prepare model artifacts for starting the LLM server.
+
+ Args:
+ request (FixtureRequest): The following params are accepted:
+ - repo_id (str): The Hugging Face repo ID.
+ - model_file (str): The model file to download.
+ - tokenizer_id (str): The tokenizer ID to download.
+ - settings (dict): The settings for sharktank export.
+ - batch_sizes (list): The batch sizes to use for the model.
+ tmp_path_factory (TempPathFactory): Temp dir to save artifacts to.
+
+ Yields:
+ Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir.
+ """
+ logger.info(
+ "Preparing model artifacts..." + start_log_group("Preparing model artifacts")
+ )
+
+ repo_id = request.param["repo_id"]
+ model_file = request.param["model_file"]
+ tokenizer_id = request.param["tokenizer_id"]
+ settings = request.param["settings"]
+ batch_sizes = request.param["batch_sizes"]
+
+ tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")
+ hf_home = os.environ.get("HF_HOME", None)
+ hf_home = Path(hf_home) if hf_home is not None else tmp_dir
+ try:
+ # Download model if it doesn't exist
+ model_path = hf_home / model_file
+ download_huggingface_model(hf_home, repo_id, model_file)
+
+ # Set up tokenizer if it doesn't exist
+ download_tokenizer(hf_home, tokenizer_id)
+
+ # Export model
+ mlir_path = tmp_dir / "model.mlir"
+ config_path = tmp_dir / "config.json"
+ export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)
+
+ # Compile model
+ vmfb_path = tmp_dir / "model.vmfb"
+ compile_model(mlir_path, vmfb_path, settings)
+
+ # Write config
+ edited_config_path = tmp_dir / "edited_config.json"
+ config = {
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 2048,
+ "attn_head_count": 32,
+ "attn_head_dim": 100,
+ "prefill_batch_sizes": batch_sizes,
+ "decode_batch_sizes": batch_sizes,
+ "transformer_block_count": 26,
+ "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
+ }
+ logger.info(f"Saving edited config to: {edited_config_path}\n")
+ logger.info(f"Config: {json.dumps(config, indent=2)}")
+ with open(edited_config_path, "w") as f:
+ json.dump(config, f)
+ logger.info("Model artifacts setup successfully" + end_log_group())
+ yield hf_home, tmp_dir
+ finally:
+ shutil.rmtree(tmp_dir)
+
+
+@pytest.fixture(scope="module")
+def available_port():
+ return find_available_port()
+
+
+@pytest.fixture(scope="module")
+def llm_server(request, model_test_dir, available_port):
+ """Start the LLM server.
+
+ Args:
+ request (FixtureRequest): The following params are accepted:
+ - model_file (str): The model file to download.
+ - settings (dict): The settings for starting the server.
+ model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir.
+ available_port (int): The available port to start the server on.
+
+ Yields:
+ subprocess.Popen: The server process that was started.
+ """
+ logger.info("Starting LLM server..." + start_log_group("Starting LLM server"))
+ hf_home, tmp_dir = model_test_dir
+ model_file = request.param["model_file"]
+ settings = request.param["settings"]
+
+ tokenizer_path = hf_home / "tokenizer.json"
+ config_path = tmp_dir / "edited_config.json"
+ vmfb_path = tmp_dir / "model.vmfb"
+ parameters_path = hf_home / model_file
+
+ # Start llm server
+ server_process = start_llm_server(
+ available_port,
+ tokenizer_path,
+ config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ )
+ logger.info("LLM server started!" + end_log_group())
+ yield server_process
+ # Teardown: kill the server
+ server_process.terminate()
+ server_process.wait()
diff --git a/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
new file mode 100644
index 000000000..c4da9e4eb
--- /dev/null
+++ b/app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
@@ -0,0 +1,104 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import os
+import pytest
+import requests
+import uuid
+
+from ..utils import AccuracyValidationException, start_log_group, end_log_group
+
+logger = logging.getLogger(__name__)
+
+CPU_SETTINGS = {
+ "device_flags": [
+ "-iree-hal-target-backends=llvm-cpu",
+ "--iree-llvmcpu-target-cpu=host",
+ ],
+ "device": "local-task",
+}
+IREE_HIP_TARGET = os.environ.get("IREE_HIP_TARGET", "gfx1100")
+gpu_settings = {
+ "device_flags": [
+ "-iree-hal-target-backends=rocm",
+ f"--iree-hip-target={IREE_HIP_TARGET}",
+ ],
+ "device": "hip",
+}
+
+
+def do_generate(prompt, port):
+ logger.info("Generating request...")
+ headers = {"Content-Type": "application/json"}
+ # Create a GenerateReqInput-like structure
+ data = {
+ "text": prompt,
+ "sampling_params": {"max_completion_tokens": 15, "temperature": 0.7},
+ "rid": uuid.uuid4().hex,
+ "return_logprob": False,
+ "logprob_start_len": -1,
+ "top_logprobs_num": 0,
+ "return_text_in_logprobs": False,
+ "stream": False,
+ }
+ logger.info("Prompt text:")
+ logger.info(data["text"])
+ BASE_URL = f"http://localhost:{port}"
+ response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data)
+ logger.info(f"Generate endpoint status code: {response.status_code}")
+ if response.status_code == 200:
+ logger.info("Generated text:")
+ data = response.text
+ assert data.startswith("data: ")
+ data = data[6:]
+ assert data.endswith("\n\n")
+ data = data[:-2]
+ return data
+ else:
+ response.raise_for_status()
+
+
+@pytest.mark.parametrize(
+ "model_test_dir,llm_server",
+ [
+ (
+ {
+ "repo_id": "SlyEcho/open_llama_3b_v2_gguf",
+ "model_file": "open-llama-3b-v2-f16.gguf",
+ "tokenizer_id": "openlm-research/open_llama_3b_v2",
+ "settings": CPU_SETTINGS,
+ "batch_sizes": [1, 4],
+ },
+ {"model_file": "open-llama-3b-v2-f16.gguf", "settings": CPU_SETTINGS},
+ )
+ ],
+ indirect=True,
+)
+def test_llm_server(llm_server, available_port):
+ # Here you would typically make requests to your server
+ # and assert on the responses
+ assert llm_server.poll() is None
+ PROMPT = "1 2 3 4 5 "
+ expected_output_prefix = "6 7 8"
+ logger.info(
+ "Sending HTTP Generation Request"
+ + start_log_group("Sending HTTP Generation Request")
+ )
+ output = do_generate(PROMPT, available_port)
+ # log to GITHUB_STEP_SUMMARY if we are in a GitHub Action
+ if "GITHUB_ACTION" in os.environ:
+ with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
+ # log prompt
+ f.write("LLM results:\n")
+ f.write(f"- llm_prompt:`{PROMPT}`\n")
+ f.write(f"- llm_output:`{output}`\n")
+ logger.info(output)
+ if not output.startswith(expected_output_prefix):
+ raise AccuracyValidationException(
+ f"Expected '{output}' to start with '{expected_output_prefix}'"
+ )
+ logger.info("HTTP Generation Request Successful" + end_log_group())
diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py
new file mode 100644
index 000000000..05712039e
--- /dev/null
+++ b/app_tests/integration_tests/llm/utils.py
@@ -0,0 +1,220 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import multiprocessing
+import os
+from pathlib import Path
+import subprocess
+import sys
+import time
+
+import requests
+from transformers import AutoTokenizer
+
+logger = logging.getLogger("__name__")
+
+
+class AccuracyValidationException(RuntimeError):
+ pass
+
+
+def download_huggingface_model(local_dir, repo_id, model_file):
+ model_path = local_dir / model_file
+ logger.info(f"Preparing model_path: {model_path}..")
+ if not os.path.exists(model_path):
+ logger.info(f"Downloading model {repo_id} {model_file} from Hugging Face...")
+ subprocess.run(
+ f"huggingface-cli download --local-dir {local_dir} {repo_id} {model_file}",
+ shell=True,
+ check=True,
+ )
+ logger.info(f"Model downloaded to {model_path}")
+ else:
+ logger.info("Using cached model")
+
+
+def download_with_hf_datasets(local_dir: Path | str, model_name: str):
+ """Download a model using `sharktank.utils.hf_datasets` script.
+
+ Args:
+ local_dir (Path | str): Local directory to download model to.
+ model_name (str): Name of model to download.
+ """
+ if isinstance(local_dir, Path):
+ local_dir = str(local_dir)
+
+ logger.info(f"Download model {model_name} with `hf_datasets` to {local_dir}...")
+ subprocess.run(
+ [
+ "python",
+ "-m",
+ "sharktank.utils.hf_datasets",
+ model_name,
+ "--local-dir",
+ local_dir,
+ ],
+ check=True,
+ )
+ logger.info(f"Model {model_name} successfully downloaded.")
+
+
+def download_tokenizer(local_dir, tokenizer_id):
+ # Set up tokenizer if it doesn't exist
+ tokenizer_path = local_dir / "tokenizer.json"
+ logger.info(f"Preparing tokenizer_path: {tokenizer_path}...")
+ if not os.path.exists(tokenizer_path):
+ logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...")
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_id,
+ )
+ tokenizer.save_pretrained(local_dir)
+ logger.info(f"Tokenizer saved to {tokenizer_path}")
+ else:
+ logger.info("Using cached tokenizer")
+
+
+def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes):
+ bs_string = ",".join(map(str, batch_sizes))
+ logger.info(
+ "Exporting model with following settings:\n"
+ f" MLIR Path: {mlir_path}\n"
+ f" Config Path: {config_path}\n"
+ f" Batch Sizes: {bs_string}"
+ )
+ subprocess.run(
+ [
+ "python",
+ "-m",
+ "sharktank.examples.export_paged_llm_v1",
+ f"--{model_path.suffix.strip('.')}-file={model_path}",
+ f"--output-mlir={mlir_path}",
+ f"--output-config={config_path}",
+ f"--bs={bs_string}",
+ ],
+ check=True,
+ )
+ logger.info(f"Model successfully exported to {mlir_path}")
+
+
+def compile_model(mlir_path, vmfb_path, device_settings):
+ logger.info(f"Compiling model to {vmfb_path}")
+ subprocess.run(
+ [
+ "iree-compile",
+ mlir_path,
+ "-o",
+ vmfb_path,
+ ]
+ + device_settings["device_flags"],
+ check=True,
+ )
+ logger.info(f"Model successfully compiled to {vmfb_path}")
+
+
+def find_available_port():
+ import socket
+ from contextlib import closing
+
+ logger.info(f"Finding available port...")
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ port = s.getsockname()[1]
+ logger.info(f"Found available port: {port}")
+ return port
+
+
+def wait_for_server(url, timeout=10):
+ logger.info(f"Waiting for server to start at {url}...")
+ start = time.time()
+ while time.time() - start < timeout:
+ try:
+ requests.get(f"{url}/health")
+ logger.info("Server successfully started")
+ return
+ except requests.exceptions.ConnectionError:
+ time.sleep(1)
+ raise TimeoutError(f"Server did not start within {timeout} seconds")
+
+
+def _start_llm_server_args(
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ port,
+):
+ return [
+ sys.executable,
+ "-m",
+ "shortfin_apps.llm.server",
+ f"--tokenizer_json={tokenizer_path}",
+ f"--model_config={model_config_path}",
+ f"--vmfb={vmfb_path}",
+ f"--parameters={parameters_path}",
+ f"--device={settings['device']}",
+ f"--port={port}",
+ ]
+
+
+def start_llm_server(
+ port,
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ timeout=10,
+ multi=False,
+):
+ logger.info("Starting LLM server...")
+ if multi:
+ server_process = multiprocessing.Process(
+ target=subprocess.Popen(
+ _start_llm_server_args(
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ port,
+ ),
+ )
+ )
+ server_process.start()
+
+ else:
+ # Start the server
+ server_process = subprocess.Popen(
+ _start_llm_server_args(
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ port,
+ )
+ )
+ logger.info("Process started... waiting for server")
+ # Wait for server to start
+ wait_for_server(f"http://localhost:{port}", timeout)
+ return server_process
+
+
+def start_log_group(headline):
+ # check if we are in github ci
+ if os.environ.get("GITHUB_ACTIONS") == "true":
+ return f"\n::group::{headline}"
+ return ""
+
+
+def end_log_group():
+ # check if we are in github ci
+ if os.environ.get("GITHUB_ACTIONS") == "true":
+ return "\n::endgroup::"
+ return ""
diff --git a/build_tools/integration_tests/llama_export_compile_serve.sh b/build_tools/integration_tests/llama_export_compile_serve.sh
deleted file mode 100755
index edd54b688..000000000
--- a/build_tools/integration_tests/llama_export_compile_serve.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/bin/bash
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-set -xeuo pipefail
-
-# Assume that the environment is already set up:
-# * Python venv set up with requirements, sharktank, and shortfin
-# * iree-compile and iree-run-module on $PATH
-# * authenticated with `huggingface-cli login`
-
-# Input variables.
-# Default model: https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf
-# Default tokenizer: https://huggingface.co/openlm-research/open_llama_3b_v2
-TEMP_DIR="${TEMP_DIR:-/tmp/sharktank/llama}"
-HUGGING_FACE_MODEL_NAME="${HUGGING_FACE_MODEL_NAME:-SlyEcho/open_llama_3b_v2_gguf}"
-HUGGING_FACE_MODEL_FILE="${HUGGING_FACE_MODEL_FILE:-open-llama-3b-v2-f16.gguf}"
-HUGGING_FACE_TOKENIZER="${HUGGING_FACE_TOKENIZER:-openlm-research/open_llama_3b_v2}"
-
-# Derived variables.
-LOCAL_GGUF_FILE="${TEMP_DIR}/${HUGGING_FACE_MODEL_FILE}"
-LOCAL_MLIR_FILE="${TEMP_DIR}/model.mlir"
-LOCAL_CONFIG_FILE="${TEMP_DIR}/config.json"
-LOCAL_VMFB_FILE="${TEMP_DIR}/model.vmfb"
-
-mkdir -p ${TEMP_DIR}
-
-huggingface-cli download --local-dir ${TEMP_DIR} ${HUGGING_FACE_MODEL_NAME} ${HUGGING_FACE_MODEL_FILE}
-
-python -m sharktank.examples.export_paged_llm_v1 \
- --gguf-file="${LOCAL_GGUF_FILE}" \
- --output-mlir="${LOCAL_MLIR_FILE}" \
- --output-config="${LOCAL_CONFIG_FILE}"
-
-iree-compile "${LOCAL_MLIR_FILE}" \
- --iree-hal-target-backends=llvm-cpu \
- --iree-llvmcpu-target-cpu-features=host \
- -o ${LOCAL_VMFB_FILE}
-
-python -m shortfin.llm.impl.service_v1_cli \
- --tokenizer="${HUGGING_FACE_TOKENIZER}" \
- --config="${LOCAL_CONFIG_FILE}" \
- --vmfb="${LOCAL_VMFB_FILE}" \
- --gguf="${LOCAL_GGUF_FILE}"
diff --git a/build_tools/python_deploy/README.md b/build_tools/python_deploy/README.md
new file mode 100644
index 000000000..d36545a9c
--- /dev/null
+++ b/build_tools/python_deploy/README.md
@@ -0,0 +1,48 @@
+# Python Deployment
+
+These scripts assist with building Python packages and pushing them to
+[PyPI (the Python Package Index)](https://pypi.org/). See also
+
+* The Python Packaging User Guide:
+
+## Overview
+
+See comments in scripts for canonical usage. This page includes additional
+notes.
+
+### Package building
+
+These scripts build packages:
+
+* [`/shark-ai/build_tools/build_linux_package.sh`](/shark-ai/build_tools/build_linux_package.sh)
+* [`/sharktank/build_tools/build_linux_package.sh`](/sharktank/build_tools/build_linux_package.sh)
+* [`/shortfin/build_tools/build_linux_package.sh`](/shortfin/build_tools/build_linux_package.sh)
+
+### Version management
+
+These scripts handle versioning across packages, including considerations like
+major, minor, and patch levels (`X.Y.Z`), as well as suffixes like
+`rc20241107`:
+
+* [`compute_common_version.py`](./compute_common_version.py)
+* [`compute_local_version.py`](./compute_local_version.py)
+* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py)
+* [`write_requirements.py`](./write_requirements.py)
+
+### PyPI deployment
+
+These scripts handle promoting nightly releases packages to stable and pushing
+to PyPI:
+
+* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py)
+* [`pypi_deploy.sh`](./pypi_deploy.sh)
+
+Both of these scripts expect to have the dependencies from
+[`requirements-pypi-deploy.txt`](./requirements-pypi-deploy.txt) installed.
+This can be easily managed by using a Python virtual environment:
+
+```bash
+python -m venv .venv
+source .venv/bin/activate
+python -m pip install -r ./requirements-pypi-deploy.txt
+```
diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py
new file mode 100755
index 000000000..aa193bcc1
--- /dev/null
+++ b/build_tools/python_deploy/compute_common_version.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python3
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This scripts grabs the `X.Y.Z[.dev]` version identifier from the
+# 'sharktank' and 'shortfin' version files and computes the version
+# for the meta 'shark-ai' package.
+#
+# Usage:
+# ./compute_common_version.py --stable-release --write-json
+# cat ../../shark-ai/version_local.json
+
+import argparse
+from pathlib import Path
+import json
+from datetime import datetime
+import sys
+
+from packaging.version import Version
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--write-json", action="store_true")
+parser.add_argument("--version-suffix", action="store", type=str)
+
+release_type = parser.add_mutually_exclusive_group()
+release_type.add_argument("-stable", "--stable-release", action="store_true") # default
+release_type.add_argument("-rc", "--nightly-release", action="store_true")
+
+
+args = parser.parse_args()
+
+if not (args.stable_release or args.nightly_release):
+ parser.print_usage(sys.stderr)
+ sys.stderr.write("error: A release type is required\n")
+ sys.exit(1)
+
+if args.stable_release and args.version_suffix:
+ sys.stderr.write("error: A version suffix is only supported for stable releases\n")
+ sys.exit(1)
+
+THIS_DIR = Path(__file__).parent.resolve()
+REPO_ROOT = THIS_DIR.parent.parent
+
+VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version.json"
+VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version.json"
+VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json"
+
+
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
+ return json.load(f)
+
+
+def write_version_info():
+ with open(VERSION_FILE_LOCAL, "w") as f:
+ json.dump(version_local, f, indent=2)
+ f.write("\n")
+
+
+sharktank_version = load_version_info(VERSION_FILE_SHARKTANK)
+SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version")
+SHARKTANK_BASE_VERSION = Version(SHARKTANK_PACKAGE_VERSION).base_version
+
+shortfin_version = load_version_info(VERSION_FILE_SHORTFIN)
+SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version")
+SHORTFIN_BASE_VERSION = Version(SHORTFIN_PACKAGE_VERSION).base_version
+
+if SHARKTANK_BASE_VERSION > SHORTFIN_BASE_VERSION:
+ COMMON_VERSION = SHARKTANK_BASE_VERSION
+else:
+ COMMON_VERSION = SHORTFIN_BASE_VERSION
+
+if args.nightly_release:
+ if args.version_suffix:
+ VERSION_SUFFIX = args.version_suffix
+ else:
+ VERSION_SUFFIX = "rc" + datetime.today().strftime("%Y%m%d")
+
+ COMMON_VERSION += VERSION_SUFFIX
+
+if args.write_json:
+ version_local = {"package-version": COMMON_VERSION}
+ write_version_info()
+
+print(COMMON_VERSION)
diff --git a/build_tools/python_deploy/compute_local_version.py b/build_tools/python_deploy/compute_local_version.py
new file mode 100755
index 000000000..0465fa443
--- /dev/null
+++ b/build_tools/python_deploy/compute_local_version.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This scripts grabs the X.Y.Z[.dev]` version identifier from a
+# `version.json` and writes the corresponding
+# `X.Y.ZrcYYYYMMDD` version identifier to `version_local.json`.
+
+import argparse
+from pathlib import Path
+import json
+from datetime import datetime
+
+from packaging.version import Version
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("path", type=Path)
+parser.add_argument("--version-suffix", action="store", type=str)
+args = parser.parse_args()
+
+VERSION_FILE = args.path / "version.json"
+VERSION_FILE_LOCAL = args.path / "version_local.json"
+
+
+def load_version_info():
+ with open(VERSION_FILE, "rt") as f:
+ return json.load(f)
+
+
+def write_version_info():
+ with open(VERSION_FILE_LOCAL, "w") as f:
+ json.dump(version_local, f, indent=2)
+ f.write("\n")
+
+
+version_info = load_version_info()
+
+if args.version_suffix:
+ VERSION_SUFFIX = args.version_suffix
+else:
+ VERSION_SUFFIX = "rc" + datetime.today().strftime("%Y%m%d")
+
+PACKAGE_VERSION = version_info.get("package-version")
+PACKAGE_BASE_VERSION = Version(PACKAGE_VERSION).base_version
+PACKAGE_LOCAL_VERSION = PACKAGE_BASE_VERSION + VERSION_SUFFIX
+
+version_local = {"package-version": PACKAGE_LOCAL_VERSION}
+
+write_version_info()
+
+print(PACKAGE_LOCAL_VERSION)
diff --git a/build_tools/python_deploy/promote_whl_from_rc_to_final.py b/build_tools/python_deploy/promote_whl_from_rc_to_final.py
new file mode 100755
index 000000000..061dd933b
--- /dev/null
+++ b/build_tools/python_deploy/promote_whl_from_rc_to_final.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This scripts takes a file like 'sharktank-2.9.0rc20241110-py3-none-any.whl'
+# with embedded version '2.9.0rc20241110' as input and then drops the
+# 'rcYYYYMMDD' suffix from both the embedded version and file name.
+#
+# Typical usage:
+# pip install -r requirements-pypi-deploy.txt
+# ./promote_whl_from_rc_to_final.py /path/to/file.whl --delete-old-wheel
+
+import argparse
+from change_wheel_version import change_wheel_version
+from packaging.version import Version
+from pathlib import Path
+from pkginfo import Wheel
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "input_file",
+ help="Path to the input .whl file to promote",
+ type=Path,
+ )
+ parser.add_argument(
+ "--delete-old-wheel",
+ help="Deletes the original wheel after successfully promoting it",
+ action="store_true",
+ default=False,
+ )
+ return parser.parse_args()
+
+
+def main(args):
+ original_wheel_path = args.input_file
+ print(f"Promoting whl from rc to final: '{original_wheel_path}'")
+
+ original_wheel = Wheel(original_wheel_path)
+ original_version = Version(original_wheel.version)
+ base_version = original_version.base_version
+ print(
+ f" Original wheel version is '{original_version}' with base '{base_version}'"
+ )
+
+ if str(base_version) == str(original_version):
+ print(" Version is already a release version, skipping")
+ return
+
+ print(f" Changing to base version: '{base_version}'")
+ new_wheel_path = change_wheel_version(original_wheel_path, str(base_version), None)
+ print(f" New wheel path is '{new_wheel_path}'")
+
+ new_wheel = Wheel(new_wheel_path)
+ new_version = Version(new_wheel.version)
+ print(f" New wheel version is '{new_version}'")
+
+ if args.delete_old_wheel:
+ print(" Deleting original wheel")
+ original_wheel_path.unlink()
+
+
+if __name__ == "__main__":
+ main(parse_arguments())
diff --git a/build_tools/python_deploy/pypi_deploy.sh b/build_tools/python_deploy/pypi_deploy.sh
new file mode 100755
index 000000000..63f123ac0
--- /dev/null
+++ b/build_tools/python_deploy/pypi_deploy.sh
@@ -0,0 +1,126 @@
+#!/bin/bash
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This script promotes Python packages from nightly releases to PyPI.
+#
+# Prerequisites:
+# * You will need to have PyPI credentials set up. See
+# https://packaging.python.org/en/latest/tutorials/packaging-projects/#uploading-the-distribution-archives
+# * Install requirements, e.g. in a Python virtual environment (venv):
+# `pip install -r requirements-pypi-deploy.txt`
+# * Install python3.13t and install pip. On Ubuntu:
+# ```bash
+# sudo add-apt-repository ppa:deadsnakes
+# sudo apt-get update
+# sudo apt-get install python3.13-nogil
+# python3.13t -m ensurepip --upgrade
+# ```
+# * Choose a release candidate to promote from
+# https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels
+#
+# Usage:
+# ./pypi_deploy.sh 2.9.0rc20241108
+
+set -euo pipefail
+
+RELEASE="$1"
+
+SCRIPT_DIR="$(dirname -- "$( readlink -f -- "$0"; )")";
+REPO_ROOT="$(cd "$SCRIPT_DIR"/../../ && pwd)"
+TMPDIR="$(mktemp --directory --tmpdir shark_platform_pypi_wheels.XXXXX)"
+ASSETS_PAGE="https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels"
+
+# TODO: rewrite in Python?
+
+function download_wheels() {
+ echo ""
+ echo "Downloading wheels for '${RELEASE}'..."
+
+ # sharktank
+ python -m pip download sharktank==${RELEASE} \
+ --no-deps --python-version 3.11 -f ${ASSETS_PAGE}
+
+ # shortfin
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.11 -f ${ASSETS_PAGE}
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.12 -f ${ASSETS_PAGE}
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.13 -f ${ASSETS_PAGE}
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.13 -f ${ASSETS_PAGE}
+ # TODO: fetch 3.13t using the same `python` somehow
+ # * https://pip.pypa.io/en/stable/cli/pip_download/
+ # * https://py-free-threading.github.io/installing_cpython/
+ # * https://pip.pypa.io/en/stable/installation/
+ python3.13t -m pip download shortfin==${RELEASE} --no-deps -f ${ASSETS_PAGE}
+
+ # TODO: shark-ai meta package when it is published to nightlies
+
+ echo ""
+ echo "Downloaded wheels:"
+ ls
+}
+
+function edit_release_versions() {
+ echo ""
+ echo "Editing release versions..."
+ for file in *
+ do
+ ${SCRIPT_DIR}/promote_whl_from_rc_to_final.py ${file} --delete-old-wheel
+ done
+
+ echo "Edited wheels:"
+ ls
+}
+
+function upload_wheels() {
+ # TODO: list packages that would be uploaded, pause, prompt to continue
+ echo ""
+ echo "Uploading wheels:"
+ ls
+ twine upload --verbose *
+}
+
+function build_shark_ai_meta_package() {
+ # TODO: download meta package from nightly releases instead of this
+ # Be aware that nightly releases pin other dependencies via the
+ # generated `requirements.txt` compared to stable releases.
+ echo ""
+
+ # TODO: rework `write_requirements.py` to use the versions from the downloaded whls?
+ echo "Computing local versions for sharktank and shortfin..."
+ ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/sharktank
+ ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/shortfin
+
+ echo "Computing common version for shark-ai meta package..."
+ ${SCRIPT_DIR}/compute_common_version.py --stable-release --write-json
+
+ echo "Writing requirements for shark-ai meta package..."
+ ${SCRIPT_DIR}/write_requirements.py
+
+ echo "Building shark-ai meta package..."
+ ${REPO_ROOT}/shark-ai/build_tools/build_linux_package.sh
+
+ # TODO: This is error-prone. We only want to publish the whl for this release.
+ # Copy instead? Specify exact file name? Clear directory before building?
+ mv ${REPO_ROOT}/shark-ai/build_tools/wheelhouse/* .
+}
+
+function main() {
+ echo "Changing into ${TMPDIR}"
+ cd "${TMPDIR}"
+ # TODO: check_requirements (using pip)
+
+ download_wheels
+ edit_release_versions
+ build_shark_ai_meta_package
+ upload_wheels
+}
+
+main
diff --git a/build_tools/python_deploy/requirements-pypi-deploy.txt b/build_tools/python_deploy/requirements-pypi-deploy.txt
new file mode 100644
index 000000000..dcc32d47a
--- /dev/null
+++ b/build_tools/python_deploy/requirements-pypi-deploy.txt
@@ -0,0 +1,4 @@
+change_wheel_version
+packaging
+pkginfo
+twine
diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py
new file mode 100755
index 000000000..224e01fd0
--- /dev/null
+++ b/build_tools/python_deploy/write_requirements.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This script writes the `packaging/shark-ai/requirements.txt` file and pins
+# the versions of the dependencies accordingly. For nightly releases,
+# * sharktank
+# * shortfin
+# get pinned to the corresponding nightly version. The IREE packages are
+# unpinned. For stable releases,
+# * iree-base-compiler
+# * iree-base-runtime
+# * iree-turbine
+# * sharktank
+# * shortfin
+# get pinned to the corresponding `X.Y.*` version.
+
+import argparse
+from pathlib import Path
+import json
+
+from packaging.version import Version
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--version-suffix", action="store", type=str)
+
+args = parser.parse_args()
+
+
+THIS_DIR = Path(__file__).parent.resolve()
+REPO_ROOT = THIS_DIR.parent.parent
+
+VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_local.json"
+VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_local.json"
+VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json"
+REQUIREMENTS_TXT = REPO_ROOT / "shark-ai/requirements.txt"
+
+
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
+ return json.load(f)
+
+
+def write_requirements(requirements):
+ with open(REQUIREMENTS_TXT, "w") as f:
+ f.write("%s\n" % requirements)
+
+
+metapackage_version = load_version_info(VERSION_FILE_LOCAL)
+PACKAGE_VERSION = metapackage_version.get("package-version")
+
+# sharktank_version = load_version_info(VERSION_FILE_SHARKTANK)
+# SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version")
+
+shortfin_version = load_version_info(VERSION_FILE_SHORTFIN)
+SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version")
+
+stable_packages_list = ["iree-base-compiler", "iree-base-runtime", "iree-turbine"]
+
+if Version(PACKAGE_VERSION).is_prerelease:
+ requirements = ""
+ for package in stable_packages_list:
+ requirements += package + "\n"
+ # TODO: Include sharktank as a dependencies of future releases
+ # requirements = (
+ # "sharktank=="
+ # + Version(SHARKTANK_PACKAGE_VERSION).base_version
+ # + args.version_suffix
+ # + "\n"
+ # )
+ requirements += (
+ "shortfin=="
+ + Version(SHORTFIN_PACKAGE_VERSION).base_version
+ + args.version_suffix
+ )
+
+ write_requirements(requirements)
+
+else:
+ MAJOR_VERSION = Version(PACKAGE_VERSION).major
+ MINOR_VERSION = Version(PACKAGE_VERSION).minor
+
+ STABLE_VERSION_TO_PIN = str(MAJOR_VERSION) + "." + str(MINOR_VERSION) + ".*"
+
+ requirements = ""
+ for package in stable_packages_list:
+ requirements += package + "==" + STABLE_VERSION_TO_PIN + "\n"
+ # TODO: Include sharktank as a dependencies of future releases
+ # requirements += (
+ # "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n"
+ # )
+ requirements += "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version
+
+ write_requirements(requirements)
diff --git a/docs/developer_guide.md b/docs/developer_guide.md
new file mode 100644
index 000000000..73aee61f7
--- /dev/null
+++ b/docs/developer_guide.md
@@ -0,0 +1,118 @@
+# SHARK Developer Guide
+
+Each sub-project has its own developer guide. If you would like to work across
+projects, these instructions should help you get started:
+
+
+### Install Dependencies
+
+Install shortfin dependencies
+```bash
+sudo apt update && sudo apt install -y clang lld
+```
+
+### Prepare your python environment
+
+Install:
+
+```bash
+sudo apt install python-is-python3 python3-venv python3-dev
+```
+
+
+
+ Or, alternatively, use `pyenv` to manage a separate python installation for more control over its version:
+
+
+The following instructions are taken from pyenv's guide here: https://github.com/pyenv/pyenv?tab=readme-ov-file#a-getting-pyenv
+
+First, install pyenv and its dependencies.
+
+```bash
+sudo apt update; sudo apt install build-essential libssl-dev zlib1g-dev \
+libbz2-dev libreadline-dev libsqlite3-dev curl git \
+libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev
+curl https://pyenv.run | bash
+```
+
+Then, make pyenv available by adding the below to your `~/.bashrc`:
+
+```bash
+export PYENV_ROOT="$HOME/.pyenv"
+command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"
+eval "$(pyenv init -)"
+```
+
+Finally, install a pyenv-managed version of python
+
+```bash
+pyenv install 3.12 # or whichever python version you'd like
+pyenv local 3.12
+```
+
+Now, your python, pip, and venv should be managed by pyenv instead.
+
+
+
+### Setup a venv
+
+We recommend setting up a Python
+[virtual environment (venv)](https://docs.python.org/3/library/venv.html).
+The project is configured to ignore `.venv` directories, and editors like
+VSCode pick them up by default.
+
+```bash
+python -m venv .venv
+source .venv/bin/activate
+```
+
+### Install PyTorch for your system
+
+If no explicit action is taken, the default PyTorch version will be installed.
+This will give you a current CUDA-based version, which takes longer to download
+and includes other dependencies that SHARK does not require. To install a
+different variant, run one of these commands first:
+
+* *CPU:*
+
+ ```bash
+ pip install -r pytorch-cpu-requirements.txt
+ ```
+
+* *ROCM:*
+
+ ```bash
+ pip install -r pytorch-rocm-requirements.txt
+ ```
+
+* *Other:* see instructions at .
+
+### Install development packages
+
+```bash
+# Install editable local projects.
+pip install -r requirements.txt -e sharktank/ shortfin/
+
+# Optionally clone and install the latest editable iree-turbine dep in deps/,
+# along with nightly versions of iree-base-compiler and iree-base-runtime.
+pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler iree-base-runtime --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+```
+
+See also: [nightly_releases.md](nightly_releases.md).
+
+### Running tests
+
+```bash
+pip install -r shortfin/requirements-tests.txt
+pytest sharktank
+pytest shortfin
+pytest app_tests/integration_tests
+```
+
+### Optional: pre-commits and developer settings
+
+This project is set up to use the `pre-commit` tooling. To install it in
+your local repo, run: `pre-commit install`. After this point, when making
+commits locally, hooks will run. See https://pre-commit.com/
diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md
index 64137956d..ddc0cb3bb 100644
--- a/docs/model_cookbook.md
+++ b/docs/model_cookbook.md
@@ -1,7 +1,7 @@
# Model cookbook
-Note: These are early notes and commands that the sharktank team is using and
-will turn into proper docs later.
+Note: These are early notes and commands that the shark-ai team is using
+and will turn into proper docs later.
## Diagrams
@@ -165,18 +165,13 @@ tokenizer_config.json: 100%|█████████████████
Setup (from [README.md](../README.md)):
-* TODO: this could be replaced with `pip install iree-turbine` or
- `pip install sharktank` at some point. For now these are dev packages.
-
```bash
# Setup venv.
python -m venv --prompt sharktank .venv
source .venv/bin/activate
-# Install requirements.
+# (Optional) Install PyTorch for CPU only, to save on download time.
pip install -r pytorch-cpu-requirements.txt
-pip install -f https://iree.dev/pip-release-links.html --src deps \
- -e "git+https://github.com/iree-org/iree-turbine.git#egg=shark-turbine"
# Install local projects.
pip install -r requirements.txt -e sharktank/ shortfin/
@@ -257,6 +252,21 @@ iree-run-module \
--parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf
```
+## Evaluation pipeline
+
+Run perplexity test:
+
+```bash
+pytest sharktank/tests/evaluate/perplexity_test.py --longrun
+```
+
+Run perplexity for a new model:
+```bash
+python -m sharktank.evaluate.perplexity \
+ --gguf-file=llama8b_f16.gguf \
+ --tokenizer-config-json=tokenizer_config.json
+```
+
## Generating data for llama models
```bash
diff --git a/docs/nightly_releases.md b/docs/nightly_releases.md
new file mode 100644
index 000000000..f41374445
--- /dev/null
+++ b/docs/nightly_releases.md
@@ -0,0 +1,213 @@
+# Nightly releases
+
+Nightly releases are uploaded to
+https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels.
+
+* The "expanded_assets" version of a release page is compatible with the
+ `-f, --find-links ` options of `pip install`
+ ([docs here](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-f)).
+ For the "dev-wheels" release above, that page is:
+
+* These releases are generated using
+ [`.github/workflows/build_package.yml`](../.github/workflows/build_packages.yml)
+* That workflow runs the
+ [`sharktank/build_tools/build_linux_package.sh`](../sharktank/build_tools/build_linux_package.sh)
+ and
+[`shortfin/build_tools/build_linux_package.sh`](../shortfin/build_tools/build_linux_package.sh)
+ scripts
+* Workflow history can be viewed at
+
+
+## Prerequisites
+
+### Operating system
+
+Currently we only officially support Linux with published packages. Windows and
+macOS support is possible, but may need additional setup, code changes, and
+source builds.
+
+### Python
+
+You will need a recent version of Python.
+
+* As of Nov 1, 2024, sharktank is compatible with Python 3.11. See
+ https://github.com/nod-ai/shark-ai/issues/349 for Python 3.12 support.
+* As of Nov 4, 2024, shortfin publishes packages for Python 3.11, 3.12, 3.13,
+ and 3.13t
+
+For example, to install Python 3.11 on Ubuntu:
+
+```bash
+sudo apt install python3.11 python3.11-dev python3.11-venv
+
+which python3.11
+# /usr/bin/python3.11
+```
+
+> [!TIP]
+> Manage multiple Python versions using `pyenv`
+> (), or the
+> [Python Launcher for Windows](https://docs.python.org/3/using/windows.html#python-launcher-for-windows)
+> on Windows.
+
+## Quickstart - sharktank
+
+```bash
+# Set up a virtual environment to isolate packages from other envs.
+python3.11 -m venv 3.11.venv
+source 3.11.venv/bin/activate
+
+# Install 'sharktank' package from nightly releases.
+pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+
+# Test the installation.
+python -c "from sharktank import ops; print('Sanity check passed')"
+
+# Deactivate the virtual environment when done.
+deactivate
+```
+
+## Quickstart - shortfin
+
+```bash
+# Set up a virtual environment to isolate packages from other envs.
+python3.11 -m venv 3.11.venv
+source 3.11.venv/bin/activate
+
+# Install 'shortfin' package from nightly releases.
+pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+
+# Test the installation.
+python -c "import shortfin as sf; print('Sanity check passed')"
+
+# Deactivate the virtual environment when done.
+deactivate
+```
+
+## Installing newer versions of dependencies
+
+To install the `iree-turbine` package from the latest source:
+
+```bash
+pip install --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+```
+
+To install the `iree-base-compiler` and `iree-base-runtime` packages from
+nightly releases:
+
+```bash
+pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler iree-base-runtime
+```
+
+To install all three packages together:
+
+```bash
+pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler iree-base-runtime --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+```
+
+## Switching between stable and nightly channels
+
+The [`shark-ai` package on PyPI](https://pypi.org/project/shark-ai/) is a
+meta-package that pins specific stable versions of each package that share
+at least their major and minor versions:
+
+```bash
+pip install shark-ai==2.9.1
+
+pip freeze
+# ...
+# iree-base-compiler==2.9.0
+# iree-base-runtime==2.9.0
+# iree-turbine==2.9.0
+# ...
+# shark-ai==2.9.1
+# shortfin==2.9.1
+# ...
+```
+
+If you attempt to update any individual package outside of those supported
+versions, pip will log an error but continue anyway:
+
+```bash
+pip install --upgrade --pre \
+ -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels \
+ shortfin==3.0.0rc20241118
+
+# Looking in links: https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+# Collecting shortfin==3.0.0rc20241118
+# Downloading https://github.com/nod-ai/shark-ai/releases/download/dev-wheels/shortfin-3.0.0rc20241118-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB)
+# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 24.3 MB/s eta 0:00:00
+# Installing collected packages: shortfin
+# Attempting uninstall: shortfin
+# Found existing installation: shortfin 2.9.1
+# Uninstalling shortfin-2.9.1:
+# Successfully uninstalled shortfin-2.9.1
+# ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
+# shark-ai 2.9.1 requires shortfin==2.9.1, but you have shortfin 3.0.0rc20241118 which is incompatible.
+# Successfully installed shortfin-3.0.0rc20241118
+
+pip freeze
+# ...
+# shark-ai==2.9.1
+# shortfin==3.0.0rc20241118
+# ...
+```
+
+Installing the `shark-ai` package again should get back to aligned versions:
+
+```bash
+pip install shark-ai==2.9.1
+# ...
+# Installing collected packages: shortfin
+# Attempting uninstall: shortfin
+# Found existing installation: shortfin 3.0.0rc20241118
+# Uninstalling shortfin-3.0.0rc20241118:
+# Successfully uninstalled shortfin-3.0.0rc20241118
+# Successfully installed shortfin-2.9.1
+
+pip freeze
+# ...
+# shark-ai==2.9.1
+# shortfin==2.9.1
+# ...
+```
+
+You can also uninstall the `shark-ai` package to bypass the error and take full
+control of package versions yourself:
+
+```bash
+pip uninstall shark-ai
+
+pip freeze
+# ...
+# (note: no shark-ai package)
+# shortfin==2.9.1
+# ...
+
+pip install --upgrade --pre \
+ -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels \
+ shortfin==3.0.0rc20241118
+
+# Looking in links: https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+# Collecting shortfin==3.0.0rc20241118
+# Using cached https://github.com/nod-ai/shark-ai/releases/download/dev-wheels/shortfin-3.0.0rc20241118-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB)
+# Installing collected packages: shortfin
+# Attempting uninstall: shortfin
+# Found existing installation: shortfin 2.9.1
+# Uninstalling shortfin-2.9.1:
+# Successfully uninstalled shortfin-2.9.1
+# Successfully installed shortfin-3.0.0rc20241118
+
+pip freeze
+# ...
+# (note: no shark-ai package)
+# shortfin==3.0.0rc20241118
+# ...
+```
+
+If you ever get stuck, consider creating a fresh
+[virtual environment](https://docs.python.org/3/library/venv.html).
diff --git a/docs/quantization.md b/docs/quantization.md
index 0563e8108..25bfc9f8d 100644
--- a/docs/quantization.md
+++ b/docs/quantization.md
@@ -4,10 +4,10 @@ author: Stella Laurenzo
date: June 30, 2024
---
-# Direct Quantization with sharktank
+# Direct Quantization with SHARK Tank
As a toolkit for building and adapting PyTorch based models for deployment,
-sharktank provides rich quantization support. By targeting the
+SHARK Tank provides rich quantization support. By targeting the
[IREE compiler](https://github.com/iree-org/iree) for optimizations, we can
strike a balance with our quantization setup that:
@@ -36,7 +36,7 @@ supports these indirect schemes -- effectively using compiler transformations
under the covers to do opaque model transformations that mirror a subset of
what is exposed directly to the user in the rest of this document.
-As an alternative, when developing sharktank and bringing up the initial
+As an alternative, when developing SHARK Tank and bringing up the initial
models, we wanted something more flexible, easier to debug/extend, and
less laden with needing to lowest common denominator something for everyone
in order to fit into fixed-function op sets that are very expensive to change.
@@ -63,12 +63,12 @@ amount of Python code implementing direct math and packing schemes.
drop-in replacements for subsets of the functionality available in stock
PyTorch modules like `Linear` and `Conv2D`.
2. Types/Ops: The `nn.Module` implementations we provide are built in terms
- of sharktank custom
- [`InferenceTensor`](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/types/tensors.py#L153)
- and [polymorphic functional ops library](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/ops/signatures.py).
+ of SHARK Tank custom
+ [`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/tensors.py#L153)
+ and [polymorphic functional ops library](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py).
3. Op specializations for optimized subsets of op type signatures and features
(for example, [an optimized affine quantized linear specialization for
- supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/ops/qlinear_impls.py)).
+ supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py)).
(TODO: good place for a diagram)
@@ -78,18 +78,18 @@ amount of Python code implementing direct math and packing schemes.
Available modules that support direct quantization (TODO: refactor to use
torch "Module" terminology and naming schemes consistently):
-* [`LinearLayer`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/layers/linear.py)
-* [convolution layers](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/layers/conv.py)
+* [`LinearLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/linear.py)
+* [convolution layers](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/conv.py)
Note that most sharktank modules extend
-[`ThetaLayer`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/layers/base.py#L63),
+[`ThetaLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/base.py#L63),
which calls for a bit of explanation. Traditional PyTorch Modules directly
instantiate their backing parameters in their constructor. For dataset-heavy
and polymorphic implementations like we commonly see in quantization and
distribution, however, it can be beneficial to separate these concerns.
The `ThetaLayer` simply takes a
-[`Theta` object](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/theta.py#L74),
+[`Theta` object](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L74),
which is a tree-structured bag of native `torch.Tensor` or `InferenceTensor`
instances, and it adopts the tensors in the bag as its own vs creating them.
For those familiar with the concept, this is a form of dependency-injection
@@ -114,7 +114,7 @@ tree to a specific Module instance.
We've already met the `Theta` object above, which holds a tree of something
called an
-[`InferenceTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153).
+[`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153).
Now we describe what this is. Note that presently, `InferenceTensor` is not a
`torch.Tensor` but its own `ABC` type that:
@@ -140,11 +140,11 @@ pipelines.
There is a growing list of `InferenceTensor` sub-types, many of which are
related to quantization:
-* [`PrimitiveTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286):
+* [`PrimitiveTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286):
A simple composition of a single `torch.Tensor`. This is often used
interchangeably with a `torch.Tensor` but is present for completeness of
the type hierarchy and to be able to type select on.
-* [`QuantizedTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372):
+* [`QuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372):
Abstract base class of all quantized tensors, providing two primary operations:
* `unpack`: Accesses the backing `QuantizedLayout` of the tensor, which is
@@ -154,12 +154,12 @@ related to quantization:
layout, this explodes it into a canonical representation of individual
tensors which can be algebraically implemented individually/generically).
-* [`PlanarQuantizedTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
+* [`PlanarQuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
Concrete implementation for all non-packed quantized tensors that can be
losslessly represented by a layout based on individual tensor components.
All `QuantizedTensor` instances can be converted to a `PlanarQuantizedTensor`.
-* [`QuantizerTensor`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
+* [`QuantizerTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
(note the "r" in the name) An abstract `InferenceTensor` that exposes a
`quantize(torch.Tensor | InferenceTensor) -> QuantizedTensor` operation used
to transform an arbitrary tensor to a quantized form. There are a handful
@@ -178,7 +178,7 @@ manipulate tensor contents via `QuantizedLayout`, but we haven't yet defined
that. The *Tensor types are structural and exist to give identity, but the
`QuantizedLayout` is where the "magic happens".
-[`QuantizedLayout`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44)
+[`QuantizedLayout`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44)
is an `ABC`, supporting:
* Serialization/interop with parameter archives.
@@ -193,7 +193,7 @@ is an `ABC`, supporting:
There are a number of implementations, as every quantization scheme typically
needs at least one concrete `QuantizedLayout`. Simple schemes like affine
quantization can be fully defined in terms of a single
-[`TensorScaledLayout`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/layouts.py#L43).
+[`TensorScaledLayout`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/layouts.py#L43).
Whereas packed schemes like we find in inference engines like GGML and XNNPACK
optimally require both a packed layout and a planar layout.
@@ -224,7 +224,7 @@ interpreting/transforming using their natively defined forms.
Previously, we found a rich type system defining all manner of layouts and
quantization schemes, but what can be done with it? That is where the
sharktank functional op library comes in. These
-[logical ops](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/signatures.py)
+[logical ops](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py)
provide the building blocks to implement built-in and custom `nn.Module`
implementations operating on `InferenceTensor` (and torch.Tensor) types.
@@ -239,12 +239,12 @@ implementation at any needed level of granularity:
structures and preserve it when computing (when combined with a
fusing compiler, this alone provides decent fallback implementations for a
variety of "weight compression" oriented techniques). See
- [some examples](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/custom_impls.py#L51).
+ [some examples](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/custom_impls.py#L51).
* Pure-Torch decompositions for algebraic techniques like affine quantization
(when combined with a fusing compiler, this alone is sufficient for
optimization). See
- [qlinear](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/qlinear_impls.py) and
- [qconv](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/ops/qconv_impls.py)
+ [qlinear](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py) and
+ [qconv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qconv_impls.py)
implementations of actual affine quantized decompositions.
* Completely custom packed/optimized implementation. These can be written to
activate on any level of detail of the type hierarchy. The implementation
@@ -277,11 +277,11 @@ is everything). We're just starting to exploit some of this as the PyTorch
level. Some examples:
* Something as simple as a humble runtime
-[tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/iree.py#L52)
-* [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/_jinja_test_ops.py#L28)
- (see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/shark_turbine/ops/templates/test_add_jinja.mlir)).
-* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py)
- (see backing [jinja template](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)).
+[tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/iree.py#L52)
+* [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/_jinja_test_ops.py#L28)
+ (see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/templates/test_add_jinja.mlir)).
+* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py)
+ (see backing [jinja template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)).
* DSL based [like this fused attention kernel](https://github.com/iree-org/iree-turbine/blob/main/tests/kernel/fused_attention_test.py#L20)
(note that in this case, the DSL exports to the unerlying IR-based registration
mechanism used in the previous examples).
@@ -292,8 +292,8 @@ Since all of these types of custom kernels are just defined with simple Python
tooling, they are really fast to iterate on. The linalg based kernels specifically
tend to be highly portable, and we don't hesitate to write one of those when
we need something specific that PyTorch doesn't provide out of the box
-(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py)
-([template](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))).
+(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py)
+([template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))).
## Dataset transformation
@@ -307,7 +307,7 @@ We take a practical approach to this, writing implementation specific converters
where needed, and taking advantage of industry-standard consolidation points
where available (like GGUF) in order to cover a wider surface area.
-Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/types/theta.py#L263),
+Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L263),
which combines some set of hyper-parameters with a root `Theta` object
(typically representing the layer-tree of frozen tensors). Datasets can be
losslessly persisted to IREE IRPA files, which can then be loaded by either
@@ -321,9 +321,9 @@ transform, shard, etc.
See some examples:
-* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) :
+* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) :
Creating a `Dataset` object from an HF diffusers safetensors file and config.json.
-* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/sharktank/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) :
+* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) :
Creates a quantized `Dataset` by combining:
* HF diffusers `config.json`
diff --git a/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md
new file mode 100644
index 000000000..1ce2d1e8d
--- /dev/null
+++ b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md
@@ -0,0 +1,242 @@
+# LLama 8b GPU Instructions on MI300X
+
+**NOTE: This was ran on the `mi300x-3` system**
+
+## Setup
+
+We will use an example with `llama_8b_f16_decomposed` in order to describe the
+process of exporting a model for use in the shortfin llm server with an MI300 GPU.
+
+### Pre-Requisites
+
+- Python >= 3.11 is recommended for this flow
+ - You can check out [pyenv](https://github.com/pyenv/pyenv) as a good tool
+ to be able to manage multiple versions of python on the same system.
+
+### Setting Up Environment
+
+Follow the `Development Getting Started` docs
+[here](https://github.com/nod-ai/shark-ai/blob/main/README.md#development-getting-started)
+to setup your environment for development.
+
+We will use an example with `llama_8b_f16_decomposed` in order to describe the
+process of exporting a model for use in the shortfin llm server with an MI300 GPU.
+
+### Define a directory for export files
+
+Create a new directory for us to export files like `model.mlir`, `model.vmfb`, etc.
+
+```bash
+mkdir $PWD/export
+export EXPORT_DIR=$PWD/exportd
+```
+
+### Define environment variables
+
+Define the following environment variables to make running this example a bit easier:
+
+#### Model/Tokenizer vars
+
+This example uses the `llama8b_f16.irpa` and `tokenizer.json` files that are
+pre-existing on the MI300X-3 system.
+You may need to change the paths for your own system.
+
+```bash
+export MODEL_PARAMS_PATH=/data/llama3.1/8b/llama8b_f16.irpa # Path to existing .irpa file, may need to change w/ system
+export TOKENIZER_PATH=/data/llama3.1/8b/tokenizer.json # Path to existing tokenizer.json, may need to change w/ system
+```
+
+#### General env vars
+
+The following env vars can be copy + pasted directly:
+
+```bash
+export MLIR_PATH=$EXPORT_DIR/model.mlir # Path to export model.mlir file
+export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json # Path to export config.json file
+export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json # Path to export config.json file
+export VMFB_PATH=$EXPORT_DIR/model.vmfb # Path to export model.vmfb file
+export BS=1,4 # Batch size for kvcache
+export ROCR_VISIBLE_DEVICES=1 # NOTE: This is temporary, until multi-device is fixed
+```
+
+### Export to MLIR
+
+We will now use the `sharktank.examples.export_paged_llm_v1` script to export
+our model to `.mlir` format.
+
+```bash
+python -m sharktank.examples.export_paged_llm_v1 \
+ --irpa-file=$MODEL_PARAMS_PATH \
+ --output-mlir=$MLIR_PATH \
+ --output-config=$OUTPUT_CONFIG_PATH \
+ --bs=$BS
+```
+
+## Compiling to `.vmfb`
+
+Now that we have generated a `model.mlir` file, we can compile it to `.vmfb`
+format, which is required for running the `shortfin` LLM server.
+
+We will use the [iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile)
+tool for compiling our model.
+
+### Compile for MI300
+
+**NOTE: This command is specific to MI300 GPUs.
+For other `--iree-hip-target` GPU options,
+look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)**
+
+```bash
+iree-compile $MLIR_PATH \
+ --iree-hal-target-backends=rocm \
+ --iree-hip-target=gfx942 \
+ -o $VMFB_PATH
+```
+
+## Write an edited config
+
+We need to write a config for our model with a slightly edited structure
+to run with shortfin. This will work for the example in our docs.
+You may need to modify some of the parameters for a specific model.
+
+### Write edited config
+
+```bash
+cat > $EDITED_CONFIG_PATH << EOF
+{
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 131072,
+ "attn_head_count": 8,
+ "attn_head_dim": 128,
+ "prefill_batch_sizes": [
+ $BS
+ ],
+ "decode_batch_sizes": [
+ $BS
+ ],
+ "transformer_block_count": 32,
+ "paged_kv_cache": {
+ "block_seq_stride": 16,
+ "device_block_count": 256
+ }
+}
+EOF
+```
+
+## Running the `shortfin` LLM server
+
+We should now have all of the files that we need to run the shortfin LLM server.
+
+Verify that you have the following in your specified directory ($EXPORT_DIR):
+
+```bash
+ls $EXPORT_DIR
+```
+
+- edited_config.json
+- model.vmfb
+
+### Launch server:
+
+#### Set the target device
+
+
+
+#### Run the shortfin server
+
+Run the following command to launch the Shortfin LLM Server in the background:
+
+> **Note**
+> By default, our server will start at `http://localhost:8000`.
+> You can specify the `--host` and/or `--port` arguments, to run at a different address.
+>
+> If you receive an error similar to the following:
+>
+> `[errno 98] address already in use`
+>
+> Then, you can confirm the port is in use with `ss -ntl | grep 8000`
+> and either kill the process running at that port,
+> or start the shortfin server at a different port.
+
+```bash
+python -m shortfin_apps.llm.server \
+ --tokenizer_json=$TOKENIZER_PATH \
+ --model_config=$EDITED_CONFIG_PATH \
+ --vmfb=$VMFB_PATH \
+ --parameters=$MODEL_PARAMS_PATH \
+ --device=hip > shortfin_llm_server.log 2>&1 &
+shortfin_process=$!
+```
+
+You can verify your command has launched successfully when you see the following
+ logs outputted to terminal:
+
+```bash
+cat shortfin_llm_server.log
+```
+
+#### Expected output
+
+```text
+[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete.
+[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
+```
+
+## Verify server
+
+### Client script
+
+We can test the LLM server, by running our client script:
+
+```bash
+python shortfin/python/shortfin_apps/llm/client.py --port 8000
+```
+
+### Simple request
+
+Or by sending a simple request:
+
+### Open python shell
+
+```bash
+python
+```
+
+### Send request
+
+```python
+import requests
+
+import os
+
+port = 8000 # Change if running at a different port
+
+generate_url = f"http://localhost:{port}/generate"
+
+def generation_request():
+ payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}}
+ try:
+ resp = requests.post(generate_url, json=payload)
+ resp.raise_for_status() # Raises an HTTPError for bad responses
+ print(resp.text)
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+
+generation_request()
+```
+
+After you receive the request, you can exit the python shell:
+
+```bash
+quit()
+```
+
+## Cleanup
+
+When done, you can kill the shortfin_llm_server by killing the process:
+
+```bash
+kill -9 $shortfin_process
+```
diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md
new file mode 100644
index 000000000..4a8423bc8
--- /dev/null
+++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md
@@ -0,0 +1,278 @@
+# LLama 8b GPU instructions on MI300X
+
+## Setup
+
+We will use an example with `llama_8b_f16` in order to describe the
+process of exporting a model for use in the shortfin llm server with an
+MI300 GPU.
+
+### Pre-Requisites
+
+- Python >= 3.11 is recommended for this flow
+ - You can check out [pyenv](https://github.com/pyenv/pyenv)
+ as a good tool to be able to manage multiple versions of python
+ on the same system.
+
+### Create virtual environment
+
+To start, create a new virtual environment:
+
+```bash
+python -m venv --prompt shark-ai .venv
+source .venv/bin/activate
+```
+
+### Install `shark-ai`
+
+You can install either the `latest stable` version of `shark-ai`
+or the `nightly` version:
+
+#### Stable
+
+```bash
+pip install shark-ai
+```
+
+#### Nightly
+
+```bash
+pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+```
+
+#### Install dataclasses-json
+
+
+
+```bash
+pip install dataclasses-json
+```
+
+### Define a directory for export files
+
+Create a new directory for us to export files like
+`model.mlir`, `model.vmfb`, etc.
+
+```bash
+mkdir $PWD/export
+export EXPORT_DIR=$PWD/export
+```
+
+### Download llama3_8b_fp16.gguf
+
+We will use the `hf_datasets` module in `sharktank` to download a
+LLama3.1 8b f16 model.
+
+```bash
+python -m sharktank.utils.hf_datasets llama3_8B_fp16 --local-dir $EXPORT_DIR
+```
+
+### Define environment variables
+
+Define the following environment variables to make running
+this example a bit easier:
+
+#### Model/Tokenizer vars
+
+This example uses the `llama8b_f16.gguf` and `tokenizer.json` files
+that were downloaded in the previous step.
+
+```bash
+export MODEL_PARAMS_PATH=$EXPORT_DIR/llama3.1-8b/llama8b_f16.gguf
+export TOKENIZER_PATH=$EXPORT_DIR/llama3.1-8b/tokenizer.json
+```
+
+#### General env vars
+
+The following env vars can be copy + pasted directly:
+
+```bash
+# Path to export model.mlir file
+export MLIR_PATH=$EXPORT_DIR/model.mlir
+# Path to export config.json file
+export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json
+# Path to export edited_config.json file
+export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json
+# Path to export model.vmfb file
+export VMFB_PATH=$EXPORT_DIR/model.vmfb
+# Batch size for kvcache
+export BS=1,4
+# NOTE: This is temporary, until multi-device is fixed
+export ROCR_VISIBLE_DEVICES=1
+```
+
+## Export to MLIR
+
+We will now use the `sharktank.examples.export_paged_llm_v1` script
+to export our model to `.mlir` format.
+
+```bash
+python -m sharktank.examples.export_paged_llm_v1 \
+ --irpa-file=$MODEL_PARAMS_PATH \
+ --output-mlir=$MLIR_PATH \
+ --output-config=$OUTPUT_CONFIG_PATH \
+ --bs=$BS
+```
+
+## Compiling to `.vmfb`
+
+Now that we have generated a `model.mlir` file,
+we can compile it to `.vmfb` format, which is required for running
+the `shortfin` LLM server.
+
+We will use the
+[iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile)
+tool for compiling our model.
+
+### Compile for MI300
+
+**NOTE: This command is specific to MI300 GPUs.
+For other `--iree-hip-target` GPU options,
+look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)**
+
+```bash
+iree-compile $MLIR_PATH \
+ --iree-hal-target-backends=rocm \
+ --iree-hip-target=gfx942 \
+ -o $VMFB_PATH
+```
+
+## Write an edited config
+
+We need to write a config for our model with a slightly edited structure
+to run with shortfin. This will work for the example in our docs.
+You may need to modify some of the parameters for a specific model.
+
+### Write edited config
+
+```bash
+cat > $EDITED_CONFIG_PATH << EOF
+{
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 131072,
+ "attn_head_count": 8,
+ "attn_head_dim": 128,
+ "prefill_batch_sizes": [
+ $BS
+ ],
+ "decode_batch_sizes": [
+ $BS
+ ],
+ "transformer_block_count": 32,
+ "paged_kv_cache": {
+ "block_seq_stride": 16,
+ "device_block_count": 256
+ }
+}
+EOF
+```
+
+## Running the `shortfin` LLM server
+
+We should now have all of the files that we need to run the shortfin LLM server.
+
+Verify that you have the following in your specified directory ($EXPORT_DIR):
+
+```bash
+ls $EXPORT_DIR
+```
+
+- edited_config.json
+- model.vmfb
+
+### Launch server:
+
+
+
+#### Run the shortfin server
+
+Now that we are finished with setup, we can start the Shortfin LLM Server.
+
+Run the following command to launch the Shortfin LLM Server in the background:
+
+> **Note**
+> By default, our server will start at `http://localhost:8000`.
+> You can specify the `--host` and/or `--port` arguments, to run at a different address.
+>
+> If you receive an error similar to the following:
+>
+> `[errno 98] address already in use`
+>
+> Then, you can confirm the port is in use with `ss -ntl | grep 8000`
+> and either kill the process running at that port,
+> or start the shortfin server at a different port.
+
+```bash
+python -m shortfin_apps.llm.server \
+ --tokenizer_json=$TOKENIZER_PATH \
+ --model_config=$EDITED_CONFIG_PATH \
+ --vmfb=$VMFB_PATH \
+ --parameters=$MODEL_PARAMS_PATH \
+ --device=hip > shortfin_llm_server.log 2>&1 &
+shortfin_process=$!
+```
+
+You can verify your command has launched successfully
+when you see the following logs outputted to terminal:
+
+```bash
+cat shortfin_llm_server.log
+```
+
+#### Expected output
+
+```text
+[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete.
+[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
+```
+
+## Verify server
+
+We can now verify our LLM server by sending a simple request:
+
+### Open python shell
+
+```bash
+python
+```
+
+### Send request
+
+```python
+import requests
+
+import os
+
+port = 8000 # Change if running on a different port
+
+generate_url = f"http://localhost:{port}/generate"
+
+def generation_request():
+ payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}}
+ try:
+ resp = requests.post(generate_url, json=payload)
+ resp.raise_for_status() # Raises an HTTPError for bad responses
+ print(resp.text)
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+
+generation_request()
+```
+
+After you receive the request, you can exit the python shell:
+
+```bash
+quit()
+```
+
+## Cleanup
+
+When done, you can kill the shortfin_llm_server by killing the process:
+
+```bash
+kill -9 $shortfin_process
+```
diff --git a/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md
new file mode 100644
index 000000000..b63861a56
--- /dev/null
+++ b/docs/shortfin/llm/user/shortfin_with_sglang_frontend_language.md
@@ -0,0 +1,254 @@
+# Using `shortfin` with `sglang`
+
+This doc includes basic steps for hooking up sglang with a running Shortfin server.
+
+## Current Support Status
+
+| Feature | Description | Enabled | Reference |
+| ----------- | ----------- | ---------- | ------------ |
+| `gen` | Generate shortfin completion, given a prompt | ✅ | [Shortfin Implementation](https://github.com/nod-ai/sglang/blob/main/python/sglang/lang/backend/shortfin.py) |
+| `streaming` | Stream shortfin completion, given a prompt | ✅ | [Streaming](https://sgl-project.github.io/frontend/frontend.html#streaming) |
+| `run_batch` | Run batch of disjoint requests with continous batching | ✅ | [Batching](https://sgl-project.github.io/frontend/frontend.html#batching) |
+| `fork` | Generate sections of the same prompt in parallel | ✅ | [Fork Docs](https://sgl-project.github.io/frontend/frontend.html#parallelism) |
+| `choices` | Given set of choices, generate response based on best log probs | ❌ | [Choices Methods](https://sgl-project.github.io/frontend/choices_methods.html#choices-methods-in-sglang) |
+| `image` | Pass image as part of multi-modal prompt | ❌ | [sgl.image](https://sgl-project.github.io/frontend/frontend.html#multi-modality) |
+| `regex` | Specify regular expression as decoding constraint | ❌ | [Regex](https://sgl-project.github.io/frontend/frontend.html#constrained-decoding) |
+
+## Prerequisites
+
+For this tutorial, you will need to meet the following prerequisites:
+
+### Software
+
+- Python >= 3.11
+ - You can check out [pyenv](https://github.com/pyenv/pyenv)
+ as a good tool to be able to manage multiple versions of python
+ on the same system.
+- A running `shortfin` LLM server as described [below](#installstart-shortfin-llm-server)
+ - We will use the shortfin server as the `backend` to generate completions
+ from SGLang's `frontend language`. In this tutorial, you can think of
+ `sglang` as the client and `shortfin` as the server.
+
+### Hardware
+
+- This tutorial is designed to run on an [AMD MI300X GPU](https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html)
+
+## Install/Start `shortfin` LLM server
+
+Follow the steps [here](https://github.com/nod-ai/shark-ai/blob/main/docs/shortfin/llm/user/e2e_llama8b_mi300x.md)
+to export a model with `sharktank` and start a `shortfin` LLM server
+with that model.
+
+## Install sglang
+
+### Install sglang inside of virtual environment
+
+Currently, we have our SGLang integration located at this [forked repo](https://github.com/nod-ai/sglang).
+We can use pip to install it in the same virtual environment that we used
+to start our Shortfin LLM Server.
+
+```bash
+pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"
+```
+
+## Getting started
+
+You can verify the installation/setup through the following examples:
+
+- [Multi-Turn Q&A Example](#multi-turn-qa-example)
+- [Fork Example](#fork-example)
+- [Benchmark Shortfin](#bench-mark-shortfin-w-sglang-bench_serving-script)
+
+## Multi-Turn Q&A example
+
+Now that we have sglang installed, we can run an example to show a multi-turn
+Q&A flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html):
+
+### Open python interpreter
+
+```bash
+python
+```
+
+### Run example
+
+You can copy and paste the following example into your interpreter:
+
+```python
+import sglang as sgl
+
+from sglang.lang.chat_template import get_chat_template
+
+backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000", ) # Change base_url if running at different address
+
+sgl.set_default_backend(backend)
+
+@sgl.function
+def multi_turn_question(s, question_1, question_2):
+ s += sgl.user(question_1)
+ s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
+ s += sgl.user(question_2)
+ s += sgl.assistant(sgl.gen("answer_2", max_tokens=256))
+
+state = multi_turn_question.run(question_1="Name the capital city of the USA.", question_2="The Smithsonian is in this location.")
+
+for m in state.messages():
+ print(m["role"], m["content"])
+```
+
+### Shortfin example output
+
+You should see an output similar to this:
+
+```text
+========== single ==========
+
+user : Name the capital city of the USA
+assistant : The capital city of the United States of America is Washington, D.C. (short for District of Columbia).
+user : The Smithsonian is in this location.
+assistant : The Smithsonian Institution is indeed located in Washington, D.C. and is one of the world's largest and most comprehensive museums and research complexes.
+```
+
+## Fork example
+
+Now that we have sglang installed, we can run an example to show a `fork`
+flow with the SGLang [Frontend Language](https://sgl-project.github.io/frontend/frontend.html):
+
+### Open python interpreter
+
+```bash
+python
+```
+
+### Run example
+
+You can copy and paste the following example into your interpreter:
+
+```python
+import sglang as sgl
+
+from sglang.lang.chat_template import get_chat_template
+
+backend = sgl.Shortfin(chat_template=get_chat_template("llama-3-instruct"), base_url="http://localhost:8000") # Change base_url if running at different address
+
+sgl.set_default_backend(backend)
+
+@sgl.function
+def tip_suggestion(s):
+ s += (
+ "Here are two tips for staying healthy: "
+ "1. Balanced Diet. 2. Regular Exercise.\n\n"
+ )
+ forks = s.fork(2)
+ for i, f in enumerate(forks):
+ f += f"Now, expand tip {i+1} into a paragraph:\n"
+ f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
+ s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
+ s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
+ s += "In summary" + sgl.gen("summary")
+
+state = tip_suggestion.run()
+
+print(state.text())
+```
+
+### Shortfin example output
+
+You should see an output similar to this:
+
+```text
+Here are two tips for staying healthy: 1. Balanced Diet. 2. Regular Exercise.
+
+Tip 1:A balanced diet is important for maintaining good health. It should
+include a variety of foods from all the major food groups, such as fruits,
+vegetables, grains, proteins, and dairy. Eating a balanced diet can help
+prevent chronic diseases such as heart disease, diabetes, and obesity.
+
+Now, expand tip 2 into a paragraph:
+Regular exercise is also important for maintaining good health. It can help
+improve cardiovascular health, strengthen muscles and bones, and reduce the
+risk of chronic diseases. Exercise can also help improve mental health by
+reducing stress and anxiety. It is recommended that adults get at least 150
+minutes of moderate-intensity exercise or 75 minutes of vigorous-intensity
+exercise per week.
+
+Now, combine the two paragraphs into a single paragraph:
+A balanced diet and regular exercise are both important for maintaining good
+health. A balanced diet should include a variety of foods from all the major
+food groups, such as fruits, vegetables, grains, proteins, and dairy.
+Eating a balanced diet can help prevent chronic diseases such as heart disease,
+diabetes, and obesity. Regular exercise is also important for maintaining good
+health. It can help improve cardiovascular health, strengthen muscles and bones,
+and reduce the risk of chronic diseases. Exercise can also help improve mental
+health by reducing stress and anxiety. It is recommended that
+
+Tip 2:Regular exercise is important for maintaining a healthy body and mind.
+It can help improve cardiovascular health, strengthen muscles and bones,
+and reduce the risk of chronic diseases such as diabetes and heart disease.
+Additionally, exercise has been shown to improve mood, reduce stress,
+and increase overall well-being. It is recommended that adults engage in
+at least 150 minutes of moderate-intensity aerobic activity or 75 minutes of
+vigorous-intensity aerobic activity per week, as well as strength training
+exercises at least two days per week.
+
+In summary, a balanced diet and regular exercise are both essential for
+maintaining good health. A balanced diet should include a variety of foods from
+all the major food groups, while regular exercise can help improve
+cardiovascular health, strengthen muscles and bones, reduce the risk of
+chronic diseases, and improve mental health. It is recommended that adults
+engage in at least 150 minutes of moderate-intensity aerobic activity or
+75 minutes of vigorous-intensity aerobic activity per week,
+as well as strength training exercises at least two days per week.
+```
+
+## Benchmark shortfin w/ sglang `bench_serving` script
+
+We can obtain benchmarking metrics using the `bench_serving` script
+provided by SGLang:
+
+**NOTE: Change `--base-url` if running at a different address**
+
+```bash
+python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer /path/to/tokenizer/dir --request-rate 1
+```
+
+There are some more metrics captured, but the most relevant are the following:
+
+- E2E Latency
+- TTFT (Time to First Token)
+- TPOT (Time per Output Token)
+- ITL (Inter-Token Latency)
+- Request Throughput
+- Benchmark Duration
+
+When complete, you should see an output similar to this:
+
+```text
+============ Serving Benchmark Result ============
+Backend: shortfin
+Traffic request rate: 1.0
+Successful requests: 10
+Benchmark duration (s): 427.91
+Total input tokens: 1960
+Total generated tokens: 2774
+Total generated tokens (retokenized): 63
+Request throughput (req/s): 0.02
+Input token throughput (tok/s): 4.58
+Output token throughput (tok/s): 6.48
+----------------End-to-End Latency----------------
+Mean E2E Latency (ms): 416268.77
+Median E2E Latency (ms): 417159.14
+---------------Time to First Token----------------
+Mean TTFT (ms): 292404.29
+Median TTFT (ms): 365989.01
+P99 TTFT (ms): 367325.63
+-----Time per Output Token (excl. 1st token)------
+Mean TPOT (ms): 1359.41
+Median TPOT (ms): 163.96
+P99 TPOT (ms): 6316.12
+---------------Inter-token Latency----------------
+Mean ITL (ms): 2238.99
+Median ITL (ms): 958.75
+P99 ITL (ms): 2719.50
+==================================================
+```
diff --git a/docs/user_guide.md b/docs/user_guide.md
new file mode 100644
index 000000000..a0415eb63
--- /dev/null
+++ b/docs/user_guide.md
@@ -0,0 +1,116 @@
+# SHARK User Guide
+
+These instructions cover the usage of the latest stable release of SHARK. For a more bleeding edge release please install the [nightly releases](nightly_releases.md).
+
+## Prerequisites
+
+Our current user guide requires that you have:
+- Access to a computer with an installed AMD Instinct™ MI300x Series Accelerator
+- Installed a compatible version of Linux and ROCm on the computer (see the [ROCm compatability matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html))
+
+## Set up Environment
+
+This section will help you install Python and set up a Python environment with venv.
+
+Officially we support Python versions: 3.11, 3.12, 3.13
+
+The rest of this guide assumes you are using Python 3.11.
+
+### Install Python
+To install Python 3.11 on Ubuntu:
+
+```bash
+sudo apt install python3.11 python3.11-dev python3.11-venv
+
+which python3.11
+# /usr/bin/python3.11
+```
+
+### Create a Python Environment
+
+Setup your Python environment with the following commands:
+
+```bash
+# Set up a virtual environment to isolate packages from other envs.
+python3.11 -m venv 3.11.venv
+source 3.11.venv/bin/activate
+
+# Optional: faster installation of torch with just CPU support.
+# See other options at https://pytorch.org/get-started/locally/
+pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
+```
+
+## Install SHARK and its dependencies
+
+```bash
+pip install shark-ai[apps]
+```
+
+> [!TIP]
+> To switch from the stable release channel to the nightly release channel,
+> see [`nightly_releases.md`](./nightly_releases.md).
+
+### Test the installation.
+
+```
+python -m shortfin_apps.sd.server --help
+```
+
+## Quickstart
+
+### Run the SDXL Server
+
+Run the [SDXL Server](../shortfin/python/shortfin_apps/sd/README.md#Start-SDXL-Server)
+
+### Run the SDXL Client
+
+```
+python -m shortfin_apps.sd.simple_client --interactive
+```
+
+Congratulations!!! At this point you can play around with the server and client based on your usage.
+
+### Note: Server implementation scope
+
+The SDXL server's implementation does not account for extremely large client batches. Normally, for heavy workloads, services would be composed under a load balancer to ensure each service is fed with requests optimally. For most cases outside of large-scale deployments, the server's internal batching/load balancing is sufficient.
+
+### Update flags
+
+Please see --help for both the server and client for usage instructions. Here's a quick snapshot.
+
+#### Update server options:
+
+| Flags | options |
+|---|---|
+|--host HOST |
+|--port PORT | server port |
+|--root-path ROOT_PATH |
+|--timeout-keep-alive |
+|--device | local-task,hip,amdgpu | amdgpu only supported in this release
+|--target | gfx942,gfx1100 | gfx942 only supported in this release
+|--device_ids |
+|--tokenizers |
+|--model_config |
+| --workers_per_device |
+| --fibers_per_device |
+| --isolation | per_fiber, per_call, none |
+| --show_progress |
+| --trace_execution |
+| --amdgpu_async_allocations |
+| --splat |
+| --build_preference | compile,precompiled |
+| --compile_flags |
+| --flagfile FLAGFILE |
+| --artifacts_dir ARTIFACTS_DIR | Where to store cached artifacts from the Cloud |
+
+#### Update client with different options:
+
+| Flags |options|
+|---|---
+|--file |
+|--reps |
+|--save | Whether to save image generated by the server |
+|--outputdir| output directory to store images generated by SDXL |
+|--steps |
+|--interactive |
+|--port| port to interact with server |
diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt
deleted file mode 100644
index 86d0267a5..000000000
--- a/libshortfin/CMakeLists.txt
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions. See
-# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
-# Apache-2.0 WITH LLVM-exception
-
-cmake_minimum_required(VERSION 3.28)
-
-if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR)
- message(
- FATAL_ERROR
- "Do not build in-source. Please remove CMakeCache.txt and the CMakeFiles/ directory. Then build out-of-source."
- )
-endif()
-
-project(
- "libshortfin"
- VERSION 0.9
- LANGUAGES C CXX)
-
-include(CMakeDependentOption)
-
-set(SOVERSION 1)
-
-set(CMAKE_C_STANDARD 11)
-set(CMAKE_CXX_STANDARD 20)
-# https://discourse.cmake.org/t/cmake-3-28-cmake-cxx-compiler-clang-scan-deps-notfound-not-found/9244/3
-set(CMAKE_CXX_SCAN_FOR_MODULES 0)
-
-# Problems with linking libfmt without PIC.
-# Turn on PIC on non windows targets.
-if(NOT WIN32)
- set(CMAKE_POSITION_INDEPENDENT_CODE ON)
-endif()
-
-# build options
-option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF)
-option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON)
-option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" ON)
-
-set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source")
-
-# Options for building static or dynamic libraries.
-option(SHORTFIN_BUILD_STATIC "Builds static libraries" OFF)
-option(SHORTFIN_BUILD_DYNAMIC "Builds dynamic libraries" ON)
-cmake_dependent_option(SHORTFIN_LINK_DYNAMIC "Links internal binaries against static libshortfin.a" ON "SHORTFIN_BUILD_DYNAMIC" OFF)
-if(NOT SHORTFIN_BUILD_STATIC AND NOT SHORTFIN_BUILD_DYNAMIC)
- message(FATAL_ERROR "One of SHORTFIN_BUILD_STATIC or SHORTFIN_BUILD_DYNAMIC must be ON")
-endif()
-message(STATUS "Shortfin build static = ${SHORTFIN_BUILD_STATIC}, dynamic = ${SHORTFIN_BUILD_DYNAMIC}")
-if(SHORTFIN_LINK_DYNAMIC)
- message(STATUS "Dynamic linking to shortfin")
- set(SHORTFIN_LINK_LIBRARY_NAME "shortfin")
-else()
- message(STATUS "Static linking to shortfin-static")
- set(SHORTFIN_LINK_LIBRARY_NAME "shortfin-static")
-endif()
-
-# Enabling ASAN. Note that this will work best if building in a completely
-# bundled fashion and with an ASAN rigged CPython. Otherwise, various LD_PRELOAD
-# hacks are needed. This is merely a develope convenience: people are more
-# than welcome to set flags themselves.
-option(SHORTFIN_ENABLE_ASAN "Enable ASAN" OFF)
-if(SHORTFIN_ENABLE_ASAN)
- add_compile_options(-fsanitize=address)
- add_link_options(-fsanitize=address)
-
- # Enable more ASAN checks.
- add_compile_definitions(IREE_SANITIZER_ADDRESS)
-endif()
-
-option(SHORTFIN_SYSTEMS_AMDGPU "Builds for AMD GPU systems" ON)
-message(STATUS "libshortfin supported systems:")
-if(SHORTFIN_SYSTEMS_AMDGPU)
- message(STATUS " - AMD GPU")
- add_compile_definitions("SHORTFIN_HAVE_AMDGPU")
-endif()
-message(STATUS " - Host")
-
-include(FetchContent)
-
-# Includes.
-list(APPEND CMAKE_MODULE_PATH
- ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/
-)
-include(shortfin_library)
-
-# Dependencies.
-
-if(SHORTFIN_BUNDLE_DEPS)
- ## fmt
- FetchContent_Declare(
- fmt
- GIT_REPOSITORY https://github.com/fmtlib/fmt.git
- GIT_TAG e69e5f977d458f2650bb346dadf2ad30c5320281 # 10.2.1 (sync with spdlog)
- )
-
- ## spdlog
- # We build fmt from source instead, because we also use fmt.
- set(SPDLOG_FMT_EXTERNAL ON)
- FetchContent_Declare(
- spdlog
- GIT_REPOSITORY https://github.com/gabime/spdlog.git
- GIT_TAG 2d4acf8cc321d7783d8f2e22e17a794c6d0e9450 # v1.14.1
- )
-
- ## xtl: required for xtensor
- FetchContent_Declare(
- xtl
- GIT_REPOSITORY https://github.com/xtensor-stack/xtl.git
- GIT_TAG a7c1c5444dfc57f76620391af4c94785ff82c8d6 # v0.7.7
- )
-
- ## xtensor
- FetchContent_Declare(
- xtensor
- GIT_REPOSITORY https://github.com/xtensor-stack/xtensor.git
- GIT_TAG 3634f2ded19e0cf38208c8b86cea9e1d7c8e397d # v0.25.0
- )
-
- FetchContent_MakeAvailable(fmt spdlog xtl xtensor)
-else()
- find_package(spdlog)
- find_package(xtensor)
-endif()
-
-## iree runtime
-
-if (NOT SHORTFIN_IREE_SOURCE_DIR AND SHORTFIN_BUNDLE_DEPS)
- FetchContent_Declare(
- iree
- GIT_REPOSITORY https://github.com/iree-org/iree.git
- GIT_TAG candidate-20240904.1006
- # TODO: We shouldn't have to pull googletest when we are not building tests.
- # This needs to be fixed with IREE.
- GIT_SUBMODULES "third_party/benchmark third_party/cpuinfo third_party/flatcc third_party/hip-build-deps third_party/googletest"
- GIT_SHALLOW TRUE
- )
- FetchContent_GetProperties(iree)
- if(NOT iree_POPULATED)
- FetchContent_Populate(iree)
- endif()
- set(SHORTFIN_IREE_SOURCE_DIR ${iree_SOURCE_DIR})
-endif()
-
-if(SHORTFIN_IREE_SOURCE_DIR)
- set(IREE_BUILD_COMPILER OFF)
- set(IREE_BUILD_TESTS OFF)
- set(IREE_BUILD_SAMPLES OFF)
- # Disable missing submodules error because we are only building the runtime.
- set(IREE_ERROR_ON_MISSING_SUBMODULES OFF)
- # Only enable local_sync/local_task/hip drivers for now.
- set(IREE_HAL_DRIVER_DEFAULTS OFF)
- set(IREE_HAL_DRIVER_LOCAL_SYNC ON)
- set(IREE_HAL_DRIVER_LOCAL_TASK ON)
- if(SHORTFIN_SYSTEMS_AMDGPU)
- set(IREE_HAL_DRIVER_HIP ON)
- endif()
- add_subdirectory(${SHORTFIN_IREE_SOURCE_DIR} shortfin_iree SYSTEM EXCLUDE_FROM_ALL)
-else()
- # Try to find iree using find_package
- find_package(IREERuntime)
-endif()
-
-# tests
-
-if(SHORTFIN_BUILD_TESTS)
- if (NOT SHORTFIN_IREE_SOURCE_DIR)
- # For now we use gtest shipped alongside with IREE.
- FetchContent_Declare(
- googletest
- URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
- )
- # For Windows: Prevent overriding the parent project's compiler/linker settings
- set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
- FetchContent_MakeAvailable(googletest)
- endif()
- include(GoogleTest)
- enable_testing()
-endif()
-
-
-add_subdirectory(src)
-
-if(SHORTFIN_BUILD_PYTHON_BINDINGS)
- find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
- add_subdirectory(bindings/python)
- set(SHORTFIN_PYTHON_CPP_PREBUILT "TRUE") # See setup.py.
- configure_file(setup.py setup.py @ONLY)
- configure_file(pyproject.toml pyproject.toml COPYONLY)
-endif()
diff --git a/libshortfin/README.md b/libshortfin/README.md
deleted file mode 100644
index 435dcbb87..000000000
--- a/libshortfin/README.md
+++ /dev/null
@@ -1,73 +0,0 @@
-# libshortfin - SHARK C++ inference library
-
-## Dev Builds
-
-Library dependencies:
-
-* [spdlog](https://github.com/gabime/spdlog)
-* [xtensor](https://github.com/xtensor-stack/xtensor)
-* [iree runtime](https://github.com/iree-org/iree)
-
-On recent Ubuntu, the primary dependencies can be satisfied via:
-
-```
-apt install libspdlog-dev libxtensor-dev
-```
-
-CMake must be told how to find the IREE runtime, either from a distribution
-tarball, or local build/install dir. For a local build directory, pass:
-
-```
-# Assumes that the iree-build directory is adjacent to this repo.
--DCMAKE_PREFIX_PATH=$(pwd)/../../iree-build/lib/cmake/IREE
-```
-
-One liner recommended CMake command (note that CMAKE_LINKER_TYPE requires
-cmake>=3.29):
-
-```
-cmake -GNinja -S. -Bbuild \
- -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
- -DCMAKE_LINKER_TYPE=LLD \
- -DCMAKE_PREFIX_PATH=$(pwd)/../../iree-build/lib/cmake/IREE
-```
-
-## Building Python Bindings
-
-If using a Python based development flow, there are two options:
-
-1. `pip install -v .` to build and install the library (TODO: Not yet implemented).
-2. Build with cmake and `-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON` and then
- from the `build/` directory, run `pip install -v -e .` to create an
- editable install that will update as you build the C++ project.
-
-If predominantly developing with a C++ based flow, the second option is
-recommended. Your python install should track any source file changes or
-builds without further interaction. Re-installing will be necessary if package
-structure changes significantly.
-
-## Running Tests
-
-The project uses a combination of ctest for native C++ tests and pytest. Much
-of the functionality is only tested via the Python tests, using the
-`_shortfin.lib` internal implementation directly. In order to run these tests,
-you must have installed the Python package as per the above steps.
-
-Which style of test is used is pragmatic and geared at achieving good test
-coverage with a minimum of duplication. Since it is often much more expensive
-to build native tests of complicated flows, many things are only tested via
-Python. This does not preclude having other language bindings later, but it
-does mean that the C++ core of the library must always be built with the
-Python bindings to test the most behavior. Given the target of the project,
-this is not considered to be a significant issue.
-
-# Production Library Building
-
-In order to build a production library, additional build steps are typically
-recommended:
-
-* Compile all deps with the same compiler/linker for LTO compatibility
-* Provide library dependencies manually and compile them with LTO
-* Compile dependencies with `-fvisibility=hidden`
-* Enable LTO builds of libshortfin
-* Set flags to enable symbol versioning
diff --git a/libshortfin/bindings/python/CMakeLists.txt b/libshortfin/bindings/python/CMakeLists.txt
deleted file mode 100644
index f1e999783..000000000
--- a/libshortfin/bindings/python/CMakeLists.txt
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions. See
-# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
-# Apache-2.0 WITH LLVM-exception
-
-# libshortfin publishes multiple python packages: - _shortfin: Trampoline
-# __init__.py which looks at environment variables to load an appropriate native
-# library. - _shortfin_default.lib: Native library as a default, uninstrumented
-# build. - _shortfin_tracing.lib: Native library with tracing enabled (TODO). -
-# Others.
-
-# nanobind
-FetchContent_Declare(
- nanobind
- GIT_REPOSITORY https://github.com/wjakob/nanobind.git
- GIT_TAG 9641bb7151f04120013b812789b3ebdfa7e7324f # 2.1.0
-)
-FetchContent_MakeAvailable(nanobind)
-
-nanobind_add_module(shortfin_python_extension NB_STATIC LTO
- array_binding.cc
- lib_ext.cc
-)
-
-set_target_properties(shortfin_python_extension
- PROPERTIES OUTPUT_NAME "_shortfin_default/lib")
-
-target_link_libraries(shortfin_python_extension
- PRIVATE ${SHORTFIN_LINK_LIBRARY_NAME}
-)
-
-nanobind_add_stub(
- shortfin_python_extension_stub
- MODULE _shortfin_default.lib
- OUTPUT _shortfin_default/lib.pyi
- DEPENDS shortfin_python_extension
-)
diff --git a/libshortfin/bindings/python/array_binding.cc b/libshortfin/bindings/python/array_binding.cc
deleted file mode 100644
index 608aa4944..000000000
--- a/libshortfin/bindings/python/array_binding.cc
+++ /dev/null
@@ -1,297 +0,0 @@
-// Copyright 2024 Advanced Micro Devices, Inc.
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "./lib_ext.h"
-#include "./utils.h"
-#include "shortfin/array/api.h"
-
-using namespace shortfin::array;
-
-namespace shortfin::python {
-
-namespace {
-static const char DOCSTRING_ARRAY_COPY_FROM[] =
- R"(Copy contents from a source array to this array.
-
-Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
-)";
-
-static const char DOCSTRING_ARRAY_COPY_TO[] =
- R"(Copy contents this array to a destination array.
-
-Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
-)";
-
-static const char DOCSTRING_ARRAY_FILL[] = R"(Fill an array with a value.
-
-Equivalent to `array.storage.fill(pattern)`.
-)";
-
-static const char DOCSTRING_STORAGE_DATA[] = R"(Access raw binary contents.
-
-Accessing `foo = storage.data` is equivalent to `storage.data.map(read=True)`.
-The returned object is a context manager that will close on exit.
-
-Assigning `storage.data = array.array("f", [1.0])` will copy that raw data
-from the source object using the buffer protocol. The source data must be
-less than or equal to the length of the storage object. Note that the entire
-storage is mapped as write-only/discardable, and writing less than the storage
-bytes leaves any unwritten contents in an undefined state.
-
-As with `map`, this will only work on buffers that are host visible, which
-includes all host buffers and device buffers created with the necessary access.
-)";
-
-static const char DOCSTRING_STORAGE_COPY_FROM[] =
- R"(Copy contents from a source storage to this array.
-
-This operation executes asynchronously and the effect will only be visible
-once the execution scope has been synced to the point of mutation.
-)";
-
-static const char DOCSTRING_STORAGE_FILL[] = R"(Fill a storage with a value.
-
-Takes as argument any value that can be interpreted as a buffer with the Python
-buffer protocol of size 1, 2, or 4 bytes. The storage will be filled uniformly
-with the pattern.
-
-This operation executes asynchronously and the effect will only be visible
-once the execution scope has been synced to the point of mutation.
-)";
-
-static const char DOCSTRING_STORAGE_MAP[] =
- R"(Create a mapping of the buffer contents in host memory.
-
-Support kwargs of:
-
-read: Enables read access to the mapped memory.
-write: Enables write access to the mapped memory and will flush upon close
- (for non-unified memory systems).
-discard: Indicates that the entire memory map should be treated as if it will
- be overwritten. Initial contents will be undefined.
-
-Mapping memory for access from the host requires a compatible buffer that has
-been created with host visibility (which includes host buffers).
-
-The returned mapping object is a context manager that will close/flush on
-exit. Alternatively, the `close()` method can be invoked explicitly.
-)";
-
-// Does in-place creation of a mapping object and stores a pointer to the
-// contained array::mapping C++ object.
-py::object CreateMappingObject(mapping **out_cpp_mapping) {
- py::object py_mapping = py::inst_alloc(py::type());
- mapping *cpp_mapping = py::inst_ptr(py_mapping);
- new (cpp_mapping) mapping();
- py::inst_mark_ready(py_mapping);
- *out_cpp_mapping = cpp_mapping;
- return py_mapping;
-}
-
-} // namespace
-
-void BindArray(py::module_ &m) {
- py::class_(m, "DType")
- .def_prop_ro("is_boolean", &DType::is_boolean)
- .def_prop_ro("is_integer", &DType::is_integer)
- .def_prop_ro("is_float", &DType::is_float)
- .def_prop_ro("is_complex", &DType::is_complex)
- .def_prop_ro("bit_count", &DType::bit_count)
- .def_prop_ro("is_byte_aligned", &DType::is_byte_aligned)
- .def_prop_ro("dense_byte_count", &DType::dense_byte_count)
- .def("is_integer_bitwidth", &DType::is_integer_bitwidth)
- .def(py::self == py::self)
- .def("__repr__", &DType::name);
-
-#define SHORTFIN_DTYPE_HANDLE(et, ident) m.attr(#ident) = DType::ident();
-#include "shortfin/array/dtypes.inl"
-#undef SHORTFIN_DTYPE_HANDLE
-
- // storage
- py::class_(m, "storage")
- .def("__sfinv_marshal__",
- [](device_array *self, py::capsule inv_capsule, int barrier) {
- auto *inv =
- static_cast(inv_capsule.data());
- static_cast(self)
- ->AddAsInvocationArgument(
- inv, static_cast(barrier));
- })
- .def_static(
- "allocate_host",
- [](local::ScopedDevice &device, iree_device_size_t allocation_size) {
- return storage::allocate_host(device, allocation_size);
- },
- py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
- .def_static(
- "allocate_device",
- [](local::ScopedDevice &device, iree_device_size_t allocation_size) {
- return storage::allocate_device(device, allocation_size);
- },
- py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
- .def(
- "fill",
- [](storage &self, py::handle buffer) {
- Py_buffer py_view;
- int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
- if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
- throw py::python_error();
- }
- PyBufferReleaser py_view_releaser(py_view);
- self.fill(py_view.buf, py_view.len);
- },
- py::arg("pattern"), DOCSTRING_STORAGE_FILL)
- .def(
- "copy_from", [](storage &self, storage &src) { self.copy_from(src); },
- py::arg("source_storage"), DOCSTRING_STORAGE_COPY_FROM)
- .def(
- "map",
- [](storage &self, bool read, bool write, bool discard) {
- int access = 0;
- if (read) access |= IREE_HAL_MEMORY_ACCESS_READ;
- if (write) access |= IREE_HAL_MEMORY_ACCESS_WRITE;
- if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD;
- if (!access) {
- throw std::invalid_argument(
- "One of the access flags must be set");
- }
- mapping *cpp_mapping = nullptr;
- py::object py_mapping = CreateMappingObject(&cpp_mapping);
- self.map_explicit(
- *cpp_mapping,
- static_cast(access));
- return py_mapping;
- },
- py::kw_only(), py::arg("read") = false, py::arg("write") = false,
- py::arg("discard") = false, DOCSTRING_STORAGE_MAP)
- // The 'data' prop is a short-hand for accessing the backing storage
- // in a one-shot manner (as for reading or writing). Getting the attribute
- // will map for read and return a memory view (equiv to map(read=True)).
- // On write, it will accept an object implementing the buffer protocol
- // and write/discard the backing storage.
- .def_prop_rw(
- "data",
- [](storage &self) {
- mapping *cpp_mapping = nullptr;
- py::object py_mapping = CreateMappingObject(&cpp_mapping);
- *cpp_mapping = self.map_read();
- return py_mapping;
- },
- [](storage &self, py::handle buffer_obj) {
- PyBufferRequest src_info(buffer_obj, PyBUF_SIMPLE);
- auto dest_data = self.map_write_discard();
- if (src_info.view().len > dest_data.size()) {
- throw std::invalid_argument(
- fmt::format("Cannot write {} bytes into buffer of {} bytes",
- src_info.view().len, dest_data.size()));
- }
- std::memcpy(dest_data.data(), src_info.view().buf,
- src_info.view().len);
- },
- DOCSTRING_STORAGE_DATA)
- .def(py::self == py::self)
- .def("__repr__", &storage::to_s);
-
- // mapping
- auto mapping_class = py::class_(m, "mapping");
- mapping_class.def("close", &mapping::reset)
- .def_prop_ro("valid", [](mapping &self) -> bool { return self; })
- .def("__enter__", [](py::object self_obj) { return self_obj; })
- .def(
- "__exit__",
- [](mapping &self, py::handle exc_type, py::handle exc_value,
- py::handle exc_tb) { self.reset(); },
- py::arg("exc_type").none(), py::arg("exc_value").none(),
- py::arg("exc_tb").none());
- struct MappingBufferHandler {
- int operator()(mapping &self, Py_buffer *view, int flags) {
- view->buf = self.data();
- view->len = self.size();
- view->readonly = !self.writable();
- view->itemsize = 1;
- view->format = (char *)"B"; // Byte
- view->ndim = 1;
- view->shape = nullptr;
- view->strides = nullptr;
- view->suboffsets = nullptr;
- view->internal = nullptr;
- return 0;
- }
- };
- BindBufferProtocol(mapping_class);
-
- // base_array and subclasses
- py::class_(m, "base_array")
- .def_prop_ro("dtype", &base_array::dtype)
- .def_prop_ro("shape", &base_array::shape);
- py::class_(m, "device_array")
- .def("__init__", [](py::args, py::kwargs) {})
- .def_static("__new__",
- [](py::handle py_type, class storage storage,
- std::span shape, DType dtype) {
- return custom_new_keep_alive(
- py_type, /*keep_alive=*/storage.scope(), storage, shape,
- dtype);
- })
- .def_static("__new__",
- [](py::handle py_type, local::ScopedDevice &device,
- std::span shape, DType dtype) {
- return custom_new_keep_alive(
- py_type, /*keep_alive=*/device.scope(),
- device_array::for_device(device, shape, dtype));
- })
- .def("__sfinv_marshal__",
- [](device_array *self, py::capsule inv_capsule, int barrier) {
- auto *inv =
- static_cast(inv_capsule.data());
- static_cast(self)
- ->AddAsInvocationArgument(
- inv, static_cast(barrier));
- })
- .def_static("for_device",
- [](local::ScopedDevice &device, std::span shape,
- DType dtype) {
- return custom_new_keep_alive(
- py::type(), /*keep_alive=*/device.scope(),
- device_array::for_device(device, shape, dtype));
- })
- .def_static("for_host",
- [](local::ScopedDevice &device, std::span shape,
- DType dtype) {
- return custom_new_keep_alive(
- py::type(), /*keep_alive=*/device.scope(),
- device_array::for_host(device, shape, dtype));
- })
- .def("for_transfer",
- [](device_array &self) {
- return custom_new_keep_alive(
- py::type(),
- /*keep_alive=*/self.device().scope(), self.for_transfer());
- })
- .def_prop_ro("device", &device_array::device,
- py::rv_policy::reference_internal)
- .def_prop_ro("storage", &device_array::storage,
- py::rv_policy::reference_internal)
-
- .def(
- "fill",
- [](py::handle_t self, py::handle buffer) {
- self.attr("storage").attr("fill")(buffer);
- },
- py::arg("pattern"), DOCSTRING_ARRAY_FILL)
- .def("copy_from", &device_array::copy_from, py::arg("source_array"),
- DOCSTRING_ARRAY_COPY_FROM)
- .def("copy_to", &device_array::copy_to, py::arg("dest_array"),
- DOCSTRING_ARRAY_COPY_TO)
- .def("__repr__", &device_array::to_s)
- .def("__str__", [](device_array &self) -> std::string {
- auto contents = self.contents_to_s();
- if (!contents) return "<>";
- return *contents;
- });
-}
-
-} // namespace shortfin::python
diff --git a/libshortfin/build_tools/cmake/shortfin_library.cmake b/libshortfin/build_tools/cmake/shortfin_library.cmake
deleted file mode 100644
index bae9fe115..000000000
--- a/libshortfin/build_tools/cmake/shortfin_library.cmake
+++ /dev/null
@@ -1,136 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions. See
-# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
-# Apache-2.0 WITH LLVM-exception
-
-set(SHORTFIN_DEFAULT_COPTS
- # General clang and GCC options application to C and C++.
- $<$:
- -Wall
- -Werror
- >
-
- # General MSVC options applicable to C and C++.
- $<$:
- >
-)
-
-function(shortfin_public_library)
- cmake_parse_arguments(
- _RULE
- ""
- "NAME"
- "COMPONENTS"
- ${ARGN}
- )
- if(SHORTFIN_BUILD_STATIC)
- # Static library.
- shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS})
- add_library("${_RULE_NAME}-static" STATIC)
- target_link_libraries(
- "${_RULE_NAME}-static" PUBLIC ${_STATIC_COMPONENTS}
- )
- endif()
-
- if(SHORTFIN_BUILD_DYNAMIC)
- # Dylib library.
- shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS})
- add_library("${_RULE_NAME}" SHARED)
- target_compile_definitions("${_RULE_NAME}" INTERFACE _SHORTFIN_USING_DYLIB)
- target_link_libraries(
- "${_RULE_NAME}" PUBLIC ${_DYLIB_COMPONENTS}
- )
- endif()
-endfunction()
-
-function(shortfin_cc_component)
- cmake_parse_arguments(
- _RULE
- ""
- "NAME"
- "HDRS;SRCS;DEPS;COMPONENTS"
- ${ARGN}
- )
- if(SHORTFIN_BUILD_STATIC)
- # Static object library.
- set(_STATIC_OBJECTS_NAME "${_RULE_NAME}.objects")
- shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS})
- add_library(${_STATIC_OBJECTS_NAME} OBJECT)
- target_sources(${_STATIC_OBJECTS_NAME}
- PRIVATE
- ${_RULE_SRCS}
- ${_RULE_HDRS}
- )
- target_compile_options(${_STATIC_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS})
- target_link_libraries(${_STATIC_OBJECTS_NAME}
- PUBLIC
- _shortfin_defs
- ${_STATIC_COMPONENTS}
- ${_RULE_DEPS}
- )
- endif()
-
- if(SHORTFIN_BUILD_DYNAMIC)
- set(CMAKE_POSITION_INDEPENDENT_CODE ON)
- set(_DYLIB_OBJECTS_NAME "${_RULE_NAME}.dylib.objects")
- shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS})
- # Dylib object library.
- add_library(${_DYLIB_OBJECTS_NAME} OBJECT)
- target_sources(${_DYLIB_OBJECTS_NAME}
- PRIVATE
- ${_RULE_SRCS}
- ${_RULE_HDRS}
- )
- target_compile_options(${_DYLIB_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS})
- target_link_libraries(${_DYLIB_OBJECTS_NAME}
- PUBLIC
- _shortfin_defs
- ${_DYLIB_COMPONENTS}
- ${_RULE_DEPS}
- )
- set_target_properties(
- ${_DYLIB_OBJECTS_NAME} PROPERTIES
- CXX_VISIBILITY_PRESET hidden
- C_VISIBILITY_PRESET hidden
- VISIBILITY_INLINES_HIDDEN ON
- )
- target_compile_definitions(${_DYLIB_OBJECTS_NAME}
- PRIVATE _SHORTFIN_BUILDING_DYLIB)
- endif()
-endfunction()
-
-function(shortfin_components_to_static_libs out_static_libs)
- set(_LIBS ${ARGN})
- list(TRANSFORM _LIBS APPEND ".objects")
- set(${out_static_libs} ${_LIBS} PARENT_SCOPE)
-endfunction()
-
-function(shortfin_components_to_dynamic_libs out_dynamic_libs)
- set(_LIBS ${ARGN})
- list(TRANSFORM _LIBS APPEND ".dylib.objects")
- set(${out_dynamic_libs} "${_LIBS}" PARENT_SCOPE)
-endfunction()
-
-function(shortfin_gtest_test)
- cmake_parse_arguments(
- _RULE
- ""
- "NAME"
- "SRCS;DEPS"
- ${ARGN}
- )
-
- if(NOT SHORTFIN_BUILD_TESTS)
- return()
- endif()
-
- add_executable(${_RULE_NAME} ${_RULE_SRCS})
- target_link_libraries(${_RULE_NAME} PRIVATE
- ${_RULE_DEPS}
- ${SHORTFIN_LINK_LIBRARY_NAME}
- GTest::gmock
- GTest::gtest_main
- )
- gtest_discover_tests(${_RULE_NAME})
-endfunction()
diff --git a/libshortfin/build_tools/python_lsan_suppressions.txt b/libshortfin/build_tools/python_lsan_suppressions.txt
deleted file mode 100644
index cc768d575..000000000
--- a/libshortfin/build_tools/python_lsan_suppressions.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-leak:PyUnicode_New
-leak:_PyUnicodeWriter_Finish
diff --git a/libshortfin/pyproject.toml b/libshortfin/pyproject.toml
deleted file mode 100644
index f1b34c64a..000000000
--- a/libshortfin/pyproject.toml
+++ /dev/null
@@ -1,20 +0,0 @@
-[build-system]
-requires = [
- "cmake>=3.29",
- "setuptools>=61.0",
- "wheel",
- "ninja",
-]
-build-backend = "setuptools.build_meta"
-
-[tool.pytest.ini_options]
-addopts = [
- "-ra",
- "--import-mode=importlib",
-]
-markers = [
- "requires_amd_gpu: tests that require and AMD GPU (deselect with '-m \"not requires_amd_gpu\"')",
-]
-testpaths = [
- "tests",
-]
diff --git a/libshortfin/requirements-iree-compiler.txt b/libshortfin/requirements-iree-compiler.txt
deleted file mode 100644
index 34228ccf4..000000000
--- a/libshortfin/requirements-iree-compiler.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-# Keep in sync with IREE_REF in CI and GIT_TAG in CMakeLists.txt
--f https://iree.dev/pip-release-links.html
-iree-compiler==20240904.1006
diff --git a/libshortfin/setup.py b/libshortfin/setup.py
deleted file mode 100644
index b28b8d114..000000000
--- a/libshortfin/setup.py
+++ /dev/null
@@ -1,222 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from distutils.core import setup, Extension
-import sys
-import shutil
-import subprocess
-import os
-from pathlib import Path
-from distutils.command.build import build as _build
-from setuptools.command.build_ext import build_ext as _build_ext
-from setuptools.command.build_py import build_py as _build_py
-
-
-# This file can be generated into the build directory to allow an arbitrary
-# CMake built version of the project to be installed into a venv for development.
-# This can be detected if the CPP_PREBUILT global contains the string
-# "TRUE", which will be the case if generated.
-CPP_PREBUILT = "@SHORTFIN_PYTHON_CPP_PREBUILT@"
-CPP_PREBUILT_SOURCE_DIR = "@libshortfin_SOURCE_DIR@"
-CPP_PREBUILT_BINARY_DIR = "@libshortfin_BINARY_DIR@"
-
-SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__))
-
-
-def is_cpp_prebuilt():
- return CPP_PREBUILT == "TRUE"
-
-
-if is_cpp_prebuilt():
- print("setup.py running in pre-built mode:", file=sys.stderr)
- SOURCE_DIR = Path(CPP_PREBUILT_SOURCE_DIR)
- BINARY_DIR = Path(CPP_PREBUILT_BINARY_DIR)
-else:
- print("setup.py running in cmake build mode:", file=sys.stderr)
- # setup.py is in the source directory.
- SOURCE_DIR = Path(SETUPPY_DIR)
- BINARY_DIR = Path(os.path.join(SETUPPY_DIR, "build", "b"))
-
-print(f" SOURCE_DIR = {SOURCE_DIR}", file=sys.stderr)
-print(f" BINARY_DIR = {BINARY_DIR}", file=sys.stderr)
-
-# Due to a quirk of setuptools, that package_dir map must only contain
-# paths relative to the directory containing setup.py. Why? No one knows.
-REL_SOURCE_DIR = SOURCE_DIR.relative_to(SETUPPY_DIR, walk_up=True)
-REL_BINARY_DIR = BINARY_DIR.relative_to(SETUPPY_DIR, walk_up=True)
-
-
-class CMakeExtension(Extension):
- def __init__(self, name, sourcedir=""):
- Extension.__init__(self, name, sources=[])
- self.sourcedir = os.path.abspath(sourcedir)
-
-
-class CustomBuild(_build):
- def run(self):
- self.run_command("build_py")
- self.run_command("build_ext")
- self.run_command("build_scripts")
-
-
-class NoopBuildExtension(_build_ext):
- def build_extension(self, ext):
- ...
-
- def copy_extensions_to_source(self, *args, **kwargs):
- ...
-
-
-def maybe_nuke_cmake_cache(cmake_build_dir):
- # From run to run under pip, we can end up with different paths to ninja,
- # which isn't great and will confuse cmake. Detect if the location of
- # ninja changes and force a cache flush.
- ninja_path = ""
- try:
- import ninja
- except ModuleNotFoundError:
- pass
- else:
- ninja_path = ninja.__file__
- expected_stamp_contents = f"{sys.executable}\n{ninja_path}"
-
- # In order to speed things up on CI and not rebuild everything, we nuke
- # the CMakeCache.txt file if the path to the Python interpreter changed.
- # Ideally, CMake would let us reconfigure this dynamically... but it does
- # not (and gets very confused).
- PYTHON_STAMP_FILE = os.path.join(cmake_build_dir, "python_stamp.txt")
- if os.path.exists(PYTHON_STAMP_FILE):
- with open(PYTHON_STAMP_FILE, "rt") as f:
- actual_stamp_contents = f.read()
- if actual_stamp_contents == expected_stamp_contents:
- # All good.
- return
-
- # Mismatch or not found. Clean it.
- cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt")
- if os.path.exists(cmake_cache_file):
- print("Removing CMakeCache.txt because Python version changed", file=sys.stderr)
- os.remove(cmake_cache_file)
-
- # And write.
- with open(PYTHON_STAMP_FILE, "wt") as f:
- f.write(expected_stamp_contents)
-
-
-class CMakeBuildPy(_build_py):
- def run(self):
- # The super-class handles the pure python build.
- super().run()
-
- # Build using cmake if not in prebuild mode.
- if not is_cpp_prebuilt():
-
- # Build extension using cmake.
- print("*****************************", file=sys.stderr)
- print("* Building libshortfin *", file=sys.stderr)
- print("*****************************", file=sys.stderr)
-
- cfg = os.getenv("SHORTFIN_CMAKE_BUILD_TYPE", "Release")
-
- CMAKE_BUILD_DIR = BINARY_DIR
-
- # Configure CMake.
- os.makedirs(BINARY_DIR, exist_ok=True)
- maybe_nuke_cmake_cache(CMAKE_BUILD_DIR)
- print(f"CMake build dir: {CMAKE_BUILD_DIR}", file=sys.stderr)
- cmake_args = [
- "-GNinja",
- "--log-level=VERBOSE",
- "-DSHORTFIN_BUNDLE_DEPS=ON",
- f"-DCMAKE_BUILD_TYPE={cfg}",
- "-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON",
- # TODO: This shouldn't be hardcoded... but shortfin doesn't
- # compile without it.
- "-DCMAKE_C_COMPILER=clang",
- "-DCMAKE_CXX_COMPILER=clang++",
- ]
-
- # Only do a from-scratch configure if not already configured.
- cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt")
- if not os.path.exists(cmake_cache_file):
- print(f"Configuring with: {cmake_args}", file=sys.stderr)
- subprocess.check_call(
- ["cmake", SOURCE_DIR] + cmake_args, cwd=CMAKE_BUILD_DIR
- )
- else:
- print(f"Not re-configing (already configured)", file=sys.stderr)
-
- # Build.
- subprocess.check_call(["cmake", "--build", "."], cwd=CMAKE_BUILD_DIR)
- print("Build complete.", file=sys.stderr)
-
- # We only take _shortfin_default from the build.
- target_dir = os.path.join(
- os.path.abspath(self.build_lib), "_shortfin_default"
- )
- print(f"Building in target: {target_dir}", file=sys.stderr)
- os.makedirs(target_dir, exist_ok=True)
- print("Copying build to target.", file=sys.stderr)
- if os.path.exists(target_dir):
- shutil.rmtree(target_dir)
- shutil.copytree(
- os.path.join(
- CMAKE_BUILD_DIR,
- "bindings",
- "python",
- "_shortfin_default",
- ),
- target_dir,
- symlinks=False,
- )
-
-
-PYTHON_SOURCE_DIR = REL_SOURCE_DIR / "bindings" / "python"
-PYTHON_BINARY_DIR = REL_BINARY_DIR / "bindings" / "python"
-
-# We need some directories to exist before setup.
-def populate_built_package(abs_dir):
- """Makes sure that a directory and __init__.py exist.
-
- This needs to unfortunately happen before any of the build process
- takes place so that setuptools can plan what needs to be built.
- We do this for any built packages (vs pure source packages).
- """
- os.makedirs(abs_dir, exist_ok=True)
- with open(os.path.join(abs_dir, "__init__.py"), "wt"):
- pass
-
-
-populate_built_package(os.path.join(PYTHON_BINARY_DIR / "_shortfin_default"))
-
-setup(
- name="shortfin",
- version="0.9",
- description="Shortfin native library implementation",
- author="SHARK Authors",
- packages=[
- "_shortfin",
- "_shortfin_default",
- # TODO: Conditionally map additional native library variants.
- "shortfin",
- ],
- zip_safe=False,
- package_dir={
- "_shortfin": str(PYTHON_SOURCE_DIR / "_shortfin"),
- "_shortfin_default": str(PYTHON_BINARY_DIR / "_shortfin_default"),
- # TODO: Conditionally map additional native library variants.
- "shortfin": str(PYTHON_SOURCE_DIR / "shortfin"),
- },
- ext_modules=[
- CMakeExtension("_shortfin_default.lib")
- # TODO: Conditionally map additional native library variants.
- ],
- cmdclass={
- "build": CustomBuild,
- "build_ext": NoopBuildExtension,
- "build_py": CMakeBuildPy,
- },
-)
diff --git a/libshortfin/src/CMakeLists.txt b/libshortfin/src/CMakeLists.txt
deleted file mode 100644
index 51dd5f87f..000000000
--- a/libshortfin/src/CMakeLists.txt
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions. See
-# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
-# Apache-2.0 WITH LLVM-exception
-
-add_subdirectory(shortfin)
-
-# Common definitions exported from both static and dynamic libraries.
-add_library(_shortfin_defs INTERFACE)
-target_include_directories(
- _shortfin_defs INTERFACE $
- $)
-
-
-set(_INIT_INTERNAL_DEPS)
-if(SHORTFIN_SYSTEMS_AMDGPU)
- list(APPEND _INIT_INTERNAL_DEPS shortfin_systems_amdgpu)
-endif()
-
-shortfin_public_library(
- NAME
- shortfin
- COMPONENTS
- shortfin_array
- shortfin_local
- shortfin_support
- shortfin_systems_host
- ${_INIT_INTERNAL_DEPS}
-)
-
-set_target_properties(shortfin PROPERTIES VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR} SOVERSION ${SOVERSION})
diff --git a/libshortfin/src/shortfin/CMakeLists.txt b/libshortfin/src/shortfin/CMakeLists.txt
deleted file mode 100644
index 1bea0003b..000000000
--- a/libshortfin/src/shortfin/CMakeLists.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions. See
-# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
-# Apache-2.0 WITH LLVM-exception
-
-add_subdirectory(array)
-add_subdirectory(local)
-add_subdirectory(support)
diff --git a/libshortfin/src/shortfin/local/systems/CMakeLists.txt b/libshortfin/src/shortfin/local/systems/CMakeLists.txt
deleted file mode 100644
index effd15204..000000000
--- a/libshortfin/src/shortfin/local/systems/CMakeLists.txt
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions. See
-# https://llvm.org/LICENSE.txt for license information. SPDX-License-Identifier:
-# Apache-2.0 WITH LLVM-exception
-
-shortfin_cc_component(
- NAME
- shortfin_systems_host
- HDRS
- host.h
- SRCS
- host.cc
- COMPONENTS
- shortfin_local
- shortfin_support
- DEPS
- iree_hal_drivers_local_task_task_driver
- iree_hal_local_executable_loader
- iree_hal_local_executable_plugin
- iree_hal_local_executable_plugin_manager
- iree_hal_local_loaders_registration_registration
- iree_hal_local_local
- iree_task_api
- iree_task_task
-)
-
-shortfin_cc_component(
- NAME
- shortfin_systems_amdgpu
- HDRS
- amdgpu.h
- SRCS
- amdgpu.cc
- COMPONENTS
- shortfin_local
- shortfin_support
- DEPS
- iree_hal_drivers_hip_hip
-)
diff --git a/libshortfin/src/shortfin/local/systems/amdgpu.cc b/libshortfin/src/shortfin/local/systems/amdgpu.cc
deleted file mode 100644
index b12cb1a5f..000000000
--- a/libshortfin/src/shortfin/local/systems/amdgpu.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-// Copyright 2024 Advanced Micro Devices, Inc.
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "shortfin/local/systems/amdgpu.h"
-
-#include "shortfin/support/logging.h"
-
-namespace shortfin::local::systems {
-
-namespace {
-const std::string_view SYSTEM_DEVICE_CLASS = "amdgpu";
-const std::string_view LOGICAL_DEVICE_CLASS = "gpu";
-const std::string_view HAL_DRIVER_PREFIX = "hip";
-} // namespace
-
-AMDGPUSystemBuilder::AMDGPUSystemBuilder(iree_allocator_t host_allocator)
- : HostCPUSystemBuilder(host_allocator) {
- InitializeDefaultSettings();
- iree_hal_hip_device_params_initialize(&default_device_params_);
-}
-
-AMDGPUSystemBuilder::~AMDGPUSystemBuilder() = default;
-
-void AMDGPUSystemBuilder::InitializeDefaultSettings() {
- char *raw_dylib_path_env_cstr = std::getenv("IREE_HIP_DYLIB_PATH");
- if (raw_dylib_path_env_cstr) {
- std::string_view rest(raw_dylib_path_env_cstr);
- for (;;) {
- auto pos = rest.find(';');
- if (pos == std::string_view::npos) {
- hip_lib_search_paths.emplace_back(rest);
- break;
- }
- std::string_view first = rest.substr(0, pos);
- rest = rest.substr(pos + 1);
- hip_lib_search_paths.emplace_back(first);
- }
- }
-}
-
-void AMDGPUSystemBuilder::Enumerate() {
- if (hip_hal_driver_) return;
-
- iree_hal_hip_driver_options_t driver_options;
- iree_hal_hip_driver_options_initialize(&driver_options);
-
- // Search path.
- std::vector hip_lib_search_path_sv;
- hip_lib_search_path_sv.resize(hip_lib_search_paths.size());
- for (size_t i = 0; i < hip_lib_search_paths.size(); ++i) {
- hip_lib_search_path_sv[i].data = hip_lib_search_paths[i].data();
- hip_lib_search_path_sv[i].size = hip_lib_search_paths[i].size();
- }
- driver_options.hip_lib_search_paths = hip_lib_search_path_sv.data();
- driver_options.hip_lib_search_path_count = hip_lib_search_path_sv.size();
-
- SHORTFIN_THROW_IF_ERROR(iree_hal_hip_driver_create(
- IREE_SV("hip"), &driver_options, &default_device_params_,
- host_allocator(), hip_hal_driver_.for_output()));
-
- // Get available devices and filter into visible_devices_.
- iree_host_size_t available_devices_count = 0;
- iree::allocated_ptr raw_available_devices(
- host_allocator());
- SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices(
- hip_hal_driver_, host_allocator(), &available_devices_count,
- raw_available_devices.for_output()));
- for (iree_host_size_t i = 0; i < available_devices_count; ++i) {
- iree_hal_device_info_t *info = &raw_available_devices.get()[i];
- // TODO: Filter based on visibility list.
- visible_devices_.push_back(*info);
- logging::info("Enumerated visible AMDGPU device: {} ({})",
- to_string_view(visible_devices_.back().path),
- to_string_view(visible_devices_.back().name));
- }
-}
-
-SystemPtr AMDGPUSystemBuilder::CreateSystem() {
- auto lsys = std::make_shared(host_allocator());
- Enumerate();
- // TODO: Real NUMA awareness.
- lsys->InitializeNodes(1);
- lsys->InitializeHalDriver(SYSTEM_DEVICE_CLASS, hip_hal_driver_);
-
- // Initialize all visible GPU devices.
- for (size_t i = 0; i < visible_devices_.size(); ++i) {
- auto &it = visible_devices_[i];
- iree::hal_device_ptr device;
- SHORTFIN_THROW_IF_ERROR(iree_hal_driver_create_device_by_id(
- hip_hal_driver_, it.device_id, 0, nullptr, host_allocator(),
- device.for_output()));
- lsys->InitializeHalDevice(std::make_unique(
- DeviceAddress(
- /*system_device_class=*/SYSTEM_DEVICE_CLASS,
- /*logical_device_class=*/LOGICAL_DEVICE_CLASS,
- /*hal_driver_prefix=*/HAL_DRIVER_PREFIX,
- /*instance_ordinal=*/i,
- /*queue_ordinal=*/0,
- /*instance_topology_address=*/{0}),
- std::move(device), /*node_affinity=*/0,
- /*node_locked=*/false));
- }
-
- // Initialize CPU devices if requested.
- if (cpu_devices_enabled) {
- // Delegate to the HostCPUSystemConfig to configure CPU devices.
- // This will need to become more complicated and should happen after
- // GPU configuration when mating NUMA nodes, etc.
- InitializeHostCPUDefaults();
- auto *driver = InitializeHostCPUDriver(*lsys);
- InitializeHostCPUDevices(*lsys, driver);
- }
-
- lsys->FinishInitialization();
- return lsys;
-}
-
-} // namespace shortfin::local::systems
diff --git a/libshortfin/src/shortfin/local/systems/amdgpu.h b/libshortfin/src/shortfin/local/systems/amdgpu.h
deleted file mode 100644
index b1a182a8d..000000000
--- a/libshortfin/src/shortfin/local/systems/amdgpu.h
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2024 Advanced Micro Devices, Inc.
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H
-#define SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H
-
-#include
-
-#include "iree/hal/drivers/hip/api.h"
-#include "shortfin/local/system.h"
-#include "shortfin/local/systems/host.h"
-#include "shortfin/support/api.h"
-#include "shortfin/support/iree_helpers.h"
-
-namespace shortfin::local::systems {
-
-// AMD GPU device subclass.
-class SHORTFIN_API AMDGPUDevice : public Device {
- public:
- using Device::Device;
-};
-
-// System configuration for some subset of AMD GPUs connected to the local
-// system. Note that this inherits from HostCPUSystemBuilder, allowing joint
-// configuration of a heterogenous CPU/GPU system. Depending on the specific
-// system, this can involve more than simple starting CPU drivers: datacenter
-// GPU systems have specific NUMA configurations that need to be mated.
-class SHORTFIN_API AMDGPUSystemBuilder : public HostCPUSystemBuilder {
- public:
- AMDGPUSystemBuilder(iree_allocator_t host_allocator);
- AMDGPUSystemBuilder() : AMDGPUSystemBuilder(iree_allocator_system()) {}
- ~AMDGPUSystemBuilder();
-
- // Triggers driver setup and initial device enumeration. No-op if already
- // done.
- void Enumerate();
-
- SystemPtr CreateSystem() override;
-
- // Settings.
- bool cpu_devices_enabled = false;
-
- // See iree_hal_hip_driver_options_t::hip_lib_search_paths. Each is either
- // a directory or "file:" prefixed path to a specific HIP dynamic library.
- // This is typically libamdhip64.so or amdhip64.dll.
- // If the environment variable IREE_HIP_DYLIB_PATH is present, then it is
- // split on ';' and each entry added here (for compatibility with IREE
- // tools).
- // Changing these paths after enumeration has no effect.
- std::vector hip_lib_search_paths;
-
- private:
- void InitializeDefaultSettings();
-
- // Valid at construction time.
- iree_hal_hip_device_params_t default_device_params_;
-
- // Valid post enumeration.
- iree::hal_driver_ptr hip_hal_driver_;
- std::vector visible_devices_;
-};
-
-} // namespace shortfin::local::systems
-
-#endif // SHORTFIN_LOCAL_SYSTEMS_AMDGPU_H
diff --git a/libshortfin/src/shortfin/local/systems/host.cc b/libshortfin/src/shortfin/local/systems/host.cc
deleted file mode 100644
index 311c13398..000000000
--- a/libshortfin/src/shortfin/local/systems/host.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-// Copyright 2024 Advanced Micro Devices, Inc.
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "shortfin/local/systems/host.h"
-
-#include "iree/hal/local/loaders/registration/init.h"
-#include "shortfin/support/iree_helpers.h"
-#include "shortfin/support/logging.h"
-
-namespace shortfin::local::systems {
-
-namespace {
-const std::string_view SYSTEM_DEVICE_CLASS = "host-cpu";
-const std::string_view LOGICAL_DEVICE_CLASS = "cpu";
-const std::string_view HAL_DRIVER_PREFIX = "local";
-} // namespace
-
-// -------------------------------------------------------------------------- //
-// HostCPUSystemBuilder
-// -------------------------------------------------------------------------- //
-
-HostCPUSystemBuilder::Deps::Deps(iree_allocator_t host_allocator) {
- iree_task_executor_options_initialize(&task_executor_options);
- iree_hal_task_device_params_initialize(&task_params);
- iree_task_topology_initialize(&task_topology_options);
-}
-
-HostCPUSystemBuilder::Deps::~Deps() {
- iree_task_topology_deinitialize(&task_topology_options);
- for (iree_host_size_t i = 0; i < loader_count; ++i) {
- iree_hal_executable_loader_release(loaders[i]);
- }
- if (device_allocator) {
- iree_hal_allocator_release(device_allocator);
- }
- if (executor) {
- iree_task_executor_release(executor);
- }
- if (plugin_manager) {
- iree_hal_executable_plugin_manager_release(plugin_manager);
- }
-}
-
-HostCPUSystemBuilder::HostCPUSystemBuilder(iree_allocator_t host_allocator)
- : HostSystemBuilder(host_allocator), host_cpu_deps_(host_allocator) {}
-
-HostCPUSystemBuilder::~HostCPUSystemBuilder() = default;
-
-void HostCPUSystemBuilder::InitializeHostCPUDefaults() {
- // Give it a default device allocator.
- if (!host_cpu_deps_.device_allocator) {
- SHORTFIN_THROW_IF_ERROR(iree_hal_allocator_create_heap(
- iree_make_cstring_view("local"), host_allocator(), host_allocator(),
- &host_cpu_deps_.device_allocator));
- }
-
- // And loaders.
- if (host_cpu_deps_.loader_count == 0) {
- SHORTFIN_THROW_IF_ERROR(iree_hal_create_all_available_executable_loaders(
- /*plugin_manager=*/nullptr, IREE_ARRAYSIZE(host_cpu_deps_.loaders),
- &host_cpu_deps_.loader_count, host_cpu_deps_.loaders,
- host_allocator()));
- }
-}
-
-SystemPtr HostCPUSystemBuilder::CreateSystem() {
- auto lsys = std::make_shared(host_allocator());
- // TODO: Real NUMA awareness.
- lsys->InitializeNodes(1);
- InitializeHostCPUDefaults();
- auto *driver = InitializeHostCPUDriver(*lsys);
- InitializeHostCPUDevices(*lsys, driver);
- lsys->FinishInitialization();
- return lsys;
-}
-
-iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) {
- // TODO: Kill these flag variants in favor of settings on the config
- // object.
- SHORTFIN_THROW_IF_ERROR(iree_task_executor_options_initialize_from_flags(
- &host_cpu_deps_.task_executor_options));
- // TODO: Do something smarter than pinning to NUMA node 0.
- SHORTFIN_THROW_IF_ERROR(iree_task_topology_initialize_from_flags(
- /*node_id=*/0, &host_cpu_deps_.task_topology_options));
-
- SHORTFIN_THROW_IF_ERROR(
- iree_task_executor_create(host_cpu_deps_.task_executor_options,
- &host_cpu_deps_.task_topology_options,
- host_allocator(), &host_cpu_deps_.executor));
-
- // Create the driver and save it in the System.
- iree::hal_driver_ptr driver;
- iree_hal_driver_t *unowned_driver;
- SHORTFIN_THROW_IF_ERROR(iree_hal_task_driver_create(
- /*identifier=*/
- {
- .data = HAL_DRIVER_PREFIX.data(),
- .size = HAL_DRIVER_PREFIX.size(),
- },
- &host_cpu_deps_.task_params, /*queue_count=*/1, &host_cpu_deps_.executor,
- host_cpu_deps_.loader_count, host_cpu_deps_.loaders,
- host_cpu_deps_.device_allocator, host_allocator(), driver.for_output()));
- unowned_driver = driver.get();
- lsys.InitializeHalDriver(SYSTEM_DEVICE_CLASS, std::move(driver));
- return unowned_driver;
-}
-
-void HostCPUSystemBuilder::InitializeHostCPUDevices(System &lsys,
- iree_hal_driver_t *driver) {
- iree_host_size_t device_info_count = 0;
- iree::allocated_ptr device_infos(host_allocator());
- SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices(
- driver, host_allocator(), &device_info_count, &device_infos));
-
- for (iree_host_size_t i = 0; i < device_info_count; ++i) {
- iree::hal_device_ptr device;
- iree_hal_device_info_t *it = &device_infos.get()[i];
- SHORTFIN_THROW_IF_ERROR(iree_hal_driver_create_device_by_id(
- driver, it->device_id, 0, nullptr, host_allocator(),
- device.for_output()));
- lsys.InitializeHalDevice(std::make_unique(
- DeviceAddress(
- /*system_device_class=*/SYSTEM_DEVICE_CLASS,
- /*logical_device_class=*/LOGICAL_DEVICE_CLASS,
- /*hal_driver_prefix=*/HAL_DRIVER_PREFIX,
- /*instance_ordinal=*/i,
- /*queue_ordinal=*/0,
- /*instance_topology_address=*/{0}),
- /*hal_device=*/std::move(device),
- /*node_affinity=*/0,
- /*node_locked=*/false));
- }
-}
-
-} // namespace shortfin::local::systems
diff --git a/libshortfin/tests/amdgpu_system_test.py b/libshortfin/tests/amdgpu_system_test.py
deleted file mode 100644
index cf04a7caf..000000000
--- a/libshortfin/tests/amdgpu_system_test.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import pytest
-
-
-@pytest.mark.requires_amd_gpu
-def test_create_amd_gpu_system():
- from _shortfin import lib as sfl
-
- sc = sfl.local.amdgpu.SystemBuilder()
- ls = sc.create_system()
- print(f"LOCAL SYSTEM:", ls)
- for device_name in ls.device_names:
- print(f" DEVICE: {device_name} = {ls.device(device_name)}")
-
- print(ls.devices)
- print("Shutting down")
- ls.shutdown()
diff --git a/libshortfin/tests/api/array_test.py b/libshortfin/tests/api/array_test.py
deleted file mode 100644
index 729f091d8..000000000
--- a/libshortfin/tests/api/array_test.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import array
-import pytest
-
-import shortfin as sf
-import shortfin.array as sfnp
-
-
-@pytest.fixture
-def lsys():
- sc = sf.host.CPUSystemBuilder()
- lsys = sc.create_system()
- yield lsys
- lsys.shutdown()
-
-
-@pytest.fixture
-def scope(lsys):
- return lsys.create_scope()
-
-
-@pytest.fixture
-def device(scope):
- return scope.device(0)
-
-
-def test_storage_constructor(lsys, device):
- async def main():
- s = sfnp.storage.allocate_host(device, 8)
- s.fill(b"\0\1\2\3")
- await device
- ary = sfnp.device_array(s, [2, 4], sfnp.uint8)
- assert ary.dtype == sfnp.uint8
- assert ary.shape == [2, 4]
- assert str(ary) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}"
- assert ary.device == device
- assert ary.storage == s
-
- lsys.run(main())
-
-
-def test_device_constructor(lsys, device):
- async def main():
- ary = sfnp.device_array(device, [2, 4], sfnp.uint8)
- ary.storage.fill(b"\0\1\2\3")
- await device
- assert ary.dtype == sfnp.uint8
- assert ary.shape == [2, 4]
- assert str(ary) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}"
- assert ary.device == device
-
- lsys.run(main())
-
-
-def test_fill_copy_from_for_transfer(lsys, device):
- async def main():
- src = sfnp.device_array(device, [2, 4], sfnp.uint8)
- src.fill(b"\0\1\2\3")
- dst = src.for_transfer()
- dst.copy_from(src)
- await device
- assert str(dst) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}"
-
- lsys.run(main())
-
-
-def test_fill_copy_to_for_transfer(lsys, device):
- async def main():
- src = sfnp.device_array(device, [2, 4], sfnp.uint8)
- src.fill(b"\0\1\2\3")
- dst = src.for_transfer()
- src.copy_to(dst)
- await device
- assert str(dst) == "{{0, 1, 2, 3},\n {0, 1, 2, 3}}"
-
- lsys.run(main())
-
-
-def test_shape_overflow(lsys, device):
- async def main():
- s = sfnp.storage.allocate_host(device, 4)
- _ = sfnp.device_array(s, [2, 4], sfnp.uint8)
-
- with pytest.raises(
- ValueError, match="Array storage requires at least 8 bytes but has only 4"
- ):
- lsys.run(main())
-
-
-@pytest.mark.parametrize(
- "dtype,code,py_value,expected_str",
- [
- (sfnp.int8, "b", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"),
- (sfnp.int16, "h", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"),
- (sfnp.int32, "i", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"),
- (
- sfnp.float32,
- "f",
- 42.0,
- "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}",
- ),
- (
- sfnp.float64,
- "d",
- 42.0,
- "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}",
- ),
- ],
-)
-def test_xtensor_types(scope, dtype, code, py_value, expected_str):
- ary = sfnp.device_array.for_host(scope.device(0), [2, 4], dtype)
- ary.storage.data = array.array(code, [py_value] * 8)
- s = str(ary)
- print("__str__ =", s)
- assert expected_str == s, f"Expected '{expected_str}' == '{s}'"
- r = repr(ary)
- print("__repr__ =", r)
- assert expected_str in r, f"Expected '{expected_str}' in '{r}'"
diff --git a/libshortfin/tests/local_scope_test.py b/libshortfin/tests/local_scope_test.py
deleted file mode 100644
index 379f13d4b..000000000
--- a/libshortfin/tests/local_scope_test.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import pytest
-import time
-
-from _shortfin import lib as sfl
-
-
-@pytest.fixture
-def lsys():
- sc = sfl.local.host.CPUSystemBuilder()
- ls = sc.create_system()
- yield ls
- ls.shutdown()
-
-
-@pytest.fixture
-def scope(lsys):
- # TODO: Should adopt the main thread.
- worker = lsys.create_worker("main")
- return lsys.create_scope(worker)
-
-
-def test_raw_device_access(scope):
- first_name = scope.device_names[0]
- assert first_name == "cpu0"
- first_device = scope.raw_device(0) # By index
- assert isinstance(first_device, sfl.local.host.HostCPUDevice)
- assert first_device is scope.raw_device(first_name) # By name
- print(first_device)
- devices = scope.raw_devices
- named_devices = scope.named_devices
- assert first_name in named_devices
- assert devices[0] is named_devices[first_name]
- assert devices[0] is first_device
- with pytest.raises(ValueError):
- scope.raw_device("cpu1")
- with pytest.raises(ValueError):
- scope.raw_device(1)
-
-
-def test_devices_collection_access(scope):
- # # Access via devices pseudo collection.
- first_device = scope.raw_device(0)
- assert scope.devices.cpu0.raw_device is first_device
- assert scope.devices[0].raw_device is first_device
- assert scope.devices["cpu0"].raw_device is first_device
- assert len(scope.devices) == 1
- with pytest.raises(ValueError):
- scope.devices.cpu1
- with pytest.raises(ValueError):
- scope.devices[1]
-
-
-def test_device_affinity_repr(scope):
- assert (
- repr(sfl.local.DeviceAffinity(scope.raw_device(0)))
- == "DeviceAffinity(host-cpu:0:0@0[0x1])"
- )
- assert repr(sfl.local.DeviceAffinity()) == "DeviceAffinity(ANY)"
-
-
-def test_device_affinity_resolve(scope):
- # TODO: Need a scope with multiple devices to test errors.
- print(scope.device(0, "cpu0", scope.raw_device(0)))
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 000000000..e736fe3bd
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,10 @@
+# Used for managing pre-commit flows.
+pre-commit
+
+# Type checking
+mypy==1.8.0
+types-requests==2.31.0.20240125
+
+# Testing
+pytest==8.0.0
+pytest-xdist==3.5.0
diff --git a/requirements.txt b/requirements.txt
index 6a5cccc92..cc2edf876 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,26 +1,4 @@
-# Runtime deps.
-gguf==0.6.0
-numpy==1.26.3
-onnx==1.15.0
-
-# Model deps.
-huggingface-hub==0.22.2
-transformers==4.40.0
-sentencepiece==0.2.0
-
-# It is expected that you have installed a PyTorch version/variant specific
-# to your needs, so we only include a minimum version spec.
-# TODO: Use a versioned release once 2.3.0 drops.
-torch>=2.3.0.dev1
-
-# Used for managing pre-commit flows.
-pre-commit
-
-# Type checking
-mypy==1.8.0
-types-requests==2.31.0.20240125
-
-# Testing
-parameterized
-pytest==8.0.0
-pytest-xdist==3.5.0
+-r sharktank/requirements.txt
+-r sharktank/requirements-tests.txt
+-r shortfin/requirements-tests.txt
+-r requirements-dev.txt
diff --git a/shark-ai/.gitignore b/shark-ai/.gitignore
new file mode 100644
index 000000000..80bf001b8
--- /dev/null
+++ b/shark-ai/.gitignore
@@ -0,0 +1,2 @@
+# Local-only config options
+requirements.txt
diff --git a/shark-ai/README.md b/shark-ai/README.md
new file mode 100644
index 000000000..0bb1abafd
--- /dev/null
+++ b/shark-ai/README.md
@@ -0,0 +1,3 @@
+# SHARK AI meta package
+
+Meta package to install `shortfin` and compatible IREE packages.
diff --git a/shark-ai/build_tools/build_linux_package.sh b/shark-ai/build_tools/build_linux_package.sh
new file mode 100755
index 000000000..d16f339b1
--- /dev/null
+++ b/shark-ai/build_tools/build_linux_package.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# build_linux_package.sh
+#
+# Builds shark-ai Python package for Linux.
+#
+# Usage:
+# ./build_tools/build_linux_package.sh
+
+set -xeu -o errtrace
+
+THIS_DIR="$(cd $(dirname $0) && pwd)"
+REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)"
+OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}"
+
+python -m pip wheel --disable-pip-version-check --no-deps -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/shark-ai"
+
+wheel_output="$(echo "${OUTPUT_DIR}/shark_ai-"*".whl")"
+ls "${wheel_output}"
diff --git a/shark-ai/pyproject.toml b/shark-ai/pyproject.toml
new file mode 100644
index 000000000..3f7e4a1da
--- /dev/null
+++ b/shark-ai/pyproject.toml
@@ -0,0 +1,36 @@
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "shark-ai"
+authors = [
+ {name = "SHARK Authors"},
+]
+description = "SHARK AI meta package"
+readme = "README.md"
+license = {text = "Apache-2.0"}
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+]
+# Version is set via the `setup.py` and requirements are set via files below.
+dynamic = ["version", "dependencies"]
+
+[project.urls]
+Repository = "https://github.com/nod-ai/shark-ai"
+
+[project.optional-dependencies]
+onnx = [
+ "iree-base-compiler[onnx]",
+]
+apps = [
+ "shortfin[apps]",
+]
+
+[tool.setuptools]
+packages = []
+
+[tool.setuptools.dynamic]
+dependencies = {file = ["requirements.txt"]}
diff --git a/shark-ai/setup.py b/shark-ai/setup.py
new file mode 100644
index 000000000..5ceac55bd
--- /dev/null
+++ b/shark-ai/setup.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import os
+from pathlib import Path
+
+from setuptools import setup
+
+THIS_DIR = Path(__file__).parent.resolve()
+
+# Setup and get version information.
+# The `version_local.json` is generated by calling:
+# `build_tools/python_deploy/compute_common_version.py -stable --write-json`
+VERSION_FILE_LOCAL = os.path.join(THIS_DIR, "version_local.json")
+
+
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
+ return json.load(f)
+
+
+version_info = load_version_info(VERSION_FILE_LOCAL)
+
+PACKAGE_VERSION = version_info.get("package-version")
+print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'")
+
+setup(
+ version=f"{PACKAGE_VERSION}",
+)
diff --git a/sharktank/.gitignore b/sharktank/.gitignore
new file mode 100644
index 000000000..000e575d5
--- /dev/null
+++ b/sharktank/.gitignore
@@ -0,0 +1,2 @@
+# Local-only config options
+version_info_rc.json
diff --git a/sharktank/README.md b/sharktank/README.md
index 9dd239eba..7770595ed 100644
--- a/sharktank/README.md
+++ b/sharktank/README.md
@@ -10,6 +10,10 @@ This sub-project is a work in progress. It is intended to be a repository of
layers, model recipes, and conversion tools from popular LLM quantization
tooling.
+## Project Status
+
+[![CI - Perplexity](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml/badge.svg?branch=main&event=schedule)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml)
+
## Examples
The repository will ultimately grow a curated set of models and tools for
@@ -40,3 +44,30 @@ python -m sharktank.examples.export_paged_llm_v1 \
```shell
python -m sharktank.tools.dump_gguf --hf-dataset=open_llama_3b_v2_f16_gguf
```
+
+## Package Python Release Builds
+
+* To build wheels for Linux:
+
+ ```bash
+ ./build_tools/build_linux_package.sh
+ ```
+
+ That should produce
+ `build_tools/wheelhouse/sharktank-{X.Y.Z}.dev0-py3-none-any.whl`, which can
+ then be installed with
+
+ ```bash
+ python3 -m pip install build_tools/wheelhouse/sharktank-{X.Y.Z}.dev0-py3-none-any.whl
+ ```
+
+* To build a wheel for your host OS/arch manually:
+
+ ```bash
+ # Build sharktank.*.whl into the dist/ directory
+ # e.g. `sharktank-3.0.0.dev0-py3-none-any.whl`
+ python3 -m pip wheel -v -w dist .
+
+ # Install the built wheel.
+ python3 -m pip install dist/*.whl
+ ```
diff --git a/sharktank/build_tools/build_linux_package.sh b/sharktank/build_tools/build_linux_package.sh
new file mode 100755
index 000000000..3ed3a7fbb
--- /dev/null
+++ b/sharktank/build_tools/build_linux_package.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# build_linux_package.sh
+#
+# Builds sharktank Python package for Linux.
+#
+# Note: requires a modern Python (3.12+ seems to work). Troubleshooting help:
+# * https://stackoverflow.com/a/77284076
+# * https://stackoverflow.com/a/77364602
+# Older versions like 3.10 don't include the package name and set as UNKNOWN?
+# * Might just need some local packages updated?
+#
+# Usage:
+# ./build_tools/build_linux_package.sh
+
+set -xeu -o errtrace
+
+THIS_DIR="$(cd $(dirname $0) && pwd)"
+REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)"
+OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}"
+
+python -m pip wheel --disable-pip-version-check --no-deps -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/sharktank"
+
+wheel_output="$(echo "${OUTPUT_DIR}/sharktank-"*".whl")"
+ls "${wheel_output}"
diff --git a/sharktank/conftest.py b/sharktank/conftest.py
index fd9c47447..475f386be 100644
--- a/sharktank/conftest.py
+++ b/sharktank/conftest.py
@@ -5,6 +5,9 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from pathlib import Path
+import pytest
+from pytest import FixtureRequest
+from typing import Optional, Any
# Tests under each top-level directory will get a mark.
@@ -24,3 +27,312 @@ def pytest_collection_modifyitems(items, config):
mark = TLD_MARKS.get(tld)
if mark:
item.add_marker(mark)
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--mlir",
+ type=Path,
+ default=None,
+ help="Path to exported MLIR program. If not specified a temporary file will be used.",
+ )
+ parser.addoption(
+ "--module",
+ type=Path,
+ default=None,
+ help="Path to exported IREE module. If not specified a temporary file will be used.",
+ )
+ parser.addoption(
+ "--parameters",
+ type=Path,
+ default=None,
+ help="Exported model parameters. If not specified a temporary file will be used.",
+ )
+ parser.addoption(
+ "--prefix",
+ type=str,
+ default=None,
+ help=(
+ "Path prefix for test artifacts. "
+ "Other arguments may override this for specific values."
+ ),
+ )
+ parser.addoption(
+ "--caching",
+ action="store_true",
+ default=False,
+ help="Load cached results if present instead of recomputing.",
+ )
+
+ parser.addoption(
+ "--longrun",
+ action="store_true",
+ dest="longrun",
+ default=False,
+ help="Enable long tests",
+ )
+
+ parser.addoption(
+ "--run-quick-llama-test",
+ action="store_true",
+ dest="run-quick-llama-test",
+ default=False,
+ help="Enable llama 8b f16 decomposed benchmarking test",
+ )
+
+ parser.addoption(
+ "--run-nightly-llama-tests",
+ action="store_true",
+ dest="run-nightly-llama-tests",
+ default=False,
+ help="Enable all llama benchmarking tests",
+ )
+
+ parser.addoption(
+ "--with-t5-data",
+ action="store_true",
+ default=False,
+ help=(
+ "Enable tests that use T5 data like models that is not a part of the source "
+ "code. The user is expected to provide the data"
+ ),
+ )
+
+ # TODO: Remove all hardcoded paths in CI tests
+ parser.addoption(
+ "--llama3-8b-tokenizer-path",
+ type=Path,
+ action="store",
+ help="Llama3.1 8b tokenizer path, defaults to 30F CI system path",
+ )
+
+ parser.addoption(
+ "--llama3-8b-f16-model-path",
+ type=Path,
+ action="store",
+ help="Llama3.1 8b model path, defaults to 30F CI system path",
+ )
+
+ parser.addoption(
+ "--llama3-8b-fp8-model-path",
+ type=Path,
+ action="store",
+ default=None,
+ help="Llama3.1 8b fp8 model path",
+ )
+
+ parser.addoption(
+ "--llama3-405b-tokenizer-path",
+ type=Path,
+ action="store",
+ help="Llama3.1 405b tokenizer path, defaults to 30F CI system path",
+ )
+
+ parser.addoption(
+ "--llama3-405b-f16-model-path",
+ type=Path,
+ action="store",
+ help="Llama3.1 405b model path, defaults to 30F CI system path",
+ )
+
+ parser.addoption(
+ "--llama3-405b-fp8-model-path",
+ type=Path,
+ action="store",
+ default=None,
+ help="Llama3.1 405b fp8 model path",
+ )
+
+ # To obtain a T5 GGUF file you can use llama.cpp's convert_hf_to_gguf.py.
+ # https://github.com/ggerganov/llama.cpp/blob/9abe9eeae98b11fa93b82632b264126a010225ff/convert_hf_to_gguf.py
+ # E.g.
+ # git lfs install
+ # git clone https://huggingface.co/google/t5-v1_1-small
+ # convert_hf_to_gguf.py \
+ # --outfile t5-v1_1-small.gguf \
+ # --outtype=f32 \
+ # t5-v1_1-small
+ parser.addoption(
+ "--google-t5-v1-1-small-fp32-model-path",
+ type=Path,
+ default="/data/t5/small/google__t5-v1_1-small_fp32.gguf",
+ help="Google T5 v1.1 small fp32 model path",
+ )
+ parser.addoption(
+ "--google-t5-v1-1-xxl-fp32-model-path",
+ type=Path,
+ default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf",
+ help="Google T5 v1.1 XXL fp32 model path",
+ )
+
+ parser.addoption(
+ "--baseline-perplexity-scores",
+ type=Path,
+ action="store",
+ default="sharktank/tests/evaluate/baseline_perplexity_scores.json",
+ help="Llama3.1 8B & 405B model baseline perplexity scores",
+ )
+
+ parser.addoption(
+ "--iree-device",
+ type=str,
+ action="store",
+ help="List an IREE device from iree-run-module --list_devices",
+ )
+
+ parser.addoption(
+ "--iree-hip-target",
+ action="store",
+ help="Specify the iree-hip target version (e.g., gfx942)",
+ )
+
+ parser.addoption(
+ "--iree-hal-target-backends",
+ action="store",
+ help="Specify the iree-hal target backend (e.g., rocm)",
+ )
+
+ parser.addoption(
+ "--tensor-parallelism-size",
+ action="store",
+ type=int,
+ default=1,
+ help="Number of devices for tensor parallel sharding",
+ )
+
+ parser.addoption(
+ "--bs",
+ action="store",
+ type=int,
+ default=4,
+ help="Batch size for mlir export",
+ )
+
+
+def set_fixture_from_cli_option(
+ request: FixtureRequest,
+ cli_option_name: str,
+ class_attribute_name: Optional[str] = None,
+) -> Optional[Any]:
+ res = request.config.getoption(cli_option_name)
+ if request.cls is None:
+ return res
+ else:
+ if class_attribute_name is None:
+ class_attribute_name = cli_option_name
+ setattr(request.cls, class_attribute_name, res)
+
+
+@pytest.fixture(scope="class")
+def mlir_path(request: FixtureRequest) -> Optional[Path]:
+ return set_fixture_from_cli_option(request, "mlir", "mlir_path")
+
+
+@pytest.fixture(scope="class")
+def module_path(request: FixtureRequest) -> Optional[Path]:
+ return set_fixture_from_cli_option(request, "module", "module_path")
+
+
+@pytest.fixture(scope="class")
+def parameters_path(request: FixtureRequest) -> Optional[Path]:
+ return set_fixture_from_cli_option(request, "parameters", "parameters_path")
+
+
+@pytest.fixture(scope="class")
+def path_prefix(request: FixtureRequest) -> Optional[str]:
+ return set_fixture_from_cli_option(request, "prefix", "path_prefix")
+
+
+@pytest.fixture(scope="class")
+def caching(request: FixtureRequest) -> Optional[bool]:
+ return set_fixture_from_cli_option(request, "caching")
+
+
+@pytest.fixture(scope="class")
+def tensor_parallelism_size(request: FixtureRequest) -> Optional[str]:
+ return set_fixture_from_cli_option(
+ request, "tensor_parallelism_size", "tensor_parallelism_size"
+ )
+
+
+@pytest.fixture(scope="class")
+def baseline_perplexity_scores(request: FixtureRequest) -> Optional[str]:
+ return set_fixture_from_cli_option(
+ request, "baseline_perplexity_scores", "baseline_perplexity_scores"
+ )
+
+
+@pytest.fixture(scope="class")
+def batch_size(request: FixtureRequest) -> Optional[str]:
+ return set_fixture_from_cli_option(request, "bs", "batch_size")
+
+
+@pytest.fixture(scope="class")
+def get_model_artifacts(request: FixtureRequest):
+ model_path = {}
+ model_path["llama3_8b_tokenizer_path"] = set_fixture_from_cli_option(
+ request, "--llama3-8b-tokenizer-path", "llama3_8b_tokenizer"
+ )
+ model_path["llama3_8b_f16_model_path"] = set_fixture_from_cli_option(
+ request, "--llama3-8b-f16-model-path", "llama3_8b_f16_model"
+ )
+ model_path["llama3_8b_fp8_model_path"] = set_fixture_from_cli_option(
+ request, "--llama3-8b-fp8-model-path", "llama3_8b_fp8_model"
+ )
+ model_path["llama3_405b_tokenizer_path"] = set_fixture_from_cli_option(
+ request, "--llama3-405b-tokenizer-path", "llama3_405b_tokenizer"
+ )
+ model_path["llama3_405b_f16_model_path"] = set_fixture_from_cli_option(
+ request, "--llama3-405b-f16-model-path", "llama3_405b_f16_model"
+ )
+ model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
+ request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
+ )
+ model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option(
+ request,
+ "--google-t5-v1-1-small-fp32-model-path",
+ "google__t5_v1_1_small_fp32_model",
+ )
+ model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option(
+ request,
+ "--google-t5-v1-1-xxl-fp32-model-path",
+ "google__t5_v1_1_xxl_fp32_model",
+ )
+ return model_path
+
+
+@pytest.fixture(scope="class")
+def get_iree_flags(request: FixtureRequest):
+ model_path = {}
+ model_path["iree_device"] = set_fixture_from_cli_option(
+ request, "--iree-device", "iree_device"
+ )
+ model_path["iree_hip_target"] = set_fixture_from_cli_option(
+ request, "--iree-hip-target", "iree_hip_target"
+ )
+ model_path["iree_hal_target_backends"] = set_fixture_from_cli_option(
+ request, "--iree-hal-target-backends", "iree_hal_target_backends"
+ )
+
+
+# The following three functions allow us to add a "XFail Reason" column to the html reports for each test
+@pytest.hookimpl(optionalhook=True)
+def pytest_html_results_table_header(cells):
+ cells.insert(2, "XFail Reason | ")
+
+
+@pytest.hookimpl(optionalhook=True)
+def pytest_html_results_table_row(report, cells):
+ if hasattr(report, "wasxfail"):
+ cells.insert(2, f"{report.wasxfail} | ")
+ else:
+ cells.insert(2, f" | ")
+
+
+@pytest.hookimpl(hookwrapper=True)
+def pytest_runtest_makereport(item, call):
+ outcome = yield
+ report = outcome.get_result()
+
+ if report.when == "call" and hasattr(item, "wasxfail"):
+ report.wasxfail = item.wasxfail
diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml
index bc5203c86..01cad409b 100644
--- a/sharktank/pyproject.toml
+++ b/sharktank/pyproject.toml
@@ -2,6 +2,40 @@
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
+[project]
+name = "sharktank"
+authors = [
+ {name = "SHARK Authors"},
+]
+description = "SHARK layers and inference models for genai"
+readme = "README.md"
+license = {text = "Apache-2.0"}
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+]
+
+# Version is set via the `setup.py` and requirements are set via files below.
+dynamic = ["version", "dependencies", "optional-dependencies"]
+
+[project.urls]
+Repository = "https://github.com/nod-ai/shark-ai"
+
+[tool.setuptools.packages.find]
+where = ["."]
+include = ["sharktank*"]
+namespaces = true
+
+[tool.setuptools.package-data]
+sharktank = ["py.typed", "kernels/templates/*.mlir"]
+
+[tool.setuptools.dynamic.dependencies]
+file = ["requirements.txt"]
+
+[tool.setuptools.dynamic.optional-dependencies]
+testing = {file = ["requirements-tests.txt"]}
+
[tool.pytest.ini_options]
addopts = [
"-ra",
diff --git a/sharktank/requirements-dev-turbine.txt b/sharktank/requirements-dev-turbine.txt
new file mode 100644
index 000000000..0d0dc7619
--- /dev/null
+++ b/sharktank/requirements-dev-turbine.txt
@@ -0,0 +1 @@
+-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
diff --git a/sharktank/requirements-tests.txt b/sharktank/requirements-tests.txt
new file mode 100644
index 000000000..d5b4b0c0e
--- /dev/null
+++ b/sharktank/requirements-tests.txt
@@ -0,0 +1,4 @@
+datasets==3.0.0
+parameterized
+pytest==8.0.0
+pytest-html
diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt
index 6b21f239f..6b533d977 100644
--- a/sharktank/requirements.txt
+++ b/sharktank/requirements.txt
@@ -1 +1,21 @@
-gguf
+iree-turbine
+
+# Runtime deps.
+gguf==0.6.0
+numpy<2.0
+
+# Needed for newer gguf versions (TODO: remove when gguf package includes this)
+# sentencepiece>=0.1.98,<=0.2.0
+
+# Model deps.
+huggingface-hub==0.22.2
+transformers==4.40.0
+datasets
+
+# It is expected that you have installed a PyTorch version/variant specific
+# to your needs, so we only include a minimum version spec.
+torch>=2.3.0
+
+# Serving deps.
+fastapi==0.112.2
+uvicorn==0.30.6
diff --git a/sharktank/setup.py b/sharktank/setup.py
index b8caf9e7d..182f94abc 100644
--- a/sharktank/setup.py
+++ b/sharktank/setup.py
@@ -6,102 +6,31 @@
import json
import os
-import distutils.command.build
from pathlib import Path
-from setuptools import find_namespace_packages, setup # type: ignore
+from setuptools import setup
-THIS_DIR = Path(__file__).resolve().parent
-REPO_DIR = THIS_DIR.parent
-VERSION_INFO_FILE = REPO_DIR / "version_info.json"
+SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__))
+# Setup and get version information.
+VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json")
+VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json")
-with open(
- os.path.join(
- THIS_DIR,
- "README.md",
- ),
- "rt",
-) as f:
- README = f.read()
-
-def load_version_info():
- with open(VERSION_INFO_FILE, "rt") as f:
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
return json.load(f)
-version_info = load_version_info()
-PACKAGE_VERSION = version_info["package-version"]
-
-packages = find_namespace_packages(
- include=[
- "sharktank",
- "sharktank.*",
- ],
-)
-
-print("Found packages:", packages)
-
-# Lookup version pins from requirements files.
-requirement_pins = {}
-
-
-def load_requirement_pins(requirements_file: Path):
- with open(requirements_file, "rt") as f:
- lines = f.readlines()
- pin_pairs = [line.strip().split("==") for line in lines if "==" in line]
- requirement_pins.update(dict(pin_pairs))
-
-
-load_requirement_pins(REPO_DIR / "requirements.txt")
-
-
-def get_version_spec(dep: str):
- if dep in requirement_pins:
- return f">={requirement_pins[dep]}"
- else:
- return ""
-
-
-# Override build command so that we can build into _python_build
-# instead of the default "build". This avoids collisions with
-# typical CMake incantations, which can produce all kinds of
-# hilarity (like including the contents of the build/lib directory).
-class BuildCommand(distutils.command.build.build):
- def initialize_options(self):
- distutils.command.build.build.initialize_options(self)
- self.build_base = "_python_build"
+try:
+ version_info = load_version_info(VERSION_FILE_LOCAL)
+except FileNotFoundError:
+ print("version_local.json not found. Default to dev build")
+ version_info = load_version_info(VERSION_FILE)
+PACKAGE_VERSION = version_info.get("package-version")
+print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'")
setup(
- name=f"sharktank",
version=f"{PACKAGE_VERSION}",
- author="SHARK Authors",
- author_email="stella@nod.ai",
- description="SHARK layers and inference models for genai",
- long_description=README,
- long_description_content_type="text/markdown",
- url="https://github.com/nod-ai/sharktank",
- license="Apache-2.0",
- classifiers=[
- "Development Status :: 3 - Alpha",
- "License :: OSI Approved :: Apache Software License",
- "Programming Language :: Python :: 3",
- ],
- packages=packages,
- include_package_data=True,
- package_data={
- "sharktank": ["py.typed", "kernels/templates/*.mlir"],
- },
- install_requires=[
- "shark-turbine",
- ],
- extras_require={
- "testing": [
- f"pytest{get_version_spec('pytest')}",
- f"pytest-xdist{get_version_spec('pytest-xdist')}",
- ],
- },
- cmdclass={"build": BuildCommand},
)
diff --git a/sharktank/sharktank/evaluate/README.md b/sharktank/sharktank/evaluate/README.md
new file mode 100644
index 000000000..beb0281cd
--- /dev/null
+++ b/sharktank/sharktank/evaluate/README.md
@@ -0,0 +1,40 @@
+# LLM Evaluation Pipeline
+
+## Setup
+Setup SHARK Platform's Evaluation Pipeline
+
+```
+pip install -r sharktank/requirements-tests.txt
+```
+
+### Perplexity
+
+Perplexity score measures the ability of a language model to predict the next token in a sequence. A lower score indicates that a model has higher certainty in it's predictions. Perplexity acts as an intrinsic evaluation metric that measures the model quality, independent of any downstream task.
+
+In SHARK-Platform, we use perplexity to track code regressions and quality loss across quantized models (with FP16 as baseline). We use 100 prompts randomly selected from the Wikitext-2 test set and calculate the mean perplexities shown below. These numbers are neither comparable between models with different tokenizers nor with other projects due to varying implementations.
+
+* Test perplexity for Llama3.1 8B (FP16) model:
+
+```bash
+pytest sharktank/tests/evaluate/perplexity_test.py --longrun
+```
+
+* Calculate perplexity for a new model:
+
+```bash
+python -m sharktank.evaluate.perplexity \
+ --gguf-file=llama3_70b_f16.gguf \
+ --tokenizer-config-json=tokenizer_config.json
+```
+
+### Perplexity Scoreboard
+
+| CPU | GPU |
+|:-------------: |:----------:|
+| AMD EPYC 9554 | MI300X |
+
+#### LLaMA 3.1
+
+|Models |Model size (GB) |Torch score |IREE score |
+|:----------------------|:---------------|:-------------|:-------------|
+|8B FP16 TP1 decomposed |16.07 |14.930181 |14.991893 |
diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py
new file mode 100644
index 000000000..6060eb91b
--- /dev/null
+++ b/sharktank/sharktank/evaluate/perplexity_iree.py
@@ -0,0 +1,492 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+import logging
+import json
+import time
+import random
+import re
+from datetime import timedelta
+from tqdm import tqdm
+
+import numpy as np
+
+from datasets import load_dataset
+
+import torch
+from torch.nn import CrossEntropyLoss
+
+from sharktank.models.llama.llama import *
+from sharktank.models.mixtral.mixtral import *
+from sharktank.models.grok.grok import *
+
+from ..models.llama.sharding import shard_theta
+
+from sharktank.layers import *
+from sharktank.types import *
+
+from sharktank.utils import cli
+from sharktank.utils.vmfb_runner import *
+from sharktank.utils.load_llm import *
+from sharktank.utils.create_cache import *
+from sharktank.utils.export_artifacts import *
+
+log_levels = {
+ "info": logging.INFO,
+ "debug": logging.DEBUG,
+}
+logger = logging.getLogger("eval")
+
+logger.setLevel(log_levels["info"])
+
+logger.root.handlers[0].setFormatter(
+ logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s")
+)
+
+__all__ = ["Perplexity", "run_perplexity"]
+
+
+class Perplexity:
+ """
+ Perplexity (PPL) is one of the most common metrics for evaluating language models.
+ It is defined as the exponentiated average negative log-likelihood of a sequence,
+ calculated with exponent base `e`.
+
+ For more information, see https://huggingface.co/docs/transformers/perplexity
+ """
+
+ def __init__(
+ self,
+ torch_device,
+ iree_device,
+ iree_hip_target,
+ iree_hal_target_backends,
+ kv_cache_type,
+ tensor_parallelism_size,
+ attention_kernel,
+ ):
+ self.torch_device = torch_device
+ self.iree_device = iree_device
+ self.iree_hip_target = iree_hip_target
+ self.iree_hal_target_backends = iree_hal_target_backends
+ self.kv_cache_type = kv_cache_type
+ self.activation_dtype = torch.float16
+ self.attention_dtype = torch.float16
+ self.tensor_parallelism_size = tensor_parallelism_size
+ self.attention_kernel = attention_kernel
+
+ def timeit(func):
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ result = func(*args, **kwargs)
+ end = time.time()
+ total_seconds = end - start
+ time_taken = abs(timedelta(seconds=total_seconds))
+ hours, minutes, seconds = re.split(":", str(time_taken))
+
+ if total_seconds < 1:
+ time_taken = f" {round(total_seconds * 1000, 3)} ms"
+ elif total_seconds < 60:
+ time_taken = "{:.2f} secs".format(round(float(total_seconds), 2))
+ else:
+ time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format(
+ int(hours), int(minutes), round(float(seconds), 2)
+ )
+
+ func_name = func.__name__
+ if func_name == "get_perplexity":
+ func_name = f"Calculate perplexity"
+ elif func_name == "compile_model":
+ func_name = f"Export & compile"
+ logger.info(f" {func_name}: {time_taken}")
+ return result
+
+ return wrapper
+
+ def print_token_comparison(self, i):
+ if i <= self.max_prompt_length:
+ batch_predicted_token_id = [[i[-1]] for i in self.batch.results]
+ batch_predicted_token = self.generator.tokenizer.decode(
+ batch_predicted_token_id
+ )
+ logger.debug(f"Predicted:")
+ logger.debug(f"{batch_predicted_token}")
+ logger.debug(f"{batch_predicted_token_id}")
+
+ expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist()
+ expected_token = self.generator.tokenizer.decode(expected_token_id)
+ logger.debug(f"Expected:")
+ logger.debug(f"{expected_token}")
+ logger.debug(f"{expected_token_id}")
+
+ @timeit
+ def compile_model(self, weight_path_str):
+ self.weight_path_str = weight_path_str
+
+ logger.info(f" Compiling: {self.weight_path_str}")
+
+ export_artifacts = ExportArtifacts(
+ irpa_path=self.weight_path_str,
+ batch_size=self.bs,
+ iree_hip_target=self.iree_hip_target,
+ iree_hal_target_backends=self.iree_hal_target_backends,
+ attention_kernel=self.attention_kernel,
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ vmfb_path = export_artifacts.get_artifacts()
+ return vmfb_path
+
+ @timeit
+ def load_model(self, weight_path, tokenizer, vmfb_path):
+
+ self.config = LlamaModelConfig(
+ hp=configs.LlamaHParams.from_gguf_props(weight_path.properties),
+ block_seq_stride=16,
+ kv_cache_type=self.kv_cache_type,
+ device=self.torch_device,
+ activation_dtype=self.activation_dtype,
+ attention_dtype=self.attention_dtype,
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+
+ if self.config.tensor_parallelism_size > 1:
+ weight_path.root_theta = shard_theta(weight_path.root_theta, self.config)
+
+ theta = weight_path.root_theta
+
+ if self.config.hp.expert_count:
+ if self.config.hp.model_arch == "grok":
+ model = PagedGrokModelV1(theta, self.config)
+ else:
+ model = PagedMixtralModelV1(theta, self.config)
+ else:
+ model = PagedLlamaModelV1(theta, self.config)
+
+ self.generator = TorchGenerator(model, tokenizer)
+
+ self.runner = vmfbRunner(
+ device=self.iree_device,
+ vmfb_path=vmfb_path,
+ external_weight_path=self.weight_path_str,
+ )
+
+ self.haldevice = self.runner.config.device
+
+ @timeit
+ def get_prompts(self, num_prompts):
+ test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
+ "text"
+ ]
+ num_test_prompts = 219
+
+ random.seed(0)
+ test_prompts = random.sample(test_prompts, num_test_prompts)
+
+ # Ignore prompts that are: empty, less than 20 tokens or a title.
+ test_prompts = [
+ s.replace("\n", "").rstrip()
+ for s in test_prompts
+ if s != "" and len(s.split()) >= 20 and s.count("=") < 2
+ ][0:num_prompts]
+
+ self.test_prompts = test_prompts
+
+ self.bs = len(test_prompts)
+
+ logger.info(f" Batch size: {self.bs}")
+
+ @timeit
+ def prefill_vmfb(self, token_batch, i):
+
+ seq_block_ids = self.batch.pad_block_ids()
+ prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"](
+ token_batch,
+ self.batch.seq_lens,
+ seq_block_ids,
+ self.cache_state,
+ )
+
+ prefill_logits = torch.tensor(prefill_logits[:, :, :])
+
+ tokens = torch.tensor(
+ self.generator.model.extract_tokens_from_logits(
+ prefill_logits, self.batch.seq_lens
+ )
+ ).unsqueeze(1)
+ self.batch.add_result_token(tokens)
+
+ self.print_token_comparison(i)
+ return prefill_logits
+
+ def decode_vmfb(self, token_batch, i):
+ logger.debug("Decode:")
+
+ logger.debug("Input:")
+ logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
+ logger.debug(f"{token_batch.tolist()}")
+
+ start_positions = self.batch.seq_lens.clone()
+ self.batch.seq_lens.add_(1)
+ self.batch.allocate_seq_block_ids()
+ seq_block_ids = self.batch.pad_block_ids()
+
+ decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"](
+ token_batch,
+ self.batch.seq_lens,
+ start_positions,
+ seq_block_ids,
+ self.cache_state,
+ )
+
+ decode_logits = torch.tensor(decode_logits[:, :, :])
+
+ tokens = torch.tensor(
+ self.generator.model.extract_tokens_from_logits(
+ decode_logits, [1] * self.bs
+ ),
+ device=self.generator.model.device,
+ ).unsqueeze(1)
+ self.batch.add_result_token(tokens)
+ self.print_token_comparison(i)
+ return decode_logits
+
+ @timeit
+ def get_logits(self, page_cache_size):
+
+ is_first_token = True
+ start = 0
+ for i in tqdm(
+ range(start, self.max_prompt_length - 1),
+ mininterval=300,
+ desc="eval: Calculating logits",
+ ):
+ logger.debug(f"Iteration: {i}")
+
+ if is_first_token:
+
+ token_batch = self.token_ids[:, : i + 1]
+
+ logger.debug(f"Prefill:")
+
+ logger.debug("Input:")
+ logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
+
+ token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
+ token_ids=token_batch.tolist(),
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ )
+
+ logger.debug(f"{token_batch}")
+
+ token_batch = torch.tensor(token_batch, device=self.torch_device)
+ self.seq_lens_batch = torch.tensor(
+ seq_lens_batch, device=self.torch_device
+ )
+
+ self.batch = self.generator.begin_eval_batch(
+ token_batch=token_batch,
+ seq_lens_batch=self.seq_lens_batch,
+ bs=self.bs,
+ page_cache_size=page_cache_size,
+ )
+
+ self.cache_state = ireert.asdevicearray(
+ self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
+ )
+
+ prefill_logits = self.prefill_vmfb(token_batch, i)
+ self.out_logits = prefill_logits[:, -1:, :]
+
+ is_first_token = False
+
+ else:
+ token_batch = self.token_ids[:, i : i + 1]
+
+ decode_logits = self.decode_vmfb(token_batch, i)
+ self.out_logits = torch.cat((self.out_logits, decode_logits), 1)
+
+ pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1]
+
+ self.pad_logits = torch.zeros(
+ self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2]
+ )
+
+ self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to(
+ self.torch_device
+ )
+
+ @timeit
+ def compute_perplexity(self):
+ loss_fct = CrossEntropyLoss(reduction="none")
+
+ ## perplexity = e ^ (sum(losses) / num_tokenized_tokens)
+ crossentropy_loss = (
+ loss_fct(self.out_logits.transpose(1, 2), self.token_ids)
+ * self.attention_mask
+ ).sum(1)
+ crossentropy_loss = torch.tensor(crossentropy_loss.tolist())
+ perplexity_batch = torch.exp(
+ crossentropy_loss / self.attention_mask.sum(1)
+ ).tolist()
+
+ perplexity_batch = [round(ppl, 6) for ppl in perplexity_batch]
+
+ return {
+ "perplexities": perplexity_batch,
+ "mean_perplexity": round(np.mean(perplexity_batch), 6),
+ }
+
+ @timeit
+ def get_perplexity(self):
+
+ token_ids, seq_lens = self.generator.tokenizer.encode(
+ self.test_prompts,
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ )
+
+ self.page_cache_size = (
+ len(token_ids[0]) // self.config.block_seq_stride
+ ) * self.bs + 1
+
+ logger.debug(f" Prompts for Evaluation:")
+ for idx, prompt in enumerate(self.test_prompts):
+ logger.debug(
+ f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
+ )
+
+ self.max_prompt_length = max(seq_lens)
+
+ self.token_ids = torch.tensor(token_ids, device=self.torch_device)
+ self.attention_mask = (
+ (self.token_ids != 0).int().detach().clone().to(self.torch_device)
+ )
+
+ self.get_logits(page_cache_size=self.page_cache_size)
+
+ self.out_logits = self.out_logits[..., :-1, :].contiguous()
+ self.token_ids = self.token_ids[..., 1:].contiguous()
+ self.attention_mask = self.attention_mask[..., 1:].contiguous()
+
+ logger.debug(f"Final Logits shape: {self.out_logits.shape}")
+ logger.debug(f"Token ids: {self.token_ids}, \n{self.token_ids.shape}")
+ logger.debug(
+ f"Mask shape: {self.attention_mask}, \n{self.attention_mask.shape}"
+ )
+
+ assert self.token_ids.shape == self.out_logits.shape[0:2]
+
+ return self.compute_perplexity()
+
+
+def run_perplexity(
+ weight_path,
+ weight_path_str,
+ tokenizer,
+ torch_device,
+ iree_device,
+ iree_hip_target,
+ iree_hal_target_backends,
+ kv_cache_type,
+ tensor_parallelism_size,
+ attention_kernel,
+ num_prompts,
+):
+ start = time.time()
+ perplexity = Perplexity(
+ torch_device=torch_device,
+ iree_device=iree_device,
+ iree_hip_target=iree_hip_target,
+ iree_hal_target_backends=iree_hal_target_backends,
+ kv_cache_type=kv_cache_type,
+ tensor_parallelism_size=tensor_parallelism_size,
+ attention_kernel=attention_kernel,
+ )
+
+ perplexity.get_prompts(num_prompts=num_prompts)
+
+ vmfb_path = perplexity.compile_model(weight_path_str)
+ perplexity.load_model(weight_path, tokenizer, vmfb_path)
+ ppl = perplexity.get_perplexity()
+
+ end = time.time()
+ total_time = round(end - start, 2)
+ if total_time < 60:
+ total_time = str(total_time) + " secs"
+ else:
+ total_time = str(round(total_time / 60, 2)) + " mins"
+ logger.info(f" Total time taken: {total_time}")
+
+ return ppl
+
+
+def main(argv):
+ parser = cli.create_parser()
+ parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
+ parser.add_argument("--torch-device", help="Torch device (or default)")
+ parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')")
+ parser.add_argument(
+ "--iree-hip-target",
+ action="store",
+ default="gfx942",
+ help="Specify the iree-hip target version (e.g., gfx942)",
+ )
+ parser.add_argument(
+ "--iree-hal-target-backends",
+ action="store",
+ default="rocm",
+ help="Specify the iree-hal target backends (e.g., rocm)",
+ )
+ parser.add_argument(
+ "--attention-kernel",
+ type=str,
+ default="decomposed",
+ choices=["decomposed", "torch_sdpa"],
+ )
+ parser.add_argument(
+ "--tensor-parallelism-size",
+ type=int,
+ default=1,
+ help="Number of devices for tensor parallel sharding",
+ )
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=100,
+ help="Number of prompts for perplexity test",
+ )
+
+ cli.add_tokenizer_options(parser)
+ cli.add_input_dataset_options(parser)
+ args = cli.parse(parser, args=argv)
+
+ torch_device = torch.device(args.torch_device) if args.torch_device else None
+ iree_device = args.iree_device
+ kv_cache_type = args.kv_cache_type
+ weight_path = cli.get_input_dataset(args)
+ tokenizer = cli.get_tokenizer(args)
+ weight_path_str = str(args.irpa_file)
+
+ ppl = run_perplexity(
+ weight_path=weight_path,
+ weight_path_str=weight_path_str,
+ tokenizer=tokenizer,
+ torch_device=torch_device,
+ iree_device=iree_device,
+ iree_hip_target=args.iree_hip_target,
+ iree_hal_target_backends=args.iree_hal_target_backends,
+ kv_cache_type=kv_cache_type,
+ tensor_parallelism_size=args.tensor_parallelism_size,
+ attention_kernel=args.attention_kernel,
+ num_prompts=args.num_prompts,
+ )
+
+ logger.info(f"\n{json.dumps(ppl, indent=2)}")
+ return ppl
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/sharktank/sharktank/evaluate/perplexity_prefill.py b/sharktank/sharktank/evaluate/perplexity_prefill.py
new file mode 100644
index 000000000..2bb785801
--- /dev/null
+++ b/sharktank/sharktank/evaluate/perplexity_prefill.py
@@ -0,0 +1,276 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+import logging
+import time
+from datetime import timedelta
+
+import json
+import numpy as np
+from tqdm import tqdm
+
+import torch
+from torch.nn import CrossEntropyLoss
+
+from sharktank.layers import *
+from sharktank.types import *
+
+from sharktank.models.llama.llama import *
+from sharktank.models.mixtral.mixtral import *
+from sharktank.models.grok.grok import *
+
+from sharktank.utils import cli
+from sharktank.utils.load_llm import *
+
+log_levels = {
+ "info": logging.INFO,
+ "debug": logging.DEBUG,
+}
+logger = logging.getLogger("eval")
+
+logger.setLevel(log_levels["debug"])
+
+logger.root.handlers[0].setFormatter(
+ logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s")
+)
+
+__all__ = ["Perplexity", "run_perplexity"]
+
+
+class Perplexity:
+ """
+ Perplexity (PPL) is one of the most common metrics for evaluating language models.
+ It is defined as the exponentiated average negative log-likelihood of a sequence,
+ calculated with exponent base `e`.
+
+ For more information, see https://huggingface.co/docs/transformers/perplexity
+ """
+
+ def __init__(
+ self,
+ prompts: list,
+ device,
+ kv_cache_type,
+ ):
+ self.prompts = prompts
+ self.add_start_token = False
+ self.batch_size = 16
+ self.bs = len(prompts)
+ self.device = device
+ self.kv_cache_type = kv_cache_type
+
+ def timeit(func):
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ result = func(*args, **kwargs)
+ end = time.time()
+ seconds = end - start
+ time_taken = abs(timedelta(seconds=round(seconds)))
+
+ if seconds < 1:
+ time_taken = f" {seconds * 1000} ms"
+
+ func_name = func.__name__
+ if func_name == "get_perplexity":
+ func_name = "Total time"
+ logger.info(f" {func_name}: {time_taken}")
+ return result
+
+ return wrapper
+
+ def print_token_comparison(self, i):
+ if i <= self.max_prompt_length:
+ batch_predicted_token_id = [[i[-1]] for i in self.batch.results]
+ batch_predicted_token = self.generator.tokenizer.decode(
+ batch_predicted_token_id
+ )
+ logger.debug(f"Predicted:")
+ logger.debug(f"{batch_predicted_token}")
+ logger.debug(f"{batch_predicted_token_id}")
+
+ expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist()
+ expected_token = self.generator.tokenizer.decode(expected_token_id)
+ logger.debug(f"Expected:")
+ logger.debug(f"{expected_token}")
+ logger.debug(f"{expected_token_id}")
+
+ @timeit
+ def load_model(self, dataset, tokenizer):
+
+ theta = dataset.root_theta
+
+ config = LlamaModelConfig(
+ hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
+ block_seq_stride=16,
+ kv_cache_type=self.kv_cache_type,
+ device=self.device,
+ activation_dtype=torch.float32,
+ attention_dtype=torch.float32,
+ )
+
+ if config.hp.expert_count:
+ if config.hp.model_arch == "grok":
+ model = PagedGrokModelV1(theta, config)
+ else:
+ model = PagedMixtralModelV1(theta, config)
+ else:
+ model = PagedLlamaModelV1(theta, config)
+
+ self.generator = TorchGenerator(model, tokenizer)
+
+ @timeit
+ def get_logits(self):
+
+ token_ids, seq_lens = self.generator.tokenizer.encode(
+ self.prompts,
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ add_start_token=self.add_start_token,
+ )
+
+ logger.info(f" Prompts:")
+ for idx, prompt in enumerate(self.prompts):
+ logger.info(f" Prompt {idx} - {prompt.encode()}\n{token_ids[idx]}")
+
+ self.max_prompt_length = max(seq_lens)
+
+ self.token_ids = torch.tensor(token_ids, device=self.device)
+ self.attention_mask = (
+ (self.token_ids != 0).int().detach().clone().to(self.device)
+ )
+
+ is_first_token = True
+ for i in tqdm(
+ range(0, self.max_prompt_length - 1),
+ desc="eval: Calculating logits",
+ ):
+ token_batch = self.token_ids[:, : i + 1]
+ logger.debug(f"Prefill:")
+
+ logger.debug("Input:")
+ logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
+
+ token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
+ token_ids=token_batch.tolist(),
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ )
+
+ token_batch = torch.tensor(token_batch, device=self.device)
+ seq_lens_batch = torch.tensor(seq_lens_batch, device=self.device)
+
+ self.batch = self.generator.begin_eval_batch(
+ token_batch=token_batch,
+ seq_lens_batch=seq_lens_batch,
+ bs=self.bs,
+ )
+
+ self.cache_state = self.batch.prefill()
+ self.print_token_comparison(i)
+
+ if is_first_token:
+ self.out_logits = self.batch.prefill_logits[:, 0:1, :]
+ is_first_token = False
+ else:
+ self.out_logits = torch.cat(
+ (self.out_logits, self.batch.prefill_logits[:, 0:1, :]), 1
+ )
+
+ pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1]
+
+ self.pad_logits = torch.zeros(
+ self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2]
+ )
+
+ self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to(
+ self.device
+ )
+
+ @timeit
+ def compute_perplexity(self):
+ loss_fct = CrossEntropyLoss(reduction="none")
+
+ ## perplexity = e ^ (sum(losses) / num_tokenized_tokens)
+ crossentropy_loss = (
+ loss_fct(self.out_logits.transpose(1, 2), self.token_ids)
+ * self.attention_mask
+ ).sum(1)
+ crossentropy_loss = torch.tensor(crossentropy_loss.tolist())
+ perplexity_batch = torch.exp(
+ crossentropy_loss / self.attention_mask.sum(1)
+ ).tolist()
+
+ return {
+ "perplexities": perplexity_batch,
+ "mean_perplexity": np.mean(perplexity_batch),
+ }
+
+ @timeit
+ def get_perplexity(self):
+
+ self.get_logits()
+
+ self.out_logits = self.out_logits[..., :-1, :].contiguous()
+ self.token_ids = self.token_ids[..., 1:].contiguous()
+ self.attention_mask = self.attention_mask[..., 1:].contiguous()
+
+ assert self.token_ids.shape == self.out_logits.shape[0:2]
+
+ logger.debug(f"Logits shape: {self.out_logits.shape}")
+ logger.debug(f"Token ids: {self.token_ids}, {self.token_ids.shape}")
+ logger.debug(
+ f"Logits shape: {self.attention_mask}, {self.attention_mask.shape}"
+ )
+
+ return self.compute_perplexity()
+
+
+def run_perplexity(
+ prompts: list[str],
+ dataset,
+ tokenizer,
+ device,
+ kv_cache_type,
+):
+ perplexity = Perplexity(prompts=prompts, device=device, kv_cache_type=kv_cache_type)
+
+ perplexity.load_model(dataset, tokenizer)
+ ppl = perplexity.get_perplexity()
+
+ return ppl
+
+
+def main(argv):
+ parser = cli.create_parser()
+ parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
+ parser.add_argument("--device", help="Torch device (or default)")
+
+ cli.add_input_dataset_options(parser)
+ cli.add_tokenizer_options(parser)
+ args = cli.parse(parser, args=argv)
+
+ device = torch.device(args.device) if args.device else None
+ kv_cache_type = args.kv_cache_type
+ dataset = cli.get_input_dataset(args)
+ tokenizer = cli.get_tokenizer(args)
+
+ prompt_path = "sharktank/evaluate/data/eval_prompts.txt"
+ with open(prompt_path, "r") as f:
+ input_texts = f.read().splitlines()
+
+ ppl = run_perplexity(
+ prompts=input_texts[0:1],
+ dataset=dataset,
+ tokenizer=tokenizer,
+ device=device,
+ kv_cache_type=kv_cache_type,
+ )
+
+ logger.info(f"\n{json.dumps(ppl, indent=2)}")
+ return ppl
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py
new file mode 100644
index 000000000..258e8c9a0
--- /dev/null
+++ b/sharktank/sharktank/evaluate/perplexity_torch.py
@@ -0,0 +1,370 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import sys
+import logging
+import time
+import random
+import re
+from datetime import timedelta
+import json
+import numpy as np
+from tqdm import tqdm
+
+from datasets import load_dataset
+
+import torch
+from torch.nn import CrossEntropyLoss
+
+from sharktank.layers import *
+from sharktank.types import *
+
+from sharktank.models.llama.llama import *
+from sharktank.models.mixtral.mixtral import *
+from sharktank.models.grok.grok import *
+
+from ..models.llama.sharding import shard_theta
+
+from sharktank.utils import cli
+from sharktank.utils.load_llm import *
+
+log_levels = {
+ "info": logging.INFO,
+ "debug": logging.DEBUG,
+}
+logger = logging.getLogger("eval")
+
+logger.setLevel(log_levels["info"])
+
+logger.root.handlers[0].setFormatter(
+ logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s")
+)
+
+__all__ = ["Perplexity_torch", "run_perplexity_torch"]
+
+
+class Perplexity_torch:
+ """
+ Perplexity (PPL) is one of the most common metrics for evaluating language models.
+ It is defined as the exponentiated average negative log-likelihood of a sequence,
+ calculated with exponent base `e`.
+
+ For more information, see https://huggingface.co/docs/transformers/perplexity
+ """
+
+ def __init__(
+ self,
+ device,
+ kv_cache_type,
+ ):
+ self.device = device
+ self.kv_cache_type = kv_cache_type
+ self.activation_dtype = torch.float32
+ self.attention_dtype = torch.float32
+
+ def timeit(func):
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ result = func(*args, **kwargs)
+ end = time.time()
+ total_seconds = end - start
+ time_taken = abs(timedelta(seconds=total_seconds))
+ hours, minutes, seconds = re.split(":", str(time_taken))
+
+ if total_seconds < 1:
+ time_taken = f" {round(total_seconds * 1000, 3)} ms"
+ elif total_seconds < 60:
+ time_taken = "{:.2f} secs".format(round(float(total_seconds), 2))
+ else:
+ time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format(
+ int(hours), int(minutes), round(float(seconds), 2)
+ )
+
+ func_name = func.__name__
+ if func_name == "get_perplexity":
+ func_name = "Calculate perplexity"
+ logger.info(f" {func_name}: {time_taken}")
+ return result
+
+ return wrapper
+
+ def print_token_comparison(self, i):
+ if i <= self.max_prompt_length:
+ batch_predicted_token_id = [[i[-1]] for i in self.batch.results]
+ batch_predicted_token = self.generator.tokenizer.decode(
+ batch_predicted_token_id
+ )
+ logger.debug(f"Predicted:")
+ logger.debug(f"{batch_predicted_token}")
+ logger.debug(f"{batch_predicted_token_id}")
+
+ expected_token_id = self.token_ids[:, i + 1 : i + 2].tolist()
+ expected_token = self.generator.tokenizer.decode(expected_token_id)
+ logger.debug(f"Expected:")
+ logger.debug(f"{expected_token}")
+ logger.debug(f"{expected_token_id}")
+
+ @timeit
+ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kernel):
+
+ self.config = LlamaModelConfig(
+ hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
+ block_seq_stride=16,
+ kv_cache_type=self.kv_cache_type,
+ device=self.device,
+ activation_dtype=self.activation_dtype,
+ attention_dtype=self.attention_dtype,
+ tensor_parallelism_size=tensor_parallelism_size,
+ )
+
+ if self.config.tensor_parallelism_size > 1:
+ dataset.root_theta = shard_theta(dataset.root_theta, self.config)
+
+ theta = dataset.root_theta
+
+ if self.config.hp.expert_count:
+ if self.config.hp.model_arch == "grok":
+ model = PagedGrokModelV1(theta, self.config)
+ else:
+ model = PagedMixtralModelV1(theta, self.config)
+ else:
+ model = PagedLlamaModelV1(theta, self.config)
+
+ self.generator = TorchGenerator(model, tokenizer)
+
+ @timeit
+ def get_prompts(self, num_prompts):
+
+ test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
+ "text"
+ ]
+
+ num_test_prompts = 219
+
+ random.seed(0)
+ test_prompts = random.sample(test_prompts, num_test_prompts)
+
+ # Ignore prompts that are: empty, less than 20 tokens or a title.
+ test_prompts = [
+ s.replace("\n", "").rstrip()
+ for s in test_prompts
+ if s != "" and len(s.split()) >= 20 and s.count("=") < 2
+ ][0:num_prompts]
+
+ self.test_prompts = test_prompts
+
+ self.bs = len(test_prompts)
+
+ logger.info(f" Batch size: {self.bs}")
+
+ @timeit
+ def get_logits(self, page_cache_size):
+
+ is_first_token = True
+ start = 0
+ for i in tqdm(
+ range(start, self.max_prompt_length - 1),
+ mininterval=300,
+ desc="eval: Calculating logits",
+ ):
+ logger.debug(f"Iteration: {i}")
+
+ if is_first_token:
+
+ token_batch = self.token_ids[:, : i + 1]
+ logger.debug(f"Prefill:")
+
+ logger.debug("Input:")
+ logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
+
+ token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
+ token_ids=token_batch.tolist(),
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ )
+
+ logger.debug(f"{token_batch}")
+
+ token_batch = torch.tensor(token_batch, device=self.device)
+ seq_lens_batch = torch.tensor(seq_lens_batch, device=self.device)
+
+ self.batch = self.generator.begin_eval_batch(
+ token_batch=token_batch,
+ seq_lens_batch=seq_lens_batch,
+ bs=self.bs,
+ page_cache_size=page_cache_size,
+ )
+
+ self.batch.prefill()
+ self.out_logits = self.batch.prefill_logits[:, 0:1, :]
+ is_first_token = False
+
+ self.print_token_comparison(i)
+
+ else:
+ token_batch = self.token_ids[:, i : i + 1]
+
+ logger.debug("Decode:")
+
+ logger.debug("Input:")
+ logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
+ logger.debug(f"{token_batch.tolist()}")
+
+ self.batch.decode(token_batch=token_batch)
+ self.out_logits = torch.cat(
+ (self.out_logits, self.batch.decode_logits), 1
+ )
+
+ self.print_token_comparison(i)
+
+ pad_logits_shape = self.token_ids.shape[1] - self.out_logits.shape[1]
+
+ self.pad_logits = torch.zeros(
+ self.out_logits.shape[0], pad_logits_shape, self.out_logits.shape[2]
+ )
+
+ self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to(
+ self.device
+ )
+
+ @timeit
+ def compute_perplexity(self):
+ loss_fct = CrossEntropyLoss(reduction="none")
+
+ ## perplexity = e ^ (sum(losses) / num_tokenized_tokens)
+ crossentropy_loss = (
+ loss_fct(self.out_logits.transpose(1, 2), self.token_ids)
+ * self.attention_mask
+ ).sum(1)
+ crossentropy_loss = torch.tensor(crossentropy_loss.tolist())
+ perplexity_batch = torch.exp(
+ crossentropy_loss / self.attention_mask.sum(1)
+ ).tolist()
+
+ perplexity_batch = [round(ppl, 6) for ppl in perplexity_batch]
+
+ return {
+ "perplexities": perplexity_batch,
+ "mean_perplexity": round(np.mean(perplexity_batch), 6),
+ }
+
+ @timeit
+ def get_perplexity(self):
+
+ token_ids, seq_lens = self.generator.tokenizer.encode(
+ self.test_prompts,
+ pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
+ )
+
+ self.page_cache_size = (
+ len(token_ids[0]) // self.config.block_seq_stride
+ ) * self.bs + 1
+
+ logger.debug(f" Prompts for Evaluation:")
+ for idx, prompt in enumerate(self.test_prompts):
+ logger.debug(
+ f" Prompt {idx}: \nTokens: {prompt.encode()}\nToken ids: {token_ids[idx]}\n"
+ )
+
+ self.max_prompt_length = max(seq_lens)
+
+ self.token_ids = torch.tensor(token_ids, device=self.device)
+ self.attention_mask = (
+ (self.token_ids != 0).int().detach().clone().to(self.device)
+ )
+
+ self.get_logits(page_cache_size=self.page_cache_size)
+
+ self.out_logits = self.out_logits[..., :-1, :].contiguous()
+ self.token_ids = self.token_ids[..., 1:].contiguous()
+ self.attention_mask = self.attention_mask[..., 1:].contiguous()
+
+ logger.debug(f"Final Logits shape: {self.out_logits.shape}")
+ logger.debug(f"Token ids: {self.token_ids}, \n{self.token_ids.shape}")
+ logger.debug(
+ f"Mask shape: {self.attention_mask}, \n{self.attention_mask.shape}"
+ )
+
+ assert self.token_ids.shape == self.out_logits.shape[0:2]
+
+ return self.compute_perplexity()
+
+
+def run_perplexity_torch(
+ dataset,
+ tokenizer,
+ device,
+ kv_cache_type,
+ tensor_parallelism_size,
+ attention_kernel,
+ num_prompts,
+):
+ start = time.time()
+
+ perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type)
+ perplexity.get_prompts(num_prompts=num_prompts)
+ perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel)
+ ppl = perplexity.get_perplexity()
+
+ end = time.time()
+ total_time = round(end - start, 2)
+ if total_time < 60:
+ total_time = str(total_time) + " secs"
+ else:
+ total_time = str(round(total_time / 60, 2)) + " mins"
+ logger.info(f" Total time taken: {total_time}")
+
+ return ppl
+
+
+def main(argv):
+ parser = cli.create_parser()
+ parser.add_argument("--kv-cache-type", default="paged", help="KV cache type")
+ parser.add_argument("--device", help="Torch device (or default)")
+ parser.add_argument(
+ "--attention-kernel",
+ type=str,
+ default="decomposed",
+ choices=["decomposed", "torch_sdpa"],
+ )
+
+ parser.add_argument(
+ "--tensor-parallelism-size",
+ type=int,
+ default=1,
+ help="Number of devices for tensor parallel sharding.",
+ )
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=100,
+ help="Number of prompts for perplexity test",
+ )
+
+ cli.add_input_dataset_options(parser)
+ cli.add_tokenizer_options(parser)
+ args = cli.parse(parser, args=argv)
+
+ device = torch.device(args.device) if args.device else None
+ kv_cache_type = args.kv_cache_type
+ dataset = cli.get_input_dataset(args)
+ tokenizer = cli.get_tokenizer(args)
+
+ ppl = run_perplexity_torch(
+ dataset=dataset,
+ tokenizer=tokenizer,
+ device=device,
+ kv_cache_type=kv_cache_type,
+ tensor_parallelism_size=args.tensor_parallelism_size,
+ attention_kernel=args.attention_kernel,
+ num_prompts=args.num_prompts,
+ )
+
+ logger.info(f"\n{json.dumps(ppl, indent=2)}")
+ return ppl
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py
index 5439cc38b..6dd9785c3 100644
--- a/sharktank/sharktank/examples/export_paged_llm_v1.py
+++ b/sharktank/sharktank/examples/export_paged_llm_v1.py
@@ -9,14 +9,18 @@
import json
import torch
-from shark_turbine.aot import *
+from iree.turbine.aot import *
from sharktank.layers import *
from sharktank.types import *
+from sharktank.utils.math import ceildiv
# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
+from ..models.llama.sharding import shard_theta
from ..models.mixtral.mixtral import *
+from ..models.grok.grok import *
+from .. import ops
def main():
@@ -32,7 +36,7 @@ def main():
parser.add_argument(
"--output-config",
help="Output file path for exported config file",
- default="tmp/batch_llama_v1.json",
+ default="/tmp/batch_llama_v1.json",
)
parser.add_argument(
"--bs",
@@ -45,16 +49,40 @@ def main():
help="Include verbose logging",
action="store_true",
)
+ parser.add_argument(
+ "--strict",
+ help="Enables strictness during export",
+ action="store_true",
+ )
+ cli.add_quantization_options(parser)
+ cli.add_model_options(parser)
args = cli.parse(parser)
+ dataset_type = cli.get_input_data_files(args)
+ dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)
-
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
- llama_config = LlamaModelConfig(hp)
- llama_config.static_tables = False # Rely on the compiler for hoisting tables.
- llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
+ tensor_parallelism_size = (
+ dataset.properties["tensor_parallelism_size"]
+ if "tensor_parallelism_size" in dataset.properties
+ else 1
+ )
+
+ llama_config = LlamaModelConfig(
+ hp,
+ tensor_parallelism_size=tensor_parallelism_size,
+ use_hf=False,
+ static_tables=False, # Rely on the compiler for hoisting tables.
+ kv_cache_type="direct" if args.bs == [1] else "paged",
+ attention_kernel=args.attention_kernel,
+ )
+ llama_config.fake_quant = args.fake_quant
+
if llama_config.hp.expert_count:
- model = PagedMixtralModelV1(dataset.root_theta, llama_config)
+ if llama_config.hp.model_arch == "grok":
+ model = PagedGrokModelV1(dataset.root_theta, llama_config)
+ else:
+ model = PagedMixtralModelV1(dataset.root_theta, llama_config)
else:
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
@@ -80,86 +108,157 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
fxb = FxProgramsBuilder(model)
- def generate_batch_prefill(bs: int):
- tokens = torch.empty(bs, 64, dtype=torch.int64)
- seq_lens = torch.empty(bs, dtype=torch.int64)
- seq_block_ids = torch.empty(bs, 4, dtype=torch.int64)
- block_dim = torch.export.Dim(
- "block", max=(hp.context_length - 1) // llama_config.block_seq_stride
- )
- sl_dim = llama_config.block_seq_stride * block_dim
-
+ def setup_cache(model, shard_count):
if model.config.kv_cache_type == "paged":
cache_state = model.cache.allocate(
page_count=hp.context_length // llama_config.block_seq_stride
)
page_dim = torch.export.Dim("page")
- cache_state_dynamic_shapes = [{0: page_dim}]
+
+ dynamic_shapes = [{0: page_dim}]
+ unpacked = cache_state
+ arg_affinities = {}
+ shard_dim = None
+
+ # Need to unpacke that state when sharded
+ if llama_config.tensor_parallelism_size > 1:
+ shard_dim = cache_state[0].shard_dim
+
+ unpacked = [[shard._data for shard in cs.shards] for cs in cache_state]
+ dynamic_shapes = [
+ [ds] * llama_config.tensor_parallelism_size for ds in dynamic_shapes
+ ]
+
+ for i in range(llama_config.tensor_parallelism_size):
+ arg_affinities[i] = DeviceAffinity(str(i))
+
+ return unpacked, shard_dim, dynamic_shapes, arg_affinities
+
elif model.config.kv_cache_type == "direct":
cache_state = model.cache.allocate(bs=1)
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
- cache_state_dynamic_shapes = (2 * hp.block_count) * [{}]
+ dynamic_shapes = [None]
+ arg_affinities = {}
+ shard_dim = None
+ return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")
+ def repack_cache(cache, shard_dim):
+ return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]
+
+ def generate_batch_prefill(bs: int):
+ # torch.export.Dim would make min at least 2
+ block_dim_min = 2
+ block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
+ block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
+ sl_dim = llama_config.block_seq_stride * block_dim
+ seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)
+ tokens = torch.empty(
+ bs,
+ seq_block_ids.shape[1] * llama_config.block_seq_stride,
+ dtype=torch.int64,
+ )
+ seq_lens = torch.empty(bs, dtype=torch.int64)
+
+ cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
+ model, llama_config.tensor_parallelism_size
+ )
+
+ if llama_config.tensor_parallelism_size > 1:
+ # We need to offset the indices for the cache
+ arg_affinities = {key + 3: arg_affinities[key] for key in arg_affinities}
+
+ for i in range(3):
+ arg_affinities[i] = DeviceAffinity("0")
+
dynamic_shapes = {
"tokens": {1: sl_dim},
"seq_lens": {},
"seq_block_ids": {1: block_dim},
- "cache_state": cache_state_dynamic_shapes,
+ "cs": cache_dynamic_shapes,
}
print(f"Exporting prefill_bs{bs}")
@fxb.export_program(
name=f"prefill_bs{bs}",
- args=(tokens, seq_lens, seq_block_ids, cache_state),
+ args=(tokens, seq_lens, seq_block_ids, cache),
dynamic_shapes=dynamic_shapes,
+ strict=args.strict,
+ arg_device=arg_affinities,
)
- def _(model, tokens, seq_lens, seq_block_ids, cache_state):
+ def _(model, tokens, seq_lens, seq_block_ids, cs):
+ if (
+ model.config.tensor_parallelism_size == 1
+ and model.config.kv_cache_type == "direct"
+ ):
+ cache_tensors = torch.unbind(cs)
+ else:
+ cache_tensors = cs
+
sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)
+
+ if llama_config.tensor_parallelism_size != 1:
+ shard_count = llama_config.tensor_parallelism_size
+
+ tokens = ops.replicate(tokens, count=shard_count)
+ attention_mask = ops.replicate(attention_mask, count=shard_count)
+ seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)
+
+ cache_tensors = repack_cache(cs, cache_shard_dim)
+
logits = model.prefill(
tokens,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
- cache_state=cache_state,
+ cache_state=cache_tensors,
)
+
+ if llama_config.tensor_parallelism_size != 1:
+ logits = ops.unshard(logits)
+
return logits
def generate_batch_decode(bs: int):
- tokens = torch.ones(bs, 1, dtype=torch.int64)
- seq_lens = torch.ones(bs, dtype=torch.int64)
- start_positions = torch.ones(bs, dtype=torch.int64)
- seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64)
- block_dim = torch.export.Dim(
- "block", max=(hp.context_length - 1) // llama_config.block_seq_stride
+ # torch.export.Dim would make min at least 2
+ block_dim_min = 2
+ block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
+ block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
+ tokens = torch.empty(
+ bs,
+ 1,
+ dtype=torch.int64,
)
+ seq_lens = torch.empty(bs, dtype=torch.int64)
+ start_positions = torch.ones(bs, dtype=torch.int64)
+ seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)
- if model.config.kv_cache_type == "paged":
- cache_state = model.cache.allocate(
- page_count=hp.context_length // llama_config.block_seq_stride
- )
- page_dim = torch.export.Dim("page")
- cache_state_dynamic_shapes = [{0: page_dim}]
- elif model.config.kv_cache_type == "direct":
- cache_state = model.cache.allocate(bs=1)
- # Direct cache dimensions:
- # 2 * transformer_block_count of...
- # [bs, seq_length, attn_head_count, attn_head_dim]
- cache_state_dynamic_shapes = (2 * hp.block_count) * [{}]
- else:
- raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")
+ (
+ cache_state,
+ cache_shard_dim,
+ cache_dynamic_shapes,
+ arg_affinities,
+ ) = setup_cache(model, llama_config.tensor_parallelism_size)
+
+ if llama_config.tensor_parallelism_size > 1:
+ # We need to offset the indices for the cache
+ arg_affinities = {key + 4: arg_affinities[key] for key in arg_affinities}
+
+ # Inputs have default affinity 0
+ for i in range(4):
+ arg_affinities[i] = DeviceAffinity("0")
dynamic_shapes = {
"tokens": {},
"seq_lens": {},
"start_positions": {},
"seq_block_ids": {1: block_dim},
- "cache_state": cache_state_dynamic_shapes,
+ "cache_state": cache_dynamic_shapes,
}
print(f"Exporting decode_bs{bs}")
@@ -174,6 +273,8 @@ def generate_batch_decode(bs: int):
cache_state,
),
dynamic_shapes=dynamic_shapes,
+ strict=args.strict,
+ arg_device=arg_affinities,
)
def _(
model,
@@ -187,6 +288,17 @@ def _(
seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride
)
attention_mask = model.decode_attention_mask(input_mask)
+
+ if llama_config.tensor_parallelism_size != 1:
+ shard_count = llama_config.tensor_parallelism_size
+
+ tokens = ops.replicate(tokens, count=shard_count)
+ attention_mask = ops.replicate(attention_mask, count=shard_count)
+ start_positions = ops.replicate(start_positions, count=shard_count)
+ seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)
+
+ cache_state = repack_cache(cache_state, cache_shard_dim)
+
logits = model.decode(
tokens,
attention_mask=attention_mask,
@@ -194,12 +306,17 @@ def _(
seq_block_ids=seq_block_ids,
cache_state=cache_state,
)
+
+ if llama_config.tensor_parallelism_size != 1:
+ logits = ops.unshard(logits)
+
return logits
bsizes = []
for bs in args.bs:
generate_batch_prefill(bs)
- generate_batch_decode(bs)
+ if not args.skip_decode:
+ generate_batch_decode(bs)
bsizes.append(bs)
config = generate_params_json(hp, bsizes, bsizes)
print("GENERATED!")
@@ -209,7 +326,7 @@ def _(
print(f"EXPORT {name}:\n{ep}")
print("Exporting")
- output = export(fxb)
+ output = export(fxb, import_symbolic_shape_expressions=True)
print(f"Saving to '{args.output_mlir}'")
output.save_mlir(args.output_mlir)
json.dump(config, open(args.output_config, "w"))
diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py
index c415b2ce1..b30acc026 100644
--- a/sharktank/sharktank/examples/paged_llm_v1.py
+++ b/sharktank/sharktank/examples/paged_llm_v1.py
@@ -8,6 +8,8 @@
from typing import Optional
+from safetensors import safe_open
+
import math
import sys
@@ -16,11 +18,15 @@
from ..layers import *
from ..types import *
+from ..ops import replicate, unshard
+
# TODO: Should be using a base class with the protocol supported.
from ..models.mixtral.mixtral import *
+from ..models.grok.grok import *
from ..models.llama.llama import *
+from ..models.llama.sharding import shard_theta
from ..utils.debugging import trace_tensor
-from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
+from ..utils.tokenizer import InferenceTokenizer
class TorchGenerator:
@@ -38,9 +44,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
+ self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
- self.free_pages = list(range(1, 128))
self.end_token = end_token
@property
@@ -51,6 +57,7 @@ def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)
+
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
if self.shared_cache_state is not None:
@@ -145,13 +152,23 @@ def prefill(self):
trace_tensor("prefill.token_ids", self.token_ids)
trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor)
trace_tensor("prefill.attention_mask", attention_mask)
+
+ token_ids = self.token_ids
+ if model.config.tensor_parallelism_size != 1:
+ tp = model.config.tensor_parallelism_size
+ token_ids = replicate(token_ids, tp)
+ attention_mask = replicate(attention_mask, tp)
+ seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
+
logits = model.prefill(
- self.token_ids,
+ token_ids,
attention_mask=attention_mask,
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
+ logits = unshard(logits)
+
# TODO: Generalize the sampling and don't make it swap on/off cpu.
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
@@ -179,6 +196,14 @@ def decode(self):
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
trace_tensor("decode.attention_mask", decode_attention_mask)
+
+ if model.config.tensor_parallelism_size != 1:
+ tp = model.config.tensor_parallelism_size
+ self.next_tokens = replicate(self.next_tokens, tp)
+ start_positions = replicate(start_positions, tp)
+ seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
+ decode_attention_mask = replicate(decode_attention_mask, tp)
+
logits = model.decode(
self.next_tokens,
attention_mask=decode_attention_mask,
@@ -186,6 +211,8 @@ def decode(self):
seq_block_ids=seq_block_ids_tensor,
cache_state=self.cache_state,
)
+
+ logits = unshard(logits)
trace_tensor("decode.logits", logits)
# TODO: Normalize the output of extract_tokens_from_logits into
# tensor [bs, 1].
@@ -218,17 +245,28 @@ def main():
help="DType to use for activations in the model",
default="float32",
)
+ parser.add_argument(
+ "--use-hf",
+ action="store_true",
+ default=False,
+ )
+ parser.add_argument(
+ "--tensor-parallelism-size",
+ type=int,
+ default=1,
+ help="How many devices are involved for tensor parallel sharding.",
+ )
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
+ cli.add_quantization_options(parser)
+ cli.add_model_options(parser)
args = cli.parse(parser)
-
device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
assert isinstance(activation_dtype, torch.dtype)
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
-
config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
@@ -236,13 +274,22 @@ def main():
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
+ attention_kernel=args.attention_kernel,
+ use_hf=args.use_hf,
+ tensor_parallelism_size=args.tensor_parallelism_size,
+ fake_quant=args.fake_quant,
)
+ if config.tensor_parallelism_size > 1:
+ dataset.root_theta = shard_theta(dataset.root_theta, config)
if config.hp.expert_count:
- model = PagedMixtralModelV1(dataset.root_theta, config)
+ if config.hp.model_arch == "grok":
+ model = PagedGrokModelV1(dataset.root_theta, config)
+ else:
+ model = PagedMixtralModelV1(dataset.root_theta, config)
else:
model = PagedLlamaModelV1(dataset.root_theta, config)
-
+
if args.save_intermediates_path:
from ..utils.patching import SaveModuleResultTensorsPatch
@@ -272,6 +319,7 @@ def main():
)
print(f":: Result tokens: {batch.results}")
batch.print_current_results()
+ counter += 1
if __name__ == "__main__":
diff --git a/sharktank/sharktank/examples/sharding/export_ffn_net.py b/sharktank/sharktank/examples/sharding/export_ffn_net.py
index f80b9a2ac..f261a92e1 100644
--- a/sharktank/sharktank/examples/sharding/export_ffn_net.py
+++ b/sharktank/sharktank/examples/sharding/export_ffn_net.py
@@ -50,6 +50,7 @@ def forward(self, x: torch.Tensor):
ffn_gate_weight = self.theta.tensor("ffn_gate", "weight")
ffn_up_weight = self.theta.tensor("ffn_up", "weight")
ffn_down_weight = self.theta.tensor("ffn_down", "weight")
+ x = ops.replicate(x, count=ffn_gate_weight.shard_count)
ffn_gate = ops.elementwise(
torch.nn.functional.silu, ops.linear(x, ffn_gate_weight)
)
@@ -89,7 +90,7 @@ def main(raw_args=None):
ds = Dataset.load(args.output_irpa_file)
mdl = ShardedFFN(ds.root_theta)
- from shark_turbine import aot
+ from iree.turbine import aot
example_arg = torch.empty(bs, sl, primary_dim, dtype=torch.float16)
ep = torch.export.export(mdl, (example_arg,))
diff --git a/sharktank/sharktank/examples/sharding/export_gemm.py b/sharktank/sharktank/examples/sharding/export_gemm.py
index 7a4322e38..9744a6d82 100644
--- a/sharktank/sharktank/examples/sharding/export_gemm.py
+++ b/sharktank/sharktank/examples/sharding/export_gemm.py
@@ -4,7 +4,7 @@
import torch
from torch import Tensor
from sharktank import ops
-from shark_turbine import aot
+from iree.turbine import aot
def export_gemm(
diff --git a/sharktank/sharktank/examples/sharding/shard_llm_dataset.py b/sharktank/sharktank/examples/sharding/shard_llm_dataset.py
index 0921a2c83..91e88d071 100644
--- a/sharktank/sharktank/examples/sharding/shard_llm_dataset.py
+++ b/sharktank/sharktank/examples/sharding/shard_llm_dataset.py
@@ -10,7 +10,8 @@
weights of an LLM by converting the RHS of all eligible layers to a sharded
form.
"""
-from ...transforms.dataset import MmtRHSShardingTransform
+from ...models.llama.sharding import shard_theta
+from ...layers import LlamaHParams, LlamaModelConfig
from ...types import *
@@ -21,16 +22,30 @@ def main(raw_args=None):
cli.add_input_dataset_options(parser)
cli.add_output_dataset_options(parser)
parser.add_argument(
- "--num-shards", type=int, required=True, help="Number of shards to split"
+ "--tensor-parallelism-size",
+ type=int,
+ required=True,
+ help="Number of shards to split",
)
args = cli.parse(parser, args=raw_args)
dataset = cli.get_input_dataset(args)
- tr = MmtRHSShardingTransform(
- r"^blk\.[0-9]+\.(attn_k|attn_q|attn_v|ffn_gate|ffn_up|ffn_down)\.weight$",
- num_shards=8,
+ if args.output_irpa_file is None:
+ raise RuntimeError(f"Need file destination for IRPA file")
+
+ if args.tensor_parallelism_size < 2:
+ raise RuntimeError(
+ f"Expect sharding greater than 1 found {args.tensor_parallelism_size}"
+ )
+
+ hp = LlamaHParams.from_gguf_props(dataset.properties)
+ llama_config = LlamaModelConfig(
+ hp, tensor_parallelism_size=args.tensor_parallelism_size
)
- dataset.transform(tr)
+ sharded_theta = shard_theta(dataset.root_theta, llama_config)
+ sharded_theta.rename_tensors_to_paths()
+ dataset.root_theta = sharded_theta
+ dataset.properties["tensor_parallelism_size"] = args.tensor_parallelism_size
dataset.save(args.output_irpa_file, io_report_callback=print)
diff --git a/sharktank/sharktank/export.py b/sharktank/sharktank/export.py
new file mode 100644
index 000000000..0a1c6940d
--- /dev/null
+++ b/sharktank/sharktank/export.py
@@ -0,0 +1,174 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Callable, Any
+import torch
+from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
+from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
+from .types.tensors import ShardedTensor
+from torch.utils._pytree import PyTree, _is_leaf
+import functools
+
+
+def flatten_signature(
+ *sample_args: list[PyTree], **sample_kwargs: dict[str, PyTree]
+) -> Callable[[Callable], Any]:
+ """Decorator that flattens the signature of a function using PyTorch's type
+ registration.
+ It will flatten the same way torch PyTorch does, returning a function that accepts
+ and returns a flat list of torch.Tensor.
+ The decorator requires sample arguments of the unflattened function.
+
+ ```
+ @flatten_signature(
+ {
+ "a1": SplitPrimitiveTensor(ts=[torch.tensor([1])], shard_dim=0),
+ "a2": torch.tensor([2]),
+ },
+ [DefaultPrimitiveTensor(data=torch.tensor([3]))]
+ )
+ def f(a, b):
+ return a["a1"], b
+ ```
+
+ will result in a function with signature
+
+ ```
+ (
+ torch.Tensor of size 1,
+ torch.Tensor of size 2,
+ torch.Tensor of size 3,
+ ) -> (
+ torch.Tensor of size 1,
+ torch.Tensor of size 2,
+ )
+ ```
+ """
+ flat_sample_args, args_tree_spec = tree_flatten(sample_args)
+ n_args = len(flat_sample_args)
+ kwargs_tree_spec = tree_structure(sample_kwargs)
+
+ def _decorator(f: Callable) -> Callable:
+ def _wrapper(*flat_args: list[Any]) -> list[Any]:
+ unflattended_args = tree_unflatten(flat_args[:n_args], args_tree_spec)
+ unflattended_kwargs = tree_unflatten(flat_args[n_args:], kwargs_tree_spec)
+ return tree_flatten(f(*unflattended_args, **unflattended_kwargs))[0]
+
+ return _wrapper
+
+ return _decorator
+
+
+def get_argument_flat_device_affinities(
+ *args: list[PyTree], **kwargs: dict[str, PyTree]
+) -> dict[int, DeviceAffinity]:
+ """Return the flat device affinities for unflattened arguments.
+ ShardedTensor types have their device affinities assigned.
+ All other arguments are left unassigned.
+
+ ```
+ get_argument_flat_device_affinities(
+ torch.Tensor([1]),
+ [ReplicatedTensor(ts=[torch.tensor([2]), torch.tensor([3])])]
+ )
+ ```
+ returns
+ ```
+ {
+ 1: DeviceAffinity("0"),
+ 2: DeviceAffinity("1"),
+ }
+ ```
+ """
+
+ def is_leaf(v: PyTree) -> bool:
+ if isinstance(v, ShardedTensor):
+ return True
+ # TODO: It is sad _is_leaf is private. Find a way not use it.
+ from torch.utils._pytree import _is_leaf
+
+ return _is_leaf(v)
+
+ # flattened up to a sharded tensor.
+ flat_args_up_to_sharded_tensor = tree_flatten((args, kwargs), is_leaf=is_leaf)[0]
+ nested_device_affinities: list[list[DeviceAffinity | None]] = [
+ [DeviceAffinity(f"{shard_idx}") for shard_idx in range(len(arg.shards))]
+ if isinstance(arg, ShardedTensor)
+ else [None]
+ for arg in flat_args_up_to_sharded_tensor
+ ]
+ flat_device_affinities: list[DeviceAffinity | None] = [
+ affinity
+ for affinity_list in nested_device_affinities
+ for affinity in affinity_list
+ ]
+ return {
+ arg_idx: affinity
+ for arg_idx, affinity in enumerate(flat_device_affinities)
+ if affinity is not None
+ }
+
+
+def export(
+ f: Callable | None = None,
+ fx_builder: FxProgramsBuilder | None = None,
+ args: tuple[PyTree] | None = None,
+ kwargs: dict[PyTree] | None = None,
+ arg_device: dict[int, DeviceAffinity] | None = None,
+ *transitive_args,
+ **transitive_kwargs,
+) -> torch.export.ExportedProgram:
+ """Wrapper around FxProgramsBuilder.export_program that handles
+ the sharktank custom tensor types.
+
+ If `arg_device` is not specified it will extract the affinities
+ from the passed `args`.
+ `arg_device` must pass the affinities for the flattened arguments.
+ These are those that correspond to torch.Tensor.
+ For example a sharded tensor with 2 shards would result in 2 arguments in the MLIR
+ signature."""
+
+ if f is None:
+ return functools.partial(
+ export,
+ fx_builder=fx_builder,
+ args=args,
+ kwargs=kwargs,
+ arg_device=arg_device,
+ *transitive_args,
+ **transitive_kwargs,
+ )
+
+ if args is None:
+ args = []
+ if kwargs is None:
+ kwargs = {}
+ if arg_device is None:
+ arg_device = get_argument_flat_device_affinities(*args, **kwargs)
+ flat_args = tree_flatten((args, kwargs))[0]
+ if fx_builder is not None:
+ # Flatten the signature of the function.
+ # Technically this is done during export, but we want the signature to match
+ # the flat device affinities.
+ def module_fn_with_flat_signature(module, *flat_args):
+ @flatten_signature(*args, **kwargs)
+ def flat_fn(*args, **kwargs):
+ return f(module, *args, **kwargs)
+
+ return flat_fn(*flat_args)
+
+ amended_kwargs = dict(**transitive_kwargs)
+ if "name" not in amended_kwargs or amended_kwargs["name"] is None:
+ amended_kwargs["name"] = f.__name__
+ return fx_builder.export_program(
+ module_fn_with_flat_signature,
+ *transitive_args,
+ args=flat_args,
+ arg_device=arg_device,
+ **amended_kwargs,
+ )
+
+ assert False, "TODO: implement the case when not using an FxProgramsBuilder"
diff --git a/sharktank/sharktank/export_layer/export_kv_cache.py b/sharktank/sharktank/export_layer/export_kv_cache.py
new file mode 100644
index 000000000..09f0a1c15
--- /dev/null
+++ b/sharktank/sharktank/export_layer/export_kv_cache.py
@@ -0,0 +1,125 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import torch
+import torch.nn.functional as F
+
+from iree.turbine.aot import *
+
+from sharktank.types import SplitPrimitiveTensor
+from sharktank.ops import reshard_split, replicate
+from sharktank.layers.kv_cache import PagedKVCache
+from ..utils import cli
+
+
+def main():
+ parser = cli.create_parser()
+ parser.add_argument(
+ "--output-mlir",
+ help="Output file path for exported MLIR file",
+ default="/tmp/kv_cache.mlir",
+ )
+ parser.add_argument(
+ "--batch-size",
+ "-bs",
+ help="Batch size to generate, e.g. `4` or `2`",
+ type=lambda arg: int(arg),
+ default="2",
+ )
+ parser.add_argument(
+ "--sharding",
+ help="Sharding level of kv-cache",
+ type=lambda arg: int(arg),
+ default="1",
+ )
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ help="Include verbose logging",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--strict",
+ help="Enables strictness during export",
+ action="store_true",
+ )
+
+ args = cli.parse(parser)
+
+ bs = args.batch_size
+
+ bs = 4
+ seq_length = 24
+ attn_head_count = 4
+ attn_head_dim = 16
+ transformer_block_count = 4
+ block_seq_stride = 4
+ page_count = bs * seq_length // block_seq_stride
+ write_seq_length = seq_length - 4
+
+ cache = PagedKVCache(
+ block_seq_stride=block_seq_stride,
+ transformer_block_count=transformer_block_count,
+ attn_head_count=attn_head_count,
+ attn_head_dim=attn_head_dim,
+ shard_count=args.sharding,
+ dtype=torch.float32,
+ device=None,
+ )
+
+ alloc = cache.allocate(page_count=page_count)
+ allocation = alloc
+
+ model = torch.nn.Module()
+ fxb = FxProgramsBuilder(model)
+
+ page_ids = torch.empty(bs, seq_length // block_seq_stride, dtype=torch.int64)
+ write_page_ids = page_ids[:, : write_seq_length // block_seq_stride]
+ partition_0 = torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ )
+
+ if args.sharding > 1:
+ partition_0 = reshard_split(partition_0, dim=2, count=args.sharding).shards
+ allocation = allocation[0].shards
+
+ argument_device_affinities = {}
+ for i in range(args.sharding):
+ argument_device_affinities[i] = DeviceAffinity(f"{i}")
+ argument_device_affinities[i + args.sharding] = DeviceAffinity(f"{i}")
+
+ @fxb.export_program(
+ name="write",
+ args=(allocation, partition_0, write_page_ids),
+ strict=False,
+ argument_device_affinities=argument_device_affinities,
+ )
+ def _(model, state, partition_0, write_page_ids: torch.Tensor) -> torch.Tensor:
+ old_state = state
+ if args.sharding > 1:
+ state = [SplitPrimitiveTensor(ts=state, shard_dim=alloc[0].shard_dim)]
+ partition_0 = SplitPrimitiveTensor(ts=partition_0, shard_dim=2)
+ write_page_ids = replicate(write_page_ids, count=args.sharding)
+ cache.write(
+ state,
+ cache_partitions=[partition_0, partition_0],
+ transformer_block_index=1,
+ page_ids=write_page_ids,
+ )
+ return state
+
+ if args.verbose:
+ for name, ep in fxb.programs.items():
+ print(f"EXPORT {name}:\n{ep}")
+
+ print("Exporting")
+ output = export(fxb)
+ print(f"Saving to '{args.output_mlir}'")
+ output.save_mlir(args.output_mlir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py
new file mode 100644
index 000000000..af4ed51d2
--- /dev/null
+++ b/sharktank/sharktank/export_layer/export_moe.py
@@ -0,0 +1,83 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import torch
+import torch.nn.functional as F
+
+from iree.turbine.aot import *
+
+from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
+from sharktank.layers.mixture_of_experts_block import MoeBlock
+from ..utils import cli
+
+
+def main():
+ parser = cli.create_parser()
+ parser.add_argument(
+ "--output-mlir",
+ help="Output file path for exported MLIR file",
+ default="/tmp/batch_llama_v1.mlir",
+ )
+ parser.add_argument(
+ "--batch-size",
+ "-bs",
+ help="Batch size to generate, e.g. `4` or `2`",
+ type=lambda arg: int(arg),
+ default="2",
+ )
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ help="Include verbose logging",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--strict",
+ help="Enables strictness during export",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--use-gelu",
+ help="Enable to use gelu for moe activation",
+ action="store_true",
+ )
+
+ args = cli.parse(parser)
+
+ bs = args.batch_size
+
+ model = MoeBlock(
+ theta=make_moe_block_theta()("blk.0"),
+ expert_count=8,
+ expert_used_count=2,
+ rms_epsilon=1e-5,
+ moe_activation=F.gelu if args.use_gelu else F.silu,
+ )
+ fxb = FxProgramsBuilder(model)
+ input = make_rand_torch((bs, 32, 6144))
+
+ @fxb.export_program(name="prefill_moe", args=(input,))
+ def _(model, input: torch.Tensor) -> torch.Tensor:
+ return model(input)
+
+ input = make_rand_torch((bs, 1, 6144))
+
+ @fxb.export_program(name="decode_moe", args=(input,))
+ def _(model, input: torch.Tensor) -> torch.Tensor:
+ return model(input)
+
+ if args.verbose:
+ for name, ep in fxb.programs.items():
+ print(f"EXPORT {name}:\n{ep}")
+
+ print("Exporting")
+ output = export(fxb)
+ print(f"Saving to '{args.output_mlir}'")
+ output.save_mlir(args.output_mlir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py
new file mode 100644
index 000000000..cb28371bb
--- /dev/null
+++ b/sharktank/sharktank/export_layer/export_paged_attention.py
@@ -0,0 +1,424 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Export support for the PagedLLMV1 protocol of models."""
+
+import json
+import torch
+
+from typing import Optional
+
+import torch.nn.functional as F
+
+from iree.turbine.aot import *
+
+from sharktank.layers import *
+from sharktank.types import *
+
+from sharktank.models.llama.testing import *
+from sharktank.layers import causal_llm
+
+from sharktank.utils.create_cache import *
+
+# TODO: Should be using a base class with the protocol supported.
+from ..models.llama.llama import LlamaModelConfig, PagedLlamaAttentionBlock
+
+
+def paged_attention(
+ attention_block: PagedLlamaAttentionBlock,
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ xv: torch.Tensor,
+ is_causal: bool,
+ seq_block_ids: torch.Tensor,
+ start_positions: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cache_state: list[torch.Tensor] = None,
+ xk_temp: Optional[torch.Tensor] = None,
+ xv_temp: Optional[torch.Tensor] = None,
+):
+
+ bs, batch_seq_len, _, _ = xq.shape
+
+ # Full sequence length.
+ kv_seq_len = seq_block_ids.shape[1] * attention_block.cache.block_seq_stride
+
+ if attention_block.cache.is_paged:
+ xk, xv = attention_block.transact_cache_paged(
+ xk_cache_update=xk,
+ xv_cache_update=xv,
+ seq_block_ids=seq_block_ids,
+ kv_seq_len=kv_seq_len,
+ start_positions=start_positions,
+ cache_state=cache_state,
+ xk_temp=xk_temp,
+ xv_temp=xv_temp,
+ )
+ elif attention_block.cache.is_direct:
+ xk, xv = attention_block.transact_cache_direct(
+ xk_cache_update=xk,
+ xv_cache_update=xv,
+ start_positions=start_positions,
+ kv_seq_len=kv_seq_len,
+ cache_state=cache_state,
+ )
+ else:
+ raise NotImplementedError(f"Unsupported KV cache type: {type(cache)}")
+
+ # Expand kv heads for GQA.
+ gqa_n_rep = attention_block.head_count // attention_block.head_count_kv
+ assert gqa_n_rep > 0
+ if gqa_n_rep > 1:
+
+ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
+ bs, slen, n_kv_heads, head_dim = x.shape
+ return (
+ x.unsqueeze(-2)
+ .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim)
+ )
+
+ xk = repeat_kv(xk)
+ xv = repeat_kv(xv)
+
+ # Transpose into [bs, heads, sl, dim]
+ xq = xq.transpose(1, 2)
+ keys = xk.transpose(1, 2)
+ values = xv.transpose(1, 2)
+ attention_mask = None
+ attn_output = F.scaled_dot_product_attention(
+ xq, keys, values, attn_mask=attention_mask, is_causal=is_causal
+ )
+ attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)
+ return attn_output
+
+
+def run_llama(
+ model: PagedLlamaAttentionBlock,
+ config: LlamaModelConfig,
+ phase: str,
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ xv: torch.Tensor,
+ # [1, 1, batch_seq_len, batch_seq_len]
+ attention_mask: torch.Tensor,
+ # [bs, batch_seq_len // block_seq_stride]
+ seq_block_ids: torch.Tensor,
+ cache_state: list[torch.Tensor],
+ # [bs] of starting positions
+ start_positions: Optional[torch.Tensor] = None,
+):
+
+ if phase == "decode":
+ bs, _, _, _ = xq.shape
+
+ # Allocate per-block temporary K/V tensors. These temporaries hold
+ # one block's K/V state for the maximum context length.
+ xk_temp = torch.empty(
+ [
+ bs,
+ config.hp.context_length,
+ config.hp.attention_head_count_kv,
+ config.hp.attn_head_dim,
+ ],
+ dtype=config.activation_dtype,
+ device=config.device,
+ )
+ xv_temp = torch.empty(
+ [
+ bs,
+ config.hp.context_length,
+ config.hp.attention_head_count_kv,
+ config.hp.attn_head_dim,
+ ],
+ dtype=config.activation_dtype,
+ device=config.device,
+ )
+ elif phase == "prefill":
+ xk_temp = None
+ xv_temp = None
+ else:
+ raise ValueError("'phase' argument needs to be either 'prefill' or 'decode'")
+
+ h = paged_attention(
+ model,
+ xq=xq,
+ xk=xk,
+ xv=xv,
+ is_causal=config.is_causal,
+ start_positions=start_positions,
+ attention_mask=attention_mask,
+ cache_state=cache_state,
+ seq_block_ids=seq_block_ids,
+ xk_temp=xk_temp,
+ xv_temp=xv_temp,
+ )
+
+ return h
+
+
+def main():
+ from ..utils import cli
+
+ parser = cli.create_parser()
+ # cli.add_input_dataset_options(parser)
+ parser.add_argument(
+ "--output-mlir",
+ help="Output file path for exported MLIR file",
+ default="/tmp/sharktank/artifacts/paged_llama.mlir",
+ )
+ parser.add_argument(
+ "--output-config",
+ help="Output file path for exported config file",
+ default="/tmp/sharktank/artifacts/paged_llama.json",
+ )
+ parser.add_argument(
+ "--bs",
+ help="Comma-separated batch size(s) to generate, e.g. `4` or `2,4`",
+ type=lambda arg: [int(bs) for bs in arg.split(",")],
+ default="4",
+ )
+ parser.add_argument(
+ "--verbose",
+ help="Include verbose logging",
+ action="store_true",
+ )
+
+ parser.add_argument(
+ "--is-causal",
+ help="Enable Causal attention",
+ action="store_true",
+ )
+ # TODO: move this to CLI to enable re-use with eager
+ parser.add_argument(
+ "--attention_kernel",
+ help="decomposed/torch",
+ default="decomposed",
+ )
+
+ args = cli.parse(parser)
+
+ # dataset = cli.get_input_dataset(args)
+ # hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
+
+ hp = configs.LlamaHParams(
+ context_length=4096,
+ embedding_length=4096,
+ block_count=1,
+ feed_forward_length=11008,
+ attn_head_dim=128,
+ rope_dimension_count=128,
+ attention_head_count=32,
+ attention_layer_norm_rms_epsilon=9.999999747378752e-06,
+ attention_head_count_kv=32,
+ model_arch="llama",
+ )
+
+ llama_config = LlamaModelConfig(hp)
+ llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged"
+ llama_config.bs = args.bs
+ llama_config.is_causal = args.is_causal
+
+ attention_block_theta = make_attention_block_theta(
+ feature_dim=llama_config.hp.attention_head_count
+ * llama_config.hp.attn_head_dim,
+ ffn_dim=llama_config.hp.feed_forward_length,
+ dtype=llama_config.attention_dtype,
+ )
+
+ causal_model = causal_llm.BaseCausalLMModel(
+ attention_block_theta, context_length=llama_config.hp.context_length
+ )
+
+ model = PagedLlamaAttentionBlock(
+ theta=attention_block_theta,
+ block_index=0,
+ cache=create_kv_cache(llama_config),
+ head_count=llama_config.hp.attention_head_count,
+ head_dim=llama_config.hp.attn_head_dim,
+ head_count_kv=llama_config.hp.attention_head_count_kv,
+ rms_epsilon=llama_config.hp.attention_layer_norm_rms_epsilon,
+ attention_kernel=args.attention_kernel,
+ )
+
+ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
+ return {
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": hp.context_length,
+ "attn_head_count": hp.attention_head_count,
+ "attn_head_dim": hp.attn_head_dim,
+ "prefill_batch_sizes": prefill_bs,
+ "decode_batch_sizes": decode_bs,
+ "transformer_block_count": hp.block_count,
+ "block_seq_stride": llama_config.block_seq_stride,
+ }
+
+ fxb = FxProgramsBuilder(model)
+
+ def generate_batch_prefill(bs: int):
+ tokens = torch.empty(bs, 64, dtype=torch.int64)
+ seq_lens = torch.empty(bs, dtype=torch.int64)
+ seq_block_ids = torch.empty(bs, 4, dtype=torch.int64)
+ block_dim = torch.export.Dim(
+ "block", max=(hp.context_length - 1) // llama_config.block_seq_stride
+ )
+ sl_dim = llama_config.block_seq_stride * block_dim
+
+ if llama_config.kv_cache_type == "paged":
+ cache_state = model.cache.allocate(
+ page_count=hp.context_length // llama_config.block_seq_stride
+ )
+ page_dim = torch.export.Dim("page")
+ cache_state_dynamic_shapes = [{0: page_dim}]
+ elif llama_config.kv_cache_type == "direct":
+ cache_state = model.cache.allocate(bs=1)
+ # Direct cache dimensions:
+ # 2 * transformer_block_count of...
+ # [bs, seq_length, attn_head_count, attn_head_dim]
+ cache_state_dynamic_shapes = (2 * hp.block_count) * [{}]
+ else:
+ raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")
+
+ dynamic_shapes = {
+ "tokens": {1: sl_dim},
+ "seq_lens": {},
+ "seq_block_ids": {1: block_dim},
+ "cache_state": cache_state_dynamic_shapes,
+ }
+
+ q = torch.zeros((bs, 64, 32, 128), dtype=torch.float16)
+ k = torch.zeros((bs, 64, 32, 128), dtype=torch.float16)
+ v = torch.zeros((bs, 64, 32, 128), dtype=torch.float16)
+
+ print(f"Exporting prefill_bs{bs}")
+ example_args = (q, k, v, seq_lens, seq_block_ids, cache_state)
+
+ @fxb.export_program(
+ name=f"prefill_bs{bs}",
+ args=example_args,
+ )
+ def _(model, q, k, v, seq_lens, seq_block_ids, cache_state):
+
+ if llama_config.is_causal:
+ attention_mask = None
+ else:
+ sl = tokens.shape[1]
+ input_mask = causal_model.input_mask(seq_lens, sl)
+ attention_mask = causal_model.attention_mask(input_mask)
+
+ h = run_llama(
+ model=model,
+ config=llama_config,
+ phase="prefill",
+ xq=q,
+ xk=k,
+ xv=v,
+ attention_mask=attention_mask,
+ seq_block_ids=seq_block_ids,
+ cache_state=cache_state,
+ )
+ return h
+
+ def generate_batch_decode(bs: int):
+ tokens = torch.ones(bs, 1, dtype=torch.int64)
+ seq_lens = torch.ones(bs, dtype=torch.int64)
+ start_positions = torch.ones(bs, dtype=torch.int64)
+ seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64)
+ block_dim = torch.export.Dim(
+ "block", max=(hp.context_length - 1) // llama_config.block_seq_stride
+ )
+
+ if llama_config.kv_cache_type == "paged":
+ cache_state = model.cache.allocate(
+ page_count=hp.context_length // llama_config.block_seq_stride
+ )
+ page_dim = torch.export.Dim("page")
+ cache_state_dynamic_shapes = [{0: page_dim}]
+ elif llama_config.kv_cache_type == "direct":
+ cache_state = model.cache.allocate(bs=1)
+ # Direct cache dimensions:
+ # 2 * transformer_block_count of...
+ # [bs, seq_length, attn_head_count, attn_head_dim]
+ cache_state_dynamic_shapes = (2 * hp.block_count) * [{}]
+ else:
+ raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")
+
+ dynamic_shapes = {
+ "tokens": {},
+ "seq_lens": {},
+ "start_positions": {},
+ "seq_block_ids": {1: block_dim},
+ "cache_state": cache_state_dynamic_shapes,
+ }
+
+ q = torch.zeros((bs, 1, 32, 128), dtype=torch.float16)
+ k = torch.zeros((bs, 1, 32, 128), dtype=torch.float16)
+ v = torch.zeros((bs, 1, 32, 128), dtype=torch.float16)
+
+ print(f"Exporting decode_bs{bs}")
+ example_args = (q, k, v, seq_lens, start_positions, seq_block_ids, cache_state)
+
+ @fxb.export_program(
+ name=f"decode_bs{bs}",
+ args=example_args,
+ )
+ def _(
+ model,
+ q,
+ k,
+ v,
+ seq_lens,
+ start_positions,
+ seq_block_ids,
+ cache_state,
+ ):
+
+ if llama_config.is_causal:
+ attention_mask = None
+ else:
+ input_mask = causal_model.input_mask(
+ seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride
+ )
+ attention_mask = causal_model.decode_attention_mask(input_mask)
+
+ h = run_llama(
+ model=model,
+ config=llama_config,
+ phase="decode",
+ xq=q,
+ xk=k,
+ xv=v,
+ attention_mask=attention_mask,
+ start_positions=start_positions,
+ seq_block_ids=seq_block_ids,
+ cache_state=cache_state,
+ )
+
+ return h
+
+ bsizes = []
+ for bs in llama_config.bs:
+ generate_batch_prefill(bs)
+ generate_batch_decode(bs)
+ bsizes.append(bs)
+
+ if args.verbose:
+ for name, ep in fxb.programs.items():
+ print(f"EXPORT {name}:\n{ep}")
+
+ config = generate_params_json(hp, bsizes, bsizes)
+ print("GENERATED!")
+
+ print("Exporting")
+ output = export(fxb)
+ print(f"Saving to '{args.output_mlir}'")
+ output.save_mlir(args.output_mlir)
+ json.dump(config, open(args.output_config, "w"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py
index 308e20ef4..445f44852 100644
--- a/sharktank/sharktank/kernels/__init__.py
+++ b/sharktank/sharktank/kernels/__init__.py
@@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .attention import *
+from .einsum_2args_q4 import *
from .mmtfp import *
from .mmt_block_scaled_offset_q4 import *
from .mmt_block_scaled_q8 import *
@@ -13,3 +14,4 @@
from .conv_2d_nchw_fchw import *
from .pooling_nchw_sum import *
from .base import *
+from .bitcast import *
diff --git a/sharktank/sharktank/kernels/base.py b/sharktank/sharktank/kernels/base.py
index 8c99c81d9..ce792b525 100644
--- a/sharktank/sharktank/kernels/base.py
+++ b/sharktank/sharktank/kernels/base.py
@@ -12,7 +12,7 @@
from jinja2 import Environment, PackageLoader, select_autoescape
-from shark_turbine.support.ir_imports import (
+from iree.turbine.support.ir_imports import (
FlatSymbolRefAttr,
FunctionType,
IrType,
@@ -24,7 +24,7 @@
Value,
)
-from shark_turbine.runtime.op_reg import (
+from iree.turbine.runtime.op_reg import (
def_library,
CustomOp,
KernelBuilder,
@@ -32,7 +32,7 @@
TensorArg,
)
-from shark_turbine.transforms.merger import Merger
+from iree.turbine.transforms.merger import Merger
from ..utils.logging import get_logger
diff --git a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py
index 11a6b5fc2..21f9e9ed4 100644
--- a/sharktank/sharktank/kernels/batch_matmul_transpose_b.py
+++ b/sharktank/sharktank/kernels/batch_matmul_transpose_b.py
@@ -8,6 +8,8 @@
import torch
+from iree.compiler.ir import IntegerType
+
__all__ = [
"batch_matmul_transpose_b",
]
@@ -80,7 +82,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
spec_sig = f"L{a_ident}_R{b_ident}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"
-
+ cst_zero = "0" if IntegerType.isinstance(accum_type) else "0."
# Template params.
c_asm_type = f"tensor<{'x'.join('?' if d is None else str(d) for d in result_desc.spec_dims)}x{accum_type}>"
@@ -93,5 +95,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
b_asm_type=b_asm_type,
c_asm_type=c_asm_type,
dtype=str(accum_type),
+ cst_zero=cst_zero,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
diff --git a/sharktank/sharktank/kernels/bitcast.py b/sharktank/sharktank/kernels/bitcast.py
new file mode 100644
index 000000000..66850008f
--- /dev/null
+++ b/sharktank/sharktank/kernels/bitcast.py
@@ -0,0 +1,138 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from sharktank.kernels.base import *
+
+import torch
+
+from iree.turbine.support.ir_imports import (
+ ComplexType,
+ F16Type,
+ F32Type,
+ RankedTensorType,
+ ShapedType,
+ Value,
+ flow_d,
+ tensor_d,
+)
+
+from iree.turbine.runtime.op_reg import (
+ CustomOp,
+ KernelBuilder,
+ KernelSelection,
+)
+
+__all__ = [
+ "bitcast_to_complex",
+ "bitcast_to_real",
+]
+
+_ftype_to_ctype_table = {
+ torch.float16: torch.complex32,
+ torch.float32: torch.complex64,
+}
+
+_ctype_to_ftype_table = {
+ torch.complex32: torch.float16,
+ torch.complex64: torch.float32,
+}
+
+_type_to_irtype_table = {
+ torch.float16: lambda: F16Type.get(),
+ torch.float32: lambda: F32Type.get(),
+ torch.complex32: lambda: ComplexType.get(F16Type.get()),
+ torch.complex64: lambda: ComplexType.get(F32Type.get()),
+}
+
+
+@CustomOp.register(library=LIBRARY)
+class bitcast_to_complex(CustomOp):
+
+ signature = "bitcast_to_complex(Tensor q) -> (Tensor)"
+
+ def select(self, ksel: KernelSelection):
+ ta = ksel.arg_tensor(0)
+
+ torch._check(ta.t.dtype in _ftype_to_ctype_table)
+ torch._check(isinstance(ta.t.shape[-1], int))
+
+ new_shape = [i for i in ta.t.shape]
+ new_shape[-1] = new_shape[-1] // 2
+
+ ctype = _ftype_to_ctype_table[ta.t.dtype]
+ ret = ksel.return_new_tensor(new_shape, dtype=ctype)
+ specialize_all_known_dims(ta)
+ specialize_all_known_dims(ret)
+
+ def eager_execute(self, tensor):
+ return torch.view_as_complex(tensor.unflatten(-1, (-1, 2)))
+
+ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
+ t = kb.arg_bindings[0]
+ result_desc = ksel.result_descs[0]
+ result_shape = [
+ d if isinstance(d, int) else RankedTensorType.get_dynamic_size()
+ for d in result_desc.t.shape
+ ]
+
+ dynamic_dims: list[Value] = []
+ _append_dynamic_dims(kb, dynamic_dims, t)
+
+ c64 = _type_to_irtype_table[result_desc.t.dtype]()
+ rtt = RankedTensorType.get(result_shape, c64)
+ result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result
+ kb.yield_results(result)
+
+
+@CustomOp.register(library=LIBRARY)
+class bitcast_to_real(CustomOp):
+
+ signature = "bitcast_to_real(Tensor q) -> (Tensor)"
+
+ def select(self, ksel: KernelSelection):
+ ta = ksel.arg_tensor(0)
+
+ torch._check(ta.t.dtype in _ctype_to_ftype_table)
+ torch._check(isinstance(ta.t.shape[-1], int))
+
+ new_shape = [i for i in ta.t.shape]
+ new_shape[-1] = new_shape[-1] * 2
+
+ ftype = _ctype_to_ftype_table[ta.t.dtype]
+ ret = ksel.return_new_tensor(new_shape, dtype=ftype)
+ specialize_all_known_dims(ta)
+ specialize_all_known_dims(ret)
+
+ def eager_execute(self, tensor):
+ return torch.view_as_real(tensor).flatten(-2, -1)
+
+ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
+ t = kb.arg_bindings[0]
+ result_desc = ksel.result_descs[0]
+ result_shape = [
+ d if isinstance(d, int) else RankedTensorType.get_dynamic_size()
+ for d in result_desc.t.shape
+ ]
+
+ dynamic_dims: list[Value] = []
+ _append_dynamic_dims(kb, dynamic_dims, t)
+
+ ftype = _type_to_irtype_table[result_desc.t.dtype]()
+ rtt = RankedTensorType.get(result_shape, ftype)
+ result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result
+ kb.yield_results(result)
+
+
+################################################################################
+# Emission utilities
+################################################################################
+
+
+def _append_dynamic_dims(kb: KernelBuilder, dynamic_dims: list[Value], tensor: Value):
+ rtt = RankedTensorType(tensor.type)
+ for i in range(rtt.rank):
+ if rtt.is_dynamic_dim(i):
+ dynamic_dims.append(tensor_d.dim(tensor, kb.constant_index(i)))
diff --git a/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py b/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py
index 9ada3b099..529511e02 100644
--- a/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py
+++ b/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py
@@ -23,6 +23,8 @@
(torch.int16, torch.int16, "torch.int16"): torch.int16,
(torch.int16, torch.int16, "torch.int32"): torch.int32,
# Legal fp types.
+ (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float16"): torch.float16,
+ (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, "torch.float32"): torch.float32,
(torch.float16, torch.float16, "torch.float16"): torch.float16,
(torch.float16, torch.float16, "torch.float32"): torch.float32,
(torch.float32, torch.float32, "torch.float32"): torch.float32,
@@ -33,6 +35,7 @@
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
+ torch.float8_e4m3fnuz: "f8E4M3FNUZ",
torch.float16: "f16",
torch.float32: "f32",
}
diff --git a/sharktank/sharktank/kernels/einsum_2args_q4.py b/sharktank/sharktank/kernels/einsum_2args_q4.py
new file mode 100644
index 000000000..76d8ad61c
--- /dev/null
+++ b/sharktank/sharktank/kernels/einsum_2args_q4.py
@@ -0,0 +1,259 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .base import *
+
+import torch
+
+__all__ = [
+ "einsum_2args_q4",
+]
+
+
+def einsum_util(einsum_str):
+ es_in, es_out = einsum_str.split("->")
+ es_in0, es_in1 = es_in.split(",")
+ es_set = set(es_out)
+ es_set = es_set.union(es_in0)
+ es_set = es_set.union(es_in1)
+ size = len(es_set)
+ imap = dict()
+ lmap = dict()
+ for i in range(len(es_out)):
+ imap[i] = es_out[i]
+ lmap[es_out[i]] = i
+ count = len(es_out)
+ for c in es_set:
+ if c not in lmap:
+ imap[count] = c
+ lmap[c] = count
+ count += 1
+
+ assert count == len(es_set)
+
+ in0_idx = [lmap[i] for i in es_in0]
+ in1_idx = [lmap[i] for i in es_in1]
+ out_idx = [lmap[i] for i in es_out]
+
+ input_idx_str = ", ".join(["d" + str(i) for i in range(size)])
+ in0_idx_str = ", ".join(["d" + str(i) for i in in0_idx])
+ in1_idx_str = ", ".join(["d" + str(i) for i in in1_idx])
+ out_idx_str = ", ".join(["d" + str(i) for i in out_idx])
+
+ iterators = ", ".join(
+ ['"parallel"' if i in out_idx else '"reduction"' for i in range(size)]
+ )
+
+ affine_map_in0 = f"affine_map<({input_idx_str}) -> ({in0_idx_str})>"
+ affine_map_in1 = f"affine_map<({input_idx_str}) -> ({in1_idx_str})>"
+ affine_map_out = f"affine_map<({input_idx_str}) -> ({out_idx_str})>"
+
+ indexing_maps = f"""{affine_map_in0},
+ {affine_map_in1},
+ {affine_map_out}
+"""
+
+ out_dyn_dim_size_str = ""
+ for c in es_out:
+ if c in es_in0:
+ out_dyn_dim_size_str += "%a" + str(es_in0.find(c)) + ","
+ elif c in es_in1:
+ if es_in1.find(c) == len(es_in1) - 1:
+ out_dyn_dim_size_str += "%b_unblocked_dim,"
+ else:
+ out_dyn_dim_size_str += "%b" + str(es_in1.find(c)) + ","
+ else:
+ raise Exception("Invalid einsum string")
+ out_dyn_dim_size_str = out_dyn_dim_size_str[:-1]
+ return (
+ (in0_idx, in1_idx, out_idx),
+ iterators,
+ indexing_maps,
+ out_dyn_dim_size_str,
+ )
+
+
+@CustomOp.register(library=LIBRARY)
+class einsum_2args_q4(CustomOp):
+ """Einsum that takes two tensor inputs and returns one tensor.
+
+ The first input is expected to be a normal tensor.
+
+ The second input corresponds to the BlockScaledLayout and operates on planar `d`
+ and `qs` tensors as specified there:
+
+ * `d`: `[..., K // BLOCK_SIZE, 1]`
+ * `qs`: `[..., K // BLOCK_SIZE, BLOCK_SIZE // 2]` (of uint8)
+ * `m`: `[..., K // BLOCK_SIZE, 1]`
+ """
+
+ signature = (
+ "einsum_2args_q4(Tensor a, Tensor d, Tensor qs, Tensor m, str es) -> (Tensor)"
+ )
+
+ def select(self, ksel: KernelSelection):
+ a_desc = ksel.arg_tensor(0) # Shape [b, ] m, k
+ d_desc = ksel.arg_tensor(1) # Shape [N, K // BLOCK_SIZE, 1]
+ qs_desc = ksel.arg_tensor(2) # Shape [N, K // BLOCK_SIZE, BLOCK_SIZE // 2]
+ m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1]
+ einsum_str = ksel.attr_str(4).v
+
+ # a arg
+ a_dims = a_desc.t.shape
+ torch._check(
+ a_desc.t.dtype.is_floating_point,
+ lambda: f"einsum_2args_q4 arg 'a': Expected floating point (got {a_desc.t.dtype})",
+ )
+
+ # qs arg
+ *qs_dims, qs_group0, qs_bs_div_2 = qs_desc.t.shape
+ block_size = qs_bs_div_2 * 2
+
+ # d arg
+ *d_dims, d_group0, d_one = d_desc.t.shape
+ torch._check(
+ d_group0 == qs_group0 and d_one == 1 and len(d_dims) == len(qs_dims),
+ lambda: f"einsum_2args_q4 arg 'd': Incorrect shape (got {d_desc.t.shape})",
+ )
+
+ # m arg
+ *m_dims, m_group0, m_one = m_desc.t.shape
+ torch._check(
+ m_desc.t.dtype == d_desc.t.dtype and len(m_dims) == len(qs_dims),
+ lambda: f"einsum_2args_q4 arg 'm': Incorrect dtype (got {m_desc.t.dtype})",
+ )
+ # einsum_str
+ torch._check(
+ einsum_str.count(",") == 1 and einsum_str.count("->") == 1,
+ lambda: f"einsum_2args_q4 arg 'einsum_str': Expected format '{{}},{{}}->{{}}' (got '{einsum_str}')",
+ )
+
+ es_in, es_out = einsum_str.split("->")
+ es_in0, es_in1 = es_in.split(",")
+ es_set = set(es_out)
+
+ shp = qs_desc.t.shape
+ b_dims = list(shp[:-2]) + [shp[-2] * block_size]
+ torch._check(
+ len(es_in0) == len(a_desc.t.shape)
+ and len(es_in1)
+ == len(qs_desc.t.shape)
+ - 1, # The quantized shape is larger until the blocks are collapsed
+ lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dimensions (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})",
+ )
+ torch._check(
+ len(es_in0) == len(set(es_in0))
+ and len(es_in1) == len(set(es_in1))
+ and len(es_in0) != 0
+ and len(es_in1) != 0,
+ lambda: f"einsum_2args_q4 arg 'einsum_str': Unsupported einsum str (got '{einsum_str}')",
+ )
+
+ # Check corresponding dimensions match
+ for i in range(len(es_in0)):
+ a_dim = a_dims[i]
+ c = es_in0[i]
+ pos = es_in1.find(c)
+ if pos >= 0:
+ b_dim = b_dims[pos]
+ torch._check(
+ a_dim == b_dim,
+ lambda: f"einsum_2args_q4 arg 'einsum_str': Einsum str dimensions do not match input dim for idx {c} (got '{einsum_str}' with inputs: {a_desc.t.shape} and {b_dims})",
+ )
+
+ # Determine the output shape by referencing corresponding input shapes
+ out_dims = []
+ for c in es_out:
+ pos0 = es_in0.find(c)
+ pos1 = es_in1.find(c)
+ a_dim = a_dims[pos0]
+ b_dim = b_dims[pos1]
+ if pos0 >= 0:
+ out_dims.append(a_dim)
+ elif pos1 >= 0:
+ out_dims.append(b_dim)
+ else:
+ torch._check(
+ False,
+ lambda: f"einsum_2args_q4 arg 'einsum_str': output indices must be in input indices (got '{einsum_str}')",
+ )
+
+ # Specialize on BS
+ qs_desc.specialize_dims(-1)
+ d_desc.specialize_dims(-1)
+ m_desc.specialize_dims(-1)
+
+ # Shape batch..., m, n
+ c_desc = ksel.return_new_tensor(out_dims, dtype=a_desc.t.dtype)
+
+ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
+ a = kb.arg_value(0)
+ a_tensor_type = RankedTensorType(a.type)
+ d = kb.arg_value(1)
+ d_tensor_type = RankedTensorType(d.type)
+ qs = kb.arg_value(2)
+ qs_tensor_type = RankedTensorType(qs.type)
+ einsum_str = ksel.arg_descs[4].v
+ # einsum_str = "mek,menk->men"
+
+ es_in, es_out = einsum_str.split("->")
+ es_in0, es_in1 = es_in.split(",")
+
+ es_name = "_".join([es_in0, es_in1, es_out])
+
+ (
+ (es_0, es_1, es_2),
+ einsum_iterators,
+ einsum_indexing_maps,
+ oddss,
+ ) = einsum_util(einsum_str)
+
+ rank1 = len(es_1)
+ dequant_iterators = ", ".join(
+ ['"parallel"' for i in range(rank1 + 1)]
+ ) # rank + 1 because of the group dimensions
+ input_idx_str = ", ".join(["d" + str(i) for i in range(rank1 + 1)])
+ broadcast_idx_str = ", ".join(
+ ["d" + str(i) if i != rank1 else "0" for i in range(rank1 + 1)]
+ )
+ affine_map_parallel = f"affine_map<({input_idx_str}) -> ({input_idx_str})>"
+ affine_map_broadcast = f"affine_map<({input_idx_str}) -> ({broadcast_idx_str})>"
+ dequant_indexing_maps = f"""{affine_map_broadcast},
+ {affine_map_broadcast},
+ {affine_map_parallel},
+ {affine_map_parallel}"""
+
+ size_str = "x".join("?" for i in range(rank1 - 2))
+
+ rank = a_tensor_type.rank
+ *n_dims, group0, bs_i8 = qs_tensor_type.shape
+ bs = bs_i8 * 2 # 2 nibbles per byte.
+ group = group0 * bs
+ a_type_str = str(a_tensor_type.element_type)
+ scale_type_str = str(d_tensor_type.element_type)
+
+ template_file = "einsum_2args_q4.mlir"
+ target_function_name = f"sharktank_einsum_2args_q4_{es_name}_{bs}_{a_type_str}"
+
+ target_function = inline_template_function(
+ kb,
+ template_file,
+ target_function_name,
+ bs=bs,
+ bs_i8=bs_i8,
+ a_type=a_type_str,
+ scale_type=scale_type_str,
+ dequant_indexing_maps=dequant_indexing_maps,
+ dequant_iterator_types=dequant_iterators,
+ einsum_indexing_maps=einsum_indexing_maps,
+ einsum_iterator_types=einsum_iterators,
+ es_name=es_name,
+ a_size=len(es_in0),
+ b_size=len(es_in1),
+ c_size=len(es_out),
+ out_dyn_dim_size_str=oddss,
+ )
+ kb.yield_results(*call_function(target_function, *kb.arg_bindings))
diff --git a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py
index 2ed171115..0c8a61f32 100644
--- a/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py
+++ b/sharktank/sharktank/kernels/mmt_block_scaled_offset_q4.py
@@ -37,28 +37,33 @@ def select(self, ksel: KernelSelection):
m_desc = ksel.arg_tensor(3) # Shape [N, K // BLOCK_SIZE, 1]
# a arg
- *batch_dims, a_m, a_k = a_desc.t.shape
+ *a_batch_dims, a_m, a_k = a_desc.t.shape
torch._check(
a_desc.t.dtype.is_floating_point,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected floating point (got {a_desc.t.dtype})",
)
torch._check(
- len(batch_dims) == 1,
+ len(a_batch_dims) == 1,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'a': Expected 3d tensor (got {a_desc.t.shape})",
)
# qs arg
- qs_n, qs_group0, qs_bs_div_2, *rest = qs_desc.t.shape
+ *qs_batch_dims, qs_n, qs_group0, qs_bs_div_2 = qs_desc.t.shape
torch._check(
- len(rest) == 0 and (qs_group0 * qs_bs_div_2 * 2) == a_k,
+ (
+ len(qs_batch_dims) == 0
+ or len(qs_batch_dims) == 1
+ and qs_batch_dims == a_batch_dims
+ )
+ and (qs_group0 * qs_bs_div_2 * 2) == a_k,
lambda: f"mmt_block_scaled_offset_q4_unsigned arg 'qs': Incorrect shape (got {qs_desc.t.shape})",
)
block_size = qs_bs_div_2 * 2
# d arg
- d_n, d_group0, d_one, *rest = d_desc.t.shape
+ *d_batch_dims, d_n, d_group0, d_one = d_desc.t.shape
torch._check(
- len(rest) == 0
+ d_batch_dims == qs_batch_dims
and (d_group0 * block_size) == a_k
and d_one == 1
and d_n == qs_n,
@@ -66,9 +71,9 @@ def select(self, ksel: KernelSelection):
)
# m arg
- m_n, m_group0, m_one, *rest = m_desc.t.shape
+ *m_batch_dims, m_n, m_group0, m_one = m_desc.t.shape
torch._check(
- len(rest) == 0
+ m_batch_dims == qs_batch_dims
and (m_group0 * block_size) == a_k
and m_one == 1
and m_n == qs_n,
@@ -81,12 +86,17 @@ def select(self, ksel: KernelSelection):
# Specialize on K, N, BS
a_desc.specialize_dims(-1)
- qs_desc.specialize_all_dims()
- d_desc.specialize_all_dims()
- m_desc.specialize_all_dims()
+ if len(qs_batch_dims) == 0:
+ qs_desc.specialize_all_dims()
+ d_desc.specialize_all_dims()
+ m_desc.specialize_all_dims()
+ else:
+ qs_desc.specialize_dims(1, 2, 3)
+ d_desc.specialize_dims(1, 2, 3)
+ m_desc.specialize_dims(1, 2, 3)
# Shape batch..., m, n
- c_desc = ksel.return_new_tensor(batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
+ c_desc = ksel.return_new_tensor(a_batch_dims + [a_m, d_n], dtype=a_desc.t.dtype)
c_desc.specialize_dims(-1)
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
@@ -99,13 +109,14 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
rank = a_tensor_type.rank
k = a_tensor_type.get_dim_size(rank - 1)
- n, group0, bs_i8 = qs_tensor_type.shape
+ *qs_batch_dims, n, group0, bs_i8 = qs_tensor_type.shape
+ batched_rhs = len(qs_batch_dims) == 1
bs = bs_i8 * 2 # 2 nibbles per byte.
a_type_str = str(a_tensor_type.element_type)
scale_type_str = str(d_tensor_type.element_type)
template_file = "mmt_block_scaled_offset_q4_unsigned.mlir"
- target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}"
+ target_function_name = f"sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{n}_{k}_{bs}_{a_type_str}_{batched_rhs}"
target_function = inline_template_function(
kb,
@@ -118,5 +129,6 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
group0=group0,
a_type=a_type_str,
scale_type=scale_type_str,
+ batched_rhs=batched_rhs,
)
kb.yield_results(*call_function(target_function, *kb.arg_bindings))
diff --git a/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir b/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir
index 908ca1c7f..056f72af3 100644
--- a/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir
+++ b/sharktank/sharktank/kernels/templates/batch_matmul_transpose_b.mlir
@@ -15,7 +15,7 @@ module {
util.func private @sharktank_batch_matmul_transpose_b_{{spec_sig}}(
%a: !a_tensor_type, %b: !b_tensor_type)
-> !c_tensor_type {
- %zero = arith.constant 0: !dtype
+ %zero = arith.constant {{cst_zero}}: !dtype
%c0 = arith.constant 0: index
%c1 = arith.constant 1: index
%batch_dim = tensor.dim %a, %c0 : !a_tensor_type // b, m, k
diff --git a/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
new file mode 100644
index 000000000..47ca6b331
--- /dev/null
+++ b/sharktank/sharktank/kernels/templates/einsum_2args_q4.mlir
@@ -0,0 +1,104 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+{% set accum_type = "f32" %}
+
+!lowp_type = i4
+!a_type = {{a_type}}
+!scale_type = {{scale_type}}
+!accum_type = {{accum_type}}
+!a_tensor_type = tensor<{% for i in range(a_size) %}?x{% endfor %}!a_type>
+!qs_raw_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs_i8}}xi8>
+!qs_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!lowp_type>
+!d_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type>
+!m_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}1x!scale_type>
+!accum_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!accum_type>
+!c_tensor_type = tensor<{% for i in range(c_size) %}?x{% endfor %}!a_type>
+!b_grouped_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}{{bs}}x!a_type>
+!b_tensor_type = tensor<{% for i in range(b_size) %}?x{% endfor %}!a_type>
+
+module {
+
+util.func private @sharktank_einsum_2args_q4_{{es_name}}_{{bs}}_{{a_type}}(
+ %a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
+ -> !c_tensor_type {
+ %debug = tensor.empty() : tensor<1xf32>
+ %zero = arith.constant 0.0: !accum_type
+ {% for i in range(a_size) %}
+ %k{{i}} = arith.constant {{i}} : index
+ {% endfor %}
+ {% for i in range(a_size, b_size) %}
+ %k{{i}} = arith.constant {{i}} : index
+ {% endfor %}
+ {% for i in range(a_size) %}
+ %a{{i}} = tensor.dim %a, %k{{i}}: !a_tensor_type
+ {% endfor %}
+ {% for i in range(b_size) %}
+ %b{{i}} = tensor.dim %qs_raw, %k{{i}}: !qs_raw_tensor_type
+ {% endfor %}
+ %bs = arith.constant {{bs}} : index
+ %b_unblocked_dim = arith.muli %b{{b_size-1}}, %bs : index
+
+ //%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
+ %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}} -> !qs_tensor_type{{"{"}}{% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}{{"}"}}
+
+ // Dequantize.
+ %b_grouped = tensor.empty({% for i in range(b_size-1) %}%b{{i}},{% endfor %}%b{{b_size-1}}) : !b_grouped_tensor_type
+ %b_grouped_dequant = linalg.generic {
+ indexing_maps = [
+ {{dequant_indexing_maps}}],
+ iterator_types = [{{dequant_iterator_types}}] }
+ ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type)
+ outs(%b_grouped : !b_grouped_tensor_type) {
+ ^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
+ %q_element_ext = arith.extui %q_element : !lowp_type to i32
+ %q_element_fp = arith.uitofp %q_element_ext : i32 to !a_type
+ {% if scale_type == a_type %}
+ %q_element_scaled = arith.mulf %q_element_fp, %d_element : !a_type
+ %q_element_offset = arith.addf %q_element_scaled, %m_element : !a_type
+ {% else %}
+ %d_element_ext = arith.extf %d_element : !scale_type to !a_type
+ %m_element_ext = arith.extf %m_element : !scale_type to !a_type
+ %q_element_scaled = arith.mulf %q_element_fp, %d_element_ext : !a_type
+ %q_element_offset = arith.addf %q_element_scaled, %m_element_ext : !a_type
+ {% endif %}
+ linalg.yield %q_element_offset : !a_type
+ } -> !b_grouped_tensor_type
+
+ // Collapse %b to the same unblocked structure.
+ %b_unblocked = tensor.collapse_shape %b_grouped_dequant [{% for i in range(b_size-1) %}[{{i}}], {% endfor %}[{{b_size-1}}, {{b_size}}]] : !b_grouped_tensor_type into !b_tensor_type
+
+ // Einsum
+ %result_empty = tensor.empty({{out_dyn_dim_size_str}}) : !accum_tensor_type
+ %result_fill = linalg.fill ins(%zero: !accum_type) outs(%result_empty: !accum_tensor_type) -> !accum_tensor_type
+ %result = linalg.generic {
+ indexing_maps = [
+ {{einsum_indexing_maps}}],
+ iterator_types = [{{einsum_iterator_types}}] }
+ ins(%a, %b_unblocked : !a_tensor_type, !b_tensor_type)
+ outs(%result_fill : !accum_tensor_type) {
+ ^bb0(%a_element: !a_type, %b_element: !a_type, %out: !accum_type):
+ %bmm_mul = arith.mulf %a_element, %b_element : !a_type
+ {% if accum_type == a_type %}
+ %bmm_accum = arith.addf %bmm_mul, %out : !a_type
+ {% else %}
+ %bmm_mul_ext = arith.extf %bmm_mul : !a_type to !accum_type
+ %bmm_accum = arith.addf %bmm_mul_ext, %out : !accum_type
+ {% endif %}
+ linalg.yield %bmm_accum : !accum_type
+ } -> !accum_tensor_type
+
+ // Cast.
+ %result_cast_empty = tensor.empty({{out_dyn_dim_size_str}}) : !c_tensor_type
+ %result_cast = linalg.copy
+ ins(%result : !accum_tensor_type)
+ outs(%result_cast_empty : !c_tensor_type) -> !c_tensor_type
+
+ //iree_input.tensor.trace "foobar" = [%a : !a_tensor_type, %d : !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type, %b_grouped_dequant: !b_grouped_tensor_type]
+ util.return %result_cast : !c_tensor_type
+}
+
+}
diff --git a/sharktank/sharktank/kernels/templates/flash_attention.mlir b/sharktank/sharktank/kernels/templates/flash_attention.mlir
index db76db84f..15d75c372 100644
--- a/sharktank/sharktank/kernels/templates/flash_attention.mlir
+++ b/sharktank/sharktank/kernels/templates/flash_attention.mlir
@@ -7,6 +7,7 @@
!q_type = tensor
!k_type = tensor
!v_type = tensor
+!trans_v_type = tensor
!o_type = tensor
!o_dyn_type = tensor
!s_type = tensor<{{scale_type}}>
@@ -32,15 +33,22 @@ util.func private @sharktank_flash_attention_{{l}}_{{s}}_{{d}}_{{e}}_{{i_type}}_
%scale = tensor.extract %s[] : !s_type
+ %init_trans_v = tensor.empty(%b0, %b1) : !trans_v_type
+ %transpose_v = linalg.transpose ins(%v: !v_type) outs(%init_trans_v: !trans_v_type) permutation = [0, 1, 3, 2]
+
%empty_dyn = tensor.empty(%b0, %b1, %l, %e) : !o_dyn_type
%empty = tensor.cast %empty_dyn : !o_dyn_type to !o_type
%atten = iree_linalg_ext.attention {indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
- affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
- affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
- ins(%q, %k, %v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) -> !o_type
+ affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>]}
+ ins(%q, %k, %transpose_v, %scale : !q_type, !k_type, !v_type, {{scale_type}}) outs(%empty : !o_type) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score : f32
+ } -> !o_type
util.return %atten : !o_type
}
}
diff --git a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir
index a7f3138cb..afe2928c0 100644
--- a/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir
+++ b/sharktank/sharktank/kernels/templates/mmt_block_scaled_offset_q4_unsigned.mlir
@@ -12,17 +12,25 @@
!accum_type = {{accum_type}}
!a_tensor_type = tensor
!aexp_tensor_type = tensor
+{% if batched_rhs %}
+!qs_raw_tensor_type = tensor
+!qs_tensor_type = tensor
+!d_tensor_type = tensor
+!m_tensor_type = tensor
+!b_grouped_tensor_type = tensor
+{% else %}
!qs_raw_tensor_type = tensor<{{n}}x{{group0}}x{{bs_i8}}xi8>
!qs_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!lowp_type>
!d_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
!m_tensor_type = tensor<{{n}}x{{group0}}x1x!scale_type>
+!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>
+{% endif %}
!accum_tensor_type = tensor
!c_tensor_type = tensor
-!b_grouped_tensor_type = tensor<{{n}}x{{group0}}x{{bs}}x!a_type>
module {
-util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}(
+util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_{{bs}}_{{a_type}}_{{batched_rhs}}(
%a: !a_tensor_type, %d: !d_tensor_type, %qs_raw: !qs_raw_tensor_type, %m: !m_tensor_type)
-> !c_tensor_type {
%zero = arith.constant 0.0: !accum_type
@@ -32,17 +40,31 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
%m_dim = tensor.dim %a, %c1 : !a_tensor_type
// Cast qs_raw from i8 to lowp type.
+{% if batched_rhs %}
+ %qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type{ %batch0_dim } -> !qs_tensor_type{ %batch0_dim }
+ %b_grouped = tensor.empty(%batch0_dim) : !b_grouped_tensor_type
+{% else %}
%qs = flow.tensor.bitcast %qs_raw : !qs_raw_tensor_type -> !qs_tensor_type
+ %b_grouped = tensor.empty() : !b_grouped_tensor_type
+{% endif %}
// Dequantize.
- %b_grouped = tensor.empty() : !b_grouped_tensor_type
%b_grouped_dequant = linalg.generic {
+{% if batched_rhs %}
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"] }
+{% else %}
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"] }
+{% endif %}
ins(%d, %m, %qs : !d_tensor_type, !m_tensor_type, !qs_tensor_type)
outs(%b_grouped : !b_grouped_tensor_type) {
^bb0(%d_element: !scale_type, %m_element: !scale_type, %q_element: !lowp_type, %out: !a_type):
@@ -70,7 +92,7 @@ util.func private @sharktank_mmt_block_scaled_offset_q4_unsigned_3d_{{n}}_{{k}}_
indexing_maps = [
// d0 = b, d1 = m, d2 = n, d3 = group0 (r), d4 = block (r)
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ({% if batched_rhs %}d0,{% endif %} d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] }
ins(%aexp, %b_grouped_dequant : !aexp_tensor_type, !b_grouped_tensor_type)
diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py
index 181544763..fd56ec872 100644
--- a/sharktank/sharktank/layers/__init__.py
+++ b/sharktank/sharktank/layers/__init__.py
@@ -16,6 +16,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
-from .mixture_of_experts_block import SparseMoeBlock
+from .mixture_of_experts_block import MoeBlock
-from . import configs
+from .configs import *
diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py
index 3d91683fe..8ace77981 100644
--- a/sharktank/sharktank/layers/causal_llm.py
+++ b/sharktank/sharktank/layers/causal_llm.py
@@ -33,12 +33,14 @@ def __init__(
device: Optional[torch.device] = None,
activation_dtype: torch.dtype = torch.float32,
attention_dtype: torch.dtype = torch.float32,
+ fake_quant: bool = True,
):
super().__init__(theta)
self.device = device
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype
self.context_length = context_length
+ self.fake_quant = fake_quant
if static_tables:
self.register_buffer(
@@ -89,16 +91,15 @@ def input_mask(
masked.
"""
range_vector = torch.arange(0, batch_seqlen, 1, device=self.device)
- matrix = torch.unsqueeze(seq_lens, dim=-1)
+ matrix = seq_lens.unsqueeze(dim=-1)
mask = range_vector >= matrix
return mask
def decode_attention_mask(self, boolean_input_mask: torch.Tensor):
dtype = self.attention_dtype
- numeric_mask = torch.zeros_like(boolean_input_mask, dtype=dtype)
- numeric_mask.masked_fill_(
- boolean_input_mask, self._maximally_negative_value(dtype)
- )
+ numeric_mask = torch.where(
+ boolean_input_mask, self._maximally_negative_value(dtype), 0
+ ).to(dtype)
return numeric_mask.unsqueeze(1).unsqueeze(1).to(self.device)
def attention_mask(
@@ -127,9 +128,10 @@ def attention_mask(
dtype = self.attention_dtype
_, batch_seq_len = input_mask.shape
causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len]
- boolean_mask = causal_mask + input_mask[:, None, None, :]
- numeric_mask = torch.zeros_like(boolean_mask, dtype=dtype)
- numeric_mask.masked_fill_(boolean_mask, self._maximally_negative_value(dtype))
+ boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :])
+ numeric_mask = torch.where(
+ boolean_mask, self._maximally_negative_value(dtype), 0
+ ).to(dtype)
return numeric_mask.to(self.device)
def extract_tokens_from_logits(
diff --git a/sharktank/sharktank/layers/configs/__init__.py b/sharktank/sharktank/layers/configs/__init__.py
index 21336d1d2..c5d75c602 100644
--- a/sharktank/sharktank/layers/configs/__init__.py
+++ b/sharktank/sharktank/layers/configs/__init__.py
@@ -4,4 +4,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from .llm_configs import LlamaHParams
+from .llm_configs import *
diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py
index ab3a582e4..35a2ee570 100644
--- a/sharktank/sharktank/layers/configs/llm_configs.py
+++ b/sharktank/sharktank/layers/configs/llm_configs.py
@@ -14,12 +14,11 @@
(and indeed, can bootstrap these off of GGUF files).
"""
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import Any, Optional
-
import torch
-__all__ = ["LlamaHParams"]
+__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"]
@dataclass
@@ -29,51 +28,79 @@ class LlamaHParams:
Comments are only provided if they differ from this source.
"""
+ model_arch: str
context_length: int
embedding_length: int
block_count: int
feed_forward_length: int
- rope_dimension_count: int
- rope_freq_base: float
attention_head_count: int
attn_head_dim: int
attention_layer_norm_rms_epsilon: float
attention_head_count_kv: int
- expert_count: int
- expert_used_count: int
+ rope_dimension_count: Optional[int] = None
+ rope_freq_base: Optional[float] = None
+ expert_count: Optional[int] = None
+ expert_used_count: Optional[int] = None
@staticmethod
def from_gguf_props(p: dict[str, Any]):
+ name_prefix = p.get("general.architecture", "llama")
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
- attention_head_count = _int_prop(p, "llama.attention.head_count")
+ default_rope_dimension_count = 128
+ attention_head_count = _int_prop(p, f"{name_prefix}.attention.head_count")
+ rope_dimension_count = _optional_int_prop(
+ p, f"{name_prefix}.rope.dimension_count", default_rope_dimension_count
+ )
return LlamaHParams(
- context_length=_int_prop(p, "llama.context_length"),
- embedding_length=_int_prop(p, "llama.embedding_length"),
- block_count=_int_prop(p, "llama.block_count"),
- feed_forward_length=_int_prop(p, "llama.feed_forward_length"),
- attn_head_dim=_int_prop(p, "llama.rope.dimension_count"),
- rope_dimension_count=_int_prop(p, "llama.rope.dimension_count"),
+ model_arch=name_prefix,
+ context_length=_int_prop(p, f"{name_prefix}.context_length"),
+ embedding_length=_int_prop(p, f"{name_prefix}.embedding_length"),
+ block_count=_int_prop(p, f"{name_prefix}.block_count"),
+ feed_forward_length=_int_prop(p, f"{name_prefix}.feed_forward_length"),
attention_head_count=attention_head_count,
attention_layer_norm_rms_epsilon=_float_prop(
- p, "llama.attention.layer_norm_rms_epsilon"
+ p, f"{name_prefix}.attention.layer_norm_rms_epsilon"
),
attention_head_count_kv=_optional_int_prop(
- p, "llama.attention.head_count_kv", attention_head_count
+ p, f"{name_prefix}.attention.head_count_kv", attention_head_count
),
+ attn_head_dim=rope_dimension_count,
+ rope_dimension_count=rope_dimension_count,
rope_freq_base=_optional_float_prop(
- p, "llama.rope.freq_base", default_rope_freq_base
+ p, f"{name_prefix}.rope.freq_base", default_rope_freq_base
),
expert_count=_optional_int_prop(
- p, "llama.expert_count", default_expert_count
+ p, f"{name_prefix}.expert_count", default_expert_count
),
expert_used_count=_optional_int_prop(
- p, "llama.expert_used_count", default_expert_used_count
+ p, f"{name_prefix}.expert_used_count", default_expert_used_count
),
)
+ def to_gguf_props(self) -> dict[str, Any]:
+ res = {
+ "general.architecture": self.model_arch,
+ f"{self.model_arch}.context_length": self.context_length,
+ f"{self.model_arch}.embedding_length": self.embedding_length,
+ f"{self.model_arch}.block_count": self.block_count,
+ f"{self.model_arch}.feed_forward_length": self.feed_forward_length,
+ f"{self.model_arch}.attention.head_count": self.attention_head_count,
+ f"{self.model_arch}.attention.layer_norm_rms_epsilon": self.attention_layer_norm_rms_epsilon,
+ f"{self.model_arch}.attention.head_count_kv": self.attention_head_count_kv,
+ }
+ if self.rope_dimension_count is not None:
+ res[f"{self.model_arch}.rope.dimension_count"] = self.rope_dimension_count
+ if self.rope_freq_base is not None:
+ res[f"{self.model_arch}.rope.freq_base"] = self.rope_freq_base
+ if self.expert_count is not None:
+ res[f"{self.model_arch}.expert_count"] = self.expert_count
+ if self.expert_used_count is not None:
+ res[f"{self.model_arch}.expert_used_count"] = self.expert_used_count
+ return res
+
def _float_prop(p: dict[str, Any], name: str) -> float:
try:
@@ -107,3 +134,122 @@ def _optional_int_prop(p: dict[str, Any], name: str, default_value: int) -> int:
return int(value)
except ValueError as e:
raise ValueError(f"Property '{name}' expected to be an int and was not") from e
+
+
+@dataclass
+class LlamaModelConfig:
+ hp: LlamaHParams
+
+ # Block sequence stride for a paged KV cache. This must divide evenly
+ # into the context length.
+ block_seq_stride: int = 16
+
+ # Either "paged" or "direct".
+ kv_cache_type: str = "paged"
+
+ # The device on which to place intermediate state.
+ device: Optional[torch.device] = None
+
+ # Dtype to use for general FP activations not otherwise configured.
+ activation_dtype: torch.dtype = torch.float16
+
+ # Dtype to use for attention.
+ attention_dtype: torch.dtype = torch.float16
+
+ # fake quant determines the mode the Layer Thetas operate w.r.t quantized tensors.
+ fake_quant: bool = True
+
+ # How many devices are involved for tensor parallel sharding.
+ # If greater than 1, the model will expect sharded model parameters and function
+ # arguments.
+ tensor_parallelism_size: int = 1
+
+ # Which attention kernel to use.
+ attention_kernel: str = "decomposed"
+
+ # Indicates if running with HuggingFace implementation and ensures
+ # numerical equivalency to HuggingFace's LLaMa if true (by modifying
+ # rotary embedding).
+ use_hf: bool = False
+
+ # If true, then the model may pre-initialize certain tables during
+ # init. This can be better for eager execution but when capturing a program,
+ # it is often better to preserve the calculation explicitly and rely on
+ # the compiler to transform it to an initialization time step. This can
+ # be the difference of many gigabytes of static data being embedded in
+ # the program and not.
+ static_tables: bool = True
+
+
+@dataclass
+class T5Config:
+ return_dict: bool = True
+ output_hidden_states: bool = False
+ output_attentions: bool = False
+ is_encoder_decoder: bool = True
+ is_decoder: bool = False
+ vocab_size: int = 32128
+ context_length: int = 512
+ d_model: int = 512
+ d_kv: int = 64
+ d_ff: int = 2048
+ num_layers: int = 6
+ num_decoder_layers: int = 6
+ num_heads: int = 8
+ relative_attention_num_buckets: int = 32
+ relative_attention_max_distance: int = 128
+ layer_norm_epsilon: float = 1e-6
+ feed_forward_proj: str = "relu"
+ is_gated_act: bool = field(init=False)
+ activation_dtype: torch.dtype = torch.float32
+ dense_act_fn: str = field(init=False)
+ use_cache: bool = True
+ pad_token_id: int = 0
+ eos_token_id: int = 1
+ decoder_start_token_id: int = 0
+ context_length_padding_block_size: int = 16
+
+ def __post_init__(self):
+ self.is_gated_act = self.feed_forward_proj.startswith("gated-")
+ self.dense_act_fn = (
+ self.feed_forward_proj.split("-")[1]
+ if "-" in self.feed_forward_proj
+ else self.feed_forward_proj
+ )
+ if self.dense_act_fn == "gelu":
+ self.dense_act_fn = "gelu_new"
+
+ @staticmethod
+ def from_gguf_properties(properties: dict[str, Any], **kwargs):
+ assert properties["general.architecture"] == "t5"
+ assert (
+ properties["t5.attention.layer_norm_epsilon"]
+ == properties["t5.attention.layer_norm_rms_epsilon"]
+ )
+
+ gguf_to_config_names_map = {
+ "t5.context_length": ["context_length"],
+ "t5.embedding_length": ["d_model"],
+ "t5.feed_forward_length": ["d_ff"],
+ "t5.block_count": ["num_layers", "num_decoder_layers"],
+ "t5.attention.head_count": ["num_heads"],
+ "t5.attention.key_length": ["d_kv"],
+ "t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
+ "t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
+ "t5.decoder_start_token_id": ["decoder_start_token_id"],
+ "tokenizer.ggml.eos_token_id": ["eos_token_id"],
+ "tokenizer.ggml.padding_token_id": ["pad_token_id"],
+ }
+ all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
+ all_kwargs.update(
+ {
+ config_name: properties[gguf_name]
+ for gguf_name, config_names in gguf_to_config_names_map.items()
+ for config_name in config_names
+ }
+ )
+ if "tokenizer.ggml.tokens" in properties:
+ all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
+ all_kwargs.update(kwargs)
+
+ return T5Config(**all_kwargs)
diff --git a/sharktank/sharktank/layers/ffn_block.py b/sharktank/sharktank/layers/ffn_block.py
index 420b06893..cde603b98 100644
--- a/sharktank/sharktank/layers/ffn_block.py
+++ b/sharktank/sharktank/layers/ffn_block.py
@@ -4,10 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Optional
+from typing import Optional, Callable
import torch
import torch.nn.functional as F
+from .. import ops
+from ..types import AnyTensor
from .base import Theta, ThetaLayer
from .linear import LinearLayer
@@ -20,18 +22,29 @@ class FFN(ThetaLayer):
def __init__(
self,
theta: Theta,
+ is_gated: bool = True,
+ activation_fn: Callable[[AnyTensor], AnyTensor] = F.silu,
):
super().__init__(theta)
- self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
+ self.is_gated = is_gated
+ self.activation_fn = activation_fn
+ if self.is_gated:
+ self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))
def forward(
self,
- h: torch.Tensor,
- ):
- ffn_gate = F.silu(self.ffn_gate(h))
- ffn_up = self.ffn_up(h)
- ffn_down = self.ffn_down(ffn_gate * ffn_up)
- return ffn_down
+ h: AnyTensor,
+ ) -> AnyTensor:
+ if self.is_gated:
+ ffn_gate = ops.elementwise(self.activation_fn, self.ffn_gate(h))
+ ffn_up = self.ffn_up(h)
+ ffn_down = self.ffn_down(ffn_gate * ffn_up)
+ return ffn_down
+ else:
+ h = self.ffn_up(h)
+ h = ops.elementwise(self.activation_fn, h)
+ h = self.ffn_down(h)
+ return h
diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py
index d833882f5..2adb9464f 100644
--- a/sharktank/sharktank/layers/ffn_moe_block.py
+++ b/sharktank/sharktank/layers/ffn_moe_block.py
@@ -12,12 +12,68 @@
from .base import ThetaLayer
from .linear import LinearLayer
from ..types import Theta, DefaultPrimitiveTensor
+from ..ops import einsum_2args, elementwise
__all__ = [
"FFNMOE",
+ "PreGatherFFNMOE",
]
+class PreGatherFFNMOE(ThetaLayer):
+ def __init__(
+ self,
+ theta: Theta,
+ activation=F.silu,
+ ):
+
+ super().__init__(theta)
+
+ self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
+ self.ffn_up = theta.tensor("ffn_up_exps", "weight")
+ self.ffn_down = theta.tensor("ffn_down_exps", "weight")
+ self.activation = activation
+
+ def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
+ inputs = inputs[:, :]
+ weights = weights[experts, :, :]
+ matmul = einsum_2args(inputs, weights, einstring)
+ return matmul
+
+ def bigger_mmg(self, inputs, weights, experts):
+ inputs = inputs[:, :]
+ weights = weights[experts, :, :]
+ matmul = einsum_2args(inputs, weights, "mek,menk->men")
+ return matmul
+
+ def one_hot_matmul(self, inputs, weights, experts):
+ matmul = einsum_2args(inputs, weights, "mk,bnk->bmn")
+ # Post mix the experts
+ oh = (
+ torch.nn.functional.one_hot(experts.reshape(-1), num_classes=8)
+ .transpose(0, 1)
+ .to(torch.float32)
+ )
+ output = einsum_2args(oh, matmul, "bm,bmn->mn")
+ return output
+
+ def forward(
+ self,
+ h: torch.Tensor,
+ experts: torch.Tensor,
+ expert_gate: torch.Tensor,
+ ):
+ ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts)
+ ffn_gate = elementwise(self.activation, ffn_gate)
+
+ ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
+ ffn_down = self.pre_matmul_gather(
+ ffn_gate * ffn_up, self.ffn_down, experts, einstring="mek,menk->men"
+ )
+ ffn_down = einsum_2args(expert_gate, ffn_down, "me,men->men")
+ return torch.sum(ffn_down, dim=1)
+
+
class FFNMOE(ThetaLayer):
def __init__(
self,
diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py
index 47b465bd2..c73b7a8f4 100644
--- a/sharktank/sharktank/layers/kv_cache.py
+++ b/sharktank/sharktank/layers/kv_cache.py
@@ -11,7 +11,7 @@
and dims floating around everywhere.
"""
-from typing import Optional
+from typing import Optional, Union, List
import abc
import math
@@ -19,6 +19,8 @@
import torch
from ..utils.debugging import trace_tensor
+from ..types import SplitPrimitiveTensor, ReplicatedTensor
+from .. import ops
__all__ = [
"BaseKVCache",
@@ -90,6 +92,7 @@ def __init__(
attn_head_count: int,
attn_head_dim: int,
seq_length: int,
+ shard_count: int = 1,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
@@ -98,6 +101,7 @@ def __init__(
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.seq_length = seq_length
+ self.shard_count = shard_count
self.device = device
self.dtype = dtype
@@ -111,15 +115,109 @@ def allocate(self, *, bs: int) -> list[torch.Tensor]:
Each tensor has shape: [bs, sl, attn_head_count, attn_head_dim]
"""
- return [
+ allocations = [
torch.empty(
- [bs, self.seq_length, self.attn_head_count, self.attn_head_dim],
+ [
+ bs,
+ self.seq_length,
+ self.attn_head_count,
+ self.attn_head_dim,
+ ],
dtype=self.dtype,
device=self.device,
)
for _ in range(2 * self.transformer_block_count)
]
+ if self.shard_count == 1:
+ return allocations
+
+ return [
+ ops.reshard_split(allocation, dim=2, count=self.shard_count)
+ for allocation in allocations
+ ]
+
+ def read(
+ self,
+ state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ *,
+ read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ transformer_block_index: int,
+ seq_len: int,
+ page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
+ ):
+ """Reads cache partitions from the page table for the given page_ids.
+
+ Args:
+ state: State struct as returned from allocate().
+ read_into_partitions: List of cache partitions to read into in-place.
+ transformer_block_index: The index of the transformer block accessing
+ the cache.
+ page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
+ to access.
+
+ Returns a tuple of cache partitions (i.e. k and v caches for the transformer
+ block), linearized. Note that this reference approach to reading by
+ materializing linearly may not be terribly efficient unless if the
+ compiler can fuse the gather.
+ """
+ read_count = len(read_into_partitions)
+ reads = []
+ for i in range(read_count):
+ reads.append(
+ state[transformer_block_index * read_count + i][:, :seq_len, :, :]
+ )
+
+ return tuple(reads)
+
+ def write_timestep(
+ self,
+ state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ # List of [bs, 1, attn_head_count, attn_head_dim]
+ cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ *,
+ transformer_block_index: int,
+ # [bs]
+ seq_positions: Union[torch.Tensor, ReplicatedTensor],
+ # [bs, max_seqlen // block_pos_stride]
+ page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
+ ):
+ """Writes a single batched timestep across all cache partitions.
+
+ Note that this internally loops over the batch size, which cannot be
+ dynamic.
+ """
+ bs, _, _, _ = cache_partitions[0].shape
+ update_count = len(cache_partitions)
+
+ for b in range(bs):
+ row_index = torch.tensor([b], dtype=torch.int64)
+ row_start_pos = seq_positions[row_index].unsqueeze(0)
+
+ for i, update in enumerate(cache_partitions):
+ cache = state[transformer_block_index * update_count + i]
+ cache.index_put_((row_index, row_start_pos), update[row_index, 0])
+
+ def write(
+ self,
+ state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ *,
+ transformer_block_index: int,
+ page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
+ ):
+ """Writes cache partitions from a linear layout to the page table.
+
+ This is the inverse of the linear read. The same caveat applies if the
+ in-place scatter cannot be fused.
+ """
+ update_count = len(cache_partitions)
+
+ for idx, update_src in enumerate(cache_partitions):
+ cache_dest = state[transformer_block_index * update_count + idx]
+ _, batch_seq_len, _, _ = update_src.shape
+ cache_dest[:, :batch_seq_len, :, :] = update_src
+
class PagedKVCache(BaseKVCache):
"""Implementation of a KV cache on top of a 'page table'.
@@ -138,6 +236,11 @@ class PagedKVCache(BaseKVCache):
Note that the internal page structure matches the organization of the
model, allowing contiguous individual local reads and writes at a sub-block
granularity if indexing deeply into the structure.
+
+ When `shard_count > 1`, it would split the `attn_head_count` dimension.
+ The page slab is a 1D sharded split tensor.
+ It is reinterpreted as a 6D tensor, by working around the lack of sharded
+ block-cyclic sharded tensor type.
"""
def __init__(
@@ -150,30 +253,59 @@ def __init__(
block_seq_stride: int = 16,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
+ shard_count: int = 1,
):
self.transformer_block_count = transformer_block_count
self.attn_head_count = attn_head_count
self.attn_head_dim = attn_head_dim
self.cache_partition_count = cache_partition_count
self.block_seq_stride = block_seq_stride
+ self.shard_count = shard_count
+ if attn_head_count % shard_count != 0:
+ raise ValueError(
+ f"The attention head count {attn_head_count} must be a multiple of the tensor parallelism size {shard_count}."
+ )
# Some derived values based on attributes.
self.sub_page_dims = [
self.transformer_block_count,
self.cache_partition_count,
self.block_seq_stride,
- self.attn_head_count,
+ self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
self.page_slab_flat_dim = math.prod(self.sub_page_dims)
self.device = device
self.dtype = dtype
- def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
+ def unflatten_page_table(
+ self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]]
+ ) -> Union[torch.Tensor, SplitPrimitiveTensor]:
"""Unflattens the 2D page table to a 6D tensor."""
assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}"
page_slab = state[0]
- return page_slab.reshape(
+ if self.shard_count == 1:
+ assert not isinstance(page_slab, SplitPrimitiveTensor)
+ return page_slab.unflatten(1, self.sub_page_dims)
+ else:
+ assert self.shard_count == page_slab.shard_count
+ shards = [
+ shard.unflatten(1, self.sub_page_dims) for shard in page_slab.shards
+ ]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=4)
+
+ def shard_state(
+ self, state: List[torch.Tensor]
+ ) -> List[Union[torch.Tensor, SplitPrimitiveTensor]]:
+ """Shard an unsharded state.
+ We can't just split the slab on the sub page dims.
+ First it needs to be reinterpreted into the actual shape.
+ The split the head dimension, then flatten each shard.
+ This is a work-around for the lack of block-cyclic sharded tensor type."""
+ if self.shard_count == 1:
+ return state
+
+ page_table = state[0].reshape(
[
-1,
self.transformer_block_count,
@@ -183,30 +315,47 @@ def unflatten_page_table(self, state: list[torch.Tensor]) -> torch.Tensor:
self.attn_head_dim,
]
)
+ sharded_page_table = ops.reshard_split(
+ page_table, dim=4, count=self.shard_count
+ )
+ shards = [
+ ops.flatten(shard, start_dim=1) for shard in sharded_page_table.shards
+ ]
+ flat_sharded_page_table = SplitPrimitiveTensor(ts=shards, shard_dim=1)
+ return [flat_sharded_page_table]
@property
def pad_sequence_stride(self) -> int:
return self.block_seq_stride
- def allocate(self, page_count: int) -> list[torch.Tensor]:
+ def allocate(
+ self, page_count: int
+ ) -> list[Union[torch.Tensor, SplitPrimitiveTensor]]:
"""Allocates tensor state for a page table for the given capacity in
pages.
"""
- return [
+ shards = [
torch.empty(
[page_count, self.page_slab_flat_dim],
dtype=self.dtype,
device=self.device,
)
+ for _ in range(self.shard_count)
]
+ if self.shard_count == 1:
+ return shards
+
+ return [SplitPrimitiveTensor(ts=shards, shard_dim=1)]
+
def read(
self,
- state: list[torch.Tensor],
+ state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
- read_into_partitions: list[torch.Tensor],
+ read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
- page_ids: torch.Tensor,
+ seq_len: int,
+ page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Reads cache partitions from the page table for the given page_ids.
@@ -231,7 +380,7 @@ def read(
bs,
block_seq_len,
self.block_seq_stride,
- self.attn_head_count,
+ self.attn_head_count // self.shard_count,
self.attn_head_dim,
]
@@ -249,7 +398,9 @@ def read(
transformer_block_index * transformer_block_stride
)
- def read_cache_partition(index: int, into_partition: torch.Tensor):
+ def read_cache_partition(
+ index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
+ ):
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
@@ -262,7 +413,7 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):
# a linear list.
# TODO: Can be rewritten into inplace with out= on index_select.
selected = (
- torch.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
+ ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
.unflatten(0, blocked_shape[0:2])
.flatten(1, 2)
)
@@ -272,17 +423,19 @@ def read_cache_partition(index: int, into_partition: torch.Tensor):
for index, read_into_partition in enumerate(read_into_partitions):
read_cache_partition(index, read_into_partition)
+ return tuple([p[:, :seq_len, :] for p in read_into_partitions])
+
def write_timestep(
self,
- state: list[torch.Tensor],
+ state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
# List of [bs, 1, attn_head_count, attn_head_dim]
- cache_partitions: list[torch.Tensor],
+ cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
# [bs]
- seq_positions: torch.Tensor,
+ seq_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, max_seqlen // block_pos_stride]
- page_ids: torch.Tensor,
+ page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes a single batched timestep across all cache partitions.
@@ -293,29 +446,41 @@ def write_timestep(
page_table = self.unflatten_page_table(state) # 6D
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count
- for i in range(bs):
- position = seq_positions[i]
- # TODO: Let's clamp to the allowable range so that we don't need
- # an assert.
- page_id = page_ids[i, :].index_select(0, position // self.block_seq_stride)
- page_offset = position % self.block_seq_stride
- for partition_index in range(self.cache_partition_count):
- cache_partition = cache_partitions[partition_index]
- indices = (
- page_id,
- torch.tensor([transformer_block_index], device=device),
- torch.tensor([partition_index], device=device),
- page_offset.unsqueeze(0),
- )
- page_table.index_put_(indices=indices, values=cache_partition[i, 0])
+
+ partition_count = len(cache_partitions)
+
+ # [bs, partitions, atten_head_count, attn_head_dim]
+ cache_partitions = ops.cat(cache_partitions, dim=1)
+
+ # [bs, 1]
+ page_index = seq_positions // self.block_seq_stride
+
+ page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1))
+ page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1)
+
+ # [1, partitions]
+ partitions = torch.arange(0, self.cache_partition_count).unsqueeze(0)
+
+ # [bs, partitions]
+ page_id = page_id.repeat(1, partition_count)
+ transformer_block = torch.full(
+ (bs, partition_count), transformer_block_index, device=device
+ )
+ page_offset = page_offset.repeat(1, partition_count)
+ partitions = partitions.repeat(bs, 1)
+
+ indices = (page_id, transformer_block, partitions, page_offset)
+ page_table.index_put_(indices=indices, values=cache_partitions)
+
+ return
def write(
self,
- state: list[torch.Tensor],
- cache_partitions: list[torch.Tensor],
+ state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ cache_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
transformer_block_index: int,
- page_ids: torch.Tensor,
+ page_ids: Union[torch.Tensor, ReplicatedTensor],
):
"""Writes cache partitions from a linear layout to the page table.
@@ -348,21 +513,21 @@ def write(
transformer_block_index * transformer_block_stride
)
- def write_cache_partition(index: int, part: torch.Tensor):
- part_block_view = part.reshape(blocked_shape)
+ part_block_views = []
+ subblock_ids_kv = []
+ for index, partition in enumerate(cache_partitions):
+ part_block_view = partition.unflatten(
+ 1, (block_seq_len, self.block_seq_stride)
+ )
+ part_block_view = part_block_view.flatten(0, 1)
+ part_block_views.append(part_block_view)
+
subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
- )
- # TODO: Potentially clamp all page 0 indices to the mask value.
- # Or even better, require that the ids are replicated such that access is
- # legal.
- # Now for each of the k/v attn_block_ids, which have been adjusted to
- # index into the sub-pages, we flatten to do a linear index_select
- # copy of the sub-blocks by collapsing the first two dims so we have
- # a linear list.
- subblock_table.index_copy_(
- 0, subblock_ids.flatten(0, 1), part_block_view.flatten(0, 1)
- )
+ ).flatten(0, 1)
+ subblock_ids_kv.append(subblock_ids)
- for index, partition in enumerate(cache_partitions):
- write_cache_partition(index, partition)
+ subblock_ids = ops.cat(subblock_ids_kv)
+ part_block_view = ops.cat(part_block_views, dim=0)
+
+ subblock_table.index_copy_(0, subblock_ids, part_block_view)
diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py
index 86e43d715..b679dccde 100644
--- a/sharktank/sharktank/layers/linear.py
+++ b/sharktank/sharktank/layers/linear.py
@@ -7,16 +7,15 @@
from typing import Optional
import torch
-
from .. import ops
from .base import Theta, ThetaLayer
-from ..types.layout_utils import saturate_cast
from ..types import (
DynamicScaledQuantizer,
QuantizedTensor,
QuantizerTensor,
StaticScaledQuantizer,
TensorScaledLayout,
+ PlanarQuantizedTensor,
)
__all__ = [
@@ -31,6 +30,10 @@ class LinearLayer(ThetaLayer):
if premul_input is not None:
x = x * premul_input
matmul(x, weight.T) + bias
+
+ fake_quant exists to allow export without adding dequant ops.
+ when fake_quant is True, the op will in quant dequant fashion.
+ When false, it will keep quantized types.
```
"""
@@ -40,11 +43,13 @@ def __init__(
*,
weight_name: str = "weight",
bias_name: str = "bias",
+ fake_quant: bool = True,
):
super().__init__(theta)
self._simulate_native_quant = True
self.weight = self.theta_tensor(weight_name)
self.bias = None
+ self.fake_quant = fake_quant
if bias_name in self.theta.keys:
self.bias = self.theta_tensor(bias_name)
@@ -54,26 +59,36 @@ def __init__(
self.qdq_input: Optional[QuantizedTensor] = theta.optional_tensor("qdq_input")
if self.q_input is not None and self.qdq_input is not None:
raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input")
+ self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output")
def forward(self, x):
weight = self.weight
bias = self.bias
q_input = self.q_input
qdq_input = self.qdq_input
-
+ qdq_output = self.qdq_output
if self.premul_input is not None:
x = ops.elementwise(torch.mul, x, self.premul_input)
if q_input is not None:
x = q_input.quantize(x)
- elif qdq_input is not None:
+ if self.fake_quant:
+ x = x.unpack().dequant()
+ elif qdq_input is not None and self.fake_quant:
x = qdq_input.quantize(x).unpack().dequant()
y = ops.linear(x, weight, bias)
# Unconditionally dequantize.
- # TODO: Support a q_output specifier that signals the layer to let
- # the QuantizedTensor escape.
- if isinstance(y, QuantizedTensor):
+ if isinstance(y, QuantizedTensor) and not self.fake_quant:
y = y.unpack().dequant()
+ # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
+ # We can truncate to fp16 in iree, so we do a cast here
+ # to account for this in the IR. This is may not be the right
+ # level to do this, but for now its here.
+ if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz:
+ y = ops.to(y, torch.float16)
+ return y
+ if qdq_output is not None and self.fake_quant:
+ y = qdq_output.quantize(y).unpack().dequant()
return y
diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py
index 7be8c7366..0cdb5d713 100644
--- a/sharktank/sharktank/layers/llama_attention_block.py
+++ b/sharktank/sharktank/layers/llama_attention_block.py
@@ -6,8 +6,6 @@
from typing import Optional
-import math
-
import torch
import torch.nn.functional as F
@@ -110,7 +108,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = values.transpose(1, 2)
# Flash attention.
- attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
+ attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt(
+ self.head_dim
+ )
# Apply attention mask.
if attention_mask is not None:
diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py
index 09bf491b7..ddce16c55 100644
--- a/sharktank/sharktank/layers/mixture_of_experts_block.py
+++ b/sharktank/sharktank/layers/mixture_of_experts_block.py
@@ -13,14 +13,14 @@
from .base import Theta, ThetaLayer
from .linear import LinearLayer
from .norm import RMSNormLayer
-from .ffn_moe_block import FFNMOE
+from .ffn_moe_block import FFNMOE, PreGatherFFNMOE
__all__ = [
- "SparseMoeBlock",
+ "MoeBlock",
]
-class SparseMoeBlock(ThetaLayer):
+class MoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
@@ -34,9 +34,13 @@ def __init__(
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
+ moe_activation=F.silu,
):
super().__init__(theta)
+ self.expert_count = expert_count
+ self.expert_used_count = expert_used_count
+
# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))
@@ -45,13 +49,17 @@ def __init__(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)
- # Add expert_count x FFN
- self.experts = nn.ModuleList(
- [FFNMOE(theta, expert_idx=i) for i in range(expert_count)]
- )
+ # Add optional FFN output norm layer
+ if theta.optional_tensor("layer_output_norm") is not None:
+ self.add_module(
+ "layer_output_norm",
+ RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
+ )
+ else:
+ self.add_module("layer_output_norm", torch.nn.Identity())
- self.expert_count = expert_count
- self.expert_used_count = expert_used_count
+ # Add expert_count x FFN
+ self.experts = PreGatherFFNMOE(theta, activation=moe_activation)
def forward(
self,
@@ -67,42 +75,16 @@ def forward(
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# Select top k experts from router weights
- router_weights, top_k_experts = torch.topk(
+ expert_gate, top_k_experts = torch.topk(
router_weights, self.expert_used_count, dim=-1
)
- router_weights /= router_weights.sum(dim=-1, keepdim=True)
- router_weights = router_weights.to(ffn_input.dtype)
- moe_output = torch.zeros(
- (batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype
- )
-
- # Create an expert mask by one hot encoding the selected top k experts
- # used to index which expert is to be invoked for each token
- # expert_mask: (expert_count, expert_used_count, sequence_length)
- expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute(
- 2, 1, 0
- )
-
- # Iterate over all experts in the model
- for expert_idx in range(self.expert_count):
- expert_layer = self.experts[expert_idx]
- top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx])
-
- # Given the hidden states, index the tokens assigned to this expert
- # and calculate the current expert's hidden state and weigh the
- # output expert hidden states by the router weights, based on the
- # appropriate tokens
- current_expert_tokens = ffn_input[None, token_idx]
+ expert_gate /= expert_gate.sum(dim=-1, keepdim=True)
+ expert_gate = expert_gate.to(ffn_input.dtype)
- current_expert = (
- expert_layer(current_expert_tokens)
- * router_weights[token_idx, top_k_expert_idx, None]
- )
-
- current_expert = current_expert.reshape(-1, feature_dim)
-
- moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype))
+ moe_output = self.experts(ffn_input, top_k_experts, expert_gate)
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)
+ moe_output = self.layer_output_norm(moe_output)
+
return h + moe_output
diff --git a/sharktank/sharktank/layers/norm.py b/sharktank/sharktank/layers/norm.py
index d062f1ffb..4fa08050a 100644
--- a/sharktank/sharktank/layers/norm.py
+++ b/sharktank/sharktank/layers/norm.py
@@ -33,9 +33,9 @@ def __init__(
def forward(self, x: torch.Tensor):
orig_dtype = x.dtype
- x = x.to(self.dtype)
+ x = ops.to(x, self.dtype)
norm = ops.rms_norm(x, self.weight, epsilon=self.epsilon)
# Will automatically upcast to the dtype of the weight, which is
# often in higher precision. Downcast back to expected.
- norm = norm.to(orig_dtype)
+ norm = ops.to(norm, orig_dtype)
return norm
diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py
index 41f3a20cc..22647bf49 100644
--- a/sharktank/sharktank/layers/paged_llama_attention_block.py
+++ b/sharktank/sharktank/layers/paged_llama_attention_block.py
@@ -10,12 +10,13 @@
import torch
import torch.nn.functional as F
-
+from ..types import QuantizerTensor
from .base import Theta, ThetaLayer
from .linear import LinearLayer
from .norm import RMSNormLayer
from .rotary_embedding import RotaryEmbeddingLayer
from .kv_cache import PagedKVCache
+from .. import ops
__all__ = [
"PagedLlamaAttentionBlock",
@@ -36,24 +37,54 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
- use_hf: bool = False,
-
+ attention_kernel: str = "decomposed",
+ attention_scale: Optional[float] = None,
+ softcap: Optional[float] = None,
+ fake_quant: Optional[bool] = True,
):
super().__init__(theta)
- self.add_module(
- "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
- )
- self.add_module("attn_q", LinearLayer(theta("attn_q")))
- self.add_module("attn_k", LinearLayer(theta("attn_k")))
- self.add_module("attn_v", LinearLayer(theta("attn_v")))
- self.add_module("attn_output", LinearLayer(theta("attn_output")))
self.block_index = block_index
self.cache = cache
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
- self.use_hf = use_hf
+ self.attention_kernel = attention_kernel
+ self.attention_scale = attention_scale
+ self.softcap = softcap
+ self.fake_quant = fake_quant
+
+ self.add_module(
+ "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
+ )
+ self.add_module(
+ "attn_q", LinearLayer(theta("attn_q"), fake_quant=self.fake_quant)
+ )
+ self.add_module(
+ "attn_k", LinearLayer(theta("attn_k"), fake_quant=self.fake_quant)
+ )
+ self.add_module(
+ "attn_v", LinearLayer(theta("attn_v"), fake_quant=self.fake_quant)
+ )
+ self.add_module(
+ "attn_output", LinearLayer(theta("attn_output"), fake_quant=self.fake_quant)
+ )
+ self.cache_quantizer = None
+ if "kv_cache" in theta.keys:
+ self.cache_quantizer: Optional[QuantizerTensor] = theta.optional_tensor(
+ "kv_cache.quantizer"
+ )
+
+ if theta.optional_tensor("attn_output_norm") is None:
+ self.add_module(
+ "attn_output_norm",
+ torch.nn.Identity(),
+ )
+ else:
+ self.add_module(
+ "attn_output_norm",
+ RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon),
+ )
def forward(
self,
@@ -88,36 +119,48 @@ def forward(
# Fast path to start_index based embedding lookup if available.
# Falls back to a slower position based index lookup.
if start_index is not None:
- xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index)
+ xq = embedding.forward(xt=xq, start_index=start_index)
+ xk = embedding.forward(xt=xk, start_index=start_index)
else:
- xq, xk = embedding.apply_batched_mask(
- xq=xq, xk=xk, mask=embedding_batch_mask
- )
+ xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask)
+ xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask)
# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride
- if self.cache.is_paged:
- xk, xv = self.transact_cache_paged(
- xk_cache_update=xk,
- xv_cache_update=xv,
- seq_block_ids=seq_block_ids,
- kv_seq_len=kv_seq_len,
- start_positions=start_positions,
- cache_state=cache_state,
- xk_temp=xk_temp,
- xv_temp=xv_temp,
- )
- elif self.cache.is_direct:
- xk, xv = self.transact_cache_direct(
- xk_cache_update=xk,
- xv_cache_update=xv,
- start_positions=start_positions,
- kv_seq_len=kv_seq_len,
- cache_state=cache_state,
- )
- else:
- raise NotImplementedError(f"Unsupported KV cache type: {type(self.cache)}")
+ # Used by fp8_e4m3fnuz model
+ if self.cache_quantizer is not None:
+ # For fake quant, store the fp16 qdq value in the cache
+ if self.fake_quant:
+ xk = (
+ self.cache_quantizer.quantize(xk)
+ .unpack()
+ .dequant()
+ .to(torch.float16)
+ )
+ xv = (
+ self.cache_quantizer.quantize(xv)
+ .unpack()
+ .dequant()
+ .to(torch.float16)
+ )
+ # For real quant, store the quantized fp8 value in the cache
+ else:
+ # TODO: this seems like a bastardization of our quantized tensor api
+ # Probably want to add support for using quantized tensors more directly
+ xk = self.cache_quantizer.quantize(xk).unpack().qs
+ xv = self.cache_quantizer.quantize(xv).unpack().qs
+
+ xk, xv = self.transact_cache(
+ xk_cache_update=xk,
+ xv_cache_update=xv,
+ seq_block_ids=seq_block_ids,
+ kv_seq_len=kv_seq_len,
+ start_positions=start_positions,
+ cache_state=cache_state,
+ xk_temp=xk_temp,
+ xv_temp=xv_temp,
+ )
# Expand kv heads for GQA.
gqa_n_rep = self.head_count // self.head_count_kv
@@ -126,87 +169,88 @@ def forward(
def repeat_kv(x: torch.Tensor) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
- return (
- x.unsqueeze(-2)
- .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim)
- .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim)
- )
+ unsq = x.unsqueeze(-2)
+ exp = ops.expand(unsq, (bs, slen, n_kv_heads, gqa_n_rep, head_dim))
+ return exp.flatten(2, 3)
xk = repeat_kv(xk)
xv = repeat_kv(xv)
+ # Fake quant is already dequantized when stored in the cache.
+ if self.cache_quantizer and not self.fake_quant:
+ xk = self.cache_quantizer.dequantize_raw_tensor(
+ xk, torch.float16, name="xk_deq"
+ )
+ xv = self.cache_quantizer.dequantize_raw_tensor(
+ xv, torch.float16, name="xv_deq"
+ )
# Transpose into [bs, heads, sl, dim]
xq = xq.transpose(1, 2)
keys = xk.transpose(1, 2)
values = xv.transpose(1, 2)
- # Flash attention.
- attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
- self.assert_not_nan(attn_weights)
+ if self.attention_kernel == "decomposed":
+ attn_weights = ops.matmul(xq, keys.transpose(2, 3))
+ if self.attention_scale is None:
+ attn_weights = attn_weights / math.sqrt(self.head_dim)
+ else:
+ attn_weights = attn_weights * self.attention_scale
- # Apply attention mask.
- self.trace_tensor("attn_weights", attn_weights, values=False)
- if attention_mask is not None:
- # self.trace_tensor("attn_mask", attention_mask)
- attn_weights = attn_weights + attention_mask
+ # Flash attention.
+ if self.softcap is not None:
+ attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap)
- attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(xq)
- attn_output = torch.matmul(attn_weights, values) # (bs, heads, slen, head_dim)
- attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1)
+ self.assert_not_nan(attn_weights)
+
+ # Apply attention mask.
+ self.trace_tensor("attn_weights", attn_weights, values=False)
+ if attention_mask is not None:
+ # self.trace_tensor("attn_mask", attention_mask)
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = ops.softmax(
+ ops.to(attn_weights, dtype=torch.float32), dim=-1
+ )
+ attn_weights = ops.to(attn_weights, dtype=xq.dtype)
+ attn_output = ops.matmul(
+ attn_weights, values
+ ) # (bs, heads, slen, head_dim)
+ else:
+ is_causal = True
+ attention_mask = None
+ attn_output = ops.scaled_dot_product_attention(
+ q=xq, # [bs, ..., sl, dim]
+ k=keys, # [bs, ..., sl, dim]
+ v=values, # [bs, ..., sl, dim]
+ a=attention_mask, # [bs, ..., sl, sl]
+ is_causal=is_causal, # assumes causal masking when true
+ scale=None, # defaults to 1/sqrt(dim)
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.flatten(2, 3)
# Project.
attn_output = self.attn_output(attn_output)
+ attn_output = self.attn_output_norm(attn_output)
- # Remainder of the block.
h = h + attn_output
-
return h
- def transact_cache_direct(
- self,
- *,
- cache_state: list[torch.Tensor],
- xk_cache_update: torch.Tensor,
- xv_cache_update: torch.Tensor,
- kv_seq_len: int,
- start_positions: Optional[torch.Tensor] = None,
- ):
- bs, batch_seq_len, _, _ = xk_cache_update.shape
- cache_k = cache_state[self.block_index * 2]
- cache_v = cache_state[self.block_index * 2 + 1]
-
- if start_positions is None:
- # Prefill. Write the entire cache.
- cache_k[:, :batch_seq_len] = xk_cache_update
- cache_v[:, :batch_seq_len] = xv_cache_update
- return xk_cache_update, xv_cache_update
- else:
- # Decode. Write a single timestep.
- # TODO: This needs to be reworked with index ops.
- assert xk_cache_update.shape[1] == 1
- assert xv_cache_update.shape[1] == 1
- max_start_pos = 0
- for row_index in range(bs):
- row_start_pos = start_positions[row_index].item()
- max_start_pos = max(row_start_pos, max_start_pos)
- cache_k[row_index, row_start_pos] = xk_cache_update[row_index, 0]
- cache_v[row_index, row_start_pos] = xv_cache_update[row_index, 0]
- return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len]
-
- def transact_cache_paged(
+ def transact_cache(
self,
*,
xk_cache_update: torch.Tensor,
xv_cache_update: torch.Tensor,
cache_state: list[torch.Tensor],
# [bs, batch_seq_len // block_seq_stride]
- seq_block_ids: torch.Tensor,
+ seq_block_ids: Optional[torch.Tensor],
kv_seq_len: int,
start_positions: Optional[torch.Tensor] = None,
xk_temp: Optional[torch.Tensor] = None,
xv_temp: Optional[torch.Tensor] = None,
):
- cache = self.cache.paged
+ cache = self.cache
# Manage the cache.
if start_positions is None:
# Prefill: Write the entire cache.
@@ -217,46 +261,45 @@ def transact_cache_paged(
page_ids=seq_block_ids,
)
return xk_cache_update, xv_cache_update
- else:
- # Decode at ragged start positions.
- # We need to initialize/read the K/V from the cache for the whole
- # sequence. Note that at this point, it is possible to fork and
- # use a memory efficient attention kernel that can do indirect
- # reads, skipping this materialization. This path is taken for
- # a decode step.
- assert xk_temp is not None and xv_temp is not None
- assert xk_cache_update.shape[1] == 1
- assert xv_cache_update.shape[1] == 1
- assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride
-
- # Write our one updated cache row into the cache.
- cache.write_timestep(
- cache_state,
- cache_partitions=[
- xk_cache_update,
- xv_cache_update,
- ],
- transformer_block_index=self.block_index,
- seq_positions=start_positions + 1,
- page_ids=seq_block_ids,
- )
- # Restore from the cache.
- cache.read(
- cache_state,
- read_into_partitions=[
- xk_temp[:, 0:kv_seq_len, ...],
- xv_temp[:, 0:kv_seq_len, ...],
- ],
- transformer_block_index=self.block_index,
- page_ids=seq_block_ids,
- )
+ # Decode at ragged start positions.
+ # We need to initialize/read the K/V from the cache for the whole
+ # sequence. Note that at this point, it is possible to fork and
+ # use a memory efficient attention kernel that can do indirect
+ # reads, skipping this materialization. This path is taken for
+ # a decode step.
+ assert xk_temp is not None and xv_temp is not None
+ assert xk_cache_update.shape[1] == 1
+ assert xv_cache_update.shape[1] == 1
+ assert kv_seq_len == seq_block_ids.shape[1] * cache.block_seq_stride
+
+ # Write our one updated cache row into the cache.
+ cache.write_timestep(
+ cache_state,
+ cache_partitions=[
+ xk_cache_update,
+ xv_cache_update,
+ ],
+ transformer_block_index=self.block_index,
+ seq_positions=start_positions,
+ page_ids=seq_block_ids,
+ )
+
+ # Restore from the cache.
+ xk, xv = cache.read(
+ cache_state,
+ read_into_partitions=[
+ xk_temp[:, 0:kv_seq_len, ...],
+ xv_temp[:, 0:kv_seq_len, ...],
+ ],
+ transformer_block_index=self.block_index,
+ page_ids=seq_block_ids,
+ seq_len=kv_seq_len,
+ )
- # For computation, we create a subview of the xk/xv tensors to have
- # a sequence length covering the blocked size. This must include
- # the newly added row (the caller is responsible for ensuring that
- # every block has at least one row left). We'll compute on this
- # ragged view and use an appropriate mask.
- xk = xk_temp[:, 0:kv_seq_len, ...]
- xv = xv_temp[:, 0:kv_seq_len, ...]
- return xk, xv
+ # For computation, we create a subview of the xk/xv tensors to have
+ # a sequence length covering the blocked size. This must include
+ # the newly added row (the caller is responsible for ensuring that
+ # every block has at least one row left). We'll compute on this
+ # ragged view and use an appropriate mask.
+ return xk, xv
diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py
index eadd9e6b0..0664a9a46 100644
--- a/sharktank/sharktank/layers/rotary_embedding.py
+++ b/sharktank/sharktank/layers/rotary_embedding.py
@@ -4,11 +4,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Optional
+from collections import namedtuple
+from typing import Optional, Union
import torch
from .base import BaseLayer
+from .. import ops
+from ..types import SplitPrimitiveTensor, ReplicatedTensor, unbox_tensor
class RotaryEmbeddingLayer(BaseLayer):
@@ -19,35 +22,80 @@ def __init__(
*,
rope_dimension_count: int,
max_seqlen: int,
- rope_freq_base: Optional[float] = 10000.0,
+ rope_freq_base: Optional[float],
device: Optional[torch.device] = None,
use_hf: bool = False,
- static_tables: bool = True,
+ static_tables: bool = False,
+ use_table: bool = True,
+ tensor_parallelism_size: int = 1,
):
super().__init__()
- # Force static_tables until compiler limitations are solved.
- # See https://github.com/nod-ai/sharktank/issues/156
- static_tables = True
self.device = device
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
- self.rope_freq_base = rope_freq_base
+ self.static_tables = static_tables
+ self.use_table = use_table
+
+ self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
+ self.tensor_parallelism_size = tensor_parallelism_size
if static_tables:
- self.register_buffer(
- "static_rotary_embed_table", self._create_rotary_embed_table()
+ ops.module_register_buffer(
+ self, "static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None
@property
def rotary_embed_table(self):
- if self.static_rotary_embed_table is None:
+ if self.use_table:
+ if self.static_tables:
+ return self.static_rotary_embed_table
return self._create_rotary_embed_table()
+
+ return None
+
+ def forward(
+ self,
+ *,
+ xt: Union[torch.Tensor, SplitPrimitiveTensor],
+ start_index: int,
+ ):
+ if isinstance(xt, SplitPrimitiveTensor):
+ rotary_shards = [None] * xt.shard_count
+ if self.rotary_embed_table is not None:
+ assert (
+ isinstance(self.rotary_embed_table, ReplicatedTensor)
+ and xt.shard_count == self.rotary_embed_table.shard_count
+ )
+ rotary_shards = [
+ unbox_tensor(shard) for shard in self.rotary_embed_table.shards
+ ]
+
+ xt_shards = [
+ self.forward_unsharded(
+ xt=unbox_tensor(xt_shard),
+ start_index=start_index,
+ rotary_embed_table=rotary_shard,
+ )
+ for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
+ ]
+ xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
+ return xt
else:
- return self.static_rotary_embed_table
+ return self.forward_unsharded(
+ xt=xt,
+ start_index=start_index,
+ rotary_embed_table=self.rotary_embed_table,
+ )
- def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int):
+ def forward_unsharded(
+ self,
+ *,
+ xt: torch.Tensor,
+ start_index: int,
+ rotary_embed_table: Optional[torch.Tensor],
+ ):
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
@@ -89,57 +137,33 @@ def create_ordering_tensor(dim):
return order_tensor
if self.use_hf:
- xq = xq[..., create_interleaved_tensor(xq.shape[-1])]
- xk = xk[..., create_interleaved_tensor(xq.shape[-1])]
-
- xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
- xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
- _, sl, _, dim = xq_.shape
+ xt = xt[..., create_interleaved_tensor(xt.shape[-1])]
+ xt_ = xt
+ _, sl, _, _ = xt_.shape
# Offset the table based on starting position.
- freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :]
- assert freqs_cis.shape[-1] == dim
+ if self.use_table:
+ freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
+ freqs_cis = freqs_cis[None, 0:sl, None, :]
+ else:
+ freqs_cis = torch.arange(sl, device=xt.device) + start_index
+ freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :]
+
assert (
- freqs_cis.shape[0] >= sl
+ freqs_cis.shape[1] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"
- broadcast_freqs_cis = freqs_cis[None, 0:sl, None, :]
+ xt_ = ops.view_as_complex(xt_)
+ xt_ = xt_ * freqs_cis
+ xt_out = ops.view_as_real(xt_)
if self.use_hf:
- xq_out = torch.view_as_real(
- self.complex_multiply(xq_, broadcast_freqs_cis)
- ).flatten(3)
- xk_out = torch.view_as_real(
- self.complex_multiply(xk_, broadcast_freqs_cis)
- ).flatten(3)
-
- xq_out = xq_out[..., create_ordering_tensor(xq_out.shape[-1])]
- xk_out = xk_out[..., create_ordering_tensor(xq_out.shape[-1])]
-
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
- xq_out = torch.view_as_real(xq_ * broadcast_freqs_cis).flatten(3)
- xk_out = torch.view_as_real(xk_ * broadcast_freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
+ xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])]
- def complex_multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
- """Function for elementwise-multiplication of two complex torch tensors.
- Functionally similar to a*b, but numerically accurate for HuggingFace
- LLaMa implementation.
-
- Args:
- a: First torch tensor operand
- b: Second torch tensor operand
- Returns:
- Tensor of same size to a, b whose elements is product of corresponding
- elements in a, b
- """
- return torch.complex(
- a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real
- )
+ return ops.to(xt_out, xt.dtype)
def compute_batch_mask(
- self, start_positions: torch.Tensor, batch_seq_len: int
+ self, start_positions: Union[torch.Tensor, ReplicatedTensor], batch_seq_len: int
) -> torch.Tensor:
"""Computes a mask for a batch that can be repeatedly applied.
@@ -156,15 +180,46 @@ def compute_batch_mask(
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
self.trace_tensor("rope.positions_seq", positions_seq)
- freqs_cis = self.rotary_embed_table[positions_seq]
+
+ if self.use_table:
+ freqs_cis = self.rotary_embed_table[positions_seq]
+ else:
+ shape = positions_seq.shape
+ if isinstance(positions_seq, ReplicatedTensor):
+ ts = [
+ self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape)
+ for s in positions_seq.shards
+ ]
+ freqs_cis = ReplicatedTensor(ts=ts)
+ else:
+ freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
+ freqs_cis = freqs_cis.unflatten(0, shape)
# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
return broadcast_freqs_cis
def apply_batched_mask(
- self, *, xq: torch.Tensor, xk: torch.Tensor, mask: torch.Tensor
+ self,
+ *,
+ xt: Union[torch.Tensor, SplitPrimitiveTensor],
+ mask: Union[torch.Tensor, ReplicatedTensor],
):
+ if not isinstance(xt, SplitPrimitiveTensor):
+ return self.apply_batched_mask_unsharded(xt=xt, mask=mask)
+
+ assert isinstance(mask, ReplicatedTensor) and mask.shard_count == xt.shard_count
+ xt_shards = [
+ self.apply_batched_mask_unsharded(
+ xt=unbox_tensor(xt_shard),
+ mask=unbox_tensor(mask_shard),
+ )
+ for xt_shard, mask_shard in zip(xt.shards, mask.shards)
+ ]
+ xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
+ return xt
+
+ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
"""Applies the embedding to a ragged batch of queries and keys.
This does a more complicated indexing operation for cases when the each
@@ -174,30 +229,32 @@ def apply_batched_mask(
"""
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
- xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
- xk_ = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
- _, sl, _, dim = xq_.shape
+ xt_ = ops.view_as_complex(xt)
+ xt_ = xt_ * mask
+ xt_out = ops.view_as_real(xt_)
- xq_out = torch.view_as_real(xq_ * mask).flatten(3)
- xk_out = torch.view_as_real(xk_ * mask).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
+ return xt_out.type_as(xt)
- def _create_rotary_embed_table(
- self,
- ):
+ def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
- max_seqlen = self.max_seqlen
freqs = 1.0 / (
- self.rope_freq_base
- ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim)
+ self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
- t = torch.arange(max_seqlen, device=freqs.device)
freqs = torch.outer(t, freqs).float()
- freqs_cis = (
- torch.complex(torch.cos(freqs), torch.sin(freqs))
- if self.use_hf
- else torch.polar(torch.ones_like(freqs), freqs)
- )
+ cos = torch.cos(freqs)
+ sin = torch.sin(freqs)
+ complex = torch.complex(cos, sin)
+ return complex
+
+ def _create_rotary_embed_table(self):
+ t = torch.arange(self.max_seqlen, device=self.device)
+ freqs_cis = self._compute_rotary_embed_table(t)
+ return self._replicate(freqs_cis)
+
+ def _replicate(self, t):
+ if self.tensor_parallelism_size > 1:
+ # Replicate across all devices, the data is not a lot and the computation is cheap.
+ t = ops.replicate(t, self.tensor_parallelism_size)
- return freqs_cis
+ return t
diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py
new file mode 100644
index 000000000..fb330aadd
--- /dev/null
+++ b/sharktank/sharktank/layers/testing.py
@@ -0,0 +1,45 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import torch
+from ..types.theta import Theta
+from ..types.tensors import DefaultPrimitiveTensor
+from ..utils.testing import make_rand_torch
+
+
+def make_llama_attention_block_theta(
+ *,
+ head_count: int,
+ head_count_kv: int,
+ head_dim: int,
+ embedding_length: int,
+ dtype: torch.dtype | None = None,
+) -> Theta:
+ return Theta(
+ {
+ "attn_q.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch(
+ (head_count * head_dim, embedding_length), dtype=dtype
+ )
+ ),
+ "attn_k.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch(
+ (head_count_kv * head_dim, embedding_length), dtype=dtype
+ )
+ ),
+ "attn_v.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch(
+ (head_count_kv * head_dim, embedding_length), dtype=dtype
+ )
+ ),
+ "attn_output.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch((embedding_length, embedding_length), dtype=dtype)
+ ),
+ "attn_norm.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch((embedding_length), dtype=dtype)
+ ),
+ }
+ )
diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py
new file mode 100644
index 000000000..077e4e064
--- /dev/null
+++ b/sharktank/sharktank/models/grok/grok.py
@@ -0,0 +1,234 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+from ...layers import *
+from ...utils.create_cache import *
+from ...types import Theta
+
+torch.set_printoptions(profile="full")
+
+__all__ = [
+ "PagedGrokModelV1",
+]
+
+################################################################################
+# Models
+################################################################################
+
+
+class PagedGrokModelV1(BaseCausalLMModel):
+ """Grok model with a paged KV cache and supporting variable sequence
+ length batched inference.
+
+ As both the caching and batching setup is complicated, this model variant
+ is modular, intending to be instantiated and used in an overall assembly
+ vs trying to providing one-stop methods that do everything.
+
+ The inference procedure is typically:
+
+ 1. Initialize the PagedKVCache state tensors.
+ 2. Generate an input mask given a vector of sequence lengths.
+ 3. Generate an attention mask from the input mask.
+ 4. Allocate a block mapping table.
+ 5. Invoke prefill() with a batch of sequences.
+ 6. Extract tokens from batched logits.
+ 7. Iteratively invoke decode() for as long as there are sequences needing
+ to be serviced.
+
+ Various samplers and schedulers can be interleaved throughout.
+ """
+
+ def __init__(self, theta: Theta, config: LlamaModelConfig):
+ hp = config.hp
+ super().__init__(
+ theta,
+ context_length=config.hp.context_length,
+ device=config.device,
+ activation_dtype=config.activation_dtype,
+ attention_dtype=config.attention_dtype,
+ )
+ self.config = config
+ self.hp = hp
+ self.cache = create_kv_cache(self.config)
+ self.activation_dtype = config.activation_dtype
+ self.add_module(
+ "token_embedding",
+ TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
+ )
+ self.add_module(
+ "attention_embedding",
+ RotaryEmbeddingLayer(
+ rope_dimension_count=hp.rope_dimension_count,
+ rope_freq_base=hp.rope_freq_base,
+ max_seqlen=hp.context_length,
+ device=self.device,
+ use_hf=True,
+ ),
+ )
+ self.add_module(
+ "output_norm",
+ RMSNormLayer(
+ theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon
+ ),
+ )
+ self.add_module("output_lm_head", LinearLayer(theta("output")))
+
+ self.attn_blocks = nn.ModuleList()
+ self.moe_blocks = nn.ModuleList()
+
+ for n in range(hp.block_count):
+ self.attn_blocks.append(
+ PagedLlamaAttentionBlock(
+ theta("blk", n),
+ block_index=n,
+ cache=self.cache,
+ head_count=hp.attention_head_count,
+ head_dim=hp.attn_head_dim,
+ head_count_kv=hp.attention_head_count_kv,
+ rms_epsilon=hp.attention_layer_norm_rms_epsilon,
+ softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864
+ )
+ )
+ self.moe_blocks.append(
+ MoeBlock(
+ theta("blk", n),
+ expert_count=hp.expert_count,
+ expert_used_count=hp.expert_used_count,
+ rms_epsilon=hp.attention_layer_norm_rms_epsilon,
+ moe_activation=F.gelu,
+ )
+ )
+
+ def prefill(
+ self,
+ # [bs, batch_seq_len]
+ tokens: torch.Tensor,
+ *,
+ # [1, 1, batch_seq_len, batch_seq_len]
+ attention_mask: torch.Tensor,
+ # [bs, batch_seq_len // block_seq_stride]
+ seq_block_ids: torch.Tensor,
+ cache_state: list[torch.Tensor],
+ ):
+ self._assert_device(tokens)
+ self._assert_device(attention_mask, dtype=self.activation_dtype)
+ self._assert_device(seq_block_ids)
+ self._assert_device(*cache_state, dtype=self.activation_dtype)
+ h = self.token_embedding(tokens)
+ h *= math.sqrt(h.shape[-1])
+ self.trace_tensor("grok.token_embedding", h)
+
+ # Iterate over attention blocks.
+ for block_idx, (attn_block, moe_block) in enumerate(
+ zip(self.attn_blocks, self.moe_blocks)
+ ):
+ if block_idx == 0:
+ self.trace_tensor(f"grok.attn_block.{block_idx}.input", h)
+
+ h = attn_block(
+ h,
+ embedding=self.attention_embedding,
+ start_index=0,
+ attention_mask=attention_mask,
+ cache_state=cache_state,
+ seq_block_ids=seq_block_ids,
+ )
+ self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)
+
+ h = moe_block(h)
+ self.trace_tensor(f"grok.moe_block.{block_idx}.output", h)
+
+ h = self.output_norm(h)
+ logits = self.output_lm_head(h)
+ logits = logits / math.sqrt(3.0)
+ return logits
+
+ def decode(
+ self,
+ # [bs, 1]
+ tokens: torch.Tensor,
+ *,
+ # [bs, 1, 1, batch_seq_len]
+ attention_mask: torch.Tensor,
+ # [bs] of starting positions
+ start_positions: torch.Tensor,
+ # [bs, batch_seq_len // block_seq_stride]
+ seq_block_ids: torch.Tensor,
+ cache_state: list[torch.Tensor],
+ ):
+ self._assert_device(tokens)
+ self._assert_device(attention_mask, dtype=self.activation_dtype)
+ self._assert_device(start_positions)
+ self._assert_device(*cache_state, dtype=self.activation_dtype)
+ bs, _ = tokens.shape
+ # Precompute a position based mask for computing rope embeddings
+ # as it is the same for all blocks.
+ embedding_batch_mask = self.attention_embedding.compute_batch_mask(
+ start_positions, batch_seq_len=1
+ )
+ self.trace_tensor("grok.embedding_batch_mask", embedding_batch_mask)
+
+ # Allocate per-block temporary K/V tensors. These temporaries hold
+ # one block's K/V state for the maximum context length.
+ xk_temp = torch.empty(
+ [
+ bs,
+ self.context_length,
+ self.hp.attention_head_count_kv,
+ self.hp.attn_head_dim,
+ ],
+ dtype=self.config.activation_dtype,
+ device=self.device,
+ )
+ xv_temp = torch.empty(
+ [
+ bs,
+ self.context_length,
+ self.hp.attention_head_count_kv,
+ self.hp.attn_head_dim,
+ ],
+ dtype=self.config.activation_dtype,
+ device=self.device,
+ )
+
+ h = self.token_embedding(tokens)
+ h *= math.sqrt(h.shape[-1])
+ self.trace_tensor("grok.token_embedding", h)
+
+ # Iterate over attention blocks.
+ for block_idx, (attn_block, moe_block) in enumerate(
+ zip(self.attn_blocks, self.moe_blocks)
+ ):
+ if block_idx == 0:
+ self.trace_tensor(f"grok.attn_block.{block_idx}.input", h)
+
+ h = attn_block(
+ h,
+ start_positions=start_positions,
+ embedding=self.attention_embedding,
+ embedding_batch_mask=embedding_batch_mask,
+ attention_mask=attention_mask,
+ cache_state=cache_state,
+ seq_block_ids=seq_block_ids,
+ xk_temp=xk_temp,
+ xv_temp=xv_temp,
+ )
+ self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)
+
+ h = moe_block(h)
+ self.trace_tensor(f"grok.moe_block.{block_idx}.output", h)
+
+ h = self.output_norm(h)
+ logits = self.output_lm_head(h)
+ logits = logits / math.sqrt(3.0)
+ return logits
diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py
index 8266872ad..0a9a6f1c3 100644
--- a/sharktank/sharktank/models/llama/llama.py
+++ b/sharktank/sharktank/models/llama/llama.py
@@ -7,7 +7,7 @@
from typing import Optional
from dataclasses import dataclass
-import math
+from typing import Union
import torch
import torch.nn as nn
@@ -15,77 +15,13 @@
from ...layers import *
from ...types import *
+from ...utils.create_cache import *
from ... import ops
__all__ = [
- "LlamaModelConfig",
"PagedLlamaModelV1",
]
-################################################################################
-# Config
-################################################################################
-
-
-@dataclass
-class LlamaModelConfig:
- hp: configs.LlamaHParams
-
- # Block sequence stride for a paged KV cache. This must divide evenly
- # into the context length.
- block_seq_stride: int = 16
-
- # Either "paged" or "direct".
- kv_cache_type: str = "paged"
-
- # The device on which to place intermediate state.
- device: Optional[torch.device] = None
-
- # Dtype to use for general FP activations not otherwise configured.
- activation_dtype: torch.dtype = torch.float16
-
- # Dtype to use for attention.
- attention_dtype: torch.dtype = torch.float16
-
- # Indicates if running with HuggingFace implementation and ensures
- # numerical equivalency to HuggingFace's LLaMa if true (by modifying
- # rotary embedding).
- use_hf: bool = False
-
- # If true, then the model may pre-initialize certain tables during
- # init. This can be better for eager execution but when capturing a program,
- # it is often better to preserve the calculation explicitly and rely on
- # the compiler to transform it to an initialization time step. This can
- # be the difference of many gigabytes of static data being embedded in
- # the program and not.
- static_tables: bool = True
-
- def create_kv_cache(self) -> BaseKVCache:
- hp = self.hp
- if self.kv_cache_type == "direct":
- return DirectKVCache(
- block_seq_stride=self.block_seq_stride,
- transformer_block_count=hp.block_count,
- attn_head_count=hp.attention_head_count_kv,
- attn_head_dim=hp.attn_head_dim,
- seq_length=hp.context_length,
- device=self.device,
- dtype=self.attention_dtype,
- )
- elif self.kv_cache_type == "paged":
- return PagedKVCache(
- transformer_block_count=hp.block_count,
- attn_head_count=hp.attention_head_count_kv,
- attn_head_dim=hp.attn_head_dim,
- cache_partition_count=2, # One for each of K/V.
- block_seq_stride=self.block_seq_stride,
- device=self.device,
- dtype=self.attention_dtype,
- )
- else:
- raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}")
-
-
################################################################################
# Models
################################################################################
@@ -111,6 +47,19 @@ class PagedLlamaModelV1(BaseCausalLMModel):
to be serviced.
Various samplers and schedulers can be interleaved throughout.
+
+ In the case of tensor sharding (config.tensor_parallelism_size > 1) the model's KV
+ cache head dimension is sharded.
+ The number of KV cache heads must be divisible by the parallelism size.
+ With this sharding approach the KV cache is not replicated across devices.
+ The cache is split across the devices while the indexing logic/computation is
+ replicated.
+ All other arguments aside from the cache state are replicated.
+ After the attention we all-reduce.
+ The the first fully connected layer is split along the parallel dimension.
+ This drives that the reduction dimension is split for the second FC layer.
+ We return the unreduced tensor. The user is free to reduce it to obtain the
+ unsharded result or chain it with other tensor-parallel operations.
"""
def __init__(self, theta: Theta, config: LlamaModelConfig):
@@ -122,12 +71,14 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
+ fake_quant=config.fake_quant,
)
self.config = config
self.hp = hp
- self.cache = config.create_kv_cache()
+ self.cache = create_kv_cache(self.config)
self.activation_dtype = config.activation_dtype
self.use_hf = config.use_hf
+ self.attention_kernel = config.attention_kernel
self.add_module(
"token_embedding",
@@ -142,6 +93,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
device=self.device,
use_hf=self.use_hf,
static_tables=config.static_tables,
+ tensor_parallelism_size=config.tensor_parallelism_size,
),
)
self.add_module(
@@ -151,7 +103,6 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
),
)
self.add_module("output_lm_head", LinearLayer(theta("output")))
-
self.attn_blocks = nn.ModuleList(
[
AttentionFFNBlock(
@@ -162,7 +113,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
- use_hf=self.use_hf,
+ attention_kernel=self.attention_kernel,
+ fake_quant=self.fake_quant,
)
for n in range(hp.block_count)
]
@@ -171,18 +123,19 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
def prefill(
self,
# [bs, batch_seq_len]
- tokens: torch.Tensor,
+ tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [1, 1, batch_seq_len, batch_seq_len]
- attention_mask: torch.Tensor,
+ attention_mask: Union[torch.Tensor, ReplicatedTensor],
# [bs, batch_seq_len // block_seq_stride]
- seq_block_ids: torch.Tensor,
- cache_state: list[torch.Tensor],
+ seq_block_ids: Union[torch.Tensor, ReplicatedTensor],
+ cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
):
self._assert_device(tokens)
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(seq_block_ids)
self._assert_device(*cache_state, dtype=self.activation_dtype)
+
h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)
@@ -207,20 +160,34 @@ def prefill(
def decode(
self,
# [bs, 1]
- tokens: torch.Tensor,
+ tokens: Union[torch.Tensor, ReplicatedTensor],
*,
# [bs, 1, 1, batch_seq_len]
- attention_mask: torch.Tensor,
+ attention_mask: Union[torch.Tensor, ReplicatedTensor],
# [bs] of starting positions
- start_positions: torch.Tensor,
+ start_positions: Union[torch.Tensor, ReplicatedTensor],
# [bs, batch_seq_len // block_seq_stride]
- seq_block_ids: torch.Tensor,
- cache_state: list[torch.Tensor],
+ seq_block_ids: Union[torch.Tensor, ReplicatedTensor],
+ cache_state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
):
+ assert len(tokens.shape) == 2
+ assert len(attention_mask.shape) == 4
+ assert len(start_positions.shape) == 1
+ assert len(seq_block_ids.shape) == 2
+ assert tokens.shape[0] == attention_mask.shape[0]
+ assert tokens.shape[0] == start_positions.shape[0]
+ assert tokens.shape[0] == seq_block_ids.shape[0]
+ assert tokens.shape[1] == 1
+ assert attention_mask.shape[1] == 1 and attention_mask.shape[2] == 1
+ assert (
+ attention_mask.shape[3]
+ == seq_block_ids.shape[1] * self.config.block_seq_stride
+ )
self._assert_device(tokens)
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)
+
bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
@@ -231,26 +198,48 @@ def decode(
# Allocate per-block temporary K/V tensors. These temporaries hold
# one block's K/V state for the maximum context length.
- xk_temp = torch.empty(
- [
- bs,
- self.context_length,
- self.hp.attention_head_count_kv,
- self.hp.attn_head_dim,
- ],
- dtype=self.config.activation_dtype,
- device=self.device,
- )
- xv_temp = torch.empty(
- [
+ if self.config.tensor_parallelism_size == 1:
+ xk_temp = torch.empty(
+ [
+ bs,
+ self.context_length,
+ self.hp.attention_head_count_kv,
+ self.hp.attn_head_dim,
+ ],
+ dtype=self.config.activation_dtype,
+ device=self.device,
+ )
+ xv_temp = torch.empty(
+ [
+ bs,
+ self.context_length,
+ self.hp.attention_head_count_kv,
+ self.hp.attn_head_dim,
+ ],
+ dtype=self.config.activation_dtype,
+ device=self.device,
+ )
+ else:
+ shard_size = [
bs,
self.context_length,
- self.hp.attention_head_count_kv,
+ self.hp.attention_head_count_kv // self.config.tensor_parallelism_size,
self.hp.attn_head_dim,
- ],
- dtype=self.config.activation_dtype,
- device=self.device,
- )
+ ]
+ xk_temp_shard = [
+ torch.empty(
+ shard_size, dtype=self.config.activation_dtype, device=self.device
+ )
+ for _ in range(self.config.tensor_parallelism_size)
+ ]
+ xv_temp_shard = [
+ torch.empty(
+ shard_size, dtype=self.config.activation_dtype, device=self.device
+ )
+ for _ in range(self.config.tensor_parallelism_size)
+ ]
+ xk_temp = SplitPrimitiveTensor(ts=xk_temp_shard, shard_dim=2)
+ xv_temp = SplitPrimitiveTensor(ts=xv_temp_shard, shard_dim=2)
h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)
@@ -296,7 +285,8 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
- use_hf: bool = False,
+ attention_kernel: str = "decomposed",
+ fake_quant: bool = True,
):
super().__init__(theta)
self.add_module(
@@ -309,7 +299,8 @@ def __init__(
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
- use_hf=use_hf,
+ attention_kernel=attention_kernel,
+ fake_quant=fake_quant,
),
)
self.add_module(
@@ -324,7 +315,7 @@ def __init__(
def forward(
self,
- h: torch.Tensor,
+ h: Union[torch.Tensor, ReplicatedTensor],
*,
embedding: RotaryEmbeddingLayer,
# [bs, batch_seq_len // block_seq_stride]
@@ -349,6 +340,7 @@ def forward(
xk_temp=xk_temp,
xv_temp=xv_temp,
)
+
# Feed forward network.
ffn_input = self.ffn_norm(h)
ffn_down = self.ffn(ffn_input)
diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py
index 74ed9e8e0..9f77daa40 100644
--- a/sharktank/sharktank/models/llama/llama_ref.py
+++ b/sharktank/sharktank/models/llama/llama_ref.py
@@ -7,7 +7,6 @@
from typing import Optional
from dataclasses import dataclass
-import math
import torch
import torch.nn as nn
@@ -230,7 +229,9 @@ def forward(
values = values.transpose(1, 2)
# Flash attention.
- attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
+ attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt(
+ self.head_dim
+ )
# Apply attention mask.
if attention_mask is not None:
diff --git a/sharktank/sharktank/models/llama/sharding.py b/sharktank/sharktank/models/llama/sharding.py
new file mode 100644
index 000000000..3715a3923
--- /dev/null
+++ b/sharktank/sharktank/models/llama/sharding.py
@@ -0,0 +1,114 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Specifications describing how the Llama model is sharded."""
+
+from ...types.sharding import *
+from ...types import Theta
+from ... import ops
+
+
+class PagedLlamaAttentionBlockSharding(ThetaLayerSharding):
+ def __init__(self, shard_count: int):
+ super().__init__()
+ self.shard_count = shard_count
+
+ def theta_sharding(self) -> ThetaSharding:
+ return ThetaSharding(
+ {
+ # The size of this is the token embedding length, which is not a memory
+ # space concern if replicated even for all attention blocks.
+ "attn_norm": RmsNormReplicatedSharding(
+ self.shard_count
+ ).theta_sharding(),
+ "attn_q": LinearSplitParallelWeightAndBiasSharding(
+ shard_count=self.shard_count
+ ).theta_sharding(),
+ "attn_k": LinearSplitParallelWeightAndBiasSharding(
+ shard_count=self.shard_count
+ ).theta_sharding(),
+ "attn_v": LinearSplitParallelWeightAndBiasSharding(
+ shard_count=self.shard_count
+ ).theta_sharding(),
+ "attn_output": LinearSplitReductionDimSharding(
+ shard_count=self.shard_count
+ ).theta_sharding(),
+ }
+ )
+
+
+class AttentionFFNBlockSharding(ThetaLayerSharding):
+ def __init__(self, shard_count: int):
+ super().__init__()
+ self.shard_count = shard_count
+
+ def theta_sharding(self) -> ThetaSharding:
+ result = PagedLlamaAttentionBlockSharding(self.shard_count).theta_sharding()
+ result.update(FFNSharding(self.shard_count).theta_sharding())
+ result.update(
+ {
+ # The size of this is the token embedding length, which is not a memory
+ # space concern if replicated.
+ "ffn_norm": RmsNormReplicatedSharding(self.shard_count).theta_sharding()
+ }
+ )
+ return result
+
+
+class LlamaSharding(ThetaLayerSharding):
+ """Shards the input channel and output channels of the convolutions."""
+
+ def __init__(self, shard_count: int, attention_block_count: int):
+ super().__init__()
+ self.shard_count = shard_count
+ self.attention_block_count = attention_block_count
+
+ def theta_sharding(self) -> ThetaSharding:
+ result = ThetaSharding(
+ {
+ # Replicate the vocabulary. For llama 1-3 this will require 0.5 GiB.
+ # For devices with large memory this may be an acceptable tradeoff where
+ # we save on communication by not all-gathering the result afterwards.
+ # The computation is just indexing and replication is not a concern.
+ # Alternatively, we can try splitting the index dimension,
+ # this would require custom logic for indexing partitioning and gathering.
+ "token_embd": TokenEmbeddingLayerReplicatedSharding(
+ self.shard_count
+ ).theta_sharding(),
+ "rope_freqs": Ignore(),
+ "output_norm": RmsNormReplicatedSharding(
+ self.shard_count
+ ).theta_sharding(),
+ "output": LinearSplitReductionDimSharding(
+ self.shard_count
+ ).theta_sharding(),
+ }
+ )
+ result.update(
+ {
+ "blk": ThetaSharding(
+ {
+ f"{i}": AttentionFFNBlockSharding(
+ self.shard_count
+ ).theta_sharding()
+ for i in range(self.attention_block_count)
+ }
+ )
+ }
+ )
+ return result
+
+
+def shard_theta(
+ theta: Theta, config: "sharktank.models.llama.llama.LlamaModelConfig"
+) -> Theta:
+ return ops.reshard(
+ theta,
+ LlamaSharding(
+ shard_count=config.tensor_parallelism_size,
+ attention_block_count=config.hp.block_count,
+ ),
+ )
diff --git a/sharktank/sharktank/models/llama/testing.py b/sharktank/sharktank/models/llama/testing.py
index b63fd5d07..079602b28 100644
--- a/sharktank/sharktank/models/llama/testing.py
+++ b/sharktank/sharktank/models/llama/testing.py
@@ -10,12 +10,11 @@
from ...types.tensors import *
from ...types.theta import Theta
-
-
-# Range of torch.rand() is [0,1)
-# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
-def make_rand_torch(shape, dtype=torch.float32):
- return torch.rand(shape, dtype=dtype) * 2 - 1
+from typing import Optional
+from .llama import LlamaModelConfig
+import torch
+from ...utils.testing import make_rand_torch
+from ...layers.testing import make_llama_attention_block_theta
def make_attention_block_theta(
@@ -56,11 +55,54 @@ def make_attention_block_theta(
)
+def make_attention_block_ffn_theta_v2(
+ *,
+ head_count: int,
+ head_count_kv: int,
+ head_dim: int,
+ embedding_length: int,
+ feed_forward_length: int,
+ dtype: torch.dtype | None = None,
+) -> Theta:
+ attention_theta = make_llama_attention_block_theta(
+ head_count=head_count,
+ head_count_kv=head_count_kv,
+ head_dim=head_dim,
+ embedding_length=embedding_length,
+ dtype=dtype,
+ )
+ ffn_theta = Theta(
+ {
+ "ffn_norm.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch((head_count * head_dim), dtype=dtype)
+ ),
+ "ffn_gate.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch(
+ (feed_forward_length, embedding_length), dtype=dtype
+ )
+ ),
+ "ffn_up.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch(
+ (feed_forward_length, embedding_length), dtype=dtype
+ )
+ ),
+ "ffn_down.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch(
+ (embedding_length, feed_forward_length), dtype=dtype
+ )
+ ),
+ }
+ )
+ res_dict = attention_theta.tree
+ res_dict.update(ffn_theta.tree)
+ return Theta(res_dict)
+
+
def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta:
return Theta(
{
"blk.0.ffn_gate_inp.weight": DefaultPrimitiveTensor(
- data=make_rand_torch((feature_dim, ffn_dim))
+ data=make_rand_torch((num_experts, ffn_dim))
),
"blk.0.ffn_norm.weight": DefaultPrimitiveTensor(
data=make_rand_torch((ffn_dim))
@@ -69,13 +111,41 @@ def make_moe_block_theta(feature_dim=1024, ffn_dim=6144, num_experts=8) -> Theta
data=make_rand_torch((ffn_dim))
),
"blk.0.ffn_gate_exps.weight": DefaultPrimitiveTensor(
- data=make_rand_torch((8, feature_dim * num_experts, ffn_dim))
+ data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim))
),
"blk.0.ffn_up_exps.weight": DefaultPrimitiveTensor(
- data=make_rand_torch((8, feature_dim * num_experts, ffn_dim))
+ data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim))
),
"blk.0.ffn_down_exps.weight": DefaultPrimitiveTensor(
- data=make_rand_torch((8, ffn_dim, feature_dim * num_experts))
+ data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts))
),
}
)
+
+
+def make_random_llama_theta(
+ config: LlamaModelConfig, vocab_size: int, dtype: Optional[torch.dtype] = None
+) -> Theta:
+ res = {
+ "token_embd.weight": DefaultPrimitiveTensor(
+ data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype)
+ )
+ }
+ for i in range(config.hp.block_count):
+ res[f"blk.{i}"] = make_attention_block_ffn_theta_v2(
+ head_count=config.hp.attention_head_count,
+ head_count_kv=config.hp.attention_head_count_kv,
+ head_dim=config.hp.attn_head_dim,
+ embedding_length=config.hp.embedding_length,
+ feed_forward_length=config.hp.feed_forward_length,
+ dtype=dtype,
+ ).tree
+
+ res[f"output.weight"] = DefaultPrimitiveTensor(
+ data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype)
+ )
+ res[f"output_norm.weight"] = DefaultPrimitiveTensor(
+ data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype)
+ )
+
+ return Theta(res)
diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py
new file mode 100644
index 000000000..052593748
--- /dev/null
+++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py
@@ -0,0 +1,371 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Imports quark pre-processed weights and quantization config into a
+Dataset of the gguf format.
+
+Usage:
+ python -m sharktank.models.llama.tools.import_quark_dataset \
+ --params=llama2-7b-fp8.safetensors --output-irpa-file=new.irpa \
+ --config-json=../llama2/config.json
+
+"""
+from typing import Optional
+
+from safetensors.torch import save_file
+import json
+from pathlib import Path
+import safetensors
+import sys
+import torch
+
+from sharktank.types import *
+from sharktank.layers.configs.llm_configs import (
+ _int_prop,
+ _float_prop,
+ _optional_int_prop,
+ _int_prop,
+)
+
+
+def _load_json(p: Path):
+ print(f"Loading {p}")
+ with open(p, "rb") as f:
+ return json.load(f)
+
+
+def _get_dataset_props(config_json_struct) -> dict:
+ # Separate meta parameters (prefixed with _) from hparams.
+ meta_params = {k: v for k, v in config_json_struct.items() if k.startswith("_")}
+ hparams = {k: v for k, v in config_json_struct.items() if not k.startswith("_")}
+ return {
+ "meta": meta_params,
+ "hparams": hparams,
+ }
+
+
+def _load_theta(st_source) -> Theta:
+ tensors = [
+ DefaultPrimitiveTensor(name=name, data=st_source.get_tensor(name))
+ for name in st_source.keys()
+ ]
+ return Theta(tensors)
+
+
+def as_torch_or_none(tensor: Optional[InferenceTensor]) -> Optional[torch.Tensor]:
+ if tensor is None:
+ return None
+ return tensor.as_torch()
+
+
+def hf_to_gguf(layer_name: str) -> str:
+ assert layer_name.startswith("model.layers")
+ mapping = {
+ "input_layernorm": "attn_norm",
+ "self_attn.q_proj": "attn_q",
+ "self_attn.k_proj": "attn_k",
+ "self_attn.v_proj": "attn_v",
+ "self_attn.o_proj": "attn_output",
+ "post_attention_layernorm": "ffn_norm",
+ "mlp.gate_proj": "ffn_gate",
+ "mlp.up_proj": "ffn_up",
+ "mlp.down_proj": "ffn_down",
+ }
+
+ # Split the input string
+ parts = layer_name.split(".")
+
+ # Extract the numerical value and the key to be mapped
+ numerical_value = parts[2] # The part after "models.layers" and its number
+ key_to_map = ".".join(parts[3:])
+
+ # Map the key
+ if key_to_map in mapping:
+ mapped_value = mapping[key_to_map]
+ else:
+ raise ValueError(f"Mapping for '{key_to_map}' not found.")
+
+ # Construct the output string
+ output_str = f"blk.{numerical_value}.{mapped_value}"
+ return output_str
+
+
+def apply_per_layer_quant(
+ root_theta: Theta,
+ layer_name: str,
+ updated_tensors: dict[str, InferenceTensor],
+ n_head: int,
+ split_sizes: list[int],
+):
+ """Take the quantization parameters and hf weights from the imported Theta
+ and create InferenceTensors out of them, converting their names to gguf format
+ in the process.
+ """
+
+ layer_theta = root_theta(layer_name)
+
+ weight_quant_scale = layer_theta.tensor("weight_scale").as_torch()
+
+ weight = layer_theta.tensor("weight").as_torch()
+
+ # It looks dumb but, this step is required for numerical correctness against quark.
+ # weight = weight.view(torch.float8_e4m3fn)
+ weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.float16)
+
+ weight_quant_zero_point = layer_theta.optional_tensor("weight_zero_point")
+ if weight_quant_zero_point == None:
+ weight_quant_zero_point = torch.zeros(1, dtype=torch.float32)
+ else:
+ weight_quant_zero_point = weight_quant_zero_point.as_torch()
+ input_quant_scale = as_torch_or_none(layer_theta.optional_tensor("input_scale"))
+ output_quant_scale = as_torch_or_none(layer_theta.optional_tensor("output_scale"))
+
+ if weight_quant_scale is None:
+ print("weight quant scale not found for layer ", layer_name)
+ return
+
+ layer_parent = ".".join(layer_name.split(".")[:-1])
+
+ def quantize_weight(
+ weight_name: str,
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ weight_zp: Optional[torch.Tensor],
+ ):
+ # Our scale is the reciprocal of the quark scale
+ # We multiply scale by two to account for diff between fnuz and fn
+ weight_quantizer = StaticScaledQuantizer(
+ scale=1.0 / (weight_scale * 2.0),
+ reciprocal_scale=(weight_scale * 2.0),
+ offset=None
+ if (weight_zp is None or torch.count_nonzero(weight_zp) == 0)
+ else weight_zp,
+ dtype=torch.float8_e4m3fnuz,
+ )
+ weight_quant = weight_quantizer.quantize(weight, name=weight_name)
+ updated_tensors[weight_quant.name] = weight_quant
+
+ if "qkv" in layer_name:
+ # The qkv layer is fused in the quark model, decompose back into individual q, k , and v weights
+ q_weight, k_weight, v_weight = torch.split(weight, split_sizes)
+ q_weight = (
+ q_weight.reshape(
+ n_head, 2, q_weight.shape[0] // n_head // 2, *q_weight.shape[1:]
+ )
+ .swapaxes(1, 2)
+ .reshape(q_weight.shape)
+ )
+ k_weight = (
+ k_weight.reshape(
+ n_head, 2, k_weight.shape[0] // n_head // 2, *k_weight.shape[1:]
+ )
+ .swapaxes(1, 2)
+ .reshape(k_weight.shape)
+ )
+ q_name = hf_to_gguf(layer_parent + ".q_proj")
+ k_name = hf_to_gguf(layer_parent + ".k_proj")
+ v_name = hf_to_gguf(layer_parent + ".v_proj")
+ quantize_weight(
+ q_name + ".weight", q_weight, weight_quant_scale, weight_quant_zero_point
+ )
+ quantize_weight(
+ k_name + ".weight", k_weight, weight_quant_scale, weight_quant_zero_point
+ )
+ quantize_weight(
+ v_name + ".weight", v_weight, weight_quant_scale, weight_quant_zero_point
+ )
+ # The output and input quantizers are duplicated for each of the q, k, and v weights
+ names = [f"{i}.qdq_output" for i in [q_name, k_name, v_name]]
+ for name in names:
+ updated_tensors[name] = StaticScaledQuantizer(
+ name=name,
+ scale=1.0 / (output_quant_scale * 2.0),
+ reciprocal_scale=output_quant_scale * 2.0,
+ dtype=torch.float8_e4m3fnuz,
+ )
+ names = [f"{i}.q_input" for i in [q_name, k_name, v_name]]
+ for name in names:
+ updated_tensors[name] = StaticScaledQuantizer(
+ name=name,
+ scale=1.0 / (input_quant_scale * 2.0),
+ reciprocal_scale=input_quant_scale * 2.0,
+ dtype=torch.float8_e4m3fnuz,
+ )
+ # Remove the updated tensors from the original tree.
+ root_theta.pop(layer_parent + ".q_proj")
+ root_theta.pop(layer_parent + ".k_proj")
+ root_theta.pop(layer_parent + ".v_proj")
+ root_theta.pop(layer_name)
+
+ else:
+ new_layer_name = hf_to_gguf(layer_name)
+ quantize_weight(
+ new_layer_name + ".weight",
+ weight,
+ weight_quant_scale,
+ weight_quant_zero_point,
+ )
+ # we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences
+ if input_quant_scale is not None:
+ updated_tensors[new_layer_name + ".q_input"] = StaticScaledQuantizer(
+ name=new_layer_name + ".q_input",
+ scale=1.0 / (input_quant_scale * 2.0),
+ reciprocal_scale=input_quant_scale * 2.0,
+ dtype=torch.float8_e4m3fnuz,
+ )
+ if output_quant_scale is not None:
+ updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer(
+ name=new_layer_name + ".qdq_output",
+ scale=1.0 / output_quant_scale,
+ reciprocal_scale=output_quant_scale,
+ dtype=torch.float8_e4m3fnuz,
+ )
+
+ # Remove the updated tensor from the original tree.
+ root_theta.pop(layer_name)
+
+
+def convert_hf_hparams_to_gguf(hf_hparams: dict[str, any]) -> dict[str, any]:
+ hp = hf_hparams["hparams"]
+ attention_head_count = _int_prop(hp, "num_attention_heads")
+ attn_head_dim = int(
+ _int_prop(hp, "hidden_size") // _int_prop(hp, "num_attention_heads")
+ )
+
+ return {
+ "llama.context_length": _int_prop(hp, "max_position_embeddings"),
+ "llama.embedding_length": _int_prop(hp, "hidden_size"),
+ "llama.block_count": _int_prop(hp, "num_hidden_layers"),
+ "llama.feed_forward_length": _int_prop(hp, "intermediate_size"),
+ "llama.rope.dimension_count": attn_head_dim,
+ "llama.attention.head_count": attention_head_count,
+ "llama.attention.layer_norm_rms_epsilon": _float_prop(hp, "rms_norm_eps"),
+ "llama.attention.head_count_kv": _optional_int_prop(
+ hp, "num_key_value_heads", attention_head_count
+ ),
+ }
+
+
+def update_norm_layer(
+ quant_theta: Theta, layer_name: str, updated_tensors: dict[str, InferenceTensor]
+):
+ """Convert layernames for non quantized tensors and add them to the updated_tensors dict"""
+ for sub in ["input_layernorm", "post_attention_layernorm"]:
+ sub_name = layer_name + "." + sub
+ new_name = hf_to_gguf(sub_name) + ".weight"
+ single_replace(quant_theta, sub_name, new_name, updated_tensors)
+ kv_cache_scale = quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch()
+ layer_idx = layer_name.split(".")[-1]
+ new_name = f"blk.{layer_idx}.kv_cache"
+ updated_tensors[new_name] = StaticScaledQuantizer(
+ name=new_name + ".quantizer",
+ scale=1.0 / (kv_cache_scale * 2.0),
+ reciprocal_scale=kv_cache_scale * 2.0,
+ dtype=torch.float8_e4m3fnuz,
+ )
+
+
+def single_replace(
+ quant_theta: Theta,
+ layer_name: str,
+ gguf_name: str,
+ updated_tensors: dict[str, InferenceTensor],
+):
+ data = quant_theta(layer_name).tensor("weight").as_torch()
+ if data.dtype == torch.bfloat16:
+ data = data.to(torch.float32)
+ updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data)
+
+
+def main(argv):
+ from sharktank.utils import cli
+
+ parser = cli.create_parser()
+ cli.add_output_dataset_options(parser)
+ parser.add_argument(
+ "--config-json", type=Path, required=True, help="Path to the config.json file"
+ )
+ parser.add_argument(
+ "--params",
+ type=Path,
+ default=Path("params.safetensors"),
+ help="Parameter file name, relative to config.json",
+ )
+ parser.add_argument(
+ "--model-base",
+ type=str,
+ default="7b",
+ help="Base model to use for split sizes to decompose the qkv tensor. Default is 7b, 70b is also supported.",
+ choices=["7b", "70b"],
+ )
+ args = cli.parse(parser, args=argv)
+
+ config_json_path: Path = args.config_json
+ params_path: Path = args.params
+ # TODO: find a way to get this programatically so we don't have to flag for it
+ split_sizes = [4096, 4096, 4096] if args.model_base == "7b" else [8192, 1024, 1024]
+ num_layers = 32 if args.model_base == "7b" else 80
+
+ # Construct the pre-transform dataset.
+ dataset_props = _get_dataset_props(_load_json(config_json_path))
+ with safetensors.safe_open(params_path, framework="pt", device="cpu") as st:
+ quant_theta = _load_theta(st)
+ ds = Dataset(dataset_props, quant_theta)
+
+ # Convert hyperparams to gguf format
+ updated_properties = convert_hf_hparams_to_gguf(ds.properties)
+
+ head_count = (updated_properties["llama.attention.head_count"],)
+
+ updated_tensors: dict[str, InferenceTensor] = {}
+ model_layers = [f"model.layers.{i}" for i in range(num_layers)]
+
+ sub_layers = [
+ "mlp.gate_proj",
+ "mlp.down_proj",
+ "mlp.up_proj",
+ "self_attn.o_proj",
+ "self_attn.q_proj",
+ "self_attn.k_proj",
+ "self_attn.v_proj",
+ ]
+ for layer in model_layers:
+ for sub in sub_layers:
+ layer_name = layer + "." + sub
+ apply_per_layer_quant(
+ quant_theta,
+ layer_name,
+ updated_tensors,
+ n_head=head_count[0],
+ split_sizes=split_sizes,
+ )
+
+ # Update the non quantized weights (norm layers)
+ for layer_idx in model_layers:
+ update_norm_layer(
+ quant_theta,
+ layer_idx,
+ updated_tensors,
+ )
+
+ # The stragglers
+ stragglers = [
+ ("model.embed_tokens", "token_embd.weight"),
+ ("model.norm", "output_norm.weight"),
+ ("lm_head", "output.weight"),
+ ]
+ for layer, new_name in stragglers:
+ single_replace(quant_theta, layer, new_name, updated_tensors)
+
+ new_theta = Theta(updated_tensors)
+ # Make a new Dataset from the updated properties and tensors.
+ new_ds = Dataset(updated_properties, new_theta)
+
+ new_ds.save(args.output_irpa_file, io_report_callback=print)
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py
index 5a179e5b9..e2995dfde 100644
--- a/sharktank/sharktank/models/mixtral/mixtral.py
+++ b/sharktank/sharktank/models/mixtral/mixtral.py
@@ -13,65 +13,15 @@
from ...layers import *
+from ...utils.create_cache import *
from ...types import Theta
torch.set_printoptions(profile="full")
__all__ = [
- "LlamaModelConfig",
"PagedMixtralModelV1",
]
-################################################################################
-# Config
-################################################################################
-
-
-@dataclass
-class LlamaModelConfig:
- hp: configs.LlamaHParams
-
- # Block sequence stride for a paged KV cache. This must divide evenly
- # into the context length.
- block_seq_stride: int = 16
-
- # Either "paged" or "direct".
- kv_cache_type: str = "paged"
-
- # The device on which to place intermediate state.
- device: Optional[torch.device] = None
-
- # Dtype to use for general FP activations not otherwise configured.
- activation_dtype: torch.dtype = torch.float16
-
- # Dtype to use for attention.
- attention_dtype: torch.dtype = torch.float16
-
- def create_kv_cache(self) -> BaseKVCache:
- hp = self.hp
- if self.kv_cache_type == "direct":
- return DirectKVCache(
- block_seq_stride=self.block_seq_stride,
- transformer_block_count=hp.block_count,
- attn_head_count=hp.attention_head_count_kv,
- attn_head_dim=hp.attn_head_dim,
- seq_length=hp.context_length,
- device=self.device,
- dtype=self.attention_dtype,
- )
- elif self.kv_cache_type == "paged":
- return PagedKVCache(
- transformer_block_count=hp.block_count,
- attn_head_count=hp.attention_head_count_kv,
- attn_head_dim=hp.attn_head_dim,
- cache_partition_count=2, # One for each of K/V.
- block_seq_stride=self.block_seq_stride,
- device=self.device,
- dtype=self.attention_dtype,
- )
- else:
- raise NotImplementedError(f"kv_cache_type = {self.kv_cache_type}")
-
################################################################################
# Models
@@ -111,7 +61,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
)
self.config = config
self.hp = hp
- self.cache = config.create_kv_cache()
+ self.cache = create_kv_cache(self.config)
self.activation_dtype = config.activation_dtype
self.add_module(
"token_embedding",
@@ -135,6 +85,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
self.add_module("output_lm_head", LinearLayer(theta("output")))
self.attn_blocks = nn.ModuleList()
+ self.moe_blocks = nn.ModuleList()
for n in range(hp.block_count):
self.attn_blocks.append(
@@ -148,8 +99,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
)
)
- self.attn_blocks.append(
- SparseMoeBlock(
+ self.moe_blocks.append(
+ MoeBlock(
theta("blk", n),
expert_count=hp.expert_count,
expert_used_count=hp.expert_used_count,
@@ -176,25 +127,26 @@ def prefill(
self.trace_tensor("mixtral.token_embedding", h)
# Iterate over attention blocks.
- for block_idx, block in enumerate(self.attn_blocks):
+ for block_idx, (attn_block, moe_block) in enumerate(
+ zip(self.attn_blocks, self.moe_blocks)
+ ):
if block_idx == 0:
self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h)
- if block.__class__.__name__ == "PagedLlamaAttentionBlock":
- h = block(
- h,
- embedding=self.attention_embedding,
- start_index=0,
- attention_mask=attention_mask,
- cache_state=cache_state,
- seq_block_ids=seq_block_ids,
- )
- self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)
- elif block.__class__.__name__ == "SparseMoeBlock":
- h = block(
- h,
- )
- self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h)
+ h = attn_block(
+ h,
+ embedding=self.attention_embedding,
+ start_index=0,
+ attention_mask=attention_mask,
+ cache_state=cache_state,
+ seq_block_ids=seq_block_ids,
+ )
+ self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)
+
+ h = moe_block(
+ h,
+ )
+ self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h)
h = self.output_norm(h)
logits = self.output_lm_head(h)
@@ -252,28 +204,29 @@ def decode(
self.trace_tensor("mixtral.token_embedding", h)
# Iterate over attention blocks.
- for block_idx, block in enumerate(self.attn_blocks):
+ for block_idx, (attn_block, moe_block) in enumerate(
+ zip(self.attn_blocks, self.moe_blocks)
+ ):
if block_idx == 0:
self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h)
- if block.__class__.__name__ == "PagedLlamaAttentionBlock":
- h = block(
- h,
- start_positions=start_positions,
- embedding=self.attention_embedding,
- embedding_batch_mask=embedding_batch_mask,
- attention_mask=attention_mask,
- cache_state=cache_state,
- seq_block_ids=seq_block_ids,
- xk_temp=xk_temp,
- xv_temp=xv_temp,
- )
- self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)
- elif block.__class__.__name__ == "SparseMoeBlock":
- h = block(
- h,
- )
- self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h)
+ h = attn_block(
+ h,
+ start_positions=start_positions,
+ embedding=self.attention_embedding,
+ embedding_batch_mask=embedding_batch_mask,
+ attention_mask=attention_mask,
+ cache_state=cache_state,
+ seq_block_ids=seq_block_ids,
+ xk_temp=xk_temp,
+ xv_temp=xv_temp,
+ )
+ self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)
+
+ h = moe_block(
+ h,
+ )
+ self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h)
h = self.output_norm(h)
logits = self.output_lm_head(h)
diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py
index 392f60a25..70a9b9cf8 100644
--- a/sharktank/sharktank/models/mixtral/mixtral_ref.py
+++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py
@@ -66,6 +66,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig):
self.add_module("output_lm_head", LinearLayer(theta("output")))
self.attn_blocks = nn.ModuleList()
+ self.moe_blocks = nn.ModuleList()
for n in range(hp.block_count):
self.attn_blocks.append(
@@ -78,8 +79,8 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig):
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
)
)
- self.attn_blocks.append(
- SparseMoeBlock(
+ self.moe_blocks.append(
+ MoeBlock(
theta("blk", n),
expert_count=hp.expert_count,
expert_used_count=hp.expert_used_count,
@@ -130,28 +131,29 @@ def forward(
block_count = len(self.attn_blocks) // 2
# print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count)
# Iterate over attention + MoE blocks.
- for block_idx, block in enumerate(self.attn_blocks):
+ for block_idx, (attn_block, moe_block) in enumerate(
+ zip(self.attn_blocks, self.moe_blocks)
+ ):
# print("block_idx, block", block_idx, block)
if block_idx == 0:
self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h)
- if block.__class__.__name__ == "LlamaAttentionBlock":
- attn_block_idx = block_idx // 2
- block_cache_k = local_kv_cache[attn_block_idx]
- block_cache_v = local_kv_cache[block_count + attn_block_idx]
- h = block(
- h,
- cache_k=block_cache_k,
- cache_v=block_cache_v,
- start_index=start_index,
- attention_mask=attention_mask,
- )
- self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)
- elif block.__class__.__name__ == "SparseMoeBlock":
- h = block(
- h,
- )
- self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h)
+ attn_block_idx = block_idx // 2
+ block_cache_k = local_kv_cache[attn_block_idx]
+ block_cache_v = local_kv_cache[block_count + attn_block_idx]
+ h = attn_block(
+ h,
+ cache_k=block_cache_k,
+ cache_v=block_cache_v,
+ start_index=start_index,
+ attention_mask=attention_mask,
+ )
+ self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h)
+
+ h = attn_block(
+ h,
+ )
+ self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h)
h = self.output_norm(h)
logits = self.output_lm_head(h)
diff --git a/sharktank/sharktank/models/punet/sharding.py b/sharktank/sharktank/models/punet/sharding.py
index 22237b8bb..a827b7942 100644
--- a/sharktank/sharktank/models/punet/sharding.py
+++ b/sharktank/sharktank/models/punet/sharding.py
@@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""Specifications describing how block/layers of punet are sharded."""
+"""Specifications describing how blocks/layers of punet are sharded."""
from ...types.sharding import *
@@ -31,7 +31,7 @@ def theta_sharding(self) -> ThetaSharding:
"conv2": Conv2DSplitOutputChannelSharding(
shard_count=self.shard_count
).theta_sharding(),
- "time_emb_proj": LinearReplicatedInputSplitWeightAndBiasSharding(
+ "time_emb_proj": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"conv_shortcut": Conv2DSplitOutputChannelSharding(
diff --git a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py
index fa862105c..22ca1f591 100644
--- a/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py
+++ b/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py
@@ -133,8 +133,19 @@ def _get_json_tensor(
# for signed arithmetic.
input_zp = _get_json_tensor("input_zp", dtype=None)
if input_zp is not None:
- assert torch.count_nonzero(input_zp) == 0
-
+ assert torch.count_nonzero(input_zp.float()) == 0
+
+ # Currently, there seems to be no standardization in `quant_params.json` for fields in every layer
+ # across different quantization schemes (int8, fp8). int8 quantization was the first end-to-end tested
+ # quantization scheme so there's some defaults to that.
+ quantization_type = (
+ qp.get("input_zp_dtype")
+ if qp.get("input_zp_dtype") is not None
+ else "torch.int8"
+ )
+ quantization_dtype = tensors.serialized_name_to_dtype(
+ quantization_type.split(".")[-1]
+ )
if output_scale is not None:
output_quantizer = StaticScaledQuantizer(
name=f"{layer_name}.q_output",
@@ -175,10 +186,12 @@ def quantize_weight(
weight_quantizer = StaticScaledQuantizer(
scale=1.0 / weight_scale,
reciprocal_scale=weight_scale,
- offset=None
- if (weight_zp is None or torch.count_nonzero(weight_zp) == 0)
- else weight_zp,
- dtype=torch.int8,
+ offset=(
+ None
+ if (weight_zp is None or torch.count_nonzero(weight_zp) == 0)
+ else weight_zp
+ ),
+ dtype=quantization_dtype,
)
weight_quant = weight_quantizer.quantize(weight, name=weight_name)
updated_tensors[weight_quant.name] = weight_quant
@@ -195,7 +208,7 @@ def quantize_bias(
bias_scale = 1.0 / (input_scale * weight_scale)
bias_quantizer = StaticScaledQuantizer(
scale=bias_scale,
- dtype=torch.int32,
+ dtype=torch.int32 if quantization_dtype == torch.int8 else torch.float16,
disable_saturate=True,
)
bias_quant = bias_quantizer.quantize(bias, name=bias_name)
@@ -227,7 +240,7 @@ def quantize_bias(
f"{layer_name}.to_v.weight", weight_v, weight_scale_v, weight_zp_v
)
updated_tensors[weight.name] = None
- if bias is not None:
+ if bias is not None and quantization_dtype == torch.int8:
bias_k, bias_v = bias.as_torch().chunk(2, dim=0)
quantize_bias(
f"{layer_name}.to_k.bias", bias_k, input_scale, weight_scale_k
@@ -259,7 +272,7 @@ def quantize_bias(
f"{layer_name}.to_v.weight", weight_v, weight_scale_v, weight_zp_v
)
updated_tensors[weight.name] = None
- if bias is not None:
+ if bias is not None and quantization_dtype == torch.int8:
bias_q, bias_k, bias_v = bias.as_torch().chunk(3, dim=0)
quantize_bias(
f"{layer_name}.to_q.bias", bias_q, input_scale, weight_scale_q
@@ -284,7 +297,7 @@ def quantize_bias(
name=f"{layer_name}.q_input",
scale=1.0 / input_scale,
reciprocal_scale=input_scale,
- dtype=torch.int8,
+ dtype=quantization_dtype,
)
updated_tensors[input_quantizer.name] = input_quantizer
diff --git a/sharktank/sharktank/models/punet/tools/run_punet.py b/sharktank/sharktank/models/punet/tools/run_punet.py
index ace279a3b..b2ad58d9d 100644
--- a/sharktank/sharktank/models/punet/tools/run_punet.py
+++ b/sharktank/sharktank/models/punet/tools/run_punet.py
@@ -9,7 +9,7 @@
import torch
-from shark_turbine import aot
+from iree.turbine import aot
from ..model import Unet2DConditionModel, ClassifierFreeGuidanceUnetModel
from ....utils.patching import SaveModuleResultTensorsPatch
diff --git a/sharktank/sharktank/models/t5/__init__.py b/sharktank/sharktank/models/t5/__init__.py
new file mode 100644
index 000000000..7c7e76704
--- /dev/null
+++ b/sharktank/sharktank/models/t5/__init__.py
@@ -0,0 +1,8 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .t5 import *
+from .export import *
diff --git a/sharktank/sharktank/models/t5/export.py b/sharktank/sharktank/models/t5/export.py
new file mode 100644
index 000000000..7bd5eef3d
--- /dev/null
+++ b/sharktank/sharktank/models/t5/export.py
@@ -0,0 +1,97 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Union
+from pathlib import Path
+import torch
+
+from .t5 import T5Config, T5Encoder
+from ...types import Dataset
+from iree.turbine.aot import FxProgramsBuilder, export
+
+__all__ = [
+ "export_encoder_mlir",
+ "export_encoder_iree_parameters",
+ "prune_decoder_parameters",
+]
+
+
+def export_encoder_mlir(
+ model: Union[T5Encoder, Path, str],
+ batch_sizes: list[int],
+ mlir_output_path: str,
+):
+ """
+ Args:
+ model: either the torch module or path to GGUF/IRPA.
+ """
+ if isinstance(model, (Path, str)):
+ dataset = Dataset.load(model)
+ config = T5Config.from_gguf_properties(
+ dataset.properties,
+ # TODO: add this property to our HuggingFace-to-GGUF conversion script.
+ # We currently use llama.cpp's converter and it can not make a distinction
+ # between T5 V1 and V1.1.
+ # V1 uses ReLU and V1.1 uses gated GeLU.
+ feed_forward_proj="gated-gelu",
+ )
+ model = T5Encoder(theta=dataset.root_theta, config=config)
+
+ fxb = FxProgramsBuilder(model)
+
+ for batch_size in batch_sizes:
+ sample_inputs = model.sample_inputs(batch_size)
+
+ context_length_dim_idx = 1
+ assert (
+ sample_inputs["input_ids"].shape[context_length_dim_idx]
+ % config.context_length_padding_block_size
+ == 0
+ )
+ context_length_block_dim_max = (
+ sample_inputs["input_ids"].shape[context_length_dim_idx]
+ // config.context_length_padding_block_size
+ )
+ context_length_block_dim = torch.export.Dim(
+ "block", max=context_length_block_dim_max
+ )
+ context_length_dim = (
+ config.context_length_padding_block_size * context_length_block_dim
+ )
+ dynamic_shapes = {"input_ids": {context_length_dim_idx: context_length_dim}}
+
+ @fxb.export_program(
+ name=f"forward_bs{batch_size}",
+ args=tuple(sample_inputs.values()),
+ dynamic_shapes=dynamic_shapes,
+ strict=False,
+ )
+ def _(
+ model,
+ input_ids,
+ ):
+ return model(input_ids)
+
+ output = export(fxb, import_symbolic_shape_expressions=True)
+ output.save_mlir(mlir_output_path)
+
+
+def prune_decoder_parameters(dataset: Dataset):
+ # Remove decoder tensors/parameters if present.
+ try:
+ del dataset.root_theta.tree["dec"]
+ except KeyError:
+ pass
+ try:
+ del dataset.properties["t5.decoder_start_token_id"]
+ except KeyError:
+ pass
+
+
+def export_encoder_iree_parameters(model_path: str, output_path: str):
+ dataset = Dataset.load(model_path)
+ prune_decoder_parameters(dataset)
+ dataset.save(output_path)
diff --git a/sharktank/sharktank/models/t5/t5.py b/sharktank/sharktank/models/t5/t5.py
new file mode 100644
index 000000000..4ae9108d5
--- /dev/null
+++ b/sharktank/sharktank/models/t5/t5.py
@@ -0,0 +1,1095 @@
+# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""T5 LLM adapted from transformers
+https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/t5/modeling_t5.py
+"""
+
+from typing import Any, Optional, Tuple
+from dataclasses import dataclass, field
+import math
+import torch
+from torch import nn
+import copy
+import logging
+import warnings
+from collections import OrderedDict
+
+from ...layers import (
+ BaseLayer,
+ RMSNormLayer,
+ TokenEmbeddingLayer,
+ LinearLayer,
+)
+from ... import ops
+from ...types.theta import Theta
+from ...types.tensors import AnyTensor
+from ...layers import FFN, T5Config
+
+__all__ = [
+ "T5Config",
+ "T5LayerFF",
+ "T5Attention",
+ "T5SelfAttention",
+ "T5CrossAttention",
+ "T5Block",
+ "T5Stack",
+ "T5Encoder",
+]
+
+logger = logging.getLogger(__name__)
+
+
+ACT2FN = {
+ "gelu": nn.functional.gelu,
+ "gelu_new": ops.gelu_tanh_approximation,
+ "relu": nn.functional.relu,
+}
+
+
+class T5LayerFF(nn.Module):
+ def __init__(
+ self,
+ theta: Theta,
+ is_gated_act: bool,
+ dense_act_fn: str,
+ layer_norm_epsilon: float,
+ activation_dtype: torch.dtype,
+ ):
+ super().__init__()
+ self.dense_activation_dense = FFN(
+ theta=theta, is_gated=is_gated_act, activation_fn=ACT2FN[dense_act_fn]
+ )
+
+ self.layer_norm = RMSNormLayer(
+ theta=theta("ffn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype
+ )
+
+ def forward(self, hidden_states):
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.dense_activation_dense(forwarded_states)
+ hidden_states = hidden_states + forwarded_states
+ return hidden_states
+
+
+class T5Attention(BaseLayer):
+ def __init__(
+ self,
+ theta: Theta,
+ is_decoder: bool,
+ relative_attention_num_buckets: int,
+ relative_attention_max_distance: int,
+ d_model: int,
+ d_kv: int,
+ num_heads: int,
+ activation_dtype: torch.dtype,
+ has_relative_attention_bias: bool = False,
+ ):
+ super().__init__()
+ self.is_decoder = is_decoder
+ self.relative_attention_num_buckets = relative_attention_num_buckets
+ self.relative_attention_max_distance = relative_attention_max_distance
+ self.d_model = d_model
+ self.key_value_proj_dim = d_kv
+ self.n_heads = num_heads
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
+ self.activation_dtype = activation_dtype
+
+ self.q = LinearLayer(theta("attn_q"))
+ self.k = LinearLayer(theta("attn_k"))
+ self.v = LinearLayer(theta("attn_v"))
+ self.o = LinearLayer(theta("attn_o"))
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = TokenEmbeddingLayer(
+ theta("attn_rel_b"), dtype=activation_dtype
+ )
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ # See transformers implementation
+ raise NotImplementedError()
+
+ @staticmethod
+ def _relative_position_bucket(
+ relative_position, bidirectional=True, num_buckets=32, max_distance=128
+ ):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (
+ ops.to(relative_position > 0, dtype=torch.long) * num_buckets
+ )
+ relative_position = ops.elementwise(torch.abs, relative_position)
+ else:
+ relative_position = -ops.elementwise(
+ torch.min, relative_position, torch.zeros_like(relative_position)
+ )
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_position_if_large = max_exact + (
+ ops.elementwise(torch.log, relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_position_if_large = ops.elementwise(
+ torch.min,
+ relative_position_if_large,
+ torch.full_like(relative_position_if_large, num_buckets - 1),
+ )
+
+ relative_buckets += ops.elementwise(
+ torch.where, is_small, relative_position, relative_position_if_large
+ )
+ return relative_buckets
+
+ def compute_bias(self, query_length, key_length, device=None):
+ """Compute binned relative position bias"""
+ if device is None:
+ device = self.relative_attention_bias.weight.device
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[
+ :, None
+ ]
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
+ None, :
+ ]
+ relative_position = (
+ memory_position - context_position
+ ) # shape (query_length, key_length)
+ relative_position_bucket = self._relative_position_bucket(
+ relative_position, # shape (query_length, key_length)
+ bidirectional=(not self.is_decoder),
+ num_buckets=self.relative_attention_num_buckets,
+ max_distance=self.relative_attention_max_distance,
+ )
+ values = self.relative_attention_bias(
+ relative_position_bucket
+ ) # shape (query_length, key_length, num_heads)
+ values = values.permute([2, 0, 1]).unsqueeze(
+ 0
+ ) # shape (1, num_heads, query_length, key_length)
+ return values
+
+ def forward(
+ self,
+ hidden_states,
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+ past_key_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ real_seq_length = seq_length
+
+ if past_key_value is not None:
+ if len(past_key_value) != 2:
+ raise ValueError(
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ )
+ real_seq_length += (
+ past_key_value[0].shape[2] if query_length is None else query_length
+ )
+
+ key_length = (
+ real_seq_length if key_value_states is None else key_value_states.shape[1]
+ )
+
+ def shape(states):
+ """projection"""
+ return states.view(
+ batch_size, -1, self.n_heads, self.key_value_proj_dim
+ ).transpose(1, 2)
+
+ def unshape(states):
+ """reshape"""
+ return (
+ states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
+ )
+
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
+ """projects hidden states correctly to key/query states"""
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(hidden_states))
+ elif past_key_value is None:
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = ops.cat([past_key_value, hidden_states], dim=2)
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
+ # checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+ return hidden_states
+
+ # get query states
+ query_states = shape(
+ self.q(hidden_states)
+ ) # (batch_size, n_heads, seq_length, dim_per_head)
+
+ # get key/value states
+ key_states = project(
+ hidden_states,
+ self.k,
+ key_value_states,
+ past_key_value[0] if past_key_value is not None else None,
+ )
+ value_states = project(
+ hidden_states,
+ self.v,
+ key_value_states,
+ past_key_value[1] if past_key_value is not None else None,
+ )
+
+ # compute scores
+ scores = ops.matmul(
+ query_states, key_states.transpose(3, 2)
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+ (1, self.n_heads, real_seq_length, key_length),
+ device=scores.device,
+ dtype=scores.dtype,
+ )
+ else:
+ position_bias = self.compute_bias(
+ real_seq_length, key_length, device=scores.device
+ )
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+ if past_key_value is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
+
+ if mask is not None:
+ position_bias = (
+ position_bias + mask
+ ) # (batch_size, n_heads, seq_length, key_length)
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ scores += position_bias_masked
+ attn_weights = ops.softmax(scores.float(), dim=-1).type_as(
+ scores
+ ) # (batch_size, n_heads, seq_length, key_length)
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = attn_weights * layer_head_mask
+
+ attn_output = unshape(
+ ops.matmul(attn_weights, value_states)
+ ) # (batch_size, seq_length, dim)
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = (
+ (key_states, value_states) if (self.is_decoder and use_cache) else None
+ )
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
+class T5SelfAttention(BaseLayer):
+ def __init__(
+ self,
+ theta: Theta,
+ is_decoder: bool,
+ relative_attention_num_buckets: int,
+ relative_attention_max_distance: int,
+ d_model: int,
+ d_kv: int,
+ num_heads: int,
+ layer_norm_epsilon: float,
+ activation_dtype: torch.dtype,
+ has_relative_attention_bias: bool = False,
+ ):
+ super().__init__()
+ self.attention = T5Attention(
+ theta=theta,
+ is_decoder=is_decoder,
+ relative_attention_num_buckets=relative_attention_num_buckets,
+ relative_attention_max_distance=relative_attention_max_distance,
+ d_model=d_model,
+ d_kv=d_kv,
+ num_heads=num_heads,
+ activation_dtype=activation_dtype,
+ has_relative_attention_bias=has_relative_attention_bias,
+ )
+ self.layer_norm = RMSNormLayer(
+ theta=theta("attn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.attention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + attention_output[0]
+ outputs = (hidden_states,) + attention_output[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class T5CrossAttention(nn.Module):
+ def __init__(
+ self,
+ theta: Theta,
+ is_decoder: bool,
+ relative_attention_num_buckets: int,
+ relative_attention_max_distance: int,
+ d_model: int,
+ d_kv: int,
+ num_heads: int,
+ layer_norm_epsilon: float,
+ activation_dtype: torch.dtype,
+ ):
+ super().__init__()
+ self.enc_dec_attention = T5Attention(
+ theta=theta,
+ is_decoder=is_decoder,
+ relative_attention_num_buckets=relative_attention_num_buckets,
+ relative_attention_max_distance=relative_attention_max_distance,
+ d_model=d_model,
+ d_kv=d_kv,
+ num_heads=num_heads,
+ activation_dtype=activation_dtype,
+ has_relative_attention_bias=False,
+ )
+ self.layer_norm = RMSNormLayer(
+ theta=theta("attn_norm"), epsilon=layer_norm_epsilon, dtype=activation_dtype
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ key_value_states,
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+ ):
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.enc_dec_attention(
+ normed_hidden_states,
+ mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+ )
+ layer_output = hidden_states + attention_output[0]
+ outputs = (layer_output,) + attention_output[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class T5Block(nn.Module):
+ def __init__(
+ self,
+ theta: Theta,
+ is_decoder: bool,
+ relative_attention_num_buckets: int,
+ relative_attention_max_distance: int,
+ d_model: int,
+ d_kv: int,
+ num_heads: int,
+ layer_norm_epsilon: float,
+ is_gated_act: bool,
+ dense_act_fn: str,
+ activation_dtype: torch.dtype,
+ has_relative_attention_bias=False,
+ ):
+ super().__init__()
+ self.is_decoder = is_decoder
+ self.layer = nn.ModuleList()
+ self.layer.append(
+ T5SelfAttention(
+ theta=theta,
+ is_decoder=is_decoder,
+ relative_attention_num_buckets=relative_attention_num_buckets,
+ relative_attention_max_distance=relative_attention_max_distance,
+ d_model=d_model,
+ d_kv=d_kv,
+ num_heads=num_heads,
+ layer_norm_epsilon=layer_norm_epsilon,
+ activation_dtype=activation_dtype,
+ has_relative_attention_bias=has_relative_attention_bias,
+ )
+ )
+ if self.is_decoder:
+ self.layer.append(
+ T5CrossAttention(
+ theta=theta,
+ is_decoder=is_decoder,
+ relative_attention_num_buckets=relative_attention_num_buckets,
+ relative_attention_max_distance=relative_attention_max_distance,
+ d_model=d_model,
+ d_kv=d_kv,
+ num_heads=num_heads,
+ activation_dtype=activation_dtype,
+ layer_norm_epsilon=layer_norm_epsilon,
+ )
+ )
+
+ self.layer.append(
+ T5LayerFF(
+ theta=theta,
+ is_gated_act=is_gated_act,
+ dense_act_fn=dense_act_fn,
+ layer_norm_epsilon=layer_norm_epsilon,
+ activation_dtype=activation_dtype,
+ )
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ encoder_decoder_position_bias=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+ past_key_value=None,
+ use_cache=False,
+ output_attentions=False,
+ return_dict=True,
+ ):
+ if past_key_value is not None:
+ if not self.is_decoder:
+ logger.warning(
+ "`past_key_values` is passed to the encoder. Please make sure this is intended."
+ )
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
+
+ if len(past_key_value) != expected_num_past_key_values:
+ raise ValueError(
+ f"There should be {expected_num_past_key_values} past states. "
+ f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
+ f"Got {len(past_key_value)} past key / value states"
+ )
+
+ self_attn_past_key_value = past_key_value[:2]
+ cross_attn_past_key_value = past_key_value[2:]
+ else:
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
+
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=self_attn_past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
+ attention_outputs = self_attention_outputs[
+ 2:
+ ] # Keep self-attention outputs and relative position weights
+
+ # clamp inf values to enable fp16 training
+ if hidden_states.dtype == torch.float16:
+ clamp_value = ops.elementwise(
+ torch.where,
+ ops.elementwise(torch.isinf, hidden_states).any(),
+ torch.finfo(hidden_states.dtype).max - 1000,
+ torch.finfo(hidden_states.dtype).max,
+ )
+ hidden_states = ops.elementwise(
+ torch.clamp, hidden_states, min=-clamp_value, max=clamp_value
+ )
+
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
+ if do_cross_attention:
+ # the actual query length is unknown for cross attention
+ # if using past key value states. Need to inject it here
+ if present_key_value_state is not None:
+ query_length = present_key_value_state[0].shape[2]
+ else:
+ query_length = None
+
+ cross_attention_outputs = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ query_length=query_length,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = cross_attention_outputs[0]
+
+ # clamp inf values to enable fp16 training
+ if hidden_states.dtype == torch.float16:
+ clamp_value = ops.elementwise(
+ torch.where,
+ ops.elementwise(torch.isinf, hidden_states).any(),
+ torch.finfo(hidden_states.dtype).max - 1000,
+ torch.finfo(hidden_states.dtype).max,
+ )
+ hidden_states = ops.elementwise(
+ torch.clamp, hidden_states, min=-clamp_value, max=clamp_value
+ )
+
+ # Combine self attn and cross attn key value states
+ if present_key_value_state is not None:
+ present_key_value_state = (
+ present_key_value_state + cross_attention_outputs[1]
+ )
+
+ # Keep cross-attention outputs and relative position weights
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states)
+
+ # clamp inf values to enable fp16 training
+ if hidden_states.dtype == torch.float16:
+ clamp_value = ops.elementwise(
+ torch.where,
+ ops.elementwise(torch.isinf, hidden_states).any(),
+ torch.finfo(hidden_states.dtype).max - 1000,
+ torch.finfo(hidden_states.dtype).max,
+ )
+ hidden_states = ops.elementwise(
+ torch.clamp, hidden_states, min=-clamp_value, max=clamp_value
+ )
+
+ outputs = (hidden_states,)
+
+ if use_cache:
+ outputs = outputs + (present_key_value_state,) + attention_outputs
+ else:
+ outputs = outputs + attention_outputs
+
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+
+
+class T5Stack(BaseLayer):
+ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None):
+ super().__init__()
+
+ self.embed_tokens = embed_tokens
+ self.config = config
+ self.is_decoder = config.is_decoder
+ theta_prefix = "dec" if config.is_decoder else "enc"
+
+ self.block = torch.nn.ModuleList(
+ [
+ T5Block(
+ theta=theta(f"{theta_prefix}.blk.{i}"),
+ is_decoder=config.is_decoder,
+ relative_attention_num_buckets=config.relative_attention_num_buckets,
+ relative_attention_max_distance=config.relative_attention_max_distance,
+ d_model=config.d_model,
+ d_kv=config.d_kv,
+ num_heads=config.num_heads,
+ layer_norm_epsilon=config.layer_norm_epsilon,
+ is_gated_act=config.is_gated_act,
+ dense_act_fn=config.dense_act_fn,
+ activation_dtype=config.activation_dtype,
+ has_relative_attention_bias=bool(i == 0),
+ )
+ for i in range(config.num_layers)
+ ]
+ )
+ self.add_module(
+ "final_layer_norm",
+ RMSNormLayer(
+ theta(f"{theta_prefix}.output_norm"), epsilon=config.layer_norm_epsilon
+ ),
+ )
+
+ dtypes = set(tensor.dtype for tensor in theta.flatten().values())
+ assert len(dtypes) == 1
+ self.dtype = dtypes.pop()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens = new_embeddings
+
+ @staticmethod
+ def create_extended_attention_mask_for_decoder(
+ input_shape, attention_mask, device=None
+ ):
+ if device is not None:
+ warnings.warn(
+ "The `device` argument is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ else:
+ device = attention_mask.device
+ batch_size, seq_length = input_shape
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = (
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+ <= seq_ids[None, :, None]
+ )
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones(
+ (batch_size, seq_length, prefix_seq_len),
+ device=device,
+ dtype=causal_mask.dtype,
+ ),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = (
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ )
+ return extended_attention_mask
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_shape: Tuple[int],
+ device: torch.device = None,
+ dtype: torch.dtype = None,
+ ) -> torch.Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (`Tuple[int]`):
+ The shape of the input to the model.
+
+ Returns:
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
+ """
+ if dtype is None:
+ dtype = self.dtype
+
+ if not (attention_mask.dim() == 2 and self.config.is_decoder):
+ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
+ if device is not None:
+ warnings.warn(
+ "The `device` argument is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder:
+ extended_attention_mask = (
+ T5Stack.create_extended_attention_mask_for_decoder(
+ input_shape, attention_mask, device
+ )
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = ops.to(
+ extended_attention_mask, dtype=dtype
+ ) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(
+ dtype
+ ).min
+ return extended_attention_mask
+
+ def get_head_mask(
+ self,
+ head_mask: Optional[torch.Tensor],
+ num_hidden_layers: int,
+ is_attention_chunked: bool = False,
+ ) -> torch.Tensor:
+ """
+ Prepare the head mask if needed.
+
+ Args:
+ head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
+ num_hidden_layers (`int`):
+ The number of hidden layers in the model.
+ is_attention_chunked (`bool`, *optional*, defaults to `False`):
+ Whether or not the attentions scores are computed by chunks or not.
+
+ Returns:
+ `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
+ `[None]` for each layer.
+ """
+ if head_mask is not None:
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+ if is_attention_chunked is True:
+ head_mask = head_mask.unsqueeze(-1)
+ else:
+ head_mask = [None] * num_hidden_layers
+
+ return head_mask
+
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = (
+ head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
+ ) # We can specify head_mask for each layer
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
+ head_mask = head_mask.to(
+ dtype=self.dtype
+ ) # switch to float if need + fp16 compatibility
+ return head_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.return_dict
+ )
+
+ if input_ids is not None and inputs_embeds is not None:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
+ )
+
+ if inputs_embeds is None:
+ if self.embed_tokens is None:
+ raise ValueError(
+ "You have to initialize the model with valid token embeddings"
+ )
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = (
+ past_key_values[0][0].shape[2] + seq_length
+ if past_key_values is not None
+ else seq_length
+ )
+
+ if use_cache is True:
+ if not self.is_decoder:
+ raise ValueError(
+ f"`use_cache` can only be set to `True` if {self} is used as a decoder"
+ )
+
+ # initialize past_key_values with `None` if past does not exist
+ if past_key_values is None:
+ past_key_values = [None] * len(self.block)
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ batch_size, mask_seq_length, device=inputs_embeds.device
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape
+ )
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.is_decoder and encoder_hidden_states is not None:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
+ )
+ encoder_extended_attention_mask = self.invert_attention_mask(
+ encoder_attention_mask
+ )
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(
+ cross_attn_head_mask, self.config.num_layers
+ )
+ present_key_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ hidden_states = inputs_embeds
+
+ for i, (layer_module, past_key_value) in enumerate(
+ zip(self.block, past_key_values)
+ ):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if use_cache is False:
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+ if self.is_decoder and encoder_hidden_states is not None:
+ encoder_decoder_position_bias = layer_outputs[
+ 4 if output_attentions else 3
+ ]
+ # append next layer key value states
+ if use_cache:
+ present_key_value_states = present_key_value_states + (
+ present_key_value_state,
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+ if self.is_decoder:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return OrderedDict(
+ (k, v)
+ for k, v in [
+ ("last_hidden_state", hidden_states),
+ ("past_key_values", present_key_value_states),
+ ("hidden_states", all_hidden_states),
+ ("attentions", all_attentions),
+ ("cross_attentions", all_cross_attentions),
+ ]
+ if v is not None
+ )
+
+
+class T5Encoder(BaseLayer):
+ def __init__(self, theta: Theta, config: T5Config):
+ super().__init__()
+ self.add_module(
+ "token_embedding",
+ TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
+ )
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = T5Stack(
+ theta=theta, config=encoder_config, embed_tokens=self.token_embedding
+ )
+
+ @property
+ def config(self):
+ return self.encoder.config
+
+ def sample_inputs(self, batch_size: int) -> OrderedDict[str, AnyTensor]:
+ return OrderedDict(
+ [
+ (
+ "input_ids",
+ torch.empty(
+ size=[batch_size, self.config.context_length], dtype=torch.long
+ ),
+ )
+ ]
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> tuple[torch.FloatTensor]:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ return encoder_outputs
diff --git a/sharktank/sharktank/ops/_registry.py b/sharktank/sharktank/ops/_registry.py
index 3c8ea0aed..f067c047b 100644
--- a/sharktank/sharktank/ops/_registry.py
+++ b/sharktank/sharktank/ops/_registry.py
@@ -18,6 +18,7 @@
__all__ = [
"AllOfExprs",
+ "AllOfExprsVariadic",
"AllOfType",
"AnyOfType",
"IsOfType",
@@ -65,7 +66,8 @@ def __call__(self, *args: type) -> bool:
class AllOfExprs(BoolTypeExpr):
- """Returns True if all types match their respective boolean type expression.
+ """Returns True if all type arguments match their respective boolean type
+ expression.
```python
# True. int == int and str in (float, str).
@@ -87,6 +89,38 @@ def expr(*types: type):
super().__init__(expr)
+class AllOfExprsVariadic(BoolTypeExpr):
+ """Returns True if all type arguments match their respective boolean type
+ expression and any remaining trailing arguments match the last type expression.
+
+ ```python
+ # True. int == int
+ # str in (float, str).
+ # float in (float, str).
+ AllOfExprsVariadic(IsOfType(int), IsOfType(float, str))(int, str, float)
+
+ # False. str is not in (int, float).
+ AllOfExprsVariadic(IsOfType(int), IsOfType(int, float))(int, float, str, int)
+ ```
+ """
+
+ def __init__(self, *exprs: BoolTypeExpr):
+ if len(exprs) == 0:
+ raise ValueError("At least one expression is required.")
+ self._exprs = list(exprs)
+
+ def expr(*types: type):
+ if len(types) < len(self._exprs):
+ return False
+ exprs = self._exprs
+ if len(types) > len(exprs):
+ # pad with the trailing expression.
+ exprs = exprs + ([exprs[-1]] * (len(types) - len(self._exprs)))
+ return all([e(t) for e, t in zip(exprs, types)])
+
+ super().__init__(expr)
+
+
class AllOfType(BoolTypeExpr):
"""Returns True if all of the types are from a set of types.
diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py
index fe6ae27b1..9acc7c562 100644
--- a/sharktank/sharktank/ops/custom_impls.py
+++ b/sharktank/sharktank/ops/custom_impls.py
@@ -5,20 +5,24 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
+
from torch import Tensor, dtype
+from typing import Union
+
import torch.nn.functional as F
from ..kernels import (
+ einsum_2args_q4,
mmt_block_scaled_offset_q4_unsigned,
mmt_block_scaled_q8,
- mmtfp,
mmt_super_block_scaled_offset_q4_unsigned,
+ bitcast_to_complex,
+ bitcast_to_real,
)
from ..types import (
BlockScaledLayout,
BlockScaledI4Layout,
- InferenceTensor,
PrimitiveTensor,
QuantizedTensor,
SuperBlockOffsetScaled_4_6_Layout,
@@ -29,7 +33,7 @@
# Fused FP matmul.
-# Disabled: See https://github.com/nod-ai/sharktank/issues/44
+# Disabled: See https://github.com/nod-ai/shark-ai/issues/44
# @matmul.override(Tensor, Tensor)
# def matmul_mmtfp_tensor_tensor(lhs, rhs, *, transpose_rhs: bool):
# lhs = unbox_tensor(lhs)
@@ -44,6 +48,18 @@
# return mmtfp(lhs, rhs)
+# Einsum
+
+
+@einsum_2args.override(Tensor, QuantizedTensor)
+def einsum_2args_QuantizedTensor(input0, input1, einsum_str):
+ unpacked = input1.unpack()
+ layout = input1.layout_type
+ if not isinstance(unpacked, BlockScaledI4Layout):
+ return NotImplemented
+ return einsum_2args_q4(input0, unpacked.d, unpacked._qs, unpacked.m, einsum_str)
+
+
# Quantized Matmul
@@ -110,3 +126,15 @@ def matmul_generic_tensor_super_block_offset_scaled_4_6_i4(
sb_mins_low,
rhs_unpacked.qs_bit_packed,
)
+
+
+@view_as_complex.override(Union[Tensor, PrimitiveTensor])
+def view_as_complex(t):
+ t = unbox_tensor(t)
+ return bitcast_to_complex(t)
+
+
+@view_as_real.override(Union[Tensor, PrimitiveTensor])
+def view_as_real(t):
+ t = unbox_tensor(t)
+ return bitcast_to_real(t)
diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py
index 72f3db711..d117ada23 100644
--- a/sharktank/sharktank/ops/default_impls.py
+++ b/sharktank/sharktank/ops/default_impls.py
@@ -7,17 +7,24 @@
# This file contains overrides of the standard ops for normal torch and
# generic primitive/quantized types.
-from typing import Optional, List, Sequence, Union
+from typing import Optional, List, Sequence, Union, Tuple
import torch
from torch import Tensor, dtype
import torch.nn.functional as F
from numbers import Number
-from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor
+from ..types import (
+ PrimitiveTensor,
+ QuantizedTensor,
+ InferenceTensor,
+ PlanarQuantizedTensor,
+ BlockScaledI4Layout,
+)
from ..types.tensors import unbox_tensor, AnyTensor
-from ._registry import AllOfType, AllOfExprs, IsOfType
+from ._registry import AllOfType, AllOfExprs, AllOfExprsVariadic, IsOfType
from .signatures import *
+import iree.turbine.ops.iree
@cat.override(AllOfType(Tensor, PrimitiveTensor))
@@ -61,11 +68,64 @@ def conv2d_default(
conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default)
conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default)
+# Einsum
+def mk_menk_men(inputs, weights):
+ # batch dims: m, lhs pdims: none, lhs rdims: k, rhs pdims: en, rhs rdims: k
+ inputs = inputs.unsqueeze(1)
+ weights_shape = weights.shape
+ weights = weights.view(
+ weights_shape[0], weights_shape[1] * weights_shape[2], weights_shape[3]
+ )
+ result = matmul(inputs, weights, transpose_rhs=True)
+ result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
+ return result
+
+
+def mek_menk_men(inputs, weights):
+ # batch dims: me, lhs pdims: none, lhs rdims: k, rhs pdims: n, rhs rdims: k
+ inputs_shape = inputs.shape
+ inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, inputs_shape[2])
+ weights_shape = weights.shape
+ weights = weights.view(
+ weights_shape[0] * weights_shape[1], weights_shape[2], weights_shape[3]
+ )
+ result = matmul(inputs, weights, transpose_rhs=True)
+ result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
+ return result
+
+
+def me_men_men(inputs, weights):
+ # batch dims: me, lhs pdims: none, lhs rdims: none, rhs pdims: n, rhs rdims: none
+ inputs_shape = inputs.shape
+ inputs = inputs.view(inputs_shape[0] * inputs_shape[1], 1, 1)
+ weights_shape = weights.shape
+ weights = weights.view(weights_shape[0] * weights_shape[1], weights_shape[2], 1)
+ result = matmul(inputs, weights, transpose_rhs=True)
+ result = result.view(weights_shape[0], weights_shape[1], weights_shape[2])
+ return result
+
+
+@einsum_2args.override(AllOfType(Tensor, PrimitiveTensor, QuantizedTensor))
+def einsum_2args(input0, input1, einsum_str):
+ # Special optimized einsum kernels that lower to batch matmul
+ if einsum_str == "mk,menk->men":
+ return mk_menk_men(input0, input1)
+ elif einsum_str == "mek,menk->men":
+ return mek_menk_men(input0, input1)
+ elif einsum_str == "me,men->men":
+ return me_men_men(input0, input1)
+ # Default non-QuantizedTensor einsum
+ if not isinstance(input1, QuantizedTensor):
+ return torch.einsum(einsum_str, unbox_tensor(x), unbox_tensor(y))
+ # Fallback to other kernels
+ return NotImplemented
+
+
# Elementwise
@elementwise.override(Tensor)
-def elementwise_unary(operator, x):
+def elementwise_unary(operator, x, *args, **kwargs):
x = unbox_tensor(x)
- return operator(x)
+ return operator(x, *args, **kwargs)
@elementwise.override(
@@ -73,11 +133,27 @@ def elementwise_unary(operator, x):
IsOfType(Tensor, PrimitiveTensor), IsOfType(Tensor, PrimitiveTensor, Number)
)
)
-def elementwise_binary(operator, x, y):
+def elementwise_binary(operator, x, y, *args, **kwargs):
x = unbox_tensor(x)
if isinstance(y, PrimitiveTensor):
y = unbox_tensor(y)
- return operator(x, y)
+ return operator(x, y, *args, **kwargs)
+
+
+@elementwise.override(
+ AllOfExprs(
+ IsOfType(Tensor, PrimitiveTensor),
+ IsOfType(Tensor, PrimitiveTensor, Number),
+ IsOfType(Tensor, PrimitiveTensor, Number),
+ )
+)
+def elementwise_ternary(operator, x, y, z, *args, **kwargs):
+ x = unbox_tensor(x)
+ if isinstance(y, PrimitiveTensor):
+ y = unbox_tensor(y)
+ if isinstance(z, PrimitiveTensor):
+ z = unbox_tensor(z)
+ return operator(x, y, z, *args, **kwargs)
# Embedding Lookup
@@ -99,6 +175,49 @@ def equal_default(a, b) -> bool:
return torch.equal(unbox_tensor(a), unbox_tensor(b))
+@expand.override(Tensor)
+def expand_default(tensor: AnyTensor, shape: List[int]) -> AnyTensor:
+ return unbox_tensor(tensor).expand(*shape)
+
+
+@flatten.override(Tensor)
+def flatten_default(
+ input: Union[Tensor, PrimitiveTensor], start_dim: int, end_dim: int
+) -> Tensor:
+ return torch.flatten(unbox_tensor(input), start_dim, end_dim)
+
+
+@gather.override(Tensor, Tensor)
+def gather_default(
+ input: Union[Tensor, PrimitiveTensor],
+ dim: int,
+ index: Union[Tensor, PrimitiveTensor],
+) -> Tensor:
+ return torch.gather(unbox_tensor(input), dim, unbox_tensor(index))
+
+
+@get_index.override(AllOfType(Tensor, PrimitiveTensor))
+def get_index_default(tensor, key):
+ return unbox_tensor(tensor).__get_item__(key)
+
+
+@get_index.override(QuantizedTensor)
+def get_index_QuantizedTensor(tensor: QuantizedTensor, key: slice):
+ unpacked = tensor.unpack()
+ if isinstance(unpacked, BlockScaledI4Layout):
+ mul = 2
+ else:
+ return NotImplemented
+ new_d = unpacked._d[key]
+ new_qs = unpacked._qs[key]
+ if unpacked.m is not None:
+ new_m = unpacked.m[key]
+ dims = new_qs.shape
+ dims = dims[:-2] + (dims[-2] * dims[-1] * mul,)
+ layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m)
+ return PlanarQuantizedTensor(shape=dims, layout=layout)
+
+
@gemm.override(AllOfType(Tensor, InferenceTensor))
def gemm(
a: AnyTensor,
@@ -133,6 +252,37 @@ def group_norm_affine_default(input, weight, bias, *, num_groups, eps):
return F.group_norm(input, num_groups=num_groups, weight=weight, bias=bias, eps=eps)
+@index_copy_.override(Tensor, Tensor, Tensor)
+def index_copy__default(
+ inout: Union[Tensor, PrimitiveTensor],
+ dim: int,
+ index: Union[Tensor, PrimitiveTensor],
+ tensor: Union[Tensor, PrimitiveTensor],
+) -> Union[Tensor, PrimitiveTensor]:
+ unbox_tensor(inout).index_copy_(dim, unbox_tensor(index), unbox_tensor(tensor))
+ return inout
+
+
+@index_put_.override(AllOfType(Tensor, PrimitiveTensor))
+def index_put__default(
+ inout: Union[Tensor, PrimitiveTensor],
+ indices: Tuple[Union[Tensor, PrimitiveTensor]],
+ values: Union[Tensor, PrimitiveTensor],
+) -> Union[Tensor, PrimitiveTensor]:
+ indices = tuple(unbox_tensor(index) for index in indices)
+ unbox_tensor(inout).index_put_(indices, unbox_tensor(values))
+ return inout
+
+
+@index_select.override(Tensor, Tensor)
+def index_select_default(
+ tensor: Union[Tensor, PrimitiveTensor],
+ dim: int,
+ index: Union[Tensor, PrimitiveTensor],
+) -> Union[Tensor, PrimitiveTensor]:
+ return torch.index_select(unbox_tensor(tensor), dim, unbox_tensor(index))
+
+
@interpolate.override(Tensor)
def interpolate_default(
input: Tensor,
@@ -187,15 +337,21 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor:
lhs = unbox_tensor(lhs)
rhs = unbox_tensor(rhs)
if transpose_rhs:
- rhs = rhs.T
- return torch.matmul(lhs, rhs.to(lhs.dtype))
+ rhs = rhs.mT
+ rhs = rhs.to(lhs.dtype)
+
+ if len(lhs.shape) > 2 and len(rhs.shape) < 3:
+ bdims = lhs.shape[:-1]
+ lhs = torch.flatten(lhs, 0, -2)
+ mm = torch.matmul(lhs, rhs)
+ return torch.unflatten(mm, 0, bdims)
+
+ return torch.matmul(lhs, rhs)
# Scaled dot product attention
-@scaled_dot_product_attention.override(
- Tensor, Tensor, Tensor, Optional[Tensor], auto_dequant=True
-)
-def scaled_dot_product_attention(q, k, v, a) -> Tensor:
+@scaled_dot_product_attention.override(Tensor, Tensor, Tensor, None)
+def scaled_dot_product_attention_torch(q, k, v, a, is_causal, scale) -> Tensor:
q = unbox_tensor(q)
k = unbox_tensor(k)
v = unbox_tensor(v)
@@ -204,18 +360,41 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor:
# TODO: plumb dropout and is_causal through ops
return torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask=a, dropout_p=0.0, is_causal=False
+ q, k, v, attn_mask=a, dropout_p=0.0, is_causal=is_causal, scale=scale
)
+@mean.override(Tensor)
+def mean_default(
+ x: Tensor, dim: Union[int, List[int]], keepdim: bool, *, dtype: torch.dtype
+) -> None:
+ return torch.mean(unbox_tensor(x), dim=dim, keepdim=keepdim, dtype=dtype)
+
+
+@module_register_buffer.override(torch.nn.Module, Tensor)
+def module_register_buffer_default(
+ module: torch.nn.Module, name: str, tensor: Union[Tensor, InferenceTensor]
+) -> None:
+ return module.register_buffer(name, unbox_tensor(tensor))
+
+
+@repeat.override(Tensor)
+def repeat_default(input: Union[Tensor, PrimitiveTensor], *sizes: List[int]) -> Tensor:
+ return unbox_tensor(input).repeat(*sizes)
+
+
+@reshape.override(Tensor)
+def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) -> Tensor:
+ return torch.reshape(unbox_tensor(input), shape)
+
+
# RMS norm
-@rms_norm.override(Tensor, Tensor)
+@rms_norm.override(AllOfType(Tensor, InferenceTensor))
def rms_norm_default(x, weight, *, epsilon: float) -> Tensor:
- x = unbox_tensor(x)
- weight = unbox_tensor(weight)
variance = x.pow(2).mean(-1, keepdim=True)
- output = x * torch.rsqrt(variance + epsilon)
- output = output * weight
+ output = x * elementwise(torch.rsqrt, variance + epsilon)
+ # The cast here is to match the hf implementation, affects numerics
+ output = elementwise(torch.mul, weight, to(output, weight.dtype))
return output
@@ -234,6 +413,34 @@ def permute(tensor: Tensor, dims: List[int]):
return torch.permute(torch_tensor, dims)
+@softmax.override(Tensor)
+def softmax_default(
+ tensor: Union[Tensor, PrimitiveTensor],
+ dim: Optional[int],
+ dtype: Optional[torch.dtype],
+) -> Tensor:
+ return F.softmax(unbox_tensor(tensor), dim=dim, dtype=dtype)
+
+
+@to.override(Tensor)
+def to_default(tensor: Tensor, *args, **kwargs):
+ return unbox_tensor(tensor).to(*args, **kwargs)
+
+
+@transfer_to_logical_device.override(Tensor)
+def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
+ return iree.turbine.ops.iree.transfer_to_logical_device(
+ f"{ordinal}", unbox_tensor(tensor)
+ )
+
+
+@transpose.override(Tensor)
+def transpose_default(
+ tensor: Union[Tensor, PrimitiveTensor], dim0: int, dim1: int
+) -> Tensor:
+ return torch.transpose(unbox_tensor(tensor), dim0, dim1)
+
+
# Sharded default impls (do nothing).
@@ -245,3 +452,46 @@ def sharded_cat_unsharded(maybe_sharded):
@sharded_sum.override(Tensor)
def sharded_sum_unsharded(maybe_sharded):
return unbox_tensor(maybe_sharded)
+
+
+@unflatten.override(Tensor)
+def unflatten_default(
+ input: Union[Tensor, PrimitiveTensor], dim: int, sizes: Tuple[int]
+) -> Tensor:
+ return torch.unflatten(unbox_tensor(input), dim, sizes)
+
+
+@unsqueeze.override(Tensor)
+def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tensor:
+ return torch.unsqueeze(tensor, dim)
+
+
+@view.override(Tensor)
+def view_default(tensor: Union[Tensor, PrimitiveTensor], shape: List[int]) -> Tensor:
+ return unbox_tensor(tensor).view(*shape)
+
+
+@view.override(QuantizedTensor)
+def view_QuantizedTensor(tensor: QuantizedTensor, shape):
+ unpacked = tensor.unpack()
+ if not isinstance(unpacked, BlockScaledI4Layout):
+ return NotImplemented
+ bs = 16
+ shape = list(shape)
+ new_d = unpacked._d.view(shape[:-1] + [shape[-1] // 32, 1])
+ qs_shape = shape[:-1] + [shape[-1] // 32, 16]
+ new_qs = unpacked._qs.view(qs_shape)
+ if unpacked.m is not None:
+ new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1])
+ layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m)
+ return PlanarQuantizedTensor(shape=shape, layout=layout)
+
+
+@view_as_complex.override(Tensor)
+def view_as_complex_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor:
+ return torch.view_as_complex(unbox_tensor(tensor))
+
+
+@view_as_real.override(Tensor)
+def view_as_real_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor:
+ return torch.view_as_real(unbox_tensor(tensor))
diff --git a/sharktank/sharktank/ops/qconv_impls.py b/sharktank/sharktank/ops/qconv_impls.py
index af1199976..add64a605 100644
--- a/sharktank/sharktank/ops/qconv_impls.py
+++ b/sharktank/sharktank/ops/qconv_impls.py
@@ -12,6 +12,7 @@
import warnings
import torch
+import torch.nn.functional as F
from sharktank import kernels
@@ -31,7 +32,7 @@
)
-def qconv2d_tensor_scaled_integer(
+def qconv2d_tensor_scaled(
input: QuantizedTensor,
weight: QuantizedTensor,
bias: Optional[AnyTensor] = None,
@@ -59,12 +60,16 @@ def qconv2d_tensor_scaled_integer(
input_layout: TensorScaledLayout = input.unpack()
weight_layout: TensorScaledLayout = weight.unpack()
- # Only handle integer quantizations.
+ # # Handle integer and fp8 quantizations.
if (
input_layout.qs.dtype.is_floating_point
or weight_layout.qs.dtype.is_floating_point
):
- return NotImplemented
+ if (
+ input_layout.qs.dtype != torch.float8_e4m3fnuz
+ or weight_layout.qs.dtype != torch.float8_e4m3fnuz
+ ):
+ return NotImplemented
# Bias is both optional and may either be quantized or fp.
bias_qs = None
@@ -86,6 +91,8 @@ def qconv2d_tensor_scaled_integer(
# Alias components (d=scale, qs=quantized samples, m=offset).
if accum_dtype is None:
accum_dtype = torch.int32
+ if weight_layout.qs.dtype.is_floating_point:
+ accum_dtype = torch.float32
input_d = input_layout.d
input_dtype = input_layout.dtype
input_qs = input_layout.qs
@@ -113,8 +120,8 @@ def qconv2d_tensor_scaled_integer(
padding = _expand_int_to_2_tuple(padding)
dilation = _expand_int_to_2_tuple(dilation)
extended_padding_list = [item for item in padding for _ in range(2)]
- padded_input = _pad_last_2d(input_qs, extended_padding_list)
- y_qs = _invoke_int32_conv2d(
+ padded_input = F.pad(input_qs, pad=extended_padding_list)
+ y_qs = _invoke_conv2d_kernel(
padded_input,
weight_qs,
bias_qs.to(accum_dtype) if bias_qs is not None else None,
@@ -145,7 +152,7 @@ def qconv2d_tensor_scaled_integer(
weight_offset_fix = torch.sum(
padded_input, dim=1, keepdim=True, dtype=accum_dtype
)
- weight_offset_fix = _invoke_int32_pooling_sum(
+ weight_offset_fix = _invoke_pooling_sum_kernel(
weight_offset_fix,
[weight_qs.shape[2], weight_qs.shape[3]],
stride,
@@ -188,13 +195,11 @@ def qconv2d_tensor_scaled_integer(
return y
-conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled_integer)
-conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(
- qconv2d_tensor_scaled_integer
-)
+conv2d.override(QuantizedTensor, QuantizedTensor)(qconv2d_tensor_scaled)
+conv2d.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qconv2d_tensor_scaled)
-def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype):
+def _invoke_conv2d_kernel(input, weight, bias, stride, dilation, *, accum_dtype):
"""Does a low level invocation of a conv2d integer kernel on an explicitly padded input.
This presumes that the stride/padding/dilation have already been normalized
@@ -233,7 +238,7 @@ def _invoke_int32_conv2d(input, weight, bias, stride, dilation, *, accum_dtype):
return y_qs
-def _invoke_int32_pooling_sum(input, kernel_size, stride, dilation, *, accum_dtype):
+def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dtype):
"""Invokes either a custom integer pooling sum or the built-in fp avg_pool2d
kernel on an explicitly padded input.
"""
@@ -254,27 +259,6 @@ def _invoke_int32_pooling_sum(input, kernel_size, stride, dilation, *, accum_dty
return output
-def _pad_last_2d(input_tensor, pad_width):
- # pad_width should be in the format [pad_left, pad_right, pad_top, pad_bottom]
- pad_left, pad_right, pad_top, pad_bottom = pad_width
- batch_size, channels, height, width = input_tensor.shape
-
- # Create a new tensor with the desired padded size filled with zeros
- padded_height = height + pad_top + pad_bottom
- padded_width = width + pad_left + pad_right
- padded_tensor = torch.zeros(
- (batch_size, channels, padded_height, padded_width),
- dtype=input_tensor.dtype,
- device=input_tensor.device,
- )
-
- # Copy the values from the input tensor to the appropriate location in the padded tensor
- padded_tensor[
- :, :, pad_top : pad_top + height, pad_left : pad_left + width
- ] = input_tensor
- return padded_tensor
-
-
def _flatten_input_scale_offset_channels(d, m):
"""Flattens either a 4d or 0d scale/offset as [N, C, H, W] to 1D.
diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py
index 0a381d613..b66d3be1d 100644
--- a/sharktank/sharktank/ops/qlinear_impls.py
+++ b/sharktank/sharktank/ops/qlinear_impls.py
@@ -28,7 +28,7 @@
from sharktank import kernels
-def qlinear_tensor_scaled_integer(
+def qlinear_tensor_scaled(
x: QuantizedTensor,
weight: QuantizedTensor,
bias: Optional[AnyTensor],
@@ -48,9 +48,15 @@ def qlinear_tensor_scaled_integer(
x_layout: TensorScaledLayout = x.unpack()
weight_layout: TensorScaledLayout = weight.unpack()
- # Only handle integer quantizations.
+ # Handle only integer and fp8 quantizations.
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
- return NotImplemented
+ if x_layout.qs.dtype == torch.float8_e4m3fnuz:
+ # assume quark
+ return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to(
+ torch.float16
+ )
+ else:
+ return NotImplemented
# Bias.
quantized_bias_accum = False
@@ -65,6 +71,8 @@ def qlinear_tensor_scaled_integer(
# Alias components (d=scale, qs=quantized samples, m=offset)
if accum_dtype is None:
accum_dtype = torch.int32
+ if weight_layout.qs.dtype.is_floating_point:
+ accum_dtype = torch.float32
x_d = x_layout.d
x_dtype = x_layout.dtype
x_qs = x_layout.qs
@@ -86,7 +94,7 @@ def qlinear_tensor_scaled_integer(
# TODO: Handle permutation that we have a kernel for.
# Fall back to automatic fusion based on integer, high precision matmul.
- y_qs = _invoke_int32_mmt(x_qs, weight_qs, accum_dtype=accum_dtype)
+ y_qs = _invoke_mmt_kernel(x_qs, weight_qs, accum_dtype=accum_dtype)
# Offset correction. By applying the offset correction in post, it is
# set up to fuse with its consumer, which is already doing additional
@@ -143,10 +151,8 @@ def qlinear_tensor_scaled_integer(
# Overrload for both bias and no bias.
-linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled_integer)
-linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(
- qlinear_tensor_scaled_integer
-)
+linear.override(QuantizedTensor, QuantizedTensor)(qlinear_tensor_scaled)
+linear.override(QuantizedTensor, QuantizedTensor, AnyTensor)(qlinear_tensor_scaled)
def linear_quantized_weight(
@@ -166,19 +172,30 @@ def linear_quantized_weight(
linear.override(Tensor, QuantizedTensor, AnyTensor)(linear_quantized_weight)
-def _invoke_int32_mmt(lhs, rhs, *, accum_dtype):
+def _invoke_mmt_kernel(lhs, rhs, *, accum_dtype):
if debugging.flags.use_custom_iree_kernels:
# The custom kernel requires that the lhs and rhs be the same
# rank. Broadcast the rhs to match.
lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
+ # If input to the kernel is 2D, expand the tensor to add the batch
+ # dimension.
+ if lhs_rank == 2:
+ lhs_size = [1] + list(lhs.shape)
+ lhs = lhs.unsqueeze(0).expand(lhs_size)
+ lhs_rank = len(lhs.shape)
if rhs_rank < lhs_rank:
assert (rhs_rank + 1) == lhs_rank
rhs_size = [lhs.shape[0]] + list(rhs.shape)
rhs = rhs.unsqueeze(0).expand(rhs_size)
+ rhs_rank = len(rhs.shape)
y_qs = kernels.batch_matmul_transpose_b(
lhs.to(accum_dtype), rhs.to(accum_dtype)
)
+ # Squeeze the batch dimension to maintain shape parity with other
+ # layers.
+ if len(y_qs.shape) > 2:
+ y_qs = y_qs.squeeze(0)
else:
# FP emulation.
y_qs = torch.matmul(
diff --git a/sharktank/sharktank/ops/shape.py b/sharktank/sharktank/ops/shape.py
index 69683f84f..0616b97b6 100644
--- a/sharktank/sharktank/ops/shape.py
+++ b/sharktank/sharktank/ops/shape.py
@@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Sequence
+from typing import Sequence, Optional
from ..types.tensors import AnyTensor
@@ -53,3 +53,12 @@ def broadcast_dims(
ranks = [len(shape) for shape in shaped_or_shape]
broadcast_rank = max(ranks)
return [dim + max(0, broadcast_rank - rank) for dim, rank in zip(dims, ranks)]
+
+
+def unbroadcast_dim(dim: int, shapes: Sequence[Sequence[int]]) -> Optional[int]:
+ """Returns the dimension in `shapes[0]` such that it would correspond to `dim`
+ after broadcasting the shapes `shapes`."""
+ ranks = [len(shape) for shape in shapes]
+ broadcast_rank = max(ranks)
+ res = dim - max(0, broadcast_rank - ranks[0])
+ return None if res < 0 else res
diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py
index ed639915e..07f466f5b 100644
--- a/sharktank/sharktank/ops/sharded_impls.py
+++ b/sharktank/sharktank/ops/sharded_impls.py
@@ -6,14 +6,17 @@
import torch
from torch import Tensor
-from typing import List, Optional, Sequence
+from typing import List, Optional, Sequence, Union, Any, Tuple
import itertools
from numbers import Number
+import math
+import functools
from ..types import (
AnyTensor,
DefaultPrimitiveTensor,
InferenceTensor,
+ PrimitiveTensor,
ReplicatedTensor,
ShardedTensor,
sharding,
@@ -22,29 +25,67 @@
UnreducedTensor,
)
from ..types.tensors import unbox_tensor
-from ._registry import AllOfType
+from ._registry import AllOfType, AllOfExprsVariadic, IsOfType
from .signatures import *
-from .shape import broadcast_dims
+from .shape import broadcast_dims, broadcast_dim, unbroadcast_dim
+from ..utils import longest_equal_range
@all_gather.override(SplitPrimitiveTensor)
def all_gather_split(
input: SplitPrimitiveTensor, *, dim: int | None
) -> ReplicatedTensor:
- assert (
- dim is None
- ), "gather dimension other than `input.shard_dim` is not supported."
- # TODO: figure out how to avoid common sub-expression elimination to not
- # merge all these into one.
- # Even if we place each resulting shard inside of ReplicatedTensor on a
- # distinct logical device with an explicit operation, CSE should still
- # collapse them.
- shards = [sharded_cat(input) for i in range(input.shard_count)]
+ dim = input.shard_dim if dim is None else dim
+ # For each device move the shards to it and do a concatenation.
+ # If we don't move first, common sub-expression elimination is free to collapse all
+ # concatenations into one and then copy to all devices, which is not what we want.
+ shards = [
+ cat(
+ [
+ shard if i == j else transfer_to_logical_device(shard, i)
+ for j, shard in enumerate(input.shards)
+ ],
+ dim=dim,
+ )
+ for i in range(input.shard_count)
+ ]
+ return ReplicatedTensor(ts=shards)
+
+
+@all_reduce.override(AllOfType(SplitPrimitiveTensor, UnreducedTensor))
+def all_reduce_split_or_unreduced(
+ input: Union[SplitPrimitiveTensor, UnreducedTensor],
+) -> ReplicatedTensor:
+ # For each device move the shards to it and do a reduction.
+ # If we don't move first, common sub-expression elimination is free to collapse all
+ # reductions into one and then copy to all devices, which is not what we want.
+ shards = [
+ functools.reduce(
+ lambda x, y: elementwise(torch.add, x, y),
+ [
+ shard if i == j else transfer_to_logical_device(shard, i)
+ for j, shard in enumerate(input.shards)
+ ],
+ )
+ for i in range(input.shard_count)
+ ]
+ return ReplicatedTensor(ts=shards)
+
+
+@cat.override(AllOfType(ReplicatedTensor))
+def cat_replicated(tensors: Sequence[ReplicatedTensor], dim: int) -> ReplicatedTensor:
+ assert len(tensors) > 0
+ shard_count = tensors[0].shard_count
+ assert all([t.shard_count == shard_count for t in tensors])
+
+ shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])]
return ReplicatedTensor(ts=shards)
@cat.override(AllOfType(SplitPrimitiveTensor))
-def cat_sharded(tensors: Sequence[SplitPrimitiveTensor], dim: int):
+def cat_split(
+ tensors: Sequence[SplitPrimitiveTensor], dim: int
+) -> SplitPrimitiveTensor:
assert len(tensors) > 0
assert all(
[
@@ -61,6 +102,7 @@ def cat_sharded(tensors: Sequence[SplitPrimitiveTensor], dim: int):
return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)
else:
# TODO: implement efficient cat along split dim.
+ # This would probably result in doing the concatenation on one device.
concatenated_unsharded = cat(
[shard for t in tensors for shard in t.shards], dim
)
@@ -88,7 +130,7 @@ def conv2d_all_split(
input.is_replicated or input.shard_dim == 1
), "Only sharding of input channel dimension is supported"
assert (
- weight.shard_dim == 0 and bias.shard_dim == 0
+ bias is None or weight.shard_dim == 0 and bias.shard_dim == 0
), "Only sharding of output channel dimension is supported"
# TODO: allow for implementation where we don't all-gather, but gather
@@ -146,7 +188,7 @@ def conv2d_replicated_input_split_weight_and_bias(
assert input.shard_count == weight.shard_count
assert bias is None or weight.shard_count == bias.shard_count
assert (
- weight.shard_dim == 0 and bias.shard_dim == 0
+ bias is None or weight.shard_dim == 0 and bias.shard_dim == 0
), "Only sharding of output channel dimension is supported"
assert groups == 1
@@ -189,7 +231,8 @@ def conv2d_split_weight_and_bias(
accum_dtype,
) -> SplitPrimitiveTensor:
assert accum_dtype is None, "accum_dtype not supported"
- assert weight.shard_count == bias.shard_count
+ if bias is not None:
+ assert weight.shard_count == bias.shard_count
# Output channels dimension is split.
if weight.shard_dim == 0 and groups == 1:
@@ -225,49 +268,67 @@ def conv2d_split_weight_and_bias(
@elementwise.override(ReplicatedTensor)
-def replicated_elementwise_unary(operator, x: ReplicatedTensor):
- partials = [operator(unbox_tensor(pt)) for pt in x.shards]
+def replicated_elementwise_unary(operator, x: ReplicatedTensor, *args, **kwargs):
+ partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards]
return ReplicatedTensor(ts=partials)
@elementwise.override(SplitPrimitiveTensor)
-def split_elementwise_unary(operator, x: SplitPrimitiveTensor):
- partials = [operator(unbox_tensor(pt)) for pt in x.shards]
+def split_elementwise_unary(operator, x: SplitPrimitiveTensor, *args, **kwargs):
+ partials = [operator(unbox_tensor(pt), *args, **kwargs) for pt in x.shards]
return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials)
+@elementwise.override(ReplicatedTensor, ReplicatedTensor)
+def replicated_elementwise_binary(
+ operator, x: ReplicatedTensor, y: ReplicatedTensor, *args, **kwargs
+):
+ assert x.shard_count == y.shard_count
+ shards = [
+ operator(unbox_tensor(shard_x), unbox_tensor(shard_y), *args, **kwargs)
+ for shard_x, shard_y in zip(x.shards, y.shards)
+ ]
+ return ReplicatedTensor(ts=shards)
+
+
@elementwise.override(SplitPrimitiveTensor, SplitPrimitiveTensor)
def split_elementwise_binary(
- operator, x: SplitPrimitiveTensor, y: SplitPrimitiveTensor
+ operator, x: SplitPrimitiveTensor, y: SplitPrimitiveTensor, *args, **kwargs
):
assert x.shard_count == y.shard_count
x_shard_dim, y_shard_dim = broadcast_dims([x.shard_dim, y.shard_dim], [x, y])
assert x_shard_dim == y_shard_dim
pt_xs = [unbox_tensor(pt) for pt in x.shards]
pt_ys = [unbox_tensor(pt) for pt in y.shards]
- partials = [operator(pt_x, pt_y) for pt_x, pt_y in zip(pt_xs, pt_ys)]
- return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials)
+ partials = [
+ operator(pt_x, pt_y, *args, **kwargs) for pt_x, pt_y in zip(pt_xs, pt_ys)
+ ]
+ return SplitPrimitiveTensor(
+ shard_dim=x.shard_dim,
+ shape=torch.broadcast_shapes(x.shape, y.shape),
+ ts=partials,
+ )
@elementwise.override(SplitPrimitiveTensor, Number)
def elementwise_binary_split_lhs_scalar_rhs(
- operator, x: SplitPrimitiveTensor, y: Number
+ operator, x: SplitPrimitiveTensor, y: Number, *args, **kwargs
):
pt_xs = [unbox_tensor(pt) for pt in x.shards]
- partials = [operator(pt_x, y) for pt_x in pt_xs]
+ partials = [operator(pt_x, y, *args, **kwargs) for pt_x in pt_xs]
return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials)
@elementwise.override(SplitPrimitiveTensor, Tensor)
def elementwise_binary_split_lhs_tensor_rhs(
- operator, x: SplitPrimitiveTensor, y: Tensor
+ operator, x: SplitPrimitiveTensor, y: Tensor, *args, **kwargs
):
- return elementwise(operator, x, replicate(y, count=x.shard_count))
+ return elementwise(operator, x, reshard_like(y, like=x), *args, **kwargs)
@elementwise.override(ReplicatedTensor, SplitPrimitiveTensor)
def elementwise_binary_replicated_lhs_sharder_rhs(
- operator, x: ReplicatedTensor, y: SplitPrimitiveTensor
+ operator, x: ReplicatedTensor, y: SplitPrimitiveTensor, *args, **kwargs
):
if x.shard_count != y.shard_count:
raise ValueError(
@@ -276,20 +337,76 @@ def elementwise_binary_replicated_lhs_sharder_rhs(
# A replicated tensor can be split with no cost.
# It is natural to propagate the split instead of the replication.
x_sharded = reshard_like(x, like=y)
- return elementwise(operator, x_sharded, y)
+ return elementwise(operator, x_sharded, y, *args, **kwargs)
@elementwise.override(SplitPrimitiveTensor, ReplicatedTensor)
def elementwise_binary_split_lhs_replicated_rhs(
- operator, x: SplitPrimitiveTensor, y: ReplicatedTensor
+ operator, x: SplitPrimitiveTensor, y: ReplicatedTensor, *args, **kwargs
):
assert len(y.shape) > 0, "0-rank not supported"
if x.shard_count != y.shard_count:
raise ValueError(
f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})"
)
+
+ shard_dim_in_res = broadcast_dim(x.shard_dim, [x.shape, y.shape])
+ shard_dim_in_y = unbroadcast_dim(shard_dim_in_res, [y.shape, x.shape])
+ is_shard_dim_broadcasted_in_y = (
+ shard_dim_in_y is None or y.shape[shard_dim_in_y] == 1
+ )
+ if is_shard_dim_broadcasted_in_y:
+ shards = [
+ elementwise(operator, x_shard, y_shard)
+ for x_shard, y_shard in zip(x.shards, y.shards)
+ ]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim_in_res)
+
y_sharded = reshard_like(y, like=x)
- return elementwise(operator, x, y_sharded)
+ return elementwise(operator, x, y_sharded, *args, **kwargs)
+
+
+@elementwise.override(ReplicatedTensor, UnreducedTensor)
+def elementwise_binary_replicated_lhs_unreduced_rhs(
+ operator, x: ReplicatedTensor, y: UnreducedTensor, *args, **kwargs
+):
+ if x.shard_count != y.shard_count:
+ raise ValueError(
+ f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})"
+ )
+ y_replicated = reshard_like(y, like=x)
+ return elementwise(operator, x, y_replicated, *args, **kwargs)
+
+
+@elementwise.override(ReplicatedTensor, Tensor)
+def elementwise_binary_replicated_lhs_unsharded_rhs(
+ operator, x: ReplicatedTensor, y: Tensor, *args, **kwargs
+):
+ y_replicated = reshard_like(y, like=x)
+ return elementwise(operator, x, y_replicated, *args, **kwargs)
+
+
+@elementwise.override(Tensor, ReplicatedTensor)
+def elementwise_binary_replicated_lhs_unsharded_rhs(
+ operator, x: Tensor, y: ReplicatedTensor, *args, **kwargs
+):
+ x_replicated = reshard_like(x, like=y)
+ return elementwise(operator, x_replicated, y, *args, **kwargs)
+
+
+# Embedding Lookup
+@embedding_lookup.override(ReplicatedTensor, ReplicatedTensor)
+def embedding_lookup_default(
+ input: ReplicatedTensor, embedding_matrix: ReplicatedTensor, dtype: torch.dtype
+):
+ assert input.shard_count == embedding_matrix.shard_count
+ shards = [
+ embedding_lookup(input_shard, embedding_matrix_shard, dtype)
+ for input_shard, embedding_matrix_shard in zip(
+ input.shards, embedding_matrix.shards
+ )
+ ]
+ return ReplicatedTensor(ts=shards)
@equal.override(ReplicatedTensor)
@@ -302,6 +419,77 @@ def equal_split(a: SplitPrimitiveTensor, b: AnyTensor) -> bool:
return a.is_deep_equal(b)
+@expand.override(SplitPrimitiveTensor)
+def expand_split(
+ tensor: SplitPrimitiveTensor, shape: List[int]
+) -> SplitPrimitiveTensor:
+ assert len(shape) == len(tensor.shape)
+ expanded_dims = [
+ i
+ for i, (old_dim, new_dim) in enumerate(zip(tensor.shape, shape))
+ if old_dim == 1 and new_dim != 1
+ ]
+ assert (
+ tensor.shard_dim not in expanded_dims
+ ), "Expanding a split dimension is not supported"
+
+ def set_element(l: List, idx: int, el: Any) -> List:
+ l[idx] = el
+ return l
+
+ shards = [
+ expand(
+ shard,
+ set_element(list(shape), tensor.shard_dim, shard.shape[tensor.shard_dim]),
+ )
+ for shard in tensor.shards
+ ]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)
+
+
+@flatten.override(ReplicatedTensor)
+def flatten_replicated(
+ input: ReplicatedTensor, start_dim: int, end_dim: int
+) -> ReplicatedTensor:
+ shards = [shard.flatten(start_dim, end_dim) for shard in input.shards]
+ return ReplicatedTensor(ts=shards)
+
+
+@flatten.override(SplitPrimitiveTensor)
+def flatten_split(
+ input: SplitPrimitiveTensor, start_dim: int, end_dim: int
+) -> SplitPrimitiveTensor:
+ end_dim_resolved = len(input.shape) - 1 if end_dim == -1 else end_dim
+ assert input.shard_dim <= start_dim or end_dim_resolved < input.shard_dim, (
+ "Flattening of a sharded dimension that is not the leading dimension in the"
+ " flattening dimension range is not supported. This would result in a"
+ " block-cyclic sharding which is not implemented."
+ )
+ assert (
+ input.shard_dim != start_dim
+ or input.shape[input.shard_dim] % input.shard_count == 0
+ ), "If the leading flattening dimension is the split dimension, its size must be divisible by the shard count."
+ shards = [shard.flatten(start_dim, end_dim) for shard in input.shards]
+ shard_dim = (
+ input.shard_dim
+ if input.shard_dim <= start_dim
+ else input.shard_dim - (end_dim_resolved - start_dim)
+ )
+ return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)
+
+
+@gather.override(ReplicatedTensor, ReplicatedTensor)
+def gather_replicated(
+ input: ReplicatedTensor, dim: int, index: ReplicatedTensor
+) -> Tensor:
+ assert input.shard_count == index.shard_count
+ shards = [
+ gather(input_shard, dim, index_shard)
+ for input_shard, index_shard in zip(input.shards, index.shards)
+ ]
+ return ReplicatedTensor(ts=shards)
+
+
@group_norm_affine.override(
SplitPrimitiveTensor, SplitPrimitiveTensor, SplitPrimitiveTensor
)
@@ -322,6 +510,78 @@ def shareded_group_norm_affine(input, weight, bias, *, num_groups, eps):
return SplitPrimitiveTensor(shard_dim=1, ts=result_shards)
+@index_copy_.override(SplitPrimitiveTensor, ReplicatedTensor, SplitPrimitiveTensor)
+def index_copy__split_replicated_split(
+ inout: SplitPrimitiveTensor,
+ dim: int,
+ index: ReplicatedTensor,
+ tensor: SplitPrimitiveTensor,
+) -> SplitPrimitiveTensor:
+ assert (
+ inout.shard_count == index.shard_count
+ and inout.shard_count == tensor.shard_count
+ )
+ assert inout.shard_dim == tensor.shard_dim
+ assert inout.shard_dim != dim
+ for inout_shard, index_shard, tensor_shard in zip(
+ inout.shards, index.shards, tensor.shards
+ ):
+ index_copy_(inout_shard, dim, index_shard, tensor_shard)
+ return inout
+
+
+@index_put_.override(
+ AllOfExprsVariadic(
+ IsOfType(SplitPrimitiveTensor),
+ IsOfType(SplitPrimitiveTensor),
+ IsOfType(Tensor, PrimitiveTensor, ReplicatedTensor),
+ )
+)
+def index_put__split(
+ inout: SplitPrimitiveTensor,
+ indices: Tuple[Union[Tensor, PrimitiveTensor, ReplicatedTensor]],
+ values: SplitPrimitiveTensor,
+) -> SplitPrimitiveTensor:
+ # TODO: verify that the values split dimension is not being indexed or implement
+ # this case.
+ indices = [replicate(idx, count=inout.shard_count) for idx in indices]
+ for i, shard in enumerate(inout.shards):
+ shard_indices = [idx.shards[i] for idx in indices]
+ shard.index_put_(shard_indices, values.shards[i])
+ return inout
+
+
+@index_select.override(ReplicatedTensor, ReplicatedTensor)
+def index_select_replicated(
+ tensor: ReplicatedTensor,
+ dim: int,
+ index: ReplicatedTensor,
+) -> ReplicatedTensor:
+ assert tensor.shard_count == index.shard_count
+ shards = [
+ index_select(tensor_shard, dim, index_shard)
+ for tensor_shard, index_shard in zip(tensor.shards, index.shards)
+ ]
+ return ReplicatedTensor(ts=shards)
+
+
+@index_select.override(SplitPrimitiveTensor, ReplicatedTensor)
+def index_select_split_replicated(
+ tensor: SplitPrimitiveTensor,
+ dim: int,
+ index: ReplicatedTensor,
+) -> ReplicatedTensor:
+ assert tensor.shard_count == index.shard_count
+ assert (
+ dim != tensor.shard_dim
+ ), "Indexing along the split dimension is not supported."
+ shards = [
+ index_select(tensor_shard, dim, index_shard)
+ for tensor_shard, index_shard in zip(tensor.shards, index.shards)
+ ]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)
+
+
@interpolate.override(ReplicatedTensor)
def interpolate_replicated(
input: ReplicatedTensor,
@@ -416,13 +676,14 @@ def matmul_replicated_lhs_split_rhs(
lhs: ReplicatedTensor, rhs: SplitPrimitiveTensor, *, transpose_rhs: bool
) -> SplitPrimitiveTensor | UnreducedTensor:
assert lhs.shard_count == rhs.shard_count
+ assert len(rhs.shape) == 2
if transpose_rhs:
return matmul(lhs, rhs.T)
rhs_reduction_dim = 1
if rhs_reduction_dim != rhs.shard_dim:
- lhs_reduction_dimension = 0
+ lhs_reduction_dimension = len(lhs.shape) - 1
lhs_split = reshard_split(
lhs, dim=lhs_reduction_dimension, count=lhs.shard_count
)
@@ -496,30 +757,108 @@ def matmul_split(
f"Cannot matmul split tensors of different shard_count: "
f"({lhs.shard_count} vs {rhs.shard_count})"
)
+ if transpose_rhs:
+ return matmul(lhs, rhs.T)
lhs_reduction_dim = len(lhs.shape) - 1
- rhs_reduction_dim = 1 if transpose_rhs else 0
+ rhs_reduction_dim = len(rhs.shape) - 2 if len(rhs.shape) > 1 else len(rhs.shape) - 1
# The reduction dimension is split on both tensors.
if lhs_reduction_dim == lhs.shard_dim and rhs_reduction_dim == rhs.shard_dim:
partials = [
- matmul(partial_lhs, partial_rhs, transpose_rhs=transpose_rhs)
+ matmul(partial_lhs, partial_rhs)
for partial_lhs, partial_rhs in zip(lhs.shards, rhs.shards)
]
return UnreducedTensor(ts=partials)
+ is_batched_matmul = len(lhs.shape) > 2 or len(rhs.shape) > 2
+ if (
+ is_batched_matmul
+ and len(lhs.shape) == len(rhs.shape)
+ and lhs.shard_dim == rhs.shard_dim
+ ):
+ # The same batch dim is sharded for both arguments.
+ shards = [
+ matmul(lhs_shard, rhs_shard)
+ for lhs_shard, rhs_shard in zip(lhs.shards, rhs.shards)
+ ]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim)
+
+ # -1 for missing parallel dim.
+ lhs_parallel_dim = len(lhs.shape) - 2
+ rhs_parallel_dim = len(rhs.shape) - 1 if len(rhs.shape) > 1 else -1
+
# One parallel dimension is split for each tensor.
- if lhs_reduction_dim != lhs.shard_dim and rhs_reduction_dim != rhs.shard_dim:
- if transpose_rhs:
- rhs = rhs.T
+ # Or lhs batch dim and rhs parallel dim are split.
+ if lhs.shard_dim <= lhs_parallel_dim and rhs_parallel_dim == rhs.shard_dim:
# We gather along the rhs shard dim.
# It is more natural to preserve the sharding axis of the input.
- shards = [sharded_cat(matmul(lhs_shard, rhs)) for lhs_shard in lhs.shards]
- return SplitPrimitiveTensor(ts=shards, shard_dim=lhs.shard_dim)
+ # TODO: This assumes non-peered memory. We prepare the operands to be
+ # available on the required devices.
+ # We need to distinguish based on some config.
+ replicated_rhs = replicate(rhs, count=lhs.shard_count)
+ return matmul(lhs, replicated_rhs)
assert False, "Sharding configuration not supported"
+# Scaled dot product attention
+@scaled_dot_product_attention.override(
+ SplitPrimitiveTensor,
+ SplitPrimitiveTensor,
+ SplitPrimitiveTensor,
+ Optional[ReplicatedTensor],
+)
+def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor:
+ if q.shard_count != k.shard_count or q.shard_count != v.shard_count:
+ raise ValueError("Incompatible number of shards for qkv")
+
+ if a and q.shard_count != a.shard_count:
+ raise ValueError(
+ f"Incompatible number of shards for a ({a.shard_count}) should be ({q.shard_count})"
+ )
+
+ if q.shard_dim != k.shard_dim or q.shard_dim != v.shard_dim:
+ raise ValueError("Incompatible shard dim across qkv")
+
+ if q.shard_dim > len(q.shards[0].shape) - 2:
+ raise ValueError("Sharding must occur as batch dimension")
+
+ a_shards = [None] * q.shard_count
+ if a is not None:
+ a_shards = a.shards
+
+ output_shards = []
+ for q_s, k_s, v_s, a_s in zip(q.shards, k.shards, v.shards, a_shards):
+ o_s = scaled_dot_product_attention(
+ q_s, k_s, v_s, a_s, is_causal=is_causal, scale=scale
+ )
+ output_shards.append(o_s)
+
+ return SplitPrimitiveTensor(ts=output_shards, shard_dim=q.shard_dim)
+
+
+@mean.override(ReplicatedTensor)
+def mean_replicated(
+ x: ReplicatedTensor,
+ dim: Union[int, List[int]],
+ keepdim: bool,
+ *,
+ dtype: torch.dtype,
+) -> None:
+ shards = [mean(shard, dim=dim, keepdim=keepdim, dtype=dtype) for shard in x.shards]
+ return ReplicatedTensor(ts=shards)
+
+
+@module_register_buffer.override(torch.nn.Module, ShardedTensor)
+def module_register_buffer_sharded(
+ module: torch.nn.Module, name: str, tensor: ShardedTensor
+) -> None:
+ for i, shard in enumerate(tensor.shards):
+ module_register_buffer(module, f"{name}__shard__{i}", shard)
+ setattr(module, name, tensor)
+
+
@permute.override(SplitPrimitiveTensor)
def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]):
permuted_shards = [permute(shard, dims) for shard in tensor.shards]
@@ -528,23 +867,60 @@ def permute_split(tensor: SplitPrimitiveTensor, dims: List[int]):
@permute.override(ReplicatedTensor)
-def permute_split(tensor: ReplicatedTensor, dims: List[int]):
+def permute_replicated(tensor: ReplicatedTensor, dims: List[int]):
permuted_shards = [permute(shard, dims) for shard in tensor.shards]
return ReplicatedTensor(ts=permuted_shards)
+@repeat.override(ReplicatedTensor)
+def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedTensor:
+ shards = [repeat(shard, *sizes) for shard in input.shards]
+ return ReplicatedTensor(ts=shards)
+
+
@replicate.override(ReplicatedTensor)
def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor:
if input.shard_count != count:
raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
- assert input.shard_count == count
return input
+@replicate.override(SplitPrimitiveTensor)
+def replicate_split(input: SplitPrimitiveTensor, *, count: int) -> ReplicatedTensor:
+ if input.shard_count != count:
+ raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
+ return all_gather(input)
+
+
+@replicate.override(UnreducedTensor)
+def replicate_unreduced(input: UnreducedTensor, *, count: int) -> ReplicatedTensor:
+ if input.shard_count != count:
+ raise ValueError(f"Number of shards not equal ({input.shard_count} != {count})")
+ return all_reduce(input)
+
+
@replicate.override(Tensor)
def replicate_unsharded(input, *, count: int) -> ReplicatedTensor:
torch_input = unbox_tensor(input)
- return ReplicatedTensor(ts=torch_input, shard_count=count)
+ # If we have a torch input replicating we can assume we need to transfer:
+ torch_inputs = [transfer_to_logical_device(torch_input, i) for i in range(count)]
+ return ReplicatedTensor(ts=torch_inputs)
+
+
+@reshape.override(SplitPrimitiveTensor)
+def reshape_split(
+ tensor: SplitPrimitiveTensor, shape: List[int]
+) -> SplitPrimitiveTensor:
+ if _reshape_get_single_split_dim(tensor.shape, shape) is not None:
+ return view(tensor, shape)
+
+ flatten_dim_range = _reshape_get_flatten_dim_range(tensor.shape, shape)
+ if flatten_dim_range is not None:
+ return flatten(tensor, flatten_dim_range[0], flatten_dim_range[1] - 1)
+
+ raise ValueError(
+ f"Unsupported reshaping of sharded split tensor of shape {tensor.shape} to shape {shape}"
+ )
@reshard.override(Tensor, sharding.Split)
@@ -572,7 +948,13 @@ def make_value(input: Theta | InferenceTensor, spec) -> dict | InferenceTensor:
result.name = input.name
return result
- return Theta({k: make_value(input(k), spec[k]) for k in input.keys})
+ return Theta(
+ {
+ k: make_value(input(k), spec[k])
+ for k in input.keys
+ if not isinstance(spec[k], sharding.Ignore)
+ }
+ )
@reshard.override(Theta, sharding.ThetaLayerSharding)
@@ -691,13 +1073,23 @@ def reshard_like_split_to_split(
return tensor
-# Sharded sum.
+@reshard_like.override(UnreducedTensor, ReplicatedTensor)
+def reshard_like_unreduced_to_replicated(
+ tensor: UnreducedTensor, like: ReplicatedTensor
+) -> ReplicatedTensor:
+ return replicate(tensor, count=like.shard_count)
@sharded_cat.override(SplitPrimitiveTensor)
-def sharded_cat_unsharded(maybe_sharded: SplitPrimitiveTensor):
- shard_ts = [t.as_torch() for t in maybe_sharded.shards]
- return torch.cat(shard_ts, dim=maybe_sharded.shard_dim)
+def sharded_cat_unsharded(tensor: SplitPrimitiveTensor):
+ shard_ts = [
+ transfer_to_logical_device(shard.as_torch(), 0) if i != 0 else shard.as_torch()
+ for i, shard in enumerate(tensor.shards)
+ ]
+ return torch.cat(shard_ts, dim=tensor.shard_dim)
+
+
+# Sharded sum.
def _sharded_sum_sharded(tensor: ShardedTensor) -> Tensor:
@@ -708,16 +1100,67 @@ def _sharded_sum_sharded(tensor: ShardedTensor) -> Tensor:
@sharded_sum.override(SplitPrimitiveTensor)
-def sharded_sum_split(maybe_sharded: SplitPrimitiveTensor):
+def sharded_sum_split(maybe_sharded: SplitPrimitiveTensor) -> Tensor:
# TODO: Should implement as an all reduce.
return _sharded_sum_sharded(maybe_sharded)
@sharded_sum.override(UnreducedTensor)
-def sharded_sum_unreduced(maybe_sharded: UnreducedTensor):
+def sharded_sum_unreduced(maybe_sharded: UnreducedTensor) -> Tensor:
return _sharded_sum_sharded(maybe_sharded)
+@softmax.override(SplitPrimitiveTensor)
+def softmax_split(
+ tensor: SplitPrimitiveTensor, dim: Optional[int], dtype: Optional[torch.dtype]
+) -> Tensor:
+ dim = dim if dim is None or dim >= 0 else len(tensor.shape) + dim
+ assert (
+ dim is not None and dim != tensor.shard_dim
+ ), "Softmax along split dimension is not supported."
+ shards = [softmax(shard, dim=dim, dtype=dtype) for shard in tensor.shards]
+ return SplitPrimitiveTensor(
+ ts=shards, shard_dim=tensor.shard_dim, shape=tensor.shape
+ )
+
+
+@to.override(ReplicatedTensor)
+def to_replicated(tensor: ReplicatedTensor, *args, **kwargs):
+ shards = [to(shard, *args, **kwargs) for shard in tensor.shards]
+ return ReplicatedTensor(ts=shards)
+
+
+@to.override(SplitPrimitiveTensor)
+def to_split(tensor: SplitPrimitiveTensor, *args, **kwargs):
+ shards = [to(shard, *args, **kwargs) for shard in tensor.shards]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)
+
+
+@transpose.override(SplitPrimitiveTensor)
+def transpose_split(
+ tensor: SplitPrimitiveTensor, dim0: int, dim1: int
+) -> SplitPrimitiveTensor:
+ shards = [transpose(shard, dim0, dim1) for shard in tensor.shards]
+ shard_dim = tensor.shard_dim
+ if shard_dim == dim0:
+ shard_dim = dim1
+ elif shard_dim == dim1:
+ shard_dim = dim0
+ return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)
+
+
+@unflatten.override(SplitPrimitiveTensor)
+def unflatten_split(
+ input: SplitPrimitiveTensor, dim: int, sizes: Tuple[int]
+) -> SplitPrimitiveTensor:
+ assert dim != input.shard_dim, "Unflattening the split dimension is not supported."
+ shards = [unflatten(shard, dim, sizes) for shard in input.shards]
+ shard_dim = input.shard_dim
+ if dim < shard_dim:
+ shard_dim += len(sizes) - 1
+ return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)
+
+
@unshard.override(ReplicatedTensor)
def unshard_replicated(input: ReplicatedTensor) -> Tensor:
return input.shards[0]
@@ -728,6 +1171,160 @@ def unshard_split(input: SplitPrimitiveTensor) -> Tensor:
return sharded_cat(input)
+@unshard.override(UnreducedTensor)
+def unshard_unreduced(input: UnreducedTensor) -> Tensor:
+ shards = input.shards
+ shards = [
+ shard if i == 0 else transfer_to_logical_device(shard, 0)
+ for i, shard in enumerate(shards)
+ ]
+ return functools.reduce(lambda x, y: elementwise(torch.add, x, y), shards)
+
+
@unshard.override(Tensor)
def unshard_unsharded(input: Tensor) -> Tensor:
return input
+
+
+def _reshape_get_flatten_dim_range(
+ from_shape: List[int], to_shape: List[int]
+) -> Optional[Tuple[int, int]]:
+ """If a reshape would flatten a range of dimensions return that index range [begin, end).
+ If the reshape is not of that kind return `None`."""
+ flatten_start_len = _reshape_get_single_split_dim(to_shape, from_shape)
+ if flatten_start_len is None:
+ return None
+ start, length = flatten_start_len
+ return start, start + length
+
+
+def _reshape_infer_dynamic_dim(
+ shape1: List[int], shape2: List[int]
+) -> Tuple[List[int], List[int]]:
+ assert (
+ len([d for d in list(shape1) + list(shape2) if d < 0]) <= 1
+ ), "Only one dynamic dimension is allowed"
+ shape1_dynamic_dims = [i for i, d in enumerate(shape1) if d <= 0]
+ if len(shape1_dynamic_dims) > 0:
+ s2, s1 = _reshape_infer_dynamic_dim(shape2, shape1)
+ return s1, s2
+
+ shape2_dynamic_dims = [i for i, d in enumerate(shape2) if d <= 0]
+ if len(shape2_dynamic_dims) == 0:
+ return shape1, shape2
+ shape2_dynamic_dim = shape2_dynamic_dims[0]
+ shape1_size = math.prod(shape1)
+ shape2_size_without_dynamic_dim = math.prod(d for d in shape2 if d > 0)
+ shape2_res = list(shape2)
+ assert shape1_size % shape2_size_without_dynamic_dim == 0
+ shape2_res[shape2_dynamic_dim] = shape1_size // shape2_size_without_dynamic_dim
+ assert shape2_res[shape2_dynamic_dim] > 0
+ return shape1, shape2_res
+
+
+def _reshape_get_single_split_dim(
+ from_shape: List[int], to_shape: List[int]
+) -> Optional[Tuple[int, int]]:
+ """If a reshape would split a single dimension, return its index and the length of the new dimensions.
+ If the reshape is not of that kind return `None`.
+ E.g.
+ _reshape_get_single_split_dim(from_shape=(2, 12, 5), to_shape=(2, 3, 4, 5))
+ results in
+ (1, 2)"""
+ from_shape, to_shape = _reshape_infer_dynamic_dim(from_shape, to_shape)
+
+ if len(to_shape) < len(from_shape):
+ return None
+ i = longest_equal_range(from_shape, to_shape)
+ split_dims_length = len(to_shape) - len(from_shape) + 1
+ if i == len(from_shape):
+ return (
+ i,
+ split_dims_length,
+ )
+ j = len(to_shape) - longest_equal_range(reversed(from_shape), reversed(to_shape))
+ assert i < j
+ expected_split_dim_size = math.prod(to_shape[i:j])
+ if expected_split_dim_size == 1:
+ # 1's were inserted.
+ return (
+ i,
+ split_dims_length,
+ )
+ if expected_split_dim_size != from_shape[i]:
+ return None
+ return (
+ i,
+ split_dims_length,
+ )
+
+
+@unsqueeze.override(SplitPrimitiveTensor)
+def unsqueeze_split(tensor: SplitPrimitiveTensor, dim: int) -> SplitPrimitiveTensor:
+ shards = [torch.unsqueeze(unbox_tensor(shard), dim) for shard in tensor.shards]
+ shard_dim = tensor.shard_dim
+ dim_resolved = dim if dim >= 0 else dim + len(tensor.shape) + 1
+ if shard_dim >= dim_resolved:
+ shard_dim += 1
+ return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)
+
+
+@unsqueeze.override(ReplicatedTensor)
+def unsqueeze_replicated(tensor: ReplicatedTensor, dim: int) -> SplitPrimitiveTensor:
+ shards = [torch.unsqueeze(unbox_tensor(shard), dim) for shard in tensor.shards]
+ return ReplicatedTensor(ts=shards)
+
+
+@view.override(SplitPrimitiveTensor)
+def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitiveTensor:
+ view_split_range = _reshape_get_single_split_dim(tensor.shape, shape)
+ if view_split_range is None:
+ raise ValueError(
+ "Only taking a tensor view where splitting a single dimension is supported"
+ )
+ view_split_dim = view_split_range[0]
+
+ if view_split_dim == tensor.shard_dim:
+ if tensor.shape[view_split_dim] % tensor.shard_count != 0:
+ raise ValueError(
+ "Only splitting a dimension that is multiple of the shard count is supported"
+ )
+ if shape[view_split_dim] % tensor.shard_count != 0:
+ raise ValueError(
+ "The resulting leading splitting dimension must be multiple of the shard count"
+ )
+
+ shard_dim = tensor.shard_dim
+ if shard_dim > view_split_dim:
+ new_dims_count = len(shape) - len(tensor.shape)
+ shard_dim += new_dims_count
+ new_shard_shape = list(shape)
+ new_shard_shape[shard_dim] //= tensor.shard_count
+ shards = [view(shard, new_shard_shape) for shard in tensor.shards]
+ res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
+ assert math.prod(res.shape) == math.prod(tensor.shape)
+ return res
+
+
+@view_as_complex.override(SplitPrimitiveTensor)
+def view_as_complex_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor:
+ shards = [view_as_complex(shard) for shard in tensor.shards]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)
+
+
+@view_as_complex.override(ReplicatedTensor)
+def view_as_complex_rep(tensor: ReplicatedTensor) -> ReplicatedTensor:
+ shards = [view_as_complex(shard) for shard in tensor.shards]
+ return ReplicatedTensor(ts=shards)
+
+
+@view_as_real.override(SplitPrimitiveTensor)
+def view_as_real_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor:
+ shards = [view_as_real(shard) for shard in tensor.shards]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)
+
+
+@view_as_real.override(ReplicatedTensor)
+def view_as_real_rep(tensor: ReplicatedTensor) -> ReplicatedTensor:
+ shards = [view_as_real(shard) for shard in tensor.shards]
+ return ReplicatedTensor(ts=shards)
diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py
index 6350d3e0b..408f00ec7 100644
--- a/sharktank/sharktank/ops/signatures.py
+++ b/sharktank/sharktank/ops/signatures.py
@@ -11,34 +11,58 @@
import torch
import numbers
from torch import Tensor, dtype
-from ..types import AnyTensor, ShardedTensor, Theta, sharding
+from ..types import AnyTensor, ShardedTensor, Theta, sharding, InferenceTensor
from numbers import Number
+import math
from ._registry import *
__all__ = [
"all_gather",
+ "all_reduce",
"cat",
"conv2d",
+ "einsum_2args",
"elementwise",
"embedding_lookup",
"equal",
+ "expand",
+ "flatten",
+ "gather",
+ "gelu_tanh_approximation",
+ "get_index",
"gemm",
"group_norm_affine",
"layer_norm",
+ "index_copy_",
+ "index_put_",
+ "index_select",
"interpolate",
"linear",
"matmul",
+ "mean",
+ "module_register_buffer",
"permute",
"rms_norm",
+ "repeat",
"replicate",
+ "reshape",
"reshard",
"reshard_split",
"reshard_like",
"scaled_dot_product_attention",
"sharded_cat",
"sharded_sum",
+ "softmax",
+ "to",
+ "transfer_to_logical_device",
+ "transpose",
+ "unflatten",
"unshard",
+ "unsqueeze",
+ "view",
+ "view_as_complex",
+ "view_as_real",
]
IntOrSequenceInt = Union[int, Sequence[int]]
@@ -46,6 +70,7 @@
@overridable
def all_gather(maybe_sharded: AnyTensor, *, dim: int | None = None) -> AnyTensor:
+ "Gather/concatenate on all devices along dimension `dim`."
...
@@ -62,6 +87,23 @@ def _all_gather_trampoline(
d.fail(tensors)
+@overridable
+def all_reduce(tensor: AnyTensor) -> AnyTensor:
+ "Reduce on all devices."
+ ...
+
+
+@all_reduce.trampoline
+def _all_reduce_trampoline(d: SignatureDispatcher, tensor: AnyTensor):
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
@overridable
def cat(tensors: Tuple[AnyTensor, ...] | List[AnyTensor], dim: int = 0) -> AnyTensor:
...
@@ -132,16 +174,52 @@ def _conv2d_trampoline(
@overridable
-def elementwise(operator, *args: AnyTensor) -> AnyTensor:
+def einsum_2args(
+ input0: AnyTensor,
+ input1: AnyTensor,
+ einsum_str: str,
+ *,
+ accum_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+ """Executes a given Einstein summation notation string on the provided tensors.
+
+ Equivalent to:
+ ```
+ y = torch.einsum(einsum_str, input0, input1)
+ ```
+ """
+ raise NotImplementedError
+
+
+@einsum_2args.trampoline
+def _einsum_trampoline(
+ d: SignatureDispatcher, input0: AnyTensor, input1: AnyTensor, einsum_str: str
+):
+ tensors = (input0, input1)
+ for override in d.find_overrides(tensors):
+ result = override(input0, input1, einsum_str)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def elementwise(operator, *args, **kwargs) -> AnyTensor:
"""Applies an elementwise operator against arguments."""
raise NotImplementedError
@elementwise.trampoline
-def _elementwise_trampoline(d: SignatureDispatcher, operator, *args: AnyTensor):
- tensors = args
+def _elementwise_trampoline(d: SignatureDispatcher, operator, *args, **kwargs):
+ tensors = []
+ for a in args:
+ if isinstance(a, (Tensor, InferenceTensor)):
+ tensors.append(a)
+ else:
+ break
for override in d.find_overrides(tensors):
- result = override(operator, *args)
+ result = override(operator, *args, **kwargs)
if result is not NotImplemented:
return override, result
else:
@@ -212,6 +290,109 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor):
d.fail(tensors)
+@overridable
+def expand(tensor: AnyTensor, shape: List[int]) -> AnyTensor:
+ """See torch.Tensor.expand"""
+ ...
+
+
+@expand.trampoline
+def _expand_trampoline(
+ d: SignatureDispatcher, tensor: AnyTensor, shape: List[int]
+) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor, shape)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def get_index(
+ tensor: AnyTensor,
+ key: slice,
+) -> torch.Tensor:
+ """Indexes the tensor using the key.
+
+ Equivalent to:
+ ```
+ out = tensor[key]
+ ```
+ """
+ raise NotImplementedError
+
+
+@get_index.trampoline
+def _get_index_trampoline(d: SignatureDispatcher, tensor: AnyTensor, key: slice):
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor, key)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def flatten(input: AnyTensor, start_dim: int = 0, end_dim: int = -1) -> AnyTensor:
+ """See torch.flatten"""
+ ...
+
+
+@flatten.trampoline
+def _flatten_trampoline(
+ d: SignatureDispatcher, input: AnyTensor, start_dim: int = 0, end_dim: int = -1
+) -> AnyTensor:
+ dispatch_args = (input,)
+ for override in d.find_overrides(dispatch_args):
+ result = override(input, start_dim, end_dim)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(dispatch_args)
+
+
+@overridable
+def gather(input: AnyTensor, dim: int, index: AnyTensor) -> AnyTensor:
+ """See torch.gather"""
+ ...
+
+
+@gather.trampoline
+def _gather_trampoline(
+ d: SignatureDispatcher, input: AnyTensor, dim: int, index: AnyTensor
+) -> AnyTensor:
+ dispatch_args = (
+ input,
+ index,
+ )
+ for override in d.find_overrides(dispatch_args):
+ result = override(input, dim, index)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(dispatch_args)
+
+
+def gelu_tanh_approximation(input: AnyTensor) -> AnyTensor:
+ """Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ Approximation with tanh"""
+ return (
+ 0.5
+ * input
+ * (
+ 1.0
+ + elementwise(
+ torch.tanh,
+ math.sqrt(2.0 / math.pi)
+ * (input + 0.044715 * elementwise(torch.pow, input, 3.0)),
+ )
+ )
+ )
+
+
@overridable
def gemm(
a: AnyTensor,
@@ -278,6 +459,75 @@ def _group_norm_affine_trampoline(
d.fail(tensors)
+@overridable
+def index_copy_(
+ inout: AnyTensor, dim: int, index: AnyTensor, tensor: AnyTensor
+) -> AnyTensor:
+ """See torch.Tensor.index_copy_"""
+ ...
+
+
+@index_copy_.trampoline
+def _index_copy__trampoline(
+ d: SignatureDispatcher,
+ inout: AnyTensor,
+ dim: int,
+ index: AnyTensor,
+ tensor: AnyTensor,
+) -> AnyTensor:
+ tensors = (inout, index, tensor)
+ for override in d.find_overrides(tensors):
+ result = override(inout, dim, index, tensor)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def index_put_(
+ inout: AnyTensor, indices: Tuple[AnyTensor], values: AnyTensor
+) -> AnyTensor:
+ """See torch.Tensor.index_put_"""
+ ...
+
+
+@index_put_.trampoline
+def _index_put__trampoline(
+ d: SignatureDispatcher,
+ inout: AnyTensor,
+ indices: Tuple[AnyTensor],
+ values: AnyTensor,
+) -> AnyTensor:
+ # We change the order for the variadic indices to be last.
+ tensors = (inout, values, *indices)
+ for override in d.find_overrides(tensors):
+ result = override(inout, indices, values)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def index_select(tensor: AnyTensor, dim: int, index: AnyTensor) -> AnyTensor:
+ """See torch.Tensor.index_select"""
+ ...
+
+
+@index_select.trampoline
+def _index_select_trampoline(
+ d: SignatureDispatcher, tensor: AnyTensor, dim: int, index: AnyTensor
+) -> AnyTensor:
+ tensors = (tensor, index)
+ for override in d.find_overrides(tensors):
+ result = override(tensor, dim, index)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
@overridable
def interpolate(
input: AnyTensor,
@@ -416,7 +666,6 @@ def _matmul_trampoline(
d: SignatureDispatcher, lhs, rhs, *, transpose_rhs: bool = False
):
tensors = (lhs, rhs)
- assert isinstance(rhs, numbers.Number) or len(rhs.shape) == 2
for override in d.find_overrides(tensors):
result = override(lhs, rhs, transpose_rhs=transpose_rhs)
if result is not NotImplemented:
@@ -444,6 +693,57 @@ def _permute_trampoline(d: SignatureDispatcher, tensor: AnyTensor, dims: List[in
d.fail(tensors)
+@overridable
+def mean(
+ x: AnyTensor,
+ dim: Union[int, List[int]],
+ keepdim: bool = False,
+ *,
+ dtype: torch.dtype = None,
+) -> AnyTensor:
+ """See torch.mean"""
+ raise NotImplementedError
+
+
+@mean.trampoline
+def _mean_trampoline(
+ d: SignatureDispatcher,
+ x: AnyTensor,
+ dim: Union[int, List[int]],
+ keepdim: bool = False,
+ *,
+ dtype: torch.dtype = None,
+) -> AnyTensor:
+ tensors = (x,)
+ for override in d.find_overrides(tensors):
+ result = override(x, dim, keepdim, dtype=dtype)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def module_register_buffer(
+ module: torch.nn.Module, name: str, tensor: AnyTensor
+) -> None:
+ """Register the tensor into the module. See torch.nn.Module.register_buffer."""
+ ...
+
+
+@module_register_buffer.trampoline
+def _module_register_buffer_trampoline(
+ d: SignatureDispatcher, module: torch.nn.Module, name: str, tensor: AnyTensor
+) -> None:
+ args = (module, tensor)
+ for override in d.find_overrides(args):
+ result = override(module, name, tensor)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(args)
+
+
@overridable
def rms_norm(x: AnyTensor, weight: AnyTensor, *, epsilon: float) -> AnyTensor:
"""Computes the full, unbiased RMS normalization of an input."""
@@ -463,6 +763,25 @@ def _rms_norm_trampoline(
d.fail(tensors)
+@overridable
+def repeat(input: AnyTensor, *sizes: List[int]) -> AnyTensor:
+ """See torch.Tensor.repeat"""
+ ...
+
+
+@repeat.trampoline
+def _repeat_trampoline(
+ d: SignatureDispatcher, input: AnyTensor, *sizes: List[int]
+) -> AnyTensor:
+ dispatch_args = (input,)
+ for override in d.find_overrides(dispatch_args):
+ result = override(input, *sizes)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(dispatch_args)
+
+
@overridable
def replicate(input: AnyTensor, count: int) -> ShardedTensor:
"""Replicate across devices.
@@ -486,7 +805,7 @@ def _replicate_trampoline(
@overridable
def scaled_dot_product_attention(
- q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor]
+ q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor], is_causal: bool
) -> AnyTensor:
"""Computes the scaled dot product attention using QKV."""
raise NotImplementedError
@@ -499,16 +818,38 @@ def _scaled_dot_product_attention(
k: AnyTensor,
v: AnyTensor,
a: Optional[AnyTensor],
+ is_causal: bool = False,
+ scale: Optional[float] = None,
):
tensors = (q, k, v, a)
for override in d.find_overrides(tensors):
- result = override(q, k, v, a)
+ result = override(q, k, v, a, is_causal=is_causal, scale=scale)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)
+@overridable
+def reshape(input: AnyTensor, shape: List[int]) -> AnyTensor:
+ """Returns a tensor with the same data and number of elements as input, but with
+ the specified shape.
+ See torch.reshape.
+ """
+ ...
+
+
+@reshape.trampoline
+def _reshape_trampoline(d: SignatureDispatcher, input, shape) -> AnyTensor:
+ dispatch_args = (input,)
+ for override in d.find_overrides(dispatch_args):
+ result = override(input, shape)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(dispatch_args)
+
+
@overridable
def reshard(
input: AnyTensor | Theta,
@@ -616,6 +957,104 @@ def _sharded_sum_trampoline(d: SignatureDispatcher, maybe_sharded: AnyTensor):
d.fail(tensors)
+@overridable
+def softmax(
+ tensor: AnyTensor, dim: Optional[int] = None, dtype: Optional[torch.dtype] = None
+) -> AnyTensor:
+ """See torch.nn.functional.softmax"""
+ ...
+
+
+@softmax.trampoline
+def _softmax_trampoline(
+ d: SignatureDispatcher,
+ tensor: AnyTensor,
+ dim: Optional[int] = None,
+ dtype: Optional[torch.dtype] = None,
+) -> AnyTensor:
+ dispatch_args = [tensor]
+ for override in d.find_overrides(dispatch_args):
+ result = override(tensor, dim=dim, dtype=dtype)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(dispatch_args)
+
+
+@overridable
+def to(tensor: AnyTensor, *args, **kwargs) -> AnyTensor:
+ """See torch.Tensor.to"""
+ ...
+
+
+@to.trampoline
+def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):
+ dispatch_args = [tensor]
+ for override in d.find_overrides(dispatch_args):
+ result = override(tensor, *args, **kwargs)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(dispatch_args)
+
+
+@overridable
+def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
+ """Transfer the tensor to a device with ordinal `ordinal`."""
+ ...
+
+
+@transfer_to_logical_device.trampoline
+def _transfer_to_logical_device_trampoline(
+ d: SignatureDispatcher, tensor: AnyTensor, ordinal: int
+):
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor, ordinal)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def transpose(tensor: AnyTensor, dim0: int, dim1: int) -> AnyTensor:
+ """See torch.transpose"""
+ ...
+
+
+@transpose.trampoline
+def _transpose_trampoline(
+ d: SignatureDispatcher, tensor: AnyTensor, dim0: int, dim1: int
+) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor, dim0, dim1)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def unflatten(input: AnyTensor, dim: int, sizes: Tuple[int]) -> AnyTensor:
+ """See torch.unflatten"""
+ ...
+
+
+@unflatten.trampoline
+def _unflatten_trampoline(
+ d: SignatureDispatcher, input: AnyTensor, dim: int, sizes: Tuple[int]
+) -> AnyTensor:
+ dispatch_args = (input,)
+ for override in d.find_overrides(dispatch_args):
+ result = override(input, dim, sizes)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(dispatch_args)
+
+
@overridable
def unshard(tensor: AnyTensor) -> AnyTensor:
"""Return the tensor that has the same elements and shape, but is not sharded."""
@@ -631,3 +1070,75 @@ def _unshard_trampoline(d: SignatureDispatcher, tensor: AnyTensor):
return override, result
else:
d.fail(tensors)
+
+
+@overridable
+def unsqueeze(tensor: AnyTensor, dim: int) -> AnyTensor:
+ """See torch.unsqueeze"""
+ ...
+
+
+@unsqueeze.trampoline
+def _unsqueeze_trampoline(
+ d: SignatureDispatcher, tensor: AnyTensor, dim: int
+) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor, dim)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def view(tensor: AnyTensor, shape: List[int]) -> AnyTensor:
+ """See torch.Tensor.view"""
+ ...
+
+
+@view.trampoline
+def _view_trampoline(
+ d: SignatureDispatcher, tensor: AnyTensor, shape: List[int]
+) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor, shape)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def view_as_complex(tensor: AnyTensor, shape: List[int]) -> AnyTensor:
+ """See torch.Tensor.view_as_complex"""
+ ...
+
+
+@view_as_complex.trampoline
+def _view_as_complex_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def view_as_real(tensor: AnyTensor, shape: List[int]) -> AnyTensor:
+ """See torch.Tensor.view_as_complex"""
+ ...
+
+
+@view_as_real.trampoline
+def _view_as_real_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
diff --git a/shortfin/shortfin/__init__.py b/sharktank/sharktank/serving_poc/__init__.py
similarity index 100%
rename from shortfin/shortfin/__init__.py
rename to sharktank/sharktank/serving_poc/__init__.py
diff --git a/shortfin/shortfin/framework/logging.py b/sharktank/sharktank/serving_poc/framework/logging.py
similarity index 95%
rename from shortfin/shortfin/framework/logging.py
rename to sharktank/sharktank/serving_poc/framework/logging.py
index 4843ee358..fe5ffc069 100644
--- a/shortfin/shortfin/framework/logging.py
+++ b/sharktank/sharktank/serving_poc/framework/logging.py
@@ -24,7 +24,7 @@ def __init__(self):
def _setup_logger():
- root_logger = logging.getLogger("shortfin")
+ root_logger = logging.getLogger("sharktank.serving_poc")
root_logger.setLevel(logging.DEBUG)
default_handler = logging.StreamHandler(sys.stderr)
default_handler.flush = sys.stderr.flush
diff --git a/shortfin/shortfin/framework/session.py b/sharktank/sharktank/serving_poc/framework/session.py
similarity index 99%
rename from shortfin/shortfin/framework/session.py
rename to sharktank/sharktank/serving_poc/framework/session.py
index c28b81f5b..28af0fd44 100644
--- a/shortfin/shortfin/framework/session.py
+++ b/sharktank/sharktank/serving_poc/framework/session.py
@@ -319,7 +319,7 @@ def __init__(self, session: DeviceSession, index: int = 0):
self._semaphore = session.device.create_semaphore(0)
self._step = 0
- def execute_sequential(self, command_buffers: list[HalCommandBuffer]):
+ def execute_sequential(self, command_buffer: HalCommandBuffer):
"""Executes a list of command buffers at the current step, advancing to the
next.
"""
@@ -329,7 +329,7 @@ def execute_sequential(self, command_buffers: list[HalCommandBuffer]):
self._step = next_step
sem = self._semaphore
self._device.queue_execute(
- command_buffers, [(sem, current_step)], [(sem, next_step)]
+ command_buffer, [(sem, current_step)], [(sem, next_step)]
)
def current_fence(self) -> HalFence:
diff --git a/shortfin/shortfin/llm/__init__.py b/sharktank/sharktank/serving_poc/llm/__init__.py
similarity index 100%
rename from shortfin/shortfin/llm/__init__.py
rename to sharktank/sharktank/serving_poc/llm/__init__.py
diff --git a/shortfin/shortfin/llm/api/rest_server.py b/sharktank/sharktank/serving_poc/llm/api/rest_server.py
similarity index 98%
rename from shortfin/shortfin/llm/api/rest_server.py
rename to sharktank/sharktank/serving_poc/llm/api/rest_server.py
index 33810428b..67536173f 100644
--- a/shortfin/shortfin/llm/api/rest_server.py
+++ b/sharktank/sharktank/serving_poc/llm/api/rest_server.py
@@ -27,7 +27,7 @@
GenerateRequest,
)
-logger = get_logger("shortfin.llm.api_server")
+logger = get_logger("sharktank.serving_poc.llm.api_server")
app = FastAPI()
service: Optional[GenerateService] = None
diff --git a/shortfin/shortfin/llm/attn_block_cache.py b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py
similarity index 98%
rename from shortfin/shortfin/llm/attn_block_cache.py
rename to sharktank/sharktank/serving_poc/llm/attn_block_cache.py
index a9443a1e3..a2299c67e 100644
--- a/shortfin/shortfin/llm/attn_block_cache.py
+++ b/sharktank/sharktank/serving_poc/llm/attn_block_cache.py
@@ -21,7 +21,7 @@
from .config import human_size, CacheParams
-logger = get_logger("shortfin.llm.cache")
+logger = get_logger("sharktank.serving_poc.llm.cache")
class AttnBlockCacheEntry:
diff --git a/shortfin/shortfin/llm/config.py b/sharktank/sharktank/serving_poc/llm/config.py
similarity index 100%
rename from shortfin/shortfin/llm/config.py
rename to sharktank/sharktank/serving_poc/llm/config.py
diff --git a/shortfin/shortfin/llm/impl/service_v1.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py
similarity index 99%
rename from shortfin/shortfin/llm/impl/service_v1.py
rename to sharktank/sharktank/serving_poc/llm/impl/service_v1.py
index f3fd80a93..8ae0be637 100644
--- a/shortfin/shortfin/llm/impl/service_v1.py
+++ b/sharktank/sharktank/serving_poc/llm/impl/service_v1.py
@@ -42,7 +42,7 @@
)
-logger = get_logger("shortfin.llm.impl.service_v1")
+logger = get_logger("sharktank.serving_poc.llm.impl.service_v1")
EXPECTED_CONCURRENCY = 10
@@ -340,7 +340,7 @@ async def prefill(self) -> TimelineGuarded[HalBufferView]:
# Perform h2d transfers.
cb.end()
- work_queue.execute_sequential([cb])
+ work_queue.execute_sequential(cb)
# Inputs:
# token_ids
@@ -468,7 +468,7 @@ async def decode(self) -> TimelineGuarded[HalBufferView]:
# Perform h2d transfers.
cb.end()
- work_queue.execute_sequential([cb])
+ work_queue.execute_sequential(cb)
# Inputs:
# token_ids
diff --git a/shortfin/shortfin/llm/impl/service_v1_cli.py b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py
similarity index 92%
rename from shortfin/shortfin/llm/impl/service_v1_cli.py
rename to sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py
index 9fc55b5ac..7895341c9 100644
--- a/shortfin/shortfin/llm/impl/service_v1_cli.py
+++ b/sharktank/sharktank/serving_poc/llm/impl/service_v1_cli.py
@@ -15,21 +15,21 @@
HalElementType,
)
-from shortfin.framework.session import DeviceSession
+from sharktank.serving_poc.framework.session import DeviceSession
-from shortfin.llm.attn_block_cache import (
+from sharktank.serving_poc.llm.attn_block_cache import (
create_attn_block_cache_module,
AttnBlockCache,
)
-from shortfin.llm.config import (
+from sharktank.serving_poc.llm.config import (
CacheParams,
ModelParams,
ServiceParams,
)
-from shortfin.llm.impl.service_v1 import GenerateServiceV1
-from shortfin.llm.service import GenerateRequest
+from sharktank.serving_poc.llm.impl.service_v1 import GenerateServiceV1
+from sharktank.serving_poc.llm.service import GenerateRequest
def setup(vmfb_path, config_path, gguf_path):
diff --git a/shortfin/shortfin/llm/service.py b/sharktank/sharktank/serving_poc/llm/service.py
similarity index 100%
rename from shortfin/shortfin/llm/service.py
rename to sharktank/sharktank/serving_poc/llm/service.py
diff --git a/shortfin/shortfin/llm/testing/fake_v1_module.py b/sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py
similarity index 100%
rename from shortfin/shortfin/llm/testing/fake_v1_module.py
rename to sharktank/sharktank/serving_poc/llm/testing/fake_v1_module.py
diff --git a/shortfin/shortfin/py.typed b/sharktank/sharktank/serving_poc/py.typed
similarity index 100%
rename from shortfin/shortfin/py.typed
rename to sharktank/sharktank/serving_poc/py.typed
diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py
index ab383a14c..9a7dcf1ee 100644
--- a/sharktank/sharktank/types/gguf_interop/base.py
+++ b/sharktank/sharktank/types/gguf_interop/base.py
@@ -11,9 +11,9 @@
import numpy as np
import torch
-from gguf import GGUFReader, GGUFValueType
+from gguf import GGUFReader, GGUFValueType, ReaderField
-from shark_turbine.aot import (
+from iree.turbine.aot import (
ExternalTensorTrait,
)
@@ -44,12 +44,26 @@ def _sanitize_scalar(scalar):
return scalar
+def _load_array(field: ReaderField) -> list:
+ if len(field.types) != 2:
+ raise ValueError(f"Unsupported array type {field.types}")
+ element_type = field.types[1]
+ if element_type == GGUFValueType.STRING:
+ return [
+ str(bytes(field.parts[parts_index]), encoding="utf8")
+ for parts_index in field.data
+ ]
+ elif element_type in GGUFReader.gguf_scalar_to_np:
+ return [
+ _sanitize_scalar(field.parts[parts_index][0]) for parts_index in field.data
+ ]
+ else:
+ raise ValueError(f"Unsupported array element type f{element_type}")
+
+
def _load_properties(reader: GGUFReader) -> dict[str, Any]:
- # TODO: Figure out what to do with tables.
- tables: dict[str, Any] = {}
properties: dict[str, Any] = {
"schema": "GGUF",
- # "tables": tables,
}
# Extract hyper-parameters. Adapted from gguf-dump.py
@@ -60,8 +74,10 @@ def _load_properties(reader: GGUFReader) -> dict[str, Any]:
properties[field.name] = str(bytes(field.parts[-1]), encoding="utf8")
elif field.types[0] in reader.gguf_scalar_to_np:
properties[field.name] = _sanitize_scalar(field.parts[-1][0])
+ elif field.types[0] == GGUFValueType.ARRAY:
+ properties[field.name] = _load_array(field)
else:
- tables[field.name] = field.parts
+ raise ValueError(f"Invalid field type.")
return properties
diff --git a/sharktank/sharktank/types/layouts.py b/sharktank/sharktank/types/layouts.py
index 54210da9b..586e4f673 100644
--- a/sharktank/sharktank/types/layouts.py
+++ b/sharktank/sharktank/types/layouts.py
@@ -22,8 +22,8 @@
register_quantized_layout,
MetaDataValueType,
QuantizedLayout,
- _dtype_to_serialized_name,
- _serialized_name_to_dtype,
+ dtype_to_serialized_name,
+ serialized_name_to_dtype,
)
from .layout_utils import (
@@ -96,7 +96,7 @@ def create(
m = planes.get("m")
dtype_str = metadata.get("dtype")
if dtype_str is not None:
- dtype = _serialized_name_to_dtype(dtype_str)
+ dtype = serialized_name_to_dtype(dtype_str)
else:
# Backwards compat with old serialized. Emulate original behavior
# before mixed precision.
@@ -106,7 +106,7 @@ def create(
@property
def metadata(self) -> Optional[dict[str, MetaDataValueType]]:
"""Additional metadata needed to reconstruct a layout."""
- return {"dtype": _dtype_to_serialized_name(self._dtype)}
+ return {"dtype": dtype_to_serialized_name(self._dtype)}
@property
def planes(self) -> dict[str, torch.Tensor]:
diff --git a/sharktank/sharktank/types/quantizers.py b/sharktank/sharktank/types/quantizers.py
index 75189bdf3..d3c093b85 100644
--- a/sharktank/sharktank/types/quantizers.py
+++ b/sharktank/sharktank/types/quantizers.py
@@ -38,8 +38,8 @@
QuantizedTensor,
UnnamedTensorName,
register_inference_tensor,
- _serialized_name_to_dtype,
- _dtype_to_serialized_name,
+ serialized_name_to_dtype,
+ dtype_to_serialized_name,
)
__all__ = [
@@ -131,6 +131,25 @@ def __init__(
else:
assert len(self._scale.shape) == 0, "Expected per-tensor scale to be 0D"
+ def dequantize_raw_tensor(
+ self, t: torch.Tensor, to: torch.dtype, *, name: str
+ ) -> torch.Tensor:
+ return (
+ PlanarQuantizedTensor(
+ shape=t.shape,
+ name=t.name,
+ layout=TensorScaledLayout(
+ shape=t.shape,
+ d=self._reciprocal_scale,
+ qs=t,
+ m=self.offset,
+ dtype=to,
+ ),
+ )
+ .unpack()
+ .dequant()
+ )
+
def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor:
"""Performs a quantizing transformation on t, returning a QuantizeTensor."""
shape = list(t.shape)
@@ -139,14 +158,15 @@ def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor
if axis is None:
# Per tensor.
if offset is None:
+ # Changed to t/reciprocal because narrow float types are garbage
qs = saturate_cast(
- t * self._scale,
+ t / self._reciprocal_scale,
dtype=self.dtype,
disable_saturate=self._disable_saturate,
)
else:
qs = saturate_cast(
- t * self._scale + offset,
+ t / self._reciprocal_scale + offset,
dtype=self.dtype,
disable_saturate=self._disable_saturate,
)
@@ -245,7 +265,7 @@ def create(
raise IOError("Missing property") from e
axis = int(extra_properties["axis"]) if "axis" in extra_properties else None
disable_saturate = bool(extra_properties.get("disable_saturate"))
- dtype = _serialized_name_to_dtype(dtype_name)
+ dtype = serialized_name_to_dtype(dtype_name)
return cls(
name=name,
scale=scale,
@@ -271,7 +291,7 @@ def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetad
scale_name = f"{self.name}:scale"
rscale_name = f"{self.name}:rscale"
offset_name = f"{self.name}:offset"
- extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)}
+ extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)}
if self._axis is not None:
extra_properties["axis"] = self._axis
if self._disable_saturate:
@@ -387,7 +407,7 @@ def create(
dtype_name = extra_properties["dtype"]
except KeyError as e:
raise IOError("Missing property") from e
- dtype = _serialized_name_to_dtype(dtype_name)
+ dtype = serialized_name_to_dtype(dtype_name)
return cls(
name=name,
dtype=dtype,
@@ -399,7 +419,7 @@ def globals(self) -> dict[str, torch.Tensor]:
def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata:
"""Adds this tensor to the global archive."""
- extra_properties = {"dtype": _dtype_to_serialized_name(self._dtype)}
+ extra_properties = {"dtype": dtype_to_serialized_name(self._dtype)}
raw_tensors = {}
return InferenceTensorMetadata(
self.serialized_name(),
diff --git a/sharktank/sharktank/types/sharding.py b/sharktank/sharktank/types/sharding.py
index a670cae33..81d2f31a5 100644
--- a/sharktank/sharktank/types/sharding.py
+++ b/sharktank/sharktank/types/sharding.py
@@ -18,7 +18,7 @@ def __init__(self):
class TensorSharding(Sharding):
- def __init__(self, *, shard_count: int):
+ def __init__(self, shard_count: int):
super().__init__()
self.shard_count = shard_count
@@ -29,7 +29,7 @@ def __init__(self):
class Replicated(TensorSharding):
- def __init__(self, *, shard_count: int):
+ def __init__(self, shard_count: int):
super().__init__(shard_count=shard_count)
@@ -39,6 +39,17 @@ def __init__(self, *, shard_count: int, shard_dim: int):
self.shard_dim = shard_dim
+class Ignore(TensorSharding):
+ """When a theta is sharded, a tensor or a branch with this sharding type will be
+ ignored.
+ It will not appear in the resulting sharded theta.
+ This is not strictly a TensorSharding. It will terminate further traversal of a
+ branch of a theta tree as well."""
+
+ def __init__(self):
+ super().__init__(shard_count=0)
+
+
class ThetaSharding(dict):
"""Sharding for each tensor in a theta.
It is of type dict[str, "ThetaSharding" | TensorSharding].
@@ -49,7 +60,15 @@ def __init__(self, *args, **kwargs):
for k, v in d.items():
d[k] = tree.map_nodes(
tree=v,
- f=lambda x: x if isinstance(x, TensorSharding) else ThetaSharding(x),
+ f=lambda x: x
+ if isinstance(
+ x,
+ (
+ TensorSharding,
+ ThetaSharding,
+ ),
+ )
+ else ThetaSharding(x),
)
super().__init__(d)
@@ -89,6 +108,27 @@ def theta_sharding(self) -> ThetaSharding:
)
+class FFNSharding(ThetaLayerSharding):
+ def __init__(self, shard_count: int):
+ super().__init__()
+ self.shard_count = shard_count
+
+ def theta_sharding(self) -> ThetaSharding:
+ return ThetaSharding(
+ {
+ "ffn_gate": LinearSplitParallelWeightAndBiasSharding(
+ shard_count=self.shard_count
+ ).theta_sharding(),
+ "ffn_up": LinearSplitParallelWeightAndBiasSharding(
+ shard_count=self.shard_count
+ ).theta_sharding(),
+ "ffn_down": LinearSplitReductionDimSharding(
+ shard_count=self.shard_count
+ ).theta_sharding(),
+ }
+ )
+
+
class GroupNormSplitChannelSharding(ThetaLayerSharding):
def __init__(self, shard_count: int):
super().__init__()
@@ -103,23 +143,67 @@ def theta_sharding(self) -> ThetaSharding:
)
-class LinearReplicatedInputSplitWeightAndBiasSharding(ThetaLayerSharding):
+class LinearLayerSharding(ThetaLayerSharding):
+ def __init__(
+ self, premul_input: TensorSharding, weight: TensorSharding, bias: TensorSharding
+ ):
+ super().__init__()
+ self.premul_input = premul_input
+ self.weight = weight
+ self.bias = bias
+
+ def theta_sharding(self) -> ThetaSharding:
+ return ThetaSharding(
+ {
+ "premul_input": self.premul_input,
+ "weight": self.weight,
+ "bias": self.bias,
+ }
+ )
+
+
+class LinearSplitParallelWeightAndBiasSharding(LinearLayerSharding):
def __init__(self, shard_count: int, weight_and_bias_spit_dim: int = 0):
+ """Split one parallel dimension for both the weight and bias.
+ Since the weight is transposed before multiplying, the weight parallel
+ dimension is the same as the output(bias) dimension."""
+ super().__init__(
+ premul_input=Replicated(shard_count=shard_count),
+ weight=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim),
+ bias=Split(shard_count=shard_count, shard_dim=weight_and_bias_spit_dim),
+ )
+
+
+class LinearSplitReductionDimSharding(LinearLayerSharding):
+ def __init__(self, shard_count: int):
+ super().__init__(
+ premul_input=Replicated(shard_count=shard_count),
+ weight=Split(shard_count=shard_count, shard_dim=1),
+ bias=Replicated(shard_count=shard_count),
+ )
+
+
+class RmsNormReplicatedSharding(ThetaLayerSharding):
+ def __init__(self, shard_count: int):
+ super().__init__()
+ self.shard_count = shard_count
+
+ def theta_sharding(self) -> ThetaSharding:
+ return ThetaSharding(
+ {
+ "weight": Replicated(shard_count=self.shard_count),
+ }
+ )
+
+
+class TokenEmbeddingLayerReplicatedSharding(ThetaLayerSharding):
+ def __init__(self, shard_count: int):
super().__init__()
self.shard_count = shard_count
- self.weight_and_bias_spit_dim = weight_and_bias_spit_dim
def theta_sharding(self) -> ThetaSharding:
return ThetaSharding(
{
- "premul_input": Replicated(shard_count=self.shard_count),
- "weight": Split(
- shard_count=self.shard_count,
- shard_dim=self.weight_and_bias_spit_dim,
- ),
- "bias": Split(
- shard_count=self.shard_count,
- shard_dim=self.weight_and_bias_spit_dim,
- ),
+ "weight": Replicated(shard_count=self.shard_count),
}
)
diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py
index eab7d5823..f870aa101 100644
--- a/sharktank/sharktank/types/tensors.py
+++ b/sharktank/sharktank/types/tensors.py
@@ -17,20 +17,22 @@
Tuple,
)
from copy import deepcopy
-from collections.abc import Collection
-from numbers import Integral
+from collections.abc import Collection, Sequence
+from numbers import Integral, Number
+import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from torch import Tensor
-from torch.utils._pytree import register_pytree_node
+from torch.utils._pytree import register_pytree_node, SequenceKey
+import torch.utils._pytree
from ..utils.math import ceildiv
-from shark_turbine.aot import (
+from iree.turbine.aot import (
+ DeviceTensorTrait,
ExternalTensorTrait,
)
-from shark_turbine.ops.iree import transfer_to_logical_device
from ..utils import tree as tree_utils
from ..utils.io import ShardedArchiveBuilder
@@ -38,6 +40,7 @@
__all__ = [
"AnyTensor",
"DefaultPrimitiveTensor",
+ "dtype_to_serialized_name",
"flatten_tensor_tree",
"InferenceTensor",
"MetaDataValueType",
@@ -47,12 +50,26 @@
"QuantizedTensor",
"register_quantized_layout",
"ReplicatedTensor",
+ "serialized_name_to_dtype",
"ShardedTensor",
"SplitPrimitiveTensor",
+ "torch_tree_flatten",
"unbox_tensor",
"UnreducedTensor",
]
+if (
+ "SHARKTANK_OVERRIDE_TORCH_TENSOR_REPR" in os.environ
+ and os.environ["SHARKTANK_OVERRIDE_TORCH_TENSOR_REPR"] != "0"
+):
+
+ def _tensor_debugger_friendly_repr(self: torch.Tensor):
+ """Override for the torch.Tensor.__repr__ so it does not take forever when the
+ debugger wants to query many/large tensors."""
+ return f"Tensor({list(self.shape)}, {self.dtype})"
+
+ Tensor.__repr__ = _tensor_debugger_friendly_repr
+
# JSON encodable value types.
MetaDataValueType = Union[int, bool, float, str]
UnnamedTensorName = ""
@@ -255,8 +272,15 @@ def transform_globals(
return self._clone_with_globals(prev_globals)
def to(
- self, *, device: Optional[Union[str, torch.device]] = None
+ self,
+ *,
+ device: Optional[Union[str, torch.device]] = None,
) -> "InferenceTensor":
+ # TODO: reconcile with ops.to(...) and torch.Tensor.to(...).
+ # Do we always want to clone with globals?
+ # This makes our type inconsistent with torch tensors.
+ # If we use this to transform a theta we want to change the theta.
+ # If we want to use this in a computation we don't want to change the theta.
return self.transform_globals(
lambda d: {k: t.to(device=device) for k, t in d.items()}
)
@@ -279,26 +303,144 @@ def T(self) -> "InferenceTensor":
return permute(self, dims=dims)
+ @property
+ def dtype(self) -> torch.dtype:
+ raise NotImplementedError()
+
+ def expand(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
+ from ..ops import expand
+
+ if all(isinstance(a, int) for a in args):
+ shape = args
+ else:
+ assert len(args) == 1
+ shape = args[0]
+ return expand(self, shape)
+
+ def flatten(self, start_dim: int = 0, end_dim: int = -1) -> "AnyTensor":
+ from ..ops import flatten
+
+ return flatten(self, start_dim, end_dim)
+
+ def index_copy_(
+ self, dim: int, index: "AnyTensor", tensor: "AnyTensor"
+ ) -> "InferenceTensor":
+ from ..ops import index_copy_
+
+ return index_copy_(self, dim, index, tensor)
+
+ def index_put_(
+ self, indices: Tuple["AnyTensor"], values: "AnyTensor"
+ ) -> "InferenceTensor":
+ from ..ops import index_put_
+
+ return index_put_(self, indices, values)
+
+ def index_select(
+ self,
+ dim: int,
+ index: "AnyTensor",
+ ) -> "InferenceTensor":
+ from ..ops import index_select
+
+ return index_select(self, dim, index)
+
+ def mean(
+ self,
+ dim: Union[int, List[int]],
+ keepdim: bool = False,
+ *,
+ dtype: torch.dtype = None,
+ ) -> "AnyTensor":
+ from ..ops import mean
+
+ return mean(self, dim, keepdim, dtype=None)
+
+ def pow(self, exponent: Union["AnyTensor", Number]) -> "AnyTensor":
+ from ..ops import elementwise
+
+ return elementwise(torch.pow, self, exponent)
+
+ def repeat(self, *sizes: List[int]) -> "AnyTensor":
+ from ..ops import repeat
+
+ return repeat(self, *sizes)
+
+ def reshape(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
+ from ..ops import reshape
+
+ if all(isinstance(a, int) for a in args):
+ shape = args
+ else:
+ assert len(args) == 1
+ shape = args[0]
+ return reshape(self, shape)
+
+ def transpose(self, dim0: int, dim1: int) -> "AnyTensor":
+ from ..ops import transpose
+
+ return transpose(self, dim0, dim1)
+
+ def unflatten(self, dim: int, sizes: Tuple[int]) -> "AnyTensor":
+ from ..ops import unflatten
+
+ return unflatten(self, dim, sizes)
+
+ def unsqueeze(self, dim: int) -> "AnyTensor":
+ from ..ops import unsqueeze
+
+ return unsqueeze(self, dim)
+
+ def view(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
+ from ..ops import view
+
+ if all(isinstance(a, int) or isinstance(a, torch.SymInt) for a in args):
+ shape = args
+ else:
+ assert len(args) == 1
+ shape = args[0]
+ return view(self, shape)
+
def __add__(self, rhs):
from ..ops import elementwise
return elementwise(torch.add, self, rhs)
def __radd__(self, lhs):
- # Assumes commutative addition due to torch.elementwise not handling numbers on
- # the lhs.
+ # Assumes commutative addition due to torch elementwise ops not handling
+ # numbers on the lhs.
return self.__add__(lhs)
+ def __mod__(self, rhs):
+ from ..ops import elementwise
+
+ return elementwise(torch.remainder, self, rhs)
+
def __mul__(self, rhs):
from ..ops import elementwise
return elementwise(torch.mul, self, rhs)
def __rmul__(self, lhs):
- # Assumes commutative multiplication due to torch.elementwise not handling
+ # Assumes commutative multiplication due to torch elementwise ops not handling
# numbers on the lhs.
return self.__mul__(lhs)
+ def __truediv__(self, rhs):
+ from ..ops import elementwise
+
+ return elementwise(torch.true_divide, self, rhs)
+
+ def __floordiv__(self, rhs):
+ from ..ops import elementwise
+
+ return elementwise(torch.floor_divide, self, rhs)
+
+ def __getitem__(self, key):
+ from ..ops import get_index
+
+ return get_index(self, key)
+
REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {}
@@ -338,6 +480,17 @@ def as_torch(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
...
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.as_torch().dtype
+
+ def __setitem__(self, key, value: "AnyTensor"):
+ if not isinstance(key, list) and not isinstance(key, tuple):
+ key = (key,)
+
+ key = [unbox_tensor(k) if isinstance(k, PrimitiveTensor) else k for k in key]
+ self.as_torch()[*key] = unbox_tensor(value)
+
@register_inference_tensor
class DefaultPrimitiveTensor(PrimitiveTensor):
@@ -391,7 +544,15 @@ def _clone_with_globals(
return DefaultPrimitiveTensor(name=self.name, data=new_globals[self.name])
def __getitem__(self, key):
- return self._data[key]
+ keys = [key]
+ if isinstance(key, tuple) or isinstance(key, list):
+ keys = key
+
+ keys = [
+ unbox_tensor(key) if isinstance(key, PrimitiveTensor) else key
+ for key in keys
+ ]
+ return self._data[*keys]
def __repr__(self):
return f"PrimitiveTensor({self.name}, {self.shape}, {self._data.dtype})"
@@ -434,7 +595,9 @@ def to_planar(self) -> "PlanarQuantizedTensor":
it should override this method to implement properly or raise
NotImplementedError.
"""
- return PlanarQuantizedTensor(self.name, self.shape, self.unpack())
+ return PlanarQuantizedTensor(
+ name=self.name, shape=self.shape, layout=self.unpack()
+ )
def add_to_archive(self, builder: ShardedArchiveBuilder) -> InferenceTensorMetadata:
"""By default all QuantizedTensors serialize as a generic PlanarQuantizedTensor.
@@ -601,6 +764,10 @@ def name(self, name: str):
for i, shard in enumerate(self.shards):
shard.name = f"{name}.shard.{i}"
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.shards[0].dtype
+
@register_inference_tensor
class ShardedTensorBase(ShardedTensor):
@@ -619,13 +786,15 @@ def __init__(
shape: Optional[list[int]],
):
assert len(ts) > 0
- assert shard_dim is None or len(ts[0].shape) > shard_dim
+ assert shard_dim is None or (shard_dim >= 0 and len(ts[0].shape) > shard_dim)
super().__init__(name=name, shape=shape, shard_dim=shard_dim)
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
- data=transfer_to_logical_device(f"{i}", unbox_tensor(t)),
+ data=t,
)
+ if isinstance(t, torch.Tensor)
+ else t
for i, t in enumerate(ts)
)
@@ -695,6 +864,8 @@ def create(
try:
t = raw_tensors[t_name]
ts.append(t)
+ # TODO: this should be changed to tracked device affinity
+ DeviceTensorTrait(i).set(t)
except KeyError as e:
raise IOError(
f"Missing component tensor '{t_name}' in {raw_tensors.keys()}"
@@ -745,6 +916,32 @@ def _is_full_slice(s: slice, dim_size: int) -> bool:
)
+def _resolve_ellipsis_in_slicing(key: Tuple[Any], shape: Tuple[int]) -> Tuple[Any]:
+ """Example:
+ key = [1:2, ..., 0]
+ shape = [2, 3, 4, 5, 6]
+ Returns:
+ [1:2, :, :, :, 0]"""
+ num_ellipsis = len([k for k in key if k == Ellipsis])
+ assert num_ellipsis <= 1, "Only one Ellipses is allowed."
+ if num_ellipsis <= 0:
+ return key
+ assert len(key) <= len(
+ shape
+ ), "Inserting trailing singleton dimensions is not supported."
+ dim = 0
+ res = []
+ for k in key:
+ if k == Ellipsis:
+ ellipsis_num_dims = len(shape) - len(key) + 1
+ res.extend([slice(None)] * ellipsis_num_dims)
+ dim += ellipsis_num_dims
+ else:
+ dim += 1
+ res.append(k)
+ return tuple(res)
+
+
@register_inference_tensor
class SplitPrimitiveTensor(ShardedTensorBase):
"""Sharded tensor split along a dimension into primitive tensors.
@@ -770,8 +967,11 @@ def __init__(
number of pieces.
"""
if isinstance(ts, torch.Tensor):
+ from ..ops import transfer_to_logical_device
+
assert shard_count is not None
ts = ts.split(ceildiv(ts.shape[shard_dim], shard_count), dim=shard_dim)
+ ts = [transfer_to_logical_device(t, i) for i, t in enumerate(ts)]
assert len(ts) == shard_count
shard_count = None
@@ -779,21 +979,23 @@ def __init__(
assert len(ts) > 0
first_shape = ts[0].shape
assert len(first_shape) > shard_dim
- if shape is None:
- # Compute the shape.
- shape = list(first_shape)
- shape[shard_dim] *= len(ts)
-
- # Assert the shape.
- shard_dim_size = first_shape[shard_dim]
- for t in ts[1:]:
- assert (
- t.shape == first_shape
- ), f"Shape mismatch for split tensors: {t.shape} vs {first_shape}"
- shard_dim_size += t.shape[shard_dim]
- assert (
- shard_dim_size == shape[shard_dim]
- ), f"Sharding mismatch: Sharded dims do not cover the whole volume {shard_dim_size} vs {shape[shard_dim]}"
+ expected_shape = list(first_shape)
+ expected_shape[shard_dim] = sum([t.shape[shard_dim] for t in ts])
+ if shape is not None:
+ shape = list(shape)
+ assert expected_shape == shape
+ else:
+ shape = expected_shape
+
+ # Assert the shapes.
+ for i, t in enumerate(ts):
+ t_shape = list(t.shape)
+ assert len(shape) == len(
+ t_shape
+ ), f"Shape size mismatch tensor shard {i} with shape {t.shape}. Expected shape size {len(shape)}. Got {len(t_shape)}."
+ assert all(
+ s == t for i, (s, t) in enumerate(zip(shape, t_shape)) if i != shard_dim
+ ), f"Shape mismatch for non-split dimension for tensor shard {i} with shape {t.shape}"
super().__init__(name=name, ts=ts, shape=shape, shard_dim=shard_dim)
@@ -813,7 +1015,7 @@ def _is_slicing_split_dim(self, key):
else:
# Any other collection is a indexing only dimension 0.
return self.shard_dim == 0
- if len(key) < self.shard_dim:
+ if len(key) <= self.shard_dim:
return False
if not isinstance(key[self.shard_dim], slice):
return True
@@ -834,19 +1036,52 @@ def _get_shard_slice(self, key):
if len(key) <= self.shard_count:
return key
new_key = list(key)
- new_key[self.shard_dim] = slice(None)
+
+ if self.shard_dim < len(new_key):
+ new_key[self.shard_dim] = slice(None)
return new_key
def __getitem__(self, key):
# TODO: implement all cases.
- # Only slicing of non-split dimension is supported.
+ if not isinstance(key, Sequence):
+ key = (key,)
+ key = _resolve_ellipsis_in_slicing(key, self.shape)
if self._is_slicing_split_dim(key):
raise NotImplementedError(
f"Slicing of the split dimension {self.shard_dim} is not supported."
)
new_key = self._get_shard_slice(key)
- shards = [shard[new_key] for shard in self.shards]
- return SplitPrimitiveTensor(ts=shards, shard_dim=self.shard_dim)
+
+ shards = []
+ for i, shard in enumerate(self.shards):
+ shard_keys = [
+ k.shards[i] if isinstance(k, ReplicatedTensor) else k for k in new_key
+ ]
+ shards.append(shard[*shard_keys])
+
+ shard_dim = self.shard_dim
+ for i in range(min(shard_dim, len(key))):
+ if isinstance(key[i], Number) and key[i] >= 0:
+ # Rank reduction dimension before the split dim.
+ shard_dim -= 1
+
+ return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)
+
+ def __setitem__(self, key, value):
+ assert isinstance(value, SplitPrimitiveTensor)
+ assert self.shard_count == value.shard_count
+ if not isinstance(key, Sequence):
+ key = (key,)
+ key = _resolve_ellipsis_in_slicing(key, self.shape)
+ if self._is_slicing_split_dim(key):
+ raise NotImplementedError(
+ f"Slicing of the split dimension {self.shard_dim} is not supported."
+ )
+ for i, (shard, value_shard) in enumerate(zip(self.shards, value.shards)):
+ shard_keys = [
+ k.shards[i] if isinstance(k, ReplicatedTensor) else k for k in key
+ ]
+ shard[*shard_keys] = unbox_tensor(value_shard)
@register_inference_tensor
@@ -866,10 +1101,11 @@ def __init__(
If `ts` is a tensor then `shard_count` must be provided and it,
will be replicated that many times.
"""
-
if isinstance(ts, torch.Tensor):
assert shard_count is not None
- ts = [ts] * shard_count
+ from ..ops import transfer_to_logical_device
+
+ ts = [transfer_to_logical_device(ts, i) for i in range(shard_count)]
shard_count = None
assert shard_count is None
@@ -884,8 +1120,10 @@ def __init__(
self._shards: tuple[DefaultPrimitiveTensor] = tuple(
DefaultPrimitiveTensor(
name=f"{name}.shard.{i}",
- data=transfer_to_logical_device(f"{i}", unbox_tensor(t)),
+ data=t,
)
+ if isinstance(t, torch.Tensor)
+ else t
for i, t in enumerate(ts)
)
@@ -936,13 +1174,35 @@ def create(
) -> "InferenceTensor":
shard_count = int(extra_properties["shard_count"])
try:
- ts = raw_tensors[""]
+ # We have to do this to avoid exporting as part of the `mlir` blob:
+ t = raw_tensors[""]
+ ts = [raw_tensors[""]]
+ for i in range(1, shard_count):
+ nt = deepcopy(t)
+ ts.append(nt)
+
+ # TODO This should be changed to assigned affinities
+ for i in range(shard_count):
+ DeviceTensorTrait(i).set(ts[i])
+
except KeyError as e:
raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e
- return cls(name=name, ts=ts, shard_count=shard_count)
+ return cls(name=name, ts=ts)
def __getitem__(self, key):
- shards = [shard[key] for shard in self.shards]
+ keys = [key]
+ if isinstance(key, tuple) or isinstance(key, list):
+ keys = key
+
+ shards = []
+ for i, shard in enumerate(self.shards):
+ shard_keys = []
+ for k in keys:
+ if isinstance(k, ReplicatedTensor):
+ shard_keys.append(k.shards[i])
+ else:
+ shard_keys.append(k)
+ shards.append(shard[*shard_keys])
return ReplicatedTensor(ts=shards)
def __repr__(self):
@@ -988,6 +1248,7 @@ def __init__(
def flatten_tensor_tree(
tree: tree_utils.Tree,
) -> Iterable[torch.Tensor | InferenceTensor]:
+ """Flatten up to our tensor types."""
return tree_utils.flatten(
tree,
is_leaf=lambda x: isinstance(
@@ -1016,7 +1277,7 @@ def unbox_tensor(t: Any) -> Tensor:
########################################################################################
-def _dtype_to_serialized_name(dtype: torch.dtype) -> str:
+def dtype_to_serialized_name(dtype: torch.dtype) -> str:
try:
return _DTYPE_TO_NAME[dtype]
except KeyError as e:
@@ -1025,7 +1286,7 @@ def _dtype_to_serialized_name(dtype: torch.dtype) -> str:
) from e
-def _serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
+def serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
try:
return _NAME_TO_DTYPE[dtype_name]
except KeyError as e:
@@ -1047,6 +1308,7 @@ def _serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
+ "float8_e4m3fnuz": torch.float8_e4m3fnuz,
}
@@ -1097,10 +1359,16 @@ def unflatten_defult_primitive_tensor(
return DefaultPrimitiveTensor(data=values_as_list[0], name=ctx["name"])
+def flatten_with_keys_default_primitive_tensor(t: DefaultPrimitiveTensor):
+ values, context = flatten_default_primitive_tensor(t)
+ return [(SequenceKey(i), v) for i, v in enumerate(values)], context
+
+
register_pytree_node(
DefaultPrimitiveTensor,
flatten_fn=flatten_default_primitive_tensor,
unflatten_fn=unflatten_defult_primitive_tensor,
+ flatten_with_keys_fn=flatten_with_keys_default_primitive_tensor,
)
@@ -1118,10 +1386,16 @@ def unflatten_split_primitive_tensor(
)
+def flatten_with_keys_split_primitive_tensor(t: SplitPrimitiveTensor):
+ values, context = flatten_split_primitive_tensor(t)
+ return [(SequenceKey(i), v) for i, v in enumerate(values)], context
+
+
register_pytree_node(
SplitPrimitiveTensor,
flatten_fn=flatten_split_primitive_tensor,
unflatten_fn=unflatten_split_primitive_tensor,
+ flatten_with_keys_fn=flatten_with_keys_split_primitive_tensor,
)
@@ -1137,8 +1411,45 @@ def unflatten_replicated_tensor(
return ReplicatedTensor(ts=list(values), name=ctx["name"])
+def flatten_with_keys_replicated_tensor(t: ReplicatedTensor):
+ values, context = flatten_replicated_tensor(t)
+ return [(SequenceKey(i), v) for i, v in enumerate(values)], context
+
+
register_pytree_node(
ReplicatedTensor,
flatten_fn=flatten_replicated_tensor,
unflatten_fn=unflatten_replicated_tensor,
+ flatten_with_keys_fn=flatten_with_keys_replicated_tensor,
+)
+
+
+def flatten_unreduced_tensor(
+ t: UnreducedTensor,
+) -> Tuple[List[Any], torch.utils._pytree.Context]:
+ return list(t.shards), {"name": t.name}
+
+
+def unflatten_unreduced_tensor(
+ values: Iterable[Any], ctx: torch.utils._pytree.Context
+) -> UnreducedTensor:
+ return UnreducedTensor(ts=list(values), name=ctx["name"])
+
+
+def flatten_with_keys_unreduced_tensor(t: UnreducedTensor):
+ values, context = flatten_unreduced_tensor(t)
+ return [(SequenceKey(i), v) for i, v in enumerate(values)], context
+
+
+register_pytree_node(
+ UnreducedTensor,
+ flatten_fn=flatten_unreduced_tensor,
+ unflatten_fn=unflatten_unreduced_tensor,
+ flatten_with_keys_fn=flatten_with_keys_unreduced_tensor,
)
+
+
+def torch_tree_flatten(tree: tree_utils.Tree):
+ """Flatten a tree of tensors the same way they will be flattened during torch.export.export
+ if they are arguments or results of a function signature."""
+ return torch.utils._pytree.tree_flatten(tree=tree)
diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py
index 975f54d24..29bc29bb8 100644
--- a/sharktank/sharktank/types/theta.py
+++ b/sharktank/sharktank/types/theta.py
@@ -15,7 +15,7 @@
import torch
import torch.nn.functional as F
-from shark_turbine.aot import (
+from iree.turbine.aot import (
ExternalTensorTrait,
ParameterArchive,
ParameterArchiveEntry,
@@ -110,6 +110,18 @@ def transform(self, *transforms: InferenceTensorTransform) -> "Theta":
def to(self, *, device: Optional[Union[str, torch.device]] = None) -> "Theta":
return self.transform(InferenceTensorTransforms.to_device(device))
+ def pop(self, *name_path: str | int) -> "Theta":
+ # prune a subtree from the tree and return it as a new Theta object
+ name_path = ".".join(_norm_name_path(name_path))
+ flat = self.flatten()
+ accum = {}
+ key_list = list(flat.keys())
+ for key in key_list:
+ if key.startswith(name_path):
+ accum[key] = flat.pop(key)
+ self._tree = flat_to_nested_dict(flat)
+ return Theta(flat_to_nested_dict(accum))
+
def flatten(self) -> dict[str, InferenceTensor]:
results = {}
diff --git a/sharktank/sharktank/utils/__init__.py b/sharktank/sharktank/utils/__init__.py
new file mode 100644
index 000000000..3651913ca
--- /dev/null
+++ b/sharktank/sharktank/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .misc import *
diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py
index 3a13b1d45..bc0b3b0b6 100644
--- a/sharktank/sharktank/utils/cli.py
+++ b/sharktank/sharktank/utils/cli.py
@@ -61,6 +61,29 @@ def add_output_dataset_options(parser: argparse.ArgumentParser):
)
+def add_model_options(parser: argparse.ArgumentParser):
+ """Adds model config options not exclusive to export or eager"""
+ parser.add_argument(
+ "--attention-kernel",
+ type=str,
+ default="decomposed",
+ choices=["decomposed", "torch"],
+ )
+ parser.add_argument(
+ "--skip-decode",
+ help="Enables prefill only, skips decode",
+ action="store_true",
+ )
+
+
+def add_quantization_options(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--fake-quant",
+ action=argparse.BooleanOptionalAction,
+ help="whether or not to run/export the model in fake quant mode. Note, running eagerly without fake quant is dependent on torch types supporting operations. YMMV",
+ )
+
+
def add_tokenizer_options(parser: argparse.ArgumentParser):
"""Adds options for specifying a tokenizer.
@@ -106,6 +129,8 @@ def get_input_dataset(args) -> Dataset:
if "irpa" in data_files:
return Dataset.load(data_files["irpa"], file_type="irpa")
+ raise ValueError(f'Dataset format unsupported. Must be "gguf" or "irpa".')
+
def get_tokenizer(args) -> tokenizer.InferenceTokenizer:
"""Gets a tokenizer based on arguments.
diff --git a/sharktank/sharktank/utils/create_cache.py b/sharktank/sharktank/utils/create_cache.py
new file mode 100644
index 000000000..c1691c8a8
--- /dev/null
+++ b/sharktank/sharktank/utils/create_cache.py
@@ -0,0 +1,34 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ..layers import *
+
+
+def create_kv_cache(config: LlamaModelConfig) -> BaseKVCache:
+ hp = config.hp
+ if config.kv_cache_type == "direct":
+ return DirectKVCache(
+ block_seq_stride=config.block_seq_stride,
+ transformer_block_count=hp.block_count,
+ attn_head_count=hp.attention_head_count_kv,
+ attn_head_dim=hp.attn_head_dim,
+ seq_length=hp.context_length,
+ device=config.device,
+ dtype=config.attention_dtype,
+ )
+ elif config.kv_cache_type == "paged":
+ return PagedKVCache(
+ transformer_block_count=hp.block_count,
+ attn_head_count=hp.attention_head_count_kv,
+ attn_head_dim=hp.attn_head_dim,
+ cache_partition_count=2, # One for each of K/V.
+ block_seq_stride=config.block_seq_stride,
+ device=config.device,
+ dtype=config.attention_dtype,
+ shard_count=config.tensor_parallelism_size,
+ )
+ else:
+ raise NotImplementedError(f"kv_cache_type = {config.kv_cache_type}")
diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py
new file mode 100644
index 000000000..c950a875a
--- /dev/null
+++ b/sharktank/sharktank/utils/export_artifacts.py
@@ -0,0 +1,333 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+import sys
+import subprocess
+import logging
+import time
+import re
+from pathlib import Path
+from datetime import timedelta
+from typing import List, Optional
+
+import iree.compiler as ireec
+
+logger = logging.getLogger("eval")
+
+logger.setLevel(logging.INFO)
+
+logger.root.handlers[0].setFormatter(
+ logging.Formatter(fmt="\n%(levelname)s:%(name)-8s %(message)s")
+)
+
+
+class ExportMlirException(Exception):
+ """shark-ai export MLIR exception that preserves the command line and error output."""
+
+ def __init__(self, process: subprocess.CompletedProcess, cwd: str):
+ try:
+ errs = process.stderr.decode("utf-8")
+ except:
+ errs = str(process.stderr)
+ super().__init__(
+ f"Error invoking export_paged_llama_v1.py\n"
+ f"Error code: {process.returncode}\n"
+ f"Stderr diagnostics:\n{errs}\n\n"
+ f"Invoked with:\n"
+ f" cd {cwd} && {process.args}\n\n"
+ )
+
+
+class IreeCompileException(Exception):
+ """Compiler exception that preserves the command line and error output."""
+
+ def __init__(self, process: subprocess.CompletedProcess, cwd: str):
+ try:
+ errs = process.stderr.decode("utf-8")
+ except:
+ errs = str(process.stderr)
+ super().__init__(
+ f"Error invoking iree-compile\n"
+ f"Error code: {process.returncode}\n"
+ f"Stderr diagnostics:\n{errs}\n\n"
+ f"Invoked with:\n"
+ f" cd {cwd} && {process.args}\n\n"
+ )
+
+
+class IreeBenchmarkException(Exception):
+ """Runtime exception that preserves the command line and error output."""
+
+ def __init__(self, process: subprocess.CompletedProcess, cwd: str):
+ # iree-run-module sends output to both stdout and stderr
+ try:
+ errs = process.stderr.decode("utf-8")
+ except:
+ errs = str(process.stderr)
+ try:
+ outs = process.stdout.decode("utf-8")
+ except:
+ outs = str(process.stdout)
+ super().__init__(
+ f"Error invoking iree-benchmark-module\n"
+ f"Error code: {process.returncode}\n"
+ f"Stderr diagnostics:\n{errs}\n"
+ f"Stdout diagnostics:\n{outs}\n"
+ f"Run with:\n"
+ f" cd {cwd} && {process.args}\n\n"
+ )
+
+
+class ExportArtifacts:
+ def __init__(
+ self,
+ *,
+ irpa_path: str,
+ batch_size: int,
+ iree_hip_target: str,
+ iree_hal_target_backends: str,
+ attention_kernel: str,
+ tensor_parallelism_size: int,
+ ):
+ self.sharktank_dir = str(
+ Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent
+ )
+ self.irpa_path = irpa_path
+ self.batch_size = batch_size
+ self.iree_hip_target = iree_hip_target
+ self.iree_hal_target_backends = iree_hal_target_backends
+ self.attention_kernel = attention_kernel
+ self.tensor_parallelism_size = tensor_parallelism_size
+
+ def timeit(func):
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ result = func(*args, **kwargs)
+ end = time.time()
+ total_seconds = end - start
+ time_taken = abs(timedelta(seconds=total_seconds))
+ hours, minutes, seconds = re.split(":", str(time_taken))
+
+ if total_seconds < 1:
+ time_taken = f" {round(total_seconds * 1000, 3)} ms"
+ elif total_seconds < 60:
+ time_taken = "{:.2f} secs".format(round(float(total_seconds), 2))
+ else:
+ time_taken = "{:02d} hrs : {:02d} mins : {:.2f} secs".format(
+ int(hours), int(minutes), round(float(seconds), 2)
+ )
+
+ func_name = func.__name__
+ logger.info(f" {func_name}: {time_taken}")
+ return result
+
+ return wrapper
+
+ @timeit
+ def shard_irpa_file(
+ self,
+ *,
+ irpa_file: str,
+ output_irpa: str,
+ ):
+ shard_irpa_args = [
+ "python3",
+ "-m",
+ "sharktank.examples.sharding.shard_llm_dataset",
+ "--irpa-file",
+ irpa_file,
+ "--output-irpa-file",
+ output_irpa,
+ "--tensor-parallelism-size",
+ str(self.tensor_parallelism_size),
+ ]
+
+ cwd = self.sharktank_dir
+ cmd = subprocess.list2cmdline(shard_irpa_args)
+
+ logger.info(f"Sharding irpa file:\n" f"cd {cwd} && {cmd}")
+
+ proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True)
+ if proc.returncode != 0:
+ logger.error(
+ f"Error sharding irpa file with shard_llm_dataset.py\n"
+ f"{proc.stdout+proc.stderr}"
+ )
+ else:
+ logger.info(f"Sharded irpa file successfully:\n" f"{proc.stdout}")
+
+ return proc.returncode
+
+ @timeit
+ def export_to_mlir(
+ self,
+ *,
+ mlir_path: str,
+ json_path: str,
+ skip_decode: Optional[bool] = None,
+ ):
+ export_args = [
+ "python3",
+ "-m",
+ "sharktank.examples.export_paged_llm_v1",
+ f"--irpa-file={self.irpa_path}",
+ f"--output-mlir={mlir_path}",
+ f"--output-config={json_path}",
+ f"--bs={str(self.batch_size)}",
+ ]
+ if skip_decode:
+ export_args.append("--skip-decode")
+ if self.attention_kernel in ["decomposed", "torch"]:
+ export_args.append("--attention-kernel")
+ export_args.append(self.attention_kernel)
+
+ cwd = self.sharktank_dir
+ cmd = subprocess.list2cmdline(export_args)
+
+ logger.info(f" Exporting mlir:\n" f"cd {cwd} && {cmd}")
+
+ proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd, text=True)
+ if proc.returncode != 0:
+ raise ExportMlirException(proc, cwd)
+ else:
+ logger.info(f" Exported to mlir successfully:\n" f"{proc.stdout}")
+
+ return proc.returncode
+
+ @timeit
+ def compile_to_vmfb(
+ self,
+ *,
+ mlir_path,
+ vmfb_path,
+ cwd,
+ hal_dump_path: Optional[Path] = None,
+ args: Optional[List[str]] = None,
+ ):
+ # TODO: Control flag to enable multiple backends
+ compile_args = [
+ f"iree-compile",
+ f"{mlir_path}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"-o={vmfb_path}",
+ ]
+ if self.tensor_parallelism_size > 1:
+ iree_hal_target_devices = [
+ f"--iree-hal-target-device=hip[{i}]"
+ for i in range(self.tensor_parallelism_size)
+ ]
+ compile_args += iree_hal_target_devices
+ if hal_dump_path:
+ compile_args += [
+ f"--iree-hal-dump-executable-files-to={hal_dump_path}/files"
+ ]
+ # Append optional arguments if provided
+ if args:
+ compile_args += args
+ cmd = subprocess.list2cmdline(compile_args)
+
+ logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}")
+ proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd)
+ return_code = proc.returncode
+ if return_code != 0:
+ raise IreeCompileException(proc, cwd)
+
+ def iree_benchmark_vmfb(
+ self,
+ *,
+ hip_device_id: str,
+ vmfb_name: str,
+ irpa_path: str,
+ args: List[str],
+ cwd: str | Path,
+ ):
+ """Runs a compiled program with the given args using `iree-benchmark-module`.
+ This assumes that the `iree-benchmark-module` command is available (usually via PATH).
+ Args:
+ vmfb_name: Name of the .vmfb file (relative to `cwd`).
+ args: List of arguments to pass to `iree-benchmark-module`.
+ cwd: Working directory to run the command within. (either string or Path works)
+ compile_cmd: Command used to compile the program, for inclusion in error messages.
+ Raises Exception if running fails for some reason.
+ """
+ benchmark_args = []
+ if self.tensor_parallelism_size > 1:
+ base_irpa_path, _ = os.path.splitext(irpa_path)
+ rocr_visible_devices = [
+ f"ROCR_VISIBLE_DEVICES={','.join(str(i) for i in range(self.tensor_parallelism_size))}"
+ ]
+ params = [f"--parameters=model={base_irpa_path}.irpa"]
+ params += [
+ f"--parameters=model={base_irpa_path}.rank{i}.irpa"
+ for i in range(self.tensor_parallelism_size)
+ ]
+ devices = [
+ f"--device=hip://{i}" for i in range(self.tensor_parallelism_size)
+ ]
+ else:
+ rocr_visible_devices = [f"ROCR_VISIBLE_DEVICES={hip_device_id}"]
+ params = [f"--parameters=model={irpa_path}"]
+ devices = [f"--device=hip://{hip_device_id}"]
+ benchmark_args += rocr_visible_devices
+ benchmark_args += [
+ "iree-benchmark-module",
+ "--hip_use_streams=true",
+ "--hip_allow_inline_execution=true",
+ "--device_allocator=caching",
+ f"--module={vmfb_name}",
+ ]
+ benchmark_args += params
+ benchmark_args += devices
+ benchmark_args += args
+ cmd = subprocess.list2cmdline(benchmark_args)
+ logger.info(f" Launching run command:\n" f"cd {cwd} && {cmd}")
+ proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd)
+ return_code = proc.returncode
+ if return_code != 0:
+ raise IreeBenchmarkException(proc, cwd)
+
+ def create_file(self, *, suffix, prefix):
+ file_path = Path(prefix).with_suffix(suffix)
+ f = open(file_path, "w")
+ return file_path
+
+ def get_artifacts(self):
+
+ self.dir_path = self.sharktank_dir + "/" + "tmp_perplexity_ci_artifacts/"
+ temp_dir = Path(self.dir_path)
+ temp_dir.mkdir(parents=True, exist_ok=True)
+
+ model_name = (
+ str(self.irpa_path).split("/")[-1].split(".")[0]
+ + "_"
+ + self.attention_kernel
+ )
+ mlir_path = str(
+ self.create_file(suffix=".mlir", prefix=self.dir_path + model_name)
+ )
+ json_path = str(
+ self.create_file(suffix=".json", prefix=self.dir_path + model_name)
+ )
+ vmfb_path = str(
+ self.create_file(suffix=".vmfb", prefix=self.dir_path + model_name)
+ )
+
+ if self.attention_kernel == "decomposed":
+ returncode = self.export_to_mlir(
+ mlir_path=mlir_path,
+ json_path=json_path,
+ )
+
+ if returncode == 0:
+ self.compile_to_vmfb(
+ mlir_path=mlir_path,
+ vmfb_path=vmfb_path,
+ cwd=self.sharktank_dir,
+ )
+
+ return vmfb_path
diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py
index e49623a8d..0562d5854 100644
--- a/sharktank/sharktank/utils/hf_datasets.py
+++ b/sharktank/sharktank/utils/hf_datasets.py
@@ -83,6 +83,23 @@ def alias_dataset(from_name: str, to_name: str):
# Dataset definitions
################################################################################
+Dataset(
+ "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF",
+ (
+ RemoteFile(
+ "gguf",
+ "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF",
+ "meta-llama-3.1-8b-instruct.f16.gguf",
+ ),
+ RemoteFile(
+ "tokenizer_config.json",
+ "NousResearch/Meta-Llama-3-8B",
+ "tokenizer_config.json",
+ extra_filenames=["tokenizer.json"],
+ ),
+ ),
+).alias_to("llama3_8B_fp16")
+
Dataset(
"QuantFactory/Llama-3-8B_q4_1_gguf",
(
@@ -247,6 +264,86 @@ def alias_dataset(from_name: str, to_name: str):
),
).alias_to("mixtral_8x7b_q8_0_gguf")
+Dataset(
+ "amd-shark/llama3.1-8B",
+ (
+ RemoteFile(
+ "gguf",
+ "amd-shark/llama-quant-models",
+ "llama3.1-8b/llama8b_f16.gguf",
+ ),
+ RemoteFile(
+ "tokenizer_config.json",
+ "amd-shark/llama-quant-models",
+ "llama3.1-8b/tokenizer_config.json",
+ extra_filenames=["llama3.1-8b/tokenizer.json"],
+ ),
+ ),
+).alias_to("llama3_8B_f16")
+
+Dataset(
+ "amd-shark/llama2-7B",
+ (
+ RemoteFile(
+ "gguf",
+ "amd-shark/llama-quant-models",
+ "llama2-7b/llama2_7b_f16.gguf",
+ ),
+ RemoteFile(
+ "tokenizer_config.json",
+ "amd-shark/llama-quant-models",
+ "llama2-7b/tokenizer_config.json",
+ extra_filenames=["llama2-7b/tokenizer.json"],
+ ),
+ ),
+).alias_to("llama2_7B_f16")
+
+Dataset(
+ "google/t5-v1_1-small",
+ (
+ RemoteFile(
+ "config",
+ "google/t5-v1_1-small",
+ "config.json",
+ extra_filenames=["generation_config.json", "special_tokens_map.json"],
+ ),
+ RemoteFile(
+ "tokenizer_config.json",
+ "google/t5-v1_1-small",
+ "tokenizer_config.json",
+ extra_filenames=["spiece.model"],
+ ),
+ RemoteFile(
+ "pytorch_model.bin",
+ "google/t5-v1_1-small",
+ "pytorch_model.bin",
+ ),
+ ),
+)
+
+Dataset(
+ "google/t5-v1_1-xxl",
+ (
+ RemoteFile(
+ "config",
+ "google/t5-v1_1-xxl",
+ "config.json",
+ extra_filenames=["generation_config.json", "special_tokens_map.json"],
+ ),
+ RemoteFile(
+ "tokenizer_config.json",
+ "google/t5-v1_1-xxl",
+ "tokenizer_config.json",
+ extra_filenames=["spiece.model"],
+ ),
+ RemoteFile(
+ "pytorch_model.bin",
+ "google/t5-v1_1-xxl",
+ "pytorch_model.bin",
+ ),
+ ),
+)
+
################################################################################
# Tool entrypoint
################################################################################
diff --git a/sharktank/sharktank/utils/io.py b/sharktank/sharktank/utils/io.py
index 62fd78f33..ac2480846 100644
--- a/sharktank/sharktank/utils/io.py
+++ b/sharktank/sharktank/utils/io.py
@@ -6,7 +6,7 @@
from pathlib import Path
-from shark_turbine.aot import (
+from iree.turbine.aot import (
ParameterArchiveBuilder,
)
diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py
new file mode 100644
index 000000000..d5976ec48
--- /dev/null
+++ b/sharktank/sharktank/utils/iree.py
@@ -0,0 +1,192 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import iree.runtime
+from typing import List, Tuple, Optional, Union
+from pathlib import Path
+import torch
+import numpy as np
+import collections.abc
+from collections import OrderedDict
+from ..types.tensors import (
+ AnyTensor,
+ InferenceTensor,
+ ShardedTensor,
+ DefaultPrimitiveTensor,
+ unbox_tensor,
+ torch_tree_flatten,
+)
+from .tree import Tree
+
+
+def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]:
+ hal_driver = iree.runtime.get_driver(driver)
+ available_devices = hal_driver.query_available_devices()
+ if driver in ["local-task", "local-sync"]:
+ # Use the same actual device for all devices.
+ return [
+ hal_driver.create_device(available_devices[0]) for _ in range(device_count)
+ ]
+ else:
+ return [
+ hal_driver.create_device(available_devices[i]) for i in range(device_count)
+ ]
+
+
+def load_iree_module(
+ module_path: str,
+ devices: List[iree.runtime.HalDevice],
+ parameters_path: Optional[str] = None,
+) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]:
+ """The VmContext and VmInstance need to outlive the VmModule and any device
+ buffers."""
+ vm_instance = iree.runtime.VmInstance()
+ hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices)
+ modules = [hal_module]
+ if parameters_path is not None:
+ params_path = Path(parameters_path)
+ parameter_index = iree.runtime.ParameterIndex()
+ if len(devices) > 1:
+ # TODO: make IREE able to load the parameters from the top parameter file
+ # without having to specify the parameter file for each shard separately.
+ for i in range(len(devices)):
+ parameter_index.load(
+ file_path=str(
+ Path(params_path).with_suffix(f".rank{i}{params_path.suffix}")
+ )
+ )
+ else:
+ parameter_index.load(file_path=str(params_path))
+ parameter_provider = parameter_index.create_provider(scope="model")
+ parameters_module = iree.runtime.create_io_parameters_module(
+ vm_instance, parameter_provider
+ )
+ modules.append(parameters_module)
+ vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path))
+ modules.append(vm_module)
+ vm_context = iree.runtime.VmContext(instance=vm_instance, modules=modules)
+ return vm_module, vm_context, vm_instance
+
+
+def run_iree_module_function(
+ module: iree.runtime.VmModule,
+ vm_context: iree.runtime.VmContext,
+ args: List[iree.runtime.DeviceArray],
+ driver: str,
+ function_name: str = "main",
+ trace_path_prefix: Optional[str] = None,
+) -> List[iree.runtime.DeviceArray]:
+ """Run IREE module function with optional tracing of arguments/results."""
+ vm_function = module.lookup_function(function_name)
+ invoker = iree.runtime.FunctionInvoker(
+ vm_context=vm_context,
+ # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
+ # This works, but does not look right.
+ device=iree.runtime.get_device(driver, cache=False),
+ vm_function=vm_function,
+ )
+ if trace_path_prefix is not None:
+ for i, arg in enumerate(args):
+ np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host())
+ results = invoker(*args)
+ if isinstance(results, iree.runtime.DeviceArray):
+ results = (results,)
+
+ if trace_path_prefix is not None:
+ for i, arg in enumerate(args):
+ np.save(
+ f"{trace_path_prefix}{function_name}_arg{i}_post_call.npy",
+ arg.to_host(),
+ )
+ for i, arg in enumerate(results):
+ np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host())
+ return results
+
+
+def prepare_iree_module_function_args(
+ args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice]
+) -> List[iree.runtime.DeviceArray]:
+ """Flatten composite tensors into their parts and place them on devices.
+ Sharded tensors become a list of their shards while placing them onto their
+ corresponding device.
+ All unsharded tensors go on device 0.
+ """
+ res = []
+ for arg in args:
+ if isinstance(arg, ShardedTensor):
+ assert len(devices) == len(arg.shards)
+ res.extend(
+ [
+ prepare_iree_module_function_args([shard], [device])[0]
+ for shard, device in zip(arg.shards, devices)
+ ]
+ )
+ elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)):
+ res.append(
+ iree.runtime.asdevicearray(
+ devices[0], unbox_tensor(arg).to("cpu").numpy()
+ )
+ )
+ else:
+ assert isinstance(arg, collections.abc.Sequence)
+ res.extend(prepare_iree_module_function_args(arg, devices))
+ return res
+
+
+def flatten_for_iree_signature(tree: Tree) -> List[torch.Tensor]:
+ """Flatten a tree of arguments or results for an IREE call.
+ E.g. sharded tensors gets flattened into their shards."""
+ return torch_tree_flatten(tree)[0]
+
+
+def call_torch_module_function(
+ module: torch.nn.Module,
+ function_name: str,
+ kwargs: OrderedDict,
+ trace_path_prefix: Optional[str] = None,
+):
+ """Call a torch module function with optional tracing.
+ For tracing the arguments/results are flattened to match IREE's signature."""
+ assert isinstance(
+ kwargs, OrderedDict
+ ), "Make sure when flattening the order is preserved"
+ if trace_path_prefix is not None:
+ flat_args = flatten_for_iree_signature(kwargs)
+ for i, arg in enumerate(flat_args):
+ np.save(
+ f"{trace_path_prefix}{function_name}_arg{i}.npy",
+ arg.to("cpu").numpy(),
+ )
+ res = getattr(module, function_name)(**kwargs)
+ if trace_path_prefix is not None:
+ flat_args = flatten_for_iree_signature(kwargs)
+ for i, arg in enumerate(flat_args):
+ np.save(
+ f"{trace_path_prefix}{function_name}_arg{i}.npy",
+ arg.to("cpu").numpy(),
+ )
+ results = (
+ (res,)
+ if isinstance(
+ res,
+ (
+ torch.Tensor,
+ InferenceTensor,
+ ),
+ )
+ else res
+ )
+ flat_results = flatten_for_iree_signature(results)
+ for i, result in enumerate(flat_results):
+ np.save(
+ f"{trace_path_prefix}{function_name}_result{i}.npy",
+ result.to("cpu").numpy(),
+ )
+ return res
+
+
+def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
+ return [torch.tensor(tensor.to_host()) for tensor in tensors]
diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py
new file mode 100644
index 000000000..47d9f0244
--- /dev/null
+++ b/sharktank/sharktank/utils/load_llm.py
@@ -0,0 +1,211 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import math
+
+import torch
+
+from sharktank.layers import *
+from sharktank.types import *
+from sharktank.models.llama.llama import *
+
+from ..utils.debugging import trace_tensor
+from ..utils.tokenizer import InferenceTokenizer
+
+
+class TorchGenerator:
+ """Generator that runs directly on the Torch model."""
+
+ def __init__(
+ self,
+ model: PagedLlamaModelV1,
+ tokenizer: InferenceTokenizer,
+ # Need to look at the model more for this.
+ end_token: int = 2,
+ ):
+ self.model = model
+ self.tokenizer = tokenizer
+ self.end_token = end_token
+
+ @property
+ def block_seq_stride(self) -> int:
+ return self.model.cache.block_seq_stride
+
+ def begin_batch(
+ self, prompts: list[str], add_start_token: bool, page_cache_size: int = 128
+ ):
+ token_ids, seq_lens = self.tokenizer.encode(
+ prompts,
+ pad_to_multiple_of=self.model.cache.pad_sequence_stride,
+ add_start_token=add_start_token,
+ )
+ token_ids = torch.tensor(token_ids, device=self.model.device)
+ seq_lens = torch.tensor(seq_lens, device=self.model.device)
+
+ if self.model.cache.is_paged:
+ cache_state = self.model.cache.paged.allocate(page_cache_size)
+ self.free_pages = list(range(1, page_cache_size))
+ else:
+ cache_state = self.model.cache.direct.allocate(bs=len(prompts))
+ return Batch(self, token_ids, seq_lens, cache_state)
+
+ def begin_eval_batch(
+ self,
+ token_batch: torch.tensor,
+ seq_lens_batch: torch.tensor,
+ bs: int,
+ page_cache_size: int = 128,
+ ):
+ if self.model.cache.is_paged:
+ cache_state = self.model.cache.paged.allocate(page_cache_size)
+ self.free_pages = list(range(1, page_cache_size))
+ else:
+ cache_state = self.model.cache.direct.allocate(bs=bs)
+ return Batch(self, token_batch, seq_lens_batch, cache_state)
+
+ def alloc_page(self) -> int:
+ if self.model.cache.is_direct:
+ # We don't allocate block ids for the direct cache.
+ return 0
+
+ return self.free_pages.pop()
+
+ def release_page(self, index: int):
+ if self.model.cache.is_direct:
+ return
+ self.free_pages.append(index)
+
+
+class Batch:
+ def __init__(
+ self,
+ parent: TorchGenerator,
+ token_ids: torch.Tensor,
+ seq_lens: torch.Tensor,
+ cache_state: list[torch.Tensor],
+ ):
+ self.bs = token_ids.shape[0]
+ # assert seq_lens.shape[0] == self.bs
+ self.parent = parent
+ self.token_ids = token_ids
+ self.seq_lens = seq_lens
+ self.cache_state = cache_state
+ self.results: list[list[int]] = [[] for _ in range(self.bs)]
+ self.done_result_indices: set[int] = set()
+
+ # Assemble the batch.
+ seq_stride = self.parent.block_seq_stride
+ self.seq_block_ids: list[list[int]] = []
+ for seq_len in self.seq_lens:
+ blocks_needed = (
+ int(math.ceil(seq_len / seq_stride)) if seq_stride > 0 else 0
+ )
+ row = []
+ for _ in range(blocks_needed):
+ row.append(self.parent.alloc_page())
+ self.seq_block_ids.append(row)
+
+ @property
+ def done(self) -> bool:
+ return len(self.done_result_indices) == self.bs
+
+ def detokenize(self) -> list[str]:
+ return self.parent.tokenizer.decode(self.results)
+
+ def print_current_results(self):
+ results = self.detokenize()
+ for i, s in enumerate(results):
+ seq_len = int(self.seq_lens[i])
+ print(f" {i}({len(self.results[i])}, {seq_len}): {s}")
+
+ def add_result_token(self, tokens: torch.Tensor):
+ for i in range(self.bs):
+ token = tokens[i][0]
+ if token == self.parent.end_token:
+ self.done_result_indices.add(i)
+ if i in self.done_result_indices:
+ continue
+ token = int(tokens[i, 0])
+ self.results[i].append(token)
+
+ def allocate_seq_block_ids(self):
+ for i in range(self.bs):
+ sl = int(self.seq_lens[i])
+ if (sl % self.parent.block_seq_stride) == 0:
+ needed_blocks = sl // self.parent.block_seq_stride + 1
+ else:
+ needed_blocks = math.ceil(sl / self.parent.block_seq_stride)
+ block_ids_row = self.seq_block_ids[i]
+ while len(block_ids_row) < needed_blocks:
+ block_ids_row.append(self.parent.alloc_page())
+
+ def prefill(self):
+ model = self.parent.model
+ attention_mask = model.attention_mask(
+ model.input_mask(self.seq_lens, self.token_ids.shape[1])
+ )
+ seq_block_ids_tensor = self.pad_block_ids()
+ trace_tensor("prefill.token_ids", self.token_ids)
+ trace_tensor("prefill.seq_block_ids", seq_block_ids_tensor)
+ trace_tensor("prefill.attention_mask", attention_mask)
+ self.prefill_logits = model.prefill(
+ self.token_ids,
+ attention_mask=attention_mask,
+ seq_block_ids=seq_block_ids_tensor,
+ cache_state=self.cache_state,
+ )
+
+ # TODO: Generalize the sampling and don't make it swap on/off cpu.
+ # TODO: Normalize the output of extract_tokens_from_logits into
+ # tensor [bs, 1].
+ tokens = torch.tensor(
+ model.extract_tokens_from_logits(self.prefill_logits, self.seq_lens)
+ ).unsqueeze(1)
+ self.add_result_token(tokens)
+ self.next_tokens = tokens.to(device=model.device)
+
+ def decode(self, token_batch):
+ self.token_batch = token_batch
+
+ model = self.parent.model
+ start_positions = self.seq_lens.clone()
+ self.seq_lens.add_(1)
+ self.allocate_seq_block_ids()
+ # TODO: Allocate more blocks on overflow.
+ seq_block_ids_tensor = self.pad_block_ids()
+ decode_attention_mask = model.decode_attention_mask(
+ model.input_mask(
+ self.seq_lens,
+ seq_block_ids_tensor.shape[1] * self.parent.block_seq_stride,
+ )
+ )
+ trace_tensor("decode.token_ids", self.token_ids)
+ trace_tensor("decode.start_positions", start_positions)
+ trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
+ trace_tensor("decode.attention_mask", decode_attention_mask)
+
+ self.decode_logits = model.decode(
+ self.token_batch,
+ attention_mask=decode_attention_mask,
+ start_positions=start_positions,
+ seq_block_ids=seq_block_ids_tensor,
+ cache_state=self.cache_state,
+ )
+
+ trace_tensor("decode.logits", self.decode_logits)
+ # # TODO: Normalize the output of extract_tokens_from_logits into
+ # # tensor [bs, 1].
+ tokens = torch.tensor(
+ model.extract_tokens_from_logits(self.decode_logits, [1] * self.bs),
+ device=self.parent.model.device,
+ ).unsqueeze(1)
+ self.add_result_token(tokens)
+ self.next_tokens = tokens
+
+ def pad_block_ids(self) -> torch.Tensor:
+ max_length = max(len(r) for r in self.seq_block_ids)
+ rows = [r + (max_length - len(r)) * [0] for r in self.seq_block_ids]
+ return torch.tensor(rows, device=self.parent.model.device)
diff --git a/sharktank/sharktank/utils/logging.py b/sharktank/sharktank/utils/logging.py
index 977462d86..3801f96cb 100644
--- a/sharktank/sharktank/utils/logging.py
+++ b/sharktank/sharktank/utils/logging.py
@@ -6,7 +6,7 @@
import logging
-from shark_turbine.support.logging import get_logger
+from iree.turbine.support.logging import get_logger
transform_logger: logging.Logger = get_logger("sharktank.transforms")
diff --git a/sharktank/sharktank/utils/math.py b/sharktank/sharktank/utils/math.py
index df47b5ae6..3f32ac952 100644
--- a/sharktank/sharktank/utils/math.py
+++ b/sharktank/sharktank/utils/math.py
@@ -4,6 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from numbers import Number
+
def ceildiv(a: int | float, b: int | float) -> int | float:
return -(a // -b)
+
+
+def round_up_to_multiple_of(x: Number, multiple: Number) -> Number:
+ return x + (-x % multiple)
diff --git a/sharktank/sharktank/utils/misc.py b/sharktank/sharktank/utils/misc.py
new file mode 100644
index 000000000..dd1ebf4d6
--- /dev/null
+++ b/sharktank/sharktank/utils/misc.py
@@ -0,0 +1,30 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Any, Callable, List
+from collections.abc import Iterable
+from operator import eq
+
+
+def longest_equal_range(l1: List[Any], l2: List[Any]) -> int:
+ """Find the longest range that is the same from the start of both lists.
+ Returns the greatest `i` such that `l1[0:i] == l2[0:i]`."""
+ for i, (a, b) in enumerate(zip(l1, l2)):
+ if a != b:
+ return i
+ return min(len(list(l1)), len(list(l2)))
+
+
+def iterables_equal(
+ iterable1: Iterable,
+ iterable2: Iterable,
+ *,
+ elements_equal: Callable[[Any, Any], bool] | None = None
+) -> bool:
+ elements_equal = elements_equal or eq
+ return all(
+ elements_equal(v1, v2) for v1, v2 in zip(iterable1, iterable2, strict=True)
+ )
diff --git a/sharktank/sharktank/utils/testing.py b/sharktank/sharktank/utils/testing.py
index 8494c1d9b..933bfd2b6 100644
--- a/sharktank/sharktank/utils/testing.py
+++ b/sharktank/sharktank/utils/testing.py
@@ -10,15 +10,26 @@
import shutil
import tempfile
import unittest
+import torch
+from typing import Any, Callable
+from operator import eq
+from collections.abc import Iterable
+import gc
from ..types import *
+# Range of torch.rand() is [0,1)
+# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
+def make_rand_torch(shape, dtype=torch.float32):
+ return torch.rand(shape, dtype=dtype) * 2 - 1
+
class TempDirTestBase(unittest.TestCase):
def setUp(self):
self._temp_dir = Path(tempfile.mkdtemp(type(self).__qualname__))
def tearDown(self):
+ gc.collect()
shutil.rmtree(self._temp_dir, ignore_errors=True)
@@ -93,6 +104,30 @@ def get_best_torch_device() -> str:
return "cpu"
+def assert_dicts_equal(
+ dict1: dict, dict2: dict, *, values_equal: Callable[[Any, Any], bool] | None = None
+) -> None:
+ values_equal = values_equal or eq
+ assert len(dict1) == len(
+ dict2
+ ), f"Dictionaries not equal. {dict1} and {dict2} have different number of elements {len(dict1)} != {len(dict2)}"
+ for k, v1 in dict1.items():
+ assert (
+ k in dict2
+ ), f"Dictionaries {dict1} and {dict2} not equal. Key {k} not found in {dict2}"
+ v2 = dict2[k]
+ assert values_equal(
+ v1, dict2[k]
+ ), f"Dictionaries {dict1} and {dict2} not equal for key {k}. Values {v1} and {v2} not equal"
+
+
+def assert_equal(
+ a: Any, b: Any, *, equal: Callable[[Any, Any], bool] | None = None
+) -> None:
+ equal = equal or eq
+ assert equal(a, b), f"{a} and {b} are not equal"
+
+
def assert_golden_safetensors(actual_path, ref_path):
"""Asserts that actual and reference safetensors files are within tolerances."""
from safetensors import safe_open
@@ -127,3 +162,36 @@ def print_stats(label, t):
actual = actual_f.get_tensor(name)
ref = ref_f.get_tensor(name)
torch.testing.assert_close(actual, ref, msg=name)
+
+
+def assert_iterables_equal(
+ iterable1: Iterable,
+ iterable2: Iterable,
+ *,
+ elements_equal: Callable[[Any, Any], bool] | None = None,
+) -> None:
+ elements_equal = elements_equal or eq
+ for i, (v1, v2) in enumerate(zip(iterable1, iterable2, strict=True)):
+ assert elements_equal(
+ v1, v2
+ ), f"Iterables not equal at index {i} for elements {v1} and {v2}"
+
+
+SHARKTANK_TEST_SKIP_ENV_VAR = "SHARKTANK_TEST_SKIP"
+
+
+def skip(*decorator_args, **decorator_kwargs):
+ """Decorator to skip a test when SHARKTANK_TEST_SKIP env var is not set or != 0"""
+
+ def decorator(test_item: Callable):
+ if SHARKTANK_TEST_SKIP_ENV_VAR not in os.environ:
+ should_skip = True
+ else:
+ should_skip = os.environ[SHARKTANK_TEST_SKIP_ENV_VAR] != "0"
+
+ if should_skip:
+ return unittest.skip(*decorator_args, **decorator_kwargs)(test_item)
+
+ return test_item
+
+ return decorator
diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py
index 29a57f958..b459c706a 100644
--- a/sharktank/sharktank/utils/tokenizer.py
+++ b/sharktank/sharktank/utils/tokenizer.py
@@ -22,34 +22,57 @@ class InferenceTokenizer(ABC):
"""Simple inference tokenizer."""
def encode(
- self, texts: list[str], pad_to_multiple_of: int = 1, pad_token: int = 0
+ self,
+ texts: list[str],
+ pad_to_multiple_of: int = 1,
+ add_start_token: bool = False,
) -> tuple[list[list[int]]]:
"""Encodes a list of texts into a padded list of tokens.
Returns a list of list of tokens and a list of unpadded lengths.
"""
- raw_rows = self._encode(texts)
+ raw_rows = self._encode(texts, add_start_token)
+ raw_rows, lengths = self.pad_tokens(
+ token_ids=raw_rows, pad_to_multiple_of=pad_to_multiple_of
+ )
+ return raw_rows, lengths
+
+ def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None):
+ """Decodes a list of tokens."""
+ if lens is not None:
+ tokens = list(tokens)
+ for i, row_length in enumerate(lens):
+ tokens[i] = tokens[i][0:row_length]
+ return self._decode(tokens)
+
+ def get_prompt_lengths(
+ self,
+ token_ids: list[list[int]],
+ ):
max_length = 0
lengths: list[int] = []
- for row in raw_rows:
+ for row in token_ids:
lengths.append(len(row))
max_length = max(max_length, len(row))
+
+ return lengths, max_length
+
+ def pad_tokens(
+ self,
+ token_ids: list[list[int]],
+ pad_to_multiple_of: int,
+ pad_token: int = 0,
+ ):
+ lengths, max_length = self.get_prompt_lengths(token_ids)
if pad_to_multiple_of > 1:
max_length = int(
pad_to_multiple_of * math.ceil(max_length / pad_to_multiple_of)
)
- for row in raw_rows:
+ for row in token_ids:
pad_count = max_length - len(row)
row.extend(pad_count * [pad_token])
- return raw_rows, lengths
- def decode(self, tokens: Union[list[list[int]]], lens: Optional[list[int]] = None):
- """Decodes a list of tokens."""
- if lens is not None:
- tokens = list(tokens)
- for i, row_length in enumerate(lens):
- tokens[i] = tokens[i][0:row_length]
- return self._decode(tokens)
+ return token_ids, lengths
@abstractmethod
def _encode(self, texts: list[str]) -> list[list[int]]:
@@ -76,9 +99,10 @@ class _TransformersTokenizer(InferenceTokenizer):
def __init__(self, t: AutoTokenizer):
self._t = t
- def _encode(self, texts: list[str]) -> list[list[int]]:
+ def _encode(self, texts: list[str], add_start_token: bool) -> list[list[int]]:
results = t.batch_encode_plus(
texts,
+ add_special_tokens=add_start_token,
padding=False,
truncation=False,
)
diff --git a/sharktank/sharktank/utils/vmfb_runner.py b/sharktank/sharktank/utils/vmfb_runner.py
new file mode 100644
index 000000000..cdbf96c9d
--- /dev/null
+++ b/sharktank/sharktank/utils/vmfb_runner.py
@@ -0,0 +1,82 @@
+from iree import runtime as ireert
+from iree.runtime._binding import create_hal_driver
+
+
+class vmfbRunner:
+ def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None):
+
+ # If an extra plugin is requested, add a global flag to load the plugin
+ # and create the driver using the non-caching creation function, as
+ # the caching creation function may ignore the flag.
+ if extra_plugin:
+ ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
+ haldriver = create_hal_driver(device)
+
+ # No plugin requested: create the driver with the caching create
+ # function.
+ else:
+ haldriver = ireert.get_driver(device)
+ if "://" in device:
+ try:
+ device_idx = int(device.split("://")[-1])
+ device_uri = None
+ except:
+ device_idx = None
+ device_uri = device.split("://")[-1]
+ else:
+ device_idx = 0
+ device_uri = None
+ if device_uri:
+ if not any(x in device for x in ["cpu", "task"]):
+ allocators = ["caching"]
+ haldevice = haldriver.create_device_by_uri(
+ device_uri, allocators=allocators
+ )
+ else:
+ haldevice = haldriver.create_device_by_uri(device_uri)
+ else:
+ hal_device_id = haldriver.query_available_devices()[device_idx]["device_id"]
+ if not any(x in device for x in ["cpu", "task"]):
+ allocators = ["caching"]
+ haldevice = haldriver.create_device(
+ hal_device_id, allocators=allocators
+ )
+ else:
+ haldevice = haldriver.create_device(hal_device_id)
+
+ self.config = ireert.Config(device=haldevice)
+ mods = []
+ if not isinstance(vmfb_path, list):
+ vmfb_path = [vmfb_path]
+ for path in vmfb_path:
+ mods.append(ireert.VmModule.mmap(self.config.vm_instance, path))
+ vm_modules = [
+ *mods,
+ ireert.create_hal_module(self.config.vm_instance, self.config.device),
+ ]
+
+ # TODO: Enable multiple weight files
+ if external_weight_path:
+ index = ireert.ParameterIndex()
+ if not isinstance(external_weight_path, list):
+ external_weight_path = [external_weight_path]
+ for i, path in enumerate(external_weight_path):
+ if path in ["", None]:
+ continue
+ index.load(path)
+ # TODO: extend scope
+ param_module = ireert.create_io_parameters_module(
+ self.config.vm_instance, index.create_provider(scope="model")
+ )
+ vm_modules.insert(i, param_module)
+ del param_module
+ del index
+
+ self.ctx = ireert.SystemContext(
+ vm_modules=vm_modules,
+ config=self.config,
+ )
+
+ def unload(self):
+ self.ctx = None
+ self.config = None
diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json
new file mode 100644
index 000000000..24511b05f
--- /dev/null
+++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json
@@ -0,0 +1,318 @@
+{
+ "llama3_8B_f16_decomposed": {
+ "perplexities": [
+ 6.677369,
+ 21.807926,
+ 15.424338,
+ 17.332415,
+ 14.951956,
+ 7.913092,
+ 8.728321,
+ 22.425966,
+ 8.184698,
+ 20.977249,
+ 7.088408,
+ 14.574989,
+ 9.036912,
+ 7.277581,
+ 16.132208,
+ 6.685175,
+ 6.525683,
+ 7.080791,
+ 10.680925,
+ 9.034086,
+ 10.639015,
+ 41.102894,
+ 11.723896,
+ 64.305908,
+ 47.054577,
+ 19.9259,
+ 18.918842,
+ 13.842684,
+ 9.974381,
+ 5.919641,
+ 10.181265,
+ 23.609016,
+ 14.340417,
+ 9.712208,
+ 5.602878,
+ 14.088163,
+ 5.680599,
+ 17.377926,
+ 9.037231,
+ 8.305407,
+ 8.028031,
+ 17.744528,
+ 11.5076,
+ 3.936302,
+ 12.987297,
+ 10.371798,
+ 11.927772,
+ 21.387051,
+ 37.799526,
+ 25.67762,
+ 15.429109,
+ 13.923962,
+ 7.594806,
+ 10.983875,
+ 14.595965,
+ 11.022234,
+ 5.853358,
+ 15.609065,
+ 8.044486,
+ 14.389134,
+ 5.917565,
+ 6.892455,
+ 2.30309,
+ 15.974725,
+ 42.017342,
+ 8.022307,
+ 12.284297,
+ 10.018423,
+ 9.268936,
+ 10.680118,
+ 8.12535,
+ 21.550434,
+ 3.638689,
+ 15.345065,
+ 23.742884,
+ 14.288899,
+ 17.796623,
+ 16.515446,
+ 8.746647,
+ 12.922096,
+ 12.94269,
+ 13.574061,
+ 14.013302,
+ 10.76523,
+ 14.746032,
+ 28.208134,
+ 17.646687,
+ 9.848188,
+ 15.280471,
+ 15.621455,
+ 29.126505,
+ 12.302313,
+ 32.452534,
+ 31.192411,
+ 14.371797,
+ 17.490683,
+ 14.689407,
+ 15.284843,
+ 12.252508,
+ 16.460979
+ ],
+ "mean_perplexity": 14.930181
+ },
+
+ "llama3_405B_f16_decomposed": {
+ "perplexities": [
+ 2.170036,
+ 8.014498,
+ 3.743922,
+ 10.629776,
+ 8.965701,
+ 2.884743,
+ 2.886767,
+ 3.853816,
+ 2.73785,
+ 15.235562,
+ 2.65135,
+ 1.970936,
+ 5.08259,
+ 2.507602,
+ 7.571635,
+ 3.005182,
+ 1.904492,
+ 3.182651,
+ 6.249443,
+ 4.661795,
+ 12.68933,
+ 35.432453,
+ 5.50336,
+ 60.950359,
+ 18.433432,
+ 5.001391,
+ 4.814827,
+ 2.99482,
+ 2.697508,
+ 2.617349,
+ 2.359061,
+ 16.697233,
+ 2.145065,
+ 2.1207,
+ 2.496015,
+ 1.822896,
+ 4.671626,
+ 2.389186,
+ 2.701802,
+ 1.921128,
+ 2.236057,
+ 4.741998,
+ 4.946936,
+ 2.758695,
+ 2.446043,
+ 2.146302,
+ 8.72202,
+ 4.180647,
+ 11.449497,
+ 13.429152,
+ 3.72468,
+ 2.407385,
+ 3.592854,
+ 5.412414,
+ 3.189998,
+ 4.186216,
+ 1.642744,
+ 2.279058,
+ 1.855652,
+ 3.453852,
+ 1.436223,
+ 1.516955,
+ 1.716439,
+ 4.715765,
+ 21.48657,
+ 2.208737,
+ 6.420449,
+ 2.001433,
+ 2.400955,
+ 3.543744,
+ 3.054271,
+ 7.904545,
+ 1.950376,
+ 3.983746,
+ 6.28265,
+ 2.64157,
+ 5.473378,
+ 3.444444,
+ 1.926046,
+ 3.092915,
+ 3.996159,
+ 3.125222,
+ 1.718025,
+ 3.856093,
+ 3.041075,
+ 11.798485,
+ 14.881112,
+ 5.631516,
+ 4.407883,
+ 4.840533,
+ 21.351448,
+ 2.065821,
+ 6.658993,
+ 28.123312,
+ 1.673253,
+ 3.729975,
+ 5.336116,
+ 8.579758,
+ 2.979404,
+ 1.915619
+ ],
+ "mean_perplexity": 6.060831
+ },
+ "llama3_8B_f16_decomposed_iree": {
+ "perplexities": [
+ 6.651368,
+ 22.059452,
+ 15.392176,
+ 17.418619,
+ 15.206824,
+ 7.907998,
+ 8.829535,
+ 22.355659,
+ 8.29262,
+ 20.958277,
+ 7.167404,
+ 14.592677,
+ 9.060788,
+ 7.274667,
+ 16.238981,
+ 6.666115,
+ 6.535679,
+ 7.086256,
+ 10.676177,
+ 8.979206,
+ 10.597121,
+ 42.038162,
+ 11.70071,
+ 65.731316,
+ 47.42622,
+ 20.109543,
+ 18.897541,
+ 13.781085,
+ 9.99165,
+ 5.955308,
+ 10.175659,
+ 23.628405,
+ 14.306578,
+ 9.719462,
+ 5.594786,
+ 14.198979,
+ 5.711433,
+ 17.381332,
+ 9.058512,
+ 8.286205,
+ 8.016202,
+ 18.4515,
+ 11.600831,
+ 3.945074,
+ 13.000222,
+ 10.373363,
+ 12.237907,
+ 21.408463,
+ 37.858665,
+ 25.794065,
+ 15.489001,
+ 14.004895,
+ 7.625473,
+ 10.993184,
+ 14.698832,
+ 11.062652,
+ 5.855446,
+ 15.625135,
+ 8.052419,
+ 14.365479,
+ 5.927001,
+ 6.931933,
+ 2.3014,
+ 15.769623,
+ 40.843319,
+ 8.022024,
+ 12.544907,
+ 10.090073,
+ 9.304819,
+ 10.679907,
+ 8.136175,
+ 21.540607,
+ 3.736973,
+ 15.381804,
+ 24.21562,
+ 14.385005,
+ 17.791706,
+ 16.498833,
+ 8.753955,
+ 12.941816,
+ 12.887664,
+ 13.725715,
+ 13.994792,
+ 10.769128,
+ 14.734674,
+ 26.970015,
+ 17.811842,
+ 9.847188,
+ 15.124973,
+ 15.623392,
+ 29.147844,
+ 12.309229,
+ 32.15152,
+ 33.225769,
+ 14.426914,
+ 17.496277,
+ 14.7356,
+ 15.503921,
+ 12.336852,
+ 16.469248
+ ],
+ "mean_perplexity": 14.991893
+ }
+}
diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py
new file mode 100644
index 000000000..d10d9f5db
--- /dev/null
+++ b/sharktank/tests/evaluate/perplexity_iree_test.py
@@ -0,0 +1,327 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+import pytest
+import json
+import numpy as np
+
+from sharktank.evaluate import perplexity_iree
+
+is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
+skipif_run_quick_llama_test = pytest.mark.skipif(
+ 'not config.getoption("run-nightly-llama-tests")',
+ reason="Run large tests if --run-nightly-llama-tests is passed",
+)
+
+
+@pytest.mark.usefixtures(
+ "get_model_artifacts",
+ "get_iree_flags",
+ "tensor_parallelism_size",
+ "baseline_perplexity_scores",
+ "batch_size",
+)
+@is_mi300x
+class PerplexityTest(unittest.TestCase):
+ def setUp(self):
+ self.current_perplexity_all = {}
+ self.delta = 5e-1
+ self.tensor_parallelism_size = 8
+ with open(self.baseline_perplexity_scores, "r") as f:
+ self.baseline_perplexity = json.load(f)
+
+ def test_llama3_8B_f16_decomposed(self):
+
+ # Llama 3.1 8B decomposed
+
+ model_name = "llama3_8B_f16_decomposed_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_8b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size=1",
+ f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
+ def test_llama3_8B_f16(self):
+
+ # Llama 3.1 8B non-decomposed
+
+ model_name = "llama3_8B_f16_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_8b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size=1",
+ f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
+ def test_llama3_8B_fp8_decomposed(self):
+
+ # Llama 3.1 8B decomposed
+
+ model_name = "llama3_8B_fp8_decomposed_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_8b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size=1",
+ f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
+ def test_llama3_8B_fp8(self):
+
+ # Llama 3.1 8B non-decomposed
+
+ model_name = "llama3_8B_fp8_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_8b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size=1",
+ f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(
+ reason="Sharding is unsupported",
+ )
+ def test_llama3_405B_f16_decomposed(self):
+
+ # Llama 3.1 405B decomposed
+
+ model_name = "llama3_405B_f16_decomposed_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_405b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
+ def test_llama3_405B_f16(self):
+
+ # Llama 3.1 405B non-decomposed
+
+ model_name = "llama3_405B_f16_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_405b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
+ def test_llama3_405B_fp8_decomposed(self):
+
+ # Llama 3.1 405B decomposed
+
+ model_name = "llama3_405B_fp8_decomposed_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_405b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ f"--attention-kernel=decomposed",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error")
+ def test_llama3_405B_fp8(self):
+
+ # Llama 3.1 405B non-decomposed
+
+ model_name = "llama3_405B_fp8_iree"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_iree.main(
+ [
+ f"--irpa-file={self.llama3_405b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--iree-device={self.iree_device}",
+ f"--iree-hal-target-backends={self.iree_hal_target_backends}",
+ f"--iree-hip-target={self.iree_hip_target}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ f"--attention-kernel=torch_sdpa",
+ f"--num-prompts={self.batch_size}",
+ ]
+ )
+
+ baseline_mean_perplexity = round(
+ np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6
+ )
+ current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6)
+
+ perplexity_difference = current_mean_perplexity - baseline_mean_perplexity
+
+ self.assertAlmostEqual(
+ baseline_mean_perplexity,
+ current_mean_perplexity,
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/sharktank/tests/evaluate/perplexity_torch_test.py b/sharktank/tests/evaluate/perplexity_torch_test.py
new file mode 100644
index 000000000..042132f20
--- /dev/null
+++ b/sharktank/tests/evaluate/perplexity_torch_test.py
@@ -0,0 +1,274 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+import pytest
+import json
+
+from sharktank.evaluate import perplexity_torch
+
+longrun = pytest.mark.skipif("not config.getoption('longrun')")
+
+
+@pytest.mark.usefixtures(
+ "get_model_artifacts", "tensor_parallelism_size", "baseline_perplexity_scores"
+)
+class PerplexityTest(unittest.TestCase):
+ def setUp(self):
+ self.current_perplexity_all = {}
+ self.delta = 5e-1
+ self.tensor_parallelism_size = 8
+ with open(self.baseline_perplexity_scores, "r") as f:
+ self.baseline_perplexity = json.load(f)
+
+ @longrun
+ def test_llama3_8B_f16_decomposed(self):
+
+ # Llama 3.1 8B decomposed
+
+ model_name = "llama3_8B_f16_decomposed"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_8b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @pytest.mark.xfail(
+ reason="Non-decomposed attention is not supported yet",
+ )
+ @longrun
+ def test_llama3_8B_f16(self):
+
+ # Llama 3.1 8B non-decomposed
+
+ model_name = "llama3_8B_f16"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_8b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ f"--attention-kernel=torch_sdpa",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @pytest.mark.xfail(
+ reason="FP8 model is unsupported",
+ )
+ @longrun
+ def test_llama3_8B_fp8_decomposed(self):
+
+ # Llama 3.1 8B decomposed
+
+ model_name = "llama3_8B_fp8_decomposed"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_8b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @pytest.mark.xfail(
+ reason="Non-decomposed attention is not supported yet",
+ )
+ @longrun
+ def test_llama3_8B_fp8(self):
+
+ # Llama 3.1 8B non-decomposed
+
+ model_name = "llama3_8B_fp8"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_8b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_8b_tokenizer}",
+ f"--attention-kernel=torch_sdpa",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @pytest.mark.xfail(
+ reason="Sharding needs to be fixed",
+ )
+ @longrun
+ def test_llama3_405B_f16_decomposed(self):
+
+ # Llama 3.1 405B decomposed
+
+ model_name = "llama3_405B_f16_decomposed"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_405b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @pytest.mark.xfail(
+ reason="Non-decomposed attention is not supported yet",
+ )
+ @longrun
+ def test_llama3_405B_f16(self):
+
+ # Llama 3.1 405B non-decomposed
+
+ model_name = "llama3_405B_f16"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_405b_f16_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ f"--attention-kernel=torch_sdpa",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @pytest.mark.xfail(
+ reason="FP8 model is unsupported",
+ )
+ @longrun
+ def test_llama3_405B_fp8_decomposed(self):
+
+ # Llama 3.1 405B decomposed
+
+ model_name = "llama3_405B_fp8_decomposed"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_405b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+ @pytest.mark.xfail(
+ reason="Non-decomposed attention is not supported yet",
+ )
+ @longrun
+ def test_llama3_405B_fp8(self):
+
+ # Llama 3.1 405B non-decomposed
+
+ model_name = "llama3_405B_fp8"
+ baseline_perplexity = self.baseline_perplexity[model_name]
+
+ current_perplexity = perplexity_torch.main(
+ [
+ f"--irpa-file={self.llama3_405b_fp8_model}",
+ f"--tokenizer-config-json={self.llama3_405b_tokenizer}",
+ f"--tensor-parallelism-size={self.tensor_parallelism_size}",
+ f"--attention-kernel=torch_sdpa",
+ ]
+ )
+
+ perplexity_difference = (
+ current_perplexity["mean_perplexity"]
+ - baseline_perplexity["mean_perplexity"]
+ )
+
+ self.assertAlmostEqual(
+ baseline_perplexity["mean_perplexity"],
+ current_perplexity["mean_perplexity"],
+ delta=self.delta,
+ msg=f"Current perplexity deviates baseline by {perplexity_difference}",
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/sharktank/tests/export_test.py b/sharktank/tests/export_test.py
new file mode 100644
index 000000000..20b7de734
--- /dev/null
+++ b/sharktank/tests/export_test.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from sharktank.types import (
+ ReplicatedTensor,
+ SplitPrimitiveTensor,
+ DefaultPrimitiveTensor,
+ unbox_tensor,
+)
+from sharktank.export import (
+ export,
+ flatten_signature,
+ get_argument_flat_device_affinities,
+)
+from sharktank import ops
+from sharktank.utils.testing import (
+ assert_equal,
+ assert_iterables_equal,
+ assert_dicts_equal,
+)
+from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
+from iree.turbine import aot
+from unittest import TestCase
+import torch
+
+
+class ExportTest(TestCase):
+ def testFlattenSignature(self):
+ expected_a = [SplitPrimitiveTensor(ts=[torch.tensor([1])], shard_dim=0)]
+ expected_b = {"element": DefaultPrimitiveTensor(data=torch.tensor([2]))}
+ expected_c = torch.tensor([3])
+
+ @flatten_signature(expected_a, expected_b, expected_c)
+ def f(
+ a: list[SplitPrimitiveTensor],
+ b: dict[str, DefaultPrimitiveTensor],
+ c: torch.Tensor,
+ ):
+ assert_iterables_equal(a, expected_a, elements_equal=ops.equal)
+ assert_dicts_equal(b, expected_b, values_equal=ops.equal)
+ assert_equal(c, expected_c, equal=ops.equal)
+
+ f(
+ unbox_tensor(expected_a[0].shards[0]),
+ expected_b["element"].as_torch(),
+ expected_c,
+ )
+
+ def testGetFlatArgumentDeviceAffinities(self):
+ args = [
+ {
+ "a": [
+ SplitPrimitiveTensor(
+ ts=[torch.tensor([1]), torch.tensor([2])], shard_dim=0
+ )
+ ]
+ },
+ torch.tensor([3]),
+ ReplicatedTensor(ts=[torch.tensor([4]), torch.tensor([5])]),
+ ]
+ affinities = get_argument_flat_device_affinities(*args)
+ expected_affinities = {
+ 0: DeviceAffinity("0"),
+ 1: DeviceAffinity("1"),
+ 3: DeviceAffinity("0"),
+ 4: DeviceAffinity("1"),
+ }
+ assert_dicts_equal(affinities, expected_affinities)
+
+ def testExportWithArgumentDeviceAffinities(self):
+ args = (ReplicatedTensor(ts=[torch.tensor([1])]), torch.tensor([[2]]))
+
+ class Module(torch.nn.Module):
+ def f(self, a, b):
+ return a, b
+
+ module = Module()
+ fxb = FxProgramsBuilder(module)
+ export(
+ Module.f,
+ fx_builder=fxb,
+ args=args,
+ strict=False,
+ )
+ export_output = aot.export(
+ fxb,
+ )
+ asm = str(export_output.mlir_module)
+ print(asm)
+ self.assertRegex(
+ asm,
+ expected_regex=(
+ "func.func @f\\("
+ "%.+: !torch.vtensor<\\[1\\],si64> "
+ "{iree.abi.affinity = #hal.device.promise<@__device_0>}, "
+ "%.+: !torch.vtensor<\\[1,1\\],si64>\\)"
+ ),
+ )
diff --git a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py
index 30cc2296c..208d54782 100644
--- a/sharktank/tests/kernels/batch_matmul_transpose_b_test.py
+++ b/sharktank/tests/kernels/batch_matmul_transpose_b_test.py
@@ -13,7 +13,7 @@
import torch
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank import kernels
diff --git a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py
index b03293523..ff1430a1a 100644
--- a/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py
+++ b/sharktank/tests/kernels/conv_2d_nchw_fchw_test.py
@@ -12,10 +12,10 @@
from parameterized import parameterized
import torch
+import torch.nn.functional as F
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank import kernels
-from sharktank.ops.qconv_impls import _pad_last_2d
class conv_2d_nchw_fchw_test(unittest.TestCase):
@@ -36,7 +36,8 @@ def testBS32(self, input_dtype, output_dtype_name, atol, rtol):
inputs = (torch.rand([2, 4, 64, 64]) * 64).to(input_dtype)
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
- inputs_pad = _pad_last_2d(inputs, extended_list)
+ inputs_pad = F.pad(inputs, pad=extended_list)
+
weights = (torch.rand([8, 4, 3, 3]) * 64).to(input_dtype)
bias = (torch.rand([8]) * 64).to(dtype=output_dtype)
result = kernels.conv_2d_nchw_fchw(
@@ -68,7 +69,7 @@ def forward(self, a, b, c):
inputs = torch.rand([2, 320, 64, 64]) * 64
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
- inputs_pad = _pad_last_2d(inputs, extended_list)
+ inputs_pad = F.pad(inputs, pad=extended_list)
ep = torch.export.export(
mod,
args=(
diff --git a/sharktank/tests/kernels/einsum_q4_test.py b/sharktank/tests/kernels/einsum_q4_test.py
new file mode 100644
index 000000000..5f037ba9a
--- /dev/null
+++ b/sharktank/tests/kernels/einsum_q4_test.py
@@ -0,0 +1,141 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+
+logging.basicConfig(level=logging.DEBUG)
+
+import unittest
+from parameterized import parameterized
+
+import torch
+
+from iree.turbine import aot
+from sharktank import kernels
+from sharktank.types import layout_utils
+
+
+class einsum_2args_q4_test(unittest.TestCase):
+ def setUp(self):
+ torch.manual_seed(42)
+
+ @parameterized.expand(
+ [
+ (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3),
+ (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3),
+ (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3),
+ ]
+ )
+ def test_basic_mk_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol):
+ a = torch.rand([2, 320], dtype=a_dtype) / 256.0
+ d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0
+ qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
+ m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0
+ einsum_string = "mk,menk->men"
+ result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string)
+
+ # Dequantize and test with normal matmul.
+ # Tolerances are empirical and results are not expected to match exactly.
+ qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs)
+ b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3)
+ ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype))
+ torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol)
+
+ @parameterized.expand(
+ [
+ (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3),
+ (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3),
+ (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3),
+ ]
+ )
+ def test_basic_mek_menk_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol):
+ a = torch.rand([2, 4, 320], dtype=a_dtype) / 256.0
+ d = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) / 256.0
+ qs = (torch.rand([2, 4, 8, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
+ m = torch.rand([2, 4, 8, 10, 1], dtype=d_dtype) + 16.0
+ einsum_string = "mek,menk->men"
+ result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string)
+
+ # Dequantize and test with normal matmul.
+ # Tolerances are empirical and results are not expected to match exactly.
+ qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs)
+ b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(3)
+ ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype))
+ torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol)
+
+ @parameterized.expand(
+ [
+ (torch.float32, torch.float32, torch.float32, 1e-2, 1e-3),
+ (torch.float32, torch.float16, torch.float32, 1e-2, 1e-3),
+ (torch.float16, torch.float16, torch.float32, 1e-2, 1e-3),
+ ]
+ )
+ def test_basic_me_men_men(self, a_dtype, d_dtype, ref_dtype, atol, rtol):
+ a = torch.rand([2, 4], dtype=a_dtype) / 256.0
+ d = torch.rand([2, 4, 10, 1], dtype=d_dtype) / 256.0
+ qs = (torch.rand([2, 4, 10, 16], dtype=ref_dtype) * 255.0).to(torch.uint8)
+ m = torch.rand([2, 4, 10, 1], dtype=d_dtype) + 16.0
+ einsum_string = "me,men->men"
+ result = kernels.einsum_2args_q4(a, d, qs, m, einsum_string)
+
+ # Dequantize and test with normal matmul.
+ # Tolerances are empirical and results are not expected to match exactly.
+ qs_i8 = layout_utils.promote_linear_i4_block_to_i8(qs)
+ b = (d.to(ref_dtype) * qs_i8.to(ref_dtype) + m.to(ref_dtype)).flatten(2)
+ ref = torch.einsum(einsum_string, a.to(ref_dtype), b.to(ref_dtype))
+ torch.testing.assert_close(result.to(ref_dtype), ref, atol=atol, rtol=rtol)
+
+ def testExportDynamicDims(self):
+ class MyModule(torch.nn.Module):
+ def forward(self, a, d, qs, m):
+ return kernels.einsum_2args_q4(a, d, qs, m, "ij,jk->ik")
+
+ mod = MyModule()
+ ep = torch.export.export(
+ mod,
+ args=(
+ torch.rand([16, 320], dtype=torch.float32),
+ torch.rand([320, 2, 1], dtype=torch.float16),
+ (torch.rand([320, 2, 16], dtype=torch.float32) * 32).to(torch.uint8),
+ torch.rand([320, 2, 1], dtype=torch.float16),
+ ),
+ dynamic_shapes={
+ "a": {},
+ "d": {},
+ "qs": {},
+ "m": {},
+ },
+ )
+ output = aot.export(ep)
+ output.verify()
+ asm = str(output.mlir_module)
+ self.assertIn("@sharktank_einsum_2args_q4_ij_jk_ik_32_f32", asm)
+
+ def testExportStaticDims(self):
+ class MyModule(torch.nn.Module):
+ def forward(self, a, d, qs, m):
+ return kernels.einsum_2args_q4(a, d, qs, m, "mek,menk->men")
+
+ mod = MyModule()
+ ep = torch.export.export(
+ mod,
+ args=(
+ torch.rand([4, 16, 320], dtype=torch.float32),
+ torch.rand([4, 16, 2, 10, 1], dtype=torch.float16),
+ (torch.rand([4, 16, 2, 10, 16], dtype=torch.float32) * 32).to(
+ torch.uint8
+ ),
+ torch.rand([4, 16, 2, 10, 1], dtype=torch.float16),
+ ),
+ )
+ output = aot.export(ep)
+ output.verify()
+ asm = str(output.mlir_module)
+ self.assertIn("@sharktank_einsum_2args_q4_mek_menk_men_32_f32", asm)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py b/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py
index d9fc7370a..dca474446 100644
--- a/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py
+++ b/sharktank/tests/kernels/mmt_block_scaled_offset_q4_test.py
@@ -13,7 +13,7 @@
import torch
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank import kernels
from sharktank.types import layout_utils
diff --git a/sharktank/tests/kernels/mmt_block_scaled_q8_test.py b/sharktank/tests/kernels/mmt_block_scaled_q8_test.py
index f3fdf2ed9..08aa8d179 100644
--- a/sharktank/tests/kernels/mmt_block_scaled_q8_test.py
+++ b/sharktank/tests/kernels/mmt_block_scaled_q8_test.py
@@ -13,7 +13,7 @@
import torch
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank import kernels
diff --git a/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py b/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py
index 41c04106d..01272553a 100644
--- a/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py
+++ b/sharktank/tests/kernels/mmt_super_block_scaled_offset_q4_test.py
@@ -13,7 +13,7 @@
import torch
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank import kernels
from sharktank.types import layout_utils
diff --git a/sharktank/tests/kernels/mmtfp_test.py b/sharktank/tests/kernels/mmtfp_test.py
index 281498f90..e2c36e4ac 100644
--- a/sharktank/tests/kernels/mmtfp_test.py
+++ b/sharktank/tests/kernels/mmtfp_test.py
@@ -13,7 +13,7 @@
import torch
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank import kernels
diff --git a/sharktank/tests/kernels/pooling_nchw_sum_test.py b/sharktank/tests/kernels/pooling_nchw_sum_test.py
index 5c4e8ac0a..e512eb484 100644
--- a/sharktank/tests/kernels/pooling_nchw_sum_test.py
+++ b/sharktank/tests/kernels/pooling_nchw_sum_test.py
@@ -12,10 +12,10 @@
from parameterized import parameterized
import torch
+import torch.nn.functional as F
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank import kernels
-from sharktank.ops.qconv_impls import _pad_last_2d
class pooling_nchw_sum_test(unittest.TestCase):
@@ -34,7 +34,7 @@ def testBS32(self, atol, rtol):
a = (torch.randint(0, 100, (2, 1, 128, 128))).to(torch.float32)
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
- inputs_pad = _pad_last_2d(a, extended_list)
+ inputs_pad = F.pad(a, pad=extended_list)
weight_shape = [3, 3]
stride = [1, 1]
dilations = [1, 1]
@@ -62,7 +62,7 @@ def forward(self, a):
inputs = torch.rand([2, 1, 128, 128]) * 64
padding = [1, 1]
extended_list = [item for item in padding for _ in range(2)]
- inputs_pad = _pad_last_2d(inputs, extended_list)
+ inputs_pad = F.pad(inputs, pad=extended_list)
ep = torch.export.export(
mod,
args=((inputs_pad).to(dtype),),
diff --git a/sharktank/tests/layers/configs_test.py b/sharktank/tests/layers/configs_test.py
new file mode 100644
index 000000000..024eb6bb6
--- /dev/null
+++ b/sharktank/tests/layers/configs_test.py
@@ -0,0 +1,27 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from sharktank.layers.configs import LlamaHParams
+
+
+def test_llama_hp_params_to_from_gguf_props_roundtrip():
+ params = LlamaHParams(
+ model_arch="llama",
+ context_length=1,
+ embedding_length=2,
+ block_count=3,
+ feed_forward_length=3,
+ rope_dimension_count=4,
+ rope_freq_base=5.0,
+ attention_head_count=6,
+ attn_head_dim=4,
+ attention_layer_norm_rms_epsilon=8.0,
+ attention_head_count_kv=9,
+ expert_count=10,
+ expert_used_count=11,
+ )
+ roundtripped_params = LlamaHParams.from_gguf_props(params.to_gguf_props())
+ assert params == roundtripped_params
diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py
new file mode 100644
index 000000000..65b42c986
--- /dev/null
+++ b/sharktank/tests/layers/kv_cache_test.py
@@ -0,0 +1,502 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+
+import torch
+
+from sharktank.ops import replicate, reshard_split, unshard
+from sharktank.layers import *
+from sharktank.types import *
+
+
+def test_direct():
+ bs = 4
+ seq_length = 24
+ attn_head_count = 4
+ attn_head_dim = 16
+ transformer_block_count = 4
+ cache = DirectKVCache(
+ block_seq_stride=4,
+ transformer_block_count=transformer_block_count,
+ attn_head_count=attn_head_count,
+ attn_head_dim=attn_head_dim,
+ seq_length=seq_length,
+ dtype=torch.float32,
+ device=None,
+ )
+
+ allocation = cache.allocate(bs=bs)
+ allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation]
+
+ write_seq_length = seq_length - 5
+
+ # Write a prefill in:
+ write_ones = torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32
+ )
+ write_twos = torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32
+ )
+ cache.write(
+ allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1
+ )
+
+ # Check the written values have updated:
+ read_empty = [
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ ]
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=read_empty,
+ transformer_block_index=1,
+ seq_len=write_seq_length,
+ )
+ torch.testing.assert_close(write_ones, read_back[0])
+ torch.testing.assert_close(write_twos, read_back[1])
+
+ # Check the others are still zero:
+ for i in range(transformer_block_count):
+ if i == 1:
+ continue
+ read_ones = [
+ torch.zeros(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ torch.zeros(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ ]
+ read_ones = cache.read(
+ allocation,
+ read_into_partitions=read_ones,
+ transformer_block_index=i,
+ seq_len=write_seq_length,
+ )
+ torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0))
+ torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0))
+
+ # Write timestep
+ write_threes = torch.full(
+ (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32
+ )
+ write_fours = torch.full(
+ (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32
+ )
+ write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64)
+ cache.write_timestep(
+ allocation,
+ cache_partitions=[write_threes, write_fours],
+ transformer_block_index=1,
+ seq_positions=write_pos,
+ )
+
+ read_empty = [
+ torch.zeros(
+ (bs, write_seq_length + 1, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ torch.zeros(
+ (bs, write_seq_length + 1, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ ]
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=read_empty,
+ transformer_block_index=1,
+ seq_len=write_seq_length + 1,
+ )
+
+ check_concat_0 = torch.concat([write_ones, write_threes], dim=1)
+ check_concat_1 = torch.concat([write_twos, write_fours], dim=1)
+
+ torch.testing.assert_close(check_concat_0, read_back[0])
+ torch.testing.assert_close(check_concat_1, read_back[1])
+
+
+def test_sharded_direct():
+ bs = 4
+ seq_length = 24
+ attn_head_count = 8
+ attn_head_dim = 16
+ transformer_block_count = 4
+ shard_count = 4
+ cache = DirectKVCache(
+ block_seq_stride=4,
+ transformer_block_count=transformer_block_count,
+ attn_head_count=attn_head_count,
+ attn_head_dim=attn_head_dim,
+ seq_length=seq_length,
+ shard_count=shard_count,
+ dtype=torch.float32,
+ device=None,
+ )
+
+ allocation = cache.allocate(bs=bs)
+ # allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation]
+
+ write_seq_length = seq_length - 5
+
+ # Write a prefill in:
+ write_ones = reshard_split(
+ torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ 1.0,
+ dtype=torch.float32,
+ ),
+ dim=2,
+ count=shard_count,
+ )
+
+ write_twos = reshard_split(
+ torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ 2.0,
+ dtype=torch.float32,
+ ),
+ dim=2,
+ count=shard_count,
+ )
+
+ cache.write(
+ allocation, cache_partitions=[write_ones, write_twos], transformer_block_index=1
+ )
+
+ # Check the written values have updated:
+ read_empty = [
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ ]
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=read_empty,
+ transformer_block_index=1,
+ seq_len=write_seq_length,
+ )
+ torch.testing.assert_close(unshard(write_ones), unshard(read_back[0]))
+ torch.testing.assert_close(unshard(write_twos), unshard(read_back[1]))
+
+ # Write timestep
+ write_threes = reshard_split(
+ torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32),
+ dim=2,
+ count=shard_count,
+ )
+ write_fours = reshard_split(
+ torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32),
+ dim=2,
+ count=shard_count,
+ )
+
+ write_pos = replicate(
+ torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count
+ )
+ cache.write_timestep(
+ allocation,
+ cache_partitions=[write_threes, write_fours],
+ transformer_block_index=1,
+ seq_positions=write_pos,
+ )
+
+ read_empty = [
+ torch.zeros(
+ (bs, write_seq_length + 1, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ torch.zeros(
+ (bs, write_seq_length + 1, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ ]
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=read_empty,
+ transformer_block_index=1,
+ seq_len=write_seq_length + 1,
+ )
+
+ check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1)
+ check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1)
+
+ torch.testing.assert_close(check_concat_0, unshard(read_back[0]))
+ torch.testing.assert_close(check_concat_1, unshard(read_back[1]))
+
+
+def test_paged():
+ bs = 4
+ seq_length = 24
+ attn_head_count = 4
+ attn_head_dim = 16
+ transformer_block_count = 4
+ block_seq_stride = 4
+ cache = PagedKVCache(
+ block_seq_stride=block_seq_stride,
+ transformer_block_count=transformer_block_count,
+ attn_head_count=attn_head_count,
+ attn_head_dim=attn_head_dim,
+ dtype=torch.float32,
+ device=None,
+ )
+
+ write_seq_length = seq_length - 4
+ page_count = bs * seq_length // block_seq_stride
+ page_ids = torch.arange(page_count, dtype=torch.int64)
+ page_ids = page_ids.view(bs, seq_length // block_seq_stride)
+ write_page_ids = page_ids[:, : write_seq_length // block_seq_stride]
+
+ allocation = cache.allocate(page_count=page_count)
+ allocation = [torch.full(t.shape, 0.0, out=t) for t in allocation]
+
+ # Write a prefill in:
+ write_ones = torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), 1.0, dtype=torch.float32
+ )
+ write_twos = torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), 2.0, dtype=torch.float32
+ )
+
+ cache.write(
+ allocation,
+ cache_partitions=[write_ones, write_twos],
+ transformer_block_index=1,
+ page_ids=write_page_ids,
+ )
+
+ # Check the written values have updated:
+ read_empty = [
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ ]
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=read_empty,
+ transformer_block_index=1,
+ seq_len=write_seq_length,
+ page_ids=write_page_ids,
+ )
+ torch.testing.assert_close(write_ones, read_back[0])
+ torch.testing.assert_close(write_twos, read_back[1])
+
+ # Check the others are still zero:
+ for i in range(transformer_block_count):
+ if i == 1:
+ continue
+ read_ones = [
+ torch.zeros(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ torch.zeros(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ ]
+ read_ones = cache.read(
+ allocation,
+ read_into_partitions=read_ones,
+ transformer_block_index=i,
+ seq_len=write_seq_length,
+ page_ids=write_page_ids,
+ )
+ torch.testing.assert_close(read_ones[0], torch.full(read_ones[0].shape, 0.0))
+ torch.testing.assert_close(read_ones[1], torch.full(read_ones[0].shape, 0.0))
+
+ # Write timestep
+ write_threes = torch.full(
+ (bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32
+ )
+ write_fours = torch.full(
+ (bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32
+ )
+ write_pos = torch.full((bs,), write_seq_length, dtype=torch.int64)
+ cache.write_timestep(
+ allocation,
+ cache_partitions=[write_threes, write_fours],
+ transformer_block_index=1,
+ seq_positions=write_pos,
+ page_ids=page_ids,
+ )
+
+ read_empty = [
+ torch.zeros(
+ (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ torch.zeros(
+ (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ ]
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=read_empty,
+ transformer_block_index=1,
+ seq_len=write_seq_length + 1,
+ page_ids=page_ids,
+ )
+
+ check_concat_0 = torch.concat([write_ones, write_threes], dim=1)
+ check_concat_1 = torch.concat([write_twos, write_fours], dim=1)
+
+ torch.testing.assert_close(check_concat_0, read_back[0])
+ torch.testing.assert_close(check_concat_1, read_back[1])
+
+
+def test_sharded_paged():
+ bs = 4
+ seq_length = 24
+ attn_head_count = 8
+ attn_head_dim = 16
+ transformer_block_count = 4
+ block_seq_stride = 4
+ shard_count = 4
+ cache = PagedKVCache(
+ block_seq_stride=block_seq_stride,
+ transformer_block_count=transformer_block_count,
+ attn_head_count=attn_head_count,
+ attn_head_dim=attn_head_dim,
+ shard_count=shard_count,
+ dtype=torch.float32,
+ device=None,
+ )
+
+ write_seq_length = seq_length - 4
+ page_count = bs * seq_length // block_seq_stride
+ page_ids = torch.arange(page_count, dtype=torch.int64)
+ page_ids = page_ids.view(bs, seq_length // block_seq_stride)
+ page_ids = replicate(page_ids, shard_count)
+ write_page_ids = page_ids[:, : write_seq_length // block_seq_stride]
+
+ allocation = cache.allocate(page_count=page_count)
+
+ # Write a prefill in:
+ write_ones = reshard_split(
+ torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ 1.0,
+ dtype=torch.float32,
+ ),
+ dim=2,
+ count=shard_count,
+ )
+ write_twos = reshard_split(
+ torch.full(
+ (bs, write_seq_length, attn_head_count, attn_head_dim),
+ 2.0,
+ dtype=torch.float32,
+ ),
+ dim=2,
+ count=shard_count,
+ )
+
+ cache.write(
+ allocation,
+ cache_partitions=[write_ones, write_twos],
+ transformer_block_index=1,
+ page_ids=write_page_ids,
+ )
+
+ # Check the written values have updated:
+ empty_k = reshard_split(
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ dim=2,
+ count=shard_count,
+ )
+
+ empty_v = reshard_split(
+ torch.empty(
+ (bs, write_seq_length, attn_head_count, attn_head_dim), dtype=torch.float32
+ ),
+ dim=2,
+ count=shard_count,
+ )
+
+ read_empty = [empty_k, empty_v]
+
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=read_empty,
+ transformer_block_index=1,
+ seq_len=write_seq_length,
+ page_ids=write_page_ids,
+ )
+ torch.testing.assert_close(unshard(write_ones), unshard(read_back[0]))
+ torch.testing.assert_close(unshard(write_twos), unshard(read_back[1]))
+
+ # Write timestep
+ write_threes = reshard_split(
+ torch.full((bs, 1, attn_head_count, attn_head_dim), 3.0, dtype=torch.float32),
+ dim=2,
+ count=shard_count,
+ )
+
+ write_fours = reshard_split(
+ torch.full((bs, 1, attn_head_count, attn_head_dim), 4.0, dtype=torch.float32),
+ dim=2,
+ count=shard_count,
+ )
+
+ write_pos = replicate(
+ torch.full((bs,), write_seq_length, dtype=torch.int64), shard_count
+ )
+
+ cache.write_timestep(
+ allocation,
+ cache_partitions=[write_threes, write_fours],
+ transformer_block_index=1,
+ seq_positions=write_pos,
+ page_ids=page_ids,
+ )
+
+ empty_k = reshard_split(
+ torch.zeros(
+ (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ dim=2,
+ count=shard_count,
+ )
+
+ empty_v = reshard_split(
+ torch.zeros(
+ (bs, write_seq_length + block_seq_stride, attn_head_count, attn_head_dim),
+ dtype=torch.float32,
+ ),
+ dim=2,
+ count=shard_count,
+ )
+
+ read_back = cache.read(
+ allocation,
+ read_into_partitions=[empty_k, empty_v],
+ transformer_block_index=1,
+ seq_len=write_seq_length + 1,
+ page_ids=page_ids,
+ )
+
+ check_concat_0 = torch.concat([unshard(write_ones), unshard(write_threes)], dim=1)
+ check_concat_1 = torch.concat([unshard(write_twos), unshard(write_fours)], dim=1)
+
+ torch.testing.assert_close(check_concat_0, unshard(read_back[0]))
+ torch.testing.assert_close(check_concat_1, unshard(read_back[1]))
diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py
index e2d038f72..ad657889d 100644
--- a/sharktank/tests/layers/linear_test.py
+++ b/sharktank/tests/layers/linear_test.py
@@ -84,7 +84,7 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self):
bias_quant,
]
)
- linear = LinearLayer(theta)
+ linear = LinearLayer(theta, fake_quant=False)
output = linear(lhs)
output_ref = torch.matmul(lhs, rhs.T) + bias
diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py
new file mode 100644
index 000000000..a55782329
--- /dev/null
+++ b/sharktank/tests/layers/paged_llama_attention_block_test.py
@@ -0,0 +1,196 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+
+logging.basicConfig(level=logging.DEBUG)
+
+import unittest
+
+import torch
+
+from iree.turbine import aot
+from sharktank.layers import (
+ PagedLlamaAttentionBlock,
+ PagedKVCache,
+ RotaryEmbeddingLayer,
+)
+from sharktank.layers.testing import make_llama_attention_block_theta
+from sharktank.types.tensors import DefaultPrimitiveTensor
+
+
+class PagedLlamaAttentionBlockTest(unittest.TestCase):
+ def setUp(self):
+ torch.manual_seed(12345)
+ self.transformer_block_count = 13
+ self.block_index = 1
+ self.shard_count = 3
+ self.head_count_kv = 2 * self.shard_count
+ self.attention_head_count = 5 * self.head_count_kv
+ self.attention_head_dim = 11 * 2
+ self.rms_epsilon = 0.01
+ self.block_seq_stride = 17
+ self.cache_partition_count = 2
+ self.page_count = 23
+ self.embedding_length = self.attention_head_count * self.attention_head_dim
+ self.rope_dimension_count = self.attention_head_dim
+ self.block_seqlen = 7
+ self.max_seqlen = self.block_seq_stride * self.block_seqlen
+ self.rope_freq_base = None
+ self.batch_size = 3
+ self.start_index = 0
+
+ def testExportDecomposed(self):
+ dtype = torch.float32
+
+ cache = PagedKVCache(
+ transformer_block_count=self.transformer_block_count,
+ attn_head_count=self.head_count_kv,
+ attn_head_dim=self.attention_head_dim,
+ cache_partition_count=self.cache_partition_count,
+ block_seq_stride=self.block_seq_stride,
+ dtype=dtype,
+ )
+
+ cache_state = cache.paged.allocate(self.page_count)
+ cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype)
+
+ theta = make_llama_attention_block_theta(
+ head_count=self.attention_head_count,
+ head_count_kv=self.head_count_kv,
+ head_dim=self.attention_head_dim,
+ embedding_length=self.embedding_length,
+ )
+ attn = PagedLlamaAttentionBlock(
+ theta=theta,
+ block_index=self.block_index,
+ cache=cache,
+ head_count=self.attention_head_count,
+ head_dim=self.attention_head_dim,
+ head_count_kv=self.head_count_kv,
+ rms_epsilon=self.rms_epsilon,
+ attention_kernel="decomposed",
+ )
+
+ seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
+ self.batch_size, -1
+ )
+
+ embedding_module = RotaryEmbeddingLayer(
+ rope_dimension_count=self.rope_dimension_count,
+ max_seqlen=self.max_seqlen,
+ rope_freq_base=self.rope_freq_base,
+ )
+
+ class MyModule(torch.nn.Module):
+ def forward(self, h, seq_block_ids, cache_state):
+ return attn.forward(
+ h,
+ seq_block_ids=seq_block_ids,
+ embedding=embedding_module,
+ start_index=0,
+ cache_state=cache_state,
+ )
+
+ mod = MyModule()
+ h = torch.rand(
+ [
+ self.batch_size,
+ self.max_seqlen,
+ self.attention_head_count * self.attention_head_dim,
+ ]
+ )
+ mod.forward(h, seq_block_ids, cache_state)
+ ep = torch.export.export(
+ mod,
+ args=(
+ h,
+ seq_block_ids,
+ cache_state,
+ ),
+ )
+ output = aot.export(ep)
+ output.verify()
+ asm = str(output.mlir_module)
+ self.assertNotIn("scaled_dot_product_attention", asm)
+
+ def testExportNondecomposed(self):
+ dtype = torch.float32
+
+ cache = PagedKVCache(
+ transformer_block_count=self.transformer_block_count,
+ attn_head_count=self.head_count_kv,
+ attn_head_dim=self.attention_head_dim,
+ cache_partition_count=self.cache_partition_count,
+ block_seq_stride=self.block_seq_stride,
+ dtype=dtype,
+ )
+
+ cache_state = cache.paged.allocate(self.page_count)
+ cache_state[0] = torch.rand(cache_state[0].shape, dtype=dtype)
+
+ theta = make_llama_attention_block_theta(
+ head_count=self.attention_head_count,
+ head_count_kv=self.head_count_kv,
+ head_dim=self.attention_head_dim,
+ embedding_length=self.embedding_length,
+ )
+ attn = PagedLlamaAttentionBlock(
+ theta=theta,
+ block_index=self.block_index,
+ cache=cache,
+ head_count=self.attention_head_count,
+ head_dim=self.attention_head_dim,
+ head_count_kv=self.head_count_kv,
+ rms_epsilon=self.rms_epsilon,
+ attention_kernel="torch",
+ )
+
+ seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
+ self.batch_size, -1
+ )
+
+ embedding_module = RotaryEmbeddingLayer(
+ rope_dimension_count=self.rope_dimension_count,
+ max_seqlen=self.max_seqlen,
+ rope_freq_base=self.rope_freq_base,
+ )
+
+ class MyModule(torch.nn.Module):
+ def forward(self, h, seq_block_ids, cache_state):
+ return attn.forward(
+ h,
+ seq_block_ids=seq_block_ids,
+ embedding=embedding_module,
+ start_index=0,
+ cache_state=cache_state,
+ )
+
+ mod = MyModule()
+ h = torch.rand(
+ [
+ self.batch_size,
+ self.max_seqlen,
+ self.attention_head_count * self.attention_head_dim,
+ ]
+ )
+ mod.forward(h, seq_block_ids, cache_state)
+ ep = torch.export.export(
+ mod,
+ args=(
+ h,
+ seq_block_ids,
+ cache_state,
+ ),
+ )
+ output = aot.export(ep)
+ output.verify()
+ asm = str(output.mlir_module)
+ self.assertIn("torch.aten._scaled_dot_product_flash_attention_for_cpu", asm)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/sharktank/tests/layers/sharded_conv2d_with_iree_test.py b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py
new file mode 100644
index 000000000..9b29e5761
--- /dev/null
+++ b/sharktank/tests/layers/sharded_conv2d_with_iree_test.py
@@ -0,0 +1,210 @@
+import unittest
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from pathlib import Path
+import tempfile
+import torch
+from iree.turbine import aot
+from sharktank.models.punet.layers import Conv2DLayer
+from sharktank import ops
+from sharktank.types import (
+ Dataset,
+ DefaultPrimitiveTensor,
+ Theta,
+ ShardedTensor,
+ SplitPrimitiveTensor,
+ unbox_tensor,
+)
+from sharktank.types.sharding import Conv2DSplitOutputChannelSharding
+import iree.runtime
+from typing import List, Optional
+import os
+
+vm_context: iree.runtime.VmContext = None
+
+
+def get_compiler_args(target_device_kind: str, shard_count: int) -> List[str]:
+ result = [
+ f"--iree-hal-target-device={target_device_kind}[{i}]"
+ for i in range(shard_count)
+ ]
+ return result
+
+
+def compile_iree_module(
+ export_output: aot.ExportOutput, module_path: str, shard_count: int
+):
+ export_output.session.set_flags(
+ *get_compiler_args(target_device_kind="llvm-cpu", shard_count=shard_count)
+ )
+ export_output.compile(save_to=module_path, target_backends=None)
+
+
+# TODO: improve IREE's Python API to be more concise in a multi-device context.
+# This run function should be way shorter.
+def run_iree_module(
+ sharded_input_image: ShardedTensor,
+ module_path: str,
+ parameters_path: str,
+) -> ShardedTensor:
+ shard_count = sharded_input_image.shard_count
+ hal_driver = iree.runtime.get_driver("local-task")
+ vm_instance = iree.runtime.VmInstance()
+ available_devices = hal_driver.query_available_devices()
+ # Use the same actual device for all devices.
+ devices = [
+ hal_driver.create_device(available_devices[0]) for _ in range(shard_count)
+ ]
+ hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices)
+ params_path = Path(parameters_path)
+ # TODO: make IREE able to load the parameters from the top parameter file
+ # without having to specify the parameter file for each shard separately.
+ parameter_index = iree.runtime.ParameterIndex()
+ for i in range(shard_count):
+ parameter_index.load(
+ file_path=str(
+ Path(params_path).with_suffix(f".rank{i}{params_path.suffix}")
+ )
+ )
+ parameter_provider = parameter_index.create_provider(scope="model")
+ parameters_module = iree.runtime.create_io_parameters_module(
+ vm_instance, parameter_provider
+ )
+
+ vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path))
+
+ # The context needs to be destroyed after the buffers, although
+ # it is not associate with them on the API level.
+ global vm_context
+ vm_context = iree.runtime.VmContext(
+ instance=vm_instance, modules=(hal_module, parameters_module, vm_module)
+ )
+ module_input_args = [
+ iree.runtime.asdevicearray(
+ devices[i], sharded_input_image.shards[i].as_torch().to("cpu").numpy()
+ )
+ for i in range(shard_count)
+ ]
+
+ vm_function = vm_module.lookup_function("main")
+ invoker = iree.runtime.FunctionInvoker(
+ vm_context=vm_context,
+ # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
+ # This works, but does not look right.
+ device=devices[0],
+ vm_function=vm_function,
+ )
+ results = invoker(*module_input_args)
+ shards = [torch.tensor(tensor.to_host()) for tensor in results]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=1)
+
+
+def run_test_sharded_conv2d_with_iree(
+ mlir_path: Path, module_path: Path, parameters_path: Path, caching: bool
+):
+ torch.set_default_dtype(torch.float32)
+ torch.manual_seed(123456)
+ batches = 2
+ in_channels = 6
+ out_channels = 8
+ height = 11
+ width = 13
+ kernel_height = 5
+ kernel_width = 5
+ shard_count = 2
+ unsharded_theta = Theta(
+ {
+ "weight": DefaultPrimitiveTensor(
+ data=torch.rand(
+ out_channels,
+ in_channels,
+ kernel_height,
+ kernel_width,
+ )
+ ),
+ }
+ )
+ unsharded_theta.rename_tensors_to_paths()
+
+ if not caching or not os.path.exists(parameters_path):
+ sharding_spec = Conv2DSplitOutputChannelSharding(shard_count=shard_count)
+ sharded_theta = ops.reshard(unsharded_theta, sharding_spec)
+
+ # Roundtrip the dataset, which anchors the tensors as parameters to be loaded
+ # vs constants to be frozen (TODO: This is a bit wonky).
+ sharded_dataset = Dataset({}, sharded_theta)
+ sharded_dataset.save(parameters_path)
+
+ sharded_dataset = Dataset.load(parameters_path)
+
+ input_image = torch.rand(
+ batches,
+ in_channels,
+ height,
+ width,
+ )
+
+ sharded_torch_module = Conv2DLayer(sharded_dataset.root_theta, padding=(0, 0))
+ sharded_input_image = ops.reshard_split(input_image, dim=1, count=shard_count)
+ expected_result = sharded_torch_module(sharded_input_image)
+
+ if not caching or not os.path.exists(module_path):
+ exported_module = aot.export(
+ sharded_torch_module,
+ args=(sharded_input_image,),
+ )
+ exported_module.save_mlir(mlir_path)
+
+ compile_iree_module(
+ export_output=exported_module,
+ module_path=module_path,
+ shard_count=shard_count,
+ )
+
+ actual_result = run_iree_module(
+ sharded_input_image=sharded_input_image,
+ module_path=module_path,
+ parameters_path=parameters_path,
+ )
+ assert len(actual_result.shards) == len(expected_result.shards)
+ assert actual_result.shard_dim == expected_result.shard_dim
+ for actual_shard, expected_shard in zip(
+ actual_result.shards, expected_result.shards
+ ):
+ torch.testing.assert_close(
+ unbox_tensor(actual_shard), unbox_tensor(expected_shard)
+ )
+
+
+def test_sharded_conv2d_with_iree(
+ mlir_path: Optional[Path],
+ module_path: Optional[Path],
+ parameters_path: Optional[Path],
+ caching: bool,
+):
+ """Test sharding, exporting and running with IREE a 2D convolution layer."""
+
+ with tempfile.TemporaryDirectory(
+ # TODO: verify hypothesis and remove ignore_cleanup_errors=True after a fix.
+ # torch.export.export is spawning some processes that don't exit when the
+ # function returns, this causes some objects to not get destroyed, which
+ # in turn holds files params.rank0.irpa and params.rank1.irpa open.
+ ignore_cleanup_errors=True
+ ) as tmp_dir:
+ mlir_path = Path(tmp_dir) / "model.mlir" if mlir_path is None else mlir_path
+ module_path = (
+ Path(tmp_dir) / "module.vmfb" if module_path is None else module_path
+ )
+ parameters_path = (
+ Path(tmp_dir) / "params.irpa"
+ if parameters_path is None
+ else parameters_path
+ )
+ run_test_sharded_conv2d_with_iree(
+ mlir_path, module_path, parameters_path, caching
+ )
diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py
new file mode 100644
index 000000000..d7b6a0b33
--- /dev/null
+++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py
@@ -0,0 +1,236 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+from sharktank.layers import PagedKVCache
+import torch
+from sharktank.utils import iterables_equal
+from copy import deepcopy
+from typing import List, Tuple
+from sharktank import ops
+from sharktank.types import SplitPrimitiveTensor
+
+
+class ShardedPagedKVCacheTest(unittest.TestCase):
+ """Verify that the sharded paged KV cache behaves as the unsharded variant."""
+
+ def setUp(self):
+ torch.manual_seed(12345)
+ self.dtype = torch.float32
+ torch.set_default_dtype(self.dtype)
+ self.shard_count = 3
+ self.transformer_block_count = 5
+ self.attn_head_count = self.shard_count * 7
+ self.block_seq_stride = 19
+ self.attn_head_dim = 17
+ self.cache_partition_count = 2
+ self.page_count = 23
+ self.batch_size = 11
+ self.block_seq_len = 2
+ self.max_seq_len = self.block_seq_len * self.block_seq_stride
+
+ self.cache = PagedKVCache(
+ transformer_block_count=self.transformer_block_count,
+ attn_head_count=self.attn_head_count,
+ block_seq_stride=self.block_seq_stride,
+ attn_head_dim=self.attn_head_dim,
+ cache_partition_count=self.cache_partition_count,
+ dtype=self.dtype,
+ )
+ self.sharded_cache = PagedKVCache(
+ shard_count=self.shard_count,
+ transformer_block_count=self.transformer_block_count,
+ attn_head_count=self.attn_head_count,
+ block_seq_stride=self.block_seq_stride,
+ attn_head_dim=self.attn_head_dim,
+ cache_partition_count=self.cache_partition_count,
+ dtype=self.dtype,
+ )
+
+ def make_unsharded_and_sharded_equal_cache_states(
+ self,
+ ) -> Tuple[List[torch.Tensor], List[SplitPrimitiveTensor]]:
+ cache_state = self.cache.allocate(self.page_count)
+ cache_state[0] = torch.rand_like(cache_state[0])
+ sharded_cache_state = self.sharded_cache.shard_state(deepcopy(cache_state))
+ self.assert_equal_unsharded_and_sharded_cache_states(
+ cache_state, sharded_cache_state
+ )
+ return cache_state, sharded_cache_state
+
+ def assert_equal_unsharded_and_sharded_cache_states(
+ self,
+ cache_state: List[torch.Tensor],
+ sharded_cache_state: List[SplitPrimitiveTensor],
+ ):
+ sharded_state_as_unsharded = ops.unshard(
+ self.sharded_cache.unflatten_page_table(sharded_cache_state)
+ ).flatten(start_dim=1)
+ assert ops.equal(
+ cache_state[0],
+ sharded_state_as_unsharded,
+ )
+
+ def testAllocate(self):
+ cache_state = self.cache.allocate(self.page_count)
+ sharded_cache_state = self.sharded_cache.allocate(self.page_count)
+ assert len(cache_state) == 1
+ assert len(sharded_cache_state) == 1
+ assert iterables_equal(cache_state[0].shape, sharded_cache_state[0].shape)
+ assert sharded_cache_state[0].shard_dim == 1
+ assert sharded_cache_state[0].shard_count == self.shard_count
+
+ def testUnflattenPageTable(self):
+ cache_state = self.cache.allocate(self.page_count)
+ sharded_cache_state = self.sharded_cache.allocate(self.page_count)
+
+ unflattened_cache_state = self.cache.unflatten_page_table(cache_state)
+ sharded_unflattened_cache_state = self.sharded_cache.unflatten_page_table(
+ sharded_cache_state
+ )
+ assert iterables_equal(
+ unflattened_cache_state.shape, sharded_unflattened_cache_state.shape
+ )
+ assert sharded_unflattened_cache_state.shard_dim == 4
+ assert sharded_unflattened_cache_state.shard_count == self.shard_count
+ assert sharded_unflattened_cache_state.shape[0] == self.page_count
+
+ def testRead(self):
+ (
+ cache_state,
+ sharded_cache_state,
+ ) = self.make_unsharded_and_sharded_equal_cache_states()
+
+ read_into_partitions_snapshot = [
+ torch.rand(
+ self.batch_size,
+ self.block_seq_len * self.block_seq_stride,
+ self.attn_head_count,
+ self.attn_head_dim,
+ )
+ for _ in range(self.cache_partition_count)
+ ]
+ read_into_partitions = deepcopy(read_into_partitions_snapshot)
+ transformer_block_index = 1
+ page_ids = torch.randint(
+ low=0, high=self.page_count, size=[self.batch_size, self.block_seq_len]
+ ).reshape([self.batch_size, self.block_seq_len])
+ self.cache.read(
+ state=cache_state,
+ read_into_partitions=read_into_partitions,
+ transformer_block_index=transformer_block_index,
+ page_ids=page_ids,
+ seq_len=self.block_seq_len * self.block_seq_stride,
+ )
+ sharded_read_into_partitions = deepcopy(
+ [
+ ops.reshard_split(t, dim=2, count=self.shard_count)
+ for t in read_into_partitions_snapshot
+ ]
+ )
+ sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
+ self.sharded_cache.read(
+ state=sharded_cache_state,
+ read_into_partitions=sharded_read_into_partitions,
+ transformer_block_index=transformer_block_index,
+ page_ids=sharded_page_ids,
+ seq_len=self.block_seq_len * self.block_seq_stride,
+ )
+ for unsharded, sharded in zip(
+ read_into_partitions, sharded_read_into_partitions
+ ):
+ assert ops.equal(unsharded, ops.unshard(sharded))
+
+ def testWriteTimestep(self):
+ (
+ cache_state,
+ sharded_cache_state,
+ ) = self.make_unsharded_and_sharded_equal_cache_states()
+
+ cache_partitions = [
+ torch.rand(
+ self.batch_size,
+ 1,
+ self.attn_head_count,
+ self.attn_head_dim,
+ )
+ for _ in range(self.cache_partition_count)
+ ]
+ transformer_block_index = 1
+ seq_positions = torch.randint(
+ low=0, high=self.max_seq_len, size=[self.batch_size]
+ )
+ page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape(
+ [self.batch_size, self.block_seq_len]
+ )
+ self.cache.write_timestep(
+ state=cache_state,
+ cache_partitions=cache_partitions,
+ transformer_block_index=transformer_block_index,
+ seq_positions=seq_positions,
+ page_ids=page_ids,
+ )
+ sharded_cache_partitions = deepcopy(
+ [
+ ops.reshard_split(t, dim=2, count=self.shard_count)
+ for t in cache_partitions
+ ]
+ )
+ sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count)
+ sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
+ self.sharded_cache.write_timestep(
+ state=sharded_cache_state,
+ cache_partitions=sharded_cache_partitions,
+ transformer_block_index=transformer_block_index,
+ seq_positions=sharded_seq_positions,
+ page_ids=sharded_page_ids,
+ )
+ self.assert_equal_unsharded_and_sharded_cache_states(
+ cache_state, sharded_cache_state
+ )
+
+ def testWrite(self):
+ (
+ cache_state,
+ sharded_cache_state,
+ ) = self.make_unsharded_and_sharded_equal_cache_states()
+
+ cache_partitions = [
+ torch.rand(
+ self.batch_size,
+ self.block_seq_len * self.block_seq_stride,
+ self.attn_head_count,
+ self.attn_head_dim,
+ )
+ for _ in range(self.cache_partition_count)
+ ]
+ transformer_block_index = 1
+ assert self.batch_size * self.block_seq_len <= self.page_count
+ page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape(
+ [self.batch_size, self.block_seq_len]
+ )
+ self.cache.write(
+ state=cache_state,
+ cache_partitions=cache_partitions,
+ transformer_block_index=transformer_block_index,
+ page_ids=page_ids,
+ )
+ sharded_cache_partitions = deepcopy(
+ [
+ ops.reshard_split(t, dim=2, count=self.shard_count)
+ for t in cache_partitions
+ ]
+ )
+ sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
+ self.sharded_cache.write(
+ state=sharded_cache_state,
+ cache_partitions=sharded_cache_partitions,
+ transformer_block_index=transformer_block_index,
+ page_ids=sharded_page_ids,
+ )
+ self.assert_equal_unsharded_and_sharded_cache_states(
+ cache_state, sharded_cache_state
+ )
diff --git a/sharktank/tests/layers/sharded_paged_llama_attention_block.py b/sharktank/tests/layers/sharded_paged_llama_attention_block.py
new file mode 100644
index 000000000..c94fd44ab
--- /dev/null
+++ b/sharktank/tests/layers/sharded_paged_llama_attention_block.py
@@ -0,0 +1,163 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+from sharktank.layers import (
+ PagedLlamaAttentionBlock,
+ PagedKVCache,
+ RotaryEmbeddingLayer,
+)
+from sharktank.layers.testing import make_llama_attention_block_theta, make_rand_torch
+from sharktank.models.llama.sharding import PagedLlamaAttentionBlockSharding
+from sharktank.types import SplitPrimitiveTensor, unbox_tensor
+import torch
+from sharktank import ops
+from copy import deepcopy
+import pytest
+
+
+class ShardedPagedLlamaAttentionBlockTest(unittest.TestCase):
+ """Verify that the sharded Llama paged attention block behaves in PyTorch as the
+ unsharded variant."""
+
+ def setUp(self):
+ torch.manual_seed(12345)
+ self.transformer_block_count = 13
+ self.block_index = 1
+ self.shard_count = 3
+ self.head_count_kv = 2 * self.shard_count
+ self.attention_head_count = 5 * self.head_count_kv
+ self.attention_head_dim = 11 * 2
+ self.rms_epsilon = 0.01
+ self.block_seq_stride = 17
+ self.cache_partition_count = 2
+ self.page_count = 23
+ self.embedding_length = self.attention_head_count * self.attention_head_dim
+ self.rope_dimension_count = self.attention_head_dim
+ self.block_seqlen = 7
+ self.max_seqlen = self.block_seq_stride * self.block_seqlen
+ self.rope_freq_base = None
+ self.batch_size = 3
+ self.start_index = 0
+
+ def testSmallSizedLayerFp64(self):
+ self.runTestSmallSizedLayer(dtype=torch.float64)
+
+ @pytest.mark.xfail(
+ reason="The accuracy seems low (atol=0.0018, rtol=0.5065)",
+ strict=True,
+ raises=AssertionError,
+ )
+ def testSmallSizedLayerFp32(self):
+ self.runTestSmallSizedLayer(dtype=torch.float32)
+
+ def runTestSmallSizedLayer(self, dtype: torch.dtype):
+ torch.set_default_dtype(dtype)
+
+ def make_paged_kv_cache(shard_count: int) -> PagedKVCache:
+ return PagedKVCache(
+ transformer_block_count=self.transformer_block_count,
+ attn_head_count=self.head_count_kv,
+ attn_head_dim=self.attention_head_dim,
+ cache_partition_count=self.cache_partition_count,
+ block_seq_stride=self.block_seq_stride,
+ dtype=dtype,
+ shard_count=shard_count,
+ )
+
+ cache = make_paged_kv_cache(shard_count=1)
+ sharded_cache = make_paged_kv_cache(shard_count=self.shard_count)
+
+ def make_unsharded_and_sharded_equal_cache_states() -> tuple[
+ list[torch.Tensor], list[SplitPrimitiveTensor]
+ ]:
+ cache_state = cache.allocate(self.page_count)
+ cache_state[0] = make_rand_torch(cache_state[0].shape, dtype=dtype)
+ sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state))
+ return cache_state, sharded_cache_state
+
+ (
+ cache_state,
+ sharded_cache_state,
+ ) = make_unsharded_and_sharded_equal_cache_states()
+
+ input_tensor = make_rand_torch(
+ (
+ self.batch_size,
+ self.max_seqlen,
+ self.attention_head_count * self.attention_head_dim,
+ ),
+ dtype=dtype,
+ )
+ seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
+ self.batch_size, -1
+ )
+ embedding_module = RotaryEmbeddingLayer(
+ rope_dimension_count=self.rope_dimension_count,
+ max_seqlen=self.max_seqlen,
+ rope_freq_base=self.rope_freq_base,
+ )
+
+ theta = make_llama_attention_block_theta(
+ head_count=self.attention_head_count,
+ head_count_kv=self.head_count_kv,
+ head_dim=self.attention_head_dim,
+ embedding_length=self.embedding_length,
+ )
+ attention_block = PagedLlamaAttentionBlock(
+ theta=theta,
+ block_index=self.block_index,
+ cache=cache,
+ head_count=self.attention_head_count,
+ head_dim=self.attention_head_dim,
+ head_count_kv=self.head_count_kv,
+ rms_epsilon=self.rms_epsilon,
+ )
+ expected_result = attention_block(
+ input_tensor,
+ embedding=embedding_module,
+ seq_block_ids=seq_block_ids,
+ start_index=self.start_index,
+ cache_state=cache_state,
+ )
+
+ sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count)
+ sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count)
+ sharded_embedding_module = RotaryEmbeddingLayer(
+ rope_dimension_count=self.rope_dimension_count,
+ max_seqlen=self.max_seqlen,
+ rope_freq_base=self.rope_freq_base,
+ tensor_parallelism_size=self.shard_count,
+ )
+
+ theta_sharding = PagedLlamaAttentionBlockSharding(shard_count=self.shard_count)
+ sharded_theta = ops.reshard(theta, theta_sharding)
+ sharded_attention_block = PagedLlamaAttentionBlock(
+ theta=sharded_theta,
+ block_index=self.block_index,
+ cache=sharded_cache,
+ head_count=self.attention_head_count,
+ head_dim=self.attention_head_dim,
+ head_count_kv=self.head_count_kv,
+ rms_epsilon=self.rms_epsilon,
+ )
+ sharded_result = sharded_attention_block(
+ sharded_input_tensor,
+ embedding=sharded_embedding_module,
+ seq_block_ids=sharded_seq_block_ids,
+ start_index=self.start_index,
+ cache_state=sharded_cache_state,
+ )
+
+ actual_result = unbox_tensor(ops.unshard(sharded_result))
+ actual_cache_state = unbox_tensor(
+ ops.unshard(
+ sharded_cache.unflatten_page_table(sharded_cache_state)
+ ).flatten(start_dim=1)
+ )
+
+ torch.testing.assert_close(actual_result, expected_result)
+ torch.testing.assert_close(actual_cache_state, cache_state[0])
diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py
new file mode 100644
index 000000000..f24b8313a
--- /dev/null
+++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py
@@ -0,0 +1,58 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+
+import torch
+
+from sharktank.layers import RotaryEmbeddingLayer
+from sharktank import ops
+from sharktank.types import (
+ ShardedTensor,
+ SplitPrimitiveTensor,
+ unbox_tensor,
+)
+
+import unittest
+from typing import List, Optional
+import os
+
+
+def test_sharded_rotary_table():
+ bs = 4
+ rope_dims = 16
+ heads = 8
+ max_seqlen = 128
+ rope_freq_base = None
+
+ # First we setup and get the default rotary embedding layer
+ xq = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
+ xk = torch.rand((bs, max_seqlen, heads, rope_dims), dtype=torch.float)
+ default_layer = RotaryEmbeddingLayer(
+ rope_dimension_count=rope_dims,
+ max_seqlen=max_seqlen,
+ rope_freq_base=rope_freq_base,
+ )
+ oq = default_layer(xt=xq, start_index=0)
+ ok = default_layer(xt=xk, start_index=0)
+
+ # Then we can shard the same inputs and layer
+ xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4)
+ xk = SplitPrimitiveTensor(ts=xk, shard_dim=2, shard_count=4)
+ shard_layer = RotaryEmbeddingLayer(
+ rope_dimension_count=rope_dims,
+ max_seqlen=max_seqlen,
+ rope_freq_base=rope_freq_base,
+ tensor_parallelism_size=4,
+ )
+ sq = shard_layer(xt=xq, start_index=0)
+ sk = shard_layer(xt=xk, start_index=0)
+
+ # Gathering and unboxing should yield the same results
+ sq = ops.unshard(sq)
+ sk = ops.unshard(sk)
+
+ torch.testing.assert_close(sq, oq)
+ torch.testing.assert_close(sk, ok)
diff --git a/sharktank/tests/models/llama/README.md b/sharktank/tests/models/llama/README.md
new file mode 100644
index 000000000..6adf38588
--- /dev/null
+++ b/sharktank/tests/models/llama/README.md
@@ -0,0 +1,14 @@
+# How to run Llama 3.1 Benchmarking Tests
+In order to run Llama 3.1 8B F16 Decomposed test:
+```
+pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s \
+ --run-quick-test --iree-hip-target=gfx942
+```
+
+In order to filter by test, use the -k option. If you
+wanted to only run the Llama 3.1 70B F16 Decomposed test:
+```
+pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s \
+ --run-nightly-llama-tests --iree-hip-target=gfx942 \
+ -k 'testBenchmark70B_f16_TP8_Decomposed'
+```
diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py
index bb5eb254d..211fab5a0 100644
--- a/sharktank/tests/models/llama/attention_test.py
+++ b/sharktank/tests/models/llama/attention_test.py
@@ -25,6 +25,7 @@
class AttentionBlockTest(unittest.TestCase):
def test(self):
+ torch.manual_seed(123456)
torch.set_default_dtype(torch.float32)
block_index = 0
seq_len = 13
@@ -58,7 +59,7 @@ def test(self):
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
- use_hf=True,
+ attention_kernel="decomposed",
)
attention_embedding = RotaryEmbeddingLayer(
rope_dimension_count=rope_dimension_count,
@@ -148,7 +149,9 @@ def test(self):
)[0]
assert sharktank_output.shape == huggingface_output.shape
- torch.testing.assert_close(sharktank_output, huggingface_output)
+ torch.testing.assert_close(
+ sharktank_output, huggingface_output, atol=1e-5, rtol=5e-2
+ )
if __name__ == "__main__":
diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py
new file mode 100644
index 000000000..751615a85
--- /dev/null
+++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py
@@ -0,0 +1,930 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+from datetime import datetime
+import os
+import sys
+import unittest
+import pytest
+import subprocess
+from pathlib import Path
+from typing import List
+from sharktank.utils.export_artifacts import (
+ ExportArtifacts,
+ ExportMlirException,
+ IreeBenchmarkException,
+ IreeCompileException,
+)
+
+is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
+skipif_run_quick_llama_test = pytest.mark.skipif(
+ 'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")',
+ reason="Skipping largs tests when --run-quick-llama-test is set.",
+)
+
+
+@pytest.mark.usefixtures("get_iree_flags")
+class BaseBenchmarkTest(unittest.TestCase):
+ directory_created = False
+ current_date = datetime.now()
+ dir_path_suffix = current_date.strftime("%Y-%m-%d")
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+ models_dir = os.path.dirname(cur_dir)
+ tests_dir = os.path.dirname(models_dir)
+ sharktank_dir = os.path.dirname(tests_dir)
+ repo_root = os.path.dirname(sharktank_dir)
+ dir_path = Path(repo_root + "/" + dir_path_suffix)
+
+ @classmethod
+ def setUpClass(cls):
+ """This method will be run once per class to create the directory."""
+ if not cls.directory_created:
+ if not os.path.exists(cls.dir_path):
+ os.makedirs(cls.dir_path)
+ cls.directory_created = True
+
+ def setUp(self):
+ self.hip_device_id = os.getenv("HIP_DEVICE_ID", default="0")
+ self.compile_args = [
+ "--iree-dispatch-creation-enable-aggressive-fusion=true",
+ "--iree-global-opt-propagate-transposes=true",
+ "--iree-opt-aggressively-propagate-transposes=true",
+ "--iree-opt-data-tiling=false",
+ "--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'",
+ "--iree-stream-resource-memory-model=discrete",
+ "--iree-hip-legacy-sync=false",
+ "--iree-hal-indirect-command-buffers=true",
+ "--iree-hal-memoization=true",
+ "--iree-opt-strip-assertions",
+ ]
+
+
+@is_mi300x
+class BenchmarkLlama3_1_8B(BaseBenchmarkTest):
+ def setUp(self):
+ super().setUp()
+ # TODO: add numpy files to Azure and download from it
+ self.artifacts_dir = Path("/data/llama3.1/weights/8b")
+ self.irpa_path = self.artifacts_dir / "fp16/llama3.1_8b_fp16.irpa"
+ self.irpa_path_fp8 = self.artifacts_dir / "f8/llama3.1_8b_fp8.irpa"
+ self.tensor_parallelism_size = 1
+ self.dir_path_8b = self.dir_path / "llama-8b"
+ self.temp_dir_8b = Path(self.dir_path_8b)
+ self.temp_dir_8b.mkdir(parents=True, exist_ok=True)
+ self.llama8b_f16_decomposed_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="decomposed",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama8b_f16_torch_sdpa_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="torch",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama8b_fp8_decomposed_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path_fp8),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="decomposed",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama8b_fp8_torch_sdpa_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path_fp8),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="torch",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.prefill_args_f16 = self.artifacts_dir / "prefill_args"
+ self.prefill_args_bs4_128_in_tokens_f16 = (
+ self.artifacts_dir / "prefill_args_bs4_128"
+ )
+ self.decode_args_f16 = self.artifacts_dir / "decode_args"
+ self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8"
+ self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8"
+ self.iree_run_prefill_args = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_f16}/tokens.npy",
+ f"--input=@{self.prefill_args_f16}/seq_lens.npy",
+ f"--input=@{self.prefill_args_f16}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_f16}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_prefill_nondecomposed_args_fp16 = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_decode_args = [
+ "--function=decode_bs4",
+ f"--input=@{self.decode_args_f16}/tokens.npy",
+ f"--input=@{self.decode_args_f16}/seq_lens.npy",
+ f"--input=@{self.decode_args_f16}/start_positions.npy",
+ f"--input=@{self.decode_args_f16}/seq_block_ids.npy",
+ f"--input=@{self.decode_args_f16}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_prefill_args_fp8 = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_fp8}/tokens.npy",
+ f"--input=@{self.prefill_args_fp8}/seq_lens.npy",
+ f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_decode_args_fp8 = [
+ "--function=decode_bs4",
+ f"--input=@{self.decode_args_fp8}/tokens.npy",
+ f"--input=@{self.decode_args_fp8}/seq_lens.npy",
+ f"--input=@{self.decode_args_fp8}/start_positions.npy",
+ f"--input=@{self.decode_args_fp8}/seq_block_ids.npy",
+ f"--input=@{self.decode_args_fp8}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+
+ def testBenchmark8B_f16_Decomposed(self):
+ output_file_name = self.dir_path_8b / "f16_decomposed"
+ output_mlir = self.llama8b_f16_decomposed_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama8b_f16_decomposed_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama8b_f16_decomposed_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ export_return_code = self.llama8b_f16_decomposed_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama8b_f16_decomposed_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama8b_f16_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @skipif_run_quick_llama_test
+ def testBenchmark8B_f16_Non_Decomposed_Prefill(self):
+ output_file_name = self.dir_path_8b / "f16_torch_prefill"
+ output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch"
+ export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ skip_decode=True,
+ )
+ self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_prefill_nondecomposed_args_fp16,
+ cwd=self.repo_root,
+ )
+
+ @skipif_run_quick_llama_test
+ @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ def testBenchmark8B_f16_Non_Decomposed(self):
+ output_file_name = self.dir_path_8b / "f16_torch"
+ output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama8b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama8b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ self.llama8b_f16_torch_sdpa_artifacts.attention_kernel = "torch"
+ export_return_code = self.llama8b_f16_torch_sdpa_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ def testBenchmark8B_fp8_Decomposed(self):
+ output_file_name = self.dir_path_8b / "fp8_decomposed"
+ output_mlir = self.llama8b_fp8_decomposed_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama8b_fp8_decomposed_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama8b_fp8_decomposed_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ export_return_code = self.llama8b_fp8_decomposed_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama8b_fp8_decomposed_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama8b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama8b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ def testBenchmark8B_fp8_Non_Decomposed(self):
+ output_file_name = self.dir_path_8b / "fp8_torch"
+ output_mlir = self.llama8b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama8b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama8b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ export_return_code = self.llama8b_fp8_torch_sdpa_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama8b_fp8_torch_sdpa_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+
+@is_mi300x
+@skipif_run_quick_llama_test
+class BenchmarkLlama3_1_70B(BaseBenchmarkTest):
+ def setUp(self):
+ super().setUp()
+ # TODO: add numpy files to Azure and download from it
+ self.artifacts_dir = Path("/data/llama3.1/weights/70b")
+ self.irpa_path = self.artifacts_dir / "fp16/llama3.1_70b_f16.irpa"
+ self.irpa_path_fp8 = self.artifacts_dir / "f8/llama70b_fp8.irpa"
+ self.tensor_parallelism_size = 8
+ self.dir_path_70b = self.dir_path / "llama-70b"
+ self.temp_dir_70b = Path(self.dir_path_70b)
+ self.temp_dir_70b.mkdir(parents=True, exist_ok=True)
+ self.llama70b_f16_decomposed_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="decomposed",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama70b_f16_torch_sdpa_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="torch",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama70b_fp8_decomposed_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path_fp8),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="decomposed",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama70b_fp8_torch_sdpa_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path_fp8),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="torch",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.prefill_args_f16 = self.artifacts_dir / "prefill_args"
+ self.prefill_args_bs4_128_in_tokens_f16 = (
+ self.artifacts_dir / "prefill_args_bs4_128"
+ )
+ self.decode_args_f16 = self.artifacts_dir / "decode_args"
+ self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8"
+ self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8"
+ self.iree_run_prefill_args = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_f16}/tokens.npy",
+ f"--input=@{self.prefill_args_f16}/seq_lens.npy",
+ f"--input=@{self.prefill_args_f16}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_f16}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_prefill_nondecomposed_args_fp16 = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_decode_args = [
+ "--function=decode_bs4",
+ f"--input=@{self.decode_args_f16}/tokens.npy",
+ f"--input=@{self.decode_args_f16}/seq_lens.npy",
+ f"--input=@{self.decode_args_f16}/start_positions.npy",
+ f"--input=@{self.decode_args_f16}/seq_block_ids.npy",
+ f"--input=@{self.decode_args_f16}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_prefill_args_fp8 = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_fp8}/tokens.npy",
+ f"--input=@{self.prefill_args_fp8}/seq_lens.npy",
+ f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_decode_args_fp8 = [
+ "--function=decode_bs4",
+ f"--input=@{self.decode_args_fp8}/tokens.npy",
+ f"--input=@{self.decode_args_fp8}/seq_lens.npy",
+ f"--input=@{self.decode_args_fp8}/start_positions.npy",
+ f"--input=@{self.decode_args_fp8}/seq_block_ids.npy",
+ f"--input=@{self.decode_args_fp8}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+
+ @pytest.mark.xfail(
+ reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
+ )
+ def testBenchmark70B_f16_TP8_Decomposed(self):
+ output_file_name = self.dir_path_70b / "f16_decomposed"
+ output_mlir = self.llama70b_f16_decomposed_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama70b_f16_decomposed_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama70b_f16_decomposed_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama70b_f16_decomposed_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama70b_f16_decomposed_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama70b_f16_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ def testBenchmark70B_f16_TP8_Non_Decomposed(self):
+ output_file_name = self.dir_path_70b / "f16_torch"
+ output_mlir = self.llama70b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama70b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama70b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ self.llama70b_f16_torch_sdpa_artifacts.attention_kernel = "torch"
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama70b_f16_torch_sdpa_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama70b_f16_torch_sdpa_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama70b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama70b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ def testBenchmark70B_fp8_TP8_Decomposed(self):
+ output_file_name = self.dir_path_70b / "fp8_decomposed"
+ output_mlir = self.llama70b_fp8_decomposed_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama70b_fp8_decomposed_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama70b_fp8_decomposed_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"f8/tp8/llama3.1_70b_fp8_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama70b_fp8_decomposed_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama70b_fp8_decomposed_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama70b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama70b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
+ def testBenchmark70B_fp8_TP8_Non_Decomposed(self):
+ output_file_name = self.dir_path_70b / "fp8_torch"
+ output_mlir = self.llama70b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama70b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama70b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"f8/tp8/llama3.1_70b_f8_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama70b_fp8_torch_sdpa_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama70b_fp8_torch_sdpa_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+
+@is_mi300x
+@skipif_run_quick_llama_test
+class BenchmarkLlama3_1_405B(BaseBenchmarkTest):
+ def setUp(self):
+ super().setUp()
+ # TODO: add numpy files to Azure and download from it
+ self.artifacts_dir = Path("/data/llama3.1/weights/405b")
+ self.irpa_path = self.artifacts_dir / "fp16/llama3.1_405b_fp16.irpa"
+ self.irpa_path_fp8 = self.artifacts_dir / "f8/llama3.1_405b_fp8.irpa"
+ self.tensor_parallelism_size = 8
+ self.dir_path_405b = self.dir_path / "llama-405b"
+ self.temp_dir_405b = Path(self.dir_path_405b)
+ self.temp_dir_405b.mkdir(parents=True, exist_ok=True)
+ self.llama405b_f16_decomposed_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="decomposed",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama405b_f16_torch_sdpa_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="torch",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama405b_fp8_decomposed_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path_fp8),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="decomposed",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.llama405b_fp8_torch_sdpa_artifacts = ExportArtifacts(
+ irpa_path=str(self.irpa_path_fp8),
+ batch_size=4,
+ iree_hip_target="gfx942",
+ iree_hal_target_backends="rocm",
+ attention_kernel="torch",
+ tensor_parallelism_size=self.tensor_parallelism_size,
+ )
+ self.prefill_args_f16 = self.artifacts_dir / "prefill_args"
+ self.prefill_args_bs4_128_in_tokens_f16 = (
+ self.artifacts_dir / "prefill_args_bs4_128"
+ )
+ self.decode_args_f16 = self.artifacts_dir / "decode_args"
+ self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8"
+ self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8"
+ self.iree_run_prefill_args = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_f16}/tokens.npy",
+ f"--input=@{self.prefill_args_f16}/seq_lens.npy",
+ f"--input=@{self.prefill_args_f16}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_f16}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_prefill_nondecomposed_args_fp16 = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/random_tokens.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_lens.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_bs4_128_in_tokens_f16}/cs_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_decode_args = [
+ "--function=decode_bs4",
+ f"--input=@{self.decode_args_f16}/tokens.npy",
+ f"--input=@{self.decode_args_f16}/seq_lens.npy",
+ f"--input=@{self.decode_args_f16}/start_positions.npy",
+ f"--input=@{self.decode_args_f16}/seq_block_ids.npy",
+ f"--input=@{self.decode_args_f16}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_prefill_args_fp8 = [
+ "--function=prefill_bs4",
+ f"--input=@{self.prefill_args_fp8}/tokens.npy",
+ f"--input=@{self.prefill_args_fp8}/seq_lens.npy",
+ f"--input=@{self.prefill_args_fp8}/seq_block_ids.npy",
+ f"--input=@{self.prefill_args_fp8}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+ self.iree_run_decode_args_fp8 = [
+ "--function=decode_bs4",
+ f"--input=@{self.decode_args_fp8}/tokens.npy",
+ f"--input=@{self.decode_args_fp8}/seq_lens.npy",
+ f"--input=@{self.decode_args_fp8}/start_positions.npy",
+ f"--input=@{self.decode_args_fp8}/seq_block_ids.npy",
+ f"--input=@{self.decode_args_fp8}/cache_state_f16.npy",
+ "--benchmark_repetitions=3",
+ ]
+
+ @pytest.mark.xfail(
+ reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
+ )
+ def testBenchmark405B_f16_TP8_Decomposed(self):
+ output_file_name = self.dir_path_405b / "f16_decomposed"
+ output_mlir = self.llama405b_f16_decomposed_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama405b_f16_decomposed_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama405b_f16_decomposed_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama405b_f16_decomposed_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama405b_f16_decomposed_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama405b_f16_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(
+ reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
+ )
+ def testBenchmark405B_f16_TP8_Non_Decomposed(self):
+ output_file_name = self.dir_path_405b / "f16_torch"
+ output_mlir = self.llama405b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama405b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama405b_f16_torch_sdpa_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ self.llama405b_f16_torch_sdpa_artifacts.attention_kernel = "torch"
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama405b_f16_torch_sdpa_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ skip_decode=True,
+ )
+ self.llama405b_f16_torch_sdpa_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(
+ reason="KeyError in theta.py", strict=True, raises=ExportMlirException
+ )
+ def testBenchmark405B_fp8_TP8_Decomposed(self):
+ output_file_name = self.dir_path_405b / "fp8_decomposed"
+ output_mlir = self.llama405b_fp8_decomposed_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama405b_fp8_decomposed_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama405b_fp8_decomposed_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"f8/tp8/llama3.1_405b_f8_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama405b_fp8_decomposed_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama405b_fp8_decomposed_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama405b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama405b_fp8_decomposed_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+ @pytest.mark.xfail(
+ reason="KeyError in theta.py", strict=True, raises=ExportMlirException
+ )
+ def testBenchmark405B_fp8_TP8_Non_Decomposed(self):
+ output_file_name = self.dir_path_405b / "fp8_torch"
+ output_mlir = self.llama405b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".mlir", prefix=output_file_name
+ )
+ output_json = self.llama405b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".json", prefix=output_file_name
+ )
+ output_vmfb = self.llama405b_fp8_torch_sdpa_artifacts.create_file(
+ suffix=".vmfb", prefix=output_file_name
+ )
+ output_shard_file_name = (
+ self.artifacts_dir
+ / f"f8/tp8/llama3.1_405b_f8_tp{self.tensor_parallelism_size}_parameters.irpa"
+ )
+ if output_shard_file_name.exists():
+ self.irpa_path = output_shard_file_name
+ export_return_code = self.llama405b_fp8_torch_sdpa_artifacts.export_to_mlir(
+ mlir_path=output_mlir,
+ json_path=output_json,
+ )
+ self.llama405b_fp8_torch_sdpa_artifacts.compile_to_vmfb(
+ mlir_path=str(output_mlir),
+ vmfb_path=output_vmfb,
+ hal_dump_path=output_file_name,
+ cwd=self.repo_root,
+ args=self.compile_args,
+ )
+ # benchmark prefill
+ self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_prefill_args,
+ cwd=self.repo_root,
+ )
+ # benchmark decode
+ self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
+ hip_device_id=self.hip_device_id,
+ vmfb_name=output_vmfb,
+ irpa_path=self.irpa_path_fp8,
+ args=self.iree_run_decode_args,
+ cwd=self.repo_root,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py
index 9d36db2e2..a80575951 100644
--- a/sharktank/tests/models/llama/kv_cache_test.py
+++ b/sharktank/tests/models/llama/kv_cache_test.py
@@ -28,6 +28,7 @@ def setUp(self):
self.block_seq_stride = 16
self.rms_epsilon = 1e-5
self.rope_dimension_count = 128
+ self.rope_freq_base = 10000.0
self.max_seq_len = 4096
self.start_positions = torch.tensor([8])
self.bs = 1
@@ -58,6 +59,7 @@ def setUp(self):
)
self.attention_embedding = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
+ rope_freq_base=self.rope_freq_base,
max_seqlen=self.max_seq_len,
device=self.device,
use_hf=False,
@@ -72,7 +74,6 @@ def setUp(self):
head_dim=self.head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
- use_hf=False,
)
for n in range(self.block_count)
]
@@ -87,7 +88,6 @@ def setUp(self):
head_dim=self.head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
- use_hf=False,
)
for n in range(self.block_count)
]
diff --git a/sharktank/tests/models/llama/llama_cpp_instructions.md b/sharktank/tests/models/llama/llama_cpp_instructions.md
new file mode 100644
index 000000000..1ca6fa3cb
--- /dev/null
+++ b/sharktank/tests/models/llama/llama_cpp_instructions.md
@@ -0,0 +1,19 @@
+### How to build llama.cpp logit comparison branch
+```
+git clone https://github.com/aviator19941/llama.cpp.git
+cd llama.cpp/
+git checkout llama_comparison
+cmake -B build
+cmake --build build --config Release
+```
+
+### How to run llama.cpp
+```
+huggingface-cli download meta-llama/Meta-Llama-3.1-70B --local-dir /home/avsharma/Meta-Llama-3.1-70B
+python convert_hf_to_gguf.py --outtype f16 --outfile Llama-3.1-70B-f16.gguf ../Meta-Llama-3.1-70B/
+```
+
+To predict the prefill token, use `--n-predict 1` and to predict the first decode token, use `--n-predict 2`:
+```
+./build/bin/llama-cli -m Llama-3.1-70B-f16.gguf -p "I believe the meaning of life is" --threads 1 --temp 0 --n-predict 1 --no-warmup
+```
diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py
index e04ca11fd..9b3daabdf 100644
--- a/sharktank/tests/models/llama/moe_block_test.py
+++ b/sharktank/tests/models/llama/moe_block_test.py
@@ -8,25 +8,24 @@
from typing import List
import torch
-from shark_turbine.aot import *
+from iree.turbine.aot import *
from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
-from sharktank.layers.mixture_of_experts_block import SparseMoeBlock
+from sharktank.layers.mixture_of_experts_block import MoeBlock
from sharktank import ops
-class SparseMoeBlockTest(unittest.TestCase):
- @unittest.skip("Skip test until grok implementation")
+class MoeBlockTest(unittest.TestCase):
def test(self):
- model = SparseMoeBlock(
+ model = MoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
)
fxb = FxProgramsBuilder(model)
- input = make_rand_torch((2, 16, 6144))
+ input = make_rand_torch((2, 32, 6144))
- @fxb.export_program(name="moe_block", args=(input,))
+ @fxb.export_program(name="moe_block", args=(input,), strict=False)
def _(model, input: torch.Tensor) -> torch.Tensor:
return model(input)
diff --git a/sharktank/tests/models/llama/prefill_tests.py b/sharktank/tests/models/llama/prefill_tests.py
new file mode 100644
index 000000000..093ecdfc9
--- /dev/null
+++ b/sharktank/tests/models/llama/prefill_tests.py
@@ -0,0 +1,175 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import torch
+from sharktank.examples.paged_llm_v1 import *
+from sharktank.utils import tokenizer
+from sharktank.utils import hf_datasets
+import unittest
+from pathlib import Path
+
+
+class BaseLlamaTest(unittest.TestCase):
+ def setUp(self):
+ raise NotImplementedError("Subclasses should implement this method.")
+
+ def createConfigModel(self, kv_cache_type):
+ return LlamaModelConfig(
+ hp=configs.LlamaHParams.from_gguf_props(self.dataset.properties),
+ block_seq_stride=16,
+ kv_cache_type=kv_cache_type,
+ device=self.device,
+ activation_dtype=self.activation_dtype,
+ attention_dtype=self.activation_dtype,
+ )
+
+ def runPrefill(self, *, kv_cache_type):
+ config = self.createConfigModel(kv_cache_type)
+ model = PagedLlamaModelV1(self.dataset.root_theta, config)
+ generator = TorchGenerator(model, self.tokenizer_config)
+ batch = generator.begin_batch(self.prompts)
+ attention_mask = model.attention_mask(
+ model.input_mask(batch.seq_lens, batch.token_ids.shape[1])
+ )
+ seq_block_ids_tensor = batch.pad_block_ids()
+ logits = batch.compute_prefill_logits(
+ model,
+ batch.token_ids,
+ attention_mask=attention_mask,
+ seq_block_ids=seq_block_ids_tensor,
+ cache_state=batch.cache_state,
+ )
+ batch.prefill()
+
+ bs, *_ = logits.shape
+ assert len(batch.seq_lens) == bs
+ greedy_token_logit = 0.0
+ step_logits = logits[0, batch.seq_lens[0] - 1]
+ greedy_token_logit = step_logits[torch.argmax(step_logits)]
+
+ return batch.results, greedy_token_logit
+
+ def comparePrefillResults(
+ self,
+ batch_results,
+ greedy_token_logit,
+ golden_prefill_token,
+ golden_prefill_token_logit,
+ ):
+ rtol = 3e-4
+ atol = 4e-3
+ assert batch_results == golden_prefill_token
+ torch.testing.assert_close(
+ greedy_token_logit, golden_prefill_token_logit, rtol=rtol, atol=atol
+ )
+
+
+class Llama7BTest(BaseLlamaTest):
+ def setUp(self):
+ default_arguments = {
+ "hf_dataset": "llama2_7B_f16",
+ "tokenizer-config-json": Path("./llama2-7b/tokenizer_config.json"),
+ "prompt": ["I believe the meaning of life is"],
+ "device": None,
+ "activation-dtype": "float32",
+ }
+ self.device = (
+ torch.device(default_arguments["device"])
+ if default_arguments["device"]
+ else None
+ )
+ self.activation_dtype = getattr(torch, default_arguments["activation-dtype"])
+ assert isinstance(self.activation_dtype, torch.dtype)
+ self.data_files = hf_datasets.get_dataset(
+ default_arguments["hf_dataset"]
+ ).download(local_dir=Path("."))
+ self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf")
+ self.tokenizer_config = tokenizer.load_tokenizer(
+ default_arguments["tokenizer-config-json"].parent,
+ tokenizer_type="transformers",
+ )
+ self.prompts = default_arguments["prompt"]
+ # token and logit determined by running llama.cpp (llama_cpp_instructions.md).
+ self.llama_cpp_7b_prefill_token = [[304]]
+ self.llama_cpp_7b_prefill_token_logit = torch.tensor(19.356606)
+
+ def testPrefillPaged7B(self):
+ batch_results_paged, greedy_token_logit_paged = self.runPrefill(
+ kv_cache_type="paged"
+ )
+ self.comparePrefillResults(
+ batch_results_paged,
+ greedy_token_logit_paged,
+ self.llama_cpp_7b_prefill_token,
+ self.llama_cpp_7b_prefill_token_logit,
+ )
+
+ def testPrefillDirect7B(self):
+ batch_results_direct, greedy_token_logit_direct = self.runPrefill(
+ kv_cache_type="direct"
+ )
+ self.comparePrefillResults(
+ batch_results_direct,
+ greedy_token_logit_direct,
+ self.llama_cpp_7b_prefill_token,
+ self.llama_cpp_7b_prefill_token_logit,
+ )
+
+
+class Llama8BTest(BaseLlamaTest):
+ def setUp(self):
+ default_arguments = {
+ "hf_dataset": "llama3_8B_f16",
+ "tokenizer-config-json": Path("./llama3.1-8b/tokenizer_config.json"),
+ "prompt": ["I believe the meaning of life is"],
+ "device": None,
+ "activation-dtype": "float32",
+ }
+ self.device = (
+ torch.device(default_arguments["device"])
+ if default_arguments["device"]
+ else None
+ )
+ self.activation_dtype = getattr(torch, default_arguments["activation-dtype"])
+ assert isinstance(self.activation_dtype, torch.dtype)
+ self.data_files = hf_datasets.get_dataset(
+ default_arguments["hf_dataset"]
+ ).download(local_dir=Path("."))
+ self.dataset = Dataset.load(self.data_files["gguf"], file_type="gguf")
+ self.tokenizer_config = tokenizer.load_tokenizer(
+ default_arguments["tokenizer-config-json"].parent,
+ tokenizer_type="transformers",
+ )
+ self.prompts = default_arguments["prompt"]
+ # token and logit determined by running llama.cpp (llama_cpp_instructions.md).
+ self.llama_cpp_8b_prefill_token = [[311]]
+ self.llama_cpp_8b_prefill_token_logit = torch.tensor(15.612568)
+
+ def testPrefillPaged8B(self):
+ batch_results_paged, greedy_token_logit_paged = self.runPrefill(
+ kv_cache_type="paged"
+ )
+ self.comparePrefillResults(
+ batch_results_paged,
+ greedy_token_logit_paged,
+ self.llama_cpp_8b_prefill_token,
+ self.llama_cpp_8b_prefill_token_logit,
+ )
+
+ def testPrefillDirect8B(self):
+ batch_results_direct, greedy_token_logit_direct = self.runPrefill(
+ kv_cache_type="direct"
+ )
+ self.comparePrefillResults(
+ batch_results_direct,
+ greedy_token_logit_direct,
+ self.llama_cpp_8b_prefill_token,
+ self.llama_cpp_8b_prefill_token_logit,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py
new file mode 100644
index 000000000..386061731
--- /dev/null
+++ b/sharktank/tests/models/llama/sharded_llama_test.py
@@ -0,0 +1,406 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+import pytest
+from typing import Any, List, Tuple, OrderedDict
+from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
+import sharktank.ops as ops
+from sharktank.types import unbox_tensor, Dataset, UnreducedTensor, SplitPrimitiveTensor
+from sharktank.models.llama.testing import make_random_llama_theta
+from sharktank.utils.testing import skip
+from sharktank.models.llama.sharding import shard_theta
+from sharktank.layers.configs import LlamaHParams
+from sharktank.utils.math import round_up_to_multiple_of
+from sharktank.utils import iterables_equal
+from sharktank.utils.iree import (
+ get_iree_devices,
+ load_iree_module,
+ run_iree_module_function,
+ prepare_iree_module_function_args,
+ call_torch_module_function,
+ iree_to_torch,
+)
+from sharktank.export import export as sharktank_export
+import tempfile
+import torch
+from copy import deepcopy
+from iree.turbine.aot import FxProgramsBuilder, export
+import iree.runtime
+import numpy as np
+import os
+
+
+@pytest.mark.usefixtures("caching", "path_prefix")
+class ShardedLlamaTest(unittest.TestCase):
+ def setUp(self):
+ torch.random.manual_seed(123456)
+ self.dtype = torch.float32
+ torch.set_default_dtype(self.dtype)
+ self.batch_size = 3
+ self.attention_head_count_kv = 4
+ self.attention_head_count = self.attention_head_count_kv * 5
+ self.vocabulary_size = 19
+ self.rope_dimension_count = 7 * 2
+ self.attn_head_dim = self.rope_dimension_count
+ self.block_seq_stride = 13
+ self.cache_page_count = 11
+ self.config = LlamaModelConfig(
+ hp=LlamaHParams(
+ context_length=self.block_seq_stride * 2,
+ embedding_length=self.attention_head_count * self.attn_head_dim,
+ block_count=3,
+ feed_forward_length=23,
+ rope_dimension_count=self.rope_dimension_count,
+ rope_freq_base=500000.0,
+ attention_head_count=self.attention_head_count,
+ attn_head_dim=self.attn_head_dim,
+ attention_layer_norm_rms_epsilon=0.01,
+ attention_head_count_kv=self.attention_head_count_kv,
+ expert_count=0,
+ expert_used_count=0,
+ model_arch="llama",
+ ),
+ block_seq_stride=self.block_seq_stride,
+ activation_dtype=self.dtype,
+ attention_dtype=self.dtype,
+ )
+ self.sharded_config = deepcopy(self.config)
+ self.sharded_config.tensor_parallelism_size = 2
+ self.theta = make_random_llama_theta(
+ config=self.config,
+ vocab_size=self.vocabulary_size,
+ )
+ self.prefill_seq_lens = torch.tensor(
+ [14, 9, self.block_seq_stride - 1], dtype=torch.int64
+ )
+
+ def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]:
+ batch_seq_len = round_up_to_multiple_of(
+ int(torch.max(self.prefill_seq_lens)), model.cache.pad_sequence_stride
+ )
+ token_ids = torch.randint(
+ low=0,
+ high=self.vocabulary_size,
+ size=[self.batch_size, batch_seq_len],
+ dtype=torch.int32,
+ )
+ attention_mask = model.attention_mask(
+ model.input_mask(self.prefill_seq_lens, batch_seq_len)
+ )
+ seq_block_ids = torch.arange(
+ self.batch_size * batch_seq_len // self.config.block_seq_stride
+ ).view(self.batch_size, -1)
+ cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
+ cache_state = [torch.rand_like(cache_state[0])]
+ return OrderedDict(
+ [
+ ("tokens", token_ids),
+ ("attention_mask", attention_mask),
+ ("seq_block_ids", seq_block_ids),
+ ("cache_state", cache_state),
+ ]
+ )
+
+ def make_equal_unsharded_and_sharded_prefill_args(
+ self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
+ ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]:
+ prefill_kwargs = self.make_prefill_args(model)
+ sharded_cache_state = sharded_model.cache.paged.allocate(
+ page_count=self.cache_page_count
+ )
+ assert iterables_equal(
+ prefill_kwargs["cache_state"][0].shape, sharded_cache_state[0].shape
+ )
+ sharded_prefill_kwargs = deepcopy(prefill_kwargs)
+ sharded_cache_state = sharded_model.cache.paged.shard_state(
+ sharded_prefill_kwargs["cache_state"]
+ )
+ sharded_prefill_kwargs["cache_state"] = sharded_cache_state
+
+ sharding = sharded_model.config.tensor_parallelism_size
+ for k in sharded_prefill_kwargs:
+ if k == "cache_state":
+ continue
+ sharded_prefill_kwargs[k] = ops.replicate(
+ sharded_prefill_kwargs[k], count=sharding
+ )
+
+ return prefill_kwargs, sharded_prefill_kwargs
+
+ def make_decode_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]:
+ start_positions = self.prefill_seq_lens.clone()
+ seq_lens = self.prefill_seq_lens + 1
+ batch_seq_len = round_up_to_multiple_of(
+ int(torch.max(seq_lens)), model.cache.pad_sequence_stride
+ )
+ decode_token_ids = torch.randint(
+ low=0,
+ high=self.vocabulary_size,
+ size=[self.batch_size, 1],
+ dtype=torch.int32,
+ )
+ attention_mask = model.decode_attention_mask(
+ model.input_mask(seq_lens, batch_seq_len)
+ )
+ seq_block_ids = torch.arange(
+ self.batch_size * batch_seq_len // self.config.block_seq_stride
+ ).view(self.batch_size, -1)
+ cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
+ cache_state = [torch.rand_like(cache_state[0])]
+ return OrderedDict(
+ [
+ ("tokens", decode_token_ids),
+ ("attention_mask", attention_mask),
+ ("start_positions", start_positions),
+ ("seq_block_ids", seq_block_ids),
+ ("cache_state", cache_state),
+ ]
+ )
+
+ def make_equal_unsharded_and_sharded_decode_args(
+ self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
+ ) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]:
+ decode_kwargs = self.make_decode_args(model)
+ sharded_decode_kwargs = deepcopy(decode_kwargs)
+ sharded_decode_kwargs["cache_state"] = sharded_model.cache.paged.shard_state(
+ sharded_decode_kwargs["cache_state"]
+ )
+
+ sharding = sharded_model.config.tensor_parallelism_size
+ for k in sharded_decode_kwargs:
+ if k == "cache_state":
+ continue
+ sharded_decode_kwargs[k] = ops.replicate(
+ sharded_decode_kwargs[k], count=sharding
+ )
+
+ return decode_kwargs, sharded_decode_kwargs
+
+ def testCompareToySizedModelToUnsharded(self):
+ """Run a sharded variant of a toy model size and compare it against the
+ unsharded variant."""
+ model = PagedLlamaModelV1(self.theta, self.config)
+ sharded_theta = shard_theta(self.theta, self.sharded_config)
+ sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config)
+
+ # Verify prefill step.
+ (
+ prefill_kwargs,
+ sharded_prefill_kwargs,
+ ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model)
+
+ expected_prefill_result = model.prefill(**prefill_kwargs)
+ sharded_prefill_result = sharded_model.prefill(**sharded_prefill_kwargs)
+ sharded_prefill_result = ops.unshard(sharded_prefill_result)
+ # The errors are quite high, but for float64 both errors drop to < 1e-12.
+ # The numerics are probably correct.
+ torch.testing.assert_close(
+ sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2
+ )
+ expected_cache_state = prefill_kwargs["cache_state"][0]
+ actual_cache_state = ops.unshard(
+ sharded_model.cache.paged.unflatten_page_table(
+ sharded_prefill_kwargs["cache_state"]
+ )
+ ).flatten(start_dim=1)
+ torch.testing.assert_close(
+ actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1
+ )
+
+ # Verify decode step.
+ (
+ decode_kwargs,
+ sharded_decode_kwargs,
+ ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model)
+ expected_decode_result = model.decode(**decode_kwargs)
+ sharded_decode_result = sharded_model.decode(**sharded_decode_kwargs)
+ sharded_decode_result = ops.unshard(sharded_decode_result)
+ torch.testing.assert_close(
+ sharded_decode_result, expected_decode_result, atol=1e-4, rtol=1e-5
+ )
+ expected_decode_cache_state = decode_kwargs["cache_state"][0]
+ actual_decode_cache_state = ops.unshard(
+ sharded_model.cache.paged.unflatten_page_table(
+ sharded_decode_kwargs["cache_state"]
+ )
+ ).flatten(start_dim=1)
+ # TODO: investigate why the Windows machine CI is producing a larger numerical
+ # error.
+ # The Ubuntu CI runs fine with default tolerances.
+ torch.testing.assert_close(
+ actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4
+ )
+
+ @skip(
+ (
+ "Before this does not crash at all we need "
+ "https://github.com/iree-org/iree/pull/18663 merged."
+ )
+ )
+ def testExportAndRunToySizedModelWithIree(self):
+ """Test exporting to MLIR and compiling with IREE the sharded Llama model.
+ Test numerical accuracy of the IREE module against PyTorch."""
+
+ if self.path_prefix is not None:
+ self.runTestExportAndRunToySizedModelWithIree(
+ path_prefix=self.path_prefix, dump_enabled=True
+ )
+ else:
+ with tempfile.TemporaryDirectory() as temp_dir:
+ self.runTestExportAndRunToySizedModelWithIree(
+ path_prefix=f"{temp_dir}/", dump_enabled=False
+ )
+
+ def runTestExportAndRunToySizedModelWithIree(
+ self, path_prefix: str, dump_enabled: bool
+ ):
+ sharded_theta = shard_theta(self.theta, self.sharded_config)
+ sharded_theta.rename_tensors_to_paths()
+ sharded_dataset = Dataset({}, sharded_theta)
+ sharded_parameters_path = f"{path_prefix}parameters.irpa"
+ sharded_dataset.save(sharded_parameters_path)
+ sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False)
+ iree_driver = "local-task"
+
+ model = PagedLlamaModelV1(self.theta, self.config)
+ sharded_model = PagedLlamaModelV1(
+ sharded_dataset.root_theta, self.sharded_config
+ )
+ (
+ _,
+ sharded_prefill_kwargs,
+ ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model)
+ (
+ _,
+ sharded_decode_kwargs,
+ ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model)
+
+ iree_module_path = f"{path_prefix}program.vmfb"
+ if not self.caching or not os.path.exists(iree_module_path):
+ # Export and compile the IREE module.
+ sharded_fxb = FxProgramsBuilder(sharded_model)
+
+ @sharktank_export(
+ fx_builder=sharded_fxb,
+ name="prefill",
+ kwargs=sharded_prefill_kwargs,
+ strict=False,
+ )
+ def _(model, *args, **kwargs) -> torch.Tensor:
+ return model.prefill(*args, **kwargs)
+
+ # TODO: remove strict=False when
+ # https://github.com/pytorch/pytorch/issues/136757
+ # is resolved.
+ @sharktank_export(
+ fx_builder=sharded_fxb,
+ name="decode",
+ kwargs=sharded_decode_kwargs,
+ strict=False,
+ )
+ def _(model, *args, **kwargs) -> torch.Tensor:
+ return model.decode(*args, **kwargs)
+
+ output = export(sharded_fxb)
+ if dump_enabled:
+ output.save_mlir(f"{path_prefix}program.mlir")
+ output.session.set_flags(
+ *[
+ f"--iree-hal-target-device=llvm-cpu[{i}]"
+ for i in range(self.sharded_config.tensor_parallelism_size)
+ ]
+ )
+ output.compile(
+ save_to=iree_module_path,
+ target_backends=None,
+ )
+
+ iree_devices = get_iree_devices(
+ driver=iree_driver,
+ device_count=self.sharded_config.tensor_parallelism_size,
+ )
+ iree_module, vm_context, vm_instance = load_iree_module(
+ module_path=iree_module_path,
+ devices=iree_devices,
+ parameters_path=sharded_parameters_path,
+ )
+
+ # Run prefill step.
+ prefill_iree_args = prepare_iree_module_function_args(
+ args=deepcopy(sharded_prefill_kwargs).values(), devices=iree_devices
+ )
+ for i, arg in enumerate(prefill_iree_args):
+ np.save(f"{path_prefix}prefill_arg{i}.npy", arg.to_host())
+ prefill_iree_result = run_iree_module_function(
+ args=prefill_iree_args,
+ function_name="prefill",
+ module=iree_module,
+ vm_context=vm_context,
+ driver=iree_driver,
+ trace_path_prefix=path_prefix if dump_enabled else None,
+ )
+ prefill_iree_result = UnreducedTensor(ts=iree_to_torch(*prefill_iree_result))
+ expected_prefill_result = call_torch_module_function(
+ module=sharded_model,
+ function_name="prefill",
+ kwargs=sharded_prefill_kwargs,
+ trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None,
+ )
+ prefill_iree_cache_state_shards = prefill_iree_args[
+ -self.config.tensor_parallelism_size - 1 :
+ ]
+ prefill_iree_cache_state = SplitPrimitiveTensor(
+ ts=iree_to_torch(*prefill_iree_cache_state_shards),
+ shard_dim=sharded_prefill_kwargs["cache_state"][0].shard_dim,
+ )
+
+ # Run decode step.
+ decode_iree_args = prepare_iree_module_function_args(
+ args=deepcopy(sharded_decode_kwargs).values(), devices=iree_devices
+ )
+ decode_iree_result = run_iree_module_function(
+ args=decode_iree_args,
+ function_name="decode",
+ module=iree_module,
+ vm_context=vm_context,
+ driver=iree_driver,
+ trace_path_prefix=path_prefix if dump_enabled else None,
+ )
+ decode_iree_result = UnreducedTensor(ts=iree_to_torch(*decode_iree_result))
+ expected_decode_result = call_torch_module_function(
+ module=sharded_model,
+ function_name="decode",
+ kwargs=sharded_decode_kwargs,
+ trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None,
+ )
+ decode_iree_cache_state_shards = decode_iree_args[
+ -self.config.tensor_parallelism_size - 1 :
+ ]
+ decode_iree_cache_state = SplitPrimitiveTensor(
+ ts=iree_to_torch(*decode_iree_cache_state_shards),
+ shard_dim=sharded_decode_kwargs["cache_state"][0].shard_dim,
+ )
+
+ # Check IREE's numerical correctness against PyTorch.
+ # TODO: Although, not entirely wrong, investigate why this accuracy is that
+ # low for fp32 (atol=0.0011, rtol=0.013).
+ torch.testing.assert_close(
+ ops.unshard(prefill_iree_result),
+ ops.unshard(expected_prefill_result),
+ )
+ torch.testing.assert_close(
+ ops.unshard(prefill_iree_cache_state),
+ ops.unshard(sharded_prefill_kwargs["cache_state"][0]),
+ )
+ torch.testing.assert_close(
+ ops.unshard(decode_iree_result),
+ ops.unshard(expected_decode_result),
+ )
+ torch.testing.assert_close(
+ ops.unshard(decode_iree_cache_state),
+ ops.unshard(sharded_decode_kwargs["cache_state"][0]),
+ )
diff --git a/sharktank/tests/models/punet/conftest.py b/sharktank/tests/models/punet/conftest.py
deleted file mode 100644
index 55364c4ef..000000000
--- a/sharktank/tests/models/punet/conftest.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-import pytest
-from pathlib import Path
-from typing import Optional
-
-
-def pytest_addoption(parser):
- parser.addoption(
- "--mlir",
- type=Path,
- default=None,
- help="Path to exported MLIR program. If not specified a temporary file will be used.",
- )
- parser.addoption(
- "--module",
- type=Path,
- default=None,
- help="Path to exported IREE module. If not specified a temporary file will be used.",
- )
- parser.addoption(
- "--parameters",
- type=Path,
- default=None,
- help="Exported model parameters. If not specified a temporary file will be used.",
- )
- parser.addoption(
- "--caching",
- action="store_true",
- default=False,
- help="Load cached results if present instead of recomputing.",
- )
-
-
-@pytest.fixture(scope="session")
-def mlir_path(pytestconfig: pytest.Config) -> Optional[Path]:
- return pytestconfig.getoption("mlir")
-
-
-@pytest.fixture(scope="session")
-def module_path(pytestconfig: pytest.Config) -> Optional[Path]:
- return pytestconfig.getoption("module")
-
-
-@pytest.fixture(scope="session")
-def parameters_path(pytestconfig: pytest.Config) -> Optional[Path]:
- return pytestconfig.getoption("parameters")
-
-
-@pytest.fixture(scope="session")
-def caching(pytestconfig: pytest.Config) -> Optional[Path]:
- return pytestconfig.getoption("caching")
diff --git a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py
index e2b602a7c..581584369 100644
--- a/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py
+++ b/sharktank/tests/models/punet/sharded_resnet_block_with_iree_test.py
@@ -10,7 +10,7 @@
import torch
-from shark_turbine import aot
+from iree.turbine import aot
from sharktank.models.punet.testing import make_resnet_block_2d_theta
from sharktank.models.punet.layers import ResnetBlock2D
from sharktank.models.punet.sharding import ResnetBlock2DSplitOutputChannelsSharding
@@ -19,6 +19,7 @@
import iree.runtime
from typing import List, Optional
import os
+import pytest
vm_context: iree.runtime.VmContext = None
@@ -40,6 +41,8 @@ def compile_iree_module(
export_output.compile(save_to=module_path, target_backends=None)
+# TODO: improve IREE's Python API to be more concise in a multi-device context.
+# This run function should be way shorter.
def run_iree_module(
sharded_input_image: ShardedTensor,
sharded_input_time_emb: ShardedTensor,
@@ -205,18 +208,26 @@ def run_test_sharded_resnet_block_with_iree(
parameters_path=parameters_path,
)
assert len(actual_result.shards) == len(expected_result.shards)
- # TODO: reenable this check once numerical issues are resolved.
- # for actual_shard, expected_shard in zip(
- # actual_result.shards, expected_result.shards
- # ):
- # torch.testing.assert_close(
- # unbox_tensor(actual_shard), unbox_tensor(expected_shard)
- # )
+ # TODO: reenable this test once numerical issues are resolved.
+ # The absolute accuracy is > 0.00042. Is this good enough?
+ # Maybe add a test with fp64, where if the accuracy is high would give us more
+ # confidence that fp32 is also OK.
+ for actual_shard, expected_shard in zip(
+ actual_result.shards, expected_result.shards
+ ):
+ torch.testing.assert_close(
+ unbox_tensor(actual_shard), unbox_tensor(expected_shard)
+ )
global vm_context
del vm_context
+@pytest.mark.xfail(
+ reason="Maybe numerical issues with low accuracy.",
+ strict=True,
+ raises=AssertionError,
+)
def test_sharded_resnet_block_with_iree(
mlir_path: Optional[Path],
module_path: Optional[Path],
diff --git a/sharktank/tests/models/t5/t5_test.py b/sharktank/tests/models/t5/t5_test.py
new file mode 100644
index 000000000..076404e5d
--- /dev/null
+++ b/sharktank/tests/models/t5/t5_test.py
@@ -0,0 +1,410 @@
+# Copyright 2024 Advanced Micro Devices, Inc
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from transformers.models.t5.modeling_t5 import (
+ T5Attention as ReferenceT5Attention,
+ T5LayerSelfAttention as ReferenceT5LayerSelfAttention,
+ T5LayerFF as ReferenceT5LayerFF,
+)
+from transformers import (
+ AutoTokenizer,
+ T5EncoderModel as ReferenceT5EncoderModel,
+ T5Config as ReferenceT5Config,
+)
+import os
+from collections import OrderedDict
+import pytest
+import torch
+from unittest import TestCase
+from parameterized import parameterized
+from sharktank.types import Theta, DefaultPrimitiveTensor, unbox_tensor, Dataset
+from sharktank.models.t5 import (
+ T5Attention,
+ T5SelfAttention,
+ T5Config,
+ T5Encoder,
+ T5LayerFF,
+ export_encoder_mlir,
+ export_encoder_iree_parameters,
+)
+from sharktank.utils.testing import make_rand_torch, TempDirTestBase
+from sharktank.utils.hf_datasets import get_dataset
+from sharktank.utils.iree import (
+ get_iree_devices,
+ load_iree_module,
+ run_iree_module_function,
+ prepare_iree_module_function_args,
+ call_torch_module_function,
+ flatten_for_iree_signature,
+ iree_to_torch,
+)
+import iree.compiler
+
+with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')")
+
+
+def make_random_mask(shape: tuple[int], dtype: torch.dtype):
+ mask = make_rand_torch(shape=shape, dtype=dtype)
+ mask = (mask >= 0).to(dtype=dtype)
+ return mask
+
+
+test_prompts = [
+ "Studies have been shown that owning a dog is good for you",
+ "The horse went into the river",
+ "We need at least one sentence long enough so that it spans more than one padding block which by default is of size 16.",
+ "Make the batch size 4",
+]
+
+
+@pytest.mark.usefixtures("get_model_artifacts")
+class T5EncoderEagerTest(TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.random.manual_seed(12345)
+ torch.no_grad()
+
+ def runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace(
+ self, huggingface_repo_id: str
+ ):
+ get_dataset(
+ huggingface_repo_id,
+ ).download()
+ tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id)
+ reference_model = ReferenceT5EncoderModel.from_pretrained(huggingface_repo_id)
+ reference_model.eval()
+
+ input_ids = tokenizer(
+ test_prompts,
+ return_tensors="pt",
+ padding=True,
+ ).input_ids
+
+ target_model_name = (
+ f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}_fp32_model"
+ )
+ target_model_path = getattr(self, target_model_name)
+ dataset = Dataset.load(target_model_path)
+ config = T5Config.from_gguf_properties(
+ dataset.properties,
+ feed_forward_proj="gated-gelu",
+ )
+ model = T5Encoder(theta=dataset.root_theta, config=config)
+ model.eval()
+
+ expected_outputs = reference_model(input_ids=input_ids)
+ actual_outputs = model(input_ids=input_ids)
+ torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0)
+
+ @with_t5_data
+ def testV1_1SmallFp32CompareTorchEagerAgainstHuggingFace(self):
+ self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-small")
+
+ @with_t5_data
+ def testV1_1XxlFp32CompareTorchEagerAgainstHuggingFace(self):
+ self.runTestV1_1Fp32CompareTorchEagerAgainstHuggingFace("google/t5-v1_1-xxl")
+
+
+@pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix")
+class T5EncoderIreeTest(TempDirTestBase):
+ def setUp(self):
+ super().setUp()
+ if self.path_prefix is None:
+ self.path_prefix = f"{self._temp_dir}/"
+
+ @parameterized.expand(
+ [
+ "google/t5-v1_1-small",
+ "google/t5-v1_1-xxl",
+ ]
+ )
+ @with_t5_data
+ def testV1_1Fp32CompareIreeAgainstTorchEager(self, huggingface_repo_id: str):
+ get_dataset(
+ huggingface_repo_id,
+ ).download()
+ tokenizer = AutoTokenizer.from_pretrained(huggingface_repo_id)
+
+ huggingface_repo_id_as_path = (
+ f"{huggingface_repo_id.replace('/', '__').replace('-', '_')}"
+ )
+ source_model_name = f"{huggingface_repo_id_as_path}_fp32_model"
+ source_model_path = getattr(self, source_model_name)
+
+ dataset = Dataset.load(source_model_path)
+ config = T5Config.from_gguf_properties(
+ dataset.properties,
+ feed_forward_proj="gated-gelu",
+ )
+
+ input_ids = tokenizer(
+ test_prompts,
+ return_tensors="pt",
+ padding=True,
+ pad_to_multiple_of=config.context_length_padding_block_size,
+ ).input_ids
+ input_args = OrderedDict([("input_ids", input_ids)])
+ batch_size = input_ids.shape[0]
+
+ reference_model = T5Encoder(theta=dataset.root_theta, config=config)
+ reference_result = flatten_for_iree_signature(
+ call_torch_module_function(
+ module=reference_model,
+ function_name="forward",
+ kwargs=input_args,
+ trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_torch_",
+ )
+ )
+
+ mlir_path = f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.mlir"
+ if not self.caching or not os.path.exists(mlir_path):
+ export_encoder_mlir(
+ source_model_path, batch_sizes=[batch_size], mlir_output_path=mlir_path
+ )
+ iree_module_path = (
+ f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.vmfb"
+ )
+ if not self.caching or not os.path.exists(iree_module_path):
+ iree.compiler.compile_file(
+ mlir_path,
+ output_file=iree_module_path,
+ extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
+ )
+
+ parameters_path = (
+ f"{self.path_prefix}{huggingface_repo_id_as_path}_encoder_fp32.irpa"
+ )
+ if not self.caching or not os.path.exists(parameters_path):
+ export_encoder_iree_parameters(source_model_path, parameters_path)
+
+ iree_devices = get_iree_devices(driver="hip", device_count=1)
+ iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
+ module_path=iree_module_path,
+ devices=iree_devices,
+ parameters_path=parameters_path,
+ )
+ iree_args = prepare_iree_module_function_args(
+ args=flatten_for_iree_signature(input_args), devices=iree_devices
+ )
+ iree_result = iree_to_torch(
+ *run_iree_module_function(
+ module=iree_module,
+ vm_context=iree_vm_context,
+ args=iree_args,
+ driver="hip",
+ function_name=f"forward_bs{batch_size}",
+ trace_path_prefix=f"{self.path_prefix}{huggingface_repo_id_as_path}_iree_",
+ )
+ )
+
+ torch.testing.assert_close(
+ reference_result, iree_result, atol=1e-4, rtol=2.0e-3
+ )
+
+
+class T5AttentionTest(TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.random.manual_seed(12345)
+ torch.no_grad()
+
+ def testCompareAgainstTransformersFp32(self):
+ dtype = torch.float32
+ batch_size = 19
+ batch_seq_len = 23
+ reference_config = ReferenceT5Config(
+ vocab_size=11,
+ d_model=13,
+ d_kv=7,
+ d_ff=3,
+ num_heads=2,
+ relative_attention_num_buckets=5,
+ relative_attention_max_distance=17,
+ dropout_rate=0.0,
+ )
+ reference_model = ReferenceT5Attention(
+ reference_config, has_relative_attention_bias=True
+ )
+ reference_model.eval()
+
+ theta = Theta(
+ {
+ "attn_q.weight": DefaultPrimitiveTensor(
+ data=reference_model.q.weight.data
+ ),
+ "attn_k.weight": DefaultPrimitiveTensor(
+ data=reference_model.k.weight.data
+ ),
+ "attn_v.weight": DefaultPrimitiveTensor(
+ data=reference_model.v.weight.data
+ ),
+ "attn_o.weight": DefaultPrimitiveTensor(
+ data=reference_model.o.weight.data
+ ),
+ "attn_rel_b.weight": DefaultPrimitiveTensor(
+ data=reference_model.relative_attention_bias.weight.data
+ ),
+ }
+ )
+ model = T5Attention(
+ theta=theta,
+ is_decoder=reference_config.is_decoder,
+ relative_attention_num_buckets=reference_config.relative_attention_num_buckets,
+ relative_attention_max_distance=reference_config.relative_attention_max_distance,
+ d_model=reference_config.d_model,
+ d_kv=reference_config.d_kv,
+ num_heads=reference_config.num_heads,
+ activation_dtype=dtype,
+ has_relative_attention_bias=True,
+ )
+ model.eval()
+
+ hidden_states = make_rand_torch(
+ shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype
+ )
+ mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype)
+ expected_outputs = reference_model(hidden_states=hidden_states, mask=mask)
+ actual_outputs = model(
+ hidden_states=DefaultPrimitiveTensor(data=hidden_states),
+ mask=DefaultPrimitiveTensor(data=mask),
+ )
+ torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0)
+
+ def testCompareSelfAttentionAgainstTransformersFp32(self):
+ dtype = torch.float32
+ batch_size = 19
+ batch_seq_len = 23
+ reference_config = ReferenceT5Config(
+ vocab_size=11,
+ d_model=13,
+ d_kv=7,
+ d_ff=3,
+ num_heads=2,
+ relative_attention_num_buckets=5,
+ relative_attention_max_distance=17,
+ dropout_rate=0.0,
+ layer_norm_epsilon=1e-6,
+ )
+ reference_model = ReferenceT5LayerSelfAttention(
+ reference_config, has_relative_attention_bias=True
+ )
+ reference_model.eval()
+
+ theta = Theta(
+ {
+ "attn_q.weight": DefaultPrimitiveTensor(
+ data=reference_model.SelfAttention.q.weight.data
+ ),
+ "attn_k.weight": DefaultPrimitiveTensor(
+ data=reference_model.SelfAttention.k.weight.data
+ ),
+ "attn_v.weight": DefaultPrimitiveTensor(
+ data=reference_model.SelfAttention.v.weight.data
+ ),
+ "attn_o.weight": DefaultPrimitiveTensor(
+ data=reference_model.SelfAttention.o.weight.data
+ ),
+ "attn_rel_b.weight": DefaultPrimitiveTensor(
+ data=reference_model.SelfAttention.relative_attention_bias.weight.data
+ ),
+ "attn_norm.weight": DefaultPrimitiveTensor(
+ data=reference_model.layer_norm.weight.data
+ ),
+ }
+ )
+ model = T5SelfAttention(
+ theta=theta,
+ is_decoder=reference_config.is_decoder,
+ relative_attention_num_buckets=reference_config.relative_attention_num_buckets,
+ relative_attention_max_distance=reference_config.relative_attention_max_distance,
+ d_model=reference_config.d_model,
+ d_kv=reference_config.d_kv,
+ num_heads=reference_config.num_heads,
+ activation_dtype=dtype,
+ layer_norm_epsilon=reference_config.layer_norm_epsilon,
+ has_relative_attention_bias=True,
+ )
+ model.eval()
+
+ hidden_states = make_rand_torch(
+ shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype
+ )
+ mask = make_random_mask(shape=[batch_size, 1, 1, batch_seq_len], dtype=dtype)
+ position_bias = make_rand_torch(
+ shape=[batch_size, reference_config.num_heads, batch_seq_len, batch_seq_len]
+ )
+ expected_outputs = reference_model(
+ hidden_states=hidden_states,
+ attention_mask=mask,
+ position_bias=position_bias,
+ )
+ actual_outputs = model(
+ hidden_states=DefaultPrimitiveTensor(data=hidden_states),
+ attention_mask=DefaultPrimitiveTensor(data=mask),
+ position_bias=DefaultPrimitiveTensor(data=position_bias),
+ )
+ actual_outputs = [
+ unbox_tensor(t) if t is not None else t for t in actual_outputs
+ ]
+ torch.testing.assert_close(actual_outputs, expected_outputs, atol=1e-5, rtol=0)
+
+
+class T5LayerFFTest(TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.random.manual_seed(12345)
+ torch.no_grad()
+
+ def testCompareAgainstTransformersFp32(self):
+ dtype = torch.float32
+ batch_size = 19
+ batch_seq_len = 23
+ reference_config = ReferenceT5Config(
+ d_model=13,
+ d_ff=3,
+ dropout_rate=0.0,
+ layer_norm_epsilon=1e-6,
+ feed_forward_proj="gated-gelu",
+ )
+
+ reference_model = ReferenceT5LayerFF(reference_config)
+ reference_model.eval()
+
+ theta = Theta(
+ {
+ "ffn_gate.weight": DefaultPrimitiveTensor(
+ data=reference_model.DenseReluDense.wi_0.weight
+ ),
+ "ffn_up.weight": DefaultPrimitiveTensor(
+ data=reference_model.DenseReluDense.wi_1.weight
+ ),
+ "ffn_down.weight": DefaultPrimitiveTensor(
+ data=reference_model.DenseReluDense.wo.weight
+ ),
+ "ffn_norm.weight": DefaultPrimitiveTensor(
+ data=reference_model.layer_norm.weight
+ ),
+ }
+ )
+ model = T5LayerFF(
+ theta=theta,
+ is_gated_act=reference_config.is_gated_act,
+ dense_act_fn=reference_config.dense_act_fn,
+ layer_norm_epsilon=reference_config.layer_norm_epsilon,
+ activation_dtype=dtype,
+ )
+
+ hidden_states = make_rand_torch(
+ shape=[batch_size, batch_seq_len, reference_config.d_model], dtype=dtype
+ )
+
+ expected_output = reference_model(
+ hidden_states=hidden_states,
+ )
+ actual_output = model(
+ hidden_states=DefaultPrimitiveTensor(data=hidden_states),
+ )
+ torch.testing.assert_close(actual_output, expected_output, atol=1e-5, rtol=0)
diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py
index b282d03fe..d303f9f43 100644
--- a/sharktank/tests/ops/ops_test.py
+++ b/sharktank/tests/ops/ops_test.py
@@ -115,7 +115,7 @@ def testMatchFail(self):
):
ops.matmul(1, 2)
- @unittest.skip("https://github.com/nod-ai/sharktank/issues/44")
+ @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44")
def testTorchImplTransposedRHS(self):
ops._registry._test_enable_last_op_dispatch(True)
t1 = torch.rand(32, 16, dtype=torch.float32)
@@ -128,7 +128,7 @@ def testTorchImplTransposedRHS(self):
ops.custom_impls.matmul_mmtfp_tensor_tensor,
)
- @unittest.skip("https://github.com/nod-ai/sharktank/issues/44")
+ @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44")
def testTorchImplNonTransposedRHS(self):
ops._registry._test_enable_last_op_dispatch(True)
t1 = torch.rand(32, 16, dtype=torch.float32)
@@ -141,7 +141,7 @@ def testTorchImplNonTransposedRHS(self):
ops.custom_impls.matmul_mmtfp_tensor_tensor,
)
- @unittest.skip("https://github.com/nod-ai/sharktank/issues/44")
+ @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44")
def testTorchImplTransposedPrimitiveRHS(self):
ops._registry._test_enable_last_op_dispatch(True)
t1 = torch.rand(32, 16, dtype=torch.float32)
@@ -155,6 +155,15 @@ def testTorchImplTransposedPrimitiveRHS(self):
ops.custom_impls.matmul_mmtfp_tensor_tensor,
)
+ def testTorchImplImplicitBatch(self):
+ ops._registry._test_enable_last_op_dispatch(True)
+ t1 = torch.rand(4, 32, 16, dtype=torch.float32)
+ t2 = torch.rand(48, 16, dtype=torch.float16)
+ t2_pt = DefaultPrimitiveTensor(data=t2)
+ result = ops.matmul(t1, t2_pt.T)
+ expected = torch.matmul(t1, t2.T.to(torch.float32))
+ torch.testing.assert_close(result, expected)
+
def testTorchImplTransposedQuantizedRHS_BlockScaledLayout(self):
ops._registry._test_enable_last_op_dispatch(True)
a_dtype = torch.float32
diff --git a/sharktank/tests/ops/qconv_test.py b/sharktank/tests/ops/qconv_test.py
index 4440202eb..97b0efd66 100644
--- a/sharktank/tests/ops/qconv_test.py
+++ b/sharktank/tests/ops/qconv_test.py
@@ -71,7 +71,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
- ops.qconv_impls.qconv2d_tensor_scaled_integer,
+ ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
@@ -105,7 +105,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_FloatBias(self):
y_actual = ops.conv2d(input_q, weight_q, bias, stride=1, padding=(1, 1))
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
- ops.qconv_impls.qconv2d_tensor_scaled_integer,
+ ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
@@ -147,7 +147,7 @@ def testInputSymPerTensor_WeightAsymPerChannel_QuantizedBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
- ops.qconv_impls.qconv2d_tensor_scaled_integer,
+ ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
@@ -184,7 +184,7 @@ def testInputSymPerTensor_WeightSymPerTensor_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
- ops.qconv_impls.qconv2d_tensor_scaled_integer,
+ ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
@@ -224,7 +224,7 @@ def testInputAsymPerChannel_WeightAsymPerChannel_NoBias(self):
)
self.assertIs(
ops._registry._test_get_last_op_dispatch(),
- ops.qconv_impls.qconv2d_tensor_scaled_integer,
+ ops.qconv_impls.qconv2d_tensor_scaled,
)
y_ref = torch.nn.functional.conv2d(
input_q.unpack().dequant(),
diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py
index 39078c5b9..e5efaa948 100644
--- a/sharktank/tests/ops/sharded_test.py
+++ b/sharktank/tests/ops/sharded_test.py
@@ -15,6 +15,40 @@
from sharktank.layers import Conv2DLayer
+class AllGatherTest(unittest.TestCase):
+ def testAllGather(self):
+ shard_count = 3
+ shard_shape = [3, 4]
+ shard_dim = 1
+ shards = [
+ torch.rand(shard_shape, dtype=torch.float32) for i in range(shard_count)
+ ]
+ expected_result = torch.cat(shards, dim=shard_dim)
+
+ sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
+ actual_result = ops.all_gather(sharded)
+
+ for shard in actual_result.shards:
+ torch.testing.assert_close(shard.as_torch(), expected_result)
+
+
+class AllReduceTest(unittest.TestCase):
+ def testAllReduce(self):
+ shard_count = 3
+ shard_shape = [3, 4]
+ shard_dim = 1
+ shards = [
+ torch.rand(shard_shape, dtype=torch.float32) for i in range(shard_count)
+ ]
+ expected_result = torch.add(torch.add(shards[0], shards[1]), shards[2])
+
+ sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
+ actual_result = ops.all_reduce(sharded)
+
+ for shard in actual_result.shards:
+ torch.testing.assert_close(shard.as_torch(), expected_result)
+
+
class CatTest(unittest.TestCase):
def testCatSplitDim(self):
"""Concatenation along the sharded split dimension."""
@@ -50,21 +84,6 @@ def testCatNonSplitDim(self):
class ConvTest(unittest.TestCase):
- def testAllGather(self):
- shard_count = 3
- shard_shape = [3, 4]
- shard_dim = 1
- shards = [
- torch.rand(shard_shape, dtype=torch.float32) for i in range(shard_count)
- ]
- expected_result = torch.cat(shards, dim=shard_dim)
-
- sharded = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
- actual_result = ops.all_gather(sharded)
-
- for shard in actual_result.shards:
- torch.testing.assert_close(shard.as_torch(), expected_result)
-
def testConv2dShardedInputAndOutputChannelsOneGroup(self):
batches = 2
in_channels = 6
@@ -337,6 +356,32 @@ def testNotEqualSharded(self):
assert not ops.equal(b_sharded, a_sharded)
+class FlattenTest(unittest.TestCase):
+ def testReplicated(self):
+ tensor = torch.rand(2, 3, 4, 5)
+ unsharded_expected_result = torch.flatten(tensor, start_dim=1, end_dim=2)
+ expected_result = ops.replicate(unsharded_expected_result, count=2)
+ sharded_tensor = ops.replicate(tensor, count=2)
+ actual_result = ops.flatten(sharded_tensor, start_dim=1, end_dim=2)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorFlattenNonSplitDim(self):
+ tensor = torch.rand(2, 3, 4, 5)
+ unsharded_expected_result = torch.flatten(tensor, start_dim=1, end_dim=2)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=3, count=2)
+ actual_result = ops.flatten(sharded_tensor, start_dim=1, end_dim=2)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorSplitDimIsLeadingFlattenDim(self):
+ tensor = torch.rand(3, 4, 5, 6)
+ unsharded_expected_result = torch.flatten(tensor, start_dim=1, end_dim=2)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=1, count=2)
+ actual_result = ops.flatten(sharded_tensor, start_dim=1, end_dim=2)
+ assert expected_result.is_deep_equal(actual_result)
+
+
class GemmTest(unittest.TestCase):
def testShardedParallelDim(self):
a = torch.rand(4, 3)
@@ -356,6 +401,47 @@ def testShardedParallelDim(self):
torch.testing.assert_close(actual, expected)
+class IndexCopyTest(unittest.TestCase):
+ def testSplitInPlace(self):
+ torch.set_default_dtype(torch.float32)
+ tensor = torch.rand(3, 4, 5, 6)
+ dim = 2
+ source = torch.rand(3, 4, 2, 6)
+ index = torch.tensor([1, 3])
+ expected_result = torch.index_copy(tensor, dim, index, source)
+
+ split_dim = 1
+ shard_count = 2
+ sharded_tensor = ops.reshard_split(tensor, dim=split_dim, count=shard_count)
+ sharded_index = ops.replicate(index, count=shard_count)
+ sharded_source = ops.reshard_split(source, dim=split_dim, count=shard_count)
+ sharded_result = ops.index_copy_(
+ sharded_tensor, dim, sharded_index, sharded_source
+ )
+ assert sharded_tensor is sharded_result
+ actual_result = ops.unshard(sharded_tensor)
+ assert ops.equal(actual_result, expected_result)
+
+
+class IndexPutTest(unittest.TestCase):
+ def testSplitNonIndexDimInPlace(self):
+ torch.set_default_dtype(torch.float32)
+ tensor = torch.rand(3, 4, 5, 6)
+ indices = (
+ torch.tensor([1, 2], dtype=torch.long),
+ torch.tensor([2, 3], dtype=torch.long),
+ )
+ values = torch.rand(2, 5, 6)
+ expected_result = tensor.clone().index_put_(indices, values)
+ shard_count = 2
+ sharded_tensor = ops.reshard_split(tensor.clone(), dim=3, count=shard_count)
+ sharded_values = ops.reshard_split(values, dim=2, count=shard_count)
+ sharded_result = ops.index_put_(sharded_tensor, indices, sharded_values)
+ assert sharded_tensor is sharded_result
+ actual_result = ops.unshard(sharded_tensor)
+ assert ops.equal(actual_result, expected_result)
+
+
class InterpolateTest(unittest.TestCase):
def testInterpolateSplitChannelDim(self):
batches = 2
@@ -502,6 +588,60 @@ def testShardedPrimitiveTensorPermute(self):
assert ops.equal(expected_result, result)
+class AttentionTest(unittest.TestCase):
+ def testAttentionShardedBatch(self):
+ q = torch.rand(4, 32, 16, dtype=torch.float32)
+ k = torch.rand(4, 32, 16, dtype=torch.float32)
+ v = torch.rand(4, 32, 16, dtype=torch.float32)
+
+ qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0))
+ ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0))
+ vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0))
+
+ expected_result = ops.scaled_dot_product_attention(q, k, v, a=None)
+ sharded_result = ops.scaled_dot_product_attention(qs, ks, vs, a=None)
+ unsharded_result = ops.sharded_cat(sharded_result)
+ torch.testing.assert_close(unsharded_result, expected_result)
+
+ def testAttentionShardedBatchCausal(self):
+ q = torch.rand(4, 32, 16, dtype=torch.float32)
+ k = torch.rand(4, 32, 16, dtype=torch.float32)
+ v = torch.rand(4, 32, 16, dtype=torch.float32)
+
+ qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0))
+ ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0))
+ vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0))
+
+ expected_result = ops.scaled_dot_product_attention(
+ q, k, v, a=None, is_causal=True
+ )
+ sharded_result = ops.scaled_dot_product_attention(
+ qs, ks, vs, a=None, is_causal=True
+ )
+ unsharded_result = ops.sharded_cat(sharded_result)
+ torch.testing.assert_close(unsharded_result, expected_result)
+
+ def testAttentionShardedBatchMask(self):
+ q = torch.rand(4, 32, 16, dtype=torch.float32)
+ k = torch.rand(4, 32, 16, dtype=torch.float32)
+ v = torch.rand(4, 32, 16, dtype=torch.float32)
+ a = torch.rand(1, 32, 32, dtype=torch.float32) > 0.5
+
+ q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0))
+ k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0))
+ v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0))
+ a_s = ReplicatedTensor(ts=a, shard_count=4)
+
+ expected_result = ops.scaled_dot_product_attention(
+ q, k, v, a=a, is_causal=False
+ )
+ sharded_result = ops.scaled_dot_product_attention(
+ q_s, k_s, v_s, a=a_s, is_causal=False
+ )
+ unsharded_result = ops.sharded_cat(sharded_result)
+ torch.testing.assert_close(unsharded_result, expected_result)
+
+
class MatmulTest(unittest.TestCase):
def testTorchRHSColumnShardedTransposed(self):
t1 = torch.rand(4, 32, 16, dtype=torch.float32)
@@ -669,6 +809,21 @@ def compute(input, ffn_gate_weight, ffn_down_weight, ffn_up_weight):
)
torch.testing.assert_close(Z_sharded, Z_ref)
+ def testSameSplitLhsAndRhsBatchDim(self):
+ a = torch.rand(3, 4, 5, 6)
+ b = torch.rand(3, 4, 6, 7)
+ shard_count = 2
+ shard_dim = 1
+ expected_result = torch.matmul(a, b)
+ sharded_a = ops.reshard_split(a, dim=shard_dim, count=shard_count)
+ sharded_b = ops.reshard_split(b, dim=shard_dim, count=shard_count)
+ sharded_result = ops.matmul(sharded_a, sharded_b)
+ assert isinstance(sharded_result, SplitPrimitiveTensor)
+ assert sharded_result.shard_count == shard_count
+ assert sharded_result.shard_dim == shard_dim
+ actual_result = unbox_tensor(ops.unshard(sharded_result))
+ torch.testing.assert_close(actual_result, expected_result)
+
class ReplicateTest(unittest.TestCase):
def testReplicateReplicated(self):
@@ -685,6 +840,102 @@ def testReplicateUnsharded(self):
expected_result = ReplicatedTensor(ts=tensor, shard_count=shard_count)
assert expected_result.is_deep_equal(actual_result)
+ # Test that is a copy.
+ tensor[...] = torch.rand_like(tensor)
+ assert all(not ops.equal(tensor, shard) for shard in actual_result.shards)
+
+
+class ReshapeTest(unittest.TestCase):
+ def testSplitTensorFlattenNonSplitDim(self):
+ tensor = torch.rand(2, 3, 4, 5)
+ new_shape = [2, 12, 5]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=3, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorSplitDimIsLeadingFlattenDim(self):
+ tensor = torch.rand(3, 4, 5, 6)
+ new_shape = [3, 20, 6]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=1, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorInsertSize1DimBeforeSplitDim(self):
+ tensor = torch.rand(4, 5, 6, 7)
+ new_shape = [4, 1, 5, 6, 7]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ shard_dim = 2
+ expected_result = ops.reshard_split(
+ unsharded_expected_result, dim=shard_dim + 1, count=2
+ )
+ sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorInsertMultipleSize1DimsBeforeSplitDim(self):
+ tensor = torch.rand(4, 5, 6, 7)
+ new_shape = [4, 1, 1, 5, 6, 7]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ shard_dim = 2
+ expected_result = ops.reshard_split(
+ unsharded_expected_result, dim=shard_dim + 2, count=2
+ )
+ sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorInsertMultipleSize1TrailingDimsNotRightAfterSplitDim(self):
+ tensor = torch.rand(4, 5, 6, 7)
+ new_shape = [4, 5, 6, 7, 1, 1]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ shard_dim = 2
+ expected_result = ops.reshard_split(
+ unsharded_expected_result, dim=shard_dim, count=2
+ )
+ sharded_tensor = ops.reshard_split(tensor, dim=shard_dim, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorUnflattenNonSplitDim(self):
+ tensor = torch.rand(3, 20, 6)
+ new_shape = [3, 4, 5, 6]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=3, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=2, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorUnflattenTrailingNonSplitDim(self):
+ tensor = torch.rand(3, 4, 30)
+ new_shape = [3, 4, 5, 6]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=1, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorUnflattenSplitDim(self):
+ tensor = torch.rand(3, 20, 6)
+ new_shape = [3, 4, 5, 6]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=1, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=1, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
+ def testSplitTensorUnflattenTrailingSplitDim(self):
+ tensor = torch.rand(2, 3, 20)
+ new_shape = [2, 3, 4, 5]
+ unsharded_expected_result = torch.reshape(tensor, new_shape)
+ expected_result = ops.reshard_split(unsharded_expected_result, dim=2, count=2)
+ sharded_tensor = ops.reshard_split(tensor, dim=2, count=2)
+ actual_result = ops.reshape(sharded_tensor, new_shape)
+ assert expected_result.is_deep_equal(actual_result)
+
class ReshardSplitTest(unittest.TestCase):
def testReshardReplicated(self):
@@ -708,6 +959,11 @@ def testReshardUnsharded(self):
)
assert expected_result.is_deep_equal(actual_result)
+ # Test that is a copy.
+ tensor[...] = torch.rand_like(tensor)
+ result_split2 = ops.reshard_split(tensor, dim=shard_dim, count=shard_count)
+ assert not ops.equal(actual_result, result_split2)
+
def testReshardSharded(self):
tensor = torch.rand(4, 5, 6, dtype=torch.float32)
shard_dim = 2
diff --git a/shortfin/tests/framework/device_session_test.py b/sharktank/tests/serving_poc/framework/device_session_test.py
similarity index 96%
rename from shortfin/tests/framework/device_session_test.py
rename to sharktank/tests/serving_poc/framework/device_session_test.py
index 7b4916eb4..5dfdd5f46 100644
--- a/shortfin/tests/framework/device_session_test.py
+++ b/sharktank/tests/serving_poc/framework/device_session_test.py
@@ -6,7 +6,7 @@
import pytest
-from shortfin.framework.session import (
+from sharktank.serving_poc.framework.session import (
DeviceSession,
)
diff --git a/shortfin/tests/llm/api_server_test.py b/sharktank/tests/serving_poc/llm/api_server_test.py
similarity index 93%
rename from shortfin/tests/llm/api_server_test.py
rename to sharktank/tests/serving_poc/llm/api_server_test.py
index fe4153dae..c2d2cc36a 100644
--- a/shortfin/tests/llm/api_server_test.py
+++ b/sharktank/tests/serving_poc/llm/api_server_test.py
@@ -39,7 +39,7 @@ def __init__(self, args):
[
sys.executable,
"-m",
- "shortfin.llm.api.rest_server",
+ "sharktank.serving_poc.llm.api.rest_server",
"--testing-mock-service",
"--port=" + port,
]
@@ -77,6 +77,11 @@ def __del__(self):
@pytest.fixture(scope="session")
def server():
+ try:
+ import fastapi
+ import uvicorn
+ except ModuleNotFoundError as e:
+ pytest.skip(f"Skipping server test because deps are missing: {e}")
runner = ServerRunner([])
yield runner
diff --git a/shortfin/tests/llm/service_v1_test.py b/sharktank/tests/serving_poc/llm/service_v1_test.py
similarity index 90%
rename from shortfin/tests/llm/service_v1_test.py
rename to sharktank/tests/serving_poc/llm/service_v1_test.py
index 56472ecc2..c010e2034 100644
--- a/shortfin/tests/llm/service_v1_test.py
+++ b/sharktank/tests/serving_poc/llm/service_v1_test.py
@@ -10,28 +10,28 @@
HalElementType,
)
-from shortfin.framework.session import DeviceSession
-from shortfin.llm.config import (
+from sharktank.serving_poc.framework.session import DeviceSession
+from sharktank.serving_poc.llm.config import (
CacheParams,
ModelParams,
ServiceParams,
)
-from shortfin.llm.service import (
+from sharktank.serving_poc.llm.service import (
GenerateRequest,
GenerateResponsePart,
)
-from shortfin.llm.attn_block_cache import (
+from sharktank.serving_poc.llm.attn_block_cache import (
create_attn_block_cache_module,
AttnBlockCache,
)
-from shortfin.llm.impl.service_v1 import (
+from sharktank.serving_poc.llm.impl.service_v1 import (
GenerateServiceV1,
)
-from shortfin.llm.testing.fake_v1_module import (
+from sharktank.serving_poc.llm.testing.fake_v1_module import (
create_fake_module,
)
diff --git a/sharktank/tests/transforms/dataset_transforms_test.py b/sharktank/tests/transforms/dataset_transforms_test.py
index 4a57b6229..928d1e4ac 100644
--- a/sharktank/tests/transforms/dataset_transforms_test.py
+++ b/sharktank/tests/transforms/dataset_transforms_test.py
@@ -19,8 +19,8 @@
from sharktank.utils.testing import MainRunnerTestBase
-class MmtRHSShardingTransformTest(MainRunnerTestBase):
- def testPrimitive(self):
+class DatasetShardingTransformTest(MainRunnerTestBase):
+ def testShardLlmDataset(self):
orig_pts = [
DefaultPrimitiveTensor(
name="blk.1.attn_k.weight", data=torch.randn([32, 128])
@@ -28,9 +28,19 @@ def testPrimitive(self):
DefaultPrimitiveTensor(
name="blk.2.attn_q.weight", data=torch.randn([48, 64])
),
- DefaultPrimitiveTensor(name="other", data=torch.randn([2, 2])),
]
- ds_orig = Dataset({}, Theta(orig_pts))
+ ds_orig = Dataset(
+ {
+ "general.architecture": "llm",
+ "llm.attention.head_count": 1,
+ "llm.context_length": 2,
+ "llm.embedding_length": 3,
+ "llm.block_count": 4,
+ "llm.feed_forward_length": 5,
+ "llm.attention.layer_norm_rms_epsilon": 0.1,
+ },
+ Theta(orig_pts),
+ )
input_path = self.save_dataset(ds_orig, "input")
output_path = self.get_irpa_path("output")
from sharktank.examples.sharding import shard_llm_dataset
@@ -41,38 +51,38 @@ def testPrimitive(self):
input_path,
"--output-irpa-file",
output_path,
- "--num-shards",
+ "--tensor-parallelism-size",
8,
)
ds_tran = Dataset.load(output_path, mmap=False)
+ ds_tran.properties["tensor_parallelism_size"] = 8
+
# Verify.
flat_sts = ds_tran.root_theta.flatten()
- self.assertEqual(3, len(flat_sts))
+ self.assertEqual(2, len(flat_sts))
st_1 = flat_sts["blk.1.attn_k.weight"]
st_2 = flat_sts["blk.2.attn_q.weight"]
- pt_3 = flat_sts["other"]
self.assertIsInstance(st_1, SplitPrimitiveTensor)
self.assertIsInstance(st_2, SplitPrimitiveTensor)
- self.assertIsInstance(pt_3, DefaultPrimitiveTensor)
self.assertListEqual(st_1.shape, [32, 128])
self.assertListEqual(st_2.shape, [48, 64])
# Verify component shapes for st_1.
self.assertEqual(8, len(st_1.shards))
- self.assertTrue(all(pt.shape == [32, 16] for pt in st_1.shards))
+ self.assertTrue(all(pt.shape == [4, 128] for pt in st_1.shards))
self.assertTrue(
- all(list(pt.as_torch().shape) == [32, 16] for pt in st_1.shards)
+ all(list(pt.as_torch().shape) == [4, 128] for pt in st_1.shards)
)
# Verify component shapes for st_2.
self.assertEqual(8, len(st_2.shards))
- self.assertTrue(all(pt.shape == [48, 8] for pt in st_2.shards))
- self.assertTrue(all(list(pt.as_torch().shape) == [48, 8] for pt in st_2.shards))
+ self.assertTrue(all(pt.shape == [6, 64] for pt in st_2.shards))
+ self.assertTrue(all(list(pt.as_torch().shape) == [6, 64] for pt in st_2.shards))
# Verify contents for one shard for sanity.
new_t = st_1.shards[0].as_torch()
- torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(16, dim=1)[0])
+ torch.testing.assert_close(new_t, orig_pts[0].as_torch().split(4, dim=0)[0])
if __name__ == "__main__":
diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py
index 0d79785f6..4494eab2f 100644
--- a/sharktank/tests/types/dataset_test.py
+++ b/sharktank/tests/types/dataset_test.py
@@ -11,12 +11,12 @@
import torch
-from shark_turbine.aot import ExternalTensorTrait
+from iree.turbine.aot import ExternalTensorTrait
from sharktank.types import *
def _t(name: str, *dims: int):
- return DefaultPrimitiveTensor(name=name, data=torch.empty(*dims))
+ return DefaultPrimitiveTensor(name=name, data=torch.ones(*dims))
def _flat_t_dict(*ts):
@@ -77,6 +77,22 @@ def testTransform(self):
self.assertIsNot(pt1, pt2)
torch.testing.assert_close(pt1, pt2)
+ def testPop(self):
+ t1 = Theta(
+ _flat_t_dict(
+ _t("a.b.c", 1, 2),
+ _t("a.c.d", 10, 11),
+ _t("a.b.3", 3, 4),
+ )
+ )
+ popped = t1.pop("a.b").flatten()
+ t1 = t1.flatten()
+
+ self.assertIsNotNone("a.c.d", t1.keys())
+ self.assertNotIn("a.b.c", t1.keys())
+ self.assertNotIn("a.b.3", t1.keys())
+ self.assertIn("a.b.3", popped.keys())
+
class DatasetTest(unittest.TestCase):
def setUp(self):
diff --git a/sharktank/tests/types/quantizers_test.py b/sharktank/tests/types/quantizers_test.py
index 787725e88..b712da06a 100644
--- a/sharktank/tests/types/quantizers_test.py
+++ b/sharktank/tests/types/quantizers_test.py
@@ -9,6 +9,7 @@
import torch
from sharktank.types import *
+from sharktank.types.layout_utils import saturate_cast
from sharktank.utils.testing import TempDirTestBase
@@ -164,6 +165,80 @@ def testQuantDequantf8fnuz(self):
dequant_value = layout.dequant()
torch.testing.assert_close(orig_value, dequant_value, atol=1e-1, rtol=1e-1)
+ def testQuarkF8Hell(self):
+ # we use hardcoded values here because they're representative of actual values from a quark model
+ scale = torch.tensor(0.0118, dtype=torch.float64)
+ orig = torch.tensor(
+ [
+ -58,
+ -48,
+ -70,
+ 53,
+ -53,
+ 76,
+ -71,
+ -90,
+ 50,
+ 77,
+ 62,
+ -98,
+ 66,
+ -54,
+ 55,
+ -80,
+ -66,
+ -62,
+ -61,
+ -56,
+ 56,
+ -67,
+ 79,
+ -60,
+ -71,
+ 42,
+ 72,
+ -73,
+ 91,
+ 63,
+ 124,
+ -128,
+ ],
+ dtype=torch.int8,
+ )
+ # mirrors dequant logic in quark and our importer
+ orig = orig.view(torch.float8_e4m3fn)
+ orig = (orig.to(torch.float64) * scale).to(torch.float16)
+ # Note that for fnuz we have to do scale*2 to account for the difference between types
+ # We specify the reciprocal scale explicitly to avoid adding more floating point error noise
+ fnuz = StaticScaledQuantizer(
+ name="dopoo",
+ scale=1.0 / (scale * 2),
+ reciprocal_scale=scale * 2,
+ offset=None,
+ dtype=torch.float8_e4m3fnuz,
+ )
+ fn = StaticScaledQuantizer(
+ name="poodoo",
+ scale=1.0 / scale,
+ reciprocal_scale=scale,
+ offset=None,
+ dtype=torch.float8_e4m3fn,
+ )
+ fnuz_quant = fnuz.quantize(orig)
+ fn_quant = fn.quantize(orig)
+
+ dequant_fnuz = fnuz_quant.unpack().dequant()
+ dequant_fn = fn_quant.unpack().dequant()
+
+ # redundant asserts for sanity
+ torch.testing.assert_close(
+ orig.to(torch.float16), dequant_fnuz, atol=1e-3, rtol=1e-3
+ )
+ torch.testing.assert_close(
+ orig.to(torch.float16), dequant_fn, atol=1e-3, rtol=1e-3
+ )
+ torch.testing.assert_close(dequant_fnuz, dequant_fn, atol=1e-3, rtol=1e-3)
+
if __name__ == "__main__":
unittest.main()
diff --git a/sharktank/tests/types/tensors_test.py b/sharktank/tests/types/tensors_test.py
index 8ee839b6e..4af4a513f 100644
--- a/sharktank/tests/types/tensors_test.py
+++ b/sharktank/tests/types/tensors_test.py
@@ -62,10 +62,8 @@ def transform2(d):
class ShardedTensorTest(unittest.TestCase):
def testReplicatedTensorSaveLoad(self):
- tensor = torch.rand([2, 3, 4], dtype=torch.float32)
- replicated_tensor = ReplicatedTensor(
- ts=tensor, shard_count=3, name="the_tensor"
- )
+ tensor = [torch.rand([2, 3, 4], dtype=torch.float32)] * 3
+ replicated_tensor = ReplicatedTensor(ts=tensor, name="the_tensor")
theta = Theta([replicated_tensor])
dataset = Dataset({}, theta)
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -140,6 +138,34 @@ def testSplitTensorExtractSliceOfNonSplitDim(self):
actual_result = ops.reshard_like(sharded_slice, expected_result)
assert ops.equal(expected_result, actual_result)
+ def testSplitTensorExtractSliceWithEllipsis(self):
+ tensor = torch.rand([2, 3, 4, 5])
+ sharded_tensor = ops.reshard_split(tensor, dim=2, count=2)
+ expected_result = tensor[0, ..., 1:3]
+ expected_sharded_result = ops.reshard_split(expected_result, dim=1, count=2)
+ actual_sharded_result = sharded_tensor[0, ..., 1:3]
+ assert ops.equal(actual_sharded_result, expected_sharded_result)
+
+ def testSplitTensorInsertSliceOfAllDimsWithEllipsis(self):
+ dst = torch.rand([2, 3, 4])
+ src = torch.rand([2, 3, 4])
+ sharded_dst = ops.reshard_split(dst.clone(), dim=1, count=3)
+ sharded_src = ops.reshard_like(src, like=sharded_dst)
+ dst[...] = src
+ sharded_dst[...] = sharded_src
+ actual_result = ops.unshard(sharded_dst)
+ assert ops.equal(actual_result, dst)
+
+ def testSplitTensorInsertSliceWithEllipsis(self):
+ dst = torch.rand([2, 3, 4, 5])
+ src = torch.rand([3, 4, 2])
+ sharded_dst = ops.reshard_split(dst.clone(), dim=2, count=2)
+ sharded_src = ops.reshard_split(src, dim=1, count=2)
+ dst[0, ..., 1:3] = src
+ sharded_dst[0, ..., 1:3] = sharded_src
+ actual_result = ops.unshard(sharded_dst)
+ assert ops.equal(actual_result, dst)
+
if __name__ == "__main__":
unittest.main()
diff --git a/sharktank/version.json b/sharktank/version.json
new file mode 100644
index 000000000..9519501ae
--- /dev/null
+++ b/sharktank/version.json
@@ -0,0 +1,3 @@
+{
+ "package-version": "3.1.0.dev"
+}
diff --git a/libshortfin/.clang-format b/shortfin/.clang-format
similarity index 100%
rename from libshortfin/.clang-format
rename to shortfin/.clang-format
diff --git a/shortfin/.gitignore b/shortfin/.gitignore
new file mode 100644
index 000000000..000e575d5
--- /dev/null
+++ b/shortfin/.gitignore
@@ -0,0 +1,2 @@
+# Local-only config options
+version_info_rc.json
diff --git a/shortfin/.readthedocs.yaml b/shortfin/.readthedocs.yaml
new file mode 100644
index 000000000..582088587
--- /dev/null
+++ b/shortfin/.readthedocs.yaml
@@ -0,0 +1,16 @@
+version: "2"
+
+build:
+ os: "ubuntu-24.04"
+ tools:
+ python: "3.12"
+ jobs:
+ pre_build:
+ - python -m pip install -v shortfin/
+
+python:
+ install:
+ - requirements: shortfin/docs/requirements.txt
+
+sphinx:
+ configuration: shortfin/docs/conf.py
diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt
new file mode 100644
index 000000000..f025eccfe
--- /dev/null
+++ b/shortfin/CMakeLists.txt
@@ -0,0 +1,268 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+cmake_minimum_required(VERSION 3.29)
+
+if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR)
+ message(
+ FATAL_ERROR
+ "Do not build in-source. Please remove CMakeCache.txt and the CMakeFiles/ directory. Then build out-of-source."
+ )
+endif()
+
+# Get version number from file
+file(READ ${CMAKE_CURRENT_SOURCE_DIR}/version.json VERSION_JSON_STRING)
+string(JSON PACKAGE_VERSION GET ${VERSION_JSON_STRING} package-version)
+string(REGEX MATCH "(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*" BASE_VERSION ${PACKAGE_VERSION})
+
+project(
+ "libshortfin"
+ VERSION ${BASE_VERSION}
+ LANGUAGES C CXX)
+
+include(CMakeDependentOption)
+
+set(SOVERSION 1)
+
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_CXX_STANDARD 20)
+# https://discourse.cmake.org/t/cmake-3-28-cmake-cxx-compiler-clang-scan-deps-notfound-not-found/9244/3
+set(CMAKE_CXX_SCAN_FOR_MODULES 0)
+set(CMAKE_EXPORT_COMPILE_COMMANDS 1)
+
+# Problems with linking libfmt without PIC.
+# Turn on PIC on non windows targets.
+if(NOT WIN32)
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+endif()
+
+# Pins
+set(SHORTFIN_IREE_GIT_TAG "iree-3.0.0rc20241118")
+
+# build options
+option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF)
+option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON)
+option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" ON)
+option(SHORTFIN_ENABLE_TRACING "Enable runtime tracing for iree and shortfin" OFF)
+option(SHORTFIN_ENABLE_LTO "Enables LTO if supported" ON)
+
+set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source")
+
+# Options for building static or dynamic libraries.
+# Default to dynamic linking, unless on Windows.
+# TODO(#211): Unify the defaults once Windows dynamic linking issues are fixed.
+set(SHORTFIN_BUILD_STATIC_DEFAULT OFF)
+set(SHORTFIN_BUILD_DYNAMIC_DEFAULT ON)
+if(WIN32)
+ set(SHORTFIN_BUILD_STATIC_DEFAULT ON)
+ set(SHORTFIN_BUILD_DYNAMIC_DEFAULT OFF)
+endif()
+option(SHORTFIN_BUILD_STATIC "Builds static libraries" ${SHORTFIN_BUILD_STATIC_DEFAULT})
+option(SHORTFIN_BUILD_DYNAMIC "Builds dynamic libraries" ${SHORTFIN_BUILD_DYNAMIC_DEFAULT})
+cmake_dependent_option(SHORTFIN_LINK_DYNAMIC "Links internal binaries against static libshortfin.a" ON "SHORTFIN_BUILD_DYNAMIC" OFF)
+if(NOT SHORTFIN_BUILD_STATIC AND NOT SHORTFIN_BUILD_DYNAMIC)
+ message(FATAL_ERROR "One of SHORTFIN_BUILD_STATIC or SHORTFIN_BUILD_DYNAMIC must be ON")
+endif()
+message(STATUS "Shortfin build static = ${SHORTFIN_BUILD_STATIC}, dynamic = ${SHORTFIN_BUILD_DYNAMIC}")
+if(SHORTFIN_LINK_DYNAMIC)
+ message(STATUS "Dynamic linking to shortfin")
+ set(SHORTFIN_LINK_LIBRARY_NAME "shortfin")
+else()
+ message(STATUS "Static linking to shortfin-static")
+ set(SHORTFIN_LINK_LIBRARY_NAME "shortfin-static")
+endif()
+
+# Includes.
+list(APPEND CMAKE_MODULE_PATH
+ ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/
+)
+include(shortfin_library)
+include(CheckCXXCompilerFlag)
+include(FetchContent)
+
+################################################################################
+# Toolchain features
+################################################################################
+
+if(SHORTFIN_ENABLE_LTO)
+ include(CheckIPOSupported)
+ check_ipo_supported(RESULT SHORTFIN_LTO_SUPPORTED OUTPUT SHORTFIN_LTO_ERROR)
+ if(SHORTFIN_LTO_SUPPORTED)
+ message(STATUS "Shortfin LTO Enabled")
+ set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
+ else()
+ message(WARNING "Could not enable LTO (not supported): ${SHORTFIN_LTO_ERROR}")
+ endif()
+endif()
+
+# Enabling ASAN. Note that this will work best if building in a completely
+# bundled fashion and with an ASAN rigged CPython. Otherwise, various LD_PRELOAD
+# hacks are needed. This is merely a develope convenience: people are more
+# than welcome to set flags themselves.
+option(SHORTFIN_ENABLE_ASAN "Enable ASAN" OFF)
+if(SHORTFIN_ENABLE_ASAN)
+ add_compile_options(-fsanitize=address)
+ add_link_options(-fsanitize=address)
+
+ # Enable more ASAN checks.
+ add_compile_definitions(IREE_SANITIZER_ADDRESS)
+endif()
+
+# Thread safety annotations: Enabled if the compiler supports it.
+check_cxx_compiler_flag("-Wthread-safety" SHORTFIN_HAS_THREAD_SAFETY_ANNOTATIONS)
+if(SHORTFIN_HAS_THREAD_SAFETY)
+ add_compile_options(-Wthread-safety)
+ add_compile_definitions(SHORTFIN_HAS_THREAD_SAFETY_ANNOTATIONS)
+endif()
+
+option(SHORTFIN_SYSTEMS_AMDGPU "Builds for AMD GPU systems" ON)
+message(STATUS "shortfin supported systems:")
+if(SHORTFIN_SYSTEMS_AMDGPU)
+ message(STATUS " - AMD GPU")
+endif()
+message(STATUS " - Host")
+
+################################################################################
+# Dependencies
+################################################################################
+
+if(SHORTFIN_BUNDLE_DEPS)
+ ## fmt
+ FetchContent_Declare(
+ fmt
+ GIT_REPOSITORY https://github.com/fmtlib/fmt.git
+ GIT_TAG e69e5f977d458f2650bb346dadf2ad30c5320281 # 10.2.1 (sync with spdlog)
+ )
+
+ ## spdlog
+ # We build fmt from source instead, because we also use fmt.
+ set(SPDLOG_FMT_EXTERNAL ON)
+ FetchContent_Declare(
+ spdlog
+ GIT_REPOSITORY https://github.com/gabime/spdlog.git
+ GIT_TAG 2d4acf8cc321d7783d8f2e22e17a794c6d0e9450 # v1.14.1
+ )
+
+ ## xtl: required for xtensor
+ FetchContent_Declare(
+ xtl
+ GIT_REPOSITORY https://github.com/xtensor-stack/xtl.git
+ GIT_TAG a7c1c5444dfc57f76620391af4c94785ff82c8d6 # v0.7.7
+ )
+
+ ## xtensor
+ FetchContent_Declare(
+ xtensor
+ GIT_REPOSITORY https://github.com/xtensor-stack/xtensor.git
+ GIT_TAG 3634f2ded19e0cf38208c8b86cea9e1d7c8e397d # v0.25.0
+ )
+
+ # In order to bundle libraries without conflict, we have to tweak settings.
+ shortfin_push_bundled_lib_options()
+ # Enable spdlog shared library options so we can export it.
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPDLOG_SHARED_LIB -Dspdlog_EXPORTS")
+ FetchContent_MakeAvailable(fmt spdlog xtl xtensor)
+ shortfin_pop_bundled_lib_options()
+else()
+ find_package(spdlog)
+ find_package(xtensor)
+endif()
+
+################################################################################
+# IREE
+################################################################################
+
+# Set IREE build flags.
+# We currently rely on IREE to have visible symbols in order to re-export
+# its API for further use.
+# TODO: Turn this back on and use explicit visibility control in the IREE
+# runtime and linker scripts.
+set(IREE_VISIBILITY_HIDDEN OFF)
+set(IREE_BUILD_COMPILER OFF)
+set(IREE_BUILD_TESTS OFF)
+set(IREE_BUILD_SAMPLES OFF)
+# Disable missing submodules error because we are only building the runtime.
+set(IREE_ERROR_ON_MISSING_SUBMODULES OFF)
+# Only enable local_sync/local_task/hip drivers for now.
+set(IREE_HAL_DRIVER_DEFAULTS OFF)
+set(IREE_HAL_DRIVER_LOCAL_SYNC ON)
+set(IREE_HAL_DRIVER_LOCAL_TASK ON)
+if(SHORTFIN_SYSTEMS_AMDGPU)
+ set(IREE_HAL_DRIVER_HIP ON)
+endif()
+if (SHORTFIN_ENABLE_TRACING)
+ set(IREE_ENABLE_RUNTIME_TRACING ON)
+ # When using shared libraries there are some issues that need to be
+ # explored more on static initialization order. Something is getting
+ # initialized and is emitting tracy events before tracy objects are
+ # initialized. This can point to some shared library overloading allocation
+ # functions and making them emit tracy events, which are further used in
+ # some static allocation. See https://github.com/wolfpld/tracy/issues/196
+ # for a similar issue and discussion. Using the workaround suggested in
+ # that issue for now. Note that this does not happen when using static
+ # libraries.
+ set(TRACY_DELAYED_INIT ON CACHE BOOL "Enable delayed init for tracy")
+endif()
+
+# In order to bundle libraries without conflict, we have to tweak settings.
+shortfin_push_bundled_lib_options()
+if(SHORTFIN_IREE_SOURCE_DIR)
+ message(STATUS "Using existing IREE sources: ${SHORTFIN_IREE_SOURCE_DIR}")
+ add_subdirectory(${SHORTFIN_IREE_SOURCE_DIR} shortfin_iree SYSTEM EXCLUDE_FROM_ALL)
+else()
+ message(STATUS "Fetching IREE sources from tag ${SHORTFIN_IREE_GIT_TAG}")
+
+ # TODO: We shouldn't have to pull googletest when we are not building tests.
+ # This needs to be fixed with IREE.
+ set(IREE_SUBMODULES "third_party/benchmark third_party/cpuinfo third_party/flatcc third_party/hip-build-deps third_party/googletest")
+ if (SHORTFIN_ENABLE_TRACING)
+ set(IREE_SUBMODULES "${IREE_SUBMODULES} third_party/tracy")
+ endif()
+ FetchContent_Declare(
+ shortfin_iree
+ GIT_REPOSITORY https://github.com/iree-org/iree.git
+ GIT_TAG "${SHORTFIN_IREE_GIT_TAG}"
+ GIT_SUBMODULES ${IREE_SUBMODULES}
+ GIT_SHALLOW TRUE
+ SYSTEM
+ EXCLUDE_FROM_ALL
+ )
+ FetchContent_GetProperties(shortfin_iree)
+ if(NOT shortfin_iree_POPULATED)
+ FetchContent_MakeAvailable(shortfin_iree)
+ endif()
+endif()
+shortfin_pop_bundled_lib_options()
+
+################################################################################
+# Tests
+################################################################################
+
+if(SHORTFIN_BUILD_TESTS)
+ if (NOT SHORTFIN_BUNDLE_DEPS AND NOT SHORTFIN_IREE_SOURCE_DIR)
+ # For now we use gtest shipped alongside with IREE.
+ FetchContent_Declare(
+ googletest
+ URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
+ )
+ # For Windows: Prevent overriding the parent project's compiler/linker settings
+ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
+ FetchContent_MakeAvailable(googletest)
+ endif()
+ include(GoogleTest)
+ enable_testing()
+endif()
+
+
+add_subdirectory(src)
+
+if(SHORTFIN_BUILD_PYTHON_BINDINGS)
+ find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
+ add_subdirectory(python)
+ set(SHORTFIN_PYTHON_CPP_PREBUILT "TRUE") # See setup.py.
+ configure_file(setup.py setup.py @ONLY)
+ configure_file(pyproject.toml pyproject.toml COPYONLY)
+endif()
diff --git a/shortfin/README.md b/shortfin/README.md
index a76fad8c7..6269ca702 100644
--- a/shortfin/README.md
+++ b/shortfin/README.md
@@ -1,15 +1,215 @@
-# SHARK Shortfin Serving Infrastructure
+# shortfin - SHARK inference library and serving engine
-**WARNING: This is an early preview that is in progress. It is not ready for
-general use.**
+The shortfin project is SHARK's open source, high performance inference library
+and serving engine. Shortfin consists of these major components:
-This sub-project contains components and infrastructure for serving various
-forms of sharktank compiled models. Instead of coming with models, it defines
-ABIs that compiled models should adhere to in order to be served. It then
-allows them to be delivered as web endpoints via popular APIs.
+* The "libshortfin" inference library written in C/C++ and built on
+ [IREE](https://github.com/iree-org/iree)
+* Python bindings for the underlying inference library
+* Example applications in
+ ['shortfin_apps'](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps)
+ built using the python bindings
-As emulation can be the sincerest form of flattery, this project derives
-substantial inspiration from vllm and the OpenAI APIs, emulating and
-interopping with them where possible. It is intended to be the lightest
-weight possible reference implementation for serving models with an
-opinionated compiled form, built elsewhere in the project.
+## Prerequisites
+
+* Python 3.11+
+
+## Simple user installation
+
+Install the latest stable version:
+
+```bash
+pip install shortfin
+```
+
+## Developer guides
+
+### Quick start: install local packages and run tests
+
+After cloning this repository, from the `shortfin/` directory:
+
+```bash
+pip install -e .
+```
+
+Install test requirements:
+
+```bash
+pip install -r requirements-tests.txt
+```
+
+Run tests:
+
+```bash
+pytest -s tests/
+```
+
+### Simple dev setup
+
+We recommend this development setup for core contributors:
+
+1. Check out this repository as a sibling to [IREE](https://github.com/iree-org/iree)
+ if you already have an IREE source checkout. Otherwise, a pinned version will
+ be downloaded for you
+2. Ensure that `python --version` reads 3.11 or higher (3.12 preferred).
+3. Run `./dev_me.py` to build and install the `shortfin` Python package with both
+ a tracing-enabled and default build. Run it again to do an incremental build
+ and delete the `build/` directory to start over
+4. Run tests with `python -m pytest -s tests/`
+5. Test optional features:
+ * `pip install iree-base-compiler` to run a small suite of model tests intended
+ to exercise the runtime (or use a [source build of IREE](https://iree.dev/building-from-source/getting-started/#using-the-python-bindings)).
+ * `pip install onnx` to run some more model tests that depend on downloading
+ ONNX models
+ * Run tests on devices other than the CPU with flags like:
+ `--system amdgpu --compile-flags="--iree-hal-target-backends=rocm --iree-hip-target=gfx1100"`
+ * Use the tracy instrumented runtime to collect execution traces:
+ `export SHORTFIN_PY_RUNTIME=tracy`
+
+Refer to the advanced build options below for other scenarios.
+
+### Advanced build options
+
+1. Native C++ build
+2. Local Python release build
+3. Package Python release build
+4. Python dev build
+
+Prerequisites
+
+* A modern C/C++ compiler, such as clang 18 or gcc 12
+* A modern Python, such as Python 3.12
+
+#### Native C++ builds
+
+```bash
+cmake -GNinja -S. -Bbuild \
+ -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
+ -DCMAKE_LINKER_TYPE=LLD
+cmake --build build --target all
+```
+
+If Python bindings are enabled in this mode (`-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON`),
+then `pip install -e build/` will install from the build dir (and support
+build/continue).
+
+#### Package Python release builds
+
+* To build wheels for Linux using a manylinux Docker container:
+
+ ```bash
+ sudo ./build_tools/build_linux_package.sh
+ ```
+
+* To build a wheel for your host OS/arch manually:
+
+ ```bash
+ # Build shortfin.*.whl into the dist/ directory
+ # e.g. `shortfin-0.9-cp312-cp312-linux_x86_64.whl`
+ python3 -m pip wheel -v -w dist .
+
+ # Install the built wheel.
+ python3 -m pip install dist/*.whl
+ ```
+
+#### Python dev builds
+
+```bash
+# Install build system pre-reqs (since we are building in dev mode, this
+# is not done for us). See source of truth in pyproject.toml:
+pip install setuptools wheel
+
+# Optionally install cmake and ninja if you don't have them or need a newer
+# version. If doing heavy development in Python, it is strongly recommended
+# to install these natively on your system as it will make it easier to
+# switch Python interpreters and build options (and the launcher in debug/asan
+# builds of Python is much slower). Note CMakeLists.txt for minimum CMake
+# version, which is usually quite recent.
+pip install cmake ninja
+
+SHORTFIN_DEV_MODE=ON pip install --no-build-isolation -v -e .
+```
+
+Note that the `--no-build-isolation` flag is useful in development setups
+because it does not create an intermediate venv that will keep later
+invocations of cmake/ninja from working at the command line. If just doing
+a one-shot build, it can be ommitted.
+
+Once built the first time, `cmake`, `ninja`, and `ctest` commands can be run
+directly from `build/cmake` and changes will apply directly to the next
+process launch.
+
+Several optional environment variables can be used with setup.py:
+
+* `SHORTFIN_CMAKE_BUILD_TYPE=Debug` : Sets the CMAKE_BUILD_TYPE. Defaults to
+ `Debug` for dev mode and `Release` otherwise.
+* `SHORTFIN_ENABLE_ASAN=ON` : Enables an ASAN build. Requires a Python runtime
+ setup that is ASAN clean (either by env vars to preload libraries or set
+ suppressions or a dev build of Python with ASAN enabled).
+* `SHORTFIN_IREE_SOURCE_DIR=$(pwd)/../../iree`
+* `SHORTFIN_RUN_CTESTS=ON` : Runs `ctest` as part of the build. Useful for CI
+ as it uses the version of ctest installed in the pip venv.
+
+### Running tests
+
+The project uses a combination of ctest for native C++ tests and pytest. Much
+of the functionality is only tested via the Python tests, using the
+`_shortfin.lib` internal implementation directly. In order to run these tests,
+you must have installed the Python package as per the above steps.
+
+Which style of test is used is pragmatic and geared at achieving good test
+coverage with a minimum of duplication. Since it is often much more expensive
+to build native tests of complicated flows, many things are only tested via
+Python. This does not preclude having other language bindings later, but it
+does mean that the C++ core of the library must always be built with the
+Python bindings to test the most behavior. Given the target of the project,
+this is not considered to be a significant issue.
+
+#### Python tests
+
+Run platform independent tests only:
+
+```bash
+pytest tests/
+```
+
+Run tests including for a specific platform (in this example, a gfx1100 AMDGPU):
+
+(note that not all tests are system aware yet and some may only run on the CPU)
+
+```bash
+pytest tests/ --system amdgpu \
+ --compile-flags="--iree-hal-target-backends=rocm --iree-hip-target=gfx1100"
+```
+
+## Production library building
+
+In order to build a production library, additional build steps are typically
+recommended:
+
+* Compile all deps with the same compiler/linker for LTO compatibility
+* Provide library dependencies manually and compile them with LTO
+* Compile dependencies with `-fvisibility=hidden`
+* Enable LTO builds of libshortfin
+* Set flags to enable symbol versioning
+
+## Miscellaneous build topics
+
+### Free-threaded Python
+
+Support for free-threaded Python builds (aka. "nogil") is in progress. It
+is currently being tested via CPython 3.13 with the `--disable-gil` option set.
+There are multiple ways to acquire such an environment:
+
+* Generally, see the documentation at
+
+* If using `pyenv`:
+
+ ```bash
+ # Install a free-threaded 3.13 version.
+ pyenv install 3.13t
+
+ # Test (should print "False").
+ pyenv shell 3.13t
+ python -c 'import sys; print(sys._is_gil_enabled())'
+ ```
diff --git a/shortfin/build_tools/build_linux_package.sh b/shortfin/build_tools/build_linux_package.sh
new file mode 100755
index 000000000..9f7388119
--- /dev/null
+++ b/shortfin/build_tools/build_linux_package.sh
@@ -0,0 +1,122 @@
+#!/bin/bash
+# Copyright 2024 Advanced Micro Devices, Inc.
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# build_linux_package.sh
+# One stop build of shortfin Python packages for Linux. The Linux build is
+# complicated because it has to be done via a manylinux docker container.
+#
+# Usage:
+# Build everything (all python versions):
+# sudo ./build_tools/build_linux_package.sh
+#
+# Build specific Python versions to custom directory:
+# OVERRIDE_PYTHON_VERSIONS="cp312-cp312 cp313-cp313" \
+# OUTPUT_DIR="/tmp/wheelhouse" \
+# sudo -E ./build_tools/build_linux_package.sh
+#
+# Valid Python versions match a subdirectory under /opt/python in the docker
+# image. Typically:
+# cp312-cp312 cp313-cp313
+#
+# Note that this script is meant to be run on CI and it will pollute both the
+# output directory and in-tree build/ directories with docker created, root
+# owned builds. Sorry - there is no good way around it.
+#
+# It can be run on a workstation but recommend using a git worktree dedicated
+# to packaging to avoid stomping on development artifacts.
+set -xeu -o errtrace
+
+THIS_DIR="$(cd $(dirname $0) && pwd)"
+REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)"
+SCRIPT_NAME="$(basename $0)"
+ARCH="$(uname -m)"
+
+# TODO(#130): Update to manylinux_2_28, upstream or a fork
+# * upstream uses a version of gcc that has build warnings/errors
+# * https://github.com/nod-ai/base-docker-images is a bit out of date but can include a recent clang
+# MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux_2_28_${ARCH}:latest}"
+MANYLINUX_DOCKER_IMAGE="${MANYLINUX_DOCKER_IMAGE:-quay.io/pypa/manylinux2014_${ARCH}:latest}"
+PYTHON_VERSIONS="${OVERRIDE_PYTHON_VERSIONS:-cp311-cp311 cp312-cp312 cp313-cp313}"
+OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}"
+
+function run_on_host() {
+ echo "Running on host"
+ echo "Launching docker image ${MANYLINUX_DOCKER_IMAGE}"
+
+ # Canonicalize paths.
+ mkdir -p "${OUTPUT_DIR}"
+ OUTPUT_DIR="$(cd "${OUTPUT_DIR}" && pwd)"
+ echo "Outputting to ${OUTPUT_DIR}"
+ mkdir -p "${OUTPUT_DIR}"
+ docker run --rm \
+ -v "${REPO_ROOT}:${REPO_ROOT}" \
+ -v "${OUTPUT_DIR}:${OUTPUT_DIR}" \
+ -e __MANYLINUX_BUILD_WHEELS_IN_DOCKER=1 \
+ -e "OVERRIDE_PYTHON_VERSIONS=${PYTHON_VERSIONS}" \
+ -e "OUTPUT_DIR=${OUTPUT_DIR}" \
+ "${MANYLINUX_DOCKER_IMAGE}" \
+ -- ${THIS_DIR}/${SCRIPT_NAME}
+
+ echo "******************** BUILD COMPLETE ********************"
+ echo "Generated binaries:"
+ ls -l "${OUTPUT_DIR}"
+}
+
+function run_in_docker() {
+ echo "Running in docker"
+ echo "Marking git safe.directory"
+ git config --global --add safe.directory '*'
+
+ echo "Using python versions: ${PYTHON_VERSIONS}"
+ local orig_path="${PATH}"
+
+ # Build phase.
+ echo "******************** BUILDING PACKAGE ********************"
+ for python_version in ${PYTHON_VERSIONS}; do
+ python_dir="/opt/python/${python_version}"
+ if ! [ -x "${python_dir}/bin/python" ]; then
+ echo "ERROR: Could not find python: ${python_dir} (skipping)"
+ continue
+ fi
+ export PATH="${python_dir}/bin:${orig_path}"
+ echo ":::: Python version $(python --version)"
+ clean_wheels "shortfin" "${python_version}"
+ build_shortfin
+ run_audit_wheel "shortfin" "${python_version}"
+ done
+}
+
+function build_shortfin() {
+ export SHORTFIN_ENABLE_TRACING=ON
+ python -m pip wheel --disable-pip-version-check -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/shortfin"
+}
+
+function run_audit_wheel() {
+ local wheel_basename="$1"
+ local python_version="$2"
+ # Force wildcard expansion here
+ generic_wheel="$(echo "${OUTPUT_DIR}/${wheel_basename}-"*"-${python_version}-linux_${ARCH}.whl")"
+ ls "${generic_wheel}"
+ echo ":::: Auditwheel ${generic_wheel}"
+ auditwheel repair -w "${OUTPUT_DIR}" "${generic_wheel}"
+ rm -v "${generic_wheel}"
+}
+
+function clean_wheels() {
+ local wheel_basename="$1"
+ local python_version="$2"
+ echo ":::: Clean wheels ${wheel_basename} ${python_version}"
+ rm -f -v "${OUTPUT_DIR}/${wheel_basename}-"*"-${python_version}-"*".whl"
+}
+
+# Trampoline to the docker container if running on the host.
+if [ -z "${__MANYLINUX_BUILD_WHEELS_IN_DOCKER-}" ]; then
+ run_on_host "$@"
+else
+ run_in_docker "$@"
+fi
diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake
new file mode 100644
index 000000000..aaa97a6c1
--- /dev/null
+++ b/shortfin/build_tools/cmake/shortfin_library.cmake
@@ -0,0 +1,212 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+set(SHORTFIN_DEFAULT_COPTS
+ # General clang and GCC options application to C and C++.
+ $<$:
+ -Wall
+ -Werror
+ >
+
+ # General MSVC options applicable to C and C++.
+ $<$:
+ >
+)
+
+function(shortfin_public_library)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME;LINUX_LD_SCRIPT"
+ "SRCS;COMPONENTS;USAGE_DEPS"
+ ${ARGN}
+ )
+
+ # Get usage requirements from requested USAGE_DEPS and forward them from the
+ # public library. This is because the contents of the libraries stop at the
+ # public library vs propagating to callers. So we must manually control
+ # the combined usage requirements of the aggregate library.
+ set(_usage_include_directories)
+ set(_usage_compile_definitions)
+ foreach(_usage_dep _shortfin_defs ${_RULE_USAGE_DEPS})
+ get_target_property(_value ${_usage_dep} INTERFACE_INCLUDE_DIRECTORIES)
+ if(_value)
+ list(APPEND _usage_include_directories ${_value})
+ endif()
+ get_target_property(_value ${_usage_dep} INTERFACE_COMPILE_DEFINITIONS)
+ if(_value)
+ list(APPEND _usage_compile_definitions ${_value})
+ endif()
+ endforeach()
+
+ # Useful for debugging include/definition issues.
+ # message(STATUS "Public library ${_RULE_NAME}: Includes = ${_usage_include_directories}")
+ # message(STATUS "Public library ${_RULE_NAME}: Definitions = ${_usage_compile_definitions}")
+
+ if(SHORTFIN_BUILD_STATIC)
+ # Static library.
+ shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS})
+ add_library("${_RULE_NAME}-static" STATIC ${_RULE_SRCS})
+ target_compile_definitions("${_RULE_NAME}-static" INTERFACE
+ _SHORTFIN_USING_DYLIB
+ ${_usage_compile_definitions}
+ )
+ target_include_directories("${_RULE_NAME}-static" INTERFACE ${_usage_include_directories})
+ target_link_libraries(
+ "${_RULE_NAME}-static"
+ PRIVATE ${_STATIC_COMPONENTS}
+ )
+ endif()
+
+ if(SHORTFIN_BUILD_DYNAMIC)
+ # Dylib library.
+ shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS})
+ add_library("${_RULE_NAME}" SHARED ${_RULE_SRCS})
+ target_compile_definitions("${_RULE_NAME}" INTERFACE
+ _SHORTFIN_USING_DYLIB
+ ${_usage_compile_definitions}
+ )
+ target_include_directories("${_RULE_NAME}" INTERFACE ${_usage_include_directories})
+ if(_RULE_LINUX_LD_SCRIPT)
+ target_link_options("${_RULE_NAME}" PRIVATE
+ "$<$:-Wl,--version-script=${_RULE_LINUX_LD_SCRIPT}>"
+ )
+ endif()
+ target_link_libraries(
+ "${_RULE_NAME}"
+ PRIVATE ${_DYLIB_COMPONENTS}
+ )
+ set_target_properties("${_RULE_NAME}" PROPERTIES
+ VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}
+ SOVERSION ${SOVERSION}
+ )
+ endif()
+endfunction()
+
+function(shortfin_cc_component)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME"
+ "HDRS;SRCS;DEFINES;DEPS;COMPONENTS"
+ ${ARGN}
+ )
+ if(SHORTFIN_BUILD_STATIC)
+ # Static object library.
+ set(_STATIC_OBJECTS_NAME "${_RULE_NAME}.objects")
+ shortfin_components_to_static_libs(_STATIC_COMPONENTS ${_RULE_COMPONENTS})
+ add_library(${_STATIC_OBJECTS_NAME} OBJECT)
+ target_sources(${_STATIC_OBJECTS_NAME}
+ PRIVATE
+ ${_RULE_SRCS}
+ ${_RULE_HDRS}
+ )
+ target_compile_options(${_STATIC_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS})
+ target_link_libraries(${_STATIC_OBJECTS_NAME}
+ PUBLIC
+ _shortfin_defs
+ ${_STATIC_COMPONENTS}
+ ${_RULE_DEPS}
+ )
+ target_compile_definitions(${_STATIC_OBJECTS_NAME} PUBLIC ${_RULE_DEFINES})
+ endif()
+
+ if(SHORTFIN_BUILD_DYNAMIC)
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+ set(_DYLIB_OBJECTS_NAME "${_RULE_NAME}.dylib.objects")
+ shortfin_components_to_dynamic_libs(_DYLIB_COMPONENTS ${_RULE_COMPONENTS})
+ # Dylib object library.
+ add_library(${_DYLIB_OBJECTS_NAME} OBJECT)
+ target_sources(${_DYLIB_OBJECTS_NAME}
+ PRIVATE
+ ${_RULE_SRCS}
+ ${_RULE_HDRS}
+ )
+ target_compile_options(${_DYLIB_OBJECTS_NAME} PRIVATE ${SHORTFIN_DEFAULT_COPTS})
+ target_link_libraries(${_DYLIB_OBJECTS_NAME}
+ PUBLIC
+ _shortfin_defs
+ ${_DYLIB_COMPONENTS}
+ ${_RULE_DEPS}
+ )
+ set_target_properties(
+ ${_DYLIB_OBJECTS_NAME} PROPERTIES
+ CXX_VISIBILITY_PRESET hidden
+ C_VISIBILITY_PRESET hidden
+ VISIBILITY_INLINES_HIDDEN ON
+ )
+ target_compile_definitions(${_DYLIB_OBJECTS_NAME}
+ PRIVATE
+ _SHORTFIN_BUILDING_DYLIB
+ # Mate up with spdlog export settings since this is part of the
+ # library that is exporting these symbols.
+ SPDLOG_SHARED_LIB
+ spdlog_EXPORTS
+ )
+ target_compile_definitions(${_DYLIB_OBJECTS_NAME} PUBLIC ${_RULE_DEFINES})
+ endif()
+endfunction()
+
+function(shortfin_components_to_static_libs out_static_libs)
+ set(_LIBS ${ARGN})
+ list(TRANSFORM _LIBS APPEND ".objects")
+ set(${out_static_libs} ${_LIBS} PARENT_SCOPE)
+endfunction()
+
+function(shortfin_components_to_dynamic_libs out_dynamic_libs)
+ set(_LIBS ${ARGN})
+ list(TRANSFORM _LIBS APPEND ".dylib.objects")
+ set(${out_dynamic_libs} "${_LIBS}" PARENT_SCOPE)
+endfunction()
+
+function(shortfin_gtest_test)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME"
+ "SRCS;DEPS"
+ ${ARGN}
+ )
+
+ if(NOT SHORTFIN_BUILD_TESTS)
+ return()
+ endif()
+
+ add_executable(${_RULE_NAME} ${_RULE_SRCS})
+ target_link_libraries(${_RULE_NAME} PRIVATE
+ ${_RULE_DEPS}
+ ${SHORTFIN_LINK_LIBRARY_NAME}
+ GTest::gmock
+ GTest::gtest_main
+ )
+ gtest_discover_tests(${_RULE_NAME})
+endfunction()
+
+
+# Make changes to the global compile flags and properties before including
+# bundled deps. This configures various options aimed at making the bundled
+# dependencies private.
+# The effect can be undone with shortfin_pop_bundled_lib_options().
+# After this call, additional changes can be made to CMAKE_CXX_FLAGS as desired.
+macro(shortfin_push_bundled_lib_options)
+ set(SHORTFIN_ORIG_CXX_VISIBILITY_PRESET "${CMAKE_CXX_VISIBILITY_PRESET}")
+ set(SHORTFIN_ORIG_C_VISIBILITY_PRESET "${CMAKE_C_VISIBILITY_PRESET}")
+ set(SHORTFIN_ORIG_VISIBILITY_INLINES_HIDDEN "${CMAKE_VISIBILITY_INLINES_HIDDEN}")
+ set(SHORTFIN_ORIG_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
+
+ # Callers get a known state for visibility controls and can make changes from
+ # there.
+ set(CMAKE_C_VISIBILITY_PRESET "default")
+ set(CMAKE_CXX_VISIBILITY_PRESET "default")
+ set(CMAKE_VISIBILITY_INLINES_HIDDEN ON)
+endmacro()
+
+macro(shortfin_pop_bundled_lib_options)
+ set(CMAKE_CXX_VISIBILITY_PRESET ${SHORTFIN_ORIG_CXX_VISIBILITY_PRESET})
+ set(CMAKE_C_VISIBILITY_PRESET ${SHORTFIN_ORIG_C_VISIBILITY_PRESET})
+ set(CMAKE_VISIBILITY_INLINES_HIDDEN ${SHORTFIN_ORIG_VISIBILITY_INLINES_HIDDEN})
+ set(CMAKE_CXX_FLAGS "${SHORTFIN_ORIG_CXX_FLAGS}")
+endmacro()
diff --git a/shortfin/build_tools/python_lsan_suppressions.txt b/shortfin/build_tools/python_lsan_suppressions.txt
new file mode 100644
index 000000000..f3ac58064
--- /dev/null
+++ b/shortfin/build_tools/python_lsan_suppressions.txt
@@ -0,0 +1,11 @@
+leak:PyUnicode_New
+leak:_PyUnicodeWriter_PrepareInternal
+leak:_PyUnicodeWriter_Finish
+leak:numpy
+leak:_mlir_libs
+leak:google/_upb
+leak:import_find_and_load
+leak:pyo3::pyclass::create_type_object
+leak:ufunc
+leak:pydantic_core
+leak:sentencepiece
diff --git a/shortfin/dev_me.py b/shortfin/dev_me.py
new file mode 100755
index 000000000..8eacca274
--- /dev/null
+++ b/shortfin/dev_me.py
@@ -0,0 +1,260 @@
+#!/usr/bin/env python3
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# dev_me.py
+#
+# This is an opinionated development environment setup procedure aimed at
+# making core contributors on the same golden path. It is not the only way
+# to develop this project.
+#
+# First time build usage:
+# rm -Rf build # Start with a fresh build dir
+# python dev_me.py [--cmake=/path/to/cmake] [--clang=/path/to/clang] \
+# [--iree=/path/to/iree] [--asan] [--build-type=Debug] \
+# [--no-tracing]
+#
+# Subsequent build:
+# ./dev_me.py
+#
+# This will perform an editable install into the used python with both
+# default and tracing packages installed. After the initial build, ninja
+# can be invoked directly under build/cmake/default or build/cmake/tracy.
+# This can be done automatically just by running dev_me.py in a tree with
+# an existing build directory.
+#
+# By default, if there is an iree source dir adjacent to this parent repository,
+# that will be used (so you can just directly edit IREE runtime code and build.
+# Otherwise, the shortfin build will download a pinned IREE source tree.
+
+import argparse
+import importlib
+import os
+from pathlib import Path
+import re
+import subprocess
+import shutil
+import sys
+import sysconfig
+
+try:
+ from packaging.version import Version
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"'packaging' package not installed and required: Install with:\n"
+ f" {sys.executable} -m pip install packaging"
+ )
+
+
+CMAKE_REQUIRED_VERSION = Version("3.29")
+PYTHON_REQUIRED_VERSION = Version("3.12")
+CLANG_REQUIRED_VERSION = Version("16")
+SETUPTOOLS_REQUIRED_VERSION = Version("61.0")
+
+
+class EnvInfo:
+ def __init__(self, args):
+ self.this_dir = Path(__file__).resolve().parent
+ self.python_exe = sys.executable
+ self.python_version = Version(".".join(str(v) for v in sys.version_info[1:2]))
+ self.debug = bool(sysconfig.get_config_var("Py_DEBUG"))
+ self.asan = "-fsanitize=address" in sysconfig.get_config_var("PY_LDFLAGS")
+ self.gil_disabled = bool(sysconfig.get_config_var("Py_GIL_DISABLED"))
+ self.cmake_exe, self.cmake_version = self.find_cmake(args)
+ self.ninja_exe = shutil.which("ninja")
+ self.clang_exe, self.clang_version = self.find_clang(args)
+ self.iree_dir = self.find_iree(args)
+ self.setuptools_version = self.find_package_version("setuptools")
+ self.wheel_version = self.find_package_version("wheel")
+
+ self.configured_dirs = []
+ self.add_configured(self.this_dir / "build" / "cmake" / "default")
+ self.add_configured(self.this_dir / "build" / "cmake" / "tracy")
+
+ def add_configured(self, path: Path):
+ probe = path / "CMakeCache.txt"
+ if probe.resolve().exists():
+ self.configured_dirs.append(path)
+
+ def find_cmake(self, args):
+ paths = []
+ if args.cmake:
+ paths.append(str(args.cmake))
+ else:
+ default_cmake = shutil.which("cmake")
+ if default_cmake:
+ paths.append(default_cmake)
+ for cmake_path in paths:
+ try:
+ cmake_output = subprocess.check_output(
+ [cmake_path, "--version"]
+ ).decode()
+ except:
+ continue
+ if m := re.search("cmake version (.+)", cmake_output):
+ return cmake_path, Version(m.group(1))
+ return None, None
+
+ def find_clang(self, args):
+ if args.clang:
+ clang_exe = args.clang
+ else:
+ clang_exe = shutil.which("clang")
+ if not clang_exe:
+ return None, None
+ try:
+ clang_output = subprocess.check_output([clang_exe, "--version"]).decode()
+ except:
+ return None, None
+ if m := re.search(r"clang version ([0-9\.]+)", clang_output):
+ return clang_exe, Version(m.group(1))
+ return None, None
+
+ def find_iree(self, args):
+ iree_dir = args.iree
+ if not iree_dir:
+ # See if a sibling iree directory exists.
+ iree_dir = self.this_dir.parent.parent / "iree"
+ if (iree_dir / "CMakeLists.txt").exists():
+ return str(iree_dir)
+ if not iree_dir.exists():
+ print(f"ERROR: --iree={iree_dir} directory does not exist")
+ sys.exit(1)
+ return str(iree_dir)
+
+ def find_package_version(self, package_name: str) -> Version | None:
+ try:
+ m = importlib.import_module(package_name)
+ except ModuleNotFoundError:
+ return None
+ return Version(m.__version__)
+
+ def check_prereqs(self, args):
+ if self.cmake_version is None or self.cmake_version < CMAKE_REQUIRED_VERSION:
+ print(
+ f"ERROR: cmake not found or of an insufficient version: {self.cmake_exe}"
+ )
+ print(f" Required: {CMAKE_REQUIRED_VERSION}, Found: {self.cmake_version}")
+ print(f" Configure explicitly with --cmake=")
+ sys.exit(1)
+ if self.python_version < PYTHON_REQUIRED_VERSION:
+ print(f"ERROR: python version too old: {self.python_exe}")
+ print(
+ f" Required: {PYTHON_REQUIRED_VERSION}, Found: {self.python_version}"
+ )
+ sys.exit(1)
+ if self.clang_exe and self.clang_version < CLANG_REQUIRED_VERSION:
+ print(f"WARNING: clang version too old: {self.clang_exe}")
+ print(f" REQUIRED: {CLANG_REQUIRED_VERSION}, Found {self.clang_version}")
+ elif not self.clang_exe:
+ print(f"WARNING: Building the project with clang is highly recommended")
+ print(f" (pass --clang= to select clang)")
+
+ if args.asan and not self.asan:
+ print(
+ f"ERROR: An ASAN build was requested but python was not built with ASAN support"
+ )
+ sys.exit(1)
+
+ if (
+ self.setuptools_version is None
+ or self.setuptools_version < SETUPTOOLS_REQUIRED_VERSION
+ ):
+ print(
+ f"ERROR: 'setuptools' packaging is not installed or too old. "
+ f"Found {self.setuptools_version}, Need {SETUPTOOLS_REQUIRED_VERSION}"
+ )
+ sys.exit(1)
+ if self.wheel_version is None:
+ print(f"'wheel' package is not installed")
+ sys.exit(1)
+
+ def __repr__(self):
+ report = [
+ f"python: {self.python_exe}",
+ f"debug: {self.debug}",
+ f"asan: {self.asan}",
+ f"gil_disabled: {self.gil_disabled}",
+ f"cmake: {self.cmake_exe} ({self.cmake_version})",
+ f"ninja: {self.ninja_exe}",
+ f"clang: {self.clang_exe} ({self.clang_version})",
+ f"iree: {self.iree_dir}",
+ f"setuptools: {self.setuptools_version}",
+ f"wheel: {self.wheel_version}",
+ ]
+ return "\n".join(report)
+
+
+def main(argv: list[str]):
+ parser = argparse.ArgumentParser(
+ prog="shortfin dev", description="Shortfin dev setup helper"
+ )
+ parser.add_argument("--cmake", type=Path, help="CMake path")
+ parser.add_argument("--clang", type=Path, help="Clang path")
+ parser.add_argument("--iree", type=Path, help="Path to IREE source checkout")
+ parser.add_argument("--asan", action="store_true", help="Build with ASAN support")
+ parser.add_argument(
+ "--no-tracing", action="store_true", help="Disable IREE tracing build"
+ )
+ parser.add_argument(
+ "--build-type", default="Debug", help="CMake build type (default Debug)"
+ )
+ args = parser.parse_args(argv)
+ env_info = EnvInfo(args)
+
+ if env_info.configured_dirs:
+ print("First time configure...")
+ build_mode(env_info)
+ else:
+ configure_mode(env_info, args)
+
+
+def configure_mode(env_info: EnvInfo, args):
+ print("Environment info:")
+ print(env_info)
+ env_info.check_prereqs(args)
+
+ env_vars = {
+ "SHORTFIN_DEV_MODE": "ON",
+ "SHORTFIN_CMAKE_BUILD_TYPE": args.build_type,
+ "SHORTFIN_ENABLE_ASAN": "ON" if args.asan else "OFF",
+ "SHORTFIN_CMAKE": env_info.cmake_exe,
+ }
+ if env_info.iree_dir:
+ env_vars["SHORTFIN_IREE_SOURCE_DIR"] = env_info.iree_dir
+ if env_info.clang_exe:
+ env_vars["CC"] = env_info.clang_exe
+ env_vars["CXX"] = f"{env_info.clang_exe}++"
+ env_vars["CMAKE_LINKER_TYPE"] = "LLD"
+ env_vars["SHORTFIN_ENABLE_TRACING"] = "OFF" if args.no_tracing else "ON"
+
+ print("Executing setup:")
+ setup_args = [
+ env_info.python_exe,
+ "-m",
+ "pip",
+ "install",
+ "--no-build-isolation",
+ "-v",
+ "-e",
+ str(env_info.this_dir),
+ ]
+ print(f"{' '.join('='.join(str(kv)) for kv in env_vars.items())} \\")
+ print(f" {' '.join(setup_args)}")
+ actual_env_vars = dict(os.environ)
+ actual_env_vars.update(env_vars)
+ subprocess.check_call(setup_args, cwd=env_info.this_dir, env=actual_env_vars)
+ print("You are now DEV'd!")
+
+
+def build_mode(env_info: EnvInfo):
+ print("Building")
+ for build_dir in env_info.configured_dirs:
+ subprocess.check_call([env_info.cmake_exe, "--build", str(build_dir)])
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/shortfin/docs/README.md b/shortfin/docs/README.md
new file mode 100644
index 000000000..3d993628a
--- /dev/null
+++ b/shortfin/docs/README.md
@@ -0,0 +1,20 @@
+# Python API Docs
+
+Documentation for the Python API is build with Sphinx under this directory.
+
+## Building docs
+
+The Python modules will be automatically imported if installed or if the build
+is located at `../build`, relative to this file.
+
+### Install dependencies
+
+```shell
+python3 -m pip install -r requirements.txt
+```
+
+### Build the docs
+
+```shell
+sphinx-build -b html . _build
+```
diff --git a/shortfin/docs/conf.py b/shortfin/docs/conf.py
new file mode 100644
index 000000000..62a49b026
--- /dev/null
+++ b/shortfin/docs/conf.py
@@ -0,0 +1,43 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Configuration file for the Sphinx documentation builder.
+#
+# For the full list of built-in configuration values, see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+import os
+import sys
+
+try:
+ import _shortfin_default
+except ImportError:
+ sys.path.insert(0, os.path.abspath("../build/python/"))
+ import _shortfin_default
+
+# -- Project information -----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+
+project = "shortfin"
+copyright = "2024, Advanced Micro Devices, Inc"
+author = "shortfin Authors"
+
+# -- General configuration ---------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
+
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.napoleon",
+]
+
+templates_path = ["_templates"]
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+# -- Options for HTML output -------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
+
+html_theme = "sphinx_rtd_theme"
+html_static_path = ["_static"]
diff --git a/shortfin/docs/index.rst b/shortfin/docs/index.rst
new file mode 100644
index 000000000..3176543ea
--- /dev/null
+++ b/shortfin/docs/index.rst
@@ -0,0 +1,27 @@
+.. Copyright 2024 Advanced Micro Devices, Inc.
+
+.. shortfin documentation master file, created by
+ sphinx-quickstart on Fri Sep 6 16:31:45 2024.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Welcome to shortfin's documentation!
+=======================================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Reference
+
+ reference
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/shortfin/docs/reference.rst b/shortfin/docs/reference.rst
new file mode 100644
index 000000000..8fc96be22
--- /dev/null
+++ b/shortfin/docs/reference.rst
@@ -0,0 +1,64 @@
+.. Copyright 2024 Advanced Micro Devices, Inc.
+
+.. py:module:: _shortfin_default.lib
+
+.. _reference:
+
+API Reference
+=============
+
+Array
+--------------
+.. automodule:: _shortfin_default.lib.array
+.. autoclass:: DType
+.. autoclass:: storage
+ :members:
+.. autoclass:: base_array
+.. autoclass:: device_array
+ :members:
+.. autoclass:: RandomGenerator
+
+.. autofunction:: _shortfin_default.lib.array.fill_randn
+.. autofunction:: _shortfin_default.lib.array.argmax
+
+Local
+--------------
+
+.. automodule:: _shortfin_default.lib.local
+
+.. autoclass:: SystemBuilder
+.. autoclass:: System
+.. autoclass:: Node
+.. autoclass:: Device
+.. autoclass:: DeviceAffinity
+.. autoclass:: Program
+.. autoclass:: ProgramFunction
+ :members:
+.. autoclass:: ProgramModule
+.. autoclass:: ProgramInvocation
+.. autoclass:: Fiber
+.. autoclass:: ScopedDevice
+.. autoclass:: Worker
+.. autoclass:: Process
+.. autoclass:: CompletionEvent
+.. autoclass:: Message
+.. autoclass:: Queue
+.. autoclass:: QueueWriter
+.. autoclass:: QueueReader
+.. autoclass:: Future
+.. autoclass:: VoidFuture
+.. autoclass:: MessageFuture
+
+
+AMD GPU
+^^^^^^^
+.. automodule:: _shortfin_default.lib.local.amdgpu
+.. autoclass:: SystemBuilder
+ :members:
+.. autoclass:: AMDGPUDevice
+
+Host
+^^^^^^^
+.. automodule:: _shortfin_default.lib.local.host
+.. autoclass:: CPUSystemBuilder
+.. autoclass:: HostCPUDevice
diff --git a/shortfin/docs/requirements.txt b/shortfin/docs/requirements.txt
new file mode 100644
index 000000000..1aef75db6
--- /dev/null
+++ b/shortfin/docs/requirements.txt
@@ -0,0 +1,2 @@
+sphinx==7.4.7
+sphinx_rtd_theme==2.0.0
diff --git a/libshortfin/examples/python/async/basic_asyncio.py b/shortfin/examples/python/async/basic_asyncio.py
similarity index 100%
rename from libshortfin/examples/python/async/basic_asyncio.py
rename to shortfin/examples/python/async/basic_asyncio.py
diff --git a/libshortfin/examples/python/async/device_sync.py b/shortfin/examples/python/async/device_sync.py
similarity index 86%
rename from libshortfin/examples/python/async/device_sync.py
rename to shortfin/examples/python/async/device_sync.py
index b88049ae0..2c1a8248c 100644
--- a/libshortfin/examples/python/async/device_sync.py
+++ b/shortfin/examples/python/async/device_sync.py
@@ -14,7 +14,7 @@
class MyProcess(sf.Process):
async def run(self):
- device = self.scope.device(0)
+ device = self.fiber.device(0)
ary1 = snp.device_array(device, [32, 1, 4], snp.int32)
ary1.storage.fill(array.array("i", [0]))
print(f"[pid:{self.pid}] ARY1:", ary1)
@@ -27,11 +27,11 @@ async def run(self):
async def main():
worker = lsys.create_worker("main")
- scope = lsys.create_scope(worker)
+ fiber = lsys.create_fiber(worker)
print("+++ Launching process")
await asyncio.gather(
- MyProcess(scope=scope).launch(),
- MyProcess(scope=scope).launch(),
+ MyProcess(fiber=fiber).launch(),
+ MyProcess(fiber=fiber).launch(),
)
print("--- Process terminated")
diff --git a/libshortfin/examples/python/async/process.py b/shortfin/examples/python/async/process.py
similarity index 82%
rename from libshortfin/examples/python/async/process.py
rename to shortfin/examples/python/async/process.py
index dcf539516..96db63052 100644
--- a/libshortfin/examples/python/async/process.py
+++ b/shortfin/examples/python/async/process.py
@@ -32,7 +32,7 @@ async def run(self):
processes = []
if self.arg < 10:
await asyncio.sleep(0.1)
- processes.append(MyProcess(self.arg + 1, scope=self.scope).launch())
+ processes.append(MyProcess(self.arg + 1, fiber=self.fiber).launch())
await asyncio.gather(*processes)
print(f"[pid:{self.pid}] Goodbye async:", self.arg, self)
tick_total()
@@ -41,14 +41,14 @@ async def run(self):
async def main():
def create_worker(i):
worker = lsys.create_worker(f"main-{i}")
- return lsys.create_scope(worker)
+ return lsys.create_fiber(worker)
workers = [create_worker(i) for i in range(3)]
processes = []
for i in range(10):
- processes.append(MyProcess(i, scope=workers[i % len(workers)]).launch())
- processes.append(MyProcess(i * 100, scope=workers[i % len(workers)]).launch())
- processes.append(MyProcess(i * 1000, scope=workers[i % len(workers)]).launch())
+ processes.append(MyProcess(i, fiber=workers[i % len(workers)]).launch())
+ processes.append(MyProcess(i * 100, fiber=workers[i % len(workers)]).launch())
+ processes.append(MyProcess(i * 1000, fiber=workers[i % len(workers)]).launch())
await asyncio.sleep(0.1)
print("<>")
diff --git a/libshortfin/examples/python/async/queue.py b/shortfin/examples/python/async/queue.py
similarity index 85%
rename from libshortfin/examples/python/async/queue.py
rename to shortfin/examples/python/async/queue.py
index dabc2c7f0..90f9dc592 100644
--- a/libshortfin/examples/python/async/queue.py
+++ b/shortfin/examples/python/async/queue.py
@@ -55,17 +55,19 @@ async def run(self):
async def main():
- queue = lsys.create_queue("infeed")
- main_scope = lsys.create_scope()
+ queue = lsys.create_queue()
+ main_fiber = lsys.create_fiber()
+ # TODO: Also test named queues.
+ # queue = lsys.create_queue("infeed")
w1 = lsys.create_worker("w1")
- w1_scope = lsys.create_scope(w1)
+ w1_fiber = lsys.create_fiber(w1)
await asyncio.gather(
- WriterProcess(queue, scope=main_scope).launch(),
+ WriterProcess(queue, fiber=main_fiber).launch(),
# By having a reader on the main worker and a separate worker,
# we test both intra and inter worker future resolution, which
# take different paths internally.
- ReaderProcess(queue, scope=main_scope).launch(),
- ReaderProcess(queue, scope=w1_scope).launch(),
+ ReaderProcess(queue, fiber=main_fiber).launch(),
+ ReaderProcess(queue, fiber=w1_fiber).launch(),
)
diff --git a/shortfin/examples/python/enumerate_devices.py b/shortfin/examples/python/enumerate_devices.py
new file mode 100644
index 000000000..2c1cc8203
--- /dev/null
+++ b/shortfin/examples/python/enumerate_devices.py
@@ -0,0 +1,32 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+r"""Simple test program that enumerates devices available.
+
+Run with SystemBuilder keyword args on the command line like::
+
+ python examples/python/enumerate_devices.py \
+ system_type=amdgpu amdgpu_logical_devices_per_physical_device=4
+
+"""
+
+import sys
+
+import shortfin as sf
+
+
+def main():
+ args = [arg.split("=", maxsplit=1) for arg in sys.argv[1:]]
+ arg_dict = {k: v for k, v in args}
+ print(f"Creating system with args: {arg_dict}")
+ builder = sf.SystemBuilder(**arg_dict)
+ with builder.create_system() as ls:
+ for device in ls.devices:
+ print(device)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/libshortfin/examples/python/fastapi/server.py b/shortfin/examples/python/fastapi/server.py
similarity index 97%
rename from libshortfin/examples/python/fastapi/server.py
rename to shortfin/examples/python/fastapi/server.py
index 639cf2d00..8807cf4c6 100644
--- a/libshortfin/examples/python/fastapi/server.py
+++ b/shortfin/examples/python/fastapi/server.py
@@ -67,9 +67,10 @@ async def run(self):
)
await asyncio.sleep(0.01)
responder.stream_part(None)
- except Exception as e:
- responder.close_with_error()
+ except:
traceback.print_exc()
+ finally:
+ responder.ensure_response()
@asynccontextmanager
diff --git a/libshortfin/examples/python/mobilenet_server/.gitignore b/shortfin/examples/python/mobilenet_server/.gitignore
similarity index 100%
rename from libshortfin/examples/python/mobilenet_server/.gitignore
rename to shortfin/examples/python/mobilenet_server/.gitignore
diff --git a/libshortfin/examples/python/mobilenet_server/build_model.sh b/shortfin/examples/python/mobilenet_server/build_model.sh
similarity index 87%
rename from libshortfin/examples/python/mobilenet_server/build_model.sh
rename to shortfin/examples/python/mobilenet_server/build_model.sh
index 6011072c9..26ae2fad2 100755
--- a/libshortfin/examples/python/mobilenet_server/build_model.sh
+++ b/shortfin/examples/python/mobilenet_server/build_model.sh
@@ -39,5 +39,11 @@ echo "Import onnx model"
python -m iree.compiler.tools.import_onnx $onnx_upgrade_path -o $mlir_path
echo "Compile onnx model"
-python -m iree.compiler.tools.scripts.ireec \
- $mlir_path -o "$vmfb_path" --iree-input-type=onnx --iree-hal-target-backends=llvm-cpu
+if [ -z "$@" ]; then
+ compile_flags="--iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host"
+else
+ compile_flags="$@"
+fi
+echo "Using compile flags: $compile_flags"
+python -m iree.compiler.tools.scripts.iree_compile \
+ $mlir_path -o "$vmfb_path" --iree-input-type=onnx $compile_flags
diff --git a/libshortfin/examples/python/mobilenet_server/inference_system.py b/shortfin/examples/python/mobilenet_server/inference_system.py
similarity index 81%
rename from libshortfin/examples/python/mobilenet_server/inference_system.py
rename to shortfin/examples/python/mobilenet_server/inference_system.py
index 068204f49..cc8b6dd1b 100644
--- a/libshortfin/examples/python/mobilenet_server/inference_system.py
+++ b/shortfin/examples/python/mobilenet_server/inference_system.py
@@ -26,7 +26,7 @@ def __init__(self, program, request_queue, **kwargs):
super().__init__(**kwargs)
self.main_function = program["module.torch-jit-export"]
self.request_reader = request_queue.reader()
- self.device = self.scope.device(0)
+ self.device = self.fiber.device(0)
self.device_input = sfnp.device_array(
self.device, [MAX_BATCH, 3, 224, 224], sfnp.float32
)
@@ -40,15 +40,17 @@ async def run(self):
# just writing to the backing storage is the best we have API
# support for. Generally, APIs on storage should be mirrored onto
# the array.
- self.host_staging.storage.data = request.raw_image_data
+ # TODO: Easier to use API for writing into the storage
+ with self.host_staging.storage.map(write=True, discard=True) as m:
+ m.fill(request.raw_image_data)
print("host_staging =", self.host_staging)
self.device_input.copy_from(self.host_staging)
# Simple call. Note that the await here is merely awaiting the
# result being *available* (i.e. that the VM coroutine has
# completed) but does not indicate that the result is ready.
- (result1,) = await self.main_function(self.device_input)
- (result2,) = await self.main_function(self.device_input)
+ (result1,) = await self.main_function(self.device_input, fiber=self.fiber)
+ (result2,) = await self.main_function(self.device_input, fiber=self.fiber)
# TODO: Implement await on individual results. The accounting is
# there but currently we can only await on the device itself.
@@ -57,20 +59,20 @@ async def run(self):
print("Result 2:", result2)
# Explicit invocation object.
- # inv = self.main_function.invocation(scope=self.scope)
+ # inv = self.main_function.invocation(fiber=self.fiber)
# inv.add_arg(self.device_input)
# results = await inv.invoke()
# print("results:", results)
# Multiple invocations in parallel.
# all_results = await asyncio.gather(
- # self.main_function(self.device_input, scope=self.scope),
- # self.main_function(self.device_input, scope=self.scope),
- # self.main_function(self.device_input, scope=self.scope),
+ # self.main_function(self.device_input, fiber=self.fiber),
+ # self.main_function(self.device_input, fiber=self.fiber),
+ # self.main_function(self.device_input, fiber=self.fiber),
# )
# print("All results:", all_results)
- # output = await self.scope.invoke(self.main_function, self.device_input)
+ # output = await self.fiber.invoke(self.main_function, self.device_input)
# print("OUTPUT:", output)
# read_back = self.device_input.for_transfer()
# read_back.copy_from(self.device_input)
@@ -88,14 +90,14 @@ def __init__(self, lsys: sf.System, home_dir: Path):
print(f"Loaded: {self.program_module}")
self.processes = []
- async def start_scope(self, scope):
+ async def start_fiber(self, fiber):
# Note that currently, program load is synchronous. But we do it
# in a task so we can await it in the future and let program loads
# overlap.
for _ in range(self.processes_per_worker):
- program = sf.Program([self.program_module], scope=scope)
+ program = sf.Program([self.program_module], devices=fiber.raw_devices)
self.processes.append(
- InferenceProcess(program, self.request_queue, scope=scope).launch()
+ InferenceProcess(program, self.request_queue, fiber=fiber).launch()
)
async def main(self):
@@ -104,14 +106,14 @@ async def main(self):
f"System created with {len(devices)} devices:\n "
f"{' '.join(repr(d) for d in devices)}"
)
- # We create a physical worker and initial scope for each device.
+ # We create a physical worker and initial fiber for each device.
# This isn't a hard requirement and there are advantages to other
# topologies.
initializers = []
for device in devices:
worker = self.lsys.create_worker(f"device-{device.name}")
- scope = self.lsys.create_scope(worker, devices=[device])
- initializers.append(self.start_scope(scope))
+ fiber = self.lsys.create_fiber(worker, devices=[device])
+ initializers.append(self.start_fiber(fiber))
# Run all initializers in parallel. These launch inference processes.
print("Waiting for initializers")
@@ -142,7 +144,8 @@ def client():
# Done.
writer.close()
- lsys = sf.host.CPUSystemBuilder().create_system()
+ sf.SystemBuilder.default_system_type = "hostcpu"
+ lsys = sf.SystemBuilder().create_system()
main = Main(lsys, home_dir)
lsys.init_worker.call_threadsafe(client)
lsys.run(main.main())
diff --git a/libshortfin/examples/python/mobilenet_server/upgrade_onnx.py b/shortfin/examples/python/mobilenet_server/upgrade_onnx.py
similarity index 100%
rename from libshortfin/examples/python/mobilenet_server/upgrade_onnx.py
rename to shortfin/examples/python/mobilenet_server/upgrade_onnx.py
diff --git a/shortfin/mypy.ini b/shortfin/mypy.ini
deleted file mode 100644
index d10567407..000000000
--- a/shortfin/mypy.ini
+++ /dev/null
@@ -1,5 +0,0 @@
-[mypy]
-
-explicit_package_bases = True
-mypy_path = $MYPY_CONFIG_FILE_DIR
-packages = shortfin.llm
diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml
index d347ed631..1abb49ef6 100644
--- a/shortfin/pyproject.toml
+++ b/shortfin/pyproject.toml
@@ -1,12 +1,54 @@
[build-system]
-requires = ["setuptools", "wheel"]
+requires = [
+ "cmake>=3.29",
+ "setuptools>=61.0",
+ "wheel",
+ "ninja",
+ 'typing-extensions ; python_version == "3.10" ',
+]
build-backend = "setuptools.build_meta"
+[project]
+name = "shortfin"
+authors = [
+ {name = "SHARK Authors"},
+]
+description = "SHARK inference library and serving engine"
+readme = "README.md"
+license = {text = "Apache-2.0"}
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+]
+requires-python = ">= 3.10"
+
+# Version is set via the `setup.py`.
+dynamic = ["version"]
+
+[project.urls]
+Repository = "https://github.com/nod-ai/shark-ai"
+Documentation = "https://shortfin.readthedocs.io/en/latest/"
+
+[project.optional-dependencies]
+apps = [
+ "transformers",
+ "dataclasses-json",
+ "pillow",
+ "fastapi",
+ "uvicorn",
+ "aiohttp",
+]
+
[tool.pytest.ini_options]
-addopts = "-ra"
+addopts = [
+ "-ra",
+ "--import-mode=importlib",
+]
testpaths = [
"tests",
]
-pythonpath = [
- ".",
-]
diff --git a/shortfin/python/CMakeLists.txt b/shortfin/python/CMakeLists.txt
new file mode 100644
index 000000000..d125416af
--- /dev/null
+++ b/shortfin/python/CMakeLists.txt
@@ -0,0 +1,79 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# shortfin publishes multiple python packages: - _shortfin: Trampoline
+# __init__.py which looks at environment variables to load an appropriate native
+# library. - _shortfin_default.lib: Native library as a default, uninstrumented
+# build. - _shortfin_tracing.lib: Native library with tracing enabled. -
+# Others.
+
+# nanobind
+FetchContent_Declare(
+ nanobind
+ GIT_REPOSITORY https://github.com/wjakob/nanobind.git
+ GIT_TAG 784efa2a0358a4dc5432c74f5685ee026e20f2b6 # 2.2.0
+)
+FetchContent_MakeAvailable(nanobind)
+
+nanobind_add_module(shortfin_python_extension
+ NB_STATIC LTO FREE_THREADED
+ array_binding.cc
+ array_host_ops.cc
+ lib_ext.cc
+)
+
+if (SHORTFIN_ENABLE_TRACING)
+ set_target_properties(shortfin_python_extension
+ PROPERTIES OUTPUT_NAME "_shortfin_tracy/lib")
+else()
+ set_target_properties(shortfin_python_extension
+ PROPERTIES OUTPUT_NAME "_shortfin_default/lib")
+endif()
+
+target_link_libraries(shortfin_python_extension
+ PRIVATE ${SHORTFIN_LINK_LIBRARY_NAME}
+)
+
+function(shortfin_python_stubs build_type)
+ nanobind_add_stub(
+ shortfin_python_extension_stub
+ MODULE _shortfin_${build_type}.lib
+ OUTPUT _shortfin_${build_type}/lib.pyi
+ DEPENDS shortfin_python_extension
+ )
+
+endfunction()
+
+function(shortfin_python_stubs build_variant)
+ set(output_root "${CMAKE_CURRENT_BINARY_DIR}/_shortfin_${build_variant}")
+ file(MAKE_DIRECTORY ${output_root})
+ nanobind_add_stub(
+ shortfin_python_extension_stub_lib_${build_variant}
+ MODULE _shortfin_${build_variant}.lib
+ OUTPUT ${output_root}/lib/__init__.pyi
+ DEPENDS shortfin_python_extension
+ )
+
+ nanobind_add_stub(
+ shortfin_python_extension_stub_array_${build_variant}
+ MODULE _shortfin_${build_variant}.lib.array
+ OUTPUT ${output_root}/lib/array.pyi
+ DEPENDS shortfin_python_extension
+ )
+
+ nanobind_add_stub(
+ shortfin_python_extension_stub_local_${build_variant}
+ MODULE _shortfin_${build_variant}.lib.local
+ OUTPUT ${output_root}/lib/local.pyi
+ DEPENDS shortfin_python_extension
+ )
+endfunction()
+
+if (SHORTFIN_ENABLE_TRACING)
+ shortfin_python_stubs(tracy)
+else()
+ shortfin_python_stubs(default)
+endif()
diff --git a/shortfin/python/_shortfin/__init__.py b/shortfin/python/_shortfin/__init__.py
new file mode 100644
index 000000000..9bfa3a497
--- /dev/null
+++ b/shortfin/python/_shortfin/__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# The proper way to import this package is via:
+# from _shortfin import lib as sfl
+
+from typing import TYPE_CHECKING
+
+import os
+import sys
+import warnings
+
+if TYPE_CHECKING:
+ from _shortfin_default import lib
+else:
+ variant = os.getenv("SHORTFIN_PY_RUNTIME", "default")
+
+ if variant == "tracy":
+ try:
+ from _shortfin_tracy import lib
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ "Shortfin Tracy runtime requested via SHORTFIN_PY_RUNTIME but it is not enabled in this build"
+ )
+ print("-- Using Tracy runtime (SHORTFIN_PY_RUNTIME=tracy)", file=sys.stderr)
+ else:
+ if variant != "default":
+ warnings.warn(
+ f"Unknown value for SHORTFIN_PY_RUNTIME env var ({variant}): Using default"
+ )
+ from _shortfin_default import lib
diff --git a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py b/shortfin/python/_shortfin/asyncio_bridge.py
similarity index 74%
rename from libshortfin/bindings/python/_shortfin/asyncio_bridge.py
rename to shortfin/python/_shortfin/asyncio_bridge.py
index 78aff49a6..28264e9e3 100644
--- a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py
+++ b/shortfin/python/_shortfin/asyncio_bridge.py
@@ -5,10 +5,19 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import asyncio
+import inspect
from . import lib as sfl
+# Feature detect some versions where signatures changes.
+if "context" in inspect.signature(asyncio.Task).parameters:
+ # Python > 3.10
+ _ASYNCIO_TASK_HAS_CONTEXT = True
+else:
+ _ASYNCIO_TASK_HAS_CONTEXT = False
+
+
class PyWorkerEventLoop(asyncio.AbstractEventLoop):
def __init__(self, worker: sfl.local.Worker):
self._worker = worker
@@ -17,8 +26,15 @@ def get_debug(self):
# Requirement of asyncio.
return False
- def create_task(self, coro):
- return asyncio.Task(coro, loop=self)
+ if _ASYNCIO_TASK_HAS_CONTEXT:
+
+ def create_task(self, coro, *, name=None, context=None):
+ return asyncio.Task(coro, loop=self, name=name, context=context)
+
+ else:
+
+ def create_task(self, coro, *, name=None):
+ return asyncio.Task(coro, loop=self, name=name)
def create_future(self):
return asyncio.Future(loop=self)
@@ -48,6 +64,13 @@ def call_later(
w.delay_call(deadline, handle._sf_maybe_run)
return handle
+ def call_at(self, when, callback, *args, context=None) -> asyncio.TimerHandle:
+ w = self._worker
+ deadline = int(when * 1e9)
+ handle = _TimerHandle(when, callback, args, self, context)
+ w.delay_call(deadline, handle._sf_maybe_run)
+ return handle
+
def call_exception_handler(self, context) -> None:
# TODO: Should route this to the central exception handler. Should
# also play with ergonomics of how the errors get reported in
@@ -62,7 +85,7 @@ def call_exception_handler(self, context) -> None:
def _timer_handle_cancelled(self, handle):
# We don't do anything special: just skip it if it comes up.
- pass
+ ...
class _Handle(asyncio.Handle):
diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc
new file mode 100644
index 000000000..08a4071a8
--- /dev/null
+++ b/shortfin/python/array_binding.cc
@@ -0,0 +1,679 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./lib_ext.h"
+#include "./utils.h"
+#include "shortfin/array/api.h"
+#include "shortfin/support/logging.h"
+
+using namespace shortfin::array;
+
+namespace shortfin::python {
+
+namespace {
+static const char DOCSTRING_ARRAY_COPY_FROM[] =
+ R"(Copy contents from a source array to this array.
+
+Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
+)";
+
+static const char DOCSTRING_ARRAY_COPY_TO[] =
+ R"(Copy contents this array to a destination array.
+
+Equivalent to `dest_array.storage.copy_from(source_array.storage)`.
+)";
+
+static const char DOCSTRING_ARRAY_FILL[] = R"(Fill an array with a value.
+
+Note that `fill` is asynchronous and may not be visible immediately. For immediate
+manipulation of host visible arrays, assign to the `items` property or use the
+`map(discard=True)` to get a mapping object which can be used to directly
+update the contents.
+
+Equivalent to `array.storage.fill(pattern)`.
+)";
+
+static const char DOCSTRING_ARRAY_ITEMS[] =
+ R"(Convenience shorthand for map(...).items)";
+
+static const char DOCSTRING_ARRAY_MAP[] =
+ R"(Create a typed mapping of the buffer contents in host memory.
+
+Support kwargs of:
+
+| read: Enables read access to the mapped memory.
+| write: Enables write access to the mapped memory and will flush upon close
+ (for non-unified memory systems).
+| discard: Indicates that the entire memory map should be treated as if it will
+ be overwritten. Initial contents will be undefined. Implies `write=True`.
+
+Mapping memory for access from the host requires a compatible buffer that has
+been created with host visibility (which includes host buffers).
+
+The returned mapping object is a context manager that will close/flush on
+exit. Alternatively, the `close()` method can be invoked explicitly.
+
+See also `storage.map()` which functions similarly but does not allow access
+to dtype specific functionality.
+)";
+
+static const char DOCSTRING_ARRAY_VIEW[] =
+ R"(Create a view of an array.
+
+Either integer indices or slices can be passed to the view() method to create
+an aliased device_array that shares a subset of the storage. Only view()
+organizations that result in a row-major, dense array are currently supported.
+)";
+
+static const char DOCSTRING_MAPPING_FILL[] =
+ R"(Fill the host mapping with a pattern.
+
+The pattern can either be an object implementing the buffer protocol or a Python
+int/float if the mapping has a dtype. In this case, the numeric value will be
+converted to the appropriate typed pattern. Only dtypes supported by the
+array.array class are supported in this fashion.
+
+The pattern must evenly divide the mapping.
+
+Note that like all methods on a mapping, any changes are immediately visible
+(whereas the `fill` method on the array and storage are async operations).
+)";
+
+static const char DOCSTRING_MAPPING_ITEMS[] =
+ R"(Access contents as a Python array.
+
+When reading this attribute, an array.array will be constructed with the
+contents of the mapping. This supports a subset of element types (byte aligned
+integers, floats and doubles) corresponding to Python types.
+
+On write, the mapping will be written with arbitrary Python types marshaled
+via array.array into its contents.
+)";
+
+static const char DOCSTRING_STORAGE_COPY_FROM[] =
+ R"(Copy contents from a source storage to this array.
+
+This operation executes asynchronously and the effect will only be visible
+once the execution fiber has been synced to the point of mutation.
+)";
+
+static const char DOCSTRING_STORAGE_FILL[] = R"(Fill a storage with a value.
+
+Takes as argument any value that can be interpreted as a buffer with the Python
+buffer protocol of size 1, 2, or 4 bytes. The storage will be filled uniformly
+with the pattern.
+
+This operation executes asynchronously and the effect will only be visible
+once the execution fiber has been synced to the point of mutation.
+)";
+
+static const char DOCSTRING_STORAGE_MAP[] =
+ R"(Create a mapping of the buffer contents in host memory.
+
+Support kwargs of:
+
+| read: Enables read access to the mapped memory.
+| write: Enables write access to the mapped memory and will flush upon close
+ (for non-unified memory systems).
+| discard: Indicates that the entire memory map should be treated as if it will
+ be overwritten. Initial contents will be undefined. Implies `write=True`.
+
+Mapping memory for access from the host requires a compatible buffer that has
+been created with host visibility (which includes host buffers).
+
+The returned mapping object is a context manager that will close/flush on
+exit. Alternatively, the `close()` method can be invoked explicitly.
+
+See also `device_array.map()` which functions similarly but allows some
+additional dtype specific accessors.
+)";
+
+device_array PyDeviceArrayView(device_array &array, py::args keys) {
+ size_t rank = array.shape().size();
+ Dims c_offsets(rank, 0);
+ Dims c_sizes(array.shape_container());
+
+ if (keys.size() > rank) {
+ throw std::invalid_argument(
+ "Cannot create view into device_array greater than its rank");
+ }
+
+ for (size_t idx = 0; py::handle key : keys) {
+ if (py::isinstance(key)) {
+ // Slice key.
+ auto slice = py::cast(key);
+ auto [start, stop, step, length] = slice.compute(c_sizes[idx]);
+ if (step != 1) {
+ throw std::logic_error("view does not support strided slices");
+ }
+ c_offsets[idx] = start;
+ c_sizes[idx] = length;
+ } else if (py::isinstance(key)) {
+ // Integer key.
+ c_offsets[idx] = py::cast(key);
+ c_sizes[idx] = 1;
+ } else {
+ throw std::invalid_argument(
+ "Args to view must either be integer indices or slices");
+ }
+ idx += 1;
+ }
+
+ return array.view(c_offsets, c_sizes);
+}
+
+class Refs {
+ public:
+ std::unordered_map
+ element_type_array_type_code_table =
+ CreateElementTypeArrayTypeCodeTable();
+ py::object array_array_ctor = py::module_::import_("array").attr("array");
+
+ private:
+ static std::unordered_map
+ CreateElementTypeArrayTypeCodeTable() {
+ std::unordered_map table;
+ // This is really gross. Python's array type codes are pegged to C types,
+ // which do not have portable sizes. We pick portablish things here and
+ // carp on mismatch.
+ auto add_type = [&](DType dt, const char *code, size_t size) {
+ if (dt.dense_byte_count() != size) {
+ throw std::invalid_argument(
+ fmt::format("Illegal native type size for dtype {}, type code {}. "
+ "Native size mismatch: {} vs {}",
+ dt.name(), code, dt.dense_byte_count(), size));
+ }
+ table[dt] = py::str(code);
+ };
+
+ // See table at https://docs.python.org/3/library/array.html
+ add_type(DType::int8(), "b", sizeof(char));
+ add_type(DType::sint8(), "b", sizeof(char));
+ add_type(DType::uint8(), "B", sizeof(unsigned char));
+ add_type(DType::int16(), "h", sizeof(signed short));
+ add_type(DType::sint16(), "h", sizeof(signed short));
+ add_type(DType::uint16(), "H", sizeof(unsigned short));
+ add_type(DType::int32(), "i", sizeof(signed int));
+ add_type(DType::sint32(), "i", sizeof(signed int));
+ add_type(DType::uint32(), "I", sizeof(unsigned int));
+ add_type(DType::int64(), "q", sizeof(signed long long));
+ add_type(DType::sint64(), "q", sizeof(signed long long));
+ add_type(DType::uint64(), "Q", sizeof(unsigned long long));
+ add_type(DType::float16(), "H", sizeof(unsigned short));
+ add_type(DType::float32(), "f", sizeof(float));
+ add_type(DType::float64(), "d", sizeof(double));
+ return table;
+ }
+};
+
+// Holder for a `mapping` class. Also holds additional metadata.
+class PyMapping {
+ public:
+ PyMapping() = default;
+ class mapping &mapping() { return mapping_; }
+ std::optional dtype() { return dtype_; }
+ void set_dtype(DType dtype) { dtype_ = dtype; }
+
+ void AssertValid() {
+ if (!mapping_) {
+ throw std::logic_error("Mapping has been closed");
+ }
+ }
+
+ void FillFromScalar(Refs *refs, py::handle value) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromScalar");
+ if (!dtype()) {
+ throw std::invalid_argument(
+ "The `fill` method is only valid for typed mappings but "
+ "this does not have a dtype");
+ }
+ auto &table = refs->element_type_array_type_code_table;
+ auto it = table.find(*dtype());
+ if (it == table.end()) {
+ throw std::invalid_argument(
+ fmt::format("Python array.array type code not known for dtype "
+ "{}: Cannot access items",
+ dtype()->name()));
+ }
+ py::object pattern =
+ refs->array_array_ctor(it->second, py::make_tuple(value));
+ FillFromBuffer(pattern);
+ }
+
+ void FillFromBuffer(py::handle buffer) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromBuffer");
+ Py_buffer py_view;
+ int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
+ if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
+ throw py::python_error();
+ }
+ PyBufferReleaser py_view_releaser(py_view);
+ if (mapping().size() % py_view.len != 0) {
+ throw std::invalid_argument(
+ fmt::format("Cannot fill mapping of size {} with pattern of "
+ "size {} (it does not evenly divide)",
+ mapping().size(), py_view.len));
+ }
+
+ // Specialize by fundamental sizes for fast-path implementations.
+ if (py_view.len == 1) {
+ std::memset(mapping().data(), *static_cast(py_view.buf),
+ mapping().size());
+ } else if (py_view.len == 2) {
+ uint16_t v = *static_cast(py_view.buf);
+ uint16_t *begin =
+ static_cast(static_cast((mapping().data())));
+ std::fill(begin, begin + mapping().size() / sizeof(v), v);
+ } else if (py_view.len == 4) {
+ uint32_t v = *static_cast(py_view.buf);
+ uint32_t *begin =
+ static_cast(static_cast((mapping().data())));
+ std::fill(begin, begin + mapping().size() / sizeof(v), v);
+ } else if (py_view.len == 8) {
+ uint64_t v = *static_cast(py_view.buf);
+ uint64_t *begin =
+ static_cast(static_cast((mapping().data())));
+ std::fill(begin, begin + mapping().size() / sizeof(v), v);
+ } else {
+ // Slow path.
+ uint8_t *begin = mapping().data();
+ uint8_t *end = begin + mapping().size();
+ while (begin < end) {
+ std::memcpy(begin, py_view.buf, py_view.len);
+ begin += py_view.len;
+ }
+ }
+ }
+
+ py::object GetItems(py::handle self_obj, Refs *refs) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::GetItems");
+ if (!dtype()) {
+ throw std::invalid_argument(
+ "The `items` property is only valid for typed mappings but "
+ "this does not have a dtype");
+ }
+ AssertValid();
+ auto &table = refs->element_type_array_type_code_table;
+ auto it = table.find(*dtype());
+ if (it == table.end()) {
+ throw std::invalid_argument(
+ fmt::format("Python array.array type code not known for dtype "
+ "{}: Cannot access items",
+ dtype()->name()));
+ }
+ py::object py_bytes = py::steal(PyBytes_FromObject(self_obj.ptr()));
+ py::object items = refs->array_array_ctor(it->second, py_bytes);
+ return items;
+ }
+
+ void SetItems(Refs *refs, py::handle initializer) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::SetItems");
+ if (!dtype()) {
+ throw std::invalid_argument(
+ "The `items` property is only valid for typed mappings but "
+ "this does not have a dtype");
+ }
+ AssertValid();
+ auto &table = refs->element_type_array_type_code_table;
+ auto it = table.find(*dtype());
+ if (it == table.end()) {
+ throw std::invalid_argument(
+ fmt::format("Python array.array type code not known for dtype "
+ "{}: Cannot access items",
+ dtype()->name()));
+ }
+
+ py::object items = refs->array_array_ctor(it->second, initializer);
+ PyBufferRequest src_info(items, PyBUF_SIMPLE);
+ if (src_info.view().len > mapping().size()) {
+ throw std::invalid_argument(
+ fmt::format("Cannot write {} bytes into buffer of {} bytes",
+ src_info.view().len, mapping().size()));
+ }
+ std::memcpy(mapping().data(), src_info.view().buf, src_info.view().len);
+ }
+
+ private:
+ class mapping mapping_;
+ std::optional dtype_;
+};
+
+// Does in-place creation of a mapping object and stores a pointer to the
+// contained array::mapping C++ object.
+py::object CreateMappingObject(PyMapping **out_cpp_mapping) {
+ py::object py_mapping = py::inst_alloc(py::type());
+ PyMapping *cpp_mapping = py::inst_ptr(py_mapping);
+ new (cpp_mapping) mapping();
+ py::inst_mark_ready(py_mapping);
+ *out_cpp_mapping = cpp_mapping;
+ return py_mapping;
+}
+
+} // namespace
+
+void BindArray(py::module_ &m) {
+ auto refs = std::make_shared();
+
+ py::class_(m, "DType")
+ .def_prop_ro("name", &DType::name)
+ .def_prop_ro("is_boolean", &DType::is_boolean)
+ .def_prop_ro("is_integer", &DType::is_integer)
+ .def_prop_ro("is_float", &DType::is_float)
+ .def_prop_ro("is_complex", &DType::is_complex)
+ .def_prop_ro("bit_count", &DType::bit_count)
+ .def_prop_ro("is_byte_aligned", &DType::is_byte_aligned)
+ .def_prop_ro("dense_byte_count", &DType::dense_byte_count)
+ .def("is_integer_bitwidth", &DType::is_integer_bitwidth)
+ .def("compute_dense_nd_size", &DType::compute_dense_nd_size)
+ .def(py::self == py::self)
+ .def("__repr__", &DType::name);
+
+#define SHORTFIN_DTYPE_HANDLE(et, ident) m.attr(#ident) = DType::ident();
+#include "shortfin/array/dtypes.inl"
+#undef SHORTFIN_DTYPE_HANDLE
+
+ // storage
+ py::class_(m, "storage")
+ .def("__sfinv_marshal__",
+ [](device_array *self, py::capsule inv_capsule, int barrier) {
+ auto *inv =
+ static_cast(inv_capsule.data());
+ static_cast(self)
+ ->AddAsInvocationArgument(
+ inv, static_cast(barrier));
+ })
+ .def_static(
+ "allocate_host",
+ [](local::ScopedDevice &device, iree_device_size_t allocation_size) {
+ return storage::allocate_host(device, allocation_size);
+ },
+ py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
+ .def_static(
+ "allocate_device",
+ [](local::ScopedDevice &device, iree_device_size_t allocation_size) {
+ return storage::allocate_device(device, allocation_size);
+ },
+ py::arg("device"), py::arg("allocation_size"), py::keep_alive<0, 1>())
+ .def(
+ "fill",
+ [](storage &self, py::handle buffer) {
+ Py_buffer py_view;
+ int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
+ if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
+ throw py::python_error();
+ }
+ PyBufferReleaser py_view_releaser(py_view);
+ self.fill(py_view.buf, py_view.len);
+ },
+ py::arg("pattern"), DOCSTRING_STORAGE_FILL)
+ .def(
+ "copy_from", [](storage &self, storage &src) { self.copy_from(src); },
+ py::arg("source_storage"), DOCSTRING_STORAGE_COPY_FROM)
+ .def(
+ "map",
+ [](storage &self, bool read, bool write, bool discard) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyStorage::map");
+ int access = 0;
+ if (read) access |= IREE_HAL_MEMORY_ACCESS_READ;
+ if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE;
+ if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD;
+ if (!access) {
+ throw std::invalid_argument(
+ "One of the access flags must be set");
+ }
+ PyMapping *cpp_mapping = nullptr;
+ py::object py_mapping = CreateMappingObject(&cpp_mapping);
+ self.map_explicit(
+ cpp_mapping->mapping(),
+ static_cast(access));
+ return py_mapping;
+ },
+ py::kw_only(), py::arg("read") = false, py::arg("write") = false,
+ py::arg("discard") = false, DOCSTRING_STORAGE_MAP)
+ .def(py::self == py::self)
+ .def("__len__", &storage::byte_length)
+ .def("__repr__", &storage::to_s);
+
+ // mapping
+ auto mapping_class = py::class_(m, "mapping");
+ mapping_class.def("close", [](PyMapping &self) { self.mapping().reset(); })
+ .def_prop_ro("valid",
+ [](PyMapping &self) -> bool { return self.mapping(); })
+ .def("__enter__", [](py::object self_obj) { return self_obj; })
+ .def(
+ "__exit__",
+ [](PyMapping &self, py::handle exc_type, py::handle exc_value,
+ py::handle exc_tb) { self.mapping().reset(); },
+ py::arg("exc_type").none(), py::arg("exc_value").none(),
+ py::arg("exc_tb").none())
+ .def(
+ "fill",
+ [refs](PyMapping &self, py::int_ value) {
+ self.AssertValid();
+ self.FillFromScalar(refs.get(), value);
+ },
+ py::arg("value"), DOCSTRING_MAPPING_FILL)
+ .def(
+ "fill",
+ [refs](PyMapping &self, py::float_ value) {
+ self.AssertValid();
+ self.FillFromScalar(refs.get(), value);
+ },
+ py::arg("value"), DOCSTRING_MAPPING_FILL)
+ .def(
+ "fill",
+ [](PyMapping &self, py::handle buffer) {
+ self.AssertValid();
+ self.FillFromBuffer(buffer);
+ },
+ py::arg("buffer"), DOCSTRING_MAPPING_FILL)
+ .def_prop_rw(
+ "items",
+ [refs](py::handle self_obj) {
+ PyMapping &self = py::cast(self_obj);
+ return self.GetItems(self_obj, refs.get());
+ },
+ [refs](PyMapping &self, py::handle initializer) {
+ self.SetItems(refs.get(), initializer);
+ },
+ DOCSTRING_MAPPING_ITEMS);
+
+ struct MappingBufferHandler {
+ int operator()(PyMapping &self, Py_buffer *view, int flags) {
+ view->buf = self.mapping().data();
+ view->len = self.mapping().size();
+ view->readonly = !self.mapping().writable();
+ view->itemsize = 1;
+ view->format = (char *)"B"; // Byte
+ view->ndim = 1;
+ view->shape = nullptr;
+ view->strides = nullptr;
+ view->suboffsets = nullptr;
+ view->internal = nullptr;
+ return 0;
+ }
+ };
+ BindBufferProtocol(mapping_class);
+
+ // base_array and subclasses
+ py::class_(m, "base_array")
+ .def_prop_ro("dtype", &base_array::dtype)
+ .def_prop_ro("shape", &base_array::shape);
+
+ py::class_(m, "device_array")
+ .def("__init__", [](py::args, py::kwargs) {})
+ .def_static(
+ "__new__",
+ [](py::handle py_type, class storage storage,
+ std::span shape, DType dtype) {
+ return custom_new_keep_alive(
+ py_type, /*keep_alive=*/storage.fiber(), storage, shape, dtype);
+ },
+ py::arg("cls"), py::arg("storage"), py::arg("shape"),
+ py::arg("dtype"))
+ .def_static(
+ "__new__",
+ [](py::handle py_type, local::ScopedDevice &device,
+ std::span shape, DType dtype) {
+ return custom_new_keep_alive(
+ py_type, /*keep_alive=*/device.fiber(),
+ device_array::for_device(device, shape, dtype));
+ },
+ py::arg("cls"), py::arg("device"), py::arg("shape"), py::arg("dtype"))
+ .def("__sfinv_marshal__",
+ [](device_array *self, py::capsule inv_capsule, int barrier) {
+ auto *inv =
+ static_cast(inv_capsule.data());
+ static_cast(self)
+ ->AddAsInvocationArgument(
+ inv, static_cast(barrier));
+ })
+ .def_static(
+ "for_device",
+ [](local::ScopedDevice &device, std::span shape,
+ DType dtype) {
+ return custom_new_keep_alive(
+ py::type(),
+ /*keep_alive=*/device.fiber(),
+ device_array::for_device(device, shape, dtype));
+ },
+ py::arg("device"), py::arg("shape"), py::arg("dtype"))
+ .def_static(
+ "for_host",
+ [](local::ScopedDevice &device, std::span shape,
+ DType dtype) {
+ return custom_new_keep_alive(
+ py::type(),
+ /*keep_alive=*/device.fiber(),
+ device_array::for_host(device, shape, dtype));
+ },
+ py::arg("device"), py::arg("shape"), py::arg("dtype"))
+ .def("for_transfer",
+ [](device_array &self) {
+ return custom_new_keep_alive(
+ py::type(),
+ /*keep_alive=*/self.device().fiber(), self.for_transfer());
+ })
+ .def_prop_ro("device", &device_array::device,
+ py::rv_policy::reference_internal)
+ .def_prop_ro("storage", &device_array::storage,
+ py::rv_policy::reference_internal)
+ .def(
+ "fill",
+ [](py::handle_t self, py::handle buffer) {
+ self.attr("storage").attr("fill")(buffer);
+ },
+ py::arg("pattern"), DOCSTRING_ARRAY_FILL)
+ .def("copy_from", &device_array::copy_from, py::arg("source_array"),
+ DOCSTRING_ARRAY_COPY_FROM)
+ .def("copy_to", &device_array::copy_to, py::arg("dest_array"),
+ DOCSTRING_ARRAY_COPY_TO)
+ .def("view", PyDeviceArrayView, DOCSTRING_ARRAY_VIEW)
+ .def(
+ "map",
+ [](device_array &self, bool read, bool write, bool discard) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyArray::map");
+ int access = 0;
+ if (read) access |= IREE_HAL_MEMORY_ACCESS_READ;
+ if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE;
+ if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD;
+ if (!access) {
+ throw std::invalid_argument(
+ "One of the access flags must be set");
+ }
+ PyMapping *cpp_mapping = nullptr;
+ py::object py_mapping = CreateMappingObject(&cpp_mapping);
+ cpp_mapping->set_dtype(self.dtype());
+ self.storage().map_explicit(
+ cpp_mapping->mapping(),
+ static_cast(access));
+ return py_mapping;
+ },
+ py::kw_only(), py::arg("read") = false, py::arg("write") = false,
+ py::arg("discard") = false, DOCSTRING_ARRAY_MAP)
+ .def_prop_rw(
+ "items",
+ [refs](device_array &self) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyArray::items");
+ PyMapping *mapping;
+ py::object mapping_obj = CreateMappingObject(&mapping);
+ mapping->set_dtype(self.dtype());
+ self.storage().map_explicit(
+ mapping->mapping(), static_cast(
+ IREE_HAL_MEMORY_ACCESS_READ));
+ return mapping->GetItems(mapping_obj, refs.get());
+ },
+ [refs](device_array &self, py::handle initializer) {
+ PyMapping mapping;
+ mapping.set_dtype(self.dtype());
+ self.storage().map_explicit(
+ mapping.mapping(), static_cast(
+ IREE_HAL_MEMORY_ACCESS_READ));
+ return mapping.SetItems(refs.get(), initializer);
+ },
+ DOCSTRING_ARRAY_ITEMS)
+ .def_prop_ro(
+ "__array_interface__",
+ [refs](device_array &self) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyArray::__array_interface__");
+ py::dict interface;
+ interface["version"] = 3;
+ interface["strides"] = py::none();
+
+ auto shape = self.shape();
+ py::list shapeList;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ shapeList.append(shape[i]);
+ }
+ py::tuple shape_tuple(py::cast(shapeList));
+ interface["shape"] = shape_tuple;
+
+ auto &table = refs->element_type_array_type_code_table;
+ auto it = table.find(self.dtype());
+ if (it == table.end()) {
+ throw std::invalid_argument(fmt::format(
+ "Python array.array type code not known for dtype "
+ "{}: Cannot access items",
+ self.dtype().name()));
+ }
+
+ auto typeString = py::str();
+ if (it->first == DType::float16()) {
+ typeString = py::str("float16");
+ } else {
+ typeString = py::str(it->second);
+ }
+ interface["typestr"] = typeString;
+
+ PyMapping *mapping;
+ py::object mapping_obj = CreateMappingObject(&mapping);
+ mapping->set_dtype(self.dtype());
+ self.storage().map_explicit(
+ mapping->mapping(), static_cast(
+ IREE_HAL_MEMORY_ACCESS_READ));
+ auto items = mapping->GetItems(mapping_obj, refs.get());
+
+ // Obtain pointer to first element in items
+ Py_buffer buffer;
+ if (PyObject_GetBuffer(items.ptr(), &buffer, PyBUF_SIMPLE) != 0) {
+ throw std::runtime_error("Failed to get buffer from items");
+ }
+ void *itemsPtr = buffer.buf;
+ interface["data"] =
+ py::make_tuple(reinterpret_cast(itemsPtr), false);
+ return interface;
+ })
+ .def("__repr__", &device_array::to_s)
+ .def("__str__", [](device_array &self) -> std::string {
+ auto contents = self.contents_to_s();
+ if (!contents) return "<>";
+ return *contents;
+ });
+
+ BindArrayHostOps(m);
+}
+
+} // namespace shortfin::python
diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc
new file mode 100644
index 000000000..3e2a8ebe3
--- /dev/null
+++ b/shortfin/python/array_host_ops.cc
@@ -0,0 +1,747 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "./lib_ext.h"
+#include "./utils.h"
+#include "shortfin/array/api.h"
+#include "shortfin/support/logging.h"
+#include "xtensor/xrandom.hpp"
+#include "xtensor/xsort.hpp"
+#include "xtl/xhalf_float.hpp"
+
+using namespace shortfin::array;
+
+namespace shortfin::python {
+
+namespace {
+
+static const char DOCSTRING_ARGMAX[] =
+ R"(Returns the indices of the maximum values along an axis.
+
+Implemented for dtypes: float16, float32.
+
+Args:
+ input: An input array.
+ axis: Axis along which to sort. Defaults to the last axis (note that the
+ numpy default is into the flattened array, which we do not support).
+ keepdims: Whether to preserve the sort axis. If true, this will become a unit
+ dim. If false, it will be removed.
+ out: Array to write into. If specified, it must have an expected shape and
+ int64 dtype.
+ device_visible: Whether to make the result array visible to devices. Defaults to
+ False.
+
+Returns:
+ A device_array of dtype=int64, allocated on the host and not visible to the device.
+)";
+
+static const char DOCSTRING_CONVERT[] =
+ R"(Does an elementwise conversion from one dtype to another.
+
+The same behavior exists for several conversion ops:
+
+* `convert` : element-wise conversion like a static cast.
+* `round` : element-wise nearest integer to the input, rounding halfway cases
+ away from zero.
+* `ceil` : element-wise smallest integer value not less than the input.
+* `floor` : element-wise smallest integer value not greater than the input.
+* `trunc` : element-wise nearest integer not greater in magnitude than the input.
+
+For nearest-integer conversions (round, ceil, floor, trunc), the input dtype
+must be a floating point array, and the output must be a byte-aligned integer
+type between 8 and 32 bits.
+
+Args:
+ input: An input array of a floating point dtype.
+ dtype: If given, then this is the explicit output dtype.
+ out: If given, then the results are written to this array. This implies the
+ output dtype.
+ device_visible: Whether to make the result array visible to devices. Defaults to
+ False.
+
+Returns:
+ A device_array of the requested dtype, or the input dtype if not specified.
+)";
+
+static const char DOCSTRING_FILL_RANDN[] =
+ R"(Fills an array with numbers sampled from the standard ormal distribution.
+
+Values are samples with a mean of 0 and standard deviation of 1.
+
+This operates like torch.randn but only supports in place fills to an existing
+array, deriving shape and dtype from the output array.
+
+Args:
+ out: Output array to fill.
+ generator: Uses an explicit generator. If not specified, uses a global
+ default.
+)";
+
+static const char DOCSTRING_RANDOM_GENERATOR[] =
+ R"(Returns an object for generating random numbers.
+
+ Every instance is self contained and does not share state with others.
+
+ Args:
+ seed: Optional seed for the generator. Not setting a seed will cause an
+ implementation defined value to be used, which may in fact be a completely
+ fixed number.
+ )";
+
+static const char DOCSTRING_TRANSPOSE[] =
+ R"(Transposes axes of an array according to a permutation vector.
+
+Args:
+ input: Array to transpose.
+ permutation: New sequence of axes. Must have same number of elements as the
+ rank of input.
+ out: If given, then the results are written to this array.
+ device_visible: Whether to make the result array visible to devices. Defaults
+ to False.
+)";
+
+#define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): \
+ return compute.template operator()()
+
+#define SF_UNARY_THUNK_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): \
+ compute.template operator()(); \
+ break
+
+#define SF_MOVEMENT_OP_SWITCH(dtype) \
+ if (!dtype.is_byte_aligned()) \
+ throw std::invalid_argument( \
+ "data movement ops are only defined for byte aligned dtypes"); \
+ switch (dtype.dense_byte_count()) { \
+ case 1: \
+ return compute.template operator()(); \
+ case 2: \
+ return compute.template operator()(); \
+ case 4: \
+ return compute.template operator()(); \
+ case 8: \
+ return compute.template operator()(); \
+ default: \
+ throw std::invalid_argument( \
+ "data movement ops are only defined for dtypes of size 1, 2, " \
+ "4, 8"); \
+ }
+
+struct PyRandomGenerator {
+ public:
+ using SeedType = xt::random::default_engine_type::result_type;
+ PyRandomGenerator(std::optional seed) {
+ if (seed) SetSeed(*seed);
+ }
+
+ static PyRandomGenerator &get_default() {
+ static PyRandomGenerator default_generator(std::nullopt);
+ return default_generator;
+ }
+
+ void SetSeed(SeedType seed) { engine().seed(seed); }
+
+ xt::random::default_engine_type &engine() { return engine_; }
+
+ private:
+ xt::random::default_engine_type engine_;
+};
+
+// Generic conversion templates, split into a bindable template and functors
+// that operate on pre-allocated outputs.
+template
+device_array GenericElementwiseConvert(device_array &input,
+ std::optional dtype,
+ std::optional out,
+ bool device_visible) {
+ // Argument check and output allocation.
+ if (!dtype) {
+ dtype = out ? out->dtype() : input.dtype();
+ } else {
+ if (out && out->dtype() != dtype) {
+ throw std::invalid_argument(
+ "if both dtype and out are specified, they must match");
+ }
+ }
+ if (!out) {
+ out.emplace(device_array::for_host(input.device(), input.shape(), *dtype,
+ device_visible));
+ }
+
+ ConvertFunc::Invoke(input, *dtype, *out);
+ return *out;
+}
+
+// Generic elementwise conversion functor
+struct ConvertFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::convert");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(*input_t); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(float16, half_float::half);
+ SF_STORE_CASE(float32, float);
+ SF_STORE_CASE(float64, double);
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ SF_STORE_CASE(uint64, uint64_t);
+ SF_STORE_CASE(int64, int64_t);
+ default:
+ throw std::invalid_argument("Invalid output dtype for convert op");
+ }
+
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ SF_UNARY_THUNK_CASE(float64, double);
+ SF_UNARY_THUNK_CASE(uint8, uint8_t);
+ SF_UNARY_THUNK_CASE(int8, int8_t);
+ SF_UNARY_THUNK_CASE(uint16, uint16_t);
+ SF_UNARY_THUNK_CASE(int16, int16_t);
+ SF_UNARY_THUNK_CASE(uint32, uint32_t);
+ SF_UNARY_THUNK_CASE(int32, uint32_t);
+ SF_UNARY_THUNK_CASE(uint64, uint64_t);
+ SF_UNARY_THUNK_CASE(int64, int64_t);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+// Converting round functor.
+struct ConvertRoundFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::round");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::round(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+struct ConvertCeilFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::ceil");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::ceil(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+struct ConvertFloorFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::floor");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::floor(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+struct ConvertTruncFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::trunc");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::trunc(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+void OptionalArrayCast(py::handle handle,
+ std::optional &maybe_array) {
+ if (py::isinstance(handle)) {
+ maybe_array.emplace(py::cast(handle));
+ }
+}
+
+int DTypePromotionRank(DType dtype) {
+ int rank = 1;
+ if (dtype.is_boolean())
+ rank *= 1000;
+ else if (dtype.is_integer())
+ rank *= 2000;
+ else if (dtype.is_float())
+ rank *= 4000;
+ else if (dtype.is_complex())
+ rank *= 8000;
+ return rank + dtype.bit_count();
+}
+
+DType PromoteArithmeticTypes(std::optional lhs_dtype,
+ std::optional rhs_dtype) {
+ if (!lhs_dtype && !rhs_dtype) {
+ throw std::invalid_argument(
+ "Elementwise operators require at least one argument to be a "
+ "device_array");
+ }
+
+ // One not an array: promote to the array type.
+ if (!lhs_dtype)
+ return *rhs_dtype;
+ else if (!rhs_dtype)
+ return *lhs_dtype;
+
+ int lhs_rank = DTypePromotionRank(*lhs_dtype);
+ int rhs_rank = DTypePromotionRank(*rhs_dtype);
+ DType promoted_dtype = lhs_rank < rhs_rank ? *rhs_dtype : *lhs_dtype;
+
+ // If mismatched signed/unsigned, then need to promote to the next signed
+ // dtype.
+ if (promoted_dtype.is_integer()) {
+ bool lhs_unsigned = iree_all_bits_set(
+ lhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED);
+ bool rhs_unsigned = iree_all_bits_set(
+ rhs_dtype->numerical_type(), IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED);
+ if ((lhs_unsigned || rhs_unsigned) && !(lhs_unsigned && rhs_unsigned)) {
+ // Signed/unsigned mismatch. Promote to next.
+ switch (promoted_dtype) {
+ case DType::uint8():
+ case DType::int8():
+ return DType::int16();
+ case DType::uint16():
+ case DType::int16():
+ return DType::int32();
+ case DType::uint32():
+ case DType::int32():
+ return DType::int64();
+ default:
+ // Jax's type promotion chart says this goes to a weak FP type, but
+ // we don't implement such a construct and I don't really see how
+ // that makes sense in a system setting like this, so we just saturate
+ // to 64bit.
+ return DType::int64();
+ }
+ }
+ }
+
+ return promoted_dtype;
+}
+
+// ---------------------------------------------------------------------------//
+// Elementwise support
+// ---------------------------------------------------------------------------//
+
+// Python element type scalar conversion functions.
+uint8_t ConvertPyToEltTy(py::handle py_value, uint8_t zero) {
+ return py::cast(py_value);
+}
+
+int8_t ConvertPyToEltTy(py::handle py_value, int8_t zero) {
+ return py::cast(py_value);
+}
+
+uint16_t ConvertPyToEltTy(py::handle py_value, uint16_t zero) {
+ return py::cast(py_value);
+}
+
+int16_t ConvertPyToEltTy(py::handle py_value, int16_t zero) {
+ return py::cast(py_value);
+}
+
+uint32_t ConvertPyToEltTy(py::handle py_value, uint32_t zero) {
+ return py::cast(py_value);
+}
+
+int32_t ConvertPyToEltTy(py::handle py_value, int32_t zero) {
+ return py::cast(py_value);
+}
+
+uint64_t ConvertPyToEltTy(py::handle py_value, uint64_t zero) {
+ return py::cast(py_value);
+}
+
+int64_t ConvertPyToEltTy(py::handle py_value, int64_t zero) {
+ return py::cast(py_value);
+}
+
+float ConvertPyToEltTy(py::handle py_value, float zero) {
+ return py::cast(py_value);
+}
+
+double ConvertPyToEltTy(py::handle py_value, double zero) {
+ return py::cast(py_value);
+}
+
+half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) {
+ // Python can't cast directly to half so first go to double.
+ return static_cast(py::cast(py_value));
+}
+
+struct AddFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs + rhs;
+ }
+};
+
+struct DivideFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs / rhs;
+ }
+};
+
+struct MultiplyFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs * rhs;
+ }
+};
+
+struct SubtractFunctor {
+ template
+ static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
+ return lhs - rhs;
+ }
+};
+
+template
+device_array ElementwiseOperation(py::handle lhs, py::handle rhs,
+ std::optional out,
+ bool device_visible) {
+ std::optional lhs_array;
+ OptionalArrayCast(lhs, lhs_array);
+ std::optional rhs_array;
+ OptionalArrayCast(rhs, rhs_array);
+ auto dtype = PromoteArithmeticTypes(
+ lhs_array ? std::optional(lhs_array->dtype()) : std::nullopt,
+ rhs_array ? std::optional(rhs_array->dtype()) : std::nullopt);
+ if (lhs_array && lhs_array->dtype() != dtype) {
+ auto converted = GenericElementwiseConvert(
+ *lhs_array, dtype, /*out=*/std::nullopt,
+ /*device_visible=*/false);
+ lhs_array.reset();
+ lhs_array.emplace(std::move(converted));
+ }
+ if (rhs_array && rhs_array->dtype() != dtype) {
+ auto converted = GenericElementwiseConvert(
+ *rhs_array, dtype, /*out=*/std::nullopt,
+ /*device_visible=*/false);
+ rhs_array.reset();
+ rhs_array.emplace(std::move(converted));
+ }
+
+ auto compute = [&]() -> device_array {
+ auto handle_result = [&](
+ D &&device, A &&result) -> device_array {
+ if (!out) {
+ out.emplace(device_array::for_host(device, result.shape(), dtype,
+ device_visible));
+ }
+ auto out_t = out->map_xtensor_w();
+ *out_t = result;
+ return *out;
+ };
+ if (!rhs_array) {
+ auto lhs_t = lhs_array->map_xtensor();
+ xt::xarray rhs_scalar = ConvertPyToEltTy(rhs, EltTy());
+ return handle_result(lhs_array->device(),
+ ElementwiseFunctor::Invoke(*lhs_t, rhs_scalar));
+ } else if (!lhs_array) {
+ xt::xarray lhs_scalar = ConvertPyToEltTy(lhs, EltTy());
+ auto rhs_t = rhs_array->map_xtensor();
+ return handle_result(rhs_array->device(),
+ ElementwiseFunctor::Invoke(lhs_scalar, *rhs_t));
+ } else {
+ auto lhs_t = lhs_array->map_xtensor();
+ auto rhs_t = rhs_array->map_xtensor();
+ return handle_result(lhs_array->device(),
+ ElementwiseFunctor::Invoke(*lhs_t, *rhs_t));
+ }
+ };
+
+ switch (dtype) {
+ SF_UNARY_FUNCTION_CASE(float16, half_float::half);
+ SF_UNARY_FUNCTION_CASE(float32, float);
+ SF_UNARY_FUNCTION_CASE(float64, double);
+ SF_UNARY_FUNCTION_CASE(uint8, uint8_t);
+ SF_UNARY_FUNCTION_CASE(int8, int8_t);
+ SF_UNARY_FUNCTION_CASE(uint16, uint16_t);
+ SF_UNARY_FUNCTION_CASE(int16, int16_t);
+ SF_UNARY_FUNCTION_CASE(uint32, uint32_t);
+ SF_UNARY_FUNCTION_CASE(int32, uint32_t);
+ SF_UNARY_FUNCTION_CASE(uint64, uint64_t);
+ SF_UNARY_FUNCTION_CASE(int64, int64_t);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for in elementwise op", dtype.name()));
+ }
+}
+
+} // namespace
+
+void BindArrayHostOps(py::module_ &m) {
+ // Simple op definitions.
+ m.def(
+ "argmax",
+ [](device_array &input, int axis, std::optional out,
+ bool keepdims, bool device_visible) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::argmax");
+ if (axis < 0) axis += input.shape().size();
+ if (axis < 0 || axis >= input.shape().size()) {
+ throw std::invalid_argument(
+ fmt::format("Axis out of range: Must be [0, {}) but got {}",
+ input.shape().size(), axis));
+ }
+ if (out && (out->dtype() != DType::int64())) {
+ throw std::invalid_argument("out array must have dtype=int64");
+ }
+ auto compute = [&]() {
+ auto input_t = input.map_xtensor();
+ auto result = xt::argmax(*input_t, axis);
+ if (!out) {
+ out.emplace(device_array::for_host(input.device(), result.shape(),
+ DType::int64(), device_visible));
+ }
+ auto out_t = out->map_xtensor_w();
+ *out_t = result;
+ if (keepdims) {
+ out->expand_dims(axis);
+ }
+ return *out;
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_FUNCTION_CASE(float16, half_float::half);
+ SF_UNARY_FUNCTION_CASE(float32, float);
+ default:
+ throw std::invalid_argument(
+ fmt::format("Unsupported dtype({}) for operator argmax",
+ input.dtype().name()));
+ }
+ },
+ py::arg("input"), py::arg("axis") = -1, py::arg("out") = py::none(),
+ py::kw_only(), py::arg("keepdims") = false,
+ py::arg("device_visible") = false, DOCSTRING_ARGMAX);
+
+ // Random number generation.
+ py::class_(m, "RandomGenerator")
+ .def(py::init>(),
+ py::arg("seed") = py::none(), DOCSTRING_RANDOM_GENERATOR);
+ m.def(
+ "fill_randn",
+ [](device_array out, std::optional gen) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::fill_randn");
+ if (!gen) gen = &PyRandomGenerator::get_default();
+ auto compute = [&]() {
+ auto result = xt::random::randn(out.shape_container(), /*mean=*/0.0,
+ /*std_dev=*/1.0, (*gen)->engine());
+ auto out_t = out.map_xtensor_w();
+ *out_t = result;
+ };
+
+ switch (out.dtype()) {
+ SF_UNARY_FUNCTION_CASE(float16, half_float::half);
+ SF_UNARY_FUNCTION_CASE(float32, float);
+ default:
+ throw std::invalid_argument(
+ fmt::format("Unsupported dtype({}) for operator randn",
+ out.dtype().name()));
+ }
+ },
+ py::arg("out"), py::arg("generator") = py::none(), DOCSTRING_FILL_RANDN);
+
+// Data-type conversion and rounding.
+#define SF_DEF_CONVERT(py_name, target) \
+ m.def(py_name, target, py::arg("input"), py::kw_only(), \
+ py::arg("dtype") = py::none(), py::arg("out") = py::none(), \
+ py::arg("device_visible") = false, DOCSTRING_CONVERT)
+ SF_DEF_CONVERT("convert", GenericElementwiseConvert);
+ SF_DEF_CONVERT("ceil", GenericElementwiseConvert);
+ SF_DEF_CONVERT("floor", GenericElementwiseConvert);
+ SF_DEF_CONVERT("round", GenericElementwiseConvert);
+ SF_DEF_CONVERT("trunc", GenericElementwiseConvert);
+
+ // Transpose.
+ m.def(
+ "transpose",
+ [](device_array input, std::vector permutation,
+ std::optional out, bool device_visible) {
+ auto compute = [&]() -> device_array {
+ auto input_t = input.map_xtensor();
+ auto permuted_t =
+ xt::transpose(*input_t, permutation, xt::check_policy::full());
+ if (!out) {
+ out.emplace(device_array::for_host(input.device(),
+ permuted_t.shape(),
+ input.dtype(), device_visible));
+ }
+ auto out_t = out->map_xtensor_w();
+ *out_t = permuted_t;
+ return *out;
+ };
+ SF_MOVEMENT_OP_SWITCH(input.dtype());
+ },
+ py::arg("input"), py::arg("permutation"), py::arg("out") = py::none(),
+ py::arg("device_visible") = false, DOCSTRING_TRANSPOSE);
+
+// Elementwise.
+#define SF_DEF_ELEMENTWISE(py_name, target) \
+ m.def(py_name, target, py::arg("lhs"), py::arg("rhs"), py::kw_only(), \
+ py::arg("out") = py::none(), py::arg("device_visible") = false)
+ SF_DEF_ELEMENTWISE("add", ElementwiseOperation);
+ SF_DEF_ELEMENTWISE("divide", ElementwiseOperation);
+ SF_DEF_ELEMENTWISE("multiply", ElementwiseOperation);
+ SF_DEF_ELEMENTWISE("subtract", ElementwiseOperation);
+
+} // namespace shortfin::python
+
+} // namespace shortfin::python
diff --git a/libshortfin/bindings/python/lib_ext.cc b/shortfin/python/lib_ext.cc
similarity index 52%
rename from libshortfin/bindings/python/lib_ext.cc
rename to shortfin/python/lib_ext.cc
index cc7c21190..c668e6a8b 100644
--- a/libshortfin/bindings/python/lib_ext.cc
+++ b/shortfin/python/lib_ext.cc
@@ -10,10 +10,10 @@
#include "shortfin/array/array.h"
#include "shortfin/array/storage.h"
#include "shortfin/local/async.h"
+#include "shortfin/local/fiber.h"
#include "shortfin/local/messaging.h"
#include "shortfin/local/process.h"
#include "shortfin/local/program.h"
-#include "shortfin/local/scope.h"
#include "shortfin/local/system.h"
#if defined(SHORTFIN_HAVE_AMDGPU)
#include "shortfin/local/systems/amdgpu.h"
@@ -26,6 +26,82 @@ namespace shortfin::python {
namespace {
+static const char DOCSTRING_SYSTEM_CTOR[] =
+ R"(Constructs a System based on system type and kwargs.
+
+System types depend on how the library was compiled and correspond to
+SystemBuilder classes. This API is a shorthand for creating a SystemBuilder
+and calling create_system() on it.
+)";
+
+static const char DOCSTRING_SYSTEM_BUILDER_CTOR[] =
+ R"(Constructs a system with environment dependent configuration.
+
+Most configuration is done by way of key/value arguments. Arguments are meant
+to be derived from flags or config files and are expected to be simple strings
+or integer values:
+
+ * "system_type": A supported system type, corresponding to a subclass of
+ SystemBuilder. Typically available values include "hostcpu", "amdgpu", etc.
+ * Other keywords are passed to the concrete SystemBuilder subclass.
+
+Resolution for the `system_type` is special and is checked in three steps, with
+the first defined winning:
+
+1. `kwargs["system_type"]`
+2. Environment `SHORTFIN_SYSTEM_TYPE` variable (or as controlled by `env_prefix`)
+3. `SystemBuilder.default_system_type` as a process-wide default.
+
+Args:
+ env_prefix: Controls how options are looked up in the environment. By default,
+ the prefix is "SHORTFIN_" and upper-cased options are appended to it. Any
+ option not explicitly specified but in the environment will be used. Pass
+ None to disable environment lookup.
+ **kwargs: Key/value arguments for controlling setup of the system.
+)";
+
+static const char DOCSTRING_HOSTCPU_SYSTEM_BUILDER_CTOR[] =
+ R"(Constructs a system with CPU based devices.
+
+Most configuration is done by way of key/value arguments. Arguments are meant
+to be derived from flags or config files and are expected to be simple strings
+or integer values:
+
+ * "hostcpu_topology_nodes": Takes one of the special values "current" (default)
+ or "all". If not one of those, this should be a comma-delimited list of
+ NUMA node ids. Each NUMA node will be modeled as one device queue and will
+ show up on the system as a device.
+ * "hostcpu_topology_max_group_count": Maximum number of groups to create per
+ node. The actual number of groups is derived by a heuristic (which can be
+ influenced by other options) such that there will not be more groups than
+ eligible physical cores on the node.
+
+Args:
+ env_prefix: Controls how options are looked up in the environment. By default,
+ the prefix is "SHORTFIN_" and upper-cased options are appended to it. Any
+ option not explicitly specified but in the environment will be used. Pass
+ None to disable environment lookup.
+ **kwargs: Key/value arguments for controlling setup of the system.
+)";
+
+static const char DOCSTRING_HOSTCPU_SYSTEM_BUILDER_HOSTCPU_ALLOCATOR_SPECS[] =
+ R"(Allocator specs to apply to HOSTCPU devices configured by this builder.
+
+This uses syntax like::
+
+ some_allocator
+ some_allocator:key=value
+ some_allocator:key=value,key=value
+ some_allocator:key=value,key=value;other_allocator:key=value
+
+Typical values for `some_allocator` include `caching` and `debug`.
+
+This can be set via a keyword of `amdgpu_allocators`, which will only apply to
+HOSTCPU devices or `allocators` which will apply to all contained devices.
+Similarly, it is available on a `SHORTFIN_` prefixed env variable if environment
+lookup is not disabled.
+)";
+
static const char DOCSTRING_PROGRAM_FUNCTION_INVOCATION[] =
R"(Creates an invocation object targeting the function.
@@ -58,8 +134,19 @@ class Refs {
return lazy_PyWorkerEventLoop_;
}
+ std::optional default_system_type() {
+ iree::slim_mutex_lock_guard g(mu_);
+ return default_system_type_;
+ }
+ void set_default_system_type(std::optional st) {
+ iree::slim_mutex_lock_guard g(mu_);
+ default_system_type_ = st;
+ }
+
private:
+ iree::slim_mutex mu_;
py::object lazy_PyWorkerEventLoop_;
+ std::optional default_system_type_;
};
// We need a fair bit of accounting additions in order to make Workers usable
@@ -86,6 +173,7 @@ class PyWorkerExtension : public local::Worker::Extension {
py::handle loop() { return loop_; }
void OnThreadStart() noexcept override {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStart");
// Python threading initialization.
// If our own thread, teach Python about it. Not done for donated.
if (worker().options().owned_thread) {
@@ -100,6 +188,7 @@ class PyWorkerExtension : public local::Worker::Extension {
}
void OnThreadStop() noexcept override {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStop");
{
// Do Python level thread cleanup.
py::gil_scoped_acquire g;
@@ -142,10 +231,18 @@ class PyWorkerExtension : public local::Worker::Extension {
class PyProcess : public local::detail::BaseProcess {
public:
- PyProcess(std::shared_ptr scope, std::shared_ptr refs)
- : BaseProcess(std::move(scope)), refs_(std::move(refs)) {}
+ PyProcess(std::shared_ptr refs)
+ : BaseProcess(), refs_(std::move(refs)) {}
+ using BaseProcess::Initialize;
+ using BaseProcess::is_initialized;
using BaseProcess::Launch;
+ void AssertInitialized() {
+ if (!is_initialized()) {
+ throw std::logic_error("Process.__init__ not called in constructor");
+ }
+ }
+
void ScheduleOnWorker() override {
// This is tricky: We need to retain the object reference across the
// thread transition, but on the receiving side, the GIL will not be
@@ -154,14 +251,15 @@ class PyProcess : public local::detail::BaseProcess {
// the callback.
py::handle self_object = py::cast(this, py::rv_policy::none);
self_object.inc_ref();
- scope()->worker().CallThreadsafe(
+ fiber()->worker().CallThreadsafe(
std::bind(&PyProcess::RunOnWorker, self_object));
}
static void RunOnWorker(py::handle self_handle) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyProcess:RunOnWorker");
py::gil_scoped_acquire g;
// Steal the reference back from ScheduleOnWorker. Important: this is
// very likely the last reference to the process. So self must not be
- // touched after self_object goes out of scope.
+ // touched after self_object goes out of fiber.
py::object self_object = py::steal(self_handle);
PyProcess *self = py::cast(self_handle);
// We assume that the run method either returns None (def) or a coroutine
@@ -207,9 +305,10 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) {
py::cast(py::repr(arg.type()))));
}
-local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self,
- py::args args) {
- auto inv = self.CreateInvocation();
+local::ProgramInvocation::Future PyFunctionCall(
+ local::ProgramFunction &self, py::args args, local::Fiber &fiber,
+ std::optional isolation) {
+ auto inv = self.CreateInvocation(fiber.shared_from_this(), isolation);
py::capsule inv_capsule(inv.get());
for (py::handle arg : args) {
PyAddProgramInvocationArg(inv_capsule, arg);
@@ -246,6 +345,7 @@ py::object PyRehydrateRef(local::ProgramInvocation *inv,
py::object RunInForeground(std::shared_ptr refs, local::System &self,
py::object coro) {
+ SHORTFIN_TRACE_SCOPE_NAMED("CoroRunInForeground");
bool is_main_thread =
refs->threading_current_thread().is(refs->threading_main_thread());
@@ -311,9 +411,35 @@ py::object RunInForeground(std::shared_ptr refs, local::System &self,
return result;
}
+ConfigOptions CreateConfigOptions(std::optional &env_prefix,
+ py::kwargs &kwargs, bool validate_undef) {
+ ConfigOptions options(std::move(env_prefix),
+ validate_undef
+ ? ConfigOptions::ValidationLevel::UNDEF_ERROR
+ : ConfigOptions::ValidationLevel::UNDEF_WARN);
+ for (auto it = kwargs.begin(); it != kwargs.end(); ++it) {
+ std::string key = py::cast((*it).first);
+ std::string value = py::cast(py::str((*it).second));
+ options.SetOption(std::move(key), std::move(value));
+ }
+ return options;
+}
+
} // namespace
NB_MODULE(lib, m) {
+ // Tragically, debug builds of Python do the right thing and don't immortalize
+ // many identifiers and such. This makes the last chance leak checking that
+ // nanobind does somewhat unreliable since the reports it prints may be
+ // to identifiers that are no longer live (at a time in process shutdown
+ // where it is expected that everything left just gets dropped on the floor).
+ // This causes segfaults or ASAN violations in the leak checker on exit in
+ // certain scenarios where we have spurious "leaks" of global objects.
+
+ py::set_leak_warnings(false);
+
+ logging::InitializeFromEnv();
+
py::register_exception_translator(
[](const std::exception_ptr &p, void * /*unused*/) {
try {
@@ -339,6 +465,13 @@ NB_MODULE(lib, m) {
});
py::class_(m, "_OpaqueVmRef");
+
+ // Logging entrypoints.
+ m.def("log_debug", [](std::string_view sv) { logging::debug("{}", sv); });
+ m.def("log_info", [](std::string_view sv) { logging::info("{}", sv); });
+ m.def("log_warn", [](std::string_view sv) { logging::warn("{}", sv); });
+ m.def("log_error", [](std::string_view sv) { logging::error("{}", sv); });
+
auto local_m = m.def_submodule("local");
BindLocal(local_m);
BindHostSystem(local_m);
@@ -371,16 +504,94 @@ void BindLocal(py::module_ &m) {
std::make_unique(worker, interp_state, refs));
};
+ py::enum_(m, "ProgramIsolation")
+ .value("NONE", local::ProgramIsolation::NONE)
+ .value("PER_FIBER", local::ProgramIsolation::PER_FIBER)
+ .value("PER_CALL", local::ProgramIsolation::PER_CALL)
+ .export_values();
+
py::class_(m, "SystemBuilder")
+ .def("__init__", [](py::args, py::kwargs) {})
+ .def_static(
+ "__new__",
+ [refs](py::handle cls, std::optional env_prefix,
+ bool validate_undef, py::kwargs kwargs) {
+ auto options =
+ CreateConfigOptions(env_prefix, kwargs, validate_undef);
+ std::optional system_type =
+ options.GetOption("system_type");
+ std::optional default_system_type;
+ if (!system_type) {
+ default_system_type = refs->default_system_type();
+ if (!default_system_type) {
+ throw std::invalid_argument(
+ "In order to construct a generic SystemBuilder, a "
+ "`system_type=` keyword (or appropriate environment "
+ "variable) must be specified. Alternatively, a default can "
+ "be set process wide via "
+ "`SystemBuilder.default_system_type =`");
+ }
+ system_type = *default_system_type;
+ }
+ return local::SystemBuilder::ForSystem(
+ iree_allocator_system(), *system_type, std::move(options));
+ },
+ // Note that for some reason, no-arg construction passes no arguments
+ // to __new__. We allow the single positional argument to be none,
+ // which satisfies this case in practice.
+ py::arg("cls") = py::none(), py::kw_only(),
+ py::arg("env_prefix").none() = "SHORTFIN_",
+ py::arg("validate_undef") = true, py::arg("kwargs"),
+ DOCSTRING_SYSTEM_BUILDER_CTOR)
+ .def_prop_rw_static(
+ "default_system_type",
+ [refs](py::handle /*unused*/) { return refs->default_system_type(); },
+ [refs](py::handle /*unused*/, std::optional st) {
+ refs->set_default_system_type(std::move(st));
+ })
.def("create_system", [live_system_refs,
worker_initializer](local::SystemBuilder &self) {
auto system_ptr = self.CreateSystem();
system_ptr->AddWorkerInitializer(worker_initializer);
auto system_obj = py::cast(system_ptr, py::rv_policy::take_ownership);
live_system_refs.attr("add")(system_obj);
+ try {
+ self.config_options().ValidateUndef();
+ } catch (...) {
+ system_obj.attr("shutdown")();
+ throw;
+ }
return system_obj;
});
py::class_(m, "System", py::is_weak_referenceable())
+ .def(
+ "__init__",
+ [live_system_refs](py::object self_obj, py::args, py::kwargs) {
+ live_system_refs.attr("add")(self_obj);
+ },
+ DOCSTRING_SYSTEM_CTOR)
+ .def_static(
+ "__new__",
+ [worker_initializer](py::handle py_type, std::string_view system_type,
+ std::optional env_prefix,
+ bool validate_undef, py::kwargs kwargs) {
+ auto options =
+ CreateConfigOptions(env_prefix, kwargs, validate_undef);
+ auto system = local::System::Create(
+ iree_allocator_system(), system_type, std::move(options));
+ system->AddWorkerInitializer(worker_initializer);
+ return system;
+ },
+ py::arg("type"), py::arg("system_type"), py::kw_only(),
+ py::arg("env_prefix") = "SHORTFIN_", py::arg("validate_undef") = true,
+ py::arg("kwargs"))
+ .def("__enter__", [](py::object self_obj) { return self_obj; })
+ .def(
+ "__exit__",
+ [](local::System &self, py::handle exc_type, py::handle exc_value,
+ py::handle exc_tb) { self.Shutdown(); },
+ py::arg("exc_type").none(), py::arg("exc_value").none(),
+ py::arg("exc_tb").none())
.def("shutdown", &local::System::Shutdown)
// Access devices by list, name, or lookup.
.def_prop_ro("device_names",
@@ -396,30 +607,33 @@ void BindLocal(py::module_ &m) {
.def(
"device",
[](local::System &self, std::string_view key) {
- auto it = self.named_devices().find(key);
- if (it == self.named_devices().end()) {
+ local::Device *device = self.FindDeviceByName(key);
+ if (!device) {
throw std::invalid_argument(fmt::format("No device '{}'", key));
}
- return it->second;
+ return device;
},
py::rv_policy::reference_internal)
.def(
"create_queue",
- [](local::System &self, std::string name) -> local::Queue & {
+ [](local::System &self,
+ std::optional name) -> std::shared_ptr {
local::Queue::Options options;
- options.name = std::move(name);
+ if (name) {
+ options.name = std::move(*name);
+ }
return self.CreateQueue(std::move(options));
},
- py::arg("name"), py::rv_policy::reference_internal)
+ py::arg("name") = py::none(), py::rv_policy::reference_internal)
.def("named_queue", &local::System::named_queue, py::arg("name"),
py::rv_policy::reference_internal)
.def(
- "create_scope",
+ "create_fiber",
[](local::System &self, local::Worker *worker,
py::handle raw_devices) {
// TODO: I couldn't really figure out how to directly accept an
// optional kw-only arg without it just being a raw object/handle.
- // If the passed devices is none, then we create the scope with
+ // If the passed devices is none, then we create the fiber with
// all devices in the system. Otherwise, with those explicitly
// given.
std::vector devices;
@@ -434,7 +648,7 @@ void BindLocal(py::module_ &m) {
worker = dynamic_cast(&self.init_worker());
}
- return self.CreateScope(*worker, devices);
+ return self.CreateFiber(*worker, devices);
},
py::rv_policy::reference_internal,
py::arg("worker").none() = py::none(), py::kw_only(),
@@ -469,7 +683,6 @@ void BindLocal(py::module_ &m) {
py::class_(m, "Device")
.def_prop_ro("name", &local::Device::name)
.def_prop_ro("node_affinity", &local::Device::node_affinity)
- .def_prop_ro("node_locked", &local::Device::node_locked)
.def(py::self == py::self)
.def("__repr__", &local::Device::to_s);
py::class_(m, "DeviceAffinity")
@@ -481,31 +694,55 @@ void BindLocal(py::module_ &m) {
.def("__repr__", &local::DeviceAffinity::to_s);
py::class_(m, "Program")
- .def(py::new_([](std::span modules,
- local::Scope &scope, bool trace_execution) {
- local::Program::Options options;
- options.trace_execution = trace_execution;
- return local::Program::Load(scope.shared_from_this(), modules,
- std::move(options));
- }),
- py::arg("modules"), py::arg("scope"), py::kw_only(),
- py::arg("trace_execution") = false)
+ .def(
+ py::new_([](std::span modules,
+ std::vector devices,
+ bool trace_execution, local::ProgramIsolation isolation) {
+ local::Program::Options options;
+ options.devices = devices;
+ options.trace_execution = trace_execution;
+ options.isolation = isolation;
+ return local::Program::Load(modules, std::move(options));
+ }),
+ py::arg("modules"), py::kw_only(), py::arg("devices"),
+ py::arg("trace_execution") = false,
+ py::arg("isolation") = local::ProgramIsolation::PER_FIBER)
.def_prop_ro("exports", &local::Program::exports)
+ .def_prop_ro("isolation", &local::Program::isolation)
.def("lookup_function", &local::Program::LookupRequiredFunction)
.def("__getitem__", &local::Program::LookupRequiredFunction);
py::class_(m, "ProgramFunction")
.def_prop_ro("name", &local::ProgramFunction::name)
.def_prop_ro("calling_convention",
&local::ProgramFunction::calling_convention)
- .def("invocation", &local::ProgramFunction::CreateInvocation,
- DOCSTRING_PROGRAM_FUNCTION_INVOCATION)
- .def("__call__", PyFunctionCall, py::arg("args"))
+ .def(
+ "invocation",
+ [](local::ProgramFunction &self, local::Fiber &fiber,
+ std::optional isolation) {
+ return self.CreateInvocation(fiber.shared_from_this(), isolation);
+ },
+ py::arg("fiber"), py::arg("isolation") = py::none(),
+ DOCSTRING_PROGRAM_FUNCTION_INVOCATION)
+ .def_prop_ro("isolation", &local::ProgramFunction::isolation)
+ .def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(),
+ py::arg("fiber"), py::arg("isolation") = py::none())
.def("__repr__", &local::ProgramFunction::to_s);
py::class_(m, "ProgramModule")
.def_prop_ro("exports", &local::ProgramModule::exports)
.def("__repr__", &local::ProgramModule::to_s)
.def_static("load", &local::ProgramModule::Load, py::arg("system"),
- py::arg("path"), py::arg("mmap") = true);
+ py::arg("path"), py::arg("mmap") = true)
+ .def_static(
+ "parameter_provider",
+ [](local::System &system, py::args params) {
+ std::vector c_params;
+ c_params.reserve(params.size());
+ for (py::handle h : params) {
+ c_params.push_back(py::cast(h));
+ }
+ return local::ProgramModule::ParameterProvider(system, c_params);
+ },
+ py::arg("system"), py::arg("params"));
py::class_(m, "ProgramInvocation")
.def("invoke",
[](local::ProgramInvocation::Ptr &self) {
@@ -552,41 +789,97 @@ void BindLocal(py::module_ &m) {
}
return PyRehydrateRef(self.get(), std::move(ref));
},
- "Gets the i'th result");
+ "Gets the i'th result")
+ .def("__repr__", [](local::ProgramInvocation::Ptr &self) {
+ if (!self) return std::string("ProgramInvocation(INVALID)");
+ return self->to_s();
+ });
+
+ py::class_(m, "BaseProgramParameters");
+ py::class_(
+ m, "StaticProgramParameters")
+ .def(
+ py::init(),
+ py::arg("system"), py::arg("parameter_scope"),
+ py::arg("max_concurrent_operations") =
+ IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS)
+ .def(
+ "load",
+ [](local::StaticProgramParameters &self,
+ std::filesystem::path file_path, std::string_view format,
+ bool readable, bool writable, bool mmap) {
+ local::StaticProgramParameters::LoadOptions options;
+ options.format = format;
+ options.readable = readable;
+ options.writable = writable;
+ options.mmap = mmap;
+ self.Load(file_path, options);
+ },
+ py::arg("file_path"), py::arg("format") = std::string_view(),
+ py::arg("readable") = true, py::arg("writable") = false,
+ py::arg("mmap") = true);
struct DevicesSet {
- DevicesSet(local::Scope &scope) : scope(scope) {}
- local::Scope &scope;
+ DevicesSet(py::object fiber_obj, std::optional index = {})
+ : fiber_obj(std::move(fiber_obj)), index(index) {}
+ py::object KeepAlive(local::ScopedDevice device) {
+ py::object device_obj = py::cast(device);
+ py::detail::keep_alive(/*nurse=*/device_obj.ptr(),
+ /*patient=*/fiber_obj.ptr());
+ return device_obj;
+ }
+ local::Fiber &fiber() { return py::cast(fiber_obj); }
+ py::object fiber_obj;
+ std::optional index;
};
- py::class_(m, "Scope")
- .def("__repr__", &local::Scope::to_s)
- .def_prop_ro("raw_devices", &local::Scope::raw_devices,
- py::rv_policy::reference_internal)
+ py::class_(m, "Fiber")
+ .def("__repr__", &local::Fiber::to_s)
+ .def_prop_ro(
+ "raw_devices",
+ [](local::Fiber &self) {
+ std::vector devices;
+ devices.reserve(self.raw_devices().size());
+ for (auto it : self.raw_devices()) {
+ devices.push_back(it.second);
+ }
+ return devices;
+ },
+ py::rv_policy::reference_internal)
.def(
"raw_device",
- [](local::Scope &self, int index) { return self.raw_device(index); },
+ [](local::Fiber &self, int index) { return self.raw_device(index); },
py::rv_policy::reference_internal)
.def(
"raw_device",
- [](local::Scope &self, std::string_view name) {
+ [](local::Fiber &self, std::string_view name) {
return self.raw_device(name);
},
py::rv_policy::reference_internal)
- .def_prop_ro(
- "devices", [](local::Scope &self) { return DevicesSet(self); },
- py::rv_policy::reference_internal)
- .def_prop_ro("device_names", &local::Scope::device_names)
- .def_prop_ro("named_devices", &local::Scope::named_devices,
- py::rv_policy::reference_internal)
+ .def_prop_ro("devices",
+ [](py::object self) { return DevicesSet(std::move(self)); })
+ .def_prop_ro("devices_dict",
+ [](py::handle self_obj) {
+ local::Fiber &self = py::cast(self_obj);
+ py::dict d;
+ for (auto &it : self.raw_devices()) {
+ py::object scoped_device =
+ py::cast(self.device(it.second));
+ py::detail::keep_alive(/*nurse=*/scoped_device.ptr(),
+ /*patient=*/self_obj.ptr());
+ d[py::cast(it.first)] = scoped_device;
+ }
+ return d;
+ })
+ .def_prop_ro("device_names", &local::Fiber::device_names)
.def(
"device",
- [](local::Scope &self, py::args args) {
+ [](local::Fiber &self, py::args args) {
return CastDeviceAffinity(self, args);
},
py::rv_policy::reference_internal);
py::class_(m, "ScopedDevice")
- .def_prop_ro("scope", &local::ScopedDevice::scope,
+ .def_prop_ro("fiber", &local::ScopedDevice::fiber,
py::rv_policy::reference)
.def_prop_ro("affinity", &local::ScopedDevice::affinity,
py::rv_policy::reference_internal)
@@ -601,25 +894,35 @@ void BindLocal(py::module_ &m) {
.def("__repr__", &local::ScopedDevice::to_s);
py::class_(m, "_ScopeDevicesSet")
+ .def("__iter__",
+ [](DevicesSet &self) { return DevicesSet(self.fiber_obj, 0); })
+ .def("__next__",
+ [](DevicesSet &self) {
+ auto &fiber = self.fiber();
+ if (!self.index || *self.index >= fiber.raw_devices().size()) {
+ // Blurgh: Exception as flow control is not cheap. There is a
+ // very obnoxious way to make this not be exception based but
+ // this is a minority path.
+ throw py::stop_iteration();
+ }
+ return self.KeepAlive(fiber.device((*self.index)++));
+ })
.def("__len__",
- [](DevicesSet &self) { return self.scope.raw_devices().size(); })
- .def(
- "__getitem__",
- [](DevicesSet &self, int index) { return self.scope.device(index); },
- py::rv_policy::reference_internal)
- .def(
- "__getitem__",
- [](DevicesSet &self, std::string_view name) {
- return self.scope.device(name);
- },
- py::rv_policy::reference_internal)
- .def(
- "__getattr__",
- [](DevicesSet &self, std::string_view name) {
- return self.scope.device(name);
- },
- py::rv_policy::reference_internal);
+ [](DevicesSet &self) { return self.fiber().raw_devices().size(); })
+ .def("__getitem__",
+ [](DevicesSet &self, size_t index) {
+ return self.KeepAlive(self.fiber().device(index));
+ })
+ .def("__getitem__",
+ [](DevicesSet &self, std::string_view name) {
+ return self.KeepAlive(self.fiber().device(name));
+ })
+ .def("__getattr__",
+ [](DevicesSet &self, std::string_view name) -> py::object {
+ return self.KeepAlive(self.fiber().device(name));
+ });
+ ;
py::class_