diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 27d63ae71..babcf0245 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -1,4 +1,3 @@ -#!/bin/bash # Copyright 2024 Advanced Micro Devices, Inc # # Licensed under the Apache License v2.0 with LLVM Exceptions. @@ -22,7 +21,7 @@ permissions: env: IREE_REPO_DIR: ${{ github.workspace }}/iree - BUILD_DIR: ${{ github.workspace }}/libshortfin/build + LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/ jobs: build-and-test: @@ -47,16 +46,15 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - depth: 1 - name: Initalize IREE submodules run : | cd ${{ env.IREE_REPO_DIR }} - git submodule update --init -- third_party/benchmark - git submodule update --init -- third_party/cpuinfo/ - git submodule update --init -- third_party/flatcc - git submodule update --init -- third_party/googletest - git submodule update --init -- third_party/hip-build-deps/ + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ - name: Build IREE runtime run: | @@ -80,16 +78,18 @@ jobs: - name: Setup Python uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 with: - python-version: "3.11" + python-version: "3.12" cache: "pip" - name: Install Python packages # TODO: Switch to `pip install -r requirements.txt -e libshortfin/`. - run: pip install nanobind typing_extensions + run: | + pip install nanobind + pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt - - name: Build libshortfin + - name: Build libshortfin (full) run: | - mkdir ${{ env.BUILD_DIR }} - cd ${{ env.BUILD_DIR }} + mkdir ${{ env.LIBSHORTFIN_DIR }}/build + cd ${{ env.LIBSHORTFIN_DIR }}/build cmake -GNinja \ -DCMAKE_C_COMPILER=clang-18 \ -DCMAKE_CXX_COMPILER=clang++-18 \ @@ -98,8 +98,25 @@ jobs: -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ .. cmake --build . --target all + pip install -v -e . + + - name: Test libshortfin (full) + run: | + cd ${{ env.LIBSHORTFIN_DIR }}/build + ctest --timeout 30 --output-on-failure + cd ${{ env.LIBSHORTFIN_DIR }} + pytest -s -v -m "not requires_amd_gpu" - - name: Test libshortfin + - name: Build libshortfin (host-only) run: | - cd ${{ env.BUILD_DIR }} - cmake --build . --target test + mkdir ${{ env.LIBSHORTFIN_DIR }}/build-host-only + cd ${{ env.LIBSHORTFIN_DIR }}/build-host-only + cmake -GNinja \ + -DCMAKE_C_COMPILER=clang-18 \ + -DCMAKE_CXX_COMPILER=clang++-18 \ + -DCMAKE_LINKER_TYPE=LLD \ + -DCMAKE_PREFIX_PATH=${{ env.IREE_REPO_DIR }}/build/lib/cmake/IREE \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ + -DSHORTFIN_HAVE_AMDGPU=OFF \ + .. + cmake --build . --target all diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml new file mode 100644 index 000000000..14aa26bda --- /dev/null +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -0,0 +1,170 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: CI - libshortfin - ASan + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + paths: + - '.github/workflows/ci_linux_x64_asan-libshortfin.yml' + - 'libshortfin/**' + +permissions: + contents: read + +env: + PYENV_ROOT: ${{ github.workspace }}/pyenv + PYENV_REF: 9ecd803bffaffb949fbdd8c70cb086227f6a3202 # v2.4.10 + PYTHON_VER: 3.12.3 + CACHE_ASAN_VER: 1 + CACHE_DEPS_VER: 1 + IREE_SOURCE_DIR: ${{ github.workspace }}/iree + LIBSHORTFIN_DIR: ${{ github.workspace }}/libshortfin/ + + +jobs: + setup-python-asan: + name: Setup Python ASan + runs-on: ubuntu-24.04 + + steps: + - name: Cache Python ASan + id: cache-python-asan + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }} + lookup-only: 'true' + + - name: Install dependencies + if: steps.cache-python-asan.outputs.cache-hit != 'true' + run: | + sudo apt update + sudo apt install clang lld cmake ninja-build + sudo apt install build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev + + - name: Checkout pyenv + if: steps.cache-python-asan.outputs.cache-hit != 'true' + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + repository: pyenv/pyenv + ref: ${{ env.PYENV_REF }} + path: ${{ env.PYENV_ROOT }} + + - name: Install pyenv & Python + if: steps.cache-python-asan.outputs.cache-hit != 'true' + run: | + cd ${{ env.PYENV_ROOT }} + src/configure && make -C src + export PATH=${{ env.PYENV_ROOT }}/bin:$PATH && eval "$(pyenv init -)" + CC=clang-18 CXX=clang++-18 LDFLAGS="-lstdc++" PYTHON_CONFIGURE_OPTS="--with-address-sanitizer" pyenv install -v -g ${{ env.PYTHON_VER }} + pyenv global ${{ env.PYTHON_VER }}-debug + + + build-and-test: + name: Build and test libshortfin + needs: [setup-python-asan] + runs-on: ubuntu-24.04 + + steps: + - name: Install dependencies + run: | + sudo apt update + sudo apt install clang lld cmake ninja-build + + - name: Checkout repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + submodules: false + + - name: Checkout IREE repo + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + repository: iree-org/iree + path: ${{ env.IREE_SOURCE_DIR }} + submodules: false + + - name: Initalize IREE submodules + run : | + cd ${{ env.IREE_SOURCE_DIR }} + git submodule update --init --depth 1 -- third_party/benchmark + git submodule update --init --depth 1 -- third_party/cpuinfo/ + git submodule update --init --depth 1 -- third_party/flatcc + git submodule update --init --depth 1 -- third_party/googletest + git submodule update --init --depth 1 -- third_party/hip-build-deps/ + + - name: Restore Python dependencies cache + id: cache-python-deps-restore + uses: actions/cache/restore@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ runner.os }}-python-deps-${{ hashFiles('libshortfin/requirements-tests.txt') }}-v${{ env.CACHE_DEPS_VER }} + + - name: Restore Python ASan cache + id: cache-python-asan + if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' + uses: actions/cache/restore@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }} + + - name: Set path + run: + echo "${{ env.PYENV_ROOT }}/bin" >> $GITHUB_PATH + + - name: Install Python dependencies + if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' + run: | + eval "$(pyenv init -)" + pip install -r ${{ env.LIBSHORTFIN_DIR }}/requirements-tests.txt + + - name: Save Python dependencies cache + if: steps.cache-python-deps-restore.outputs.cache-hit != 'true' + id: cache-python-deps-save + uses: actions/cache/save@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: ${{ env.PYENV_ROOT }} + key: ${{ steps.cache-python-deps-restore.outputs.cache-primary-key }} + + - name: Build libshortfin + env: + # TODO(#151): Don't ignore ODR violations + ASAN_OPTIONS=detect_odr_violation: 0 + run: | + eval "$(pyenv init -)" + mkdir ${{ env.LIBSHORTFIN_DIR }}/build + cd ${{ env.LIBSHORTFIN_DIR }}/build + cmake -GNinja \ + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_C_COMPILER=clang-18 \ + -DCMAKE_CXX_COMPILER=clang++-18 \ + -DCMAKE_LINKER_TYPE=LLD \ + -DSHORTFIN_BUNDLE_DEPS=ON \ + -DSHORTFIN_IREE_SOURCE_DIR=${{ env.IREE_SOURCE_DIR }} \ + -DSHORTFIN_BUILD_PYTHON_BINDINGS=ON \ + -DSHORTFIN_ENABLE_ASAN=ON \ + .. + cmake --build . --target all + pip install -v -e . + + - name: Run ctest + if: ${{ !cancelled() }} + env: + CTEST_OUTPUT_ON_FAILURE: 1 + run: | + cd ${{ env.LIBSHORTFIN_DIR }}/build + ctest --timeout 30 --output-on-failure + + - name: Run pytest + if: ${{ !cancelled() }} + run: | + eval "$(pyenv init -)" + cd ${{ env.LIBSHORTFIN_DIR }} + pytest -m "not requires_amd_gpu" diff --git a/docs/amdgpu_kernel_optimization_guide.md b/docs/amdgpu_kernel_optimization_guide.md index bf597cd94..09c5b59f9 100644 --- a/docs/amdgpu_kernel_optimization_guide.md +++ b/docs/amdgpu_kernel_optimization_guide.md @@ -4,7 +4,7 @@ Author: Jakub Kuderski @kuhar Date: 2024-06-24 -Last Update: 2024-08-14 +Last Update: 2024-08-22 ## Introduction @@ -280,6 +280,11 @@ at once. A sequence of up to 4 adjacent `global_load_dwordx4` instructions (implicitly) forms a *clause* that translates to a single data fabric transaction. +> [!TIP] +> To achieve peak L1 bandwidth, make sure that your memory access engages all +> four L1 cache sets. That is, at the level of the workgroup, you should be +> loading 4 cache lines (128 B) that each map to a different cache set. + > [!TIP] > For data that is 'streamed' and does not need to be cached, consider > using *non-temporal* loads/stores. This disables coherency and invalidates diff --git a/libshortfin/CMakeLists.txt b/libshortfin/CMakeLists.txt index 20571005c..c5a2f0f6a 100644 --- a/libshortfin/CMakeLists.txt +++ b/libshortfin/CMakeLists.txt @@ -18,6 +18,8 @@ project( VERSION 0.9 LANGUAGES C CXX) +set(SOVERSION 1) + set(CMAKE_C_STANDARD 11) set(CMAKE_CXX_STANDARD 20) # https://discourse.cmake.org/t/cmake-3-28-cmake-cxx-compiler-clang-scan-deps-notfound-not-found/9244/3 @@ -33,8 +35,30 @@ endif() option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF) option(SHORTFIN_BUILD_TESTS "Builds C++ tests" ON) option(SHORTFIN_BUNDLE_DEPS "Download dependencies instead of using system libraries" OFF) + set(SHORTFIN_IREE_SOURCE_DIR "" CACHE FILEPATH "Path to IREE source") +# Enabling ASAN. Note that this will work best if building in a completely +# bundled fashion and with an ASAN rigged CPython. Otherwise, various LD_PRELOAD +# hacks are needed. This is merely a develope convenience: people are more +# than welcome to set flags themselves. +option(SHORTFIN_ENABLE_ASAN "Enable ASAN" OFF) +if(SHORTFIN_ENABLE_ASAN) + add_compile_options(-fsanitize=address) + add_link_options(-fsanitize=address) + + # Enable more ASAN checks. + add_compile_definitions(IREE_SANITIZER_ADDRESS) +endif() + +option(SHORTFIN_SYSTEMS_AMDGPU "Builds for AMD GPU systems" ON) +message(STATUS "libshortfin supported systems:") +if(SHORTFIN_SYSTEMS_AMDGPU) + message(STATUS " - AMD GPU") + add_compile_definitions("SHORTFIN_HAVE_AMDGPU") +endif() +message(STATUS " - Host") + include(FetchContent) # Includes. @@ -92,6 +116,7 @@ if (NOT SHORTFIN_IREE_SOURCE_DIR AND SHORTFIN_BUNDLE_DEPS) # TODO: We shouldn't have to pull googletest when we are not building tests. # This needs to be fixed with IREE. GIT_SUBMODULES "third_party/benchmark third_party/cpuinfo third_party/flatcc third_party/hip-build-deps third_party/googletest" + GIT_SHALLOW TRUE ) FetchContent_GetProperties(iree) if(NOT iree_POPULATED) @@ -110,7 +135,9 @@ if(SHORTFIN_IREE_SOURCE_DIR) set(IREE_HAL_DRIVER_DEFAULTS OFF) set(IREE_HAL_DRIVER_LOCAL_SYNC ON) set(IREE_HAL_DRIVER_LOCAL_TASK ON) - set(IREE_HAL_DRIVER_HIP ON) + if(SHORTFIN_SYSTEMS_AMDGPU) + set(IREE_HAL_DRIVER_HIP ON) + endif() add_subdirectory(${SHORTFIN_IREE_SOURCE_DIR} shortfin_iree SYSTEM EXCLUDE_FROM_ALL) else() # Try to find iree using find_package diff --git a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py b/libshortfin/bindings/python/_shortfin/asyncio_bridge.py index 0ef214527..63ded30e9 100644 --- a/libshortfin/bindings/python/_shortfin/asyncio_bridge.py +++ b/libshortfin/bindings/python/_shortfin/asyncio_bridge.py @@ -5,9 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio -from collections.abc import Callable -from contextvars import Context -from typing_extensions import Unpack from . import lib as sfl diff --git a/libshortfin/bindings/python/array_binding.cc b/libshortfin/bindings/python/array_binding.cc index b7a3fb752..9858c2350 100644 --- a/libshortfin/bindings/python/array_binding.cc +++ b/libshortfin/bindings/python/array_binding.cc @@ -12,6 +12,53 @@ using namespace shortfin::array; namespace shortfin::python { +namespace { +static const char DOCSTRING_STORAGE_DATA[] = R"(Access raw binary contents. + +Accessing `foo = storage.data` is equivalent to `storage.data.map(read=True)`. +The returned object is a context manager that will close on exit. + +Assigning `storage.data = array.array("f", [1.0])` will copy that raw data +from the source object using the buffer protocol. The source data must be +less than or equal to the length of the storage object. Note that the entire +storage is mapped as write-only/discardable, and writing less than the storage +bytes leaves any unwritten contents in an undefined state. + +As with `map`, this will only work on buffers that are host visible, which +includes all host buffers and device buffers created with the necessary access. +)"; + +static const char DOCSTRING_STORAGE_MAP[] = + R"(Create a mapping of the buffer contents in host memory. + +Support kwargs of: + +read: Enables read access to the mapped memory. +write: Enables write access to the mapped memory and will flush upon close + (for non-unified memory systems). +discard: Indicates that the entire memory map should be treated as if it will + be overwritten. Initial contents will be undefined. + +Mapping memory for access from the host requires a compatible buffer that has +been created with host visibility (which includes host buffers). + +The returned mapping object is a context manager that will close/flush on +exit. Alternatively, the `close()` method can be invoked explicitly. +)"; + +// Does in-place creation of a mapping object and stores a pointer to the +// contained array::mapping C++ object. +py::object CreateMappingObject(mapping **out_cpp_mapping) { + py::object py_mapping = py::inst_alloc(py::type()); + mapping *cpp_mapping = py::inst_ptr(py_mapping); + new (cpp_mapping) mapping(); + py::inst_mark_ready(py_mapping); + *out_cpp_mapping = cpp_mapping; + return py_mapping; +} + +} // namespace + void BindArray(py::module_ &m) { py::class_(m, "DType") .def_prop_ro("is_boolean", &DType::is_boolean) @@ -52,6 +99,7 @@ void BindArray(py::module_ &m) { m.attr("complex64") = DType::complex64(); m.attr("complex128") = DType::complex128(); + // storage py::class_(m, "storage") .def_static( "allocate_host", @@ -75,8 +123,83 @@ void BindArray(py::module_ &m) { PyBufferReleaser py_view_releaser(py_view); self.Fill(py_view.buf, py_view.len); }) + .def("copy_from", [](storage &self, storage &src) { self.CopyFrom(src); }) + .def( + "map", + [](storage &self, bool read, bool write, bool discard) { + int access = 0; + if (read) access |= IREE_HAL_MEMORY_ACCESS_READ; + if (write) access |= IREE_HAL_MEMORY_ACCESS_WRITE; + if (discard) access |= IREE_HAL_MEMORY_ACCESS_DISCARD; + if (!access) { + throw std::invalid_argument( + "One of the access flags must be set"); + } + mapping *cpp_mapping = nullptr; + py::object py_mapping = CreateMappingObject(&cpp_mapping); + self.MapExplicit( + *cpp_mapping, + static_cast(access)); + return py_mapping; + }, + py::kw_only(), py::arg("read") = false, py::arg("write") = false, + py::arg("discard") = false, DOCSTRING_STORAGE_MAP) + // The 'data' prop is a short-hand for accessing the backing storage + // in a one-shot manner (as for reading or writing). Getting the attribute + // will map for read and return a memory view (equiv to map(read=True)). + // On write, it will accept an object implementing the buffer protocol + // and write/discard the backing storage. + .def_prop_rw( + "data", + [](storage &self) { + mapping *cpp_mapping = nullptr; + py::object py_mapping = CreateMappingObject(&cpp_mapping); + *cpp_mapping = self.MapRead(); + return py_mapping; + }, + [](storage &self, py::handle buffer_obj) { + PyBufferRequest src_info(buffer_obj, PyBUF_SIMPLE); + auto dest_data = self.MapWriteDiscard(); + if (src_info.view().len > dest_data.size()) { + throw std::invalid_argument( + fmt::format("Cannot write {} bytes into buffer of {} bytes", + src_info.view().len, dest_data.size())); + } + std::memcpy(dest_data.data(), src_info.view().buf, + src_info.view().len); + }, + DOCSTRING_STORAGE_DATA) .def("__repr__", &storage::to_s); + // mapping + auto mapping_class = py::class_(m, "mapping"); + mapping_class.def("close", &mapping::reset) + .def_prop_ro("valid", [](mapping &self) -> bool { return self; }) + .def("__enter__", [](py::object self_obj) { return self_obj; }) + .def( + "__exit__", + [](mapping &self, py::handle exc_type, py::handle exc_value, + py::handle exc_tb) { self.reset(); }, + py::arg("exc_type").none(), py::arg("exc_value").none(), + py::arg("exc_tb").none()); + struct MappingBufferHandler { + int operator()(mapping &self, Py_buffer *view, int flags) { + view->buf = self.data(); + view->len = self.size(); + view->readonly = self.writable(); + view->itemsize = 1; + view->format = (char *)"B"; // Byte + view->ndim = 1; + view->shape = nullptr; + view->strides = nullptr; + view->suboffsets = nullptr; + view->internal = nullptr; + return 0; + } + }; + BindBufferProtocol(mapping_class); + + // base_array and subclasses py::class_(m, "base_array") .def_prop_ro("dtype", &base_array::dtype) .def_prop_ro("shape", &base_array::shape); @@ -94,40 +217,33 @@ void BindArray(py::module_ &m) { std::span shape, DType dtype) { return custom_new_keep_alive( py_type, /*keep_alive=*/device.scope(), - device_array::allocate(device, shape, dtype)); + device_array::for_device(device, shape, dtype)); + }) + .def_static("for_device", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), /*keep_alive=*/device.scope(), + device_array::for_device(device, shape, dtype)); }) + .def_static("for_host", + [](local::ScopedDevice &device, std::span shape, + DType dtype) { + return custom_new_keep_alive( + py::type(), /*keep_alive=*/device.scope(), + device_array::for_host(device, shape, dtype)); + }) + .def("for_transfer", + [](device_array &self) { + return custom_new_keep_alive( + py::type(), + /*keep_alive=*/self.device().scope(), self.for_transfer()); + }) .def_prop_ro("device", &device_array::device, py::rv_policy::reference_internal) .def_prop_ro("storage", &device_array::storage, py::rv_policy::reference_internal) .def("__repr__", &device_array::to_s); - py::class_(m, "host_array") - .def("__init__", [](py::args, py::kwargs) {}) - .def_static("__new__", - [](py::handle py_type, class storage storage, - std::span shape, DType dtype) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/storage.scope(), storage, shape, - dtype); - }) - .def_static("__new__", - [](py::handle py_type, local::ScopedDevice &device, - std::span shape, DType dtype) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/device.scope(), - host_array::allocate(device, shape, dtype)); - }) - .def_static("__new__", - [](py::handle py_type, device_array &device_array) { - return custom_new_keep_alive( - py_type, /*keep_alive=*/device_array.device().scope(), - host_array::for_transfer(device_array)); - }) - .def_prop_ro("device", &host_array::device, - py::rv_policy::reference_internal) - .def_prop_ro("storage", &host_array::storage, - py::rv_policy::reference_internal) - .def("__repr__", &host_array::to_s); } } // namespace shortfin::python diff --git a/libshortfin/bindings/python/lib_ext.cc b/libshortfin/bindings/python/lib_ext.cc index 15070afaa..6072caa04 100644 --- a/libshortfin/bindings/python/lib_ext.cc +++ b/libshortfin/bindings/python/lib_ext.cc @@ -13,7 +13,9 @@ #include "shortfin/local/program.h" #include "shortfin/local/scope.h" #include "shortfin/local/system.h" +#if defined(SHORTFIN_HAVE_AMDGPU) #include "shortfin/local/systems/amdgpu.h" +#endif // SHORTFIN_HAVE_AMDGPU #include "shortfin/local/systems/host.h" #include "shortfin/support/globals.h" #include "shortfin/support/logging.h" @@ -94,22 +96,19 @@ class PyWorkerExtension : public local::Worker::Extension { py::gil_scoped_acquire g; loop_.reset(); - // Scrub thread state if not donated. - if (worker().options().owned_thread) { - PyThreadState_Clear(PyThreadState_Get()); - } else { - // Otherwise, juse reset the event loop. - refs_->asyncio_set_event_loop(py::none()); - refs_->asyncio_set_running_loop(py::none()); - } + // reset the event loop. + refs_->asyncio_set_event_loop(py::none()); + refs_->asyncio_set_running_loop(py::none()); } // And destroy our thread state (if not donated). - // TODO: PyThreadState_Delete seems like it should be used here, but I - // couldn't find that being done and I couldn't find a way to use it - // with the GIL/thread state correct. if (worker().options().owned_thread) { - PyThreadState_Swap(nullptr); + // Ordinarily PyGILState_Ensure must be balanced with PyGILState_Release, + // by PyThreadState_DeleteCurrent() implicitly releases it as part of + // its cleanup process. + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyThreadState_Clear(PyThreadState_Get()); + PyThreadState_DeleteCurrent(); } } @@ -150,29 +149,33 @@ class PyProcess : public local::detail::BaseProcess { std::bind(&PyProcess::RunOnWorker, self_object)); } static void RunOnWorker(py::handle self_handle) { - { - py::gil_scoped_acquire g; - // Steal the reference back from ScheduleOnWorker. Important: this is - // very likely the last reference to the process. So self must not be - // touched after self_object goes out of scope. - py::object self_object = py::steal(self_handle); - PyProcess *self = py::cast(self_handle); - // We assume that the run method either returns None (def) or a coroutine - // (async def). - auto coro = self_object.attr("run")(); - if (!coro.is_none()) { - auto task = self->refs_->asyncio_create_task(coro); - // Capture the self object to avoid lifetime hazzard with PyProcess - // going away before done. - task.attr("add_done_callback")( - py::cpp_function([self_object](py::handle future) { - PyProcess *done_self = py::cast(self_object); - done_self->Terminate(); - })); - } else { - // Synchronous termination. - self->Terminate(); - } + py::gil_scoped_acquire g; + // Steal the reference back from ScheduleOnWorker. Important: this is + // very likely the last reference to the process. So self must not be + // touched after self_object goes out of scope. + py::object self_object = py::steal(self_handle); + PyProcess *self = py::cast(self_handle); + // We assume that the run method either returns None (def) or a coroutine + // (async def). + auto coro = self_object.attr("run")(); + if (!coro.is_none()) { + auto task = self->refs_->asyncio_create_task(coro); + // Capture the self object to avoid lifetime hazzard with PyProcess + // going away before done. + task.attr("add_done_callback")( + py::cpp_function([self_object](py::handle future) { + PyProcess *done_self = py::cast(self_object); + done_self->Terminate(); + // The result of the process future doesn't matter to us, but it + // may be carrying an exception and this is our only chance to + // bubble it. If it is, this will throw and be handled by the + // last chance exception handler in the worker. + // TODO: Route process termination and exceptions to a supervisor. + future.attr("result")(); + })); + } else { + // Synchronous termination. + self->Terminate(); } } @@ -238,7 +241,9 @@ NB_MODULE(lib, m) { auto local_m = m.def_submodule("local"); BindLocal(local_m); BindHostSystem(local_m); +#if defined(SHORTFIN_HAVE_AMDGPU) BindAMDGPUSystem(local_m); +#endif // SHORTFIN_HAVE_AMDGPU auto array_m = m.def_submodule("array"); BindArray(array_m); @@ -341,6 +346,8 @@ void BindLocal(py::module_ &m) { return self.CreateWorker(options); }, py::arg("name"), py::rv_policy::reference_internal) + .def_prop_ro("init_worker", &local::System::init_worker, + py::rv_policy::reference_internal) .def( "run", [refs](local::System &self, py::object coro) { @@ -709,6 +716,7 @@ void BindHostSystem(py::module_ &global_m) { py::class_(m, "HostCPUDevice"); } +#if defined(SHORTFIN_HAVE_AMDGPU) void BindAMDGPUSystem(py::module_ &global_m) { auto m = global_m.def_submodule("amdgpu", "AMDGPU system config"); py::class_(m, "AMDGPUDevice"); } +#endif // SHORTFIN_HAVE_AMDGPU } // namespace shortfin::python diff --git a/libshortfin/bindings/python/shortfin/array.py b/libshortfin/bindings/python/shortfin/array.py index e99595554..049fe9ed7 100644 --- a/libshortfin/bindings/python/shortfin/array.py +++ b/libshortfin/bindings/python/shortfin/array.py @@ -37,7 +37,6 @@ base_array = _sfl.array.base_array device_array = _sfl.array.device_array -host_array = _sfl.array.host_array storage = _sfl.array.storage DType = _sfl.array.DType @@ -73,7 +72,6 @@ # Classes. "base_array", "device_array", - "host_array", "storage", "DType", ] diff --git a/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py b/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py new file mode 100644 index 000000000..2cff38342 --- /dev/null +++ b/libshortfin/bindings/python/shortfin/interop/fastapi/__init__.py @@ -0,0 +1,113 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio + +try: + from fastapi import Request, Response + from fastapi.responses import StreamingResponse +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Shortfin fastapi interop requires fastapi to be installed" + ) from e + + +class FastAPIResponder: + """Bridge between FastAPI and shortfin that can be used to send out of band + responses back to a waiting FastAPI async request. + + This isn't really shortfin specific and can be used to bridge to any non + webserver owned loop. + + It is typically used by putting it in a Message that is sent to some processing + queue. Then return/awaiting it from an API callback. Example: + + ``` + @app.get("/predict") + async def predict(value: int, request: Request): + message = RequestMessage(value, FastAPIResponder(request)) + system.request_writer(message) + return await message.responder.response + ``` + + See: examples/python/fastapi/server.py + """ + + def __init__(self, request: Request): + super().__init__() + self.request = request + # Capture the running loop so that we can send responses back. + self._loop = asyncio.get_running_loop() + self.response = asyncio.Future(loop=self._loop) + self._responded = False + self._streaming_queue: asyncio.Queue | None = None + self.is_disconnected = False + + def close_with_error(self): + # Called in a failsafe fashion as part of exception handlers seeking to + # shutdown the response. If not yet responded, this will response with + # a status code of 500. If streaming, then None will be streamed. + if self._responded: + if self._streaming_queue: + self.stream_part(None) + else: + self.send_response(Response(status_code=500)) + + def send_response(self, response: Response): + """Sends a response back for this transaction. + + This is intended for sending single part responses back. See + start_response() for sending back a streaming, multi-part response. + """ + assert not self._responded, "Response already sent" + if self._loop.is_closed(): + raise IOError("Web server is shut down") + self._responded = True + self._loop.call_soon_threadsafe(self.response.set_result, response) + + def start_response(self, **kwargs): + """Starts a streaming response, passing the given kwargs to the + fastapi.responses.StreamingResponse constructor. + + This is appropriate to use for generating a sparse response stream as is + typical of chat apps. As it will hop threads for each part, other means should + be used for bulk transfer (i.e. by scheduling on the webserver loop + directly). + """ + assert not self._responded, "Response already sent" + if self._loop.is_closed(): + raise IOError("Web server is shut down") + self._responded = True + self._streaming_queue = asyncio.Queue() + + async def gen(request, streaming_queue): + while True: + if await request.is_disconnected(): + self.is_disconnected = True + part = await streaming_queue.get() + if part is None: + break + yield part + + def start(request, streaming_queue, response_future): + response = StreamingResponse(gen(request, streaming_queue), **kwargs) + response_future.set_result(response) + + self._loop.call_soon_threadsafe( + start, self.request, self._streaming_queue, self.response + ) + + def stream_part(self, content: bytes | None): + """Streams content to a response started with start_response(). + + Streaming must be ended by sending None. + """ + assert self._streaming_queue is not None, "start_response() not called" + if self._loop.is_closed(): + raise IOError("Web server is shut down") + self._loop.call_soon_threadsafe(self._streaming_queue.put_nowait, content) + if content is None: + self._streaming_queue = None diff --git a/libshortfin/bindings/python/utils.h b/libshortfin/bindings/python/utils.h index 24e2a6642..dab4423f8 100644 --- a/libshortfin/bindings/python/utils.h +++ b/libshortfin/bindings/python/utils.h @@ -14,10 +14,10 @@ namespace shortfin::python { // Casts any of int, str, local::Device, DeviceAffinity to a DeviceAffinity. // If the object is a sequence, then the affinity is constructed from the union. -inline local::ScopedDevice CastDeviceAffinity(local::Scope &scope, +inline local::ScopedDevice CastDeviceAffinity(local::Scope& scope, py::handle object) { if (py::isinstance(object)) { - return scope.device(py::cast(object)); + return scope.device(py::cast(object)); } else if (py::isinstance(object)) { return local::ScopedDevice(scope, py::cast(object)); } else if (py::isinstance(object)) { @@ -39,4 +39,58 @@ inline local::ScopedDevice CastDeviceAffinity(local::Scope &scope, py::repr(object).c_str())); } +// For a bound class, binds the buffer protocol. This will result in a call +// to handler like: +// HandlerFunctor(self, Py_buffer *view, int flags) +// This is a low level callback and must not raise any exceptions. If +// error conditions are warranted the usual PyErr_SetString approach must be +// used (and -1 returned). Return 0 on success. +template +void BindBufferProtocol(py::handle clazz) { + PyBufferProcs buffer_procs; + memset(&buffer_procs, 0, sizeof(buffer_procs)); + buffer_procs.bf_getbuffer = + // It is not legal to raise exceptions from these callbacks. + +[](PyObject* raw_self, Py_buffer* view, int flags) noexcept -> int { + if (view == NULL) { + PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer"); + return -1; + } + + // Cast must succeed due to invariants. + auto& self = py::cast(py::handle(raw_self)); + + Py_INCREF(raw_self); + view->obj = raw_self; + HandlerFunctor handler; + return handler(self, view, flags); + }; + buffer_procs.bf_releasebuffer = + +[](PyObject* raw_self, Py_buffer* view) noexcept -> void {}; + auto heap_type = reinterpret_cast(clazz.ptr()); + assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && + "must be heap type"); + heap_type->as_buffer = buffer_procs; +} + +// Represents a Py_buffer obtained via PyObject_GetBuffer() and terminated via +// PyBuffer_Release(). +class PyBufferRequest { + public: + PyBufferRequest(py::handle& exporter, int flags) { + int rc = PyObject_GetBuffer(exporter.ptr(), &view_, flags); + if (rc != 0) { + throw py::python_error(); + } + } + ~PyBufferRequest() { PyBuffer_Release(&view_); } + PyBufferRequest(const PyBufferRequest&) = delete; + void operator=(const PyBufferRequest&) = delete; + + Py_buffer& view() { return view_; } + + private: + Py_buffer view_; +}; + } // namespace shortfin::python diff --git a/libshortfin/examples/python/fastapi/server.py b/libshortfin/examples/python/fastapi/server.py new file mode 100644 index 000000000..66ab37b75 --- /dev/null +++ b/libshortfin/examples/python/fastapi/server.py @@ -0,0 +1,133 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import asyncio +import traceback +from contextlib import asynccontextmanager +import json +import threading +import sys + +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse +import shortfin as sf +from shortfin.interop.fastapi import FastAPIResponder +import uvicorn + + +class RequestMessage(sf.Message): + def __init__(self, request_value: int, responder: FastAPIResponder): + super().__init__() + self.request_value = request_value + self.responder = responder + + +class System: + def __init__(self): + self.ls = sf.host.CPUSystemBuilder().create_system() + # TODO: Come up with an easier bootstrap thing than manually + # running a thread. + self.t = threading.Thread(target=lambda: self.ls.run(self.run())) + self.request_queue = self.ls.create_queue("request") + self.request_writer = self.request_queue.writer() + + def start(self): + self.t.start() + + def shutdown(self): + self.request_queue.close() + + async def run(self): + print("*** Sytem Running ***") + request_reader = self.request_queue.reader() + while request := await request_reader(): + try: + responder = request.responder + if request.request_value == 0: + raise ValueError("Something broke") + elif request.request_value > 20: + responder.send_response(Response(status_code=400)) + elif request.request_value == 1: + # Send a single response. + responder.send_response( + JSONResponse({"answer": request.request_value}) + ) + else: + # Stream responses from 0..value + responder.start_response() + for i in range(request.request_value + 1): + if responder.is_disconnected: + continue + responder.stream_part( + (json.dumps({"answer": i}) + "\n\0").encode() + ) + await asyncio.sleep(0.01) + responder.stream_part(None) + except Exception as e: + responder.close_with_error() + traceback.print_exc() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + system.start() + yield + print("Shutting down shortfin") + system.shutdown() + + +system = System() +app = FastAPI(lifespan=lifespan) + + +@app.get("/predict") +async def predict(value: int, request: Request): + message = RequestMessage(value, FastAPIResponder(request)) + system.request_writer(message) + return await message.responder.response + + +@app.get("/health") +async def health() -> Response: + return Response(status_code=200) + + +def main(argv): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="Root path to use for installing behind path based proxy.", + ) + parser.add_argument( + "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" + ) + parser.add_argument( + "--testing-mock-service", + action="store_true", + help="Enable the mock testing service", + ) + parser.add_argument( + "--device-uri", type=str, default="local-task", help="Device URI to serve on" + ) + + args = parser.parse_args(argv) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=args.timeout_keep_alive, + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/libshortfin/examples/python/http/http_server.py b/libshortfin/examples/python/http/http_server.py deleted file mode 100644 index 43b62f06d..000000000 --- a/libshortfin/examples/python/http/http_server.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import argparse -import asyncio -from contextlib import asynccontextmanager -import threading -import sys - -from fastapi import FastAPI, Request, Response -from fastapi.responses import JSONResponse, StreamingResponse -import shortfin as sf -import uvicorn - - -class FastAPIResponder(sf.Message): - """Bridge between FastAPI and shortfin that can be put on a queue and used to - send a response back at an arbitrary point. - - This object is constructed in a FastAPI handler, capturing the current event loop - used by the web server. Then it can be put on a shortfin Queue and once within - a shortfin worker, an arbitrary worker can call `send_response` to send a simple - FastAPI response back to the webserver loop and onto the client. - - """ - - def __init__(self, request: Request): - super().__init__() - self.request = request - # Capture the running loop so that we can send responses back. - self._loop = asyncio.get_running_loop() - self.response = asyncio.Future(loop=self._loop) - self._responded = False - self._streaming_queue: asyncio.Queue | None = None - self.is_disconnected = False - - def send_response(self, response: Response): - """Sends a response back for this transaction. - - This is intended for sending single part responses back. See - start_response() for sending back a streaming, multi-part response. - """ - assert not self._responded, "Response already sent" - if self._loop.is_closed(): - raise IOError("Web server is shut down") - self._responded = True - self._loop.call_soon_threadsafe(self.response.set_result, response) - - def start_response(self, **kwargs): - """Starts a streaming response, passing the given kwargs to the - fastapi.responses.StreamingResponse constructor. - - This is appropriate to use for generating a sparse response stream as is - typical of chat apps. As it will hop threads for each part, other means should - be used for bulk transfer (i.e. by scheduling on the webserver loop - directly). - """ - assert not self._responded, "Response already sent" - if self._loop.is_closed(): - raise IOError("Web server is shut down") - self._responded = True - self._streaming_queue = asyncio.Queue() - - async def gen(): - while True: - if await self.request.is_disconnected(): - self.is_disconnected = True - part = await self._streaming_queue.get() - if part is None: - break - yield part - - def start(): - response = StreamingResponse(gen(), **kwargs) - self.response.set_result(response) - - self._loop.call_soon_threadsafe(start) - - def stream_part(self, content: bytes | None): - """Streams content to a response started with start_response(). - - Streaming must be ended by sending None. - """ - assert self._streaming_queue is not None, "start_response() not called" - if self._loop.is_closed(): - raise IOError("Web server is shut down") - self._loop.call_soon_threadsafe(self._streaming_queue.put_nowait, content) - - -class System: - def __init__(self): - self.ls = sf.host.CPUSystemBuilder().create_system() - # TODO: Come up with an easier bootstrap thing than manually - # running a thread. - self.t = threading.Thread(target=lambda: self.ls.run(self.run())) - self.request_queue = self.ls.create_queue("request") - self.request_writer = self.request_queue.writer() - - def start(self): - self.t.start() - - def shutdown(self): - self.request_queue.close() - - async def run(self): - print("*** Sytem Running ***") - request_reader = self.request_queue.reader() - while responder := await request_reader(): - print("Got request:", responder) - # Can send a single response: - # request.send_response(JSONResponse({"answer": 42})) - # Or stream: - responder.start_response() - for i in range(20): - if responder.is_disconnected: - print("Cancelled!") - break - responder.stream_part(f"Iteration {i}\n".encode()) - await asyncio.sleep(0.2) - else: - responder.stream_part(None) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - system.start() - yield - print("Shutting down shortfin") - system.shutdown() - - -system = System() -app = FastAPI(lifespan=lifespan) - - -@app.get("/predict") -async def predict(request: Request): - transaction = FastAPIResponder(request) - system.request_writer(transaction) - return await transaction.response - - -def main(argv): - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default=None) - parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="Root path to use for installing behind path based proxy.", - ) - parser.add_argument( - "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" - ) - parser.add_argument( - "--testing-mock-service", - action="store_true", - help="Enable the mock testing service", - ) - parser.add_argument( - "--device-uri", type=str, default="local-task", help="Device URI to serve on" - ) - - args = parser.parse_args(argv) - - uvicorn.run( - app, - host=args.host, - port=args.port, - log_level="debug", - timeout_keep_alive=args.timeout_keep_alive, - ) - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/libshortfin/examples/python/mobilenet_server/inference_system.py b/libshortfin/examples/python/mobilenet_server/inference_system.py new file mode 100644 index 000000000..8ae7773db --- /dev/null +++ b/libshortfin/examples/python/mobilenet_server/inference_system.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +from pathlib import Path +import sys + +import shortfin as sf +import shortfin.array as sfnp + +MAX_BATCH = 1 + + +class InferenceRequest(sf.Message): + def __init__(self, raw_image_data): + super().__init__() + self.raw_image_data = raw_image_data + + +class InferenceProcess(sf.Process): + def __init__(self, program, request_queue, **kwargs): + super().__init__(**kwargs) + self.program = program + self.request_reader = request_queue.reader() + self.device = self.scope.device(0) + self.device_input = sfnp.device_array( + self.device, [MAX_BATCH, 3, 224, 224], sfnp.float32 + ) + self.host_staging = self.device_input.for_transfer() + + async def run(self): + print(f"Inference process: {self.pid}") + while request := await self.request_reader(): + print(f"[{self.pid}] Got request {request}") + # TODO: Should really be taking a slice and writing that. For now, + # just writing to the backing storage is the best we have API + # support for. Generally, APIs on storage should be mirrored onto + # the array. + self.host_staging.storage.data = request.raw_image_data + print(self.host_staging) + self.device_input.storage.copy_from(self.host_staging.storage) + print(self.device_input) + + +class Main: + def __init__(self, lsys: sf.System, home_dir: Path): + self.processes_per_worker = 1 + self.lsys = lsys + self.home_dir = home_dir + self.request_queue = lsys.create_queue("request") + self.program_module = self.lsys.load_module(home_dir / "model.vmfb") + print(f"Loaded: {self.program_module}") + self.processes = [] + + async def start_scope(self, scope): + # Note that currently, program load is synchronous. But we do it + # in a task so we can await it in the future and let program loads + # overlap. + program = scope.load_unbound_program([self.program_module]) + for _ in range(self.processes_per_worker): + self.processes.append( + InferenceProcess(program, self.request_queue, scope=scope).launch() + ) + + async def main(self): + devices = self.lsys.devices + print( + f"System created with {len(devices)} devices:\n " + f"{' '.join(repr(d) for d in devices)}" + ) + # We create a physical worker and initial scope for each device. + # This isn't a hard requirement and there are advantages to other + # topologies. + initializers = [] + for device in devices: + worker = self.lsys.create_worker(f"device-{device.name}") + scope = self.lsys.create_scope(worker, devices=[device]) + initializers.append(self.start_scope(scope)) + + # Run all initializers in parallel. These launch inference processes. + print("Waiting for initializers") + await asyncio.gather(*initializers) + + # Wait for inference processes to end. + print(f"Running {len(self.processes)} inference processes") + await asyncio.gather(*self.processes) + print("Inference processors completed") + + +def run_cli(home_dir: Path, argv): + def client(): + # Create a random image. + print("Preparing requests...") + writer = main.request_queue.writer() + + # Dumb way to prepare some data to feed [1, 3, 224, 224] f32. + import array + + dummy_data = array.array( + "f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) + ) + # dummy_data = array.array("f", [0.2] * (3 * 224 * 224)) + message = InferenceRequest(dummy_data) + writer(message) + + # Done. + writer.close() + + lsys = sf.host.CPUSystemBuilder().create_system() + main = Main(lsys, home_dir) + lsys.init_worker.call_threadsafe(client) + lsys.run(main.main()) + + +if __name__ == "__main__": + home_dir = Path(__file__).resolve().parent + run_cli(home_dir, sys.argv[1:]) diff --git a/libshortfin/examples/python/mobilenet_server/server.py b/libshortfin/examples/python/mobilenet_server/server.py deleted file mode 100644 index c8f6484bf..000000000 --- a/libshortfin/examples/python/mobilenet_server/server.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import asyncio -from pathlib import Path - -import shortfin as sf - - -class InferenceProcess(sf.Process): - def __init__(self, program, **kwargs): - super().__init__(**kwargs) - self.program = program - - async def run(self): - print(f"Inference process: {self.pid}") - - -class Main: - def __init__(self, lsys: sf.System, home_dir: Path): - self.processes_per_worker = 4 - self.lsys = lsys - self.home_dir = home_dir - self.program_module = self.lsys.load_module(home_dir / "model.vmfb") - print(f"Loaded: {self.program_module}") - self.processes = [] - - async def initialize(self, scope): - # Note that currently, program load is synchronous. But we do it - # in a task so we can await it in the future and let program loads - # overlap. - program = scope.load_unbound_program([self.program_module]) - for _ in range(self.processes_per_worker): - self.processes.append(InferenceProcess(program, scope=scope).launch()) - - async def main(self): - devices = self.lsys.devices - print( - f"System created with {len(devices)} devices:\n " - f"{' '.join(repr(d) for d in devices)}" - ) - # We create a physical worker and initial scope for each device. - # This isn't a hard requirement and there are advantages to other - # topologies. - initializers = [] - for device in devices: - worker = self.lsys.create_worker(f"device-{device.name}") - scope = self.lsys.create_scope(worker, devices=[device]) - initializers.append(self.initialize(scope)) - - # Run all initializers in parallel. These launch inference processes. - await asyncio.gather(*initializers) - - # Wait for inference processes to end. - await asyncio.gather(*self.processes) - - -def run_server(home_dir: Path): - lsys = sf.host.CPUSystemBuilder().create_system() - main = Main(lsys, home_dir) - lsys.run(main.main()) - - -if __name__ == "__main__": - home_dir = Path(__file__).resolve().parent - run_server(home_dir) diff --git a/libshortfin/pyproject.toml b/libshortfin/pyproject.toml index 5185be707..e868b4264 100644 --- a/libshortfin/pyproject.toml +++ b/libshortfin/pyproject.toml @@ -13,6 +13,9 @@ addopts = [ "-ra", "--import-mode=importlib", ] +markers = [ + "requires_amd_gpu: tests that require and AMD GPU (deselect with '-m \"not requires_amd_gpu\"')", +] testpaths = [ "tests", ] diff --git a/libshortfin/requirements-tests.txt b/libshortfin/requirements-tests.txt new file mode 100644 index 000000000..50bdd9831 --- /dev/null +++ b/libshortfin/requirements-tests.txt @@ -0,0 +1,5 @@ +nanobind==2.0.0 +pytest +requests +fastapi +uvicorn diff --git a/libshortfin/setup.py b/libshortfin/setup.py index f3fc4e9b9..4f4074d2f 100644 --- a/libshortfin/setup.py +++ b/libshortfin/setup.py @@ -5,9 +5,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from distutils.core import setup, Extension +import sys +import shutil +import subprocess import os from pathlib import Path +from distutils.command.build import build as _build from setuptools.command.build_ext import build_ext as _build_ext +from setuptools.command.build_py import build_py as _build_py # This file can be generated into the build directory to allow an arbitrary @@ -18,35 +23,45 @@ CPP_PREBUILT_SOURCE_DIR = "@libshortfin_SOURCE_DIR@" CPP_PREBUILT_BINARY_DIR = "@libshortfin_BINARY_DIR@" +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) + def is_cpp_prebuilt(): return CPP_PREBUILT == "TRUE" -def native_build(): - if is_cpp_prebuilt(): - print("setup.py running in pre-built mode from:") - print(f" SOURCE_DIR = {CPP_PREBUILT_SOURCE_DIR}") - print(f" BINARY_DIR = {CPP_PREBUILT_BINARY_DIR}") - return Path(CPP_PREBUILT_SOURCE_DIR), Path(CPP_PREBUILT_BINARY_DIR) - raise RuntimeError("Packaging currently only supported in pre-built mode") - +if is_cpp_prebuilt(): + print("setup.py running in pre-built mode:", file=sys.stderr) + SOURCE_DIR = Path(CPP_PREBUILT_SOURCE_DIR) + BINARY_DIR = Path(CPP_PREBUILT_BINARY_DIR) +else: + print("setup.py running in cmake build mode:", file=sys.stderr) + # setup.py is in the source directory. + SOURCE_DIR = Path(SETUPPY_DIR) + BINARY_DIR = Path(os.path.join(SETUPPY_DIR, "build", "b")) -source_dir, binary_dir = native_build() +print(f" SOURCE_DIR = {SOURCE_DIR}", file=sys.stderr) +print(f" BINARY_DIR = {BINARY_DIR}", file=sys.stderr) # Due to a quirk of setuptools, that package_dir map must only contain # paths relative to the directory containing setup.py. Why? No one knows. -current_dir = Path(__file__).resolve().parent -rel_source_dir = source_dir.relative_to(current_dir, walk_up=True) -rel_binary_dir = binary_dir.relative_to(current_dir, walk_up=True) +REL_SOURCE_DIR = SOURCE_DIR.relative_to(SETUPPY_DIR, walk_up=True) +REL_BINARY_DIR = BINARY_DIR.relative_to(SETUPPY_DIR, walk_up=True) -class BuiltExtension(Extension): +class CMakeExtension(Extension): def __init__(self, name, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) +class CustomBuild(_build): + def run(self): + self.run_command("build_py") + self.run_command("build_ext") + self.run_command("build_scripts") + + class NoopBuildExtension(_build_ext): def build_extension(self, ext): ... @@ -55,8 +70,127 @@ def copy_extensions_to_source(self, *args, **kwargs): ... -python_src_dir = rel_source_dir / "bindings" / "python" -python_bin_dir = rel_binary_dir / "bindings" / "python" +def maybe_nuke_cmake_cache(cmake_build_dir): + # From run to run under pip, we can end up with different paths to ninja, + # which isn't great and will confuse cmake. Detect if the location of + # ninja changes and force a cache flush. + ninja_path = "" + try: + import ninja + except ModuleNotFoundError: + pass + else: + ninja_path = ninja.__file__ + expected_stamp_contents = f"{sys.executable}\n{ninja_path}" + + # In order to speed things up on CI and not rebuild everything, we nuke + # the CMakeCache.txt file if the path to the Python interpreter changed. + # Ideally, CMake would let us reconfigure this dynamically... but it does + # not (and gets very confused). + PYTHON_STAMP_FILE = os.path.join(cmake_build_dir, "python_stamp.txt") + if os.path.exists(PYTHON_STAMP_FILE): + with open(PYTHON_STAMP_FILE, "rt") as f: + actual_stamp_contents = f.read() + if actual_stamp_contents == expected_stamp_contents: + # All good. + return + + # Mismatch or not found. Clean it. + cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt") + if os.path.exists(cmake_cache_file): + print("Removing CMakeCache.txt because Python version changed", file=sys.stderr) + os.remove(cmake_cache_file) + + # And write. + with open(PYTHON_STAMP_FILE, "wt") as f: + f.write(expected_stamp_contents) + + +class CMakeBuildPy(_build_py): + def run(self): + # The super-class handles the pure python build. + super().run() + + # Build using cmake if not in prebuild mode. + if not is_cpp_prebuilt(): + + # Build extension using cmake. + print("*****************************", file=sys.stderr) + print("* Building libshortfin *", file=sys.stderr) + print("*****************************", file=sys.stderr) + + cfg = os.getenv("SHORTFIN_CMAKE_BUILD_TYPE", "Release") + + CMAKE_BUILD_DIR = BINARY_DIR + + # Configure CMake. + os.makedirs(BINARY_DIR, exist_ok=True) + maybe_nuke_cmake_cache(CMAKE_BUILD_DIR) + print(f"CMake build dir: {CMAKE_BUILD_DIR}", file=sys.stderr) + cmake_args = [ + "-GNinja", + "--log-level=VERBOSE", + "-DSHORTFIN_BUNDLE_DEPS=ON", + f"-DCMAKE_BUILD_TYPE={cfg}", + "-DSHORTFIN_BUILD_PYTHON_BINDINGS=ON", + # TODO: This shouldn't be hardcoded... but shortfin doesn't + # compile without it. + "-DCMAKE_C_COMPILER=clang", + "-DCMAKE_CXX_COMPILER=clang++", + ] + + # Only do a from-scratch configure if not already configured. + cmake_cache_file = os.path.join(CMAKE_BUILD_DIR, "CMakeCache.txt") + if not os.path.exists(cmake_cache_file): + print(f"Configuring with: {cmake_args}", file=sys.stderr) + subprocess.check_call( + ["cmake", SOURCE_DIR] + cmake_args, cwd=CMAKE_BUILD_DIR + ) + else: + print(f"Not re-configing (already configured)", file=sys.stderr) + + # Build. + subprocess.check_call(["cmake", "--build", "."], cwd=CMAKE_BUILD_DIR) + print("Build complete.", file=sys.stderr) + + # We only take _shortfin_default from the build. + target_dir = os.path.join( + os.path.abspath(self.build_lib), "_shortfin_default" + ) + print(f"Building in target: {target_dir}", file=sys.stderr) + os.makedirs(target_dir, exist_ok=True) + print("Copying build to target.", file=sys.stderr) + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + shutil.copytree( + os.path.join( + CMAKE_BUILD_DIR, + "bindings", + "python", + "_shortfin_default", + ), + target_dir, + symlinks=False, + ) + + +PYTHON_SOURCE_DIR = REL_SOURCE_DIR / "bindings" / "python" +PYTHON_BINARY_DIR = REL_BINARY_DIR / "bindings" / "python" + +# We need some directories to exist before setup. +def populate_built_package(abs_dir): + """Makes sure that a directory and __init__.py exist. + + This needs to unfortunately happen before any of the build process + takes place so that setuptools can plan what needs to be built. + We do this for any built packages (vs pure source packages). + """ + os.makedirs(abs_dir, exist_ok=True) + with open(os.path.join(abs_dir, "__init__.py"), "wt"): + pass + + +populate_built_package(os.path.join(PYTHON_BINARY_DIR / "_shortfin_default")) setup( name="shortfin", @@ -71,16 +205,18 @@ def copy_extensions_to_source(self, *args, **kwargs): ], zip_safe=False, package_dir={ - "_shortfin": str(python_src_dir / "_shortfin"), - "_shortfin_default": str(python_bin_dir / "_shortfin_default"), + "_shortfin": str(PYTHON_SOURCE_DIR / "_shortfin"), + "_shortfin_default": str(PYTHON_BINARY_DIR / "_shortfin_default"), # TODO: Conditionally map additional native library variants. - "shortfin": str(python_src_dir / "shortfin"), + "shortfin": str(PYTHON_SOURCE_DIR / "shortfin"), }, ext_modules=[ - BuiltExtension("_shortfin_default.lib"), + CMakeExtension("_shortfin_default.lib") # TODO: Conditionally map additional native library variants. ], cmdclass={ + "build": CustomBuild, "build_ext": NoopBuildExtension, + "build_py": CMakeBuildPy, }, ) diff --git a/libshortfin/src/CMakeLists.txt b/libshortfin/src/CMakeLists.txt index 1a69094e0..de31643e3 100644 --- a/libshortfin/src/CMakeLists.txt +++ b/libshortfin/src/CMakeLists.txt @@ -13,6 +13,11 @@ target_include_directories( $) +set(_INIT_INTERNAL_DEPS) +if(SHORTFIN_SYSTEMS_AMDGPU) + list(APPEND _INIT_INTERNAL_DEPS shortfin_systems_amdgpu) +endif() + shortfin_public_library( NAME shortfin @@ -20,6 +25,8 @@ shortfin_public_library( shortfin_array shortfin_local shortfin_support - shortfin_systems_amdgpu shortfin_systems_host + ${_INIT_INTERNAL_DEPS} ) + +set_target_properties(shortfin PROPERTIES VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR} SOVERSION ${SOVERSION}) diff --git a/libshortfin/src/shortfin/array/CMakeLists.txt b/libshortfin/src/shortfin/array/CMakeLists.txt index da22e3cc0..0e9360363 100644 --- a/libshortfin/src/shortfin/array/CMakeLists.txt +++ b/libshortfin/src/shortfin/array/CMakeLists.txt @@ -10,13 +10,25 @@ shortfin_cc_component( HDRS array.h api.h + dims.h dtype.h storage.h SRCS array.cc dtype.cc storage.cc + xtensor_bridge.cc COMPONENTS shortfin_local shortfin_support + DEPS + xtensor +) + +shortfin_gtest_test( + NAME shortfin_array_test + SRCS + array_test.cc + dims_test.cc + dtype_test.cc ) diff --git a/libshortfin/src/shortfin/array/api.h b/libshortfin/src/shortfin/array/api.h index e7f73ede4..baa8a55ea 100644 --- a/libshortfin/src/shortfin/array/api.h +++ b/libshortfin/src/shortfin/array/api.h @@ -8,7 +8,9 @@ #define SHORTFIN_ARRAY_API_H #include "shortfin/array/array.h" +#include "shortfin/array/dims.h" #include "shortfin/array/dtype.h" #include "shortfin/array/storage.h" +#include "shortfin/array/xtensor_bridge.h" #endif // SHORTFIN_ARRAY_API_H diff --git a/libshortfin/src/shortfin/array/array.cc b/libshortfin/src/shortfin/array/array.cc index 74d20e47e..1d6d7cc5a 100644 --- a/libshortfin/src/shortfin/array/array.cc +++ b/libshortfin/src/shortfin/array/array.cc @@ -6,29 +6,56 @@ #include "shortfin/array/array.h" +#include + #include "fmt/core.h" #include "fmt/ranges.h" +#include "shortfin/array/xtensor_bridge.h" namespace shortfin::array { +template class InlinedDims; + // -------------------------------------------------------------------------- // // device_array // -------------------------------------------------------------------------- // -std::string device_array::to_s() const { - return fmt::format("device_array([{}], dtype='{}', {})", - fmt::join(shape(), ", "), dtype().name(), - storage_.device().to_s()); -} +const mapping device_array::data() const { return storage_.MapRead(); } -// -------------------------------------------------------------------------- // -// host_array -// -------------------------------------------------------------------------- // +mapping device_array::data() { return storage_.MapRead(); } + +mapping device_array::data_rw() { return storage_.MapReadWrite(); } + +mapping device_array::data_w() { return storage_.MapWriteDiscard(); } -std::string host_array::to_s() const { - return fmt::format("host_array([{}], dtype='{}', {})", +std::optional device_array::map_memory_for_xtensor() { + if (storage_.is_mappable_for_read_write()) { + return storage_.MapReadWrite(); + } else if (storage_.is_mappable_for_read()) { + return storage_.MapRead(); + } + return {}; +} + +std::string device_array::to_s() const { + std::string contents; + const char *contents_prefix = " "; + if (!storage_.is_mappable_for_read()) { + contents = ""; + } else { + auto maybe_contents = contents_to_s(); + if (maybe_contents) { + contents = std::move(*maybe_contents); + contents_prefix = "\n"; + } else { + contents = ""; + } + } + + return fmt::format("device_array([{}], dtype='{}', device={}({})) ={}{}", fmt::join(shape(), ", "), dtype().name(), - storage_.device().to_s()); + storage_.device().to_s(), storage_.formatted_memory_type(), + contents_prefix, contents); } } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/array.h b/libshortfin/src/shortfin/array/array.h index 1e0ea80d2..c3ab6e302 100644 --- a/libshortfin/src/shortfin/array/array.h +++ b/libshortfin/src/shortfin/array/array.h @@ -12,8 +12,10 @@ #include #include +#include "shortfin/array/dims.h" #include "shortfin/array/dtype.h" #include "shortfin/array/storage.h" +#include "shortfin/array/xtensor_bridge.h" #include "shortfin/support/api.h" namespace shortfin::array { @@ -28,129 +30,98 @@ class SHORTFIN_API base_array { // a value type because the Dims union is otherwise not copy/movable. base_array(const base_array &other) : base_array(other.shape(), other.dtype()) {} - base_array(base_array &&other) : rank_(other.rank_), dtype_(other.dtype_) { - // Custom move the dims to avoid an additional allocation. This could just - // be a memcpy on most impls, but this is the "right way". - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - new (&shape_.dynamic_dims) Dims(); - shape_.dynamic_dims = std::move(other.shape_.dynamic_dims); - } else { - // Inline allocation. - new (&shape_.inline_dims) Dims(); - shape_.inline_dims = other.shape_.inline_dims; - } - other.rank_ = 0; - } - virtual ~base_array() { ClearDims(); } + base_array(base_array &&other) + : dtype_(other.dtype_), shape_(std::move(other.shape_)) {} + virtual ~base_array() = default; + virtual std::string to_s() const = 0; DType dtype() const { return dtype_; } // Access shape. - void set_shape(std::span shape) { - ClearDims(); - rank_ = shape.size(); - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - new (&shape_.dynamic_dims) std::unique_ptr(new size_t[rank_]); - std::copy(shape.begin(), shape.end(), shape_.dynamic_dims.get()); - } else { - // Inline allocation. - new (&shape_.inline_dims) Dims(); - std::copy(shape.begin(), shape.end(), shape_.inline_dims.begin()); - } - } - std::span shape() const { - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - return std::span(shape_.dynamic_dims.get(), rank_); - } else { - // Inline allocation. - return std::span(&shape_.inline_dims.front(), rank_); - } - } - std::span mutable_shape() { - if (rank_ > MAX_INLINE_RANK) { - // Dynamic allocation. - return std::span(shape_.dynamic_dims.get(), rank_); - } else { - // Inline allocation. - return std::span(&shape_.inline_dims.front(), rank_); - } - } + void set_shape(std::span shape) { shape_.set(shape); } + std::span shape() const { return shape_.span(); } + std::span mutable_shape() { return shape_.span(); } - private: - static constexpr size_t MAX_INLINE_RANK = 6; - union Dims { - Dims() {} - ~Dims() {} - std::array inline_dims; - std::unique_ptr dynamic_dims; - }; - - // Clears shape, setting the rank to zero and deleting any non-inline - // dimension storage. - void ClearDims() { - if (rank_ > MAX_INLINE_RANK) { - shape_.dynamic_dims.~unique_ptr(); - } - rank_ = 0; - } + // Sometimes we need to access the raw shape container (i.e. for adapters, + // etc). + Dims &shape_container() { return shape_; } + const Dims &shape_container() const { return shape_; } - size_t rank_ = 0; + private: DType dtype_; Dims shape_; }; -// View over some device allocation, modeled as a dense C-order nd array. -class SHORTFIN_API device_array final : public base_array { +class SHORTFIN_API device_array + : public base_array, + public poly_xt_mixin { public: device_array(class storage storage, std::span shape, DType dtype) : base_array(shape, dtype), storage_(std::move(storage)) {} - static device_array allocate(local::ScopedDevice &device, - std::span shape, DType dtype) { + class storage &storage() { return storage_; } + local::ScopedDevice &device() { return storage_.device(); } + + // Allocate an array on the device. + static device_array for_device(local::ScopedDevice &device, + std::span shape, DType dtype) { return device_array( storage::AllocateDevice(device, dtype.compute_dense_nd_size(shape)), shape, dtype); } - class storage &storage() { return storage_; } - local::ScopedDevice &device() { return storage_.device(); } - std::string to_s() const; - - private: - class storage storage_; -}; - -// View over some host allocation, registered for transfer to/from the -// device. -// These arrays can either be allocated directly or ::for_transfer with -// a corresponding device_array. -class SHORTFIN_API host_array final : public base_array { - public: - host_array(class storage storage, std::span shape, DType dtype) - : base_array(shape, dtype), storage_(std::move(storage)) {} - - static host_array allocate(local::ScopedDevice &device, - std::span shape, DType dtype) { - return host_array( + // Allocates a host array that is registered by the device. This can include + // arrays that are visible from different combinations of host/device. + static device_array for_host(local::ScopedDevice &device, + std::span shape, DType dtype) { + return device_array( storage::AllocateHost(device, dtype.compute_dense_nd_size(shape)), shape, dtype); } - // Allocates a host array for transfer to/from the given device array. - static host_array for_transfer(device_array &with_device_array) { - return allocate(with_device_array.storage().device(), - with_device_array.shape(), with_device_array.dtype()); + // Allocates a host array for transfer to/from this array. + device_array for_transfer() { + return for_host(storage().device(), shape(), dtype()); } - class storage &storage() { return storage_; } - local::ScopedDevice &device() { return storage_.device(); } - std::string to_s() const; + // Untyped access to the backing data. The array must be mappable. Specific + // access modes: + // * data(): Read-only access to the data. + // * data_rw(): Read/write access to the data. + // * data_w(): Write-only access to the data with discard (initial contents + // are undefined.) + const mapping data() const; + mapping data(); + // Map the array's data for read-write untyped access. + mapping data_rw(); + // Map the array's data for write-only untyped access. + mapping data_w(); + + // Maps memory for bridging to xtensor. If mapping is unsupported, return {}. + std::optional map_memory_for_xtensor(); + + // Typed access to the backing data. + template + typed_mapping typed_data() { + return typed_mapping(data()); + } + template + typed_mapping typed_data() const { + return typed_mapping(data()); + } + template + typed_mapping typed_data_rw() { + return typed_mapping(data_rw()); + } + template + typed_mapping typed_data_w() { + return typed_mapping(data_w()); + } - private: + std::string to_s() const override; + + protected: class storage storage_; }; diff --git a/libshortfin/src/shortfin/array/array_test.cc b/libshortfin/src/shortfin/array/array_test.cc new file mode 100644 index 000000000..2c435b292 --- /dev/null +++ b/libshortfin/src/shortfin/array/array_test.cc @@ -0,0 +1,62 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include +#include + +#include "shortfin/array/api.h" +#include "shortfin/local/systems/host.h" + +using namespace shortfin; +using namespace shortfin::local; +using namespace shortfin::array; + +namespace { + +class DeviceArrayTest : public testing::Test { + protected: + DeviceArrayTest() {} + + void SetUp() override { + system = systems::HostCPUSystemBuilder().CreateSystem(); + scope = system->CreateScope(system->init_worker(), system->devices()); + device = scope->device(0); + } + void TearDown() override { + system->Shutdown(); + system.reset(); + } + + SystemPtr system; + std::shared_ptr scope; + ScopedDevice device; +}; + +TEST_F(DeviceArrayTest, contents_to_s_valid) { + device_array ary1 = device_array::for_host( + device, std::to_array({2, 3}), DType::float32()); + { + auto map = ary1.typed_data_w(); + std::fill(map.begin(), map.end(), 42.0); + } + + std::optional contents = ary1.contents_to_s(); + ASSERT_TRUE(contents); + EXPECT_EQ(*contents, "{{ 42., 42., 42.},\n { 42., 42., 42.}}"); +} + +TEST_F(DeviceArrayTest, contents_to_s_invalid) { + device_array ary1 = device_array::for_host( + device, std::to_array({2, 3}), DType::opaque32()); + // No xtensor adaptor for opaque32. + std::optional contents = ary1.contents_to_s(); + ASSERT_FALSE(contents); +} + +} // namespace diff --git a/libshortfin/src/shortfin/array/dims.h b/libshortfin/src/shortfin/array/dims.h new file mode 100644 index 000000000..529aebc42 --- /dev/null +++ b/libshortfin/src/shortfin/array/dims.h @@ -0,0 +1,256 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_ARRAY_DIMS_H +#define SHORTFIN_ARRAY_DIMS_H + +#include +#include +#include + +#include "shortfin/support/api.h" + +namespace shortfin::array { + +// Vector-alike for storing inlined dims. Note that this has a template +// signature identical to std::vector because xtensor specializes on this +// exact signature. See the concrete size_t instantiation below. +template > +class SHORTFIN_API InlinedDims { + public: + using element_type = T; + using value_type = T; + using allocator_type = Alloc; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = value_type &; + using const_reference = const value_type &; + using pointer = value_type *; + using const_pointer = const value_type *; + + class iterator { + public: + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T *; + using reference = T &; + using iterator_category = std::random_access_iterator_tag; + iterator(pointer p) : p(p) {} + iterator &operator++() { + p++; + return *this; + } + iterator &operator++(int) { + p++; + return *this; + } + bool operator==(iterator other) const { return p == other.p; } + bool operator!=(iterator other) const { return p != other.p; } + reference operator*() { return *p; } + + private: + pointer p; + }; + class const_iterator { + public: + using difference_type = std::ptrdiff_t; + using value_type = const T; + using pointer = const T *; + using reference = const T &; + using iterator_category = std::random_access_iterator_tag; + + const_iterator(pointer p) : p(p) {} + const_iterator &operator++() { + p++; + return *this; + } + const_iterator &operator++(int) { + p++; + return *this; + } + bool operator==(const_iterator other) const { return p == other.p; } + bool operator!=(const_iterator other) const { return p != other.p; } + reference operator*() { return *p; } + + private: + pointer p; + }; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + InlinedDims() { new (&dims_.inline_dims) InlineTy(); } + InlinedDims(size_type count, T value = T()) : size_(count) { + if (size_ > MAX_INLINE_RANK) { + // Dynamic allocation. + new (&dims_.dynamic_dims) DynamicTy(new element_type[size_]); + std::fill(dims_.dynamic_dims.get(), dims_.dynamic_dims.get() + size_, + value); + } else { + // Inline allocation. + new (&dims_.inline_dims) InlineTy(); + std::fill(dims_.inline_dims.begin(), dims_.inline_dims.end(), value); + } + } + InlinedDims(const InlinedDims &other) { + new (&dims_.inline_dims) InlineTy(); + set(other.span()); + } + InlinedDims(InlinedDims &&other) : size_(other.size_) { + // Custom move the dims to avoid an additional allocation. This could just + // be a memcpy on most impls, but this is the "right way". + if (size_ > MAX_INLINE_RANK) { + // Dynamic allocation. + new (&dims_.dynamic_dims) DynamicTy(); + dims_.dynamic_dims = std::move(other.dims_.dynamic_dims); + } else { + // Inline allocation. + new (&dims_.inline_dims) InlineTy(); + dims_.inline_dims = other.dims_.inline_dims; + } + other.size_ = 0; + } + InlinedDims &operator=(const InlinedDims &other) { + set(other.span()); + return *this; + } + ~InlinedDims() { clear(); } + + T *data() { + if (size_ > MAX_INLINE_RANK) { + return dims_.dynamic_dims.get(); + } else { + return &dims_.inline_dims.front(); + } + } + const T *data() const { + if (size_ > MAX_INLINE_RANK) { + return dims_.dynamic_dims.get(); + } else { + return &dims_.inline_dims.front(); + } + } + std::size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + + // Clears shape, setting the rank to zero and deleting any non-inline + // dimension storage. + void clear() { + if (size_ > MAX_INLINE_RANK) { + dims_.dynamic_dims.~unique_ptr(); + } else { + dims_.inline_dims.~array(); + } + size_ = 0; + } + + void set(std::span dims) { + clear(); + size_ = dims.size(); + if (size_ > MAX_INLINE_RANK) { + // Dynamic allocation. + new (&dims_.dynamic_dims) DynamicTy(new element_type[size_]); + std::copy(dims.begin(), dims.end(), dims_.dynamic_dims.get()); + } else { + // Inline allocation. + new (&dims_.inline_dims) InlineTy(); + std::copy(dims.begin(), dims.end(), dims_.inline_dims.begin()); + } + } + + // Container access. + iterator begin() { return iterator(data()); } + iterator end() { return iterator(data() + size()); } + const_iterator begin() const { return const_iterator(data()); } + const_iterator end() const { return const_iterator(data() + size()); } + const_iterator cbegin() const { return const_iterator(data()); } + const_iterator cend() const { return const_iterator(data() + size()); } + + void resize(size_type count) { resize_impl(count, value_type()); } + void resize(size_type count, value_type value) { resize_impl(count, value); } + + reference operator[](std::size_t idx) { return *(data() + idx); } + const_reference operator[](std::size_t idx) const { return *(data() + idx); } + + reference front() { return *data(); } + const_reference front() const { return *data(); } + reference back() { return *(data() + size() - 1); } + const_reference back() const { return *(data() + size() - 1); } + + // Access as a span. + std::span span() { return std::span(data(), size_); } + std::span span() const { return std::span(data(), size_); } + + private: + void resize_impl(size_type count, value_type value) { + if (count == size()) return; + if (size() > MAX_INLINE_RANK) { + // Currently dynamically allocated. + if (count < size()) { + // Truncate. + if (count < MAX_INLINE_RANK) { + // Switch to inlined. + InlineTy new_array; + for (std::size_t i = 0; i < count; ++i) + new_array[i] = dims_.dynamic_dims[i]; + dims_.dynamic_dims.~unique_ptr(); + new (&dims_.inline_dims) InlineTy(new_array); + size_ = count; + } else { + // Stay dynamic and just truncate. + size_ = count; + } + } else { + // Expand and stay dynamic. + DynamicTy new_array(new element_type[count]); + for (std::size_t i = 0; i < size_; ++i) + new_array[i] = dims_.dynamic_dims[i]; + for (std::size_t i = size_; i < count; ++i) new_array[i] = value; + dims_.dynamic_dims = std::move(new_array); + size_ = count; + } + } else { + // Currently inlined. + if (count < size()) { + // Truncate. + size_ = count; + } else if (count < MAX_INLINE_RANK) { + // Stay inlined and initialize new items. + for (std::size_t i = size_; i < count; ++i) + dims_.inline_dims[i] = value; + size_ = count; + } else { + // Need to switch to dynamic size. + DynamicTy new_array(new element_type[count]); + for (std::size_t i = 0; i < size_; ++i) + new_array[i] = dims_.inline_dims[i]; + for (std::size_t i = size_; i < count; ++i) new_array[i] = value; + dims_.inline_dims.~array(); + new (&dims_.dynamic_dims) DynamicTy(std::move(new_array)); + size_ = count; + } + } + } + + static constexpr size_t MAX_INLINE_RANK = 6; + using InlineTy = std::array; + using DynamicTy = std::unique_ptr; + union _D { + _D() {} + ~_D() {} + InlineTy inline_dims; + DynamicTy dynamic_dims; + }; + + std::size_t size_ = 0; + _D dims_; +}; + +extern template class InlinedDims; +using Dims = InlinedDims; + +} // namespace shortfin::array + +#endif // SHORTFIN_ARRAY_DIMS_H diff --git a/libshortfin/src/shortfin/array/dims_test.cc b/libshortfin/src/shortfin/array/dims_test.cc new file mode 100644 index 000000000..287e2fa9a --- /dev/null +++ b/libshortfin/src/shortfin/array/dims_test.cc @@ -0,0 +1,148 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/array/dims.h" + +#include +#include + +#include + +namespace shortfin::array { + +TEST(array_dims, empty) { + Dims dims; + EXPECT_TRUE(dims.empty()); + EXPECT_EQ(dims.size(), 0); +} + +TEST(array_dims, inline_init) { + Dims dims(3, 42); + EXPECT_EQ(dims.size(), 3); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 42); + } + + Dims copy(dims); + EXPECT_EQ(dims.size(), copy.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), copy.begin())); + EXPECT_TRUE(std::equal(dims.cbegin(), dims.cend(), copy.begin())); + + Dims move = std::move(copy); + EXPECT_EQ(dims.size(), move.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), move.begin())); + + Dims assign; + assign = dims; + EXPECT_EQ(dims.size(), assign.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), assign.begin())); + + EXPECT_EQ(*dims.data(), *assign.data()); + + assign.clear(); + EXPECT_TRUE(assign.empty()); +} + +TEST(array_dims, dynamic_init) { + Dims dims(12, 42); + EXPECT_EQ(dims.size(), 12); + for (size_t i = 0; i < 12; ++i) { + EXPECT_EQ(dims[i], 42); + } + + Dims copy(dims); + EXPECT_EQ(dims.size(), copy.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), copy.begin())); + EXPECT_TRUE(std::equal(dims.cbegin(), dims.cend(), copy.begin())); + + Dims move = std::move(copy); + EXPECT_EQ(dims.size(), move.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), move.begin())); + + Dims assign; + assign = dims; + EXPECT_EQ(dims.size(), assign.size()); + EXPECT_TRUE(std::equal(dims.begin(), dims.end(), assign.begin())); + + EXPECT_EQ(*dims.data(), *assign.data()); + + assign.clear(); + EXPECT_TRUE(assign.empty()); +} + +TEST(array_dims, resize_same_size) { + Dims dims(3, 64); + dims.resize(3, 32); + EXPECT_EQ(dims.size(), 3); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +TEST(array_dims, resize_inline_to_inline) { + Dims dims(3, 64); + dims.resize(5, 32); + EXPECT_EQ(dims.size(), 5); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } + for (size_t i = 3; i < 5; ++i) { + EXPECT_EQ(dims[i], 32); + } +} + +TEST(array_dims, resize_inline_to_dynamic) { + Dims dims(3, 64); + dims.resize(12, 32); + EXPECT_EQ(dims.size(), 12); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } + for (size_t i = 3; i < 12; ++i) { + EXPECT_EQ(dims[i], 32); + } +} + +TEST(array_dims, resize_inline_truncate) { + Dims dims(5, 64); + dims.resize(2, 32); + EXPECT_EQ(dims.size(), 2); + for (size_t i = 0; i < 2; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +TEST(array_dims, resize_dynamic_to_dynamic) { + Dims dims(12, 64); + dims.resize(15, 32); + EXPECT_EQ(dims.size(), 15); + for (size_t i = 0; i < 12; ++i) { + EXPECT_EQ(dims[i], 64); + } + for (size_t i = 12; i < 15; ++i) { + EXPECT_EQ(dims[i], 32); + } +} + +TEST(array_dims, resize_truncate_to_inline) { + Dims dims(12, 64); + dims.resize(3, 32); + EXPECT_EQ(dims.size(), 3); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +TEST(array_dims, resize_truncate_to_dynamic) { + Dims dims(12, 64); + dims.resize(10, 32); + EXPECT_EQ(dims.size(), 10); + for (size_t i = 0; i < 10; ++i) { + EXPECT_EQ(dims[i], 64); + } +} + +} // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/dtype_test.cc b/libshortfin/src/shortfin/array/dtype_test.cc new file mode 100644 index 000000000..f1dc0477a --- /dev/null +++ b/libshortfin/src/shortfin/array/dtype_test.cc @@ -0,0 +1,34 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/array/dtype.h" + +#include +#include + +#include + +namespace shortfin::array { + +TEST(array_dtype, basics) { + EXPECT_EQ(DType::complex64().name(), "complex64"); + EXPECT_EQ(static_cast(DType::complex64()), + IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64); + EXPECT_TRUE(DType::complex64() == DType::complex64()); + EXPECT_TRUE(DType::complex64() != DType::complex128()); +} + +TEST(array_dtype, compure_dense_nd_size) { + // 0d special case. + EXPECT_EQ(DType::float32().compute_dense_nd_size({}), 4); + // 0 extent special case. + EXPECT_EQ(DType::float32().compute_dense_nd_size(std::array{0, 4}), + 0); + EXPECT_EQ(DType::float32().compute_dense_nd_size(std::array{2, 4}), + 32); +} + +} // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/storage.cc b/libshortfin/src/shortfin/array/storage.cc index 6eb37267d..fa9e0f4b8 100644 --- a/libshortfin/src/shortfin/array/storage.cc +++ b/libshortfin/src/shortfin/array/storage.cc @@ -14,6 +14,10 @@ namespace shortfin::array { using namespace local; using namespace local::detail; +// -------------------------------------------------------------------------- // +// storage +// -------------------------------------------------------------------------- // + namespace detail { void ThrowIllegalDeviceAffinity(Device *first, Device *second) { throw std::invalid_argument(fmt::format( @@ -22,6 +26,15 @@ void ThrowIllegalDeviceAffinity(Device *first, Device *second) { } } // namespace detail +storage::storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, + local::detail::TimelineResource::Ref timeline_resource) + : timeline_resource_(std::move(timeline_resource)), + buffer_(std::move(buffer)), + device_(device) { + logging::construct("array::storage", this); +} +storage::~storage() { logging::destruct("array::storage", this); } + storage storage::AllocateDevice(ScopedDevice &device, iree_device_size_t allocation_size) { if (!device.raw_device()) { @@ -99,7 +112,81 @@ void storage::Fill(const void *pattern, iree_host_size_t pattern_length) { } void storage::CopyFrom(storage &source_storage) { - // TODO + device_.scope().scheduler().AppendCommandBuffer( + device_, TransactionType::TRANSFER, [&](Account &account) { + // Must depend on the source's mutation dependencies to avoid + // read-before-write hazard. + account.active_deps_extend( + source_storage.timeline_resource_->mutation_barrier()); + // And depend on our own use and mutations dependencies. + account.active_deps_extend(timeline_resource_->use_barrier()); + account.active_deps_extend(timeline_resource_->mutation_barrier()); + + SHORTFIN_THROW_IF_ERROR(iree_hal_command_buffer_copy_buffer( + account.active_command_buffer(), + /*source_ref=*/ + iree_hal_make_buffer_ref(source_storage.buffer_, 0, byte_length()), + /*target_ref=*/ + iree_hal_make_buffer_ref(buffer_, 0, byte_length()))); + + // And move our own mutation barrier to the current pending timeline + // value. + timeline_resource_->set_mutation_barrier( + account.timeline_sem(), account.timeline_idle_timepoint()); + }); +} + +bool storage::is_mappable_for_read() const { + return (iree_hal_buffer_allowed_usage(buffer_) & + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && + (iree_hal_buffer_allowed_access(buffer_) & + IREE_HAL_MEMORY_ACCESS_READ); +} + +bool storage::is_mappable_for_read_write() const { + return (iree_hal_buffer_allowed_usage(buffer_) & + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && + (iree_hal_buffer_allowed_access(buffer_) & + (IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE)); +} + +void storage::MapExplicit(mapping &mapping, iree_hal_memory_access_t access) { + assert(access != IREE_HAL_MEMORY_ACCESS_NONE); + mapping.reset(); + SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_map_range( + buffer_, IREE_HAL_MAPPING_MODE_SCOPED, access, + /*byte_offset=*/0, byte_length(), &mapping.mapping_)); + mapping.access_ = access; + mapping.timeline_resource_ = timeline_resource_; +} + +iree_hal_memory_type_t storage::memory_type() const { + return iree_hal_buffer_memory_type(buffer_); +} +iree_hal_memory_access_t storage::memory_access() const { + return iree_hal_buffer_allowed_access(buffer_); +} +iree_hal_buffer_usage_t storage::buffer_usage() const { + return iree_hal_buffer_allowed_usage(buffer_); +} + +// Formatted type and access. +std::string storage::formatted_memory_type() const { + iree_bitfield_string_temp_t temp; + auto sv = iree_hal_memory_type_format(memory_type(), &temp); + return std::string(sv.data, sv.size); +} + +std::string storage::formatted_memory_access() const { + iree_bitfield_string_temp_t temp; + auto sv = iree_hal_memory_access_format(memory_access(), &temp); + return std::string(sv.data, sv.size); +} + +std::string storage::formatted_buffer_usage() const { + iree_bitfield_string_temp_t temp; + auto sv = iree_hal_buffer_usage_format(buffer_usage(), &temp); + return std::string(sv.data, sv.size); } std::string storage::to_s() const { @@ -107,4 +194,27 @@ std::string storage::to_s() const { byte_length()); } +// -------------------------------------------------------------------------- // +// mapping +// -------------------------------------------------------------------------- // + +mapping::mapping() { + logging::construct("array::mapping", this); + std::memset(&mapping_, 0, sizeof(mapping_)); +} + +mapping::~mapping() noexcept { + logging::destruct("array::mapping", this); + reset(); +} + +void mapping::reset() noexcept { + if (*this) { + // Crash the process on failure to unmap. We don't have a good mitigation, + IREE_CHECK_OK(iree_hal_buffer_unmap_range(&mapping_)); + access_ = IREE_HAL_MEMORY_ACCESS_NONE; + timeline_resource_.reset(); + } +} + } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/storage.h b/libshortfin/src/shortfin/array/storage.h index 10d0313fe..0db73d28f 100644 --- a/libshortfin/src/shortfin/array/storage.h +++ b/libshortfin/src/shortfin/array/storage.h @@ -14,9 +14,65 @@ namespace shortfin::array { +// Access to mapped memory. +// Mappings are moveable but not copyable. When default constructed or moved +// from, they will not be valid and have nullptr semantics. +class SHORTFIN_API mapping { + public: + mapping(); + mapping(const mapping &) = delete; + mapping &operator=(const mapping &) = delete; + mapping &operator=(mapping &&other) { + timeline_resource_ = std::move(other.timeline_resource_); + access_ = other.access_; + mapping_ = other.mapping_; + other.access_ = IREE_HAL_MEMORY_ACCESS_NONE; + std::memset(&other.mapping_, 0, sizeof(other.mapping_)); + return *this; + } + mapping(mapping &&other) + : timeline_resource_(std::move(other.timeline_resource_)), + access_(other.access_), + mapping_(other.mapping_) { + other.access_ = IREE_HAL_MEMORY_ACCESS_NONE; + std::memset(&other.mapping_, 0, sizeof(other.mapping_)); + } + ~mapping() noexcept; + + // Whether the mapping is valid. + operator bool() const { return access_ != IREE_HAL_MEMORY_ACCESS_NONE; } + + // Resets the mapping, making it invalid (if not already so); + void reset() noexcept; + + // Access the mapped data. The mapping must be valid or else it is UB. + const uint8_t *data() const { + assert(*this && "mapping is not valid"); + return mapping_.contents.data; + } + uint8_t *data() { + assert(*this && "mapping is not valid"); + return mapping_.contents.data; + } + + // The size of the mapped data. Will return 0 if the mapping is not valid. + iree_device_size_t size() const { return mapping_.contents.data_length; } + + bool readable() const { return access_ & IREE_HAL_MEMORY_ACCESS_READ; } + bool writable() const { return access_ & IREE_HAL_MEMORY_ACCESS_WRITE; } + + private: + // See note on storage::timeline_resource_. Must be declared first. + local::detail::TimelineResource::Ref timeline_resource_; + iree_hal_memory_access_t access_ = IREE_HAL_MEMORY_ACCESS_NONE; + iree_hal_buffer_mapping_t mapping_; + friend class storage; +}; + // Array storage backed by an IREE buffer of some form. class SHORTFIN_API storage { public: + ~storage(); local::ScopedDevice &device() { return device_; } local::Scope &scope() { return device_.scope(); } const local::ScopedDevice &device() const { return device_; } @@ -29,9 +85,9 @@ class SHORTFIN_API storage { // Allocates host storage, compatible with the given device affinity. // By default, if there are any affinity bits set in the device, then - // the storage will be device visible and have permitted usage for transfers. - // This default policy can be overriden based on device defaults or explicit - // options. + // the storage will be device visible and have permitted usage for + // transfers. This default policy can be overriden based on device defaults + // or explicit options. static storage AllocateHost(local::ScopedDevice &device, iree_device_size_t allocation_size); @@ -53,17 +109,106 @@ class SHORTFIN_API storage { return iree_hal_buffer_byte_length(buffer_.get()); } + // Memory type and access. + iree_hal_memory_type_t memory_type() const; + iree_hal_memory_access_t memory_access() const; + iree_hal_buffer_usage_t buffer_usage() const; + + // Formatted type and access. + std::string formatted_memory_type() const; + std::string formatted_memory_access() const; + std::string formatted_buffer_usage() const; + + // Whether the buffer supports host mappable memory. + bool is_mappable_for_read() const; + bool is_mappable_for_read_write() const; + + // Maps the memory for access from a host pointer using a scoped mapping. + void MapExplicit(mapping &mapping, iree_hal_memory_access_t access); + + // Maps the memory for read/write access, preserving any contents. + mapping MapReadWrite() { + mapping m; + MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE); + return m; + } + + // Maps the memory for discard write. This is used if populating an initial + // buffer. + mapping MapWriteDiscard() { + mapping m; + MapExplicit(m, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE); + return m; + } + + // Maps the memory for read-only access. + mapping MapRead() { + mapping m; + MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ); + return m; + } + + const mapping MapRead() const { + mapping m; + const_cast(this)->MapExplicit(m, IREE_HAL_MEMORY_ACCESS_READ); + return m; + } + std::string to_s() const; + // Access raw buffer. This must not be retained apart from the storage for + // any length of time that may extend its lifetime (as the storage keeps + // underlying device references alive as needed). + operator iree_hal_buffer_t *() { return buffer_; } + private: storage(local::ScopedDevice device, iree::hal_buffer_ptr buffer, - local::detail::TimelineResource::Ref timeline_resource) - : buffer_(std::move(buffer)), - device_(device), - timeline_resource_(std::move(timeline_resource)) {} + local::detail::TimelineResource::Ref timeline_resource); + // The timeline resource holds the back reference to the owning scope, + // which keeps all devices alive. Buffers must be destroyed before devices, + // so this must be declared first. + local::detail::TimelineResource::Ref timeline_resource_; iree::hal_buffer_ptr buffer_; local::ScopedDevice device_; - local::detail::TimelineResource::Ref timeline_resource_; +}; + +// Wraps an untyped mapping, providing typed access. +template +class typed_mapping { + public: + using span_type = std::span; + using const_span_type = std::span; + + typed_mapping(mapping untyped_mapping) + : untyped_mapping_(std::move(untyped_mapping)) {} + typed_mapping(const typed_mapping &) = delete; + typed_mapping &operator=(const typed_mapping &) = delete; + + iree_device_size_t size() const noexcept { + return untyped_mapping_.size() / sizeof(EltTy); + } + bool empty() const noexcept { return size() == 0; } + EltTy *data() noexcept { + return reinterpret_cast(untyped_mapping_.data()); + } + EltTy *data() const noexcept { + return reinterpret_cast(untyped_mapping_.data()); + } + + span_type span() { return span_type(data(), size()); } + const_span_type span() const { return const_span_type(data(), size()); } + + span_type::iterator begin() { return span().begin(); } + span_type::iterator end() { return span().end(); } + + const_span_type::iterator begin() const { return span().begin(); } + const_span_type::iterator end() const { return span().end(); } + + const_span_type::iterator cbegin() const { return span().begin(); } + const_span_type::iterator cend() const { return span().end(); } + + private: + mapping untyped_mapping_; }; } // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/xtensor_bridge.cc b/libshortfin/src/shortfin/array/xtensor_bridge.cc new file mode 100644 index 000000000..0dc00f9c7 --- /dev/null +++ b/libshortfin/src/shortfin/array/xtensor_bridge.cc @@ -0,0 +1,93 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/array/xtensor_bridge.h" + +#include + +namespace shortfin::array { + +namespace { + +template +class typed_xt_methods final : public poly_xt_methods { + public: + using xt_specific_t = + decltype(xt::adapt(static_cast(nullptr), Dims())); + // Our specific adaptor type must fit within the memory allocation of the + // generic adaptor type. + static_assert(sizeof(xt_specific_t) <= sizeof(xt_generic_t)); + + xt_specific_t &adaptor() { + return *reinterpret_cast(adaptor_storage); + } + + static void concrete_inplace_new(uint8_t *inst_storage, void *array_memory, + size_t array_memory_size, Dims &dims) { + // We rely on the fact that the typed_xt_methods specialization has the + // exact same memory layout as the base class. + static_assert(sizeof(typed_xt_methods) == sizeof(poly_xt_methods)); + + typed_xt_methods *methods = + reinterpret_cast(inst_storage); + new (methods) typed_xt_methods(); + new (methods->adaptor_storage) + xt_specific_t(xt::adapt(static_cast(array_memory), dims)); + } + + void inplace_destruct_this() override { + adaptor().~xt_specific_t(); + this->~typed_xt_methods(); + } + + std::string contents_to_s() override { + std::stringstream out; + out << adaptor(); + return out.str(); + } +}; +} // namespace + +bool poly_xt_methods::inplace_new(uint8_t *inst_storage, DType dtype, + void *array_memory, size_t array_memory_size, + Dims &dims) { +#define POLY_XT_CASE(et, cpp_type) \ + case et: \ + typed_xt_methods::concrete_inplace_new( \ + inst_storage, array_memory, array_memory_size, dims); \ + return true + + switch (static_cast(dtype)) { + // Hot comparisons first. + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_32, float); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_32, int32_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_32, int32_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_32, uint32_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_64, int64_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_64, int64_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_64, uint64_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_8, int8_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_8, int8_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_8, uint8_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_INT_16, int16_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_SINT_16, int16_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_UINT_16, uint16_t); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_64, double); + POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_BOOL_8, bool); + // TODO: float16 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_FLOAT_16, TODO); + // TODO: bfloat16 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_BFLOAT_16, TODO); + // TODO: complex64 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, TODO); + // TODO: complex128 + // POLY_XT_CASE(IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, TODO); + } + + return false; +} + +} // namespace shortfin::array diff --git a/libshortfin/src/shortfin/array/xtensor_bridge.h b/libshortfin/src/shortfin/array/xtensor_bridge.h new file mode 100644 index 000000000..a3243e03b --- /dev/null +++ b/libshortfin/src/shortfin/array/xtensor_bridge.h @@ -0,0 +1,160 @@ +// Copyright 2024 Advanced Micro Devices, Inc +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_ARRAY_XTENSOR_BRIDGE_H +#define SHORTFIN_ARRAY_XTENSOR_BRIDGE_H + +#include + +#include +#include +#include +#include + +#include "shortfin/array/dims.h" +#include "shortfin/array/dtype.h" +#include "shortfin/array/storage.h" + +namespace shortfin::array { + +// Polymorphic trampoline methods to a backing typed, xarray adaptor. This +// allows xtensor facilities to be used in a dtype agnostic fashion. +class SHORTFIN_API poly_xt_methods { + public: + // Prints the contents of the array. + virtual std::string contents_to_s() = 0; + + protected: + // Since we adapt from a pointer-based container with Dims, just pick one + // as a generic version so that we can reserve space in the class for it. + using xt_generic_t = + decltype(xt::adapt(static_cast(nullptr), Dims())); + + // Placement new an appropriate subclass into the provided storage area, + // which must be sized to hold the base class (subclasses are statically + // asserted to be the same size). The appropriate subclass will also placement + // new an appropriate xtensor adaptor into the adaptor_storage field. It is + // statically asserted that the type specific adaptor will fit into the + // storage area reserved. + // Returns true if an appropriate instance is instantiated. False if no + // implementation for the dtype exists. + static bool inplace_new(uint8_t *inst_storage, DType dtype, + void *array_memory, size_t array_memory_size, + Dims &dims); + + // When instantiated via inplace_new, destorys the instance, calling both + // the type specific adaptor destructor and the subclass destructor. + virtual void inplace_destruct_this() = 0; + + uint8_t adaptor_storage[sizeof(xt_generic_t)]; + + template + friend class poly_xt_mixin; +}; + +// Polymorphic xtensor array mixin. Since xt::array is static on element type, +// this class provides a bridge that will polymorphically manage a specialized +// xarray adaptor for a base_array derived class. +// +// This is designed to use via CRTP on a subclass of base_array. +// +// Access is indirected through a heap allocated poly_xt_methods subclass that +// is initialized on-demand by mapping the device memory and constructing an +// appropriate typed subclass. This is done through two layers of generic +// storage (one contained here for the poly_xt_methods subclass and one +// on that class for the concrete xtensor adaptor it contains). The overhead +// on the base_array instance if the xtensor bridge is not used is one pointer. +// On first use, it is a heap allocation and a switch on dtype. +template +class SHORTFIN_API poly_xt_mixin { + public: + poly_xt_mixin() = default; + // Don't copy the poly instance: if it is needed on the copy, it will be + // re-allocated. + poly_xt_mixin(const poly_xt_mixin &other) {} + + std::optional contents_to_s() { + auto *m = optional_xt_methods(); + if (!m) return {}; + return m->contents_to_s(); + } + + std::optional contents_to_s() const { + return const_cast(this)->contents_to_s(); + } + + // Access (potentially instantiating) the polymorphic xt methods trampoline + // for this array. If no xtensor adaptor can be created or if the memory + // is not accessible to the host, returns nullptr. The returned pointer + // must not outlive the creating array. + poly_xt_methods *optional_xt_methods() { + if (poly_) { + return poly_->methods(); + } + DType dtype = derived_this()->dtype(); + auto inst = std::make_unique(); + // CRTP derived class must provide a memory mapping via its + // map_memory_for_xtensor() method. + // This must be typed as MemoryTy and have data() and size() accessors. + std::optional mapping = derived_this()->map_memory_for_xtensor(); + if (!mapping) { + return nullptr; + } + inst->memory = std::move(*mapping); + void *data = static_cast(inst->memory.data()); + size_t data_size = inst->memory.size(); + if (!poly_xt_methods::inplace_new(inst->methods_storage, dtype, data, + data_size, + derived_this()->shape_container())) { + return nullptr; + } + poly_ = std::move(inst); + return poly_.get()->methods(); + } + + // Accesses (potentially instantiating) the polymorphic xt methods trampoline. + // If it cannot be created, throws a std::logic_error. The returned reference + // must not outlive the creating array. + poly_xt_methods &xt_methods() { + auto m = optional_xt_methods(); + if (!m) { + throw std::logic_error(fmt::format( + "No xtensor specialization registered for dtype {} or storage type", + derived_this()->dtype().name())); + } + return *m; + } + + protected: + ~poly_xt_mixin() { + if (poly_) { + // Need to in-place destruct the adaptor and then the methods itself. + poly_->methods()->inplace_destruct_this(); + } + } + + private: + struct PolyInstance { + MemoryTy memory; + uint8_t methods_storage[sizeof(poly_xt_methods)]; + poly_xt_methods *methods() { + return reinterpret_cast(methods_storage); + } + }; + + const DerivedArrayTy *derived_this() const { + return static_cast(this); + } + DerivedArrayTy *derived_this() { return static_cast(this); } + + // If the polymorphic accessor has been instantiated, it will be constructed + // here. + std::unique_ptr poly_; +}; + +} // namespace shortfin::array + +#endif // SHORTFIN_ARRAY_XTENSOR_BRIDGE_H diff --git a/libshortfin/src/shortfin/local/process.cc b/libshortfin/src/shortfin/local/process.cc index 4fd395368..b40b8ce87 100644 --- a/libshortfin/src/shortfin/local/process.cc +++ b/libshortfin/src/shortfin/local/process.cc @@ -54,10 +54,7 @@ void detail::BaseProcess::Launch() { ScheduleOnWorker(); } -void detail::BaseProcess::ScheduleOnWorker() { - logging::info("ScheduleOnWorker()"); - Terminate(); -} +void detail::BaseProcess::ScheduleOnWorker() { Terminate(); } void detail::BaseProcess::Terminate() { int deallocate_pid; diff --git a/libshortfin/src/shortfin/local/scheduler.cc b/libshortfin/src/shortfin/local/scheduler.cc index 64e4247e6..c5a9fc062 100644 --- a/libshortfin/src/shortfin/local/scheduler.cc +++ b/libshortfin/src/shortfin/local/scheduler.cc @@ -30,6 +30,9 @@ void Account::Initialize() { void Account::Reset() { active_tx_type_ = TransactionType::NONE; + // if (active_command_buffer_) { + // iree_hal_command_buffer_end(active_command_buffer_); + // } active_command_buffer_.reset(); } @@ -67,10 +70,17 @@ CompletionEvent Account::OnSync() { // TimelineResource // -------------------------------------------------------------------------- // -TimelineResource::TimelineResource(iree_allocator_t host_allocator, - size_t semaphore_capacity) { - SHORTFIN_THROW_IF_ERROR(iree_hal_fence_create( - semaphore_capacity, host_allocator, use_barrier_fence_.for_output())); +TimelineResource::TimelineResource(std::shared_ptr scope, + size_t semaphore_capacity) + : scope_(std::move(scope)) { + logging::construct("TimelineResource", this); + SHORTFIN_THROW_IF_ERROR( + iree_hal_fence_create(semaphore_capacity, scope_->host_allocator(), + use_barrier_fence_.for_output())); +} + +TimelineResource::~TimelineResource() { + logging::destruct("TimelineResource", this); } void TimelineResource::use_barrier_insert(iree_hal_semaphore_t *sem, @@ -83,6 +93,19 @@ void TimelineResource::use_barrier_insert(iree_hal_semaphore_t *sem, // Scheduler // -------------------------------------------------------------------------- // +Scheduler::Scheduler(System &system) : system_(system) { + logging::construct("Scheduler", this); +} + +Scheduler::~Scheduler() { + logging::destruct("Scheduler", this); + + // Explicitly reset account state prior to implicit destruction. + for (auto &account : accounts_) { + account.Reset(); + } +} + void Scheduler::Initialize(std::span devices) { for (Device *device : devices) { accounts_.emplace_back(*this, device); diff --git a/libshortfin/src/shortfin/local/scheduler.h b/libshortfin/src/shortfin/local/scheduler.h index 057bfbd9f..2f606ced3 100644 --- a/libshortfin/src/shortfin/local/scheduler.h +++ b/libshortfin/src/shortfin/local/scheduler.h @@ -83,13 +83,35 @@ class SHORTFIN_API TimelineResource { Ref() : res_(nullptr) {} explicit Ref(TimelineResource *res) : res_(res) { res_->Retain(); } Ref(const Ref &other) : res_(other.res_) { res_->Retain(); } - void operator=(const Ref &other) = delete; - Ref(Ref &&other) : res_(other.res_) { other.res_ = nullptr; } - ~Ref() { - if (res_) res_->Release(); + Ref &operator=(const Ref &other) { + if (other.res_ != res_) { + reset(); + if (other.res_) { + other.res_->Retain(); + res_ = other.res_; + } + } + return *this; + } + Ref &operator=(Ref &&other) { + if (other.res_ != res_) { + reset(); + res_ = other.res_; + other.res_ = nullptr; + } + return *this; } + Ref(Ref &&other) : res_(other.res_) { other.res_ = nullptr; } + ~Ref() { reset(); } TimelineResource *operator->() { return res_; } + void reset() { + if (res_) { + res_->Release(); + res_ = nullptr; + } + } + private: TimelineResource *res_; }; @@ -121,13 +143,18 @@ class SHORTFIN_API TimelineResource { } private: - TimelineResource(iree_allocator_t host_allocator, size_t semaphore_capacity); + TimelineResource(std::shared_ptr scope, size_t semaphore_capacity); + ~TimelineResource(); void Retain() { refcnt_++; } void Release() { if (--refcnt_ == 0) delete this; } int refcnt_ = 0; + + // Back reference to the owning scope. + std::shared_ptr scope_; + // Non-owning mutation barrier semaphore and timepoint. The fact that this // is a single semaphore is an implementation detail that may be generalized // in the future should it be necessary to track multiple write sources. @@ -171,11 +198,13 @@ class SHORTFIN_API Account { void Initialize(); void Reset(); Scheduler &scheduler_; + iree::hal_semaphore_ptr sem_; + iree::hal_fence_ptr active_deps_; + iree::hal_command_buffer_ptr active_command_buffer_; + Device *device_; iree_hal_device_t *hal_device_; TransactionType active_tx_type_ = TransactionType::NONE; - iree::hal_fence_ptr active_deps_; - iree::hal_command_buffer_ptr active_command_buffer_; iree_hal_queue_affinity_t active_queue_affinity_bits_; // Timepoint at which this device is considered idle, inclusive of any @@ -193,14 +222,14 @@ class SHORTFIN_API Account { // an eventual submission would submit a duplicate timepoint). This // timepoint is only valid for the local sem_. uint64_t idle_timepoint_ = 0; - iree::hal_semaphore_ptr sem_; friend class Scheduler; }; // Handles scheduling state for a scope. class SHORTFIN_API Scheduler { public: - Scheduler(System &system) : system_(system) {} + Scheduler(System &system); + ~Scheduler(); TransactionMode transaction_mode() const { return tx_mode_; } @@ -224,9 +253,9 @@ class SHORTFIN_API Scheduler { // Gets a fresh TimelineResource which can be used for tracking resource // read/write and setting barriers. Note that these are all allocated fresh // on each call today but may be pooled in the future. - TimelineResource::Ref NewTimelineResource(iree_allocator_t host_allocator) { + TimelineResource::Ref NewTimelineResource(std::shared_ptr scope) { return TimelineResource::Ref( - new TimelineResource(host_allocator, semaphore_count_)); + new TimelineResource(std::move(scope), semaphore_count_)); } System &system() { return system_; } diff --git a/libshortfin/src/shortfin/local/scope.cc b/libshortfin/src/shortfin/local/scope.cc index f0eb9ca77..39784f196 100644 --- a/libshortfin/src/shortfin/local/scope.cc +++ b/libshortfin/src/shortfin/local/scope.cc @@ -21,10 +21,11 @@ namespace shortfin::local { Scope::Scope(std::shared_ptr system, Worker &worker, std::span> devices) - : host_allocator_(system->host_allocator()), - scheduler_(*system), - system_(std::move(system)), + : system_(std::move(system)), + host_allocator_(system_->host_allocator()), + scheduler_(*system_), worker_(worker) { + logging::construct("Scope", this); for (auto &it : devices) { AddDevice(it.first, it.second); } @@ -33,17 +34,18 @@ Scope::Scope(std::shared_ptr system, Worker &worker, Scope::Scope(std::shared_ptr system, Worker &worker, std::span devices) - : host_allocator_(system->host_allocator()), - scheduler_(*system), - system_(std::move(system)), + : system_(std::move(system)), + host_allocator_(system_->host_allocator()), + scheduler_(*system_), worker_(worker) { + logging::construct("Scope", this); for (auto *device : devices) { AddDevice(device->address().logical_device_class, device); } Initialize(); } -Scope::~Scope() = default; +Scope::~Scope() { logging::destruct("Scope", this); } std::string Scope::to_s() const { return fmt::format("Scope(worker='{}', devices=[{}])", worker_.name(), diff --git a/libshortfin/src/shortfin/local/scope.h b/libshortfin/src/shortfin/local/scope.h index fb6d74fd4..0cb566b89 100644 --- a/libshortfin/src/shortfin/local/scope.h +++ b/libshortfin/src/shortfin/local/scope.h @@ -28,19 +28,25 @@ class SHORTFIN_API Worker; // needed to do thing with some slice of device queues. class SHORTFIN_API ScopedDevice { public: + ScopedDevice() = default; ScopedDevice(Scope &scope, DeviceAffinity affinity) - : scope_(scope), affinity_(affinity) {} + : scope_(&scope), affinity_(affinity) {} ScopedDevice(Scope &scope, Device *device) - : scope_(scope), affinity_(device) {} + : scope_(&scope), affinity_(device) {} + ScopedDevice(const ScopedDevice &other) + : scope_(other.scope_), affinity_(other.affinity_) {} - Scope &scope() const { return scope_; } + Scope &scope() const { + assert(scope_ && "scope must not be null"); + return *scope_; + } DeviceAffinity affinity() const { return affinity_; } Device *raw_device() const { return affinity_.device(); } std::string to_s() const { return affinity().to_s(); } bool operator==(const ScopedDevice &other) const { - return (&scope_ == &other.scope_) && affinity_ == other.affinity_; + return (scope_ == other.scope_) && affinity_ == other.affinity_; } // Returns a future which will be satisfied when the primary device timeline @@ -49,7 +55,7 @@ class SHORTFIN_API ScopedDevice { CompletionEvent OnSync(bool flush = true); private: - Scope &scope_; + Scope *scope_ = nullptr; DeviceAffinity affinity_; }; @@ -85,6 +91,9 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { // All scopes are created as shared pointers. std::shared_ptr shared_ptr() { return shared_from_this(); } + // The host allocator. + iree_allocator_t host_allocator() { return host_allocator_; } + // The worker that this scope is bound to. Worker &worker() { return worker_; } @@ -120,7 +129,7 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { } detail::Scheduler &scheduler() { return scheduler_; } detail::TimelineResource::Ref NewTimelineResource() { - return scheduler().NewTimelineResource(host_allocator_); + return scheduler().NewTimelineResource(shared_ptr()); } // Loads a program from a list of modules onto the devices managed by this @@ -135,19 +144,19 @@ class SHORTFIN_API Scope : public std::enable_shared_from_this { void AddDevice(std::string_view device_class, Device *device); void Initialize(); // Called after all devices are added. - iree_allocator_t host_allocator_; + // Back reference to owning system. + std::shared_ptr system_; string_interner interner_; + iree_allocator_t host_allocator_; + detail::Scheduler scheduler_; + Worker &worker_; + // Map of `` to the count of that class contained. std::unordered_map device_class_count_; // Ordered devices. std::vector devices_; // Map of `` to Device. std::unordered_map named_devices_; - detail::Scheduler scheduler_; - - // Back reference to owning system. - std::shared_ptr system_; - Worker &worker_; }; } // namespace shortfin::local diff --git a/libshortfin/src/shortfin/local/system.cc b/libshortfin/src/shortfin/local/system.cc index 28c8c9654..2eaf3eaf7 100644 --- a/libshortfin/src/shortfin/local/system.cc +++ b/libshortfin/src/shortfin/local/system.cc @@ -19,6 +19,7 @@ namespace shortfin::local { System::System(iree_allocator_t host_allocator) : host_allocator_(host_allocator) { + logging::construct("System", this); SHORTFIN_THROW_IF_ERROR(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, host_allocator_, vm_instance_.for_output())); @@ -27,6 +28,7 @@ System::System(iree_allocator_t host_allocator) } System::~System() { + logging::destruct("System", this); bool needs_shutdown = false; { iree::slim_mutex_lock_guard guard(lock_); @@ -40,6 +42,21 @@ System::~System() { "explicitly for maximum stability."); Shutdown(); } + + // Orderly destruction of heavy-weight objects. + // Shutdown order is important so we don't leave it to field ordering. + vm_instance_.reset(); + + // Devices. + devices_.clear(); + named_devices_.clear(); + retained_devices_.clear(); + + // HAL drivers. + hal_drivers_.clear(); + + // If support for logging refs was compiled in, report now. + iree::detail::LogLiveRefs(); } void System::Shutdown() { @@ -63,20 +80,7 @@ void System::Shutdown() { } } blocking_executor_.Kill(); - local_workers.clear(); - - // Orderly destruction of heavy-weight objects. - // Shutdown order is important so we don't leave it to field ordering. - vm_instance_.reset(); - - // Devices. - devices_.clear(); - named_devices_.clear(); - retained_devices_.clear(); - - // HAL drivers. - hal_drivers_.clear(); } std::shared_ptr System::CreateScope(Worker &worker, @@ -180,7 +184,7 @@ void System::InitializeHalDriver(std::string_view moniker, throw std::logic_error(fmt::format( "Cannot register multiple hal drivers with moniker '{}'", moniker)); } - slot.reset(driver.release()); + slot = std::move(driver); } void System::InitializeHalDevice(std::unique_ptr device) { diff --git a/libshortfin/src/shortfin/local/system.h b/libshortfin/src/shortfin/local/system.h index 3a5bbfd86..cb5c70808 100644 --- a/libshortfin/src/shortfin/local/system.h +++ b/libshortfin/src/shortfin/local/system.h @@ -42,9 +42,40 @@ class SystemBuilder; // on some form of factory that constructs one to suit both the system being // executed on and any preferences on which resources should be accessible. // -// As the root of the hierarchy and the owner of numerous ancillary resources, -// we declare that System is always managed via a shared_ptr, as this -// simplifies many aspects of system management. +// Ownership +// --------- +// There are three levels of ownership, all rooted on the System: +// 1. System: The System class, all drivers, devices, workers, and executors. +// There will only ever be one (or a small number if doing something multi +// tenant), and all owning references to the System are via +// `std::shared_ptr`. Every object in the system must either be +// a managed child of the system or own a system reference. +// 2. Scope: Binds any number of devices to a coherent schedule, rooted on +// a Worker. Scopes are independent of the system and there are generally +// as many as needed logical concurrency in the application. Each scope +// holds a system reference by way of a `std::shared_ptr`. These +// are still heavy-weight objects mostly created at initialization time +// and are therefore managed held as a `std::shared_ptr` by anything +// that depends on them. +// 3. TimelineResource: Any resource in the system (i.e. buffer, +// synchronization, object, etc) will hold a unique TimelineResource. These +// are light-weight objects managed via intrusive reference counting by +// their contained `TimelineResource::Ref` class. Each `TimelineResource` +// maintains a `std::shared_ptr` back reference to its owning +// scope. +// +// Leaf objects can have any lifetime that they wish, so long as they maintain +// an appropriate ownership reference into the System hierarchy above. This +// includes any application managed objects like arrays, storage, processes, +// messages, queues, etc. +// +// Lifetime debug logging can be enabled via compiler defines: +// SHORTFIN_LOG_LIFETIMES=1 : Enables constructor/destructor and this pointer +// logging for the primary objects in the system hierarchy. +// SHORTFIN_IREE_LOG_RC=1 : Enables the application view of IREE object +// reference counting, showing steal/retain/release and the number of +// references the application holds for each object. Also will log any +// outstanding references when the System is deallocated. class SHORTFIN_API System : public std::enable_shared_from_this { public: System(iree_allocator_t host_allocator); diff --git a/libshortfin/src/shortfin/local/worker.h b/libshortfin/src/shortfin/local/worker.h index 585e92d90..52f5e5948 100644 --- a/libshortfin/src/shortfin/local/worker.h +++ b/libshortfin/src/shortfin/local/worker.h @@ -73,6 +73,8 @@ class SHORTFIN_API Worker { Worker(Options options); Worker(const Worker &) = delete; + Worker &operator=(const Worker &) = delete; + Worker(Worker &&) = delete; ~Worker(); const Options &options() const { return options_; } diff --git a/libshortfin/src/shortfin/support/CMakeLists.txt b/libshortfin/src/shortfin/support/CMakeLists.txt index ec481d2d5..cbe6df89b 100644 --- a/libshortfin/src/shortfin/support/CMakeLists.txt +++ b/libshortfin/src/shortfin/support/CMakeLists.txt @@ -31,7 +31,7 @@ shortfin_cc_component( ) shortfin_gtest_test( - NAME support_test + NAME shortfin_support_test SRCS # Order is specific: lower level tests before higher level. iree_helpers_test.cc diff --git a/libshortfin/src/shortfin/support/blocking_executor.cc b/libshortfin/src/shortfin/support/blocking_executor.cc index fc739ec0c..fde3cc593 100644 --- a/libshortfin/src/shortfin/support/blocking_executor.cc +++ b/libshortfin/src/shortfin/support/blocking_executor.cc @@ -59,9 +59,18 @@ void BlockingExecutor::Kill(bool wait, iree_timeout_t warn_timeout) { iree::slim_mutex_lock_guard g(control_mu_); last_live_thread_count = live_thread_count_; total_thread_count = created_thread_count_; + // If transitioned to 0 live threads, there is a short period of time + // that can exist between the scan of the free list above and a task + // getting scheduled. Therefore, the first time we hit this condition, + // enter the inhibited state, which denies further scheduling. Then + // the next time we encounter no live threads, that will be a true + // count. if (live_thread_count_ == 0) { - inhibit_ = true; - break; + if (inhibit_) { + break; + } else { + inhibit_ = true; + } } } diff --git a/libshortfin/src/shortfin/support/blocking_executor_test.cc b/libshortfin/src/shortfin/support/blocking_executor_test.cc index 78f99cf4a..92a9b31f5 100644 --- a/libshortfin/src/shortfin/support/blocking_executor_test.cc +++ b/libshortfin/src/shortfin/support/blocking_executor_test.cc @@ -13,7 +13,13 @@ namespace shortfin { -TEST(BlockingExecutor, concurrent_tasks) { +class BlockingExecutorTest : public testing::Test { + protected: + void SetUp() override {} + void TearDown() override { iree::detail::LogLiveRefs(); } +}; + +TEST_F(BlockingExecutorTest, concurrent_tasks) { { std::atomic tasks_run{0}; @@ -33,7 +39,7 @@ TEST(BlockingExecutor, concurrent_tasks) { } } -TEST(BlockingExecutor, inhibit_when_shutdown) { +TEST_F(BlockingExecutorTest, inhibit_when_shutdown) { { std::atomic tasks_run{0}; @@ -46,6 +52,7 @@ TEST(BlockingExecutor, inhibit_when_shutdown) { } executor.Kill(/*wait=*/true); + logging::info("Killed"); // New work should be inhibited. try { @@ -57,7 +64,7 @@ TEST(BlockingExecutor, inhibit_when_shutdown) { } } -TEST(BlockingExecutor, warn_deadline) { +TEST_F(BlockingExecutorTest, warn_deadline) { { std::atomic tasks_run{0}; @@ -75,7 +82,7 @@ TEST(BlockingExecutor, warn_deadline) { } } -TEST(BlockingExecutor, threads_recycle) { +TEST_F(BlockingExecutorTest, threads_recycle) { { std::atomic tasks_run{0}; diff --git a/libshortfin/src/shortfin/support/iree_concurrency.h b/libshortfin/src/shortfin/support/iree_concurrency.h index 6ccd1792e..28ef1e99b 100644 --- a/libshortfin/src/shortfin/support/iree_concurrency.h +++ b/libshortfin/src/shortfin/support/iree_concurrency.h @@ -18,8 +18,15 @@ namespace shortfin::iree { namespace detail { struct thread_ptr_helper { - static void retain(iree_thread_t *obj) { iree_thread_retain(obj); } - static void release(iree_thread_t *obj) { iree_thread_release(obj); } + static void steal(iree_thread_t *obj) { LogIREESteal("iree_thread_t", obj); } + static void retain(iree_thread_t *obj) { + LogIREERetain("iree_thread_t", obj); + iree_thread_retain(obj); + } + static void release(iree_thread_t *obj) { + LogIREERelease("iree_thread_t", obj); + iree_thread_release(obj); + } }; }; // namespace detail diff --git a/libshortfin/src/shortfin/support/iree_helpers.cc b/libshortfin/src/shortfin/support/iree_helpers.cc index 8344377b4..d518e99c3 100644 --- a/libshortfin/src/shortfin/support/iree_helpers.cc +++ b/libshortfin/src/shortfin/support/iree_helpers.cc @@ -6,8 +6,81 @@ #include "shortfin/support/iree_helpers.h" +#include + +#include +#include + +#include "shortfin/support/iree_concurrency.h" +#include "shortfin/support/logging.h" + namespace shortfin::iree { +namespace detail { + +#if SHORTFIN_IREE_LOG_RC + +slim_mutex log_mutex; +std::unordered_map app_ref_counts; + +void LogIREERetain(const char *type_name, void *ptr) { + slim_mutex_lock_guard g(log_mutex); + std::string key = fmt::format("{}({})", type_name, ptr); + int &rc = app_ref_counts[key]; + rc += 1; + if (rc == 1) { + logging::info("IREE new {}", key); + } else { + logging::info("IREE retain {} = {}", key, rc); + } +} + +void LogIREERelease(const char *type_name, void *ptr) { + slim_mutex_lock_guard g(log_mutex); + std::string key = fmt::format("{}({})", type_name, ptr); + int &rc = app_ref_counts[key]; + rc -= 1; + if (rc == 0) { + logging::info("IREE delete {}", key); + } else { + logging::info("IREE release {} = {}", key, rc); + } +} + +void LogIREESteal(const char *type_name, void *ptr) { + slim_mutex_lock_guard g(log_mutex); + std::string key = fmt::format("{}({})", type_name, ptr); + int &rc = app_ref_counts[key]; + rc += 1; + if (rc == 1) { + logging::info("IREE steal {}", key); + } else { + logging::info("IREE retain {} = {}", key, rc); + } +} + +void SHORTFIN_API LogLiveRefs() { + slim_mutex_lock_guard g(log_mutex); + bool logged_banner = false; + for (auto &it : app_ref_counts) { + if (it.second == 0) continue; + if (it.second < 0) { + logging::error("Shortfin IREE negative reference count: {} = {}", + it.first, it.second); + continue; + } + if (!logged_banner) { + logged_banner = true; + logging::warn("Shortfin visible live IREE refs remain:"); + } + logging::warn(" Live IREE ref {} = {}", it.first, it.second); + } +} + +#endif + +} // namespace detail + error::error(std::string message, iree_status_t failing_status) : message_(std::move(message)), failing_status_(failing_status) { message_.append(": "); @@ -19,7 +92,7 @@ void error::AppendStatus() const noexcept { status_appended_ = false; iree_allocator_t allocator = iree_allocator_system(); - char* status_buffer = nullptr; + char *status_buffer = nullptr; iree_host_size_t length = 0; if (iree_status_to_string(failing_status_, &allocator, &status_buffer, &length)) { diff --git a/libshortfin/src/shortfin/support/iree_helpers.h b/libshortfin/src/shortfin/support/iree_helpers.h index 8cbe368fd..c77ddbaa8 100644 --- a/libshortfin/src/shortfin/support/iree_helpers.h +++ b/libshortfin/src/shortfin/support/iree_helpers.h @@ -17,6 +17,10 @@ #include "iree/vm/api.h" #include "shortfin/support/api.h" +#if !defined(SHORTFIN_IREE_LOG_RC) +#define SHORTFIN_IREE_LOG_RC 0 +#endif + namespace shortfin { // -------------------------------------------------------------------------- // @@ -36,59 +40,142 @@ namespace iree { namespace detail { +#if SHORTFIN_IREE_LOG_RC +void SHORTFIN_API LogIREERetain(const char *type_name, void *ptr); +void SHORTFIN_API LogIREERelease(const char *type_name, void *ptr); +void SHORTFIN_API LogIREESteal(const char *type_name, void *ptr); +void SHORTFIN_API LogLiveRefs(); +#else +inline void LogIREERetain(const char *type_name, void *ptr) {} +inline void LogIREERelease(const char *type_name, void *ptr) {} +inline void LogIREESteal(const char *type_name, void *ptr) {} +inline void LogLiveRefs() {} +#endif + struct hal_buffer_ptr_helper { - static void retain(iree_hal_buffer_t *obj) { iree_hal_buffer_retain(obj); } - static void release(iree_hal_buffer_t *obj) { iree_hal_buffer_release(obj); } + static void steal(iree_hal_buffer_t *obj) { + LogIREESteal("iree_hal_buffer_t", obj); + } + static void retain(iree_hal_buffer_t *obj) { + LogIREERetain("iree_hal_buffer_t", obj); + iree_hal_buffer_retain(obj); + } + static void release(iree_hal_buffer_t *obj) { + LogIREERelease("iree_hal_buffer_t", obj); + iree_hal_buffer_release(obj); + } }; struct hal_command_buffer_helper { + static void steal(iree_hal_command_buffer_t *obj) { + LogIREESteal("iree_hal_command_buffer_t", obj); + } static void retain(iree_hal_command_buffer_t *obj) { + LogIREERetain("iree_hal_command_buffer_t", obj); iree_hal_command_buffer_retain(obj); } static void release(iree_hal_command_buffer_t *obj) { + LogIREERelease("iree_hal_command_buffer_t", obj); iree_hal_command_buffer_release(obj); } }; struct hal_device_ptr_helper { - static void retain(iree_hal_device_t *obj) { iree_hal_device_retain(obj); } - static void release(iree_hal_device_t *obj) { iree_hal_device_release(obj); } + static void steal(iree_hal_device_t *obj) { + LogIREESteal("iree_hal_device_t", obj); + } + static void retain(iree_hal_device_t *obj) { + LogIREERetain("iree_hal_device_t", obj); + iree_hal_device_retain(obj); + } + static void release(iree_hal_device_t *obj) { + LogIREERelease("iree_hal_device_t", obj); + iree_hal_device_release(obj); + } }; struct hal_driver_ptr_helper { - static void retain(iree_hal_driver_t *obj) { iree_hal_driver_retain(obj); } - static void release(iree_hal_driver_t *obj) { iree_hal_driver_release(obj); } + static void steal(iree_hal_driver_t *obj) { + LogIREESteal("iree_hal_driver_t", obj); + } + static void retain(iree_hal_driver_t *obj) { + LogIREERetain("iree_hal_driver_t", obj); + iree_hal_driver_retain(obj); + } + static void release(iree_hal_driver_t *obj) { + LogIREERelease("iree_hal_driver_t", obj); + iree_hal_driver_release(obj); + } }; struct hal_fence_ptr_helper { - static void retain(iree_hal_fence_t *obj) { iree_hal_fence_retain(obj); } - static void release(iree_hal_fence_t *obj) { iree_hal_fence_release(obj); } + static void steal(iree_hal_fence_t *obj) { + LogIREESteal("iree_hal_fence_t", obj); + } + static void retain(iree_hal_fence_t *obj) { + LogIREERetain("iree_hal_fence_t", obj); + iree_hal_fence_retain(obj); + } + static void release(iree_hal_fence_t *obj) { + LogIREERelease("iree_hal_fence_t", obj); + iree_hal_fence_release(obj); + } }; struct hal_semaphore_ptr_helper { + static void steal(iree_hal_semaphore_t *obj) { + LogIREESteal("iree_hal_semaphore_t", obj); + } static void retain(iree_hal_semaphore_t *obj) { + LogIREERetain("iree_hal_semaphore_t", obj); iree_hal_semaphore_retain(obj); } static void release(iree_hal_semaphore_t *obj) { + LogIREERelease("iree_hal_semaphore_t", obj); iree_hal_semaphore_release(obj); } }; struct vm_context_ptr_helper { - static void retain(iree_vm_context_t *obj) { iree_vm_context_retain(obj); } - static void release(iree_vm_context_t *obj) { iree_vm_context_release(obj); } + static void steal(iree_vm_context_t *obj) { + LogIREESteal("iree_vm_context_t", obj); + } + static void retain(iree_vm_context_t *obj) { + LogIREERetain("iree_vm_context_t", obj); + iree_vm_context_retain(obj); + } + static void release(iree_vm_context_t *obj) { + LogIREERelease("iree_vm_context_t", obj); + iree_vm_context_release(obj); + } }; struct vm_instance_ptr_helper { - static void retain(iree_vm_instance_t *obj) { iree_vm_instance_retain(obj); } + static void steal(iree_vm_instance_t *obj) { + LogIREESteal("iree_vm_instance_t", obj); + } + static void retain(iree_vm_instance_t *obj) { + LogIREERetain("iree_vm_instance_t", obj); + iree_vm_instance_retain(obj); + } static void release(iree_vm_instance_t *obj) { + LogIREERelease("iree_vm_instance_t", obj); iree_vm_instance_release(obj); } }; struct vm_module_ptr_helper { - static void retain(iree_vm_module_t *obj) { iree_vm_module_retain(obj); } - static void release(iree_vm_module_t *obj) { iree_vm_module_release(obj); } + static void steal(iree_vm_module_t *obj) { + LogIREESteal("iree_vm_module_t", obj); + } + static void retain(iree_vm_module_t *obj) { + LogIREERetain("iree_vm_module_t", obj); + iree_vm_module_retain(obj); + } + static void release(iree_vm_module_t *obj) { + LogIREERelease("iree_vm_module_t", obj); + iree_vm_module_release(obj); + } }; }; // namespace detail @@ -105,41 +192,60 @@ class object_ptr { } } object_ptr(object_ptr &&other) : ptr(other.ptr) { other.ptr = nullptr; } + object_ptr &operator=(const object_ptr &other) = delete; object_ptr &operator=(object_ptr &&other) { + reset(); ptr = other.ptr; other.ptr = nullptr; return *this; } - ~object_ptr() { - if (ptr) { - Helper::release(ptr); - } - } + ~object_ptr() { reset(); } // Constructs a new object_ptr by transferring ownership of a raw // pointer. - static object_ptr steal_reference(T *owned) { return object_ptr(owned); } + static object_ptr steal_reference(T *owned) { + Helper::steal(owned); + return object_ptr(owned); + } + // Constructs a new object_ptr by retaining a raw pointer. static object_ptr borrow_reference(T *owned) { Helper::retain(owned); return object_ptr(owned); } operator T *() const noexcept { return ptr; } + class Assignment { + public: + explicit Assignment(object_ptr *assign) : assign(assign) {} + ~Assignment() { + if (assign->ptr) { + Helper::steal(assign->ptr); + } + } + + constexpr operator T **() noexcept { + return reinterpret_cast(&assign->ptr); + } + + private: + object_ptr *assign = nullptr; + }; + // Releases any current reference held by this instance and returns a // pointer to the raw backing pointer. This is typically used for passing // to out parameters which are expected to store a new owned pointer directly. - T **for_output() { + constexpr Assignment for_output() noexcept { reset(); - return &ptr; + return Assignment(this); } operator bool() const { return ptr != nullptr; } T *get() const { return ptr; } - void reset(T *other = nullptr) { + void reset() { if (ptr) { Helper::release(ptr); } - ptr = other; + ptr = nullptr; } T *release() { T *ret = ptr; @@ -151,6 +257,8 @@ class object_ptr { // Assumes the reference count for owned_ptr. object_ptr(T *owned_ptr) : ptr(owned_ptr) {} T *ptr = nullptr; + + friend class Assignment; }; using hal_buffer_ptr = diff --git a/libshortfin/src/shortfin/support/iree_helpers_test.cc b/libshortfin/src/shortfin/support/iree_helpers_test.cc index a13b81b72..bf059ee98 100644 --- a/libshortfin/src/shortfin/support/iree_helpers_test.cc +++ b/libshortfin/src/shortfin/support/iree_helpers_test.cc @@ -29,6 +29,7 @@ struct iree_dummy_t { }; struct dummy_ptr_helper { + static void steal(iree_dummy_t *obj) {} static void retain(iree_dummy_t *obj) { obj->retain_count++; } static void release(iree_dummy_t *obj) { obj->release_count++; } }; diff --git a/libshortfin/src/shortfin/support/logging.h b/libshortfin/src/shortfin/support/logging.h index 55bd36347..337ebacae 100644 --- a/libshortfin/src/shortfin/support/logging.h +++ b/libshortfin/src/shortfin/support/logging.h @@ -9,6 +9,10 @@ #include "spdlog/spdlog.h" +#if !defined(SHORTFIN_LOG_LIFETIMES) +#define SHORTFIN_LOG_LIFETIMES 0 +#endif + namespace shortfin::logging { // TODO: Re-export doesn't really work like this. Need to define API @@ -18,6 +22,22 @@ using spdlog::error; using spdlog::info; using spdlog::warn; +#if SHORTFIN_LOG_LIFETIMES +template +inline void construct(const char* type_name, T* inst) { + info("new {}({})", type_name, static_cast(inst)); +} +template +inline void destruct(const char* type_name, T* inst) { + info("delete {}({})", type_name, static_cast(inst)); +} +#else +template +inline void construct(const char *type_name, T *) {} +template +inline void destruct(const char *type_name, T *) {} +#endif + } // namespace shortfin::logging #endif // SHORTFIN_SUPPORT_LOGGING_H diff --git a/libshortfin/tests/amdgpu_system_test.py b/libshortfin/tests/amdgpu_system_test.py index 74ea69af2..4b887ea54 100644 --- a/libshortfin/tests/amdgpu_system_test.py +++ b/libshortfin/tests/amdgpu_system_test.py @@ -4,8 +4,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest -def test_create_host_cpu_system(): + +@pytest.mark.requires_amd_gpu +def test_create_amd_gpu_system(): from _shortfin import lib as sfl sc = sfl.local.amdgpu.SystemBuilder() @@ -15,3 +18,5 @@ def test_create_host_cpu_system(): print(f" DEVICE: {device_name} = {ls.device(device_name)}") print(ls.devices) + print("Shutting down") + ls.shutdown() diff --git a/libshortfin/tests/array_test.py b/libshortfin/tests/array_test.py index 41cf51aa8..9f53da1c3 100644 --- a/libshortfin/tests/array_test.py +++ b/libshortfin/tests/array_test.py @@ -8,13 +8,15 @@ import pytest import time -from _shortfin import lib as sfl +import shortfin as sf @pytest.fixture def lsys(): - sc = sfl.local.host.CPUSystemBuilder() - return sc.create_system() + sc = sf.host.CPUSystemBuilder() + lsys = sc.create_system() + yield lsys + lsys.shutdown() @pytest.fixture @@ -25,32 +27,79 @@ def scope(lsys): def test_storage(scope): - storage = sfl.array.storage.allocate_device(scope.device(0), 32) + storage = sf.array.storage.allocate_host(scope.device(0), 32) print(storage) - ary = sfl.array.device_array(storage, [2, 4], sfl.array.float32) + ary = sf.array.device_array(storage, [2, 4], sf.array.float32) print(ary) print(ary.shape) assert ary.shape == [2, 4] - assert ary.dtype == sfl.array.float32 + assert ary.dtype == sf.array.float32 + + print("ARY.DEVICE=", ary.device, ary.device.__class__) + print("SCOPE.DEVICE=", scope.device(0)) + print("EQ:", ary.device == scope.device(0)) + assert ary.device == scope.device(0) + # Mapping API contract. + with storage.map(read=True) as m: + assert m.valid + mv = memoryview(m) + assert len(mv) == 32 + assert not m.valid + + storage.data = array.array("f", [1.234534523] * 8) + print("WRITTEN:", ary) + + read_back = array.array("f") + read_back.frombytes(storage.data) + print("READ BACK:", read_back) + + +@pytest.mark.parametrize( + "dtype,code,py_value,expected_repr", + [ + (sf.array.int8, "b", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + (sf.array.int16, "h", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + (sf.array.int32, "i", 42, "{{42, 42, 42, 42},\n {42, 42, 42, 42}}"), + ( + sf.array.float32, + "f", + 42.0, + "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", + ), + ( + sf.array.float64, + "d", + 42.0, + "{{ 42., 42., 42., 42.},\n { 42., 42., 42., 42.}}", + ), + ], +) +def test_xtensor_types(scope, dtype, code, py_value, expected_repr): + ary = sf.array.device_array.for_host(scope.device(0), [2, 4], dtype) + ary.storage.data = array.array(code, [py_value] * 8) + r = repr(ary) + print(r) + assert expected_repr in r, f"Expected '{expected_repr}' in '{r}'" + def test_device_array(scope): - ary1 = sfl.array.device_array(scope.device(0), [32, 1, 4], sfl.array.float32) + ary1 = sf.array.device_array(scope.device(0), [32, 1, 4], sf.array.float32) print(ary1) assert ary1.shape == [32, 1, 4] - assert ary1.dtype == sfl.array.float32 + assert ary1.dtype == sf.array.float32 assert scope.device(0) == ary1.device - hary1 = sfl.array.host_array(ary1) + hary1 = sf.array.device_array.for_transfer(ary1) print(hary1) - assert isinstance(hary1, sfl.array.host_array) + assert isinstance(hary1, sf.array.device_array) assert hary1.shape == ary1.shape assert hary1.dtype == ary1.dtype assert hary1.device == ary1.device def test_device_array_fill(scope): - ary1 = sfl.array.device_array(scope.device(0), [32, 1, 4], sfl.array.int32) - ary1.storage.fill(array.array("i", [0])) + ary1 = sf.array.device_array(scope.device(0), [32, 1, 4], sf.array.int32) + ary1.storage.fill(array.array("i", [42])) # TODO: Transfer to host and verify. diff --git a/libshortfin/tests/examples_test.py b/libshortfin/tests/examples/async_test.py similarity index 93% rename from libshortfin/tests/examples_test.py rename to libshortfin/tests/examples/async_test.py index 54b815cc4..1595d7d8e 100644 --- a/libshortfin/tests/examples_test.py +++ b/libshortfin/tests/examples/async_test.py @@ -11,7 +11,7 @@ import subprocess import sys -project_dir = Path(__file__).resolve().parent.parent +project_dir = Path(__file__).resolve().parent.parent.parent example_dir = project_dir / "examples" / "python" diff --git a/libshortfin/tests/examples/fastapi_test.py b/libshortfin/tests/examples/fastapi_test.py new file mode 100644 index 000000000..f19c1c12f --- /dev/null +++ b/libshortfin/tests/examples/fastapi_test.py @@ -0,0 +1,109 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from contextlib import closing +import os +from pathlib import Path +import pytest +import requests +import socket +import subprocess +import sys +import time + +project_dir = Path(__file__).resolve().parent.parent.parent +example_dir = project_dir / "examples" / "python" + + +@pytest.fixture(scope="session") +def server(): + runner = ServerRunner([]) + yield runner + print("Sending kill signal") + runner.process.terminate() + print("Waiting for server to exit") + runner.process.wait(20) + + +# Test error first to make sure it doesn't mess up the server. +def test_error_response(server): + resp = requests.get(f"{server.url}/predict?value=0") + assert resp.status_code == 500 + + +def test_single_response(server): + resp = requests.get(f"{server.url}/predict?value=1") + resp.raise_for_status() + full_contents = resp.content + print(full_contents) + assert full_contents == b'{"answer":1}' + + +def test_stream_response(server): + resp = requests.get(f"{server.url}/predict?value=20") + resp.raise_for_status() + full_contents = resp.content + print(full_contents) + exp_contents = ("".join(['{"answer": %s}\n\x00' % i for i in range(21)])).encode() + assert full_contents == exp_contents + + +class ServerRunner: + def __init__(self, args): + port = str(find_free_port()) + self.url = "http://localhost:" + port + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.process = subprocess.Popen( + [ + sys.executable, + str(example_dir / "fastapi" / "server.py"), + "--port=" + port, + ] + + args, + env=env, + # TODO: Have a more robust way of forking a subprocess. + cwd=str(example_dir), + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_ready() + + def _wait_for_ready(self): + start = time.time() + while True: + try: + if requests.get(f"{self.url}/health").status_code == 200: + return + except Exception as e: + if self.process.poll() is not None: + raise RuntimeError("API server processs terminated") from e + time.sleep(1.0) + if (time.time() - start) > 30: + raise RuntimeError("Timeout waiting for server start") + + def __del__(self): + try: + process = self.process + except AttributeError: + pass + else: + process.terminate() + process.wait() + + +def find_free_port(): + """This tries to find a free port to run a server on for the test. + + Race conditions are possible - the port can be acquired between when this + runs and when the server starts. + + https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("localhost", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] diff --git a/libshortfin/tests/local_scope_test.py b/libshortfin/tests/local_scope_test.py index de3598711..9f56e7833 100644 --- a/libshortfin/tests/local_scope_test.py +++ b/libshortfin/tests/local_scope_test.py @@ -13,7 +13,9 @@ @pytest.fixture def lsys(): sc = sfl.local.host.CPUSystemBuilder() - return sc.create_system() + ls = sc.create_system() + yield ls + ls.shutdown() @pytest.fixture diff --git a/libshortfin/tests/smoke_test.py b/libshortfin/tests/smoke_test.py deleted file mode 100644 index a066d6eb7..000000000 --- a/libshortfin/tests/smoke_test.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - - -def test_sfl_import(): - from _shortfin import lib as sfl - - sfl.initialize() diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 54b301160..78240d614 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -31,7 +31,7 @@ def main(): parser.add_argument( "--output-config", help="Output file path for exported config file", - default="/tmp/batch_llama_v1.json", + default="tmp/batch_llama_v1.json", ) parser.add_argument( "--bs", @@ -50,6 +50,7 @@ def main(): hp = configs.LlamaHParams.from_gguf_props(dataset.properties) llama_config = LlamaModelConfig(hp) + llama_config.static_tables = False # Rely on the compiler for hoisting tables. llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" model = PagedLlamaModelV1(dataset.root_theta, llama_config) diff --git a/sharktank/sharktank/examples/sharding/export_gemm.py b/sharktank/sharktank/examples/sharding/export_gemm.py new file mode 100644 index 000000000..7a4322e38 --- /dev/null +++ b/sharktank/sharktank/examples/sharding/export_gemm.py @@ -0,0 +1,107 @@ +import sys +from typing import List +import argparse +import torch +from torch import Tensor +from sharktank import ops +from shark_turbine import aot + + +def export_gemm( + mlir_path: str, + device_count: int, + m: int, + n: int, + k: int, + with_alpha: bool, + with_beta: bool, +): + class GemmModule(torch.nn.Module): + def forward(self, *args, **kwargs): + return ops.gemm(*args, **kwargs) + + a = torch.empty(m, k, dtype=torch.float32) + b = torch.empty(k, n, dtype=torch.float32) + c = torch.empty(m, n, dtype=torch.float32) + sharded_a = ops.reshard_split(a, dim=0, count=device_count) + sharded_b = ops.replicate(b, count=device_count) + sharded_c = ops.reshard_split(c, dim=0, count=device_count) + gemm_module = GemmModule() + kwargs = { + "a": sharded_a, + "b": sharded_b, + "c": sharded_c, + } + # Need to pass alpha and beta not as numbers, but as tensors since + # the IREE FX importer does not support ConstantArgument. + if with_alpha: + kwargs["alpha"] = torch.tensor(2.0, dtype=torch.float32) + if with_alpha: + kwargs["beta"] = torch.tensor(3.0, dtype=torch.float32) + torch_exported = torch.export.export(gemm_module, args=(), kwargs=kwargs) + export_output = aot.export(torch_exported) + export_output.save_mlir(mlir_path) + + +def export_gemm_cli(argv: List[str]): + parser = argparse.ArgumentParser( + description=""" +Export sharded GEMM to MLIR. +alpha * a @ b + beta * c +a is MxK matrix. +b is KxN matrix. +c is MxN matrix. +The sharded/split dimension is M. +a and c will be split across dimension 0 (M). +b will be replicated on all devices. +For n devices the exported function will have signature +(a0, a1, ..., an, b0, b1, ..., bn, c0, c1, ..., cn) -> (r0, r1, ..., rn), +where ai and ci are the respective shards on the i-th device. +bi is equal to b, but on the i-th device. +The caller must place the shards on the expected devices. + +The result is split along dimension M also, +where ri is on the i-th device. + +Support for --with-alpha and --with-beta is under construction. + +Example usage: +python export_gemm.py --device_count=2 --m=10, --k=20, --n=30 \\ + --mlir=sharded-gemm.mlir""", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--mlir", help="Path to the exported program.", type=str, required=True + ) + parser.add_argument( + "--device_count", help="Number of shards/devices", type=int, required=True + ) + parser.add_argument("--m", help="M", type=int, default=512) + parser.add_argument("--n", help="N", type=int, default=512) + parser.add_argument("--k", help="K", type=int, default=512) + parser.add_argument( + "--with-alpha", + help="Have alpha as an argument to the function signature", + default=False, + action="store_true", + ) + parser.add_argument( + "--with-beta", + help="Have alpha as an argument to the function signature", + default=False, + action="store_true", + ) + args = parser.parse_args(args=argv[1:]) + export_gemm( + mlir_path=args.mlir, + device_count=args.device_count, + m=args.m, + n=args.n, + k=args.k, + with_alpha=args.with_alpha, + with_beta=args.with_beta, + ) + + +if __name__ == "__main__": + export_gemm_cli(sys.argv) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 91f700789..d253af617 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -28,7 +28,8 @@ def __init__( theta: Theta, *, context_length: int, - static_context_mask: bool = True, + static_tables: bool = True, + static_context_mask: bool = False, device: Optional[torch.device] = None, activation_dtype: torch.dtype = torch.float32, attention_dtype: torch.dtype = torch.float32, @@ -39,7 +40,7 @@ def __init__( self.attention_dtype = attention_dtype self.context_length = context_length - if static_context_mask: + if static_tables: self.register_buffer( "causal_context_mask", self.generate_causal_context_mask() ) @@ -66,10 +67,12 @@ def _maximally_negative_value(self, dtype): def generate_causal_context_mask(self) -> torch.Tensor: context_length = self.context_length + unary_broadcast_ones = torch.ones([1, 1], dtype=torch.bool, device=self.device) + context_broadcast_ones = unary_broadcast_ones.expand( + context_length, context_length + ) causal_context_mask = torch.triu( - torch.ones( - [context_length, context_length], dtype=torch.bool, device=self.device - ), + context_broadcast_ones, diagonal=1, )[None, None, :, :] return causal_context_mask @@ -114,9 +117,11 @@ def attention_mask( scenarios can benefit from managing this in different ways. """ if causal_context_mask is None: + # Try to use the statically generated. causal_context_mask = self.causal_context_mask if causal_context_mask is None: - causal_context_mask = self._generate_causal_context_mask() + # Fallback to dynamically generated. + causal_context_mask = self.generate_causal_context_mask() # Combine the causal context mask and input mask. dtype = self.attention_dtype diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 755392522..18984713d 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -21,14 +21,29 @@ def __init__( max_seqlen: int, device: Optional[torch.device] = None, use_hf: bool = False, + static_tables: bool = True, ): super().__init__() + # Force static_tables until compiler limitations are solved. + # See https://github.com/nod-ai/sharktank/issues/156 + static_tables = True self.device = device + self.rope_dimension_count = rope_dimension_count + self.max_seqlen = max_seqlen self.use_hf = use_hf - self._table = self._create_rotary_embed_table( - max_seqlen=max_seqlen, - dim=rope_dimension_count, - ) + if static_tables: + self.register_buffer( + "static_rotary_embed_table", self._create_rotary_embed_table() + ) + else: + self.static_rotary_embed_table = None + + @property + def rotary_embed_table(self): + if self.static_rotary_embed_table is None: + return self._create_rotary_embed_table() + else: + return self.static_rotary_embed_table def forward(self, *, xq: torch.Tensor, xk: torch.Tensor, start_index: int): # xq_, xk_ shape: bs, sl, _, dim @@ -80,7 +95,7 @@ def create_ordering_tensor(dim): _, sl, _, dim = xq_.shape # Offset the table based on starting position. - freqs_cis = self._table[start_index : start_index + sl, :] + freqs_cis = self.rotary_embed_table[start_index : start_index + sl, :] assert freqs_cis.shape[-1] == dim assert ( freqs_cis.shape[0] >= sl @@ -139,7 +154,7 @@ def compute_batch_mask( ) + start_positions.unsqueeze(1) # Broadcast lookup to [b, ...]. self.trace_tensor("rope.positions_seq", positions_seq) - freqs_cis = self._table[positions_seq] + freqs_cis = self.rotary_embed_table[positions_seq] # Unsqueeze a unit dim for attention heads. broadcast_freqs_cis = freqs_cis.unsqueeze(2) @@ -167,10 +182,10 @@ def apply_batched_mask( def _create_rotary_embed_table( self, - max_seqlen: int, - dim: int, theta_value: float = 10000.0, ): + dim = self.rope_dimension_count + max_seqlen = self.max_seqlen freqs = 1.0 / ( theta_value ** (torch.arange(0, dim, 2, device=self.device)[: (dim // 2)].float() / dim) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index ea3170122..984fc6524 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -52,6 +52,14 @@ class LlamaModelConfig: # rotary embedding). use_hf: bool = False + # If true, then the model may pre-initialize certain tables during + # init. This can be better for eager execution but when capturing a program, + # it is often better to preserve the calculation explicitly and rely on + # the compiler to transform it to an initialization time step. This can + # be the difference of many gigabytes of static data being embedded in + # the program and not. + static_tables: bool = True + def create_kv_cache(self) -> BaseKVCache: hp = self.hp if self.kv_cache_type == "direct": @@ -110,6 +118,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): super().__init__( theta, context_length=config.hp.context_length, + static_tables=config.static_tables, device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, @@ -131,6 +140,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): max_seqlen=hp.context_length, device=self.device, use_hf=self.use_hf, + static_tables=config.static_tables, ), ) self.add_module( @@ -500,7 +510,7 @@ def transact_cache_paged( xv_cache_update, ], transformer_block_index=self.block_index, - seq_positions=start_positions + 1, + seq_positions=start_positions, page_ids=seq_block_ids, ) diff --git a/sharktank/sharktank/ops/_registry.py b/sharktank/sharktank/ops/_registry.py index 66fa034f3..c519af75b 100644 --- a/sharktank/sharktank/ops/_registry.py +++ b/sharktank/sharktank/ops/_registry.py @@ -17,8 +17,10 @@ from ..types import PrimitiveTensor, QuantizedTensor __all__ = [ + "AllOfExprs", "AllOfType", "AnyOfType", + "IsOfType", "overridable", "SignatureDispatcher", "BoolTypeExpr", @@ -62,6 +64,29 @@ def __call__(self, *args: type) -> bool: return self._expr(*args) +class AllOfExprs(BoolTypeExpr): + """Returns True if all types match their respective boolean type expression. + + ```python + # True. int == int and str in (float, str). + AllOfExprs(IsOfType(int), IsOfType(float, str))(int, str) + + # False. str is not in (int, float). + AllOfExprs(IsOfType(int), IsOfType(int, float))(int, str) + ``` + """ + + def __init__(self, *exprs: BoolTypeExpr): + self._exprs = exprs + + def expr(*types: type): + if len(types) < len(self._exprs): + return False + return all([e(t) for e, t in zip(self._exprs, types)]) + + super().__init__(expr) + + class AllOfType(BoolTypeExpr): """Returns True if all of the types are from a set of types. @@ -109,6 +134,9 @@ def expr(*types: type): super().__init__(expr) +IsOfType = AllOfType + + class SignatureDispatcher: """Replaces an overridable function with a tensor type base dispatcher. @@ -201,7 +229,7 @@ def _is_type_expr_target( ): if len(override_type_spec) > 1: raise TypeError( - "Override with multiple arguments not allowed when using BoolTypeExpr." + f"Override with multiple arguments not allowed when using BoolTypeExpr. Type spec: {override_type_spec}" ) return True return False diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index ed11eb01a..4ca60cc49 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -7,15 +7,16 @@ # This file contains overrides of the standard ops for normal torch and # generic primitive/quantized types. -from typing import Optional, List, Sequence +from typing import Optional, List, Sequence, Union import torch from torch import Tensor, dtype import torch.nn.functional as F +from numbers import Number -from ..types import PrimitiveTensor, QuantizedTensor -from ..types.tensors import unbox_tensor -from ._registry import AllOfType +from ..types import PrimitiveTensor, QuantizedTensor, InferenceTensor +from ..types.tensors import unbox_tensor, AnyTensor +from ._registry import AllOfType, AllOfExprs, IsOfType from .signatures import * @@ -60,7 +61,6 @@ def conv2d_default( conv2d.override(Tensor, Tensor, Tensor, auto_dequant=True)(conv2d_default) conv2d.override(Tensor, Tensor, auto_dequant=True)(conv2d_default) - # Elementwise @elementwise.override(Tensor) def elementwise_unary(operator, x): @@ -68,10 +68,15 @@ def elementwise_unary(operator, x): return operator(x) -@elementwise.override(Tensor, Tensor) +@elementwise.override( + AllOfExprs( + IsOfType(Tensor, PrimitiveTensor), IsOfType(Tensor, PrimitiveTensor, Number) + ) +) def elementwise_binary(operator, x, y): x = unbox_tensor(x) - y = unbox_tensor(y) + if isinstance(y, PrimitiveTensor): + y = unbox_tensor(y) return operator(x, y) @@ -94,6 +99,31 @@ def equal_default(a, b) -> bool: return torch.equal(unbox_tensor(a), unbox_tensor(b)) +@gemm.override(AllOfType(Tensor, InferenceTensor)) +def gemm( + a: AnyTensor, + b: AnyTensor, + c: Optional[AnyTensor], + alpha: Optional[Union[Number, AnyTensor]], + beta: Optional[Union[Number, AnyTensor]], + transa: bool, + transb: bool, +) -> bool: + if transa: + a = a.T + if transb: + b = b.T + res = matmul(a, b) + if alpha is not None: + res = alpha * res + if c is not None: + if beta is not None: + res = res + beta * c + else: + res = res + c + return res + + # Group norm. @group_norm_affine.override(Tensor, Tensor, Tensor) def group_norm_affine_default(input, weight, bias, *, num_groups, eps): diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 4fb85d1ff..b1ef57090 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -8,6 +8,7 @@ from torch import Tensor from typing import List, Optional, Sequence import itertools +from numbers import Number from ..types import ( AnyTensor, @@ -248,6 +249,22 @@ def split_elementwise_binary( return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) +@elementwise.override(SplitPrimitiveTensor, Number) +def elementwise_binary_split_lhs_scalar_rhs( + operator, x: SplitPrimitiveTensor, y: Number +): + pt_xs = [unbox_tensor(pt) for pt in x.shards] + partials = [operator(pt_x, y) for pt_x in pt_xs] + return SplitPrimitiveTensor(shard_dim=x.shard_dim, shape=x.shape, ts=partials) + + +@elementwise.override(SplitPrimitiveTensor, Tensor) +def elementwise_binary_split_lhs_tensor_rhs( + operator, x: SplitPrimitiveTensor, y: Tensor +): + return elementwise(operator, x, replicate(y, count=x.shard_count)) + + @elementwise.override(ReplicatedTensor, SplitPrimitiveTensor) def elementwise_binary_replicated_lhs_sharder_rhs( operator, x: ReplicatedTensor, y: SplitPrimitiveTensor @@ -264,8 +281,9 @@ def elementwise_binary_replicated_lhs_sharder_rhs( @elementwise.override(SplitPrimitiveTensor, ReplicatedTensor) def elementwise_binary_split_lhs_replicated_rhs( - operator, x: ReplicatedTensor, y: SplitPrimitiveTensor + operator, x: SplitPrimitiveTensor, y: ReplicatedTensor ): + assert len(y.shape) > 0, "0-rank not supported" if x.shard_count != y.shard_count: raise ValueError( f"Operands' number of shards not equal ({x.shard_count} != {y.shard_count})" diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 7595578ef..07ae56bb1 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -12,6 +12,7 @@ import numbers from torch import Tensor, dtype from ..types import AnyTensor, ShardedTensor, Theta, sharding +from numbers import Number from ._registry import * @@ -22,6 +23,7 @@ "elementwise", "embedding_lookup", "equal", + "gemm", "group_norm_affine", "layer_norm", "interpolate", @@ -210,6 +212,45 @@ def _equal_trampoline(d: SignatureDispatcher, a: AnyTensor, b: AnyTensor): d.fail(tensors) +@overridable +def gemm( + a: AnyTensor, + b: AnyTensor, + c: Optional[AnyTensor] = None, + alpha: Optional[Union[Number, AnyTensor]] = None, + beta: Optional[Union[Number, AnyTensor]] = None, + transa: bool = False, + transb: bool = False, +): + """GEMM as defined by BLAS. + `alpha*a*b + beta*c` + If `c` is None it is the zero-filed tensor. + """ + raise NotImplementedError + + +@gemm.trampoline +def _gemm_trampoline( + d: SignatureDispatcher, + a: AnyTensor, + b: AnyTensor, + c: Optional[AnyTensor] = None, + alpha: Optional[Union[Number, AnyTensor]] = None, + beta: Optional[Union[Number, AnyTensor]] = None, + transa: bool = False, + transb: bool = False, +): + tensors = (a, b, c) + for override in d.find_overrides(tensors): + result = override( + a=a, b=b, c=c, alpha=alpha, beta=beta, transa=transa, transb=transb + ) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + @overridable def group_norm_affine( input: AnyTensor, weight: AnyTensor, bias: AnyTensor, *, num_groups: int, eps: float diff --git a/sharktank/sharktank/types/gguf_interop/base.py b/sharktank/sharktank/types/gguf_interop/base.py index 315a0cd84..a343e333c 100644 --- a/sharktank/sharktank/types/gguf_interop/base.py +++ b/sharktank/sharktank/types/gguf_interop/base.py @@ -125,9 +125,10 @@ def load_file(gguf_path: Union[str, os.PathLike]) -> Dataset: # Extract tensors. tensors: dict[str, InferenceTensor] = {} for tensor in reader.tensors: + shape = [int(d) for d in tensor.shape] gguf_tensor = _wrap_tensor( name=tensor.name, - logical_shape=list(tensor.shape), + logical_shape=list(shape), type_name=tensor.tensor_type.name, data=tensor.data, # type: ignore ) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 6acb9e8d2..b48fb1b52 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -284,6 +284,21 @@ def __add__(self, rhs): return elementwise(torch.add, self, rhs) + def __radd__(self, lhs): + # Assumes commutative addition due to torch.elementwise not handling numbers on + # the lhs. + return self.__add__(lhs) + + def __mul__(self, rhs): + from ..ops import elementwise + + return elementwise(torch.mul, self, rhs) + + def __rmul__(self, lhs): + # Assumes commutative multiplication due to torch.elementwise not handling + # numbers on the lhs. + return self.__mul__(lhs) + REGISTERED_INFERENCE_TENSOR_CLASSES: dict[str, Type[InferenceTensor]] = {} diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py new file mode 100644 index 000000000..3953b951b --- /dev/null +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -0,0 +1,288 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import torch +import torch.nn as nn +from sharktank.models.llama.llama import ( + PagedLlamaAttentionBlock, + PagedKVCache, + DirectKVCache, +) +from sharktank.models.llama.testing import * +from sharktank.layers.rotary_embedding import RotaryEmbeddingLayer +from sharktank.layers import causal_llm + + +class KVCacheTest(unittest.TestCase): + def setUp(self): + self.block_count = 5 + self.seq_len = 16 + self.head_count = 32 + self.head_dim = 128 + self.ffn_dim = 11008 + self.head_count_kv = 32 + self.block_seq_stride = 16 + self.rms_epsilon = 1e-5 + self.rope_dimension_count = 128 + self.max_seq_len = 4096 + self.start_positions = torch.tensor([8]) + self.bs = 1 + self.device = "cpu" + self.attention_dtype = torch.float32 + self.attention_block_theta = make_attention_block_theta( + feature_dim=self.head_count * self.head_dim, + ffn_dim=self.ffn_dim, + dtype=self.attention_dtype, + ) + self.paged_kv_cache = PagedKVCache( + transformer_block_count=self.head_count, + attn_head_count=self.head_count, + attn_head_dim=self.head_dim, + cache_partition_count=2, # One for each of K/V. + block_seq_stride=self.block_seq_stride, + device=self.device, + dtype=self.attention_dtype, + ) + self.direct_kv_cache = DirectKVCache( + block_seq_stride=self.block_seq_stride, + transformer_block_count=self.head_count, + attn_head_count=self.head_count, + attn_head_dim=self.head_dim, + seq_length=self.max_seq_len, + device=self.device, + dtype=self.attention_dtype, + ) + self.attention_embedding = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seq_len, + device=self.device, + use_hf=False, + ) + self.paged_attn_blocks = nn.ModuleList( + [ + PagedLlamaAttentionBlock( + self.attention_block_theta, + block_index=n, + cache=self.paged_kv_cache, + head_count=self.head_count, + head_dim=self.head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + use_hf=False, + ) + for n in range(self.block_count) + ] + ) + self.direct_attn_blocks = nn.ModuleList( + [ + PagedLlamaAttentionBlock( + theta=self.attention_block_theta, + block_index=n, + cache=self.direct_kv_cache, + head_count=self.head_count, + head_dim=self.head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, + use_hf=False, + ) + for n in range(self.block_count) + ] + ) + self.paged_cache_state = self.paged_kv_cache.allocate(page_count=128) + self.paged_seq_block_ids = torch.tensor( + [ + [127], + ] + ) + self.direct_cache_state = self.direct_kv_cache.allocate(bs=1) + self.direct_seq_block_ids = torch.tensor( + [ + [0], + ] + ) + self.embedding_batch_mask = self.attention_embedding.compute_batch_mask( + self.start_positions, batch_seq_len=1 + ) + self.model = causal_llm.BaseCausalLMModel( + self.attention_block_theta, context_length=self.max_seq_len + ) + self.prefill_attention_mask = self.model.attention_mask( + self.model.input_mask(self.start_positions, self.seq_len) + ) + + def testDirectAndPagedKVCachePrefill(self): + torch.set_default_dtype(torch.float32) + + paged_input_tensor = make_rand_torch( + (1, self.seq_len, self.head_count * self.head_dim), + dtype=self.attention_dtype, + ) + direct_input_tensor = paged_input_tensor.detach().clone() + # Iterate over paged attention blocks. + for block_idx, paged_block in enumerate(self.paged_attn_blocks): + paged_input_tensor = paged_block( + paged_input_tensor, + embedding=self.attention_embedding, + start_index=0, + attention_mask=self.prefill_attention_mask, + cache_state=self.paged_cache_state, + seq_block_ids=self.paged_seq_block_ids, + ) + # Iterate over direct attention blocks. + for block_idx, direct_block in enumerate(self.direct_attn_blocks): + direct_input_tensor = direct_block( + direct_input_tensor, + embedding=self.attention_embedding, + start_index=0, + attention_mask=self.prefill_attention_mask, + cache_state=self.direct_cache_state, + seq_block_ids=self.direct_seq_block_ids, + ) + page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) + index_written = self.start_positions.item() + """ + Getting the value of the paged_seq_block_ids, which is the page id we are writing + the K/V cache into. + """ + page_id = self.paged_seq_block_ids[0][0].item() + """ + direct_cache_state is a list of num_transformer_blocks * 2 (one for K and one for V), + so here we index into the first transformer block's keys with self.direct_cache_state[0] + and the first transformer block's values with self.direct_cache_state[1]. Each row + in direct_cache_state is a tensor of [bs, seq_len , attn_heads, attn_dim], so we make sure + the first 8 (start_position) tensors starting at sequence 0 of the seq_len are written to. + """ + updated_direct_cache_state = self.direct_cache_state[0][ + :, :index_written + ].squeeze(0) + """ + paged_cache_state is a list of a single tensor that represents a flattened page table. + Indexing into self.paged_cache_state[0] and unflattening the page table columns to a 6D tensor of: + * transformer block + * cache partition (K or V cache) + * block sequence stride (number of sequence positions per block) + * attention heads + * attention dimensionality + allows us to access the cache partitions for a certain transformer block and sequence in a + certain page_id. For example, page_table[page_id][0, 0, :index_written] lets us access the + first transformer block's K cache for the first 8 (start_positions) tensors starting at + sequence 0. + """ + updated_paged_cache_state = page_table[page_id][0, 0, :index_written] + assert updated_direct_cache_state.shape == updated_paged_cache_state.shape + torch.testing.assert_close( + updated_direct_cache_state, updated_paged_cache_state + ) + + paged_prefill_attn_output = paged_input_tensor + direct_prefill_attn_output = direct_input_tensor + assert paged_prefill_attn_output.shape == direct_prefill_attn_output.shape + torch.testing.assert_close( + paged_prefill_attn_output, direct_prefill_attn_output + ) + + @unittest.skip( + "Bug in Windows decode test for paged_decode_attn_output vs. direct_decode_attn_output" + ) + def testDirectAndPagedKVCacheDecode(self): + torch.set_default_dtype(torch.float32) + self.start_positions.add_(1) + assert self.direct_seq_block_ids.shape[1] == self.paged_seq_block_ids.shape[1] + decode_attention_mask = self.model.decode_attention_mask( + self.model.input_mask( + self.start_positions, self.direct_seq_block_ids.shape[1] * self.seq_len + ) + ) + + token_paged_input_tensor = make_rand_torch( + (1, 1, self.head_count * self.head_dim), dtype=self.attention_dtype + ) + token_direct_input_tensor = token_paged_input_tensor.detach().clone() + + xk_temp = torch.empty( + [ + self.bs, + self.max_seq_len, + self.head_count_kv, + self.head_dim, + ], + dtype=self.attention_dtype, + device=self.device, + ) + xv_temp = torch.empty( + [ + self.bs, + self.max_seq_len, + self.head_count_kv, + self.head_dim, + ], + dtype=self.attention_dtype, + device=self.device, + ) + + # Iterate over paged attention blocks. + for block_idx, paged_block in enumerate(self.paged_attn_blocks): + token_paged_input_tensor = paged_block( + token_paged_input_tensor, + start_positions=self.start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=self.embedding_batch_mask, + attention_mask=decode_attention_mask, + cache_state=self.paged_cache_state, + seq_block_ids=self.paged_seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + + # Iterate over direct attention blocks. + for block_idx, direct_block in enumerate(self.direct_attn_blocks): + token_direct_input_tensor = direct_block( + token_direct_input_tensor, + start_positions=self.start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=self.embedding_batch_mask, + attention_mask=decode_attention_mask, + cache_state=self.direct_cache_state, + seq_block_ids=self.direct_seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + + page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) + index_written = self.start_positions.item() + page_id = self.paged_seq_block_ids[0][0].item() + updated_direct_cache_state_keys = self.direct_cache_state[0][ + :, index_written + ].squeeze(0) + updated_paged_cache_state_keys = page_table[page_id][0, 0, index_written] + updated_direct_cache_state_values = self.direct_cache_state[1][ + :, index_written + ].squeeze(0) + updated_paged_cache_state_values = page_table[page_id][0, 1, index_written] + assert ( + updated_direct_cache_state_keys.shape + == updated_paged_cache_state_keys.shape + ) + torch.testing.assert_close( + updated_direct_cache_state_keys, updated_paged_cache_state_keys + ) + assert ( + updated_direct_cache_state_values.shape + == updated_paged_cache_state_values.shape + ) + torch.testing.assert_close( + updated_direct_cache_state_values, updated_paged_cache_state_values + ) + + paged_decode_attn_output = token_paged_input_tensor + direct_decode_attn_output = token_direct_input_tensor + assert paged_decode_attn_output.shape == direct_decode_attn_output.shape + torch.testing.assert_close(paged_decode_attn_output, direct_decode_attn_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py index 54469d40a..24a5f91b1 100644 --- a/sharktank/tests/ops/ops_test.py +++ b/sharktank/tests/ops/ops_test.py @@ -90,6 +90,18 @@ def testQuantizedTensorRhs(self): ... +class GemmTest(unittest.TestCase): + def testGemm(self): + a = torch.tensor([[1, 2], [3, 4]]) + b = torch.tensor([[5, 6], [7, 8]]) + c = torch.tensor([[9, 10], [11, 12]]) + alpha = 2 + beta = 3 + expected = alpha * a @ b.T + beta * c + result = ops.gemm(a, b, c, alpha, beta, False, True) + torch.testing.assert_close(result, expected) + + class MatmulTest(unittest.TestCase): def tearDown(self): ops._registry._test_enable_last_op_dispatch(False) diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index a098fd8be..34e5ebca7 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -337,6 +337,25 @@ def testNotEqualSharded(self): assert not ops.equal(b_sharded, a_sharded) +class GemmTest(unittest.TestCase): + def testShardedParallelDim(self): + a = torch.rand(4, 3) + b = torch.rand(5, 3) + c = torch.rand(4, 5) + alpha = 2 + beta = 3 + shard_count = 2 + expected = ops.gemm(a, b, c, alpha, beta, False, True) + sharded_a = ops.reshard_split(a, dim=0, count=shard_count) + sharded_c = ops.reshard_split(c, dim=0, count=shard_count) + sharded_result = ops.gemm(sharded_a, b, sharded_c, alpha, beta, False, True) + assert isinstance(sharded_result, SplitPrimitiveTensor) + assert sharded_result.shard_count == 2 + assert sharded_result.shard_dim == 0 + actual = ops.unshard(sharded_result) + torch.testing.assert_close(actual, expected) + + class InterpolateTest(unittest.TestCase): def testInterpolateSplitChannelDim(self): batches = 2