diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index a41542064..6d27198a9 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -60,7 +60,7 @@ jobs: run: | dnf install -y almalinux-release-devel epel-release yum remove -y openssl-devel zlib-devel || true - yum install -y protobuf-devel protobuf-compiler tmate + yum install -y protobuf-devel protobuf-compiler libnuma-devel tmate - name: Python deps run: | @@ -73,6 +73,12 @@ jobs: key: ${{ env.CACHE_KEY }} restore-keys: linux-build-test-cpp- + - name: Build ROCT/ROCR + run: | + export cache_dir="${{ env.CACHE_DIR }}" + bash build_tools/ci/build_roct_rocr.sh + echo "hsa-runtime64_ROOT=$PWD/rocr-install" >> $GITHUB_ENV + - name: Build packages run: | export cache_dir="${{ env.CACHE_DIR }}" diff --git a/.gitmodules b/.gitmodules index 67a096370..1e9d239c4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -23,3 +23,8 @@ path = third_party/iree url = https://github.com/iree-org/iree.git shallow = true +[submodule "third_party/ROCR-Runtime"] + path = third_party/ROCR-Runtime + url = https://github.com/nod-ai/ROCR-Runtime.git + shallow = true + branch = iree-aie diff --git a/build_tools/ci/build_roct_rocr.sh b/build_tools/ci/build_roct_rocr.sh new file mode 100755 index 000000000..772bb9854 --- /dev/null +++ b/build_tools/ci/build_roct_rocr.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +set -eux -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" + +roct_dir="$(cd $repo_root/third_party/ROCT-Thunk-Interface && pwd)" +rocr_dir="$(cd $repo_root/third_party/ROCR-Runtime && pwd)" + +build_roct_dir="$repo_root/roct-build" +roct_install_dir="$repo_root/roct-install" +mkdir -p "$build_roct_dir" +build_roct_dir="$(cd $build_roct_dir && pwd)" + +build_rocr_dir="$repo_root/rocr-build" +rocr_install_dir="$repo_root/rocr-install" +mkdir -p "$build_rocr_dir" +build_rocr_dir="$(cd $build_rocr_dir && pwd)" + +cache_dir="${cache_dir:-}" + +if [ -z "${cache_dir}" ]; then + cache_dir="${repo_root}/.build-cache" + mkdir -p "${cache_dir}" + cache_dir="$(cd ${cache_dir} && pwd)" +fi +echo "Caching to ${cache_dir}" +mkdir -p "${cache_dir}/ccache" + +if [[ "$OSTYPE" == "msys"* ]]; then + export CC=clang-cl.exe + export CXX=clang-cl.exe +fi +export CCACHE_DIR="${cache_dir}/ccache" +export CCACHE_MAXSIZE="700M" +export CMAKE_C_COMPILER_LAUNCHER=ccache +export CMAKE_CXX_COMPILER_LAUNCHER=ccache + +cd $roct_dir +cmake -GNinja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="$roct_install_dir" \ + -S "$roct_dir" -B "$build_roct_dir" +cmake --build "$build_roct_dir" --target install + +cd $rocr_dir +cmake -GNinja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="$rocr_install_dir" \ + -DCMAKE_PREFIX_PATH="$roct_install_dir" \ + -DIMAGE_SUPPORT=OFF \ + -S "$rocr_dir/src" -B "$build_rocr_dir" +cmake --build "$build_rocr_dir" --target install diff --git a/build_tools/ci/build_test_cpp.sh b/build_tools/ci/build_test_cpp.sh index 6f7a820aa..8ad78c03f 100644 --- a/build_tools/ci/build_test_cpp.sh +++ b/build_tools/ci/build_test_cpp.sh @@ -91,7 +91,7 @@ if [[ "$OSTYPE" != "darwin"* ]]; then -DCMAKE_CXX_COMPILER="${CXX}" \ -DLLVM_TARGET_ARCH=X86 \ -DLLVM_TARGETS_TO_BUILD=X86 \ - -DIREE_EXTERNAL_HAL_DRIVERS=xrt \ + -DIREE_EXTERNAL_HAL_DRIVERS=hsa \ -S $iree_dir -B $build_dir else cmake $CMAKE_ARGS \ diff --git a/iree_runtime_plugin.cmake b/iree_runtime_plugin.cmake index 15a4d07da..3737aa9fd 100644 --- a/iree_runtime_plugin.cmake +++ b/iree_runtime_plugin.cmake @@ -26,5 +26,16 @@ if(IREE_AMD_AIE_ENABLE_XRT_DRIVER) include(iree_aie_bootgen) endif() +set(IREE_AMD_AIE_ENABLE_HSA_DRIVER OFF) +if("hsa" IN_LIST IREE_EXTERNAL_HAL_DRIVERS) + message(STATUS "Enabling HSA build because it is an enabled HAL driver") + set(IREE_AMD_AIE_ENABLE_HSA_DRIVER ON) +endif() + +if(IREE_AMD_AIE_ENABLE_HSA_DRIVER) + find_package(hsa-runtime64 CONFIG REQUIRED + NAMES hsa-runtime64 hsa_runtime64) +endif() + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/runtime/src AMD-AIE) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/experimental AMD-AIE-experimental) diff --git a/runtime/src/iree-amd-aie/CMakeLists.txt b/runtime/src/iree-amd-aie/CMakeLists.txt index bfa015081..3d4309203 100644 --- a/runtime/src/iree-amd-aie/CMakeLists.txt +++ b/runtime/src/iree-amd-aie/CMakeLists.txt @@ -8,6 +8,10 @@ if(IREE_AMD_AIE_ENABLE_XRT_DRIVER) add_subdirectory(driver/xrt) endif() +if(IREE_AMD_AIE_ENABLE_HSA_DRIVER) + add_subdirectory(driver/hsa) +endif() + # Flatbuffer schema generation does not require XRT. Moreover the generated # flatbuffer header files are used by the compiler to create artefacts # (.vmfb file), and so the schema sub-directory is required even when not diff --git a/runtime/src/iree-amd-aie/driver/hsa/CMakeLists.txt b/runtime/src/iree-amd-aie/driver/hsa/CMakeLists.txt new file mode 100644 index 000000000..97e99d25b --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/CMakeLists.txt @@ -0,0 +1,92 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +iree_register_external_hal_driver( + NAME + hsa + DRIVER_TARGET + iree-amd-aie::driver::hsa::registration + REGISTER_FN + iree_hal_hsa_driver_module_register +) + +iree_cc_library( + NAME + dynamic_symbols + HDRS + "dynamic_symbols.h" + "status_util.h" + TEXTUAL_HDRS + "dynamic_symbol_tables.h" + SRCS + "dynamic_symbols.c" + "hsa_headers.h" + "status_util.c" + DEPS + hsa-runtime64::hsa-runtime64 + iree::base + iree::base::core_headers + iree::base::internal::dynamic_library + PUBLIC +) + +iree_cc_library( + NAME + hsa + HDRS + "api.h" + SRCS + "api.h" + "event_pool.c" + "event_pool.h" + "event_semaphore.c" + "event_semaphore.h" + "hsa_allocator.c" + "hsa_allocator.h" + "hsa_buffer.c" + "hsa_buffer.h" + "hsa_device.c" + "hsa_device.h" + "hsa_driver.c" + "native_executable.c" + "native_executable.h" + "nop_executable_cache.c" + "nop_executable_cache.h" + "pending_queue_actions.c" + "pending_queue_actions.h" + "pipeline_layout.c" + "pipeline_layout.h" + "queue_command_buffer.c" + "queue_command_buffer.h" + "timepoint_pool.c" + "timepoint_pool.h" + DEPS + hsa-runtime64::hsa-runtime64 + ::dynamic_symbols + iree::base + iree::base::core_headers + iree::base::internal + iree::base::internal::arena + iree::base::internal::atomic_slist + iree::base::internal::event_pool + iree::base::internal::synchronization + iree::base::internal::threading + iree::base::internal::wait_handle + iree::base::internal::flatcc::parsing + iree::hal + iree::hal::utils::collective_batch + iree::hal::utils::deferred_command_buffer + iree::hal::utils::file_transfer + iree::hal::utils::memory_file + iree::hal::utils::resource_set + iree::hal::utils::semaphore_base + iree::schemas::rocm_executable_def_c_fbs + PUBLIC +) + diff --git a/runtime/src/iree-amd-aie/driver/hsa/api.h b/runtime/src/iree-amd-aie/driver/hsa/api.h new file mode 100644 index 000000000..b2558cdfd --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/api.h @@ -0,0 +1,108 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// See iree/base/api.h for documentation on the API conventions used. + +#ifndef IREE_EXPERIMENTAL_HSA_API_H_ +#define IREE_EXPERIMENTAL_HSA_API_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_device_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hsa_memory_pool_params_t { + // Minimum number of bytes to keep in the pool when trimming with + // iree_hal_device_trim. + uint64_t minimum_capacity; + // Soft maximum number of bytes to keep in the pool. + // When more than this is allocated the extra will be freed at the next + // device synchronization in order to remain under the threshold. + uint64_t release_threshold; +} iree_hal_hsa_memory_pool_params_t; + +typedef struct iree_hal_hsa_memory_pooling_params_t { + // Used exclusively for DEVICE_LOCAL allocations. + iree_hal_hsa_memory_pool_params_t device_local; + // Used for any host-visible/host-local memory types. + iree_hal_hsa_memory_pool_params_t other; +} iree_hal_hsa_memory_pooling_params_t; + +// Parameters configuring an iree_hal_hsa_device_t. +// Must be initialized with iree_hal_hsa_device_params_initialize prior to +// use. +typedef struct iree_hal_hsa_device_params_t { + // Number of queues exposed on the device. + // Each queue acts as a separate synchronization scope where all work executes + // concurrently unless prohibited by semaphores. + iree_host_size_t queue_count; + + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; + + // The host and device event pool capacity. + // The HSA driver implements semaphore with host and device events. This + // parameter controls the size of those pools. Larger values would make + // creating semaphore values quicker, though with increased memory + // consumption. + iree_host_size_t event_pool_capacity; + + // Enables tracing of command buffers when IREE tracing is enabled. + // May take advantage of additional extensions for more accurate timing or + // hardware-specific performance counters. + // + // NOTE: tracing has a non-trivial overhead and will skew the timing of + // submissions and introduce false barriers between dispatches. Use this to + // identify slow dispatches and refine from there; be wary of whole-program + // tracing with this enabled. + bool queue_tracing; + + // Parameters for each memory pool used for queue-ordered allocations. + iree_hal_hsa_memory_pooling_params_t memory_pools; +} iree_hal_hsa_device_params_t; + +// Initializes |out_params| to default values. +IREE_API_EXPORT void iree_hal_hsa_device_params_initialize( + iree_hal_hsa_device_params_t* out_params); + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_driver_t +//===----------------------------------------------------------------------===// + +// HSA HAL driver creation options. +typedef struct iree_hal_hsa_driver_options_t { + // The index of the default HSA device to use within the list of available + // devices. + int default_device_index; +} iree_hal_hsa_driver_options_t; + +// Initializes the given |out_options| with default driver creation options. +IREE_API_EXPORT void iree_hal_hsa_driver_options_initialize( + iree_hal_hsa_driver_options_t* out_options); + +// Creates a HSA HAL driver with the given |options|, from which HSA devices +// can be enumerated and created with specific parameters. +// +// |out_driver| must be released by the caller (see iree_hal_driver_release). +IREE_API_EXPORT iree_status_t iree_hal_hsa_driver_create( + iree_string_view_t identifier, const iree_hal_hsa_driver_options_t* options, + const iree_hal_hsa_device_params_t* default_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_API_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbol_tables.h b/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbol_tables.h new file mode 100644 index 000000000..2feed7f09 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbol_tables.h @@ -0,0 +1,93 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// HSA symbols +//===----------------------------------------------------------------------===// + +#include + +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_init) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_shut_down) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_agent_get_info, hsa_agent_t, + hsa_agent_info_t, void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_iterate_agents, + hsa_status_t (*)(hsa_agent_t, void *), void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_queue_create, hsa_agent_t, uint32_t, + hsa_queue_type32_t, + void (*)(hsa_status_t, hsa_queue_t *, void *), + void *, uint32_t, uint32_t, hsa_queue_t **) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_signal_wait_scacquire, hsa_signal_t, + hsa_signal_condition_t, hsa_signal_value_t, + uint64_t, hsa_wait_state_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_queue_load_write_index_relaxed, + const hsa_queue_t *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_signal_create, hsa_signal_value_t, uint32_t, + const hsa_agent_t *, hsa_signal_t *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_queue_store_write_index_release, + const hsa_queue_t *, uint64_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_queue_add_write_index_relaxed, + const hsa_queue_t *, uint64_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_signal_store_screlease, hsa_signal_t, + hsa_signal_value_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_signal_store_relaxed, hsa_signal_t, + hsa_signal_value_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_signal_add_screlease, hsa_signal_t, + hsa_signal_value_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_signal_wait_acquire, hsa_signal_t, + hsa_signal_condition_t, hsa_signal_value_t, + uint64_t, hsa_wait_state_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_signal_destroy, hsa_signal_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_executable_get_symbol_by_name, + hsa_executable_t, const char *, + const hsa_agent_t *, hsa_executable_symbol_t *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_executable_symbol_get_info, + hsa_executable_symbol_t, + hsa_executable_symbol_info_t, void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_ext_image_create, hsa_agent_t, + const hsa_ext_image_descriptor_t *, const void *, + hsa_access_permission_t, hsa_ext_image_t *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_executable_create_alt, hsa_profile_t, + hsa_default_float_rounding_mode_t, const char *, + hsa_executable_t *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_executable_load_agent_code_object, + hsa_executable_t, hsa_agent_t, + hsa_code_object_reader_t, const char *, + hsa_loaded_code_object_t *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_executable_freeze, hsa_executable_t, + const char *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_executable_destroy, hsa_executable_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_code_object_reader_create_from_memory, + const void *, size_t, hsa_code_object_reader_t *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_agent_iterate_regions, hsa_agent_t, + hsa_status_t (*)(hsa_region_t, void *), void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_agent_iterate_memory_pools, hsa_agent_t, + hsa_status_t (*)(hsa_amd_memory_pool_t, void *), + void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_region_get_info, hsa_region_t, + hsa_region_info_t, void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_memory_pool_get_info, + hsa_amd_memory_pool_t, + hsa_amd_memory_pool_info_t, void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_memory_allocate, hsa_region_t, size_t, + void **) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_memory_free, void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_memory_pool_allocate, + hsa_amd_memory_pool_t, size_t, uint32_t, void **) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_memory_pool_free, void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_memory_async_copy, void *, hsa_agent_t, + const void *, hsa_agent_t, size_t, uint32_t, + const hsa_signal_t *, hsa_signal_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_signal_async_handler, hsa_signal_t, + hsa_signal_condition_t, hsa_signal_value_t, + hsa_amd_signal_handler, void *) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_memory_copy, void *, const void *, size_t) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_memory_lock_to_pool, void *, size_t, + hsa_agent_t *, int, hsa_amd_memory_pool_t, + uint32_t, void **) +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_amd_memory_fill, void *, uint32_t, size_t); +IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_status_string, hsa_status_t, const char **) diff --git a/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbols.c b/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbols.c new file mode 100644 index 000000000..d7032e444 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbols.c @@ -0,0 +1,80 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" + +#include + +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" +#include "iree/base/internal/dynamic_library.h" +#include "iree/base/target_platform.h" + +//===----------------------------------------------------------------------===// +// HSA dynamic symbols +//===----------------------------------------------------------------------===// + +static const char* iree_hal_hsa_dylib_names[] = { +#if defined(IREE_PLATFORM_WINDOWS) + "libhsa-runtime64.dll", +#else + "libhsa-runtime64.so", +#endif // IREE_PLATFORM_WINDOWS +}; + +// Resolves all HSA dynamic symbols in `dynamic_symbol_tables.h` +static iree_status_t iree_hal_hsa_dynamic_symbols_resolve_all( + iree_hal_hsa_dynamic_symbols_t* syms) { +#define IREE_HAL_HSA_REQUIRED_PFN_DECL(hsa_symbol_name, ...) \ + { \ + static const char* name = #hsa_symbol_name; \ + IREE_RETURN_IF_ERROR(iree_dynamic_library_lookup_symbol( \ + syms->dylib, name, (void**)&syms->hsa_symbol_name)); \ + } + +#include "iree-amd-aie/driver/hsa/dynamic_symbol_tables.h" // IWYU pragma: keep +#undef IREE_HAL_HSA_REQUIRED_PFN_DECL + return iree_ok_status(); +} + +iree_status_t iree_hal_hsa_dynamic_symbols_initialize( + iree_allocator_t host_allocator, iree_hal_hsa_dynamic_symbols_t* out_syms) { + IREE_ASSERT_ARGUMENT(out_syms); + IREE_TRACE_ZONE_BEGIN(z0); + + memset(out_syms, 0, sizeof(*out_syms)); + iree_status_t status = iree_dynamic_library_load_from_files( + IREE_ARRAYSIZE(iree_hal_hsa_dylib_names), iree_hal_hsa_dylib_names, + IREE_DYNAMIC_LIBRARY_FLAG_NONE, host_allocator, &out_syms->dylib); + if (iree_status_is_not_found(status)) { + iree_status_ignore(status); + status = iree_make_status( + IREE_STATUS_UNAVAILABLE, + "HSA runtime library 'libhsa-runtime64.dll'/'libhsa-runtime64.so' not " + "available;" + "please ensure installed and in dynamic library search path"); + } + if (iree_status_is_ok(status)) { + status = iree_hal_hsa_dynamic_symbols_resolve_all(out_syms); + } + if (!iree_status_is_ok(status)) { + iree_hal_hsa_dynamic_symbols_deinitialize(out_syms); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hsa_dynamic_symbols_deinitialize( + iree_hal_hsa_dynamic_symbols_t* syms) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_dynamic_library_release(syms->dylib); + memset(syms, 0, sizeof(*syms)); + + IREE_TRACE_ZONE_END(z0); +} diff --git a/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbols.h b/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbols.h new file mode 100644 index 000000000..f6ab2abce --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/dynamic_symbols.h @@ -0,0 +1,57 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_DYNAMIC_SYMBOLS_H_ +#define IREE_EXPERIMENTAL_HSA_DYNAMIC_SYMBOLS_H_ + +#include "iree-amd-aie/driver/hsa/hsa_headers.h" +#include "iree/base/api.h" +#include "iree/base/internal/dynamic_library.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// iree_dynamic_library_t allows dynamically loading a subset of HSA driver API. +// We load all the symbols in `dynamic_symbol_tables.h` and fail if any of the +// symbol is not available. The functions signatures are matching the +// declarations in the HSA headers. + +//===----------------------------------------------------------------------===// +// HSA dynamic symbols +//===----------------------------------------------------------------------===// + +// HSA driver API dynamic symbols. +typedef struct iree_hal_hsa_dynamic_symbols_t { + // The dynamic library handle. + iree_dynamic_library_t* dylib; + + // Concrete HSA symbols defined by including the `dynamic_symbol_tables.h`. +#define IREE_HAL_HSA_REQUIRED_PFN_DECL(hsaSymbolName, ...) \ + hsa_status_t (*hsaSymbolName)(__VA_ARGS__); + +#include "iree-amd-aie/driver/hsa/dynamic_symbol_tables.h" // IWYU pragma: export +#undef IREE_HAL_HSA_REQUIRED_PFN_DECL +} iree_hal_hsa_dynamic_symbols_t; + +// Initializes |out_syms| in-place with dynamically loaded HSA symbols. +// iree_hal_hsa_dynamic_symbols_deinitialize must be used to release the +// library resources. +iree_status_t iree_hal_hsa_dynamic_symbols_initialize( + iree_allocator_t host_allocator, iree_hal_hsa_dynamic_symbols_t* out_syms); + +// Deinitializes |syms| by unloading the backing library. All function pointers +// will be invalidated. They _may_ still work if there are other reasons the +// library remains loaded so be careful. +void iree_hal_hsa_dynamic_symbols_deinitialize( + iree_hal_hsa_dynamic_symbols_t* syms); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_DYNAMIC_SYMBOLS_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/event_pool.c b/runtime/src/iree-amd-aie/driver/hsa/event_pool.c new file mode 100644 index 000000000..a23da0abb --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/event_pool.c @@ -0,0 +1,315 @@ + +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/event_pool.h" + +#include +#include +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" +#include "iree/base/internal/atomics.h" +#include "iree/base/internal/synchronization.h" +#include "iree/hal/api.h" + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_event_t +//===----------------------------------------------------------------------===// + +struct iree_hal_hsa_event_t { + // A reference count used to manage resource lifetime. Its value range: + // * 1 - when inside the event pool and to be acquired; + // * >= 1 - when acquired outside of the event pool; + // * 0 - when before releasing back to the pool or destruction. + iree_atomic_ref_count_t ref_count; + + // The allocator used to create the event. + iree_allocator_t host_allocator; + // The symbols used to create and destroy signals objects. + const iree_hal_hsa_dynamic_symbols_t* symbols; + + // The event pool that owns this event. This cannot be NULL. We retain it to + // make sure the event outlive the pool. + iree_hal_hsa_event_pool_t* pool; + + hsa_signal_t signal; +}; + + +hsa_signal_t iree_hal_hsa_signal_handle(const iree_hal_hsa_event_t* event) { + return event->signal; +} + +static inline void iree_hal_hsa_event_destroy(iree_hal_hsa_event_t* event) { + iree_allocator_t host_allocator = event->host_allocator; + const iree_hal_hsa_dynamic_symbols_t* symbols = event->symbols; + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count); + IREE_HSA_IGNORE_ERROR(symbols, hsa_signal_destroy(event->signal)); + iree_allocator_free(host_allocator, event); + + IREE_TRACE_ZONE_END(z0); +} + +static inline iree_status_t iree_hal_hsa_event_create( + const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_hal_hsa_event_pool_t* pool, iree_allocator_t host_allocator, + iree_hal_hsa_event_t** out_event) { + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(pool); + IREE_ASSERT_ARGUMENT(out_event); + *out_event = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_event_t* event = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, sizeof(*event), (void**)&event)); + iree_atomic_ref_count_init(&event->ref_count); // -> 1 + event->host_allocator = host_allocator; + event->symbols = symbols; + event->pool = pool; + + hsa_signal_value_t signal_value = 1; + uint32_t num_consumers = 0; + const hsa_agent_t* consumers = NULL; + + iree_status_t status = IREE_HSA_RESULT_TO_STATUS( + symbols, + hsa_signal_create(signal_value, num_consumers, consumers, &event->signal), + "hsa_signal_create"); + + if (iree_status_is_ok(status)) { + *out_event = event; + } else { + iree_atomic_ref_count_dec(&event->ref_count); // -> 0 + iree_hal_hsa_event_destroy(event); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hsa_event_retain(iree_hal_hsa_event_t* event) { + iree_atomic_ref_count_inc(&event->ref_count); +} + +static void iree_hal_hsa_event_pool_release_event( + iree_hal_hsa_event_pool_t* event_pool, iree_host_size_t event_count, + iree_hal_hsa_event_t** events); + +void iree_hal_hsa_event_release(iree_hal_hsa_event_t* event) { + if (iree_atomic_ref_count_dec(&event->ref_count) == 1) { + iree_hal_hsa_event_pool_t* pool = event->pool; + // Release back to the pool if the reference count becomes 0. + iree_hal_hsa_event_pool_release_event(pool, 1, &event); + // Drop our reference to the pool itself when we return event to it. + iree_hal_hsa_event_pool_release(pool); // -1 + } +} + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_event_pool_t +//===----------------------------------------------------------------------===// + +struct iree_hal_hsa_event_pool_t { + // A reference count used to manage resource lifetime. + iree_atomic_ref_count_t ref_count; + + // The allocator used to create the event pool. + iree_allocator_t host_allocator; + // The symbols used to create and destroy signals objects. + const iree_hal_hsa_dynamic_symbols_t* symbols; + + // Guards event related fields in the pool. We don't expect a performant + // program to frequently allocate events for synchronization purposes; the + // traffic to this pool should be low. So it should be fine to use mutex to + // guard here. + iree_slim_mutex_t event_mutex; + + // Maximum number of event objects that will be maintained in the pool. + // More events may be allocated at any time, but they will be disposed + // directly when they are no longer needed. + iree_host_size_t available_capacity IREE_GUARDED_BY(event_mutex); + // Total number of currently available event objects. + iree_host_size_t available_count IREE_GUARDED_BY(event_mutex); + // The list of available_count event objects. + iree_hal_hsa_event_t* available_list[] IREE_GUARDED_BY(event_mutex); +}; +// + Additional inline allocation for holding events up to the capacity. + +static void iree_hal_hsa_event_pool_free(iree_hal_hsa_event_pool_t* event_pool); + +iree_status_t iree_hal_hsa_event_pool_allocate( + const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_host_size_t available_capacity, iree_allocator_t host_allocator, + iree_hal_hsa_event_pool_t** out_event_pool) { + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(out_event_pool); + *out_event_pool = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_event_pool_t* event_pool = NULL; + iree_host_size_t total_size = + sizeof(*event_pool) + + available_capacity * sizeof(*event_pool->available_list); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, total_size, (void**)&event_pool)); + iree_atomic_ref_count_init(&event_pool->ref_count); // -> 1 + event_pool->host_allocator = host_allocator; + event_pool->symbols = symbols; + iree_slim_mutex_initialize(&event_pool->event_mutex); + event_pool->available_capacity = available_capacity; + event_pool->available_count = 0; + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < available_capacity; ++i) { + status = iree_hal_hsa_event_create( + symbols, event_pool, host_allocator, + &event_pool->available_list[event_pool->available_count++]); + if (!iree_status_is_ok(status)) break; + } + + if (iree_status_is_ok(status)) { + *out_event_pool = event_pool; + } else { + iree_hal_hsa_event_pool_free(event_pool); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_hsa_event_pool_free( + iree_hal_hsa_event_pool_t* event_pool) { + iree_allocator_t host_allocator = event_pool->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < event_pool->available_count; ++i) { + iree_hal_hsa_event_t* event = event_pool->available_list[i]; + iree_atomic_ref_count_dec(&event->ref_count); // -> 0 + iree_hal_hsa_event_destroy(event); + } + IREE_ASSERT_REF_COUNT_ZERO(&event_pool->ref_count); + + iree_slim_mutex_deinitialize(&event_pool->event_mutex); + iree_allocator_free(host_allocator, event_pool); + + IREE_TRACE_ZONE_END(z0); +} + +void iree_hal_hsa_event_pool_retain(iree_hal_hsa_event_pool_t* event_pool) { + iree_atomic_ref_count_inc(&event_pool->ref_count); +} + +void iree_hal_hsa_event_pool_release(iree_hal_hsa_event_pool_t* event_pool) { + if (iree_atomic_ref_count_dec(&event_pool->ref_count) == 1) { + iree_hal_hsa_event_pool_free(event_pool); + } +} + +iree_status_t iree_hal_hsa_event_pool_acquire( + iree_hal_hsa_event_pool_t* event_pool, iree_host_size_t event_count, + iree_hal_hsa_event_t** out_events) { + IREE_ASSERT_ARGUMENT(event_pool); + if (!event_count) return iree_ok_status(); + IREE_ASSERT_ARGUMENT(out_events); + IREE_TRACE_ZONE_BEGIN(z0); + + // We'll try to get what we can from the pool and fall back to initializing + // new iree_hal_hsa_event_t objects. + iree_host_size_t remaining_count = event_count; + + // Try first to grab from the pool. + iree_slim_mutex_lock(&event_pool->event_mutex); + iree_host_size_t from_pool_count = + iree_min(event_pool->available_count, event_count); + if (from_pool_count > 0) { + iree_host_size_t pool_base_index = + event_pool->available_count - from_pool_count; + memcpy(out_events, &event_pool->available_list[pool_base_index], + from_pool_count * sizeof(*event_pool->available_list)); + event_pool->available_count -= from_pool_count; + remaining_count -= from_pool_count; + } + iree_slim_mutex_unlock(&event_pool->event_mutex); + + // Allocate the rest of the events. + if (remaining_count > 0) { + IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire"); + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < remaining_count; ++i) { + status = iree_hal_hsa_event_create(event_pool->symbols, event_pool, + event_pool->host_allocator, + &out_events[from_pool_count + i]); + if (!iree_status_is_ok(status)) { + // Must release all events we've acquired so far. + iree_hal_hsa_event_pool_release_event(event_pool, from_pool_count + i, + out_events); + IREE_TRACE_ZONE_END(z1); + IREE_TRACE_ZONE_END(z0); + return status; + } + } + IREE_TRACE_ZONE_END(z1); + } + + // Retain a reference to a pool when we pass event to the caller. When the + // caller returns event back to the pool they'll release the reference. + for (iree_host_size_t i = 0; i < event_count; ++i) { + iree_hal_hsa_event_pool_retain(out_events[i]->pool); // +1 + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_hsa_event_pool_release_event( + iree_hal_hsa_event_pool_t* event_pool, iree_host_size_t event_count, + iree_hal_hsa_event_t** events) { + IREE_ASSERT_ARGUMENT(event_pool); + if (!event_count) return; + IREE_ASSERT_ARGUMENT(events); + IREE_TRACE_ZONE_BEGIN(z0); + + // We'll try to release all we can back to the pool and then deinitialize + // the ones that won't fit. + iree_host_size_t remaining_count = event_count; + + // Try first to release to the pool. + iree_slim_mutex_lock(&event_pool->event_mutex); + iree_host_size_t to_pool_count = + iree_min(event_pool->available_capacity - event_pool->available_count, + event_count); + if (to_pool_count > 0) { + for (iree_host_size_t i = 0; i < to_pool_count; ++i) { + IREE_ASSERT_REF_COUNT_ZERO(&events[i]->ref_count); + iree_hal_hsa_event_retain(events[i]); // -> 1 + } + iree_host_size_t pool_base_index = event_pool->available_count; + memcpy(&event_pool->available_list[pool_base_index], events, + to_pool_count * sizeof(*event_pool->available_list)); + event_pool->available_count += to_pool_count; + remaining_count -= to_pool_count; + } + iree_slim_mutex_unlock(&event_pool->event_mutex); + + // Deallocate the rest of the events. We don't bother resetting them as we are + // getting rid of them. + if (remaining_count > 0) { + IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-release"); + for (iree_host_size_t i = 0; i < remaining_count; ++i) { + iree_hal_hsa_event_destroy(events[to_pool_count + i]); + } + IREE_TRACE_ZONE_END(z1); + } + IREE_TRACE_ZONE_END(z0); +} diff --git a/runtime/src/iree-amd-aie/driver/hsa/event_pool.h b/runtime/src/iree-amd-aie/driver/hsa/event_pool.h new file mode 100644 index 000000000..3a4f99e1d --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/event_pool.h @@ -0,0 +1,81 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_EVENT_POOL_H_ +#define IREE_EXPERIMENTAL_HSA_EVENT_POOL_H_ + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_event_t +//===----------------------------------------------------------------------===// + +// An struct that wraps a signals object with a reference count for lifetime +// management. +// +// iree_hal_hsa_event_t objects cannot be directly created; they should be +// acquired from the event pool and released back to the pool once done. +// +// Thread-safe; multiple threads may retain and release the same event. +typedef struct iree_hal_hsa_event_t iree_hal_hsa_event_t; + +// Returns the underlying hsa_signal_tt handle behind |event|. +hsa_signal_t iree_hal_hsa_signal_handle(const iree_hal_hsa_event_t* event); + +// Retains the given |event| by increasing its reference count. +void iree_hal_hsa_event_retain(iree_hal_hsa_event_t* event); + +// Releases the given |event| by decreasing its reference count. +// +// |event| will be returned to its owning pool when the reference count is 0. +void iree_hal_hsa_event_release(iree_hal_hsa_event_t* event); + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_event_pool_t +//===----------------------------------------------------------------------===// + +// A simple pool of iree_hal_event_t objects to recycle. +// +// Thread-safe; multiple threads may acquire and release events from the pool. +typedef struct iree_hal_hsa_event_pool_t iree_hal_hsa_event_pool_t; + +// Allocates a new event pool with up to |available_capacity| events. +// +// Extra events requested beyond the capability are directly created and +// destroyed without pooling. +iree_status_t iree_hal_hsa_event_pool_allocate( + const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_host_size_t available_capacity, iree_allocator_t host_allocator, + iree_hal_hsa_event_pool_t** out_event_pool); + +// Retains the given |event_pool| by increasing its reference count. +void iree_hal_hsa_event_pool_retain(iree_hal_hsa_event_pool_t* event_pool); + +// Releases the given |event_pool| by decreasing its reference count. +// +// Once the |event_pool|'s reference count becomes zero, it will be freed. +void iree_hal_hsa_event_pool_release(iree_hal_hsa_event_pool_t* event_pool); + +// Acquires one or more events from the event pool. +// +// Each returned event have an initial reference count of 1. The returned +// signal objects may retain captured states of some queues from previous +// uses; callers should record again to overwrite. +iree_status_t iree_hal_hsa_event_pool_acquire( + iree_hal_hsa_event_pool_t* event_pool, iree_host_size_t event_count, + iree_hal_hsa_event_t** out_events); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_EVENT_POOL_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/event_semaphore.c b/runtime/src/iree-amd-aie/driver/hsa/event_semaphore.c new file mode 100644 index 000000000..445f177f8 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/event_semaphore.c @@ -0,0 +1,545 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/event_semaphore.h" + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree-amd-aie/driver/hsa/timepoint_pool.h" +#include "iree/base/internal/synchronization.h" +#include "iree/base/internal/wait_handle.h" +#include "iree/hal/utils/semaphore_base.h" + +typedef struct iree_hal_hsa_semaphore_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_semaphore_t base; + + // The allocator used to create this semaphore. + iree_allocator_t host_allocator; + // The symbols used to issue HSA API calls. + const iree_hal_hsa_dynamic_symbols_t* symbols; + + // The timepoint pool to acquire timepoint objects. + iree_hal_hsa_timepoint_pool_t* timepoint_pool; + + // The list of pending queue actions that this semaphore need to advance on + // new signaled values. + iree_hal_hsa_pending_queue_actions_t* pending_queue_actions; + + // Guards value and status. We expect low contention on semaphores and since + // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler + // than trying to make the entire structure lock-free. + iree_slim_mutex_t mutex; + + // Current signaled value. May be IREE_HAL_SEMAPHORE_FAILURE_VALUE to + // indicate that the semaphore has been signaled for failure and + // |failure_status| contains the error. + uint64_t current_value IREE_GUARDED_BY(mutex); + + // OK or the status passed to iree_hal_semaphore_fail. Owned by the semaphore. + iree_status_t failure_status IREE_GUARDED_BY(mutex); +} iree_hal_hsa_semaphore_t; + +static const iree_hal_semaphore_vtable_t iree_hal_hsa_semaphore_vtable; + +static iree_hal_hsa_semaphore_t* iree_hal_hsa_semaphore_cast( + iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_semaphore_vtable); + return (iree_hal_hsa_semaphore_t*)base_value; +} + +iree_status_t iree_hal_hsa_event_semaphore_create( + uint64_t initial_value, const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_hal_hsa_pending_queue_actions_t* pending_queue_actions, + iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(timepoint_pool); + IREE_ASSERT_ARGUMENT(pending_queue_actions); + IREE_ASSERT_ARGUMENT(out_semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_semaphore_t* semaphore = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore), + (void**)&semaphore)); + + iree_hal_semaphore_initialize(&iree_hal_hsa_semaphore_vtable, + &semaphore->base); + semaphore->host_allocator = host_allocator; + semaphore->symbols = symbols; + semaphore->timepoint_pool = timepoint_pool; + semaphore->pending_queue_actions = pending_queue_actions; + iree_slim_mutex_initialize(&semaphore->mutex); + semaphore->current_value = initial_value; + semaphore->failure_status = iree_ok_status(); + + *out_semaphore = &semaphore->base; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_hsa_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(base_semaphore); + iree_allocator_t host_allocator = semaphore->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_ignore(semaphore->failure_status); + iree_slim_mutex_deinitialize(&semaphore->mutex); + + iree_hal_semaphore_deinitialize(&semaphore->base); + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_hsa_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(base_semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_slim_mutex_lock(&semaphore->mutex); + + *out_value = semaphore->current_value; + + iree_status_t status = iree_ok_status(); + if (*out_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + status = iree_status_clone(semaphore->failure_status); + } + + iree_slim_mutex_unlock(&semaphore->mutex); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hsa_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(base_semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_slim_mutex_lock(&semaphore->mutex); + + if (new_value <= semaphore->current_value) { + uint64_t current_value IREE_ATTRIBUTE_UNUSED = semaphore->current_value; + iree_slim_mutex_unlock(&semaphore->mutex); + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "semaphore values must be monotonically " + "increasing; current_value=%" PRIu64 + ", new_value=%" PRIu64, + current_value, new_value); + } + + semaphore->current_value = new_value; + + iree_slim_mutex_unlock(&semaphore->mutex); + + // Notify timepoints - note that this must happen outside the lock. + iree_hal_semaphore_notify(&semaphore->base, new_value, IREE_STATUS_OK); + + // Advance the pending queue actions if possible. This also must happen + // outside the lock to avoid nesting. + iree_status_t status = iree_hal_hsa_pending_queue_actions_issue( + semaphore->pending_queue_actions); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_hsa_semaphore_fail(iree_hal_semaphore_t* base_semaphore, + iree_status_t status) { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(base_semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + + const iree_status_code_t status_code = iree_status_code(status); + + iree_slim_mutex_lock(&semaphore->mutex); + + // Try to set our local status - we only preserve the first failure so only + // do this if we are going from a valid semaphore to a failed one. + if (!iree_status_is_ok(semaphore->failure_status)) { + // Previous status was not OK; drop our new status. + IREE_IGNORE_ERROR(status); + iree_slim_mutex_unlock(&semaphore->mutex); + IREE_TRACE_ZONE_END(z0); + return; + } + + // Signal to our failure sentinel value. + semaphore->current_value = IREE_HAL_SEMAPHORE_FAILURE_VALUE; + semaphore->failure_status = status; + + iree_slim_mutex_unlock(&semaphore->mutex); + + // Notify timepoints - note that this must happen outside the lock. + iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE, + status_code); + IREE_TRACE_ZONE_END(z0); +} + +// Handles host wait timepoints on the host when the |semaphore| timeline +// advances past the given |value|. +// +// Note that this callback is invoked by the a host thread. +static iree_status_t iree_hal_hsa_semaphore_timepoint_host_wait_callback( + void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value, + iree_status_code_t status_code) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hsa_timepoint_t* timepoint = (iree_hal_hsa_timepoint_t*)user_data; + iree_event_set(&timepoint->timepoint.host_wait); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Acquires a timepoint to wait the timeline to reach at least the given +// |min_value| from the host. +static iree_status_t iree_hal_hsa_semaphore_acquire_timepoint_host_wait( + iree_hal_hsa_semaphore_t* semaphore, uint64_t min_value, + iree_timeout_t timeout, iree_hal_hsa_timepoint_t** out_timepoint) { + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_timepoint_pool_acquire_host_wait( + semaphore->timepoint_pool, 1, out_timepoint)); + // Initialize the timepoint with the value and callback, and connect it to + // this semaphore. + iree_hal_semaphore_acquire_timepoint( + &semaphore->base, min_value, timeout, + (iree_hal_semaphore_callback_t){ + .fn = iree_hal_hsa_semaphore_timepoint_host_wait_callback, + .user_data = *out_timepoint, + }, + &(*out_timepoint)->base); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Acquires an iree_hal_hsa_event_t object to wait on the host for the +// timeline to reach at least the given |min_value| on the device. +// Returns true and writes to |out_event| if we can find such an event; +// returns false otherwise. +// The caller should release the |out_event| once done. +static bool iree_hal_hsa_semaphore_acquire_event_host_wait( + iree_hal_hsa_semaphore_t* semaphore, uint64_t min_value, + iree_hal_hsa_event_t** out_event) { + *out_event = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Scan through the timepoint list and try to find a device event signal to + // wait on. We need to lock with the timepoint list mutex here. + iree_slim_mutex_lock(&semaphore->base.timepoint_mutex); + for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head; + tp != NULL; tp = tp->next) { + iree_hal_hsa_timepoint_t* signal_timepoint = (iree_hal_hsa_timepoint_t*)tp; + if (signal_timepoint->kind == IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_SIGNAL && + signal_timepoint->base.minimum_value >= min_value) { + *out_event = signal_timepoint->timepoint.device_signal; + iree_hal_hsa_event_retain(*out_event); + break; + } + } + iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex); + + IREE_TRACE_ZONE_END(z0); + return *out_event != NULL; +} + +static iree_status_t iree_hal_hsa_semaphore_wait( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_timeout_t timeout) { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(base_semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_slim_mutex_lock(&semaphore->mutex); + if (!iree_status_is_ok(semaphore->failure_status)) { + // Fastest path: failed; return an error to tell callers to query for it. + iree_slim_mutex_unlock(&semaphore->mutex); + IREE_TRACE_ZONE_END(z0); + return iree_status_from_code(IREE_STATUS_ABORTED); + } + if (semaphore->current_value >= value) { + // Fast path: already satisfied. + iree_slim_mutex_unlock(&semaphore->mutex); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + if (iree_timeout_is_immediate(timeout)) { + // Not satisfied but a poll, so can avoid the expensive wait handle work. + iree_slim_mutex_unlock(&semaphore->mutex); + IREE_TRACE_ZONE_END(z0); + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + iree_slim_mutex_unlock(&semaphore->mutex); + + iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout); + + // Slow path: try to see if we can have a device signal to wait on. This + // should happen outside of the lock given that acquiring has its own internal + // locks. This is faster than waiting on a host timepoint. + iree_hal_hsa_event_t* wait_event = NULL; + if (iree_hal_hsa_semaphore_acquire_event_host_wait(semaphore, value, + &wait_event)) { + semaphore->symbols->hsa_signal_wait_scacquire( + iree_hal_hsa_signal_handle(wait_event), HSA_SIGNAL_CONDITION_EQ, 0, + UINT64_MAX, HSA_WAIT_STATE_BLOCKED); + + iree_hal_hsa_event_release(wait_event); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + // Slow path: acquire a timepoint. This should happen outside of the lock too + // given that acquiring has its own internal locks. + iree_hal_hsa_timepoint_t* timepoint = NULL; + iree_status_t status = iree_hal_hsa_semaphore_acquire_timepoint_host_wait( + semaphore, value, timeout, &timepoint); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Wait until the timepoint resolves. + // If satisfied the timepoint is automatically cleaned up and we are done. If + // the deadline is reached before satisfied then we have to clean it up. + status = iree_wait_one(&timepoint->timepoint.host_wait, deadline_ns); + if (!iree_status_is_ok(status)) { + iree_hal_semaphore_cancel_timepoint(&semaphore->base, &timepoint->base); + } + iree_hal_hsa_timepoint_pool_release(semaphore->timepoint_pool, 1, &timepoint); + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_hal_hsa_semaphore_multi_wait( + const iree_hal_semaphore_list_t semaphore_list, + iree_hal_wait_mode_t wait_mode, iree_timeout_t timeout, + iree_arena_block_pool_t* block_pool) { + if (semaphore_list.count == 0) return iree_ok_status(); + + if (semaphore_list.count == 1) { + // Fast-path for a single semaphore. + return iree_hal_semaphore_wait(semaphore_list.semaphores[0], + semaphore_list.payload_values[0], timeout); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout); + + // Avoid heap allocations by using the device block pool for the wait set. + iree_arena_allocator_t arena; + iree_arena_initialize(block_pool, &arena); + iree_wait_set_t* wait_set = NULL; + iree_status_t status = iree_wait_set_allocate( + semaphore_list.count, iree_arena_allocator(&arena), &wait_set); + + // Acquire a host wait handle for each semaphore timepoint we are to wait on. + iree_host_size_t timepoint_count = 0; + iree_hal_hsa_timepoint_t** timepoints = NULL; + iree_host_size_t total_timepoint_size = + semaphore_list.count * sizeof(timepoints[0]); + bool needs_wait = true; + status = + iree_arena_allocate(&arena, total_timepoint_size, (void**)&timepoints); + if (iree_status_is_ok(status)) { + memset(timepoints, 0, total_timepoint_size); + for (iree_host_size_t i = 0; i < semaphore_list.count && needs_wait; ++i) { + uint64_t current_value = 0; + status = iree_hal_hsa_semaphore_query(semaphore_list.semaphores[i], + ¤t_value); + if (!iree_status_is_ok(status)) break; + + if (current_value >= semaphore_list.payload_values[i]) { + // Fast path: already satisfied. + // If in ANY wait mode, this is sufficient and we don't actually need + // to wait. This also skips acquiring timepoints for any remaining + // semaphores. We still exit normally otherwise so as to cleanup + // any timepoints already acquired. + if (wait_mode == IREE_HAL_WAIT_MODE_ANY) needs_wait = false; + } else { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(semaphore_list.semaphores[i]); + + // Slow path: get a native host wait handle for the timepoint. This + // should happen outside of the lock given that acquiring has its own + // internal locks. + iree_hal_hsa_timepoint_t* timepoint = NULL; + status = iree_hal_hsa_semaphore_acquire_timepoint_host_wait( + semaphore, semaphore_list.payload_values[i], timeout, &timepoint); + if (iree_status_is_ok(status)) { + timepoints[timepoint_count++] = timepoint; + status = + iree_wait_set_insert(wait_set, timepoint->timepoint.host_wait); + } + if (!iree_status_is_ok(status)) break; + } + } + } + + // Perform the wait. + if (iree_status_is_ok(status) && needs_wait) { + if (wait_mode == IREE_HAL_WAIT_MODE_ANY) { + status = iree_wait_any(wait_set, deadline_ns, /*out_wake_handle=*/NULL); + } else { + status = iree_wait_all(wait_set, deadline_ns); + } + } + + for (iree_host_size_t i = 0; i < timepoint_count; ++i) { + iree_hal_hsa_timepoint_t* timepoint = timepoints[i]; + iree_hal_semaphore_t* semaphore = timepoint->base.semaphore; + // Cancel if this is still an unresolved host wait. + if (semaphore) { + iree_hal_semaphore_cancel_timepoint(semaphore, &timepoint->base); + } + iree_hal_hsa_timepoint_pool_release(timepoint->pool, 1, &timepoint); + } + iree_wait_set_free(wait_set); + iree_arena_deinitialize(&arena); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Handles device signal timepoints on the host when the |semaphore| timeline +// advances past the given |value|. +// +// Note that this callback is invoked by the a host thread after the HSA host +// function callback function is triggered in the HSA driver. +static iree_status_t iree_hal_hsa_semaphore_timepoint_device_signal_callback( + void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value, + iree_status_code_t status_code) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hsa_timepoint_t* timepoint = (iree_hal_hsa_timepoint_t*)user_data; + // Just release the timepoint back to the pool. This will decrease the + // reference count of the underlying HSA event internally. + iree_hal_hsa_timepoint_pool_release(timepoint->pool, 1, &timepoint); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Acquires a timepoint to signal the timeline to the given |to_value| from the +// device. +iree_status_t iree_hal_hsa_event_semaphore_acquire_timepoint_device_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t to_value, + hsa_signal_t* out_signal) { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(base_semaphore); + iree_hal_hsa_timepoint_t* signal_timepoint = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_timepoint_pool_acquire_device_signal( + semaphore->timepoint_pool, 1, &signal_timepoint)); + + // Initialize the timepoint with the value and callback, and connect it to + // this semaphore. + iree_hal_semaphore_acquire_timepoint( + &semaphore->base, to_value, iree_infinite_timeout(), + (iree_hal_semaphore_callback_t){ + .fn = iree_hal_hsa_semaphore_timepoint_device_signal_callback, + .user_data = signal_timepoint, + }, + &signal_timepoint->base); + iree_hal_hsa_event_t* event = signal_timepoint->timepoint.device_signal; + + // Scan through the timepoint list and update device wait timepoints to wait + // for this device signal when possible. We need to lock with the timepoint + // list mutex here. + iree_slim_mutex_lock(&semaphore->base.timepoint_mutex); + for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head; + tp != NULL; tp = tp->next) { + iree_hal_hsa_timepoint_t* wait_timepoint = (iree_hal_hsa_timepoint_t*)tp; + if (wait_timepoint->kind == IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_WAIT && + wait_timepoint->timepoint.device_wait == NULL && + wait_timepoint->base.minimum_value <= to_value) { + iree_hal_hsa_event_retain(event); + wait_timepoint->timepoint.device_wait = event; + } + } + iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex); + + // *out_event = iree_hal_hsa_event_handle(event); + *out_signal = iree_hal_hsa_signal_handle(event); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Handles device wait timepoints on the host when the |semaphore| timeline +// advances past the given |value|. +// +// Note that this callback is invoked by the a host thread. +static iree_status_t iree_hal_hsa_semaphore_timepoint_device_wait_callback( + void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value, + iree_status_code_t status_code) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hsa_timepoint_t* timepoint = (iree_hal_hsa_timepoint_t*)user_data; + // Just release the timepoint back to the pool. This will decrease the + // reference count of the underlying HSA event internally. + iree_hal_hsa_timepoint_pool_release(timepoint->pool, 1, &timepoint); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Acquires a timepoint to wait the timeline to reach at least the given +// |min_value| on the device. +iree_status_t iree_hal_hsa_event_semaphore_acquire_timepoint_device_wait( + iree_hal_semaphore_t* base_semaphore, uint64_t min_value, + hsa_signal_t* out_signal) { + iree_hal_hsa_semaphore_t* semaphore = + iree_hal_hsa_semaphore_cast(base_semaphore); + iree_hal_hsa_timepoint_t* wait_timepoint = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_timepoint_pool_acquire_device_wait( + semaphore->timepoint_pool, 1, &wait_timepoint)); + + // Initialize the timepoint with the value and callback, and connect it to + // this semaphore. + iree_hal_semaphore_acquire_timepoint( + &semaphore->base, min_value, iree_infinite_timeout(), + (iree_hal_semaphore_callback_t){ + .fn = iree_hal_hsa_semaphore_timepoint_device_wait_callback, + .user_data = wait_timepoint, + }, + &wait_timepoint->base); + + iree_hal_hsa_event_t* wait_event = NULL; + if (iree_hal_hsa_semaphore_acquire_event_host_wait(semaphore, min_value, + &wait_event)) { + // We've found an existing signal timepoint to wait on; we don't need a + // standalone wait timepoint anymore. Decrease its refcount before + // overwriting it to return it back to the pool and retain the existing one. + iree_hal_hsa_event_release(wait_timepoint->timepoint.device_wait); + wait_timepoint->timepoint.device_wait = wait_event; + } + + *out_signal = + iree_hal_hsa_signal_handle(wait_timepoint->timepoint.device_wait); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static const iree_hal_semaphore_vtable_t iree_hal_hsa_semaphore_vtable = { + .destroy = iree_hal_hsa_semaphore_destroy, + .query = iree_hal_hsa_semaphore_query, + .signal = iree_hal_hsa_semaphore_signal, + .fail = iree_hal_hsa_semaphore_fail, + .wait = iree_hal_hsa_semaphore_wait, +}; diff --git a/runtime/src/iree-amd-aie/driver/hsa/event_semaphore.h b/runtime/src/iree-amd-aie/driver/hsa/event_semaphore.h new file mode 100644 index 000000000..de1010eec --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/event_semaphore.h @@ -0,0 +1,64 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_EVENT_SEMAPHORE_H_ +#define IREE_EXPERIMENTAL_HSA_EVENT_SEMAPHORE_H_ + +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/pending_queue_actions.h" +#include "iree-amd-aie/driver/hsa/timepoint_pool.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates an IREE HAL semaphore with the given |initial_value|. +// +// The HAL semaphore are backed by iree_event_t or hsa_signal_t objects for +// different timepoints along the timeline under the hood. Those timepoints will +// be allocated from the |timepoint_pool|. +// +// This semaphore is meant to be used together with a pending queue actions; it +// may advance the given |pending_queue_actions| if new values are signaled. +// +// Thread-safe; multiple threads may signal/wait values on the same semaphore. +iree_status_t iree_hal_hsa_event_semaphore_create( + uint64_t initial_value, const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_hal_hsa_pending_queue_actions_t* pending_queue_actions, + iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore); + +// Acquires a timepoint to signal the timeline to the given |to_value| from the +// device. The underlying HSA event is written into |out_event| for interacting +// with HSA APIs. +iree_status_t iree_hal_hsa_event_semaphore_acquire_timepoint_device_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t to_value, + hsa_signal_t* out_signal); + +// Acquires a timepoint to wait the timeline to reach at least the given +// |min_value| on the device. The underlying HSA event is written into +// |out_event| for interacting with HSA APIs. +iree_status_t iree_hal_hsa_event_semaphore_acquire_timepoint_device_wait( + iree_hal_semaphore_t* base_semaphore, uint64_t min_value, + hsa_signal_t* out_signal); + +// Performs a multi-wait on one or more semaphores. Returns +// IREE_STATUS_DEADLINE_EXCEEDED if the wait does not complete before |timeout|. +iree_status_t iree_hal_hsa_semaphore_multi_wait( + const iree_hal_semaphore_list_t semaphore_list, + iree_hal_wait_mode_t wait_mode, iree_timeout_t timeout, + iree_arena_block_pool_t* block_pool); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_EVENT_SEMAPHORE_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.c b/runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.c new file mode 100644 index 000000000..a1a05c256 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.c @@ -0,0 +1,708 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/hsa_allocator.h" + +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/hsa_buffer.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" +#include "iree/base/tracing.h" + +typedef struct iree_hal_hsa_allocator_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + hsa_agent_t hsa_agent; + + hsa_agent_t cpu_agent; + hsa_amd_memory_pool_t cpu_pool; + + // One memory pool and region for now + hsa_amd_memory_pool_t buffers_pool; + hsa_region_t kernel_argument_pool; + + const iree_hal_hsa_dynamic_symbols_t* symbols; + + iree_allocator_t host_allocator; + + // Whether the GPU and CPU can concurrently access HSA managed data in a + // coherent way. We would need to explicitly perform flushing and invalidation + // between GPU and CPU if not. + bool supports_concurrent_managed_access; + + IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;) +} iree_hal_hsa_allocator_t; + +static const iree_hal_allocator_vtable_t iree_hal_hsa_allocator_vtable; + +static iree_hal_hsa_allocator_t* iree_hal_hsa_allocator_cast( + iree_hal_allocator_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_allocator_vtable); + return (iree_hal_hsa_allocator_t*)base_value; +} + +static hsa_status_t get_kernarg_memory_region(hsa_region_t region, + void* allocator_untyped) { + iree_hal_hsa_allocator_t* allocator = + (iree_hal_hsa_allocator_t*)(allocator_untyped); + + hsa_region_segment_t segment; + allocator->symbols->hsa_region_get_info(region, HSA_REGION_INFO_SEGMENT, + &segment); + if (HSA_REGION_SEGMENT_GLOBAL != segment) { + return HSA_STATUS_SUCCESS; + } + + hsa_region_global_flag_t flags; + allocator->symbols->hsa_region_get_info(region, HSA_REGION_INFO_GLOBAL_FLAGS, + &flags); + if (flags & HSA_REGION_GLOBAL_FLAG_KERNARG) { + hsa_region_t* ret = (hsa_region_t*)(&(allocator->kernel_argument_pool)); + *ret = region; + return HSA_STATUS_INFO_BREAK; + } + + return HSA_STATUS_SUCCESS; +} + +static hsa_status_t get_fine_grained_memory_pool(hsa_amd_memory_pool_t pool, + void* allocator_untyped) { + iree_hal_hsa_allocator_t* allocator = + (iree_hal_hsa_allocator_t*)(allocator_untyped); + + hsa_amd_segment_t segment; + hsa_status_t status = allocator->symbols->hsa_amd_memory_pool_get_info( + pool, HSA_AMD_MEMORY_POOL_INFO_SEGMENT, &segment); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + if (segment != HSA_AMD_SEGMENT_GLOBAL) { + return HSA_STATUS_SUCCESS; + } + + uint32_t flags; + status = allocator->symbols->hsa_amd_memory_pool_get_info( + pool, HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, &flags); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + + bool is_fine_grained = + (flags & (HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED | + HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED)); + bool is_kernel_arg_region = + (flags & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT); + + if (is_fine_grained && !is_kernel_arg_region) { + allocator->buffers_pool = pool; + return HSA_STATUS_INFO_BREAK; + } + return HSA_STATUS_SUCCESS; +} + +static hsa_status_t iterate_find_cpu_agent_callback(hsa_agent_t agent, + void* base_allocator) { + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + + hsa_device_type_t type; + hsa_status_t status = allocator->symbols->hsa_agent_get_info( + agent, HSA_AGENT_INFO_DEVICE, &type); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + if (type == HSA_DEVICE_TYPE_CPU) { + allocator->cpu_agent = agent; + } + return HSA_STATUS_SUCCESS; +} + +static hsa_status_t iterate_find_cpu_agent_pool_callback( + hsa_amd_memory_pool_t pool, void* base_allocator) { + iree_hal_hsa_allocator_t* allocator = + (iree_hal_hsa_allocator_t*)(base_allocator); + + hsa_amd_segment_t segment; + hsa_status_t status = allocator->symbols->hsa_amd_memory_pool_get_info( + pool, HSA_AMD_MEMORY_POOL_INFO_SEGMENT, &segment); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + if (segment != HSA_AMD_SEGMENT_GLOBAL) { + return HSA_STATUS_SUCCESS; + } + + uint32_t flags; + status = allocator->symbols->hsa_amd_memory_pool_get_info( + pool, HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, &flags); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + + bool is_fine_grained = + (flags & (HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_FINE_GRAINED | + HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_EXTENDED_SCOPE_FINE_GRAINED)); + bool is_kernel_arg_region = + (flags & HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT); + + if (is_fine_grained && !is_kernel_arg_region) { + allocator->cpu_pool = pool; + return HSA_STATUS_INFO_BREAK; + } + return HSA_STATUS_SUCCESS; +} + +iree_status_t iree_hal_hsa_allocator_create( + const iree_hal_hsa_dynamic_symbols_t* hsa_symbols, hsa_agent_t agent, + iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { + IREE_ASSERT_ARGUMENT(hsa_symbols); + IREE_ASSERT_ARGUMENT(out_allocator); + IREE_TRACE_ZONE_BEGIN(z0); + + // To support device-local + host-visible memory we need concurrent managed + // access indicating that the host and devices can concurrently access the + // device memory. If we don't have this feature then we fall back to forcing + // all device-local + host-visible memory into host-local + device-visible + // page-locked memory. The compiler tries to avoid this for high-traffic + // buffers except for readback staging buffers. + int supports_concurrent_managed_access = 1; + + IREE_TRACE_ZONE_APPEND_TEXT( + z0, supports_concurrent_managed_access + ? "has CONCURRENT_MANAGED_ACCESS" + : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on " + "device-local + host-visible memory)"); + + iree_hal_hsa_allocator_t* allocator = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*allocator), + (void**)&allocator)); + iree_hal_resource_initialize(&iree_hal_hsa_allocator_vtable, + &allocator->resource); + allocator->hsa_agent = agent; + allocator->symbols = hsa_symbols; + allocator->host_allocator = host_allocator; + allocator->supports_concurrent_managed_access = + supports_concurrent_managed_access != 0; + + hsa_symbols->hsa_agent_iterate_regions(agent, get_kernarg_memory_region, + allocator); + hsa_symbols->hsa_amd_agent_iterate_memory_pools( + agent, get_fine_grained_memory_pool, allocator); + + hsa_symbols->hsa_iterate_agents(&iterate_find_cpu_agent_callback, + (void*)allocator); + hsa_symbols->hsa_amd_agent_iterate_memory_pools( + allocator->cpu_agent, &iterate_find_cpu_agent_pool_callback, + (void*)allocator); + + *out_allocator = (iree_hal_allocator_t*)allocator; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_hsa_allocator_destroy( + iree_hal_allocator_t* IREE_RESTRICT base_allocator) { + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(allocator->host_allocator, allocator); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_allocator_t iree_hal_hsa_allocator_host_allocator( + const iree_hal_allocator_t* IREE_RESTRICT base_allocator) { + iree_hal_hsa_allocator_t* allocator = + (iree_hal_hsa_allocator_t*)base_allocator; + return allocator->host_allocator; +} + +static iree_status_t iree_hal_hsa_allocator_trim( + iree_hal_allocator_t* IREE_RESTRICT base_allocator) { + return iree_ok_status(); +} + +static void iree_hal_hsa_allocator_query_statistics( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_allocator_statistics_t* IREE_RESTRICT out_statistics) { + IREE_STATISTICS({ + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + memcpy(out_statistics, &allocator->statistics, sizeof(*out_statistics)); + }); +} + +static iree_status_t iree_hal_hsa_allocator_query_memory_heaps( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_host_size_t capacity, + iree_hal_allocator_memory_heap_t* IREE_RESTRICT heaps, + iree_host_size_t* IREE_RESTRICT out_count) { + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + + iree_host_size_t count = 3; + if (allocator->supports_concurrent_managed_access) { + ++count; // device-local | host-visible + } + if (out_count) *out_count = count; + if (capacity < count) { + // NOTE: lightweight as this is hit in normal pre-sizing usage. + return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); + } + + // Don't think there's a query for these. + // Max allocation size may be much smaller in certain memory types such as + // page-locked memory and it'd be good to enforce that. + const iree_device_size_t max_allocation_size = ~(iree_device_size_t)0; + const iree_device_size_t min_alignment = 64; + + int i = 0; + + // Device-local memory (dispatch resources): + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + .allowed_usage = + IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_DISPATCH, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + + if (allocator->supports_concurrent_managed_access) { + // Device-local managed memory with host mapping support: + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + } + + // Write-combined page-locked host-local memory (upload): + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + + // Cached page-locked host-local memory (download): + heaps[i++] = (iree_hal_allocator_memory_heap_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_COHERENT | + IREE_HAL_MEMORY_TYPE_HOST_CACHED, + .allowed_usage = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH | + IREE_HAL_BUFFER_USAGE_MAPPING, + .max_allocation_size = max_allocation_size, + .min_alignment = min_alignment, + }; + + IREE_ASSERT(i == count); + return iree_ok_status(); +} + +static iree_hal_buffer_compatibility_t +iree_hal_hsa_allocator_query_buffer_compatibility( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_buffer_params_t* IREE_RESTRICT params, + iree_device_size_t* IREE_RESTRICT allocation_size) { + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + + // All buffers can be allocated on the heap. + iree_hal_buffer_compatibility_t compatibility = + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; + + // Buffers are importable in HSA under most cases, though performance may + // vary wildly. We don't fully verify that the buffer parameters are + // self-consistent and just look at whether we can get a device pointer. + if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE; + } + + // Buffers can only be used on the queue if they are device visible. + if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } + if (iree_any_bit_set(params->usage, + IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; + } + } + + if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + // Device local and host visible in general is much more slower than device + // only for discrete GPUs. So mark as so accordingly. + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE; + // If concurrent managed access is not supported then make device-local + + // host-visible allocations fall back to host-local + device-visible + // page-locked memory. This will be significantly slower for the device to + // access but the compiler only uses this type for readback staging buffers + // and it's better to function than function fast. + if (!allocator->supports_concurrent_managed_access) { + params->type &= ~(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE); + params->type |= + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE; + } + } + + // We are now optimal. + params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL; + + // Guard against the corner case where the requested buffer size is 0. The + // application is unlikely to do anything when requesting a 0-byte buffer; but + // it can happen in real world use cases. So we should at least not crash. + if (*allocation_size == 0) *allocation_size = 4; + + return compatibility; +} + +static void iree_hal_hsa_buffer_free( + const iree_hal_hsa_dynamic_symbols_t* hsa_symbols, + iree_hal_hsa_buffer_type_t buffer_type, hsa_device_pointer_t device_ptr, + void* host_ptr) { + IREE_TRACE_ZONE_BEGIN(z0); + switch (buffer_type) { + case IREE_HAL_HSA_BUFFER_TYPE_DEVICE: { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "hsa_amd_memory_pool_free"); + IREE_HSA_IGNORE_ERROR(hsa_symbols, hsa_amd_memory_pool_free(device_ptr)); + break; + } + case IREE_HAL_HSA_BUFFER_TYPE_HOST: { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "hsa_amd_memory_pool_free"); + IREE_HSA_IGNORE_ERROR(hsa_symbols, hsa_amd_memory_pool_free(device_ptr)); + break; + } + case IREE_HAL_HSA_BUFFER_TYPE_HOST_REGISTERED: { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "host unregister"); + break; + } + case IREE_HAL_HSA_BUFFER_TYPE_ASYNC: { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "(ignored; async)"); + break; + } + case IREE_HAL_HSA_BUFFER_TYPE_EXTERNAL: { + IREE_TRACE_ZONE_APPEND_TEXT(z0, "(ignored; external)"); + break; + } + case IREE_HAL_HSA_BUFFER_TYPE_KERNEL_ARG: { + IREE_HSA_IGNORE_ERROR(hsa_symbols, hsa_memory_free(device_ptr)); + break; + } + } + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_hsa_allocator_allocate_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + const iree_hal_buffer_params_t* IREE_RESTRICT params, + iree_device_size_t allocation_size, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + + // Coerce options into those required by the current device. + iree_hal_buffer_params_t compat_params = *params; + iree_hal_buffer_compatibility_t compatibility = + iree_hal_hsa_allocator_query_buffer_compatibility( + base_allocator, &compat_params, &allocation_size); + + if (!iree_all_bits_set(compatibility, + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE)) { +#if IREE_STATUS_MODE + iree_bitfield_string_temp_t temp0, temp1, temp2; + iree_string_view_t memory_type_str = + iree_hal_memory_type_format(params->type, &temp0); + iree_string_view_t usage_str = + iree_hal_buffer_usage_format(params->usage, &temp1); + iree_string_view_t compatibility_str = + iree_hal_buffer_compatibility_format(compatibility, &temp2); + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "allocator cannot allocate a buffer with the given parameters; " + "memory_type=%.*s, usage=%.*s, compatibility=%.*s", + (int)memory_type_str.size, memory_type_str.data, (int)usage_str.size, + usage_str.data, (int)compatibility_str.size, compatibility_str.data); +#else + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "allocator cannot allocate a buffer with the given parameters"); +#endif // IREE_STATUS_MODE + } + + iree_status_t status = iree_ok_status(); + iree_hal_hsa_buffer_type_t buffer_type = IREE_HAL_HSA_BUFFER_TYPE_DEVICE; + void* host_ptr = NULL; + hsa_device_pointer_t device_ptr = NULL; + IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hsa_buffer_allocate"); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size); + + // TODO(muhaawad): Not sure if this is the right way to do kernel arguments + // allocations + if (iree_all_bits_set(compat_params.usage, + IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | + IREE_HAL_BUFFER_USAGE_TRANSFER) && + iree_all_bits_set( + compat_params.access, + IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE) && + iree_all_bits_set(compat_params.type, + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + // Kernel arguments + IREE_HSA_RETURN_IF_ERROR( + allocator->symbols, + hsa_memory_allocate(allocator->kernel_argument_pool, allocation_size, + &host_ptr), + "hsa_memory_allocate"); + buffer_type = IREE_HAL_HSA_BUFFER_TYPE_KERNEL_ARG; + device_ptr = host_ptr; + } else if (iree_all_bits_set(compat_params.type, + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { + // Device local case. + buffer_type = IREE_HAL_HSA_BUFFER_TYPE_DEVICE; + if (iree_all_bits_set(compat_params.type, + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + status = IREE_HSA_RESULT_TO_STATUS( + allocator->symbols, + hsa_amd_memory_pool_allocate(allocator->buffers_pool, allocation_size, + /*flags=*/0, &device_ptr)); + host_ptr = (void*)device_ptr; + + } else { + // Device only. + buffer_type = IREE_HAL_HSA_BUFFER_TYPE_DEVICE; + + status = IREE_HSA_RESULT_TO_STATUS( + allocator->symbols, + hsa_amd_memory_pool_allocate(allocator->buffers_pool, allocation_size, + /*flags=*/0, &device_ptr)); + } + } else { + buffer_type = IREE_HAL_HSA_BUFFER_TYPE_HOST; + status = IREE_HSA_RESULT_TO_STATUS( + allocator->symbols, + hsa_amd_memory_pool_allocate(allocator->buffers_pool, allocation_size, + /*flags=*/0, &host_ptr)); + device_ptr = host_ptr; + } + IREE_TRACE_ZONE_END(z0); + + iree_hal_buffer_t* buffer = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_hsa_buffer_wrap( + base_allocator, compat_params.type, compat_params.access, + compat_params.usage, allocation_size, + /*byte_offset=*/0, + /*byte_length=*/allocation_size, buffer_type, device_ptr, host_ptr, + iree_hal_buffer_release_callback_null(), + iree_hal_allocator_host_allocator(base_allocator), &buffer); + } + + if (iree_status_is_ok(status)) { + IREE_TRACE_ALLOC_NAMED(IREE_HAL_HSA_ALLOCATOR_ID, + (void*)iree_hal_hsa_buffer_device_pointer(buffer), + allocation_size); + IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc( + &allocator->statistics, compat_params.type, allocation_size)); + *out_buffer = buffer; + } else { + if (!buffer && (device_ptr || host_ptr)) { + iree_hal_hsa_buffer_free(allocator->symbols, buffer_type, device_ptr, + host_ptr); + } else { + iree_hal_buffer_release(buffer); + } + } + return status; +} + +static void iree_hal_hsa_allocator_deallocate_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_buffer_t* IREE_RESTRICT base_buffer) { + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + + const iree_hal_hsa_buffer_type_t buffer_type = + iree_hal_hsa_buffer_type(base_buffer); + + iree_hal_hsa_buffer_free(allocator->symbols, buffer_type, + iree_hal_hsa_buffer_device_pointer(base_buffer), + iree_hal_hsa_buffer_host_pointer(base_buffer)); + + switch (buffer_type) { + case IREE_HAL_HSA_BUFFER_TYPE_DEVICE: + case IREE_HAL_HSA_BUFFER_TYPE_HOST: { + IREE_TRACE_FREE_NAMED( + IREE_HAL_HSA_ALLOCATOR_ID, + (void*)iree_hal_hsa_buffer_device_pointer(base_buffer)); + IREE_STATISTICS(iree_hal_allocator_statistics_record_free( + &allocator->statistics, iree_hal_buffer_memory_type(base_buffer), + iree_hal_buffer_allocation_size(base_buffer))); + break; + } + default: + // Buffer type not tracked. + break; + } + + iree_hal_buffer_destroy(base_buffer); +} + +static iree_status_t iree_hal_hsa_allocator_import_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + const iree_hal_buffer_params_t* IREE_RESTRICT params, + iree_hal_external_buffer_t* IREE_RESTRICT external_buffer, + iree_hal_buffer_release_callback_t release_callback, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + iree_hal_hsa_allocator_t* allocator = + iree_hal_hsa_allocator_cast(base_allocator); + // Coerce options into those required by the current device. + iree_hal_buffer_params_t compat_params = *params; + iree_device_size_t allocation_size = external_buffer->size; + iree_hal_buffer_compatibility_t compatibility = + iree_hal_hsa_allocator_query_buffer_compatibility( + base_allocator, &compat_params, &allocation_size); + if (!iree_all_bits_set(compatibility, + IREE_HAL_BUFFER_COMPATIBILITY_IMPORTABLE)) { +#if IREE_STATUS_MODE + iree_bitfield_string_temp_t temp0, temp1, temp2; + iree_string_view_t memory_type_str = + iree_hal_memory_type_format(params->type, &temp0); + iree_string_view_t usage_str = + iree_hal_buffer_usage_format(params->usage, &temp1); + iree_string_view_t compatibility_str = + iree_hal_buffer_compatibility_format(compatibility, &temp2); + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "allocator cannot import a buffer with the given parameters; " + "memory_type=%.*s, usage=%.*s, compatibility=%.*s", + (int)memory_type_str.size, memory_type_str.data, (int)usage_str.size, + usage_str.data, (int)compatibility_str.size, compatibility_str.data); +#else + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "allocator cannot import a buffer with the given parameters"); +#endif // IREE_STATUS_MODE + } + + iree_status_t status = iree_ok_status(); + iree_hal_hsa_buffer_type_t buffer_type = IREE_HAL_HSA_BUFFER_TYPE_DEVICE; + void* host_ptr = NULL; + hsa_device_pointer_t device_ptr = NULL; + + switch (external_buffer->type) { + case IREE_HAL_EXTERNAL_BUFFER_TYPE_HOST_ALLOCATION: { + uint32_t flags = 0; + int num_agents = 1; + status = IREE_HSA_RESULT_TO_STATUS( + allocator->symbols, + hsa_amd_memory_lock_to_pool(host_ptr, external_buffer->size, + &allocator->cpu_agent, num_agents, + allocator->cpu_pool, flags, device_ptr), + "hsa_amd_memory_lock_to_pool"); + + break; + } + case IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION: { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "not yet implemented"); + } + case IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_FD: + case IREE_HAL_EXTERNAL_BUFFER_TYPE_OPAQUE_WIN32: + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "handle-based imports not yet implemented"); + default: + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "external buffer type not supported"); + } + + iree_hal_buffer_t* buffer = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_hsa_buffer_wrap( + base_allocator, compat_params.type, compat_params.access, + compat_params.usage, external_buffer->size, + /*byte_offset=*/0, + /*byte_length=*/external_buffer->size, buffer_type, device_ptr, + host_ptr, release_callback, + iree_hal_allocator_host_allocator(base_allocator), &buffer); + } + + if (iree_status_is_ok(status)) { + *out_buffer = buffer; + } else { + if (!buffer && (device_ptr || host_ptr)) { + iree_hal_hsa_buffer_free(allocator->symbols, buffer_type, device_ptr, + host_ptr); + } else { + iree_hal_buffer_release(buffer); + } + } + return status; +} + +static iree_status_t iree_hal_hsa_allocator_export_buffer( + iree_hal_allocator_t* IREE_RESTRICT base_allocator, + iree_hal_buffer_t* IREE_RESTRICT buffer, + iree_hal_external_buffer_type_t requested_type, + iree_hal_external_buffer_flags_t requested_flags, + iree_hal_external_buffer_t* IREE_RESTRICT out_external_buffer) { + iree_hal_hsa_buffer_type_t buffer_type = iree_hal_hsa_buffer_type(buffer); + + switch (requested_type) { + case IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION: + switch (buffer_type) { + case IREE_HAL_HSA_BUFFER_TYPE_EXTERNAL: + out_external_buffer->flags = requested_flags; + out_external_buffer->type = requested_type; + out_external_buffer->handle.device_allocation.ptr = + ((uint64_t)(uintptr_t)iree_hal_hsa_buffer_device_pointer(buffer)); + out_external_buffer->size = iree_hal_buffer_allocation_size(buffer); + return iree_ok_status(); + + default: + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "HSA buffer type is not supported for " + "export as an external device allocation"); + } + + default: + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "external buffer type not supported"); + } +} + +static const iree_hal_allocator_vtable_t iree_hal_hsa_allocator_vtable = { + .destroy = iree_hal_hsa_allocator_destroy, + .host_allocator = iree_hal_hsa_allocator_host_allocator, + .trim = iree_hal_hsa_allocator_trim, + .query_statistics = iree_hal_hsa_allocator_query_statistics, + .query_memory_heaps = iree_hal_hsa_allocator_query_memory_heaps, + .query_buffer_compatibility = + iree_hal_hsa_allocator_query_buffer_compatibility, + .allocate_buffer = iree_hal_hsa_allocator_allocate_buffer, + .deallocate_buffer = iree_hal_hsa_allocator_deallocate_buffer, + .import_buffer = iree_hal_hsa_allocator_import_buffer, + .export_buffer = iree_hal_hsa_allocator_export_buffer, +}; diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.h b/runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.h new file mode 100644 index 000000000..272a0f8b8 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_allocator.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_ALLOCATOR_H_ +#define IREE_EXPERIMENTAL_HSA_ALLOCATOR_H_ + +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +iree_status_t iree_hal_hsa_allocator_create( + const iree_hal_hsa_dynamic_symbols_t* hsa_symbols, hsa_agent_t agent, + iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_ALLOCATOR_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_buffer.c b/runtime/src/iree-amd-aie/driver/hsa/hsa_buffer.c new file mode 100644 index 000000000..f4b79fc94 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_buffer.c @@ -0,0 +1,172 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/hsa_buffer.h" + +#include +#include +#include + +#include "iree/base/api.h" +#include "iree/base/tracing.h" + +typedef struct iree_hal_hsa_buffer_t { + iree_hal_buffer_t base; + iree_hal_hsa_buffer_type_t type; + void* host_ptr; + hsa_device_pointer_t device_ptr; + iree_hal_buffer_release_callback_t release_callback; +} iree_hal_hsa_buffer_t; + +static const iree_hal_buffer_vtable_t iree_hal_hsa_buffer_vtable; + +static iree_hal_hsa_buffer_t* iree_hal_hsa_buffer_cast( + iree_hal_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_buffer_vtable); + return (iree_hal_hsa_buffer_t*)base_value; +} + +static const iree_hal_hsa_buffer_t* iree_hal_hsa_buffer_const_cast( + const iree_hal_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_buffer_vtable); + return (const iree_hal_hsa_buffer_t*)base_value; +} + +iree_status_t iree_hal_hsa_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + iree_hal_hsa_buffer_type_t buffer_type, hsa_device_pointer_t device_ptr, + void* host_ptr, iree_hal_buffer_release_callback_t release_callback, + iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(out_buffer); + if (!host_ptr && iree_any_bit_set(allowed_usage, + IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT | + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "mappable buffers require host pointers"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_buffer_t* buffer = NULL; + iree_status_t status = + iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer); + if (iree_status_is_ok(status)) { + iree_hal_buffer_initialize(host_allocator, allocator, &buffer->base, + allocation_size, byte_offset, byte_length, + memory_type, allowed_access, allowed_usage, + &iree_hal_hsa_buffer_vtable, &buffer->base); + buffer->type = buffer_type; + buffer->host_ptr = host_ptr; + buffer->device_ptr = device_ptr; + buffer->release_callback = release_callback; + *out_buffer = &buffer->base; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_hsa_buffer_destroy(iree_hal_buffer_t* base_buffer) { + iree_hal_hsa_buffer_t* buffer = iree_hal_hsa_buffer_cast(base_buffer); + iree_allocator_t host_allocator = base_buffer->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + if (buffer->release_callback.fn) { + buffer->release_callback.fn(buffer->release_callback.user_data, + base_buffer); + } + iree_allocator_free(host_allocator, buffer); + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_hsa_buffer_map_range( + iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode, + iree_hal_memory_access_t memory_access, + iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, + iree_hal_buffer_mapping_t* mapping) { + iree_hal_hsa_buffer_t* buffer = iree_hal_hsa_buffer_cast(base_buffer); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( + iree_hal_buffer_memory_type(base_buffer), + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(base_buffer), + mapping_mode == IREE_HAL_MAPPING_MODE_PERSISTENT + ? IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT + : IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)); + + uint8_t* data_ptr = (uint8_t*)(buffer->host_ptr) + local_byte_offset; + // If we mapped for discard scribble over the bytes. This is not a mandated + // behavior but it will make debugging issues easier. Alternatively for + // heap buffers we could reallocate them such that ASAN yells, but that + // would only work if the entire buffer was discarded. +#ifndef NDEBUG + if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { + memset(data_ptr, 0xCD, local_byte_length); + } +#endif // !NDEBUG + + mapping->contents = iree_make_byte_span(data_ptr, local_byte_length); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_buffer_unmap_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) { + // Nothing to do today. + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_buffer_invalidate_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + // Nothing to do today. + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_buffer_flush_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + // Nothing to do today. + return iree_ok_status(); +} + +iree_hal_hsa_buffer_type_t iree_hal_hsa_buffer_type( + const iree_hal_buffer_t* base_buffer) { + const iree_hal_hsa_buffer_t* buffer = + iree_hal_hsa_buffer_const_cast(base_buffer); + return buffer->type; +} + +hsa_device_pointer_t iree_hal_hsa_buffer_device_pointer( + const iree_hal_buffer_t* base_buffer) { + const iree_hal_hsa_buffer_t* buffer = + iree_hal_hsa_buffer_const_cast(base_buffer); + return buffer->device_ptr; +} + +void* iree_hal_hsa_buffer_host_pointer(const iree_hal_buffer_t* base_buffer) { + const iree_hal_hsa_buffer_t* buffer = + iree_hal_hsa_buffer_const_cast(base_buffer); + return buffer->host_ptr; +} + +void iree_hal_hsa_buffer_drop_release_callback(iree_hal_buffer_t* base_buffer) { + iree_hal_hsa_buffer_t* buffer = iree_hal_hsa_buffer_cast(base_buffer); + buffer->release_callback = iree_hal_buffer_release_callback_null(); +} + +static const iree_hal_buffer_vtable_t iree_hal_hsa_buffer_vtable = { + .recycle = iree_hal_buffer_recycle, + .destroy = iree_hal_hsa_buffer_destroy, + .map_range = iree_hal_hsa_buffer_map_range, + .unmap_range = iree_hal_hsa_buffer_unmap_range, + .invalidate_range = iree_hal_hsa_buffer_invalidate_range, + .flush_range = iree_hal_hsa_buffer_flush_range, +}; diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_buffer.h b/runtime/src/iree-amd-aie/driver/hsa/hsa_buffer.h new file mode 100644 index 000000000..1ba1d1f39 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_buffer.h @@ -0,0 +1,74 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_BUFFER_H_ +#define IREE_EXPERIMENTAL_HSA_BUFFER_H_ + +#include "iree-amd-aie/driver/hsa/hsa_headers.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef void* hsa_device_pointer_t; + +typedef enum iree_hal_hsa_buffer_type_e { + // Device local buffer + IREE_HAL_HSA_BUFFER_TYPE_DEVICE = 0, + // Host local buffer + IREE_HAL_HSA_BUFFER_TYPE_HOST, + // Host local buffer. + IREE_HAL_HSA_BUFFER_TYPE_HOST_REGISTERED, + // Device local buffer. + IREE_HAL_HSA_BUFFER_TYPE_ASYNC, + // Externally registered buffer whose providence is unknown. + // Must be freed by the user. + IREE_HAL_HSA_BUFFER_TYPE_EXTERNAL, + // Kernel arguments buffer + IREE_HAL_HSA_BUFFER_TYPE_KERNEL_ARG, + +} iree_hal_hsa_buffer_type_t; + +// Wraps a HSA allocation in an iree_hal_buffer_t. +iree_status_t iree_hal_hsa_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + iree_hal_hsa_buffer_type_t buffer_type, hsa_device_pointer_t device_ptr, + void* host_ptr, iree_hal_buffer_release_callback_t release_callback, + iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer); + +// Returns the underlying HSA buffer type. +iree_hal_hsa_buffer_type_t iree_hal_hsa_buffer_type( + const iree_hal_buffer_t* buffer); + +// Returns the HSA base pointer for the given |buffer|. +// This is the entire allocated_buffer and must be offset by the buffer +// byte_offset and byte_length when used. +hsa_device_pointer_t iree_hal_hsa_buffer_device_pointer( + const iree_hal_buffer_t* buffer); + +hsa_device_pointer_t iree_hal_hsa_buffer_device_pointer( + const iree_hal_buffer_t* buffer); + +// Returns the HSA host pointer for the given |buffer|, if available. +void* iree_hal_hsa_buffer_host_pointer(const iree_hal_buffer_t* buffer); + +// Drops the release callback so that when the buffer is destroyed no callback +// will be made. This is not thread safe but all callers are expected to be +// holding an allocation and the earliest the buffer could be destroyed is after +// this call returns and the caller has released its reference. +void iree_hal_hsa_buffer_drop_release_callback(iree_hal_buffer_t* buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_BUFFER_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_device.c b/runtime/src/iree-amd-aie/driver/hsa/hsa_device.c new file mode 100644 index 000000000..94838cff5 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_device.c @@ -0,0 +1,604 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/hsa_device.h" + +#include +#include +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/event_pool.h" +#include "iree-amd-aie/driver/hsa/event_semaphore.h" +#include "iree-amd-aie/driver/hsa/hsa_allocator.h" +#include "iree-amd-aie/driver/hsa/hsa_buffer.h" +#include "iree-amd-aie/driver/hsa/nop_executable_cache.h" +#include "iree-amd-aie/driver/hsa/pending_queue_actions.h" +#include "iree-amd-aie/driver/hsa/pipeline_layout.h" +#include "iree-amd-aie/driver/hsa/queue_command_buffer.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree-amd-aie/driver/hsa/timepoint_pool.h" +#include "iree/base/internal/arena.h" +#include "iree/base/internal/event_pool.h" +#include "iree/base/internal/math.h" +#include "iree/base/tracing.h" +#include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/file_transfer.h" +#include "iree/hal/utils/memory_file.h" + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_device_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hsa_device_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + iree_string_view_t identifier; + + // Block pool used for command buffers with a larger block size (as command + // buffers can contain inlined data uploads). + iree_arena_block_pool_t block_pool; + + // Optional driver that owns the HSA symbols. We retain it for our lifetime + // to ensure the symbols remains valid. + iree_hal_driver_t* driver; + + const iree_hal_hsa_dynamic_symbols_t* hsa_symbols; + + // Parameters used to control device behavior. + iree_hal_hsa_device_params_t params; + + // The hsa agent + hsa_agent_t hsa_agent; + + // The queue where we will dispatch work + hsa_queue_t* hsa_dispatch_queue; + + // The host allocator + iree_allocator_t host_allocator; + + // Host/device event pools, used for backing semaphore timepoints. + iree_event_pool_t* host_event_pool; + iree_hal_hsa_event_pool_t* device_event_pool; + // Timepoint pools, shared by various semaphores. + iree_hal_hsa_timepoint_pool_t* timepoint_pool; + + // A queue to order device workloads and relase to the GPU when constraints + // are met. It buffers submissions and allocations internally before they + // are ready. This queue couples with HAL semaphores backed by iree_event_t + // and hsa_signal_t objects. + iree_hal_hsa_pending_queue_actions_t* pending_queue_actions; + + // Device allocator. + iree_hal_allocator_t* device_allocator; +} iree_hal_hsa_device_t; + +static const iree_hal_device_vtable_t iree_hal_hsa_device_vtable; + +static iree_hal_hsa_device_t* iree_hal_hsa_device_cast( + iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_device_vtable); + return (iree_hal_hsa_device_t*)base_value; +} + +static iree_hal_hsa_device_t* iree_hal_hsa_device_cast_unsafe( + iree_hal_device_t* base_value) { + return (iree_hal_hsa_device_t*)base_value; +} + +IREE_API_EXPORT void iree_hal_hsa_device_params_initialize( + iree_hal_hsa_device_params_t* out_params) { + memset(out_params, 0, sizeof(*out_params)); + out_params->arena_block_size = 32 * 1024; + out_params->event_pool_capacity = 32; + out_params->queue_count = 1; + out_params->queue_tracing = false; +} + +static iree_status_t iree_hal_hsa_device_check_params( + const iree_hal_hsa_device_params_t* params) { + if (params->arena_block_size < 4096) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "arena block size too small (< 4096 bytes)"); + } + if (params->queue_count == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "at least one queue is required"); + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_device_create_internal( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_hsa_device_params_t* params, hsa_agent_t agent, + hsa_queue_t* dispatch_queue, const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_hsa_device_t* device = NULL; + iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&device)); + + iree_hal_resource_initialize(&iree_hal_hsa_device_vtable, &device->resource); + iree_string_view_append_to_buffer( + identifier, &device->identifier, + (char*)device + iree_sizeof_struct(*device)); + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->block_pool); + device->driver = driver; + iree_hal_driver_retain(device->driver); + device->hsa_symbols = symbols; + device->params = *params; + device->hsa_agent = agent; + device->hsa_dispatch_queue = dispatch_queue; + device->host_allocator = host_allocator; + + iree_status_t status = iree_hal_hsa_pending_queue_actions_create( + symbols, &device->block_pool, host_allocator, + &device->pending_queue_actions); + + if (iree_status_is_ok(status)) { + status = iree_hal_hsa_allocator_create(symbols, agent, host_allocator, + &device->device_allocator); + } + + if (iree_status_is_ok(status)) { + *out_device = (iree_hal_device_t*)device; + } else { + iree_hal_device_release((iree_hal_device_t*)device); + } + return status; +} + +iree_status_t iree_hal_hsa_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_hsa_device_params_t* params, + const iree_hal_hsa_dynamic_symbols_t* symbols, hsa_agent_t agent, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(driver); + IREE_ASSERT_ARGUMENT(params); + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(out_device); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_hsa_device_check_params(params); + + size_t num_queue_packets = 1024; + hsa_queue_type_t queue_type = HSA_QUEUE_TYPE_MULTI; + void* callback = NULL; + void* data = NULL; + uint32_t private_segment_size = 0; + uint32_t group_segment_size = 0; + hsa_queue_t* dispatch_queue; + + IREE_HSA_RETURN_IF_ERROR( + symbols, + hsa_queue_create(agent, num_queue_packets, queue_type, callback, data, + private_segment_size, group_segment_size, + &dispatch_queue), + "hsa_queue_create"); + + status = iree_hal_hsa_device_create_internal(driver, identifier, params, + agent, dispatch_queue, symbols, + host_allocator, out_device); + + iree_event_pool_t* host_event_pool = NULL; + if (iree_status_is_ok(status)) { + status = iree_event_pool_allocate(params->event_pool_capacity, + host_allocator, &host_event_pool); + } + + iree_hal_hsa_event_pool_t* device_event_pool = NULL; + if (iree_status_is_ok(status)) { + status = + iree_hal_hsa_event_pool_allocate(symbols, params->event_pool_capacity, + host_allocator, &device_event_pool); + } + + iree_hal_hsa_timepoint_pool_t* timepoint_pool = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_hsa_timepoint_pool_allocate( + host_event_pool, device_event_pool, params->event_pool_capacity, + host_allocator, &timepoint_pool); + } + + if (iree_status_is_ok(status)) { + iree_hal_hsa_device_t* hsa_device = iree_hal_hsa_device_cast(*out_device); + hsa_device->host_event_pool = host_event_pool; + hsa_device->device_event_pool = device_event_pool; + hsa_device->timepoint_pool = timepoint_pool; + } else { + // Release resources we have accquired after HAL device creation. + if (timepoint_pool) iree_hal_hsa_timepoint_pool_free(timepoint_pool); + if (device_event_pool) iree_hal_hsa_event_pool_release(device_event_pool); + if (host_event_pool) iree_event_pool_free(host_event_pool); + // Release other resources via the HAL device. + iree_hal_device_release(*out_device); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +const iree_hal_hsa_dynamic_symbols_t* iree_hal_hsa_device_dynamic_symbols( + iree_hal_device_t* base_device) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast_unsafe(base_device); + return device->hsa_symbols; +} + +static void iree_hal_hsa_device_destroy(iree_hal_device_t* base_device) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + + // Destroy the pending workload queue. + iree_hal_hsa_pending_queue_actions_destroy( + (iree_hal_resource_t*)device->pending_queue_actions); + + // There should be no more buffers live that use the allocator. + iree_hal_allocator_release(device->device_allocator); + + // Destroy various pools for synchronization. + if (device->timepoint_pool) { + iree_hal_hsa_timepoint_pool_free(device->timepoint_pool); + } + if (device->device_event_pool) { + iree_hal_hsa_event_pool_release(device->device_event_pool); + } + if (device->host_event_pool) iree_event_pool_free(device->host_event_pool); + + iree_arena_block_pool_deinitialize(&device->block_pool); + + // Finally, destroy the device. + iree_hal_driver_release(device->driver); + + iree_allocator_free(host_allocator, device); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_string_view_t iree_hal_hsa_device_id( + iree_hal_device_t* base_device) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return device->identifier; +} + +static iree_allocator_t iree_hal_hsa_device_host_allocator( + iree_hal_device_t* base_device) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return device->host_allocator; +} + +static iree_hal_allocator_t* iree_hal_hsa_device_allocator( + iree_hal_device_t* base_device) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return device->device_allocator; +} + +static void iree_hal_hsa_replace_device_allocator( + iree_hal_device_t* base_device, iree_hal_allocator_t* new_allocator) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + iree_hal_allocator_retain(new_allocator); + iree_hal_allocator_release(device->device_allocator); + device->device_allocator = new_allocator; +} + +static void iree_hal_hsa_replace_channel_provider( + iree_hal_device_t* base_device, iree_hal_channel_provider_t* new_provider) { +} + +static iree_status_t iree_hal_hsa_device_trim(iree_hal_device_t* base_device) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "memory pools are not supported"); +} + +static iree_status_t iree_hal_hsa_device_query_i64( + iree_hal_device_t* base_device, iree_string_view_t category, + iree_string_view_t key, int64_t* out_value) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + *out_value = 0; + + if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { + *out_value = + iree_string_view_match_pattern(device->identifier, key) ? 1 : 0; + return iree_ok_status(); + } + + if (iree_string_view_equal(category, IREE_SV("hal.executable.format"))) { + *out_value = iree_string_view_equal(key, IREE_SV("rocm-hsaco-fb")) ? 1 : 0; + return iree_ok_status(); + } + + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "unknown device configuration key value '%.*s :: %.*s'", + (int)category.size, category.data, (int)key.size, key.data); +} + +static iree_status_t iree_hal_hsa_device_create_channel( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "channel not yet implemented"); +} + +iree_status_t iree_hal_hsa_device_create_queue_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, + iree_hal_command_buffer_t** out_command_buffer) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return iree_hal_hsa_queue_command_buffer_create( + base_device, device->hsa_symbols, mode, command_categories, + binding_capacity, device->hsa_dispatch_queue, &device->block_pool, + device->host_allocator, device->device_allocator, out_command_buffer); +} + +static iree_status_t iree_hal_hsa_device_create_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_hal_command_buffer_t** out_command_buffer) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + + return iree_hal_deferred_command_buffer_create( + iree_hal_device_allocator(base_device), mode, command_categories, + binding_capacity, &device->block_pool, + iree_hal_device_host_allocator(base_device), out_command_buffer); +} + +static iree_status_t iree_hal_hsa_device_create_descriptor_set_layout( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_flags_t flags, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return iree_hal_hsa_descriptor_set_layout_create( + flags, binding_count, bindings, device->host_allocator, + out_descriptor_set_layout); +} + +static iree_status_t iree_hal_hsa_device_create_event( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_event_flags_t flags, iree_hal_event_t** out_event) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "event not yet implemented"); +} + +static iree_status_t iree_hal_hsa_device_import_file( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_memory_access_t access, iree_io_file_handle_t* handle, + iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) { + if (iree_io_file_handle_type(handle) != + IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) { + return iree_make_status( + IREE_STATUS_UNAVAILABLE, + "implementation does not support the external file type"); + } + return iree_hal_memory_file_wrap( + queue_affinity, access, handle, iree_hal_device_allocator(base_device), + iree_hal_device_host_allocator(base_device), out_file); +} + +static iree_status_t iree_hal_hsa_device_create_executable_cache( + iree_hal_device_t* base_device, iree_string_view_t identifier, + iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return iree_hal_hsa_nop_executable_cache_create( + identifier, device->hsa_symbols, device->hsa_agent, + device->host_allocator, device->device_allocator, out_executable_cache); +} + +static iree_status_t iree_hal_hsa_device_create_pipeline_layout( + iree_hal_device_t* base_device, iree_host_size_t push_constants, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t* const* set_layouts, + iree_hal_pipeline_layout_t** out_pipeline_layout) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return iree_hal_hsa_pipeline_layout_create( + set_layout_count, set_layouts, push_constants, device->host_allocator, + out_pipeline_layout); +} + +static iree_status_t iree_hal_hsa_device_create_semaphore( + iree_hal_device_t* base_device, uint64_t initial_value, + iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return iree_hal_hsa_event_semaphore_create( + initial_value, device->hsa_symbols, device->timepoint_pool, + device->pending_queue_actions, device->host_allocator, out_semaphore); +} + +static iree_hal_semaphore_compatibility_t +iree_hal_hsa_device_query_semaphore_compatibility( + iree_hal_device_t* base_device, iree_hal_semaphore_t* semaphore) { + // TODO: implement HSA semaphores. + return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY; +} + +static iree_status_t iree_hal_hsa_device_queue_alloca( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, + iree_device_size_t allocation_size, + iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + // NOTE: block on the semaphores here; we could avoid this by properly + // sequencing device work with semaphores. The HSA HAL is not currently + // asynchronous. + IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, + iree_infinite_timeout())); + + iree_status_t status = + iree_hal_allocator_allocate_buffer(iree_hal_device_allocator(base_device), + params, allocation_size, out_buffer); + + // Only signal if not returning a synchronous error - synchronous failure + // indicates that the stream is unchanged (it's not really since we waited + // above, but we at least won't deadlock like this). + if (iree_status_is_ok(status)) { + status = iree_hal_semaphore_list_signal(signal_semaphore_list); + } + return status; +} + +static iree_status_t iree_hal_hsa_device_queue_dealloca( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* buffer) { + // NOTE: block on the semaphores here; we could avoid this by properly + // sequencing device work with semaphores. The HSA HAL is not currently + // asynchronous. + IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, + iree_infinite_timeout())); + + // Buffer will be freed when the buffer is released. + + // Only signal if not returning a synchronous error + iree_status_t status = iree_hal_semaphore_list_signal(signal_semaphore_list); + return status; +} + +static iree_status_t iree_hal_hsa_device_queue_read( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_file_t* source_file, uint64_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_file, source_offset, target_buffer, target_offset, length, flags, + options)); + return loop_status; +} + +static iree_status_t iree_hal_hsa_device_queue_write( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_file_t* target_file, uint64_t target_offset, + iree_device_size_t length, uint32_t flags) { + // TODO: expose streaming chunk count/size options. + iree_status_t loop_status = iree_ok_status(); + iree_hal_file_transfer_options_t options = { + .loop = iree_loop_inline(&loop_status), + .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, + .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, + }; + IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, + source_buffer, source_offset, target_file, target_offset, length, flags, + options)); + return loop_status; +} + +static void iree_hal_hsa_device_collect_tracing_context(void* user_data) {} + +static iree_status_t iree_hal_hsa_device_queue_execute( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers, + iree_hal_buffer_binding_table_t const* binding_tables) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_hsa_pending_queue_actions_enqueue_execution( + base_device, device->hsa_dispatch_queue, device->pending_queue_actions, + iree_hal_hsa_device_collect_tracing_context, wait_semaphore_list, + signal_semaphore_list, command_buffer_count, command_buffers); + if (iree_status_is_ok(status)) { + // Try to advance the pending workload queue. + status = + iree_hal_hsa_pending_queue_actions_issue(device->pending_queue_actions); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hsa_device_queue_flush( + iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + // Try to advance the pending workload queue. + iree_status_t status = + iree_hal_hsa_pending_queue_actions_issue(device->pending_queue_actions); + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hsa_device_wait_semaphores( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { + iree_hal_hsa_device_t* device = iree_hal_hsa_device_cast(base_device); + return iree_hal_hsa_semaphore_multi_wait(semaphore_list, wait_mode, timeout, + &device->block_pool); +} + +static iree_status_t iree_hal_hsa_device_profiling_begin( + iree_hal_device_t* base_device, + const iree_hal_device_profiling_options_t* options) { + // Unimplemented (and that's ok). + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_device_profiling_flush( + iree_hal_device_t* base_device) { + // Unimplemented (and that's ok). + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_device_profiling_end( + iree_hal_device_t* base_device) { + // Unimplemented (and that's ok). + return iree_ok_status(); +} + +static const iree_hal_device_vtable_t iree_hal_hsa_device_vtable = { + .destroy = iree_hal_hsa_device_destroy, + .id = iree_hal_hsa_device_id, + .host_allocator = iree_hal_hsa_device_host_allocator, + .device_allocator = iree_hal_hsa_device_allocator, + .replace_device_allocator = iree_hal_hsa_replace_device_allocator, + .replace_channel_provider = iree_hal_hsa_replace_channel_provider, + .trim = iree_hal_hsa_device_trim, + .query_i64 = iree_hal_hsa_device_query_i64, + .create_channel = iree_hal_hsa_device_create_channel, + .create_command_buffer = iree_hal_hsa_device_create_command_buffer, + .create_descriptor_set_layout = + iree_hal_hsa_device_create_descriptor_set_layout, + .create_event = iree_hal_hsa_device_create_event, + .create_executable_cache = iree_hal_hsa_device_create_executable_cache, + .import_file = iree_hal_hsa_device_import_file, + .create_pipeline_layout = iree_hal_hsa_device_create_pipeline_layout, + .create_semaphore = iree_hal_hsa_device_create_semaphore, + .query_semaphore_compatibility = + iree_hal_hsa_device_query_semaphore_compatibility, + .queue_alloca = iree_hal_hsa_device_queue_alloca, + .queue_dealloca = iree_hal_hsa_device_queue_dealloca, + .queue_read = iree_hal_hsa_device_queue_read, + .queue_write = iree_hal_hsa_device_queue_write, + .queue_execute = iree_hal_hsa_device_queue_execute, + .queue_flush = iree_hal_hsa_device_queue_flush, + .wait_semaphores = iree_hal_hsa_device_wait_semaphores, + .profiling_begin = iree_hal_hsa_device_profiling_begin, + .profiling_flush = iree_hal_hsa_device_profiling_flush, + .profiling_end = iree_hal_hsa_device_profiling_end, +}; diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_device.h b/runtime/src/iree-amd-aie/driver/hsa/hsa_device.h new file mode 100644 index 000000000..d6988f923 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_device.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_DEVICE_H_ +#define IREE_EXPERIMENTAL_HSA_DEVICE_H_ + +#include "iree-amd-aie/driver/hsa/api.h" +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a device +iree_status_t iree_hal_hsa_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_hsa_device_params_t* params, + const iree_hal_hsa_dynamic_symbols_t* symbols, hsa_agent_t agent, + iree_allocator_t host_allocator, iree_hal_device_t** out_device); + +// Creates a HSA queue-backed command buffer using resources from the +// given |base_device|. +iree_status_t iree_hal_hsa_device_create_queue_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns the dynamic symbol table from the |device| if it is a HSA device +// and otherwise returns NULL. +// +// WARNING: the symbols are only valid for as long as the device is. Hosting +// libraries and applications should prefer to either link against HSA +// themselves or maintain their own dynamic linking support: the IREE runtime +// only provides the symbols required by the HAL driver and not the entirety of +// the API. +const iree_hal_hsa_dynamic_symbols_t* iree_hal_hsa_device_dynamic_symbols( + iree_hal_device_t* device); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_DEVICE_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_driver.c b/runtime/src/iree-amd-aie/driver/hsa/hsa_driver.c new file mode 100644 index 000000000..a4090d3f4 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_driver.c @@ -0,0 +1,581 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "iree-amd-aie/driver/hsa/api.h" +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/hsa_device.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" +#include "iree/base/assert.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" + +// Maximum device name length supported by the HSA HAL driver. +#define IREE_HAL_HSA_MAX_DEVICE_NAME_LENGTH 64 + +#define IREE_HAL_HSA_MAX_DEVICES 64 +#define IREE_HAL_HSA_DEVICE_NOT_FOUND IREE_HAL_HSA_MAX_DEVICES + +// Utility macros to convert between hsa_agent_t ID and iree_hal_device_id_t. +#define IREE_DEVICE_ID_TO_HSADEVICE(device_id) (int)((device_id) - 1) + +typedef struct iree_hal_hsa_driver_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + iree_allocator_t host_allocator; + + // Identifier used for registering the driver in the IREE driver registry. + iree_string_view_t identifier; + // HSA driver API dynamic symbols to interact with the HSA system. + iree_hal_hsa_dynamic_symbols_t hsa_symbols; + + // The default parameters for creating devices using this driver. + iree_hal_hsa_device_params_t device_params; + + // The index of the default HSA device to use if multiple ones are available. + int default_device_index; + + // Number of GPU agents + int num_gpu_agents; + + // IREE device ID to hsa_agent_t + hsa_agent_t agents[IREE_HAL_HSA_MAX_DEVICES]; +} iree_hal_hsa_driver_t; + +typedef struct iree_hal_hsa_device_info_t { +} iree_hal_hsa_device_info_t; + +// A struct encapsulating common variables we need while communicating with HSA +// callbacks +typedef struct iree_hal_hsa_callback_package_t { + iree_hal_hsa_driver_t* driver; + size_t* index; + void* return_value; +} iree_hal_hsa_callback_package_t; + +static const iree_hal_driver_vtable_t iree_hal_hsa_driver_vtable; + +static iree_hal_hsa_driver_t* iree_hal_hsa_driver_cast( + iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_driver_vtable); + return (iree_hal_hsa_driver_t*)base_value; +} + +IREE_API_EXPORT void iree_hal_hsa_driver_options_initialize( + iree_hal_hsa_driver_options_t* out_options) { + IREE_ASSERT_ARGUMENT(out_options); + memset(out_options, 0, sizeof(*out_options)); + out_options->default_device_index = 0; +} + +hsa_status_t iterate_count_gpu_agents_callback(hsa_agent_t agent, + void* base_driver) { + iree_hal_hsa_callback_package_t* package = + (iree_hal_hsa_callback_package_t*)(base_driver); + iree_hal_hsa_driver_t* driver = package->driver; + int* count_ptr = (int*)package->return_value; + hsa_device_type_t type; + hsa_status_t status = + (&(driver->hsa_symbols)) + ->hsa_agent_get_info(agent, HSA_AGENT_INFO_DEVICE, &type); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + if (type == HSA_DEVICE_TYPE_GPU) { + *count_ptr = *count_ptr + 1; + } + return HSA_STATUS_SUCCESS; +} + +hsa_status_t iterate_populate_gpu_agents_callback(hsa_agent_t agent, + void* base_driver) { + iree_hal_hsa_callback_package_t* package = + (iree_hal_hsa_callback_package_t*)(base_driver); + iree_hal_hsa_driver_t* driver = package->driver; + size_t* index_ptr = package->index; + hsa_agent_t* agents_ptr = (hsa_agent_t*)package->return_value; + + hsa_device_type_t type; + hsa_status_t status = + (&(driver->hsa_symbols)) + ->hsa_agent_get_info(agent, HSA_AGENT_INFO_DEVICE, &type); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + + if (type == HSA_DEVICE_TYPE_GPU) { + size_t current_index = *index_ptr; + agents_ptr[current_index] = agent; + *index_ptr = current_index + 1; + } + return HSA_STATUS_SUCCESS; +} + +// Initializes the HSA system. +iree_status_t iree_hal_hsa_init(iree_hal_hsa_driver_t* driver) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_HSA_RESULT_TO_STATUS(&driver->hsa_symbols, hsa_init(), "hsa_init"); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Deinitializes the HSA system. +static iree_status_t iree_hal_hsa_shut_down(iree_hal_hsa_driver_t* driver) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = IREE_HSA_RESULT_TO_STATUS( + &driver->hsa_symbols, hsa_shut_down(), "hsa_shut_down"); + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hsa_driver_create_internal( + iree_string_view_t identifier, const iree_hal_hsa_driver_options_t* options, + const iree_hal_hsa_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + iree_hal_hsa_driver_t* driver = NULL; + iree_host_size_t total_size = iree_sizeof_struct(*driver) + identifier.size; + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&driver)); + + iree_hal_resource_initialize(&iree_hal_hsa_driver_vtable, &driver->resource); + driver->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &driver->identifier, + (char*)driver + iree_sizeof_struct(*driver)); + driver->default_device_index = options->default_device_index; + + iree_status_t status = iree_hal_hsa_dynamic_symbols_initialize( + host_allocator, &driver->hsa_symbols); + + status = iree_hal_hsa_init(driver); + + memcpy(&driver->device_params, device_params, sizeof(driver->device_params)); + + driver->num_gpu_agents = 0; + + // Populate HSA agents + // Query the number of available HSA devices. + iree_hal_hsa_callback_package_t symbols_and_device_count = { + .driver = driver, .return_value = &driver->num_gpu_agents}; + + IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR( + z0, &driver->hsa_symbols, + hsa_iterate_agents(&iterate_count_gpu_agents_callback, + &symbols_and_device_count), + "hsa_iterate_agents"); + + size_t agent_index = 0; + iree_hal_hsa_callback_package_t symbols_and_agents = { + .driver = driver, .index = &agent_index, .return_value = driver->agents}; + + IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR( + z0, &driver->hsa_symbols, + hsa_iterate_agents(&iterate_populate_gpu_agents_callback, + &symbols_and_agents), + "hsa_iterate_agents"); + + if (iree_status_is_ok(status)) { + *out_driver = (iree_hal_driver_t*)driver; + } else { + iree_hal_driver_release((iree_hal_driver_t*)driver); + } + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_hsa_driver_create( + iree_string_view_t identifier, const iree_hal_hsa_driver_options_t* options, + const iree_hal_hsa_device_params_t* device_params, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(device_params); + IREE_ASSERT_ARGUMENT(out_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_hal_hsa_driver_create_internal( + identifier, options, device_params, host_allocator, out_driver); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_hsa_driver_destroy(iree_hal_driver_t* base_driver) { + IREE_ASSERT_ARGUMENT(base_driver); + + iree_hal_hsa_driver_t* driver = iree_hal_hsa_driver_cast(base_driver); + iree_allocator_t host_allocator = driver->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + // iree_hal_hsa_shut_down(driver); + // iree_hal_hsa_dynamic_symbols_deinitialize(&driver->hsa_symbols); + + iree_allocator_free(host_allocator, driver); + + IREE_TRACE_ZONE_END(z0); +} + +// Methods to translate HSA agents to IREE Device ID +static iree_hal_device_id_t iree_hsadevice_to_device_id( + iree_hal_hsa_driver_t* driver, hsa_agent_t agent) { + iree_hal_device_id_t device_id = 0; + while (device_id != IREE_HAL_HSA_MAX_DEVICES && + driver->agents[device_id++].handle != agent.handle) + ; + + return device_id; +} + +static hsa_agent_t iree_device_id_to_hsadevice(iree_hal_hsa_driver_t* driver, + iree_hal_device_id_t device_id) { + return driver->agents[device_id]; +} + +static iree_status_t get_hsa_agent_uuid(iree_hal_hsa_dynamic_symbols_t* syms, + hsa_agent_t agent, + char* out_device_uuid) { + // `HSA_AMD_AGENT_INFO_UUID` is part of the `hsa_amd_agent_info_t` + // However, hsa_agent_get_info expects a hsa_agent_info_t. + hsa_agent_info_t uuid_info = (int)HSA_AMD_AGENT_INFO_UUID; + IREE_HSA_RETURN_IF_ERROR( + syms, hsa_agent_get_info(agent, uuid_info, out_device_uuid), + "hsa_agent_get_info"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_populate_device_info( + iree_hal_hsa_driver_t* driver, hsa_agent_t agent, + iree_hal_hsa_dynamic_symbols_t* syms, uint8_t* buffer_ptr, + uint8_t** out_buffer_ptr, iree_hal_device_info_t* out_device_info) { + *out_buffer_ptr = buffer_ptr; + + char device_name[IREE_HAL_HSA_MAX_DEVICE_NAME_LENGTH]; + + IREE_HSA_RETURN_IF_ERROR( + syms, hsa_agent_get_info(agent, HSA_AGENT_INFO_NAME, device_name), + "hsa_agent_get_info"); + memset(out_device_info, 0, sizeof(*out_device_info)); + + out_device_info->device_id = iree_hsadevice_to_device_id(driver, agent); + + // Maximum UUID is 21 + char device_uuid[21] = {0}; + get_hsa_agent_uuid(syms, agent, device_uuid); + + // HSA UUID is already prefixed with GPU- + char device_path_str[4 + 36 + 1] = {0}; + snprintf(device_path_str, sizeof(device_path_str), + "%c%c%c-" + "%02x%02x%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x%02x%02x%02x%02x", + device_uuid[0], device_uuid[1], device_uuid[2], + (uint8_t)device_uuid[4], (uint8_t)device_uuid[5], + (uint8_t)device_uuid[6], (uint8_t)device_uuid[7], + (uint8_t)device_uuid[8], (uint8_t)device_uuid[9], + (uint8_t)device_uuid[10], (uint8_t)device_uuid[11], + (uint8_t)device_uuid[12], (uint8_t)device_uuid[13], + (uint8_t)device_uuid[14], (uint8_t)device_uuid[15], + (uint8_t)device_uuid[16], (uint8_t)device_uuid[17], + (uint8_t)device_uuid[18], (uint8_t)device_uuid[19]); + + buffer_ptr += iree_string_view_append_to_buffer( + iree_make_string_view(device_path_str, + IREE_ARRAYSIZE(device_path_str) - 1), + &out_device_info->path, (char*)buffer_ptr); + + iree_string_view_t device_name_str = + iree_make_string_view(device_name, strlen(device_name)); + buffer_ptr += iree_string_view_append_to_buffer( + device_name_str, &out_device_info->name, (char*)buffer_ptr); + + *out_buffer_ptr = buffer_ptr; + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_driver_query_available_devices( + iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, + iree_host_size_t* out_device_info_count, + iree_hal_device_info_t** out_device_infos) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(out_device_info_count); + IREE_ASSERT_ARGUMENT(out_device_infos); + iree_hal_hsa_driver_t* driver = iree_hal_hsa_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Ensure HSA is initialized before querying it. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_hsa_init(driver)); + + int device_count = driver->num_gpu_agents; + + // Allocate the return infos and populate with the devices. + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t total_size = + device_count * (sizeof(iree_hal_device_info_t) + + IREE_HAL_HSA_MAX_DEVICE_NAME_LENGTH * sizeof(char)); + + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); + + hsa_agent_t* agents = driver->agents; + + int valid_device_count = 0; + if (iree_status_is_ok(status)) { + uint8_t* buffer_ptr = + (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t); + for (iree_host_size_t i = 0; i < device_count; ++i) { + hsa_agent_t device = agents[i]; + + status = iree_hal_hsa_populate_device_info( + driver, device, &driver->hsa_symbols, buffer_ptr, &buffer_ptr, + &device_infos[valid_device_count]); + if (!iree_status_is_ok(status)) break; + valid_device_count++; + } + } + if (iree_status_is_ok(status)) { + *out_device_info_count = valid_device_count; + *out_device_infos = device_infos; + } else { + iree_allocator_free(host_allocator, device_infos); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hsa_driver_dump_device_info( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_string_builder_t* builder) { + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_driver_select_default_device( + iree_hal_driver_t* base_driver, iree_hal_hsa_dynamic_symbols_t* syms, + int default_device_index, iree_allocator_t host_allocator, + hsa_agent_t* out_device) { + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t device_count = 0; + IREE_RETURN_IF_ERROR(iree_hal_hsa_driver_query_available_devices( + base_driver, host_allocator, &device_count, &device_infos)); + + iree_hal_hsa_driver_t* driver = iree_hal_hsa_driver_cast(base_driver); + + iree_status_t status = iree_ok_status(); + if (device_count == 0) { + status = iree_make_status(IREE_STATUS_UNAVAILABLE, + "no compatible HSA devices were found"); + } else if (default_device_index >= device_count) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "default device %d not found (of %" PRIhsz + " enumerated)", + default_device_index, device_count); + } else { + *out_device = iree_device_id_to_hsadevice(driver, default_device_index); + } + iree_allocator_free(host_allocator, device_infos); + + return status; +} + +static iree_status_t iree_hal_hsa_driver_create_device_by_id( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_host_size_t param_count, const iree_string_pair_t* params, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(out_device); + + iree_hal_hsa_driver_t* driver = iree_hal_hsa_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Ensure HSA is initialized before querying it. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_hsa_init(driver)); + + // Use either the specified device (enumerated earlier) or whatever default + // one was specified when the driver was created. + hsa_agent_t agent; + if (device_id == IREE_HAL_DEVICE_ID_DEFAULT) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_driver_select_default_device( + base_driver, &driver->hsa_symbols, driver->default_device_index, + host_allocator, &agent)); + } else { + agent = iree_device_id_to_hsadevice(driver, + IREE_DEVICE_ID_TO_HSADEVICE(device_id)); + } + + iree_string_view_t device_name = iree_make_cstring_view("hip"); + + // Attempt to create the device now. + iree_status_t status = iree_hal_hsa_device_create( + base_driver, device_name, &driver->device_params, &driver->hsa_symbols, + agent, host_allocator, out_device); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hsa_driver_create_device_by_uuid( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + char* device_uuid, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_hsa_driver_t* driver = iree_hal_hsa_driver_cast(base_driver); + + // Ensure HSA is initialized before querying it. + IREE_RETURN_IF_ERROR(iree_hal_hsa_init(driver)); + iree_status_t status; + // HSA doesn't have an API to do this so we need to scan all devices to + // find the one with the matching UUID. + int device_count = driver->num_gpu_agents; + + // Iterate over device info searching for the agent with the right UUID + bool found_device = false; + hsa_agent_t device = {0}; + for (iree_host_size_t i = 0; i < device_count; ++i) { + // Maximum UUID is 21 + char query_uuid[21] = {0}; + status = + get_hsa_agent_uuid(&driver->hsa_symbols, driver->agents[i], query_uuid); + char query_uuid_stripped[16] = {0}; + iree_string_view_t query_uuid_sv = iree_make_string_view(query_uuid, 21); + if (!iree_string_view_parse_hex_bytes(query_uuid_sv, + IREE_ARRAYSIZE(query_uuid_stripped), + (uint8_t*)query_uuid_stripped)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid UUID: '%.*s'", (int)query_uuid_sv.size, + query_uuid_sv.data); + } + if (!iree_status_is_ok(status)) break; + if (memcmp(device_uuid, query_uuid_stripped, sizeof(query_uuid_stripped)) == + 0) { + found_device = true; + break; + device = driver->agents[i]; + } + } + if (!found_device) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "HSA device with UUID " + "%02x%02x%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x-" + "%02x%02x", + (uint8_t)device_uuid[4], (uint8_t)device_uuid[5], + (uint8_t)device_uuid[6], (uint8_t)device_uuid[7], + (uint8_t)device_uuid[8], (uint8_t)device_uuid[9], + (uint8_t)device_uuid[10], (uint8_t)device_uuid[11], + (uint8_t)device_uuid[12], (uint8_t)device_uuid[13], + (uint8_t)device_uuid[14], (uint8_t)device_uuid[15]); + } + + iree_string_view_t device_name = iree_make_cstring_view("hip"); + + // Attempt to create the device now. + status = iree_hal_hsa_device_create( + base_driver, device_name, &driver->device_params, &driver->hsa_symbols, + device, host_allocator, out_device); + + return status; +} + +static iree_status_t iree_hal_hsa_driver_create_device_by_index( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + int device_index, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + iree_hal_hsa_driver_t* driver = iree_hal_hsa_driver_cast(base_driver); + + // Ensure HSA is initialized before querying it. + IREE_RETURN_IF_ERROR(iree_hal_hsa_init(driver)); + + // Query the number of available HSA devices. + int device_count = driver->num_gpu_agents; + if (device_index >= device_count) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "device %d not found (of %d enumerated)", + device_index, device_count); + } + + hsa_agent_t device = driver->agents[device_index]; + + iree_string_view_t device_name = iree_make_cstring_view("hip"); + + // Attempt to create the device now. + iree_status_t status = iree_hal_hsa_device_create( + base_driver, device_name, &driver->device_params, &driver->hsa_symbols, + device, host_allocator, out_device); + + return status; +} + +static iree_status_t iree_hal_hsa_driver_create_device_by_path( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + iree_string_view_t device_path, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(out_device); + + if (iree_string_view_is_empty(device_path)) { + return iree_hal_hsa_driver_create_device_by_id( + base_driver, IREE_HAL_DEVICE_ID_DEFAULT, param_count, params, + host_allocator, out_device); + } + + bool found = iree_string_view_consume_prefix(&device_path, IREE_SV("GPU-")); + + if (found) { + char device_uuid[16]; + if (!iree_string_view_parse_hex_bytes( + device_path, IREE_ARRAYSIZE(device_uuid), (uint8_t*)device_uuid)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid UUID: '%.*s'", (int)device_path.size, + device_path.data); + } + return iree_hal_hsa_driver_create_device_by_uuid( + base_driver, driver_name, device_uuid, param_count, params, + host_allocator, out_device); + } + + // Try to parse as a device index or device type + int device_index = -1; + + if (iree_string_view_consume_prefix(&device_path, IREE_SV("GPU")) || + iree_string_view_consume_prefix(&device_path, IREE_SV("gpu"))) { + device_index = 0; + } + + if (device_index != -1 || + iree_string_view_atoi_int32(device_path, &device_index)) { + return iree_hal_hsa_driver_create_device_by_index( + base_driver, driver_name, device_index, param_count, params, + host_allocator, out_device); + } + + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path"); +} + +static const iree_hal_driver_vtable_t iree_hal_hsa_driver_vtable = { + .destroy = iree_hal_hsa_driver_destroy, + .query_available_devices = iree_hal_hsa_driver_query_available_devices, + .dump_device_info = iree_hal_hsa_driver_dump_device_info, + .create_device_by_id = iree_hal_hsa_driver_create_device_by_id, + .create_device_by_path = iree_hal_hsa_driver_create_device_by_path, +}; + +#undef IREE_HAL_HSA_MAX_DEVICE_NAME_LENGTH +#undef IREE_HAL_HSA_MAX_DEVICES +#undef IREE_HAL_HSA_DEVICE_NOT_FOUND diff --git a/runtime/src/iree-amd-aie/driver/hsa/hsa_headers.h b/runtime/src/iree-amd-aie/driver/hsa/hsa_headers.h new file mode 100644 index 000000000..84005d414 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/hsa_headers.h @@ -0,0 +1,18 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_HSA_HEADERS_H_ +#define IREE_EXPERIMENTAL_HSA_HSA_HEADERS_H_ + +#if defined(IREE_PTR_SIZE_32) +#error "32-bit not supported on HSA backend" +#endif // defined(IREE_PTR_SIZE_32) + +#include "hsa/hsa.h" +#include "hsa/hsa_ext_amd.h" + +#endif // IREE_EXPERIMENTAL_HSA_HSA_HEADERS_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/native_executable.c b/runtime/src/iree-amd-aie/driver/hsa/native_executable.c new file mode 100644 index 000000000..db8e34eef --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/native_executable.c @@ -0,0 +1,422 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/native_executable.h" + +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" + +// flatcc schemas: +#include "iree/base/internal/flatcc/parsing.h" +// Using the existing ROCM schema fow now. +#include "iree/schemas/rocm_executable_def_reader.h" +#include "iree/schemas/rocm_executable_def_verifier.h" + +typedef struct iree_hal_hsa_native_executable_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + iree_allocator_t host_allocator; + + const iree_hal_hsa_dynamic_symbols_t* symbols; + + hsa_executable_t executable; + + uint64_t kernel_object; + + iree_host_size_t entry_point_count; + // The list of entry point data pointers, pointing to trailing inline + // allocation after the end of this struct. + iree_hal_hsa_kernel_info_t entry_points[]; +} iree_hal_hsa_native_executable_t; +// + Additional inline allocation for holding entry point information. + +static const iree_hal_executable_vtable_t iree_hal_hsa_native_executable_vtable; + +static iree_hal_hsa_native_executable_t* iree_hal_hsa_native_executable_cast( + iree_hal_executable_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_native_executable_vtable); + return (iree_hal_hsa_native_executable_t*)base_value; +} + +// Verifies the structure of the flatbuffer so that we can avoid doing so during +// runtime. +// +// There are still some conditions we must be aware of (such as omitted names on +// functions with internal linkage), however we shouldn't need to bounds check +// anything within the flatbuffer after this succeeds. +static iree_status_t iree_hal_hsa_native_executable_flatbuffer_verify( + iree_const_byte_span_t flatbuffer_data) { + if (!flatbuffer_data.data) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer data is not present"); + } + + // Run flatcc generated verification. This ensures all pointers are in-bounds + // and that we can safely walk the file, but not that the actual contents of + // the flatbuffer meet our expectations. + int verify_ret = iree_hal_rocm_ExecutableDef_verify_as_root( + flatbuffer_data.data, flatbuffer_data.data_length); + if (verify_ret != flatcc_verify_ok) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer verification failed: %s", + flatcc_verify_error_string(verify_ret)); + } + + iree_hal_rocm_ExecutableDef_table_t executable_def = + iree_hal_rocm_ExecutableDef_as_root(flatbuffer_data.data); + + flatbuffers_string_vec_t entry_points_vec = + iree_hal_rocm_ExecutableDef_entry_points_get(executable_def); + size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); + if (entry_point_count == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no entry points present"); + } + for (size_t i = 0; i < entry_point_count; ++i) { + if (flatbuffers_string_len( + flatbuffers_string_vec_at(entry_points_vec, i)) == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable entry point %zu has no name", i); + } + } + + iree_hal_rocm_BlockSizeDef_vec_t block_sizes_vec = + iree_hal_rocm_ExecutableDef_block_sizes_get(executable_def); + size_t block_size_count = iree_hal_rocm_BlockSizeDef_vec_len(block_sizes_vec); + if (entry_point_count != block_size_count) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "entry points (%zu) and block sizes (%zu) count mismatch", + entry_point_count, block_size_count); + } + + flatbuffers_uint32_vec_t shared_memory_sizes_vec = + iree_hal_rocm_ExecutableDef_shared_memory_sizes_get(executable_def); + size_t shared_memory_sizes_count = + flatbuffers_string_vec_len(shared_memory_sizes_vec); + if (entry_point_count != shared_memory_sizes_count) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "entry points (%zu) and shared memory sizes (%zu) count mismatch", + entry_point_count, shared_memory_sizes_count); + } + + flatbuffers_string_t hsaco_image = + iree_hal_rocm_ExecutableDef_hsaco_image_get(executable_def); + if (flatbuffers_string_len(hsaco_image) == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no HSACO image present"); + } + + return iree_ok_status(); +} + +typedef struct iree_hal_hsa_callback_package_t { + const iree_hal_hsa_dynamic_symbols_t* symbols; + unsigned int* return_value; +} iree_hal_hsa_callback_package_t; + +static hsa_status_t get_lds_size_callback(hsa_amd_memory_pool_t memory_pool, + void* data) { + iree_hal_hsa_callback_package_t* package = + (iree_hal_hsa_callback_package_t*)(data); + + hsa_amd_segment_t segment; + hsa_status_t status = package->symbols->hsa_amd_memory_pool_get_info( + memory_pool, HSA_AMD_MEMORY_POOL_INFO_SEGMENT, &segment); + if (status != HSA_STATUS_SUCCESS) { + return status; + } + + if (segment == HSA_AMD_SEGMENT_GROUP) { + unsigned int size; + status = package->symbols->hsa_amd_memory_pool_get_info( + memory_pool, HSA_AMD_MEMORY_POOL_INFO_SIZE, &size); + *package->return_value = size; + return HSA_STATUS_SUCCESS; + } + return HSA_STATUS_SUCCESS; +} + +iree_status_t iree_hal_hsa_native_executable_create( + const iree_hal_hsa_dynamic_symbols_t* symbols, hsa_agent_t agent, + const iree_hal_executable_params_t* executable_params, + iree_allocator_t host_allocator, iree_hal_allocator_t* device_allocator, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(device_allocator); + + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(executable_params); + IREE_ASSERT_ARGUMENT(out_executable); + IREE_TRACE_ZONE_BEGIN(z0); + + *out_executable = NULL; + iree_hal_hsa_native_executable_t* executable = NULL; + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_native_executable_flatbuffer_verify( + executable_params->executable_data)); + + iree_hal_rocm_ExecutableDef_table_t executable_def = + iree_hal_rocm_ExecutableDef_as_root( + executable_params->executable_data.data); + + flatbuffers_string_vec_t entry_points_vec = + iree_hal_rocm_ExecutableDef_entry_points_get(executable_def); + iree_hal_rocm_BlockSizeDef_vec_t block_sizes_vec = + iree_hal_rocm_ExecutableDef_block_sizes_get(executable_def); + flatbuffers_uint32_vec_t shared_memory_sizes_vec = + iree_hal_rocm_ExecutableDef_shared_memory_sizes_get(executable_def); + flatbuffers_string_t hsaco_image = + iree_hal_rocm_ExecutableDef_hsaco_image_get(executable_def); + iree_host_size_t entry_point_count = + flatbuffers_string_vec_len(entry_points_vec); + + // Calculate the total number of characters across all entry point names. This + // is only required when tracing so that we can store copies of the names as + // the flatbuffer storing the strings may be released while the executable is + // still live. + iree_host_size_t total_entry_point_name_chars = 0; + IREE_TRACE({ + for (iree_host_size_t i = 0; i < entry_point_count; i++) { + const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i); + total_entry_point_name_chars += flatbuffers_string_len(entry_name); + } + }); + + // Allocate storage for the kernel module. + iree_host_size_t total_size = + sizeof(*executable) + + entry_point_count * sizeof(executable->entry_points[0]) + + total_entry_point_name_chars; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, total_size, (void**)&executable)); + IREE_TRACE( + char* string_table_buffer = + (char*)((char*)executable + sizeof(*executable) + + entry_point_count * sizeof(executable->entry_points[0]))); + + iree_hal_resource_initialize(&iree_hal_hsa_native_executable_vtable, + &executable->resource); + + executable->host_allocator = host_allocator; + executable->symbols = symbols; + executable->entry_point_count = entry_point_count; + + iree_status_t status = iree_ok_status(); + + hsa_code_object_reader_t code_object_reader; + + size_t hsaco_image_size = flatbuffers_string_len(hsaco_image); + status = IREE_HSA_RESULT_TO_STATUS( + symbols, hsa_code_object_reader_create_from_memory( + hsaco_image, hsaco_image_size, &code_object_reader)); + + if (!iree_status_is_ok(status)) { + return status; + } + + hsa_executable_t hsa_executable; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, hsa_executable_create_alt( + HSA_PROFILE_FULL, HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, + NULL, &hsa_executable)); + if (!iree_status_is_ok(status)) { + return status; + } + status = IREE_HSA_RESULT_TO_STATUS( + symbols, hsa_executable_load_agent_code_object( + hsa_executable, agent, code_object_reader, NULL, NULL)); + if (!iree_status_is_ok(status)) { + return status; + } + + status = IREE_HSA_RESULT_TO_STATUS( + symbols, hsa_executable_freeze(hsa_executable, NULL)); + if (!iree_status_is_ok(status)) { + return status; + } + + for (iree_host_size_t i = 0; i < entry_point_count; i++) { + const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i); + + hsa_executable_symbol_t symbol; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, + hsa_executable_get_symbol_by_name(hsa_executable, entry_name, &agent, + &symbol), + "hsa_executable_get_symbol_by_name"); + if (!iree_status_is_ok(status)) { + iree_string_view_t name_view = iree_make_cstring_view(entry_name); + iree_string_view_t suffix_view = iree_make_cstring_view(".kd"); + iree_host_size_t total_length = name_view.size + suffix_view.size; + char* kd_entry_name = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, total_length + 1, + (void**)&kd_entry_name)); + + iree_string_view_t result_view; + iree_host_size_t copied_length = iree_string_view_append_to_buffer( + name_view, &result_view, kd_entry_name); + iree_string_view_append_to_buffer(suffix_view, &result_view, + kd_entry_name + copied_length); + + kd_entry_name[total_length] = '\0'; + + status = IREE_HSA_RESULT_TO_STATUS( + symbols, hsa_executable_get_symbol_by_name( + hsa_executable, kd_entry_name, &agent, &symbol)); + if (!iree_status_is_ok(status)) break; + } + + uint64_t kernel_object; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, + hsa_executable_symbol_get_info( + symbol, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, &kernel_object)); + + uint32_t private_segment_size; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, + hsa_executable_symbol_get_info( + symbol, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, + &private_segment_size)); + if (!iree_status_is_ok(status)) break; + + uint32_t group_segment_size; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, + hsa_executable_symbol_get_info( + symbol, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, + &group_segment_size)); + if (!iree_status_is_ok(status)) break; + + uint32_t kernarg_segment_size; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, + hsa_executable_symbol_get_info( + symbol, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, + &kernarg_segment_size)); + if (!iree_status_is_ok(status)) break; + + uint32_t kernarg_segment_align; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, + hsa_executable_symbol_get_info( + symbol, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_ALIGNMENT, + &kernarg_segment_align)); + if (!iree_status_is_ok(status)) break; + + unsigned int max_shared_memory; + iree_hal_hsa_callback_package_t lds_query_package = { + .symbols = symbols, .return_value = &max_shared_memory}; + status = IREE_HSA_RESULT_TO_STATUS( + symbols, hsa_amd_agent_iterate_memory_pools( + agent, get_lds_size_callback, &lds_query_package)); + + if (shared_memory_sizes_vec[i] > max_shared_memory) { + status = iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "function '%s' requested shared memory size of %u bytes larger " + "than allowed size of %u bytes", + entry_name, shared_memory_sizes_vec[i], max_shared_memory); + } + if (!iree_status_is_ok(status)) break; + + // Package required parameters for kernel launches for each entry point. + iree_hal_hsa_kernel_info_t* kernel_info = &executable->entry_points[i]; + kernel_info->layout = executable_params->pipeline_layouts[i]; + iree_hal_pipeline_layout_retain(kernel_info->layout); + kernel_info->kernel_object = kernel_object; + kernel_info->block_size[0] = block_sizes_vec[i].x; + kernel_info->block_size[1] = block_sizes_vec[i].y; + kernel_info->block_size[2] = block_sizes_vec[i].z; + kernel_info->shared_memory_size = shared_memory_sizes_vec[i]; + + kernel_info->private_segment_size = private_segment_size; + kernel_info->group_segment_size = group_segment_size; + kernel_info->kernarg_segment_size = kernarg_segment_size; + kernel_info->kernarg_segment_align = kernarg_segment_align; + + // Stash the entry point name in the string table for use when tracing. + IREE_TRACE({ + iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name); + memcpy(string_table_buffer, entry_name, entry_name_length); + kernel_info->function_name = + iree_make_string_view(string_table_buffer, entry_name_length); + string_table_buffer += entry_name_length; + }); + + IREE_TRACE({ + if (iree_hal_rocm_ExecutableDef_source_locations_is_present( + executable_def)) { + iree_hal_rocm_FileLineLocDef_vec_t source_locs_vec = + iree_hal_rocm_ExecutableDef_source_locations_get(executable_def); + iree_hal_rocm_FileLineLocDef_table_t source_loc = + iree_hal_rocm_FileLineLocDef_vec_at(source_locs_vec, i); + flatbuffers_string_t filename = + iree_hal_rocm_FileLineLocDef_filename_get(source_loc); + uint32_t line = iree_hal_rocm_FileLineLocDef_line_get(source_loc); + kernel_info->source_filename = + iree_make_string_view(filename, flatbuffers_string_len(filename)); + kernel_info->source_line = line; + } + }); + } + + if (iree_status_is_ok(status)) { + *out_executable = (iree_hal_executable_t*)executable; + } else { + iree_hal_executable_destroy((iree_hal_executable_t*)executable); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_hsa_native_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_hsa_native_executable_t* executable = + iree_hal_hsa_native_executable_cast(base_executable); + iree_allocator_t host_allocator = executable->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < executable->entry_point_count; ++i) { + iree_hal_pipeline_layout_release(executable->entry_points[i].layout); + } + IREE_HSA_IGNORE_ERROR(executable->symbols, + hsa_executable_destroy(executable->executable)); + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_hsa_native_executable_entry_point_kernel_info( + iree_hal_executable_t* base_executable, int32_t entry_point, + iree_hal_hsa_kernel_info_t* out_info) { + iree_hal_hsa_native_executable_t* executable = + iree_hal_hsa_native_executable_cast(base_executable); + if (entry_point >= executable->entry_point_count) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "entry point ordinal %d out of range; executable " + "only contains %ld entry points", + entry_point, executable->entry_point_count); + } + memcpy(out_info, &executable->entry_points[entry_point], sizeof(*out_info)); + return iree_ok_status(); +} + +static const iree_hal_executable_vtable_t + iree_hal_hsa_native_executable_vtable = { + .destroy = iree_hal_hsa_native_executable_destroy, +}; diff --git a/runtime/src/iree-amd-aie/driver/hsa/native_executable.h b/runtime/src/iree-amd-aie/driver/hsa/native_executable.h new file mode 100644 index 000000000..7921bfd90 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/native_executable.h @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_NATIVE_EXECUTABLE_H_ +#define IREE_EXPERIMENTAL_HSA_NATIVE_EXECUTABLE_H_ + +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/hsa_headers.h" +#include "iree/base/api.h" +#include "iree/base/tracing.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_hsa_kernel_info_t { + iree_hal_pipeline_layout_t* layout; + + uint64_t kernel_object; + + uint32_t block_size[3]; + uint32_t shared_memory_size; + + uint32_t private_segment_size; + uint32_t group_segment_size; + uint32_t kernarg_segment_size; + uint32_t kernarg_segment_align; + + IREE_TRACE(iree_string_view_t function_name;) + IREE_TRACE(iree_string_view_t source_filename;) + IREE_TRACE(uint32_t source_line;) +} iree_hal_hsa_kernel_info_t; + +// Creates an IREE executable from a HSACO module. The module may contain +// several kernels that can be extracted along with the associated block size. +iree_status_t iree_hal_hsa_native_executable_create( + const iree_hal_hsa_dynamic_symbols_t* symbols, hsa_agent_t agent, + const iree_hal_executable_params_t* executable_params, + iree_allocator_t host_allocator, iree_hal_allocator_t* device_allocator, + iree_hal_executable_t** out_executable); + +// Returns the kernel launch parameters for the given |entry_point| in the +// |executable|. +iree_status_t iree_hal_hsa_native_executable_entry_point_kernel_info( + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_hsa_kernel_info_t* out_info); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_NATIVE_EXECUTABLE_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/nop_executable_cache.c b/runtime/src/iree-amd-aie/driver/hsa/nop_executable_cache.c new file mode 100644 index 000000000..eecbd0be9 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/nop_executable_cache.c @@ -0,0 +1,111 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/nop_executable_cache.h" + +#include +#include + +#include "iree-amd-aie/driver/hsa/hsa_allocator.h" +#include "iree-amd-aie/driver/hsa/native_executable.h" +#include "iree/base/api.h" +#include "iree/base/tracing.h" + +typedef struct iree_hal_hsa_nop_executable_cache_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + iree_allocator_t host_allocator; + + iree_hal_allocator_t* device_allocator; + + const iree_hal_hsa_dynamic_symbols_t* symbols; + + hsa_agent_t agent; + +} iree_hal_hsa_nop_executable_cache_t; + +static const iree_hal_executable_cache_vtable_t + iree_hal_hsa_nop_executable_cache_vtable; + +static iree_hal_hsa_nop_executable_cache_t* +iree_hal_hsa_nop_executable_cache_cast( + iree_hal_executable_cache_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_nop_executable_cache_vtable); + return (iree_hal_hsa_nop_executable_cache_t*)base_value; +} + +iree_status_t iree_hal_hsa_nop_executable_cache_create( + iree_string_view_t identifier, + const iree_hal_hsa_dynamic_symbols_t* symbols, + hsa_agent_t agent, iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator, + iree_hal_executable_cache_t** out_executable_cache) { + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(out_executable_cache); + IREE_TRACE_ZONE_BEGIN(z0); + + *out_executable_cache = NULL; + iree_hal_hsa_nop_executable_cache_t* executable_cache = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*executable_cache), + (void**)&executable_cache)); + + iree_hal_resource_initialize(&iree_hal_hsa_nop_executable_cache_vtable, + &executable_cache->resource); + executable_cache->host_allocator = host_allocator; + executable_cache->device_allocator = device_allocator; + executable_cache->symbols = symbols; + executable_cache->agent = agent; + + *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_hsa_nop_executable_cache_destroy( + iree_hal_executable_cache_t* base_executable_cache) { + iree_hal_hsa_nop_executable_cache_t* executable_cache = + iree_hal_hsa_nop_executable_cache_cast(base_executable_cache); + iree_allocator_t host_allocator = executable_cache->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_cache); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_hsa_nop_executable_cache_can_prepare_format( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_caching_mode_t caching_mode, + iree_string_view_t executable_format) { + return iree_string_view_equal(executable_format, + iree_make_cstring_view("HSACO")); +} + +static iree_status_t iree_hal_hsa_nop_executable_cache_prepare_executable( + iree_hal_executable_cache_t* base_executable_cache, + const iree_hal_executable_params_t* executable_params, + iree_hal_executable_t** out_executable) { + iree_hal_hsa_nop_executable_cache_t* executable_cache = + iree_hal_hsa_nop_executable_cache_cast(base_executable_cache); + return iree_hal_hsa_native_executable_create( + executable_cache->symbols, executable_cache->agent, executable_params, + executable_cache->host_allocator, executable_cache->device_allocator, + out_executable); +} + +static const iree_hal_executable_cache_vtable_t + iree_hal_hsa_nop_executable_cache_vtable = { + .destroy = iree_hal_hsa_nop_executable_cache_destroy, + .can_prepare_format = + iree_hal_hsa_nop_executable_cache_can_prepare_format, + .prepare_executable = + iree_hal_hsa_nop_executable_cache_prepare_executable, +}; diff --git a/runtime/src/iree-amd-aie/driver/hsa/nop_executable_cache.h b/runtime/src/iree-amd-aie/driver/hsa/nop_executable_cache.h new file mode 100644 index 000000000..5710626aa --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/nop_executable_cache.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_NOP_EXECUTABLE_CACHE_H_ +#define IREE_EXPERIMENTAL_HSA_NOP_EXECUTABLE_CACHE_H_ + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/hsa_headers.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a no-op executable cache that does not cache at all. +// This is useful to isolate pipeline caching behavior and verify compilation +// behavior. +iree_status_t iree_hal_hsa_nop_executable_cache_create( + iree_string_view_t identifier, + const iree_hal_hsa_dynamic_symbols_t* symbols, + hsa_agent_t agent, iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator, + iree_hal_executable_cache_t** out_executable_cache); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_NOP_EXECUTABLE_CACHE_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/pending_queue_actions.c b/runtime/src/iree-amd-aie/driver/hsa/pending_queue_actions.c new file mode 100644 index 000000000..49a68f92f --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/pending_queue_actions.c @@ -0,0 +1,955 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/pending_queue_actions.h" + +#include +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/event_semaphore.h" +#include "iree-amd-aie/driver/hsa/hsa_device.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" +#include "iree/base/internal/arena.h" +#include "iree/base/internal/atomic_slist.h" +#include "iree/base/internal/atomics.h" +#include "iree/base/internal/synchronization.h" +#include "iree/base/internal/threading.h" +#include "iree/hal/api.h" +#include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/resource_set.h" + +// The maximal number of hsa_signal_t objects a command buffer can wait. +#define IREE_HAL_HSA_MAX_WAIT_EVENT_COUNT 32 + +//===----------------------------------------------------------------------===// +// Queue action +//===----------------------------------------------------------------------===// + +typedef enum iree_hal_hsa_queue_action_kind_e { + IREE_HAL_HSA_QUEUE_ACTION_TYPE_EXECUTION, + // TODO: Add support for queue alloca and dealloca. +} iree_hal_hsa_queue_action_kind_t; + +typedef enum iree_hal_hsa_queue_action_state_e { + // The current action is active as waiting for or under execution. + IREE_HAL_HSA_QUEUE_ACTION_STATE_ALIVE, + // The current action is done execution and waiting for destruction. + IREE_HAL_HSA_QUEUE_ACTION_STATE_ZOMBIE, +} iree_hal_hsa_queue_action_state_t; + +// A pending queue action. +// +// Note that this struct does not have internal synchronization; it's expected +// to work together with the pending action queue, which synchronizes accesses. +typedef struct iree_hal_hsa_queue_action_t { + // Intrusive doubly-linked list next entry pointer. + struct iree_hal_hsa_queue_action_t* next; + // Intrusive doubly-linked list previous entry pointer. + struct iree_hal_hsa_queue_action_t* prev; + + // The owning pending actions queue. We use its allocators and pools. + // Retained to make sure it outlives the current action. + iree_hal_hsa_pending_queue_actions_t* owning_actions; + + // The current state of this action. When an action is initially created it + // will be alive and enqueued to wait for releasing to the GPU. After done + // execution, it will be flipped into zombie state and enqueued again for + // destruction. + iree_hal_hsa_queue_action_state_t state; + // The callback to run after completing this action and before freeing + // all resources. Can be NULL. + iree_hal_hsa_pending_action_cleanup_callback_t cleanup_callback; + // User data to pass into the callback. + void* callback_user_data; + + iree_hal_hsa_queue_action_kind_t kind; + union { + struct { + iree_host_size_t count; + iree_hal_command_buffer_t** ptr; + } command_buffers; + } payload; + + // The device from which to allocate HSA stream-based command buffers for + // applying deferred command buffers. + iree_hal_device_t* device; + + // The stream to launch main GPU workload. + hsa_queue_t* hsa_queue; + + // Resource set to retain all associated resources by the payload. + iree_hal_resource_set_t* resource_set; + + // Semaphore list to wait on for the payload to start on the GPU. + iree_hal_semaphore_list_t wait_semaphore_list; + // Semaphore list to signal after the payload completes on the GPU. + iree_hal_semaphore_list_t signal_semaphore_list; + + // Scratch fields for analyzing whether actions are ready to issue. + hsa_signal_t signals[IREE_HAL_HSA_MAX_WAIT_EVENT_COUNT]; + iree_host_size_t signal_count; + // Whether the current action is still not ready for releasing to the GPU. + bool is_pending; +} iree_hal_hsa_queue_action_t; + +//===----------------------------------------------------------------------===// +// Queue action list +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hsa_queue_action_list_t { + iree_hal_hsa_queue_action_t* head; + iree_hal_hsa_queue_action_t* tail; +} iree_hal_hsa_queue_action_list_t; + +// Returns true if the action list is empty. +static inline bool iree_hal_hsa_queue_action_list_is_empty( + const iree_hal_hsa_queue_action_list_t* list) { + return list->head == NULL; +} + +// Pushes |action| on to the end of the given action |list|. +static void iree_hal_hsa_queue_action_list_push_back( + iree_hal_hsa_queue_action_list_t* list, + iree_hal_hsa_queue_action_t* action) { + if (list->tail) { + list->tail->next = action; + } else { + list->head = action; + } + action->next = NULL; + action->prev = list->tail; + list->tail = action; +} + +// Erases |action| from |list|. +static void iree_hal_hsa_queue_action_list_erase( + iree_hal_hsa_queue_action_list_t* list, + iree_hal_hsa_queue_action_t* action) { + iree_hal_hsa_queue_action_t* next = action->next; + iree_hal_hsa_queue_action_t* prev = action->prev; + if (prev) { + prev->next = next; + action->prev = NULL; + } else { + list->head = next; + } + if (next) { + next->prev = prev; + action->next = NULL; + } else { + list->tail = prev; + } +} + +// Takes all actions from |available_list| and moves them into |ready_list|. +static void iree_hal_hsa_queue_action_list_take_all( + iree_hal_hsa_queue_action_list_t* available_list, + iree_hal_hsa_queue_action_list_t* ready_list) { + IREE_ASSERT_NE(available_list, ready_list); + ready_list->head = available_list->head; + ready_list->tail = available_list->tail; + available_list->head = NULL; + available_list->tail = NULL; +} + +// Frees all actions in the given |list|. +static void iree_hal_hsa_queue_action_list_free_actions( + iree_allocator_t host_allocator, iree_hal_hsa_queue_action_list_t* list) { + for (iree_hal_hsa_queue_action_t* action = list->head; action != NULL;) { + iree_hal_hsa_queue_action_t* next_action = action->next; + iree_allocator_free(host_allocator, action); + action = next_action; + } +} + +//===----------------------------------------------------------------------===// +// Ready-list processing +//===----------------------------------------------------------------------===// + +// Ready action atomic slist entry struct. +typedef struct iree_hal_hsa_atomic_slist_entry_t { + iree_hal_hsa_queue_action_t* ready_list_head; + iree_atomic_slist_intrusive_ptr_t slist_next; +} iree_hal_hsa_atomic_slist_entry_t; + +// Ready action atomic slist. +IREE_TYPED_ATOMIC_SLIST_WRAPPER(iree_hal_hsa_ready_action, + iree_hal_hsa_atomic_slist_entry_t, + offsetof(iree_hal_hsa_atomic_slist_entry_t, + slist_next)); + +// The ready-list processing worker's working/exiting state. +// +// States in the list has increasing priorities--meaning normally ones appearing +// earlier can overwrite ones appearing later without checking; but not the +// reverse order. +typedef enum iree_hal_hsa_worker_state_e { + IREE_HAL_HSA_WORKER_STATE_IDLE_WAITING = 0, // Worker to main thread + IREE_HAL_HSA_WORKER_STATE_WORKLOAD_PENDING = 1, // Main to worker thread + IREE_HAL_HSA_WORKER_STATE_EXIT_REQUESTED = -1, // Main to worker thread + IREE_HAL_HSA_WORKER_STATE_EXIT_COMMITTED = -2, // Worker to main thread + IREE_HAL_HSA_WORKER_STATE_EXIT_ERROR = -3, // Worker to main thread +} iree_hal_hsa_worker_state_t; + +// The data structure needed by a ready-list processing worker thread to issue +// ready actions to the GPU. +// +// This data structure is shared between the parent thread, which owns the +// whole pending actions queue, and the worker thread; so proper synchronization +// is needed to touch it from both sides. +// +// The parent thread should push a list of ready actions to ready_worklist, +// update worker_state, and give state_notification accordingly. +// The worker thread waits on the state_notification and checks worker_state, +// and pops from the ready_worklist to process. The worker thread also monintors +// worker_state and stops processing if requested by the parent thread. +typedef struct iree_hal_hsa_working_area_t { + // Notification from the parent thread to request worker state changes. + iree_notification_t state_notification; + // Notification to the parent thread to indicate the worker committed exiting. + iree_notification_t exit_notification; + iree_hal_hsa_ready_action_slist_t ready_worklist; // atomic + iree_atomic_int32_t worker_state; // atomic + iree_atomic_intptr_t error_code; // atomic + // The number of actions that have been issued to the GPU but not yet fully + // completed both execution and cleanup. We don't need this field to be atomic + // given it is modified only from the worker thread. + int32_t pending_action_count; + iree_allocator_t host_allocator; // const +} iree_hal_hsa_working_area_t; + +static void iree_hal_hsa_working_area_initialize( + iree_allocator_t host_allocator, + iree_hal_hsa_working_area_t* working_area) { + iree_notification_initialize(&working_area->state_notification); + iree_notification_initialize(&working_area->exit_notification); + iree_hal_hsa_ready_action_slist_initialize(&working_area->ready_worklist); + iree_atomic_store_int32(&working_area->worker_state, + IREE_HAL_HSA_WORKER_STATE_IDLE_WAITING, + iree_memory_order_release); + iree_atomic_store_int32(&working_area->error_code, IREE_STATUS_OK, + iree_memory_order_release); + working_area->pending_action_count = 0; + working_area->host_allocator = host_allocator; +} + +static void iree_hal_hsa_working_area_deinitialize( + iree_hal_hsa_working_area_t* working_area) { + iree_hal_hsa_ready_action_slist_deinitialize(&working_area->ready_worklist); + iree_notification_deinitialize(&working_area->exit_notification); + iree_notification_deinitialize(&working_area->state_notification); +} + +// The main function for the ready-list processing worker thread. +static int iree_hal_hsa_worker_execute( + iree_hal_hsa_working_area_t* working_area); + +//===----------------------------------------------------------------------===// +// Pending queue actions +//===----------------------------------------------------------------------===// + +struct iree_hal_hsa_pending_queue_actions_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + // The allocator used to create the timepoint pool. + iree_allocator_t host_allocator; + // The block pool to allocate resource sets from. + iree_arena_block_pool_t* block_pool; + + // The symbols used to create and destroy hsa_signal_t objects. + const iree_hal_hsa_dynamic_symbols_t* symbols; + + // Non-recursive mutex guarding access to the action list. + iree_slim_mutex_t action_mutex; + + // The double-linked list of pending actions. + iree_hal_hsa_queue_action_list_t action_list IREE_GUARDED_BY(action_mutex); + + // The worker thread that monitors incoming requests and issues ready actions + // to the GPU. + iree_thread_t* worker_thread; + // The worker's working area; data exchange place with the parent thread. + iree_hal_hsa_working_area_t working_area; +}; + +static const iree_hal_resource_vtable_t + iree_hal_hsa_pending_queue_actions_vtable; + +iree_status_t iree_hal_hsa_pending_queue_actions_create( + const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, + iree_hal_hsa_pending_queue_actions_t** out_actions) { + IREE_ASSERT_ARGUMENT(symbols); + IREE_ASSERT_ARGUMENT(block_pool); + IREE_ASSERT_ARGUMENT(out_actions); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_pending_queue_actions_t* actions = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*actions), + (void**)&actions)); + iree_hal_resource_initialize(&iree_hal_hsa_pending_queue_actions_vtable, + &actions->resource); + actions->host_allocator = host_allocator; + actions->block_pool = block_pool; + actions->symbols = symbols; + iree_slim_mutex_initialize(&actions->action_mutex); + memset(&actions->action_list, 0, sizeof(actions->action_list)); + + // Initialize the working area for the ready-list processing worker. + iree_hal_hsa_working_area_t* working_area = &actions->working_area; + iree_hal_hsa_working_area_initialize(host_allocator, working_area); + + // Create the ready-list processing worker itself. + iree_thread_create_params_t params; + memset(¶ms, 0, sizeof(params)); + params.name = IREE_SV("deferque_worker"); + params.create_suspended = false; + iree_status_t status = iree_thread_create( + (iree_thread_entry_t)iree_hal_hsa_worker_execute, working_area, params, + actions->host_allocator, &actions->worker_thread); + + if (iree_status_is_ok(status)) { + *out_actions = actions; + } else { + iree_hal_hsa_pending_queue_actions_destroy((iree_hal_resource_t*)actions); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_hal_hsa_pending_queue_actions_t* +iree_hal_hsa_pending_queue_actions_cast(iree_hal_resource_t* base_value) { + return (iree_hal_hsa_pending_queue_actions_t*)base_value; +} + +static bool iree_hal_hsa_worker_committed_exiting( + iree_hal_hsa_working_area_t* working_area); + +void iree_hal_hsa_pending_queue_actions_destroy( + iree_hal_resource_t* base_actions) { + iree_hal_hsa_pending_queue_actions_t* actions = + iree_hal_hsa_pending_queue_actions_cast(base_actions); + iree_allocator_t host_allocator = actions->host_allocator; + iree_hal_hsa_working_area_t* working_area = &actions->working_area; + IREE_TRACE_ZONE_BEGIN(z0); + + // Request the worker to exit. + iree_hal_hsa_worker_state_t prev_state = + (iree_hal_hsa_worker_state_t)iree_atomic_exchange_int32( + &working_area->worker_state, IREE_HAL_HSA_WORKER_STATE_EXIT_REQUESTED, + iree_memory_order_acq_rel); + iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS); + + // Check potential exit states from the worker. + if (prev_state != IREE_HAL_HSA_WORKER_STATE_EXIT_ERROR) { + // Wait until the worker acknowledged exiting. + iree_notification_await( + &working_area->exit_notification, + (iree_condition_fn_t)iree_hal_hsa_worker_committed_exiting, + working_area, iree_infinite_timeout()); + } + + // Now we can delete worker related resources. + iree_thread_release(actions->worker_thread); + iree_hal_hsa_working_area_deinitialize(working_area); + + iree_slim_mutex_deinitialize(&actions->action_mutex); + iree_hal_hsa_queue_action_list_free_actions(host_allocator, + &actions->action_list); + iree_allocator_free(host_allocator, actions); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_resource_vtable_t + iree_hal_hsa_pending_queue_actions_vtable = { + .destroy = iree_hal_hsa_pending_queue_actions_destroy, +}; + +// Copies of the given |in_list| to |out_list| to retain the command buffer +// list. +static iree_status_t iree_hal_hsa_copy_command_buffer_list( + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* in_list, iree_allocator_t host_allocator, + iree_hal_command_buffer_t*** out_list) { + *out_list = NULL; + if (!command_buffer_count) return iree_ok_status(); + + iree_host_size_t total_size = command_buffer_count * sizeof(*in_list); + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)out_list)); + memcpy((void*)*out_list, in_list, total_size); + return iree_ok_status(); +} + +// Frees the semaphore and value list inside |semaphore_list|. +static void iree_hal_hsa_free_command_buffer_list( + iree_allocator_t host_allocator, + iree_hal_command_buffer_t* const* command_buffer_list) { + iree_allocator_free(host_allocator, (void*)command_buffer_list); +} + +// Copies of the given |in_list| to |out_list| to retain the semaphore and value +// list. +static iree_status_t iree_hal_hsa_copy_semaphore_list( + iree_hal_semaphore_list_t in_list, iree_allocator_t host_allocator, + iree_hal_semaphore_list_t* out_list) { + memset(out_list, 0, sizeof(*out_list)); + if (!in_list.count) return iree_ok_status(); + + out_list->count = in_list.count; + iree_host_size_t semaphore_size = in_list.count * sizeof(*in_list.semaphores); + IREE_RETURN_IF_ERROR(iree_allocator_malloc(host_allocator, semaphore_size, + (void**)&out_list->semaphores)); + memcpy(out_list->semaphores, in_list.semaphores, semaphore_size); + + iree_host_size_t value_size = in_list.count * sizeof(*in_list.payload_values); + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + host_allocator, value_size, (void**)&out_list->payload_values)); + memcpy(out_list->payload_values, in_list.payload_values, value_size); + return iree_ok_status(); +} + +// Frees the semaphore and value list inside |semaphore_list|. +static void iree_hal_hsa_free_semaphore_list( + iree_allocator_t host_allocator, + iree_hal_semaphore_list_t* semaphore_list) { + iree_allocator_free(host_allocator, semaphore_list->semaphores); + iree_allocator_free(host_allocator, semaphore_list->payload_values); +} + +iree_status_t iree_hal_hsa_pending_queue_actions_enqueue_execution( + iree_hal_device_t* device, hsa_queue_t* dispatch_queue, + iree_hal_hsa_pending_queue_actions_t* actions, + iree_hal_hsa_pending_action_cleanup_callback_t cleanup_callback, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers) { + IREE_ASSERT_ARGUMENT(actions); + IREE_ASSERT_ARGUMENT(command_buffer_count == 0 || command_buffers); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_queue_action_t* action = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(actions->host_allocator, sizeof(*action), + (void**)&action)); + + action->owning_actions = actions; + action->state = IREE_HAL_HSA_QUEUE_ACTION_STATE_ALIVE; + action->cleanup_callback = cleanup_callback; + action->kind = IREE_HAL_HSA_QUEUE_ACTION_TYPE_EXECUTION; + action->device = device; + + action->hsa_queue = dispatch_queue; + + // Initialize scratch fields. + action->signal_count = 0; + action->is_pending = true; + + // Retain all command buffers and semaphores. + iree_hal_resource_set_t* resource_set = NULL; + iree_status_t status = + iree_hal_resource_set_allocate(actions->block_pool, &resource_set); + if (IREE_LIKELY(iree_status_is_ok(status))) { + status = iree_hal_resource_set_insert(resource_set, command_buffer_count, + command_buffers); + } + if (IREE_LIKELY(iree_status_is_ok(status))) { + status = + iree_hal_resource_set_insert(resource_set, wait_semaphore_list.count, + wait_semaphore_list.semaphores); + } + if (IREE_LIKELY(iree_status_is_ok(status))) { + status = + iree_hal_resource_set_insert(resource_set, signal_semaphore_list.count, + signal_semaphore_list.semaphores); + } + if (IREE_LIKELY(iree_status_is_ok(status))) { + action->resource_set = resource_set; + } + + // Copy the command buffer list for later access. + // TODO: avoid host allocator malloc; use some pool for the allocation. + if (IREE_LIKELY(iree_status_is_ok(status))) { + action->payload.command_buffers.count = command_buffer_count; + status = iree_hal_hsa_copy_command_buffer_list( + command_buffer_count, command_buffers, actions->host_allocator, + &action->payload.command_buffers.ptr); + } + + // Copy the semaphore and value list for later access. + // TODO: avoid host allocator malloc; use some pool for the allocation. + if (IREE_LIKELY(iree_status_is_ok(status))) { + status = iree_hal_hsa_copy_semaphore_list(wait_semaphore_list, + actions->host_allocator, + &action->wait_semaphore_list); + } + if (IREE_LIKELY(iree_status_is_ok(status))) { + status = iree_hal_hsa_copy_semaphore_list(signal_semaphore_list, + actions->host_allocator, + &action->signal_semaphore_list); + } + + if (IREE_LIKELY(iree_status_is_ok(status))) { + // Retain the owning queue to make sure the action outlives it. + iree_hal_resource_retain(actions); + + // Now everything is okay and we can enqueue the action. + iree_slim_mutex_lock(&actions->action_mutex); + iree_hal_hsa_queue_action_list_push_back(&actions->action_list, action); + iree_slim_mutex_unlock(&actions->action_mutex); + } else { + iree_hal_hsa_free_semaphore_list(actions->host_allocator, + &action->wait_semaphore_list); + iree_hal_hsa_free_semaphore_list(actions->host_allocator, + &action->signal_semaphore_list); + iree_hal_hsa_free_command_buffer_list(actions->host_allocator, + action->payload.command_buffers.ptr); + iree_hal_resource_set_free(resource_set); + iree_allocator_free(actions->host_allocator, action); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Releases resources after action completion on the GPU and advances timeline +// and pending actions queue. +// +// This is the HSA host function callback to hsa_amd_signal_async_handler(), +// invoked by a HSA driver thread. Note that code in this function MUST NOT +// invoke any GPU API under the hood to avoid potential deadlock. +static bool iree_hal_hsa_execution_device_signal_host_callback( + hsa_signal_value_t IREE_ATTRIBUTE_UNUSED value, void* user_data) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hsa_queue_action_t* action = (iree_hal_hsa_queue_action_t*)user_data; + IREE_ASSERT_EQ(action->kind, IREE_HAL_HSA_QUEUE_ACTION_TYPE_EXECUTION); + IREE_ASSERT_EQ(action->state, IREE_HAL_HSA_QUEUE_ACTION_STATE_ALIVE); + iree_hal_hsa_pending_queue_actions_t* actions = action->owning_actions; + + // Flip the action state to zombie and enqueue it again so that we can let + // the worker thread clean it up. Note that this is necessary because cleanup + // may involve GPU API calls like buffer releasing or unregistering, so we can + // not inline it here. + action->state = IREE_HAL_HSA_QUEUE_ACTION_STATE_ZOMBIE; + iree_slim_mutex_lock(&actions->action_mutex); + iree_hal_hsa_queue_action_list_push_back(&actions->action_list, action); + iree_slim_mutex_unlock(&actions->action_mutex); + + // Notify the worker thread again that we have the cleanup action enqueued. + // Only overwrite the idle waiting state, which has lower priority. + iree_hal_hsa_worker_state_t prev_state = + IREE_HAL_HSA_WORKER_STATE_IDLE_WAITING; + iree_atomic_compare_exchange_strong_int32( + &actions->working_area.worker_state, /*expected=*/&prev_state, + /*desired=*/IREE_HAL_HSA_WORKER_STATE_WORKLOAD_PENDING, + /*order_succ=*/iree_memory_order_acq_rel, + /*order_fail=*/iree_memory_order_acquire); + iree_notification_post(&actions->working_area.state_notification, + IREE_ALL_WAITERS); + + // Advance semaphore timelines by calling into the host signaling function. + // This will internally try to release more workload to the GPU. + IREE_IGNORE_ERROR( + iree_hal_semaphore_list_signal(action->signal_semaphore_list)); + + IREE_TRACE_ZONE_END(z0); + + return false; +} + +// Issues the given kernel dispatch |action| to the GPU. +static iree_status_t iree_hal_hsa_pending_queue_actions_issue_execution( + iree_hal_hsa_queue_action_t* action) { + IREE_ASSERT_EQ(action->kind, IREE_HAL_HSA_QUEUE_ACTION_TYPE_EXECUTION); + IREE_ASSERT_EQ(action->is_pending, false); + const iree_hal_hsa_dynamic_symbols_t* symbols = + action->owning_actions->symbols; + IREE_TRACE_ZONE_BEGIN(z0); + + // No need to lock given that this action is already detched from the pending + // actions list; so only this thread is seeing it now. + + // First wait all the device hsa_signal_t in the dispatch stream. + for (iree_host_size_t i = 0; i < action->signal_count; ++i) { + symbols->hsa_signal_wait_scacquire(action->signals[i], + HSA_SIGNAL_CONDITION_EQ, 0, UINT64_MAX, + HSA_WAIT_STATE_BLOCKED); + } + + // Then launch all command buffers to the dispatch queue. + for (iree_host_size_t i = 0; i < action->payload.command_buffers.count; ++i) { + iree_hal_command_buffer_t* command_buffer = + action->payload.command_buffers.ptr[i]; + iree_hal_command_buffer_t* queue_command_buffer = NULL; + iree_hal_command_buffer_mode_t mode = + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION | + IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_device_create_queue_command_buffer( + action->device, mode, IREE_HAL_COMMAND_CATEGORY_ANY, + /*binding_capacity=*/0, &queue_command_buffer)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(action->resource_set, 1, + &queue_command_buffer)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_deferred_command_buffer_apply( + command_buffer, queue_command_buffer, + iree_hal_buffer_binding_table_empty())); + } + + // Increase the pending action counter. We decrease it once it fully + // completes and gets cleaned up. + ++action->owning_actions->working_area.pending_action_count; + + // Last record hsa_signal_t signals in the dispatch queue. + hsa_signal_t completion_signal; + for (iree_host_size_t i = 0; i < action->signal_semaphore_list.count; ++i) { + // Grab a hsa_signal_t for this semaphore value signaling. + hsa_signal_t signal; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_event_semaphore_acquire_timepoint_device_signal( + action->signal_semaphore_list.semaphores[i], + action->signal_semaphore_list.payload_values[i], &signal)); + symbols->hsa_signal_store_relaxed(signal, 1); + + uint64_t write_index = + symbols->hsa_queue_add_write_index_relaxed(action->hsa_queue, 1); + + size_t queue_mask = action->hsa_queue->size - 1; + + struct hsa_barrier_and_packet_s* barrier_packet = + (hsa_barrier_and_packet_t*)(action->hsa_queue->base_address) + + (write_index & queue_mask); + + memset((void*)barrier_packet, 0, sizeof(hsa_barrier_and_packet_t)); + + uint16_t packet_header = 0; + packet_header |= HSA_PACKET_TYPE_BARRIER_AND << HSA_PACKET_HEADER_TYPE; + packet_header |= HSA_FENCE_SCOPE_AGENT + << HSA_PACKET_HEADER_ACQUIRE_FENCE_SCOPE; + packet_header |= HSA_FENCE_SCOPE_AGENT + << HSA_PACKET_HEADER_RELEASE_FENCE_SCOPE; + packet_header |= 1 << HSA_PACKET_HEADER_BARRIER; + barrier_packet->completion_signal = signal; + + __atomic_store_n(&barrier_packet->header, packet_header, __ATOMIC_RELEASE); + + symbols->hsa_signal_store_screlease(action->hsa_queue->doorbell_signal, + write_index); + + completion_signal = signal; + } + + IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR( + z0, symbols, + hsa_amd_signal_async_handler( + completion_signal, HSA_SIGNAL_CONDITION_EQ, 0, + iree_hal_hsa_execution_device_signal_host_callback, action), + "hsa_amd_signal_async_handler"); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Performs the given cleanup |action| on the CPU. +static iree_status_t iree_hal_hsa_pending_queue_actions_issue_cleanup( + iree_hal_hsa_queue_action_t* action) { + iree_hal_hsa_pending_queue_actions_t* actions = action->owning_actions; + iree_allocator_t host_allocator = actions->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + // Call user provided callback before releasing any resource. + if (action->cleanup_callback) { + action->cleanup_callback(action->callback_user_data); + } + + // Only release resources after callbacks have been issued. + iree_hal_resource_set_free(action->resource_set); + iree_hal_hsa_free_semaphore_list(host_allocator, + &action->wait_semaphore_list); + iree_hal_hsa_free_semaphore_list(host_allocator, + &action->signal_semaphore_list); + + // Drop reference to the pending action queue given now we are done. + iree_hal_resource_release(actions); + + iree_allocator_free(host_allocator, action); + + // Now we fully executed and cleaned up this action. Decrease the pending + // action counter. + --actions->working_area.pending_action_count; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +iree_status_t iree_hal_hsa_pending_queue_actions_issue( + iree_hal_hsa_pending_queue_actions_t* actions) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_queue_action_list_t pending_list = {NULL, NULL}; + iree_hal_hsa_queue_action_list_t ready_list = {NULL, NULL}; + + iree_slim_mutex_lock(&actions->action_mutex); + + if (iree_hal_hsa_queue_action_list_is_empty(&actions->action_list)) { + iree_slim_mutex_unlock(&actions->action_mutex); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + // Scan through the list and categorize actions into pending and ready lists. + iree_status_t status = iree_ok_status(); + iree_hal_hsa_queue_action_t* action = actions->action_list.head; + while (action) { + iree_hal_hsa_queue_action_t* next_action = action->next; + action->next = NULL; + + iree_host_size_t semaphore_count = action->wait_semaphore_list.count; + iree_hal_semaphore_t** semaphores = action->wait_semaphore_list.semaphores; + uint64_t* values = action->wait_semaphore_list.payload_values; + + action->signal_count = 0; + action->is_pending = false; + + // Cleanup actions are immediately ready to release. Otherwise, look at all + // wait semaphores to make sure that they are either already ready or we can + // wait on a device event. + if (action->state == IREE_HAL_HSA_QUEUE_ACTION_STATE_ALIVE) { + for (iree_host_size_t i = 0; i < semaphore_count; ++i) { + // If this semaphore has already signaled past the desired value, we can + // just ignore it. + uint64_t value = 0; + status = iree_hal_semaphore_query(semaphores[i], &value); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) break; + if (value >= values[i]) continue; + + // Try to acquire a hsa_signal_t from a device wait timepoint. If so, we + // can use that hsa_signal_t to wait on the device. Otherwise, this + // action is still not ready. + hsa_signal_t signal; + status = iree_hal_hsa_event_semaphore_acquire_timepoint_device_wait( + semaphores[i], values[i], &signal); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) break; + + if (IREE_UNLIKELY(action->signal_count >= + IREE_HAL_HSA_MAX_WAIT_EVENT_COUNT)) { + status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "exceeded max wait hsa_signal_t limit"); + break; + } + action->signals[action->signal_count++] = signal; + } + } + + if (IREE_UNLIKELY(!iree_status_is_ok(status))) break; + + if (action->is_pending) { + iree_hal_hsa_queue_action_list_push_back(&pending_list, action); + } else { + iree_hal_hsa_queue_action_list_push_back(&ready_list, action); + } + + action = next_action; + } + + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + // Some error happened during processing the current action. Clear the + // scratch fields and put it back to the pending list so we don't leak. + action->signal_count = 0; + action->is_pending = true; + iree_hal_hsa_queue_action_list_push_back(&pending_list, action); + } + + // Preserve pending timepoints. + actions->action_list = pending_list; + + iree_slim_mutex_unlock(&actions->action_mutex); + + if (ready_list.head == NULL) { + // Nothing ready yet. Just return. + IREE_TRACE_ZONE_END(z0); + return status; + } + + iree_hal_hsa_atomic_slist_entry_t* entry = NULL; + // TODO: avoid host allocator malloc; use some pool for the allocation. + if (iree_status_is_ok(status)) { + status = iree_allocator_malloc(actions->host_allocator, sizeof(*entry), + (void**)&entry); + } + + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + // Release all actions in the ready list to avoid leaking. + iree_hal_hsa_queue_action_list_free_actions(actions->host_allocator, + &ready_list); + iree_allocator_free(actions->host_allocator, entry); + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Now push the ready list to the worker and have it to issue the actions to + // the GPU. + entry->ready_list_head = ready_list.head; + iree_hal_hsa_ready_action_slist_push(&actions->working_area.ready_worklist, + entry); + + // We can only overwrite the worker state if the previous state is idle + // waiting; we cannot overwrite exit related states. so we need to perform + // atomic compare and exchange here. + iree_hal_hsa_worker_state_t prev_state = + IREE_HAL_HSA_WORKER_STATE_IDLE_WAITING; + iree_atomic_compare_exchange_strong_int32( + &actions->working_area.worker_state, /*expected=*/&prev_state, + /*desired=*/IREE_HAL_HSA_WORKER_STATE_WORKLOAD_PENDING, + /*order_succ=*/iree_memory_order_acq_rel, + /*order_fail=*/iree_memory_order_acquire); + iree_notification_post(&actions->working_area.state_notification, + IREE_ALL_WAITERS); + + // Handle potential error cases from the worker thread. + if (prev_state == IREE_HAL_HSA_WORKER_STATE_EXIT_ERROR) { + iree_status_code_t code = iree_atomic_load_int32( + &actions->working_area.error_code, iree_memory_order_acquire); + status = iree_status_from_code(code); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Worker routines +//===----------------------------------------------------------------------===// + +static bool iree_hal_hsa_worker_has_incoming_request( + iree_hal_hsa_working_area_t* working_area) { + iree_hal_hsa_worker_state_t value = iree_atomic_load_int32( + &working_area->worker_state, iree_memory_order_acquire); + // These are the only two possible states that set from the main thread to + // the worker thread. + return value == IREE_HAL_HSA_WORKER_STATE_WORKLOAD_PENDING || + value == IREE_HAL_HSA_WORKER_STATE_EXIT_REQUESTED; +} + +static bool iree_hal_hsa_worker_committed_exiting( + iree_hal_hsa_working_area_t* working_area) { + return iree_atomic_load_int32(&working_area->worker_state, + iree_memory_order_acquire) == + IREE_HAL_HSA_WORKER_STATE_EXIT_COMMITTED; +} + +// Processes all ready actions in the given |worklist|. +static iree_status_t iree_hal_hsa_worker_process_ready_list( + iree_allocator_t host_allocator, + iree_hal_hsa_ready_action_slist_t* worklist) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + do { + iree_hal_hsa_atomic_slist_entry_t* entry = + iree_hal_hsa_ready_action_slist_pop(worklist); + if (!entry) break; + + // Process the current batch of ready actions. + iree_hal_hsa_queue_action_t* action = entry->ready_list_head; + while (action) { + iree_hal_hsa_queue_action_t* next_action = action->next; + action->next = NULL; + + switch (action->state) { + case IREE_HAL_HSA_QUEUE_ACTION_STATE_ALIVE: + status = iree_hal_hsa_pending_queue_actions_issue_execution(action); + if (iree_status_is_ok(status)) action->signal_count = 0; + break; + case IREE_HAL_HSA_QUEUE_ACTION_STATE_ZOMBIE: + status = iree_hal_hsa_pending_queue_actions_issue_cleanup(action); + break; + } + if (!iree_status_is_ok(status)) break; + + action = next_action; + } + + iree_allocator_free(host_allocator, entry); + } while (iree_status_is_ok(status)); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +// The main function for the ready-list processing worker thread. +static int iree_hal_hsa_worker_execute( + iree_hal_hsa_working_area_t* working_area) { + iree_hal_hsa_ready_action_slist_t* worklist = &working_area->ready_worklist; + + while (true) { + // Block waiting for incoming requests. + iree_notification_await( + &working_area->state_notification, + (iree_condition_fn_t)iree_hal_hsa_worker_has_incoming_request, + working_area, iree_infinite_timeout()); + + // Immediately flip the state to idle waiting if and only if the previous + // state is workload pending. We do it before processing ready list to make + // sure that we don't accidentally ignore new workload pushed after done + // ready list processing but before overwriting the state from this worker + // thread. Also we don't want to overwrite other exit states. So we need to + // perform atomic compare and exchange here. + iree_hal_hsa_worker_state_t prev_state = + IREE_HAL_HSA_WORKER_STATE_WORKLOAD_PENDING; + iree_atomic_compare_exchange_strong_int32( + &working_area->worker_state, /*expected=*/&prev_state, + /*desired=*/IREE_HAL_HSA_WORKER_STATE_IDLE_WAITING, + /*order_succ=*/iree_memory_order_acq_rel, + /*order_fail=*/iree_memory_order_acquire); + + // Check if we received request to stop processing and exit this thread. + bool should_exit = iree_atomic_load_int32(&working_area->worker_state, + iree_memory_order_acquire) == + IREE_HAL_HSA_WORKER_STATE_EXIT_REQUESTED; + + // Process the ready list. We also want this even requested to exit. + iree_status_t status = iree_hal_hsa_worker_process_ready_list( + working_area->host_allocator, worklist); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + IREE_ASSERT(false && "error when processing ready list"); + iree_atomic_store_int32(&working_area->error_code, + iree_status_code(status), + iree_memory_order_release); + // This state has the highest priority so just overwrite. + iree_atomic_store_int32(&working_area->worker_state, + IREE_HAL_HSA_WORKER_STATE_EXIT_ERROR, + iree_memory_order_release); + iree_notification_post(&working_area->exit_notification, + IREE_ALL_WAITERS); + return -1; + } + + if (should_exit && working_area->pending_action_count == 0) { + // Signal that this thread is committed to exit. This state has a priority + // that is only lower than error exit. And we just checked error exit in + // the above. So also just overwrite. + iree_atomic_store_int32(&working_area->worker_state, + IREE_HAL_HSA_WORKER_STATE_EXIT_COMMITTED, + iree_memory_order_release); + iree_notification_post(&working_area->exit_notification, + IREE_ALL_WAITERS); + return 0; + } + } + return 0; +} diff --git a/runtime/src/iree-amd-aie/driver/hsa/pending_queue_actions.h b/runtime/src/iree-amd-aie/driver/hsa/pending_queue_actions.h new file mode 100644 index 000000000..54f274fce --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/pending_queue_actions.h @@ -0,0 +1,62 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_PENDING_QUEUE_ACTIONS_H_ +#define IREE_EXPERIMENTAL_HSA_PENDING_QUEUE_ACTIONS_H_ + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree/base/api.h" +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// A data structure to manage pending queue actions +typedef struct iree_hal_hsa_pending_queue_actions_t + iree_hal_hsa_pending_queue_actions_t; + +// Creates a pending actions queue. +iree_status_t iree_hal_hsa_pending_queue_actions_create( + const iree_hal_hsa_dynamic_symbols_t* symbols, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, + iree_hal_hsa_pending_queue_actions_t** out_actions); + +// Destroys the pending |actions| queue. +void iree_hal_hsa_pending_queue_actions_destroy(iree_hal_resource_t* actions); + +// Callback to execute user code after action completion but before resource +// releasing. +// +// Data behind |user_data| must remain alive before the action is released. +typedef void(IREE_API_PTR* iree_hal_hsa_pending_action_cleanup_callback_t)( + void* user_data); + +// Enqueues the given list of |command_buffers| that waits on +// |wait_semaphore_list| and signals |signal_semaphore_lsit|. +// +// |cleanup_callback|, if not NULL, will run after the action completes but +// before releasing all retained resources. +iree_status_t iree_hal_hsa_pending_queue_actions_enqueue_execution( + iree_hal_device_t* device, hsa_queue_t* dispatch_queue, + iree_hal_hsa_pending_queue_actions_t* actions, + iree_hal_hsa_pending_action_cleanup_callback_t cleanup_callback, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers); + +// Tries to scan the pending actions and release ready ones to the GPU. +iree_status_t iree_hal_hsa_pending_queue_actions_issue( + iree_hal_hsa_pending_queue_actions_t* actions); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_PENDING_QUEUE_ACTIONS_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/pipeline_layout.c b/runtime/src/iree-amd-aie/driver/hsa/pipeline_layout.c new file mode 100644 index 000000000..a000de4a0 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/pipeline_layout.c @@ -0,0 +1,249 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/pipeline_layout.h" + +#include + +#include "iree/base/api.h" +#include "iree/base/tracing.h" + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_descriptor_set_layout_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hsa_descriptor_set_layout_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + // The host allocator used for creating this descriptor set layout struct. + iree_allocator_t host_allocator; + + // The total number of bindings in this descriptor set. + iree_host_size_t binding_count; +} iree_hal_hsa_descriptor_set_layout_t; + +static const iree_hal_descriptor_set_layout_vtable_t + iree_hal_hsa_descriptor_set_layout_vtable; + +static iree_hal_hsa_descriptor_set_layout_t* +iree_hal_hsa_descriptor_set_layout_cast( + iree_hal_descriptor_set_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_descriptor_set_layout_vtable); + return (iree_hal_hsa_descriptor_set_layout_t*)base_value; +} + +static const iree_hal_hsa_descriptor_set_layout_t* +iree_hal_hsa_descriptor_set_layout_const_cast( + const iree_hal_descriptor_set_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_descriptor_set_layout_vtable); + return (const iree_hal_hsa_descriptor_set_layout_t*)base_value; +} + +iree_status_t iree_hal_hsa_descriptor_set_layout_create( + iree_hal_descriptor_set_layout_flags_t flags, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_allocator_t host_allocator, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); + IREE_TRACE_ZONE_BEGIN(z0); + + *out_descriptor_set_layout = NULL; + + iree_hal_hsa_descriptor_set_layout_t* descriptor_set_layout = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*descriptor_set_layout), + (void**)&descriptor_set_layout)); + + iree_hal_resource_initialize(&iree_hal_hsa_descriptor_set_layout_vtable, + &descriptor_set_layout->resource); + descriptor_set_layout->host_allocator = host_allocator; + descriptor_set_layout->binding_count = binding_count; + *out_descriptor_set_layout = + (iree_hal_descriptor_set_layout_t*)descriptor_set_layout; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +iree_host_size_t iree_hal_hsa_descriptor_set_layout_binding_count( + const iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + const iree_hal_hsa_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_hsa_descriptor_set_layout_const_cast(base_descriptor_set_layout); + return descriptor_set_layout->binding_count; +} + +static void iree_hal_hsa_descriptor_set_layout_destroy( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_hsa_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_hsa_descriptor_set_layout_cast(base_descriptor_set_layout); + iree_allocator_t host_allocator = descriptor_set_layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, descriptor_set_layout); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_descriptor_set_layout_vtable_t + iree_hal_hsa_descriptor_set_layout_vtable = { + .destroy = iree_hal_hsa_descriptor_set_layout_destroy, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_pipeline_layout_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hsa_pipeline_layout_t { + // Abstract resource used for injecting reference counting and vtable; + // must be at offset 0. + iree_hal_resource_t resource; + + // The host allocator used for creating this pipeline layout struct. + iree_allocator_t host_allocator; + + // The kernel argument index for push constants. + // Note that push constants are placed after all normal descriptors. + iree_host_size_t push_constant_base_index; + iree_host_size_t push_constant_count; + + iree_host_size_t set_layout_count; + // The list of descriptor set layout pointers, pointing to trailing inline + // allocation after the end of this struct. + struct { + iree_hal_descriptor_set_layout_t* set_layout; + // Base kernel argument index for this descriptor set. + iree_host_size_t base_index; + } set_layouts[]; +} iree_hal_hsa_pipeline_layout_t; +// + Additional inline allocation for holding all descriptor sets. + +static const iree_hal_pipeline_layout_vtable_t + iree_hal_hsa_pipeline_layout_vtable; + +static iree_hal_hsa_pipeline_layout_t* iree_hal_hsa_pipeline_layout_cast( + iree_hal_pipeline_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_pipeline_layout_vtable); + return (iree_hal_hsa_pipeline_layout_t*)base_value; +} + +static const iree_hal_hsa_pipeline_layout_t* +iree_hal_hsa_pipeline_layout_const_cast( + const iree_hal_pipeline_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_pipeline_layout_vtable); + return (const iree_hal_hsa_pipeline_layout_t*)base_value; +} + +iree_status_t iree_hal_hsa_pipeline_layout_create( + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t* const* set_layouts, + iree_host_size_t push_constant_count, iree_allocator_t host_allocator, + iree_hal_pipeline_layout_t** out_pipeline_layout) { + IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); + IREE_ASSERT_ARGUMENT(out_pipeline_layout); + IREE_TRACE_ZONE_BEGIN(z0); + + *out_pipeline_layout = NULL; + if (push_constant_count > IREE_HAL_HSA_MAX_PUSH_CONSTANT_COUNT) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "push constant count %" PRIhsz " over the limit of %d", + push_constant_count, IREE_HAL_HSA_MAX_PUSH_CONSTANT_COUNT); + } + + // Currently the pipeline layout doesn't do anything. + // TODO: Handle creating the argument layout at that time hadling both push + // constant and buffers. + iree_hal_hsa_pipeline_layout_t* pipeline_layout = NULL; + iree_host_size_t total_size = + sizeof(*pipeline_layout) + + set_layout_count * sizeof(*pipeline_layout->set_layouts); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, total_size, + (void**)&pipeline_layout)); + + iree_hal_resource_initialize(&iree_hal_hsa_pipeline_layout_vtable, + &pipeline_layout->resource); + pipeline_layout->host_allocator = host_allocator; + pipeline_layout->set_layout_count = set_layout_count; + iree_host_size_t base_index = 0; + for (iree_host_size_t i = 0; i < set_layout_count; ++i) { + pipeline_layout->set_layouts[i].set_layout = set_layouts[i]; + // Copy and retain all descriptor sets so we don't lose them. + iree_hal_descriptor_set_layout_retain(set_layouts[i]); + pipeline_layout->set_layouts[i].base_index = base_index; + base_index += + iree_hal_hsa_descriptor_set_layout_binding_count(set_layouts[i]); + } + pipeline_layout->push_constant_base_index = base_index; + pipeline_layout->push_constant_count = push_constant_count; + *out_pipeline_layout = (iree_hal_pipeline_layout_t*)pipeline_layout; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_hsa_pipeline_layout_destroy( + iree_hal_pipeline_layout_t* base_pipeline_layout) { + iree_hal_hsa_pipeline_layout_t* pipeline_layout = + iree_hal_hsa_pipeline_layout_cast(base_pipeline_layout); + iree_allocator_t host_allocator = pipeline_layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < pipeline_layout->set_layout_count; ++i) { + iree_hal_descriptor_set_layout_release( + pipeline_layout->set_layouts[i].set_layout); + } + iree_allocator_free(host_allocator, pipeline_layout); + + IREE_TRACE_ZONE_END(z0); +} + +const iree_hal_descriptor_set_layout_t* +iree_hal_hsa_pipeline_layout_descriptor_set_layout( + const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { + const iree_hal_hsa_pipeline_layout_t* pipeline_layout = + iree_hal_hsa_pipeline_layout_const_cast(base_pipeline_layout); + if (set < pipeline_layout->set_layout_count) { + return pipeline_layout->set_layouts[set].set_layout; + } + return NULL; +} + +iree_host_size_t iree_hal_hsa_pipeline_layout_base_binding_index( + const iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) { + const iree_hal_hsa_pipeline_layout_t* pipeline_layout = + iree_hal_hsa_pipeline_layout_const_cast(base_pipeline_layout); + return pipeline_layout->set_layouts[set].base_index; +} + +static const iree_hal_pipeline_layout_vtable_t + iree_hal_hsa_pipeline_layout_vtable = { + .destroy = iree_hal_hsa_pipeline_layout_destroy, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_dispatch_layout_t +//===----------------------------------------------------------------------===// + +iree_hal_hsa_dispatch_layout_t iree_hal_hsa_pipeline_layout_dispatch_layout( + const iree_hal_pipeline_layout_t* base_pipeline_layout) { + const iree_hal_hsa_pipeline_layout_t* pipeline_layout = + iree_hal_hsa_pipeline_layout_const_cast(base_pipeline_layout); + iree_hal_hsa_dispatch_layout_t dispatch_params = { + .push_constant_base_index = pipeline_layout->push_constant_base_index, + .push_constant_count = pipeline_layout->push_constant_count, + .total_binding_count = pipeline_layout->push_constant_base_index, + .set_layout_count = pipeline_layout->set_layout_count, + }; + + return dispatch_params; +} diff --git a/runtime/src/iree-amd-aie/driver/hsa/pipeline_layout.h b/runtime/src/iree-amd-aie/driver/hsa/pipeline_layout.h new file mode 100644 index 000000000..b510c0c7c --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/pipeline_layout.h @@ -0,0 +1,93 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_PIPELINE_LAYOUT_H_ +#define IREE_EXPERIMENTAL_HSA_PIPELINE_LAYOUT_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// The max number of bindings per descriptor set allowed in the HSA HAL +// implementation. +#define IREE_HAL_HSA_MAX_DESCRIPTOR_SET_BINDING_COUNT 16 + +// The max number of descriptor sets allowed in the HSA HAL implementation. +// +// This depends on the general descriptor set planning in IREE and should adjust +// with it. +#define IREE_HAL_HSA_MAX_DESCRIPTOR_SET_COUNT 4 + +// The max number of push constants supported by the HSA HAL implementation. +#define IREE_HAL_HSA_MAX_PUSH_CONSTANT_COUNT 64 + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_descriptor_set_layout_t +//===----------------------------------------------------------------------===// + +// Creates a descriptor set layout with the given |bindings|. +// +// Bindings in a descriptor set map to a list of consecutive kernel arguments in +// HSA kernels. +iree_status_t iree_hal_hsa_descriptor_set_layout_create( + iree_hal_descriptor_set_layout_flags_t flags, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_allocator_t host_allocator, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + +// Returns the binding count for the given descriptor set layout. +iree_host_size_t iree_hal_hsa_descriptor_set_layout_binding_count( + const iree_hal_descriptor_set_layout_t* descriptor_set_layout); + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_pipeline_layout_t +//===----------------------------------------------------------------------===// + +// Creates the pipeline layout with the given |set_layouts| and +// |push_constant_count|. +// +// Bindings in the pipeline map to kernel arguments in HSA kernels, followed by +// the kernel argument for the push constant data. +iree_status_t iree_hal_hsa_pipeline_layout_create( + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t* const* set_layouts, + iree_host_size_t push_constant_count, iree_allocator_t host_allocator, + iree_hal_pipeline_layout_t** out_pipeline_layout); + +// Returns the total number of sets in the given |pipeline_layout|. +iree_host_size_t iree_hal_hsa_pipeline_layout_descriptor_set_count( + const iree_hal_pipeline_layout_t* pipeline_layout); + +// Returns the descriptor set layout of the given |set| in |pipeline_layout|. +const iree_hal_descriptor_set_layout_t* +iree_hal_hsa_pipeline_layout_descriptor_set_layout( + const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); + +// Returns the base kernel argument index for the given set. +iree_host_size_t iree_hal_hsa_pipeline_layout_base_binding_index( + const iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set); + +typedef struct iree_hal_hsa_dispatch_layout_t { + iree_host_size_t push_constant_base_index; + iree_host_size_t push_constant_count; + iree_host_size_t set_layout_count; + iree_host_size_t total_binding_count; +} iree_hal_hsa_dispatch_layout_t; + +// Returns dispatch layout parameters in a struct form for pipeline layout. +iree_hal_hsa_dispatch_layout_t iree_hal_hsa_pipeline_layout_dispatch_layout( + const iree_hal_pipeline_layout_t* base_pipeline_layout); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_PIPELINE_LAYOUT_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/queue_command_buffer.c b/runtime/src/iree-amd-aie/driver/hsa/queue_command_buffer.c new file mode 100644 index 000000000..26b2032f8 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/queue_command_buffer.c @@ -0,0 +1,586 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/queue_command_buffer.h" + +#include "iree-amd-aie/driver/hsa/hsa_buffer.h" +#include "iree-amd-aie/driver/hsa/native_executable.h" +#include "iree-amd-aie/driver/hsa/pipeline_layout.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +// #include "iree-amd-aie/driver/hsa/tracing.h" +#include "iree/hal/utils/resource_set.h" + +typedef struct iree_hal_hsa_queue_command_buffer_t { + iree_hal_command_buffer_t base; + iree_allocator_t host_allocator; + iree_hal_allocator_t* device_allocator; + + const iree_hal_hsa_dynamic_symbols_t* hsa_symbols; + + // The queue where we will dipatch work + hsa_queue_t* hsa_queue; + + // A resource set to maintain references to all resources used within the + // command buffer. Reset on each begin. + iree_hal_resource_set_t* resource_set; + + // Staging arena used for host->device transfers. + // Used for when we need HSA to be able to reference memory as it performs + // asynchronous operations. + iree_arena_allocator_t arena; + + int32_t push_constants[IREE_HAL_HSA_MAX_PUSH_CONSTANT_COUNT]; + + // The current bound descriptor sets. + struct { + hsa_device_pointer_t + bindings[IREE_HAL_HSA_MAX_DESCRIPTOR_SET_BINDING_COUNT]; + } descriptor_sets[IREE_HAL_HSA_MAX_DESCRIPTOR_SET_COUNT]; +} iree_hal_hsa_queue_command_buffer_t; + +static const iree_hal_command_buffer_vtable_t + iree_hal_hsa_queue_command_buffer_vtable; + +static iree_hal_hsa_queue_command_buffer_t* +iree_hal_hsa_queue_command_buffer_cast(iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_hsa_queue_command_buffer_vtable); + return (iree_hal_hsa_queue_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_hsa_queue_command_buffer_create( + iree_hal_device_t* device, + const iree_hal_hsa_dynamic_symbols_t* hsa_symbols, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hsa_queue_t* queue, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(hsa_symbols); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + + if (binding_capacity > 0) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect command buffers not yet implemented"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_queue_command_buffer_t* command_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*command_buffer), + (void**)&command_buffer)); + + iree_hal_command_buffer_initialize( + device_allocator, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + binding_capacity, (uint8_t*)command_buffer + sizeof(*command_buffer), + &iree_hal_hsa_queue_command_buffer_vtable, &command_buffer->base); + command_buffer->host_allocator = host_allocator; + command_buffer->hsa_symbols = hsa_symbols; + command_buffer->hsa_queue = queue; + command_buffer->device_allocator = device_allocator; + iree_arena_initialize(block_pool, &command_buffer->arena); + + iree_status_t status = + iree_hal_resource_set_allocate(block_pool, &command_buffer->resource_set); + + *out_command_buffer = &command_buffer->base; + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_hsa_queue_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + iree_allocator_t host_allocator = command_buffer->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_resource_set_free(command_buffer->resource_set); + iree_arena_deinitialize(&command_buffer->arena); + iree_allocator_free(host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +bool iree_hal_hsa_queue_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_resource_is(&command_buffer->resource, + &iree_hal_hsa_queue_command_buffer_vtable); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + // Reset the arena as there should be nothing using it now that we've + // dispatched all our operations inline. + // NOTE: the resource set may contain resources we need to drop as we don't + // need to keep them live any longer than it takes to schedule the + // operations. In a real command buffer we would be this stream command + // buffer is strictly used to perform inline execution/replay of + // deferred command buffers that are retaining the resources already. + iree_arena_reset(&command_buffer->arena); + iree_hal_resource_set_free(command_buffer->resource_set); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_allocate(command_buffer->arena.block_pool, + &command_buffer->resource_set)); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_hsa_queue_command_buffer_begin_debug_group( + iree_hal_command_buffer_t* base_command_buffer, iree_string_view_t label, + iree_hal_label_color_t label_color, + const iree_hal_label_location_t* location) {} + +static void iree_hal_hsa_queue_command_buffer_end_debug_group( + iree_hal_command_buffer_t* base_command_buffer) {} + +static iree_status_t iree_hal_hsa_queue_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || + iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "barrier involving host not yet supported"); + } + + if (flags != IREE_HAL_EXECUTION_BARRIER_FLAG_NONE) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-zero barrier flag not yet supported"); + } + IREE_TRACE_ZONE_BEGIN(z0); + + // Nothing to do for barriers between memory operations or dispatches--HSA + // stream semantics guarantees execution and memory visibility in program + // order. + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "event not yet supported"); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t buffer_ref) { + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t target_ref, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + hsa_device_pointer_t target_device_buffer = + iree_hal_hsa_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_ref.buffer)); + iree_device_size_t target_offset = + iree_hal_buffer_byte_offset(target_ref.buffer) + target_ref.offset; + hsa_device_pointer_t dst = (uint8_t*)target_device_buffer + target_offset; + size_t num_elements = target_ref.length / pattern_length; + + switch (pattern_length) { + case 4: { + IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->hsa_symbols, + hsa_amd_memory_fill(dst, *(const uint32_t*)(pattern), num_elements), + "hsa_amd_memory_fill"); + break; + } + case 2: { + uint16_t* dst_ptr = (uint16_t*)dst; + uint16_t pattern_value = *(const uint16_t*)pattern; + for (size_t i = 0; i < num_elements; ++i) { + memcpy(dst_ptr + i, &pattern_value, sizeof(uint16_t)); + } + break; + } + case 1: { + uint8_t* dst_ptr = (uint8_t*)dst; + uint8_t pattern_value = *(const uint8_t*)pattern; + for (size_t i = 0; i < num_elements; ++i) { + memcpy(dst_ptr + i, &pattern_value, sizeof(uint8_t)); + } + break; + } + default: + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INTERNAL, + "unsupported fill pattern length"); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref) { + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + // Allocate scratch space in the arena for the data and copy it in. + // The update buffer API requires that the command buffer capture the host + // memory at the time the method is called in case the caller wants to reuse + // the memory. Because HSA memcpys are async if we didn't copy it's possible + // for the reused memory to change before the stream reaches the copy + // operation and get the wrong data. + const uint8_t* src = (const uint8_t*)source_buffer + source_offset; + if (command_buffer->arena.block_pool) { + uint8_t* storage = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, target_ref.length, + (void**)&storage)); + memcpy(storage, src, target_ref.length); + src = storage; + } + + // Issue the copy using the scratch memory as the source. + hsa_device_pointer_t target_device_buffer = + iree_hal_hsa_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_ref.buffer)); + hsa_device_pointer_t dst = (uint8_t*)target_device_buffer + + iree_hal_buffer_byte_offset(target_ref.buffer) + + target_ref.offset; + + IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->hsa_symbols, + hsa_memory_copy(dst, (void*)src, target_ref.length), "hsa_memory_copy"); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref) { + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + hsa_device_pointer_t target_device_buffer = + iree_hal_hsa_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(target_ref.buffer)); + iree_device_size_t target_offset = + iree_hal_buffer_byte_offset(target_ref.buffer) + target_ref.offset; + hsa_device_pointer_t source_device_buffer = + iree_hal_hsa_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(source_ref.buffer)); + iree_device_size_t source_offset = + iree_hal_buffer_byte_offset(source_ref.buffer) + source_ref.offset; + hsa_device_pointer_t dst = (uint8_t*)target_device_buffer + target_offset; + hsa_device_pointer_t src = (uint8_t*)source_device_buffer + source_offset; + + IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->hsa_symbols, + hsa_memory_copy(dst, src, target_ref.length), "hsa_memory_copy"); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_collective( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, + iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref, + iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "collectives not yet supported"); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_host_size_t constant_base_index = offset / sizeof(int32_t); + for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) { + command_buffer->push_constants[i + constant_base_index] = + ((uint32_t*)values)[i]; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, + iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings) { + if (binding_count > IREE_HAL_HSA_MAX_DESCRIPTOR_SET_BINDING_COUNT) { + return iree_make_status( + IREE_STATUS_RESOURCE_EXHAUSTED, + "exceeded available binding slots for push " + "descriptor set #%" PRIu32 "; requested %" PRIhsz " vs. maximal %d", + set, binding_count, IREE_HAL_HSA_MAX_DESCRIPTOR_SET_BINDING_COUNT); + } + + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + hsa_device_pointer_t* current_bindings = + command_buffer->descriptor_sets[set].bindings; + for (iree_host_size_t i = 0; i < binding_count; i++) { + const iree_hal_buffer_ref_t* binding = &bindings[i]; + hsa_device_pointer_t device_ptr = NULL; + if (binding->buffer) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &binding->buffer)); + + hsa_device_pointer_t device_buffer = iree_hal_hsa_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding->buffer)); + iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); + device_ptr = (uint8_t*)device_buffer + offset + binding->offset; + } + current_bindings[binding->ordinal] = device_ptr; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { + iree_hal_hsa_queue_command_buffer_t* command_buffer = + iree_hal_hsa_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_hsa_kernel_info_t kernel_info; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_native_executable_entry_point_kernel_info( + executable, entry_point, &kernel_info)); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &executable)); + + iree_hal_hsa_dispatch_layout_t dispatch_layout = + iree_hal_hsa_pipeline_layout_dispatch_layout(kernel_info.layout); + + // The total number of descriptors across all descriptor sets. + iree_host_size_t descriptor_count = dispatch_layout.total_binding_count; + // The total number of push constants. + iree_host_size_t push_constant_count = dispatch_layout.push_constant_count; + // We append push constants to the end of descriptors to form a linear chain + // of kernel arguments. + iree_host_size_t kernel_params_count = descriptor_count + push_constant_count; + iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); + + // Each kernel_params[i] is itself a pointer to the corresponding + // element at the *second* inline allocation at the end of the current + // segment. + iree_host_size_t total_size = kernel_params_length * 2; + + iree_hal_buffer_params_t buffer_params = { + .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | + IREE_HAL_BUFFER_USAGE_TRANSFER, + .access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE, + .type = + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + }; + + iree_device_size_t kern_arg_allocation_size = total_size; + iree_hal_buffer_t* kern_arg_allocation_buffer = NULL; + iree_status_t result = iree_hal_allocator_allocate_buffer( + command_buffer->device_allocator, buffer_params, kern_arg_allocation_size, + &kern_arg_allocation_buffer); + if (!iree_status_is_ok(result)) { + return result; + } + uint8_t* storage_base = + (uint8_t*)iree_hal_hsa_buffer_host_pointer(kern_arg_allocation_buffer); + + void** params_ptr = (void**)storage_base; + + // Set up kernel arguments to point to the payload slots. + hsa_device_pointer_t* payload_ptr = + (hsa_device_pointer_t*)((uint8_t*)params_ptr + kernel_params_length); + for (size_t i = 0; i < kernel_params_count; i++) { + params_ptr[i] = &payload_ptr[i]; + } + + // Copy descriptors from all sets to the end of the current segment for later + // access. + iree_host_size_t set_count = dispatch_layout.set_layout_count; + for (iree_host_size_t i = 0; i < set_count; ++i) { + // TODO: cache this information in the kernel info to avoid recomputation. + iree_host_size_t binding_count = + iree_hal_hsa_descriptor_set_layout_binding_count( + iree_hal_hsa_pipeline_layout_descriptor_set_layout( + kernel_info.layout, i)); + iree_host_size_t index = + iree_hal_hsa_pipeline_layout_base_binding_index(kernel_info.layout, i); + memcpy(payload_ptr + index, command_buffer->descriptor_sets[i].bindings, + binding_count * sizeof(hsa_device_pointer_t)); + } + + // Append the push constants to the kernel arguments. + iree_host_size_t base_index = dispatch_layout.push_constant_base_index; + // As commented in the above, what each kernel parameter points to is a + // hsa_device_pointer_t, which as the size of a pointer on the target machine. + // we are just storing a 32-bit value for the push constant here instead. So + // we must process one element each type, for 64-bit machines. + for (iree_host_size_t i = 0; i < push_constant_count; i++) { + *((uint32_t*)params_ptr[base_index + i]) = + command_buffer->push_constants[i]; + } + + // Make room for the packet + uint64_t write_index = + command_buffer->hsa_symbols->hsa_queue_add_write_index_relaxed( + command_buffer->hsa_queue, 1); + + // Create the packet + size_t queue_mask = command_buffer->hsa_queue->size - 1; + + hsa_kernel_dispatch_packet_t* packet = + (hsa_kernel_dispatch_packet_t*)(command_buffer->hsa_queue->base_address) + + (write_index & queue_mask); + + hsa_signal_value_t signal_value = 1; + uint32_t num_consumers = 0; + const hsa_agent_t* consumers = NULL; + iree_status_t status = IREE_HSA_RESULT_TO_STATUS( + command_buffer->hsa_symbols, + hsa_signal_create(signal_value, num_consumers, consumers, + &packet->completion_signal), + "hsa_signal_create"); + if (status != IREE_STATUS_OK) { + return status; + } + + uint16_t packet_dimensions = 3; + packet->setup |= packet_dimensions + << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS; + + packet->grid_size_x = kernel_info.block_size[0] * workgroup_x; + packet->grid_size_y = kernel_info.block_size[1] * workgroup_y; + packet->grid_size_z = kernel_info.block_size[2] * workgroup_z; + + packet->workgroup_size_x = kernel_info.block_size[0]; + packet->workgroup_size_y = kernel_info.block_size[1]; + packet->workgroup_size_z = kernel_info.block_size[2]; + + packet->kernarg_address = *params_ptr; + packet->kernel_object = kernel_info.kernel_object; + packet->private_segment_size = kernel_info.private_segment_size; + packet->group_segment_size = kernel_info.group_segment_size; + + uint16_t header = 0; + header |= HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_ACQUIRE_FENCE_SCOPE; + header |= HSA_FENCE_SCOPE_SYSTEM << HSA_PACKET_HEADER_RELEASE_FENCE_SCOPE; + header |= HSA_PACKET_TYPE_KERNEL_DISPATCH << HSA_PACKET_HEADER_TYPE; + + __atomic_store_n(&packet->header, header, __ATOMIC_RELEASE); + // TODO(muhaawad): We don't need a completion signal here anymore + // since we have fences that make sure everything is completed. + // We might still add completion signals and use within the semaphores + // instead of inserting barrier packet each time + command_buffer->hsa_symbols->hsa_signal_store_screlease( + command_buffer->hsa_queue->doorbell_signal, write_index); + + command_buffer->hsa_symbols->hsa_signal_wait_acquire( + packet->completion_signal, HSA_SIGNAL_CONDITION_LT, 1, UINT64_MAX, + HSA_WAIT_STATE_BLOCKED); + + status = + IREE_HSA_RESULT_TO_STATUS(command_buffer->hsa_symbols, + hsa_signal_destroy(packet->completion_signal)); + if (status != IREE_STATUS_OK) { + return status; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hsa_queue_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "need HSA implementation of dispatch indirect"); +} + +static const iree_hal_command_buffer_vtable_t + iree_hal_hsa_queue_command_buffer_vtable = { + .destroy = iree_hal_hsa_queue_command_buffer_destroy, + .begin = iree_hal_hsa_queue_command_buffer_begin, + .end = iree_hal_hsa_queue_command_buffer_end, + .begin_debug_group = + iree_hal_hsa_queue_command_buffer_begin_debug_group, + .end_debug_group = iree_hal_hsa_queue_command_buffer_end_debug_group, + .execution_barrier = + iree_hal_hsa_queue_command_buffer_execution_barrier, + .signal_event = iree_hal_hsa_queue_command_buffer_signal_event, + .reset_event = iree_hal_hsa_queue_command_buffer_reset_event, + .wait_events = iree_hal_hsa_queue_command_buffer_wait_events, + .discard_buffer = iree_hal_hsa_queue_command_buffer_discard_buffer, + .fill_buffer = iree_hal_hsa_queue_command_buffer_fill_buffer, + .update_buffer = iree_hal_hsa_queue_command_buffer_update_buffer, + .copy_buffer = iree_hal_hsa_queue_command_buffer_copy_buffer, + .collective = iree_hal_hsa_queue_command_buffer_collective, + .push_constants = iree_hal_hsa_queue_command_buffer_push_constants, + .push_descriptor_set = + iree_hal_hsa_queue_command_buffer_push_descriptor_set, + .dispatch = iree_hal_hsa_queue_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_hsa_queue_command_buffer_dispatch_indirect, +}; diff --git a/runtime/src/iree-amd-aie/driver/hsa/queue_command_buffer.h b/runtime/src/iree-amd-aie/driver/hsa/queue_command_buffer.h new file mode 100644 index 000000000..e37e5c2a5 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/queue_command_buffer.h @@ -0,0 +1,50 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_QUEUE_COMMAND_BUFFER_H_ +#define IREE_EXPERIMENTAL_HSA_QUEUE_COMMAND_BUFFER_H_ + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/hsa_headers.h" +// #include "iree-amd-aie/driver/hsa/tracing.h" +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates command buffer that immediately issues commands against the given +// HSA |stream|. Access to |stream| must be synchronized by the user. +// +// If |block_pool| is non-NULL then the stream command buffer will retain copies +// of input data until reset. If NULL then the caller must ensure the lifetime +// of input data outlives the command buffer. +// +// This command buffer is used to replay deferred command buffers. When +// replaying the scratch data required for things like buffer updates is +// retained by the source deferred command buffer and as such the |block_pool| +// and can be NULL to avoid a double copy. +iree_status_t iree_hal_hsa_queue_command_buffer_create( + iree_hal_device_t* device, + const iree_hal_hsa_dynamic_symbols_t* hsa_symbols, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_host_size_t binding_capacity, hsa_queue_t* queue, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns true if |command_buffer| is a HSA stream-based command buffer. +bool iree_hal_hsa_queue_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_QUEUE_COMMAND_BUFFER_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/registration/CMakeLists.txt b/runtime/src/iree-amd-aie/driver/hsa/registration/CMakeLists.txt new file mode 100644 index 000000000..cb252a455 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/registration/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +# +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + registration + HDRS + "driver_module.h" + SRCS + "driver_module.c" + DEPS + iree::base + iree::base::core_headers + iree-amd-aie::driver::hsa + iree::hal + DEFINES + "IREE_HAVE_HAL_HSA_DRIVER_MODULE=1" + PUBLIC +) diff --git a/runtime/src/iree-amd-aie/driver/hsa/registration/driver_module.c b/runtime/src/iree-amd-aie/driver/hsa/registration/driver_module.c new file mode 100644 index 000000000..954605faa --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/registration/driver_module.c @@ -0,0 +1,73 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/registration/driver_module.h" + +#include +#include + +#include "iree-amd-aie/driver/hsa/api.h" +#include "iree/base/api.h" +#include "iree/base/internal/flags.h" +#include "iree/base/status.h" + +static iree_status_t iree_hal_hsa_driver_factory_enumerate( + void* self, iree_host_size_t* out_driver_info_count, + const iree_hal_driver_info_t** out_driver_infos) { + IREE_ASSERT_ARGUMENT(out_driver_info_count); + IREE_ASSERT_ARGUMENT(out_driver_infos); + IREE_TRACE_ZONE_BEGIN(z0); + + static const iree_hal_driver_info_t driver_infos[1] = {{ + .driver_name = IREE_SVL("hsa"), + .full_name = IREE_SVL("HSA HAL driver (via dylib)"), + }}; + *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); + *out_driver_infos = driver_infos; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hsa_driver_factory_try_create( + void* self, iree_string_view_t driver_name, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(out_driver); + + if (!iree_string_view_equal(driver_name, IREE_SV("hsa"))) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "no driver '%.*s' is provided by this factory", + (int)driver_name.size, driver_name.data); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_driver_options_t driver_options; + iree_hal_hsa_driver_options_initialize(&driver_options); + + iree_hal_hsa_device_params_t device_params; + iree_hal_hsa_device_params_initialize(&device_params); + + driver_options.default_device_index = 0; + + iree_status_t status = iree_hal_hsa_driver_create( + driver_name, &driver_options, &device_params, host_allocator, out_driver); + + IREE_TRACE_ZONE_END(z0); + + return status; +} + +IREE_API_EXPORT iree_status_t +iree_hal_hsa_driver_module_register(iree_hal_driver_registry_t* registry) { + static const iree_hal_driver_factory_t factory = { + .self = NULL, + .enumerate = iree_hal_hsa_driver_factory_enumerate, + .try_create = iree_hal_hsa_driver_factory_try_create, + }; + return iree_hal_driver_registry_register_factory(registry, &factory); +} diff --git a/runtime/src/iree-amd-aie/driver/hsa/registration/driver_module.h b/runtime/src/iree-amd-aie/driver/hsa/registration/driver_module.h new file mode 100644 index 000000000..ba18779e5 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/registration/driver_module.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_REGISTRATION_DRIVER_MODULE_H_ +#define IREE_EXPERIMENTAL_HSA_REGISTRATION_DRIVER_MODULE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Registers the HSA HAL driver to the given |registry|. +IREE_API_EXPORT iree_status_t +iree_hal_hsa_driver_module_register(iree_hal_driver_registry_t* registry); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_REGISTRATION_DRIVER_MODULE_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/status_util.c b/runtime/src/iree-amd-aie/driver/hsa/status_util.c new file mode 100644 index 000000000..967609440 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/status_util.c @@ -0,0 +1,191 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/status_util.h" + +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree/base/status.h" + +// The list of HSA error strings with their corresponding IREE error state +// classification. +// +#define IREE_HSA_ERROR_LIST(IREE_HSA_MAP_ERROR) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_SUCCESS", IREE_STATUS_OK) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_INFO_BREAK", IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR", IREE_STATUS_UNKNOWN) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_ARGUMENT", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_QUEUE_CREATION", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_ALLOCATION", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_AGENT", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_REGION", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_SIGNAL", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_QUEUE", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_OUT_OF_RESOURCES", \ + IREE_STATUS_RESOURCE_EXHAUSTED) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_PACKET_FORMAT", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_RESOURCE_FREE", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_NOT_INITIALIZED", \ + IREE_STATUS_FAILED_PRECONDITION) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_REFCOUNT_OVERFLOW", \ + IREE_STATUS_RESOURCE_EXHAUSTED) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_INDEX", \ + IREE_STATUS_INVALID_ARGUMENT) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_ISA", IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_ISA_NAME", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_CODE_OBJECT", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_EXECUTABLE", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_FROZEN_EXECUTABLE", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_SYMBOL_NAME", \ + IREE_STATUS_NOT_FOUND) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED", \ + IREE_STATUS_ALREADY_EXISTS) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_VARIABLE_UNDEFINED", \ + IREE_STATUS_NOT_FOUND) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_EXCEPTION", IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_CODE_SYMBOL", \ + IREE_STATUS_NOT_FOUND) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_EXECUTABLE_SYMBOL", \ + IREE_STATUS_NOT_FOUND) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_FILE", IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_CACHE", IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_WAVEFRONT", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_SIGNAL_GROUP", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_INVALID_RUNTIME_STATE", \ + IREE_STATUS_INTERNAL) \ + IREE_HSA_MAP_ERROR("HSA_STATUS_ERROR_FATAL", IREE_STATUS_INTERNAL) + +// TODO(muhaawad): Not sure if there is an HSA-way of doing this. +const char* hsa_status_to_string(hsa_status_t status) { + switch (status) { + case HSA_STATUS_SUCCESS: + return "HSA_STATUS_SUCCESS"; + case HSA_STATUS_INFO_BREAK: + return "HSA_STATUS_INFO_BREAK"; + case HSA_STATUS_ERROR: + return "HSA_STATUS_ERROR"; + case HSA_STATUS_ERROR_INVALID_ARGUMENT: + return "HSA_STATUS_ERROR_INVALID_ARGUMENT"; + case HSA_STATUS_ERROR_INVALID_QUEUE_CREATION: + return "HSA_STATUS_ERROR_INVALID_QUEUE_CREATION"; + case HSA_STATUS_ERROR_INVALID_ALLOCATION: + return "HSA_STATUS_ERROR_INVALID_ALLOCATION"; + case HSA_STATUS_ERROR_INVALID_AGENT: + return "HSA_STATUS_ERROR_INVALID_AGENT"; + case HSA_STATUS_ERROR_INVALID_REGION: + return "HSA_STATUS_ERROR_INVALID_REGION"; + case HSA_STATUS_ERROR_INVALID_SIGNAL: + return "HSA_STATUS_ERROR_INVALID_SIGNAL"; + case HSA_STATUS_ERROR_INVALID_QUEUE: + return "HSA_STATUS_ERROR_INVALID_QUEUE"; + case HSA_STATUS_ERROR_OUT_OF_RESOURCES: + return "HSA_STATUS_ERROR_OUT_OF_RESOURCES"; + case HSA_STATUS_ERROR_INVALID_PACKET_FORMAT: + return "HSA_STATUS_ERROR_INVALID_PACKET_FORMAT"; + case HSA_STATUS_ERROR_RESOURCE_FREE: + return "HSA_STATUS_ERROR_RESOURCE_FREE"; + case HSA_STATUS_ERROR_NOT_INITIALIZED: + return "HSA_STATUS_ERROR_NOT_INITIALIZED"; + case HSA_STATUS_ERROR_REFCOUNT_OVERFLOW: + return "HSA_STATUS_ERROR_REFCOUNT_OVERFLOW"; + case HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS: + return "HSA_STATUS_ERROR_INCOMPATIBLE_ARGUMENTS"; + case HSA_STATUS_ERROR_INVALID_INDEX: + return "HSA_STATUS_ERROR_INVALID_INDEX"; + case HSA_STATUS_ERROR_INVALID_ISA: + return "HSA_STATUS_ERROR_INVALID_ISA"; + case HSA_STATUS_ERROR_INVALID_ISA_NAME: + return "HSA_STATUS_ERROR_INVALID_ISA_NAME"; + case HSA_STATUS_ERROR_INVALID_CODE_OBJECT: + return "HSA_STATUS_ERROR_INVALID_CODE_OBJECT"; + case HSA_STATUS_ERROR_INVALID_EXECUTABLE: + return "HSA_STATUS_ERROR_INVALID_EXECUTABLE"; + case HSA_STATUS_ERROR_FROZEN_EXECUTABLE: + return "HSA_STATUS_ERROR_FROZEN_EXECUTABLE"; + case HSA_STATUS_ERROR_INVALID_SYMBOL_NAME: + return "HSA_STATUS_ERROR_INVALID_SYMBOL_NAME"; + case HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED: + return "HSA_STATUS_ERROR_VARIABLE_ALREADY_DEFINED"; + case HSA_STATUS_ERROR_VARIABLE_UNDEFINED: + return "HSA_STATUS_ERROR_VARIABLE_UNDEFINED"; + case HSA_STATUS_ERROR_EXCEPTION: + return "HSA_STATUS_ERROR_EXCEPTION"; + case HSA_STATUS_ERROR_INVALID_CODE_SYMBOL: + return "HSA_STATUS_ERROR_INVALID_CODE_SYMBOL"; + case HSA_STATUS_ERROR_INVALID_EXECUTABLE_SYMBOL: + return "HSA_STATUS_ERROR_INVALID_EXECUTABLE_SYMBOL"; + case HSA_STATUS_ERROR_INVALID_FILE: + return "HSA_STATUS_ERROR_INVALID_FILE"; + case HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER: + return "HSA_STATUS_ERROR_INVALID_CODE_OBJECT_READER"; + case HSA_STATUS_ERROR_INVALID_CACHE: + return "HSA_STATUS_ERROR_INVALID_CACHE"; + case HSA_STATUS_ERROR_INVALID_WAVEFRONT: + return "HSA_STATUS_ERROR_INVALID_WAVEFRONT"; + case HSA_STATUS_ERROR_INVALID_SIGNAL_GROUP: + return "HSA_STATUS_ERROR_INVALID_SIGNAL_GROUP"; + case HSA_STATUS_ERROR_INVALID_RUNTIME_STATE: + return "HSA_STATUS_ERROR_INVALID_RUNTIME_STATE"; + case HSA_STATUS_ERROR_FATAL: + return "HSA_STATUS_ERROR_FATAL"; + default: + return "Unknown HSA_STATUS"; + } +} + +// Converts HSA |error_name| to the corresponding IREE status code. +static iree_status_code_t iree_hal_hsa_error_name_to_status_code( + const char* error_name) { +#define IREE_HSA_ERROR_TO_IREE_STATUS(hsa_error, iree_status) \ + if (strncmp(error_name, hsa_error, strlen(hsa_error)) == 0) { \ + return iree_status; \ + } + IREE_HSA_ERROR_LIST(IREE_HSA_ERROR_TO_IREE_STATUS) +#undef IREE_HSA_ERROR_TO_IREE_STATUS + return IREE_STATUS_UNKNOWN; +} + +iree_status_t iree_hal_hsa_result_to_status( + const iree_hal_hsa_dynamic_symbols_t* syms, hsa_status_t result, + const char* file, uint32_t line) { + if (IREE_LIKELY(result == HSA_STATUS_SUCCESS)) { + return iree_ok_status(); + } + + const char* error_name = hsa_status_to_string(result); + + const char* error_string = NULL; + hsa_status_t status_string_result = + syms->hsa_status_string(result, &error_string); + if (status_string_result != HSA_STATUS_SUCCESS) { + error_string = "unknown error"; + } + + return iree_make_status_with_location( + file, line, iree_hal_hsa_error_name_to_status_code(error_name), + "HSA driver error '%s' (%d): %s", error_name, result, error_string); +} diff --git a/runtime/src/iree-amd-aie/driver/hsa/status_util.h b/runtime/src/iree-amd-aie/driver/hsa/status_util.h new file mode 100644 index 000000000..4f2007078 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/status_util.h @@ -0,0 +1,72 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_STATUS_UTIL_H_ +#define IREE_EXPERIMENTAL_HSA_STATUS_UTIL_H_ + +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Converts a hsa_status_t to an iree_status_t. +// +// Usage: +// iree_status_t status = IREE_HSA_RESULT_TO_STATUS(hsa_symbols, +// hsaDoThing(...)); +#define IREE_HSA_RESULT_TO_STATUS(syms, expr, ...) \ + iree_hal_hsa_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__) + +// IREE_RETURN_IF_ERROR but implicitly converts the hsa_status_t return value to +// an iree_status_t. +// +// Usage: +// IREE_HSA_RETURN_IF_ERROR(hsa_symbols, hsaDoThing(...), "message"); +#define IREE_HSA_RETURN_IF_ERROR(syms, expr, ...) \ + IREE_RETURN_IF_ERROR(iree_hal_hsa_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__), \ + __VA_ARGS__) + +// IREE_RETURN_IF_ERROR but ends the current zone and implicitly converts the +// hsa_status_t return value to an iree_status_t. +// +// Usage: +// IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR(zone_id, hsa_symbols, +// hsaDoThing(...), "message"); + +#define IREE_HSA_RETURN_AND_END_ZONE_IF_ERROR(zone_id, syms, expr, ...) \ + IREE_RETURN_AND_END_ZONE_IF_ERROR( \ + zone_id, \ + iree_hal_hsa_result_to_status((syms), ((syms)->expr), __FILE__, \ + __LINE__), \ + __VA_ARGS__) + +// IREE_IGNORE_ERROR but implicitly converts the hsa_status_t return value to an +// iree_status_t. +// +// Usage: +// IREE_HSA_IGNORE_ERROR(hsa_symbols, hsaDoThing(...)); +#define IREE_HSA_IGNORE_ERROR(syms, expr) \ + IREE_IGNORE_ERROR(iree_hal_hsa_result_to_status((syms), ((syms)->expr), \ + __FILE__, __LINE__)) + +// Converts a hsa_status_t to an iree_status_t object. +iree_status_t iree_hal_hsa_result_to_status( + const iree_hal_hsa_dynamic_symbols_t* syms, hsa_status_t result, + const char* file, uint32_t line); + +const char* hsa_status_to_string(hsa_status_t status); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_STATUS_UTIL_H_ diff --git a/runtime/src/iree-amd-aie/driver/hsa/timepoint_pool.c b/runtime/src/iree-amd-aie/driver/hsa/timepoint_pool.c new file mode 100644 index 000000000..78d5c0a1a --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/timepoint_pool.c @@ -0,0 +1,353 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree-amd-aie/driver/hsa/timepoint_pool.h" + +#include +#include +#include + +#include "iree-amd-aie/driver/hsa/dynamic_symbols.h" +#include "iree-amd-aie/driver/hsa/event_pool.h" +#include "iree-amd-aie/driver/hsa/status_util.h" +#include "iree/base/api.h" +#include "iree/base/internal/atomics.h" +#include "iree/base/internal/event_pool.h" +#include "iree/base/internal/synchronization.h" +#include "iree/hal/api.h" +#include "iree/hal/utils/semaphore_base.h" + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_timepoint_t +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_hsa_timepoint_allocate( + iree_hal_hsa_timepoint_pool_t* pool, iree_allocator_t host_allocator, + iree_hal_hsa_timepoint_t** out_timepoint) { + IREE_ASSERT_ARGUMENT(pool); + IREE_ASSERT_ARGUMENT(out_timepoint); + *out_timepoint = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_timepoint_t* timepoint = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*timepoint), + (void**)&timepoint)); + // iree_allocator_malloc zeros out the whole struct. + timepoint->host_allocator = host_allocator; + timepoint->pool = pool; + + *out_timepoint = timepoint; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Clears all data fields in the given |timepoint| except the original host +// allocator and owning pool. +static void iree_hal_hsa_timepoint_clear(iree_hal_hsa_timepoint_t* timepoint) { + iree_allocator_t host_allocator = timepoint->host_allocator; + iree_hal_hsa_timepoint_pool_t* pool = timepoint->pool; + memset(timepoint, 0, sizeof(*timepoint)); + timepoint->host_allocator = host_allocator; + timepoint->pool = pool; +} + +static void iree_hal_hsa_timepoint_free(iree_hal_hsa_timepoint_t* timepoint) { + iree_allocator_t host_allocator = timepoint->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_ASSERT(timepoint->kind == IREE_HAL_HSA_TIMEPOINT_KIND_NONE); + iree_allocator_free(host_allocator, timepoint); + + IREE_TRACE_ZONE_END(z0); +} + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_timepoint_pool_t +//===----------------------------------------------------------------------===// + +struct iree_hal_hsa_timepoint_pool_t { + // The allocator used to create the timepoint pool. + iree_allocator_t host_allocator; + + // The pool to acquire host events. + iree_event_pool_t* host_event_pool; + // The pool to acquire device events. Internally synchronized. + iree_hal_hsa_event_pool_t* device_event_pool; + + // Note that the above pools are internally synchronized; so we don't and + // shouldn't use the following mutex to guard access to them. + + // Guards timepoint related fields this pool. We don't expect a performant + // program to frequently allocate timepoints for synchronization purposes; the + // traffic to this pool should be low. So it should be fine to use mutex to + // guard here. + iree_slim_mutex_t timepoint_mutex; + + // Maximum number of timepoint objects that will be maintained in the pool. + // More timepoints may be allocated at any time, but they will be disposed + // directly when they are no longer needed. + iree_host_size_t available_capacity IREE_GUARDED_BY(timepoint_mutex); + // Total number of currently available timepoint objects. + iree_host_size_t available_count IREE_GUARDED_BY(timepoint_mutex); + // The list of available_count timepoint objects. + iree_hal_hsa_timepoint_t* available_list[] IREE_GUARDED_BY(timepoint_mutex); +}; +// + Additional inline allocation for holding timepoints up to the capacity. + +iree_status_t iree_hal_hsa_timepoint_pool_allocate( + iree_event_pool_t* host_event_pool, + iree_hal_hsa_event_pool_t* device_event_pool, + iree_host_size_t available_capacity, iree_allocator_t host_allocator, + iree_hal_hsa_timepoint_pool_t** out_timepoint_pool) { + IREE_ASSERT_ARGUMENT(host_event_pool); + IREE_ASSERT_ARGUMENT(device_event_pool); + IREE_ASSERT_ARGUMENT(out_timepoint_pool); + *out_timepoint_pool = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hsa_timepoint_pool_t* timepoint_pool = NULL; + iree_host_size_t total_size = + sizeof(*timepoint_pool) + + available_capacity * sizeof(*timepoint_pool->available_list); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, total_size, + (void**)&timepoint_pool)); + timepoint_pool->host_allocator = host_allocator; + timepoint_pool->host_event_pool = host_event_pool; + timepoint_pool->device_event_pool = device_event_pool; + + iree_slim_mutex_initialize(&timepoint_pool->timepoint_mutex); + timepoint_pool->available_capacity = available_capacity; + timepoint_pool->available_count = 0; + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < available_capacity; ++i) { + status = iree_hal_hsa_timepoint_allocate( + timepoint_pool, host_allocator, + &timepoint_pool->available_list[timepoint_pool->available_count++]); + if (!iree_status_is_ok(status)) break; + } + + if (iree_status_is_ok(status)) { + *out_timepoint_pool = timepoint_pool; + } else { + iree_hal_hsa_timepoint_pool_free(timepoint_pool); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hsa_timepoint_pool_free( + iree_hal_hsa_timepoint_pool_t* timepoint_pool) { + iree_allocator_t host_allocator = timepoint_pool->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < timepoint_pool->available_count; ++i) { + iree_hal_hsa_timepoint_free(timepoint_pool->available_list[i]); + } + iree_slim_mutex_deinitialize(&timepoint_pool->timepoint_mutex); + iree_allocator_free(host_allocator, timepoint_pool); + + IREE_TRACE_ZONE_END(z0); +} + +// Acquires |timepoint_count| timepoints from the given |timepoint_pool|. +// The |out_timepoints| needs to be further initialized with proper kind and +// payload values. +static iree_status_t iree_hal_hsa_timepoint_pool_acquire_internal( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, + iree_hal_hsa_timepoint_t** out_timepoints) { + IREE_ASSERT_ARGUMENT(timepoint_pool); + if (!timepoint_count) return iree_ok_status(); + IREE_ASSERT_ARGUMENT(out_timepoints); + IREE_TRACE_ZONE_BEGIN(z0); + + // We'll try to get what we can from the pool and fall back to initializing + // new iree_hal_hsa_timepoint_t objects. + iree_host_size_t remaining_count = timepoint_count; + + // Try first to grab from the pool. + iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex); + iree_host_size_t from_pool_count = + iree_min(timepoint_pool->available_count, timepoint_count); + if (from_pool_count > 0) { + iree_host_size_t pool_base_index = + timepoint_pool->available_count - from_pool_count; + memcpy(out_timepoints, &timepoint_pool->available_list[pool_base_index], + from_pool_count * sizeof(*timepoint_pool->available_list)); + timepoint_pool->available_count -= from_pool_count; + remaining_count -= from_pool_count; + } + iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex); + + // Allocate the rest of the timepoints. + if (remaining_count > 0) { + IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-acquire"); + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < remaining_count; ++i) { + status = iree_hal_hsa_timepoint_allocate( + timepoint_pool, timepoint_pool->host_allocator, + &out_timepoints[from_pool_count + i]); + if (!iree_status_is_ok(status)) { + // Must release all timepoints we've acquired so far. + iree_hal_hsa_timepoint_pool_release(timepoint_pool, from_pool_count + i, + out_timepoints); + IREE_TRACE_ZONE_END(z1); + IREE_TRACE_ZONE_END(z0); + return status; + } + } + IREE_TRACE_ZONE_END(z1); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +iree_status_t iree_hal_hsa_timepoint_pool_acquire_host_wait( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, + iree_hal_hsa_timepoint_t** out_timepoints) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Acquire host events to wrap up. This should happen before acquiring the + // timepoints to avoid nested locks. + iree_event_t* host_events = iree_alloca( + timepoint_count * sizeof((*out_timepoints)->timepoint.host_wait)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_event_pool_acquire(timepoint_pool->host_event_pool, + timepoint_count, host_events)); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_timepoint_pool_acquire_internal( + timepoint_pool, timepoint_count, out_timepoints)); + for (iree_host_size_t i = 0; i < timepoint_count; ++i) { + out_timepoints[i]->kind = IREE_HAL_HSA_TIMEPOINT_KIND_HOST_WAIT; + out_timepoints[i]->timepoint.host_wait = host_events[i]; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +iree_status_t iree_hal_hsa_timepoint_pool_acquire_device_signal( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, + iree_hal_hsa_timepoint_t** out_timepoints) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Acquire device events to wrap up. This should happen before acquiring the + // timepoints to avoid nested locks. + iree_hal_hsa_event_t** device_events = iree_alloca( + timepoint_count * sizeof((*out_timepoints)->timepoint.device_signal)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_event_pool_acquire(timepoint_pool->device_event_pool, + timepoint_count, device_events)); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_timepoint_pool_acquire_internal( + timepoint_pool, timepoint_count, out_timepoints)); + for (iree_host_size_t i = 0; i < timepoint_count; ++i) { + out_timepoints[i]->kind = IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_SIGNAL; + out_timepoints[i]->timepoint.device_signal = device_events[i]; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +iree_status_t iree_hal_hsa_timepoint_pool_acquire_device_wait( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, + iree_hal_hsa_timepoint_t** out_timepoints) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Acquire device events to wrap up. This should happen before acquiring the + // timepoints to avoid nested locks. + iree_hal_hsa_event_t** device_events = iree_alloca( + timepoint_count * sizeof((*out_timepoints)->timepoint.device_wait)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_event_pool_acquire(timepoint_pool->device_event_pool, + timepoint_count, device_events)); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hsa_timepoint_pool_acquire_internal( + timepoint_pool, timepoint_count, out_timepoints)); + for (iree_host_size_t i = 0; i < timepoint_count; ++i) { + out_timepoints[i]->kind = IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_WAIT; + out_timepoints[i]->timepoint.device_wait = device_events[i]; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +void iree_hal_hsa_timepoint_pool_release( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, iree_hal_hsa_timepoint_t** timepoints) { + IREE_ASSERT_ARGUMENT(timepoint_pool); + if (!timepoint_count) return; + IREE_ASSERT_ARGUMENT(timepoints); + IREE_TRACE_ZONE_BEGIN(z0); + + // Release the wrapped host/device events. This should happen before acquiring + // the timepoint pool's lock given that the host/device event pool its + // internal lock too. + // TODO: Release in batch to avoid lock overhead from separate calls. + for (iree_host_size_t i = 0; i < timepoint_count; ++i) { + switch (timepoints[i]->kind) { + case IREE_HAL_HSA_TIMEPOINT_KIND_HOST_WAIT: + iree_event_pool_release(timepoint_pool->host_event_pool, 1, + &timepoints[i]->timepoint.host_wait); + break; + case IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_SIGNAL: + iree_hal_hsa_event_release(timepoints[i]->timepoint.device_signal); + break; + case IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_WAIT: + iree_hal_hsa_event_release(timepoints[i]->timepoint.device_wait); + break; + default: + break; + } + } + + // We'll try to release all we can back to the pool and then deinitialize + // the ones that won't fit. + iree_host_size_t remaining_count = timepoint_count; + + // Try first to release to the pool. + iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex); + iree_host_size_t to_pool_count = iree_min( + timepoint_pool->available_capacity - timepoint_pool->available_count, + timepoint_count); + if (to_pool_count > 0) { + for (iree_host_size_t i = 0; i < to_pool_count; ++i) { + iree_hal_hsa_timepoint_clear(timepoints[i]); + } + iree_host_size_t pool_base_index = timepoint_pool->available_count; + memcpy(&timepoint_pool->available_list[pool_base_index], timepoints, + to_pool_count * sizeof(*timepoint_pool->available_list)); + timepoint_pool->available_count += to_pool_count; + remaining_count -= to_pool_count; + } + iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex); + + // Deallocate the rest of the timepoints. We don't bother resetting them as we + // are getting rid of them. + if (remaining_count > 0) { + IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-release"); + for (iree_host_size_t i = 0; i < remaining_count; ++i) { + iree_hal_hsa_timepoint_clear(timepoints[to_pool_count + i]); + iree_hal_hsa_timepoint_free(timepoints[to_pool_count + i]); + } + IREE_TRACE_ZONE_END(z1); + } + IREE_TRACE_ZONE_END(z0); +} diff --git a/runtime/src/iree-amd-aie/driver/hsa/timepoint_pool.h b/runtime/src/iree-amd-aie/driver/hsa/timepoint_pool.h new file mode 100644 index 000000000..e1ad69012 --- /dev/null +++ b/runtime/src/iree-amd-aie/driver/hsa/timepoint_pool.h @@ -0,0 +1,120 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved. +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_EXPERIMENTAL_HSA_TIMEPOINT_POOL_H_ +#define IREE_EXPERIMENTAL_HSA_TIMEPOINT_POOL_H_ + +#include "iree-amd-aie/driver/hsa/event_pool.h" +#include "iree/base/api.h" +#include "iree/base/internal/event_pool.h" +#include "iree/hal/utils/semaphore_base.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_timepoint_t +//===----------------------------------------------------------------------===// + +// Forward declaration of the timepoint pool. +typedef struct iree_hal_hsa_timepoint_pool_t iree_hal_hsa_timepoint_pool_t; + +// An enum to identify the timepoint kind in iree_hal_hsa_timepoint_t objects. +typedef enum iree_hal_hsa_timepoint_kind_e { + // None; for uninitialized timepoint objects. + IREE_HAL_HSA_TIMEPOINT_KIND_NONE = 0, + // A timepoint waited by the host. + IREE_HAL_HSA_TIMEPOINT_KIND_HOST_WAIT, + // A timepoint signaled by the device. + IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_SIGNAL, + // A timepoint waited by the device. + IREE_HAL_HSA_TIMEPOINT_KIND_DEVICE_WAIT, +} iree_hal_hsa_timepoint_kind_t; + +// An object that wraps a host iree_event_t or device iree_hal_hsa_event_t to +// represent wait/signal of a timepoint on a timeline. +// +// iree_hal_hsa_timepoint_t objects cannot be directly created; it should be +// acquired from the timeline pool and released back to the pool once done. +// +// Thread-compatible; a timepoint is typically only accessed by one thread. +typedef struct iree_hal_hsa_timepoint_t { + // Base timepoint structure providing intrusive linked list pointers and + // timepoint callback mechanisms. + iree_hal_semaphore_timepoint_t base; + + // The allocator used to create the timepoint. + iree_allocator_t host_allocator; + + // The timepoint pool that owns this timepoint. + iree_hal_hsa_timepoint_pool_t* pool; + + iree_hal_hsa_timepoint_kind_t kind; + union { + iree_event_t host_wait; + iree_hal_hsa_event_t* device_signal; + // The device event to wait. NULL means no device event available to wait + // for this timepoint at the moment. + iree_hal_hsa_event_t* device_wait; + } timepoint; +} iree_hal_hsa_timepoint_t; + +//===----------------------------------------------------------------------===// +// iree_hal_hsa_timepoint_pool_t +//===----------------------------------------------------------------------===// + +// A simple pool of iree_hal_hsa_timepoint_t objects to recycle. +// +// Thread-safe; multiple threads may acquire and release timepoints from the +// pool. +typedef struct iree_hal_hsa_timepoint_pool_t iree_hal_hsa_timepoint_pool_t; + +// Allocates a new timepoint pool with up to |available_capacity| timepoints. +// +// Extra timepoint requests beyond the capability are directly created and +// destroyed without pooling. +iree_status_t iree_hal_hsa_timepoint_pool_allocate( + iree_event_pool_t* host_event_pool, + iree_hal_hsa_event_pool_t* device_event_pool, + iree_host_size_t available_capacity, iree_allocator_t host_allocator, + iree_hal_hsa_timepoint_pool_t** out_timepoint_pool); + +// Deallocates a timepoint pool and destroys all timepoints. +// +// All timepoints that were acquired from the pool must have already been +// released back to it prior to deallocation. +void iree_hal_hsa_timepoint_pool_free( + iree_hal_hsa_timepoint_pool_t* timepoint_pool); + +// Acquires one or more timepoints from the timepoint pool. +// +// |out_timepoints| are owned by the caller and must be kept live until the +// timepoints have been reached, or cancelled by the caller. +iree_status_t iree_hal_hsa_timepoint_pool_acquire_host_wait( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, + iree_hal_hsa_timepoint_t** out_timepoints); +iree_status_t iree_hal_hsa_timepoint_pool_acquire_device_signal( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, + iree_hal_hsa_timepoint_t** out_timepoints); +iree_status_t iree_hal_hsa_timepoint_pool_acquire_device_wait( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, + iree_hal_hsa_timepoint_t** out_timepoints); + +// Releases one or more timepoints back to the timepoint pool. +void iree_hal_hsa_timepoint_pool_release( + iree_hal_hsa_timepoint_pool_t* timepoint_pool, + iree_host_size_t timepoint_count, iree_hal_hsa_timepoint_t** timepoints); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_EXPERIMENTAL_HSA_TIMEPOINT_POOL_H_ diff --git a/third_party/ROCR-Runtime b/third_party/ROCR-Runtime new file mode 160000 index 000000000..cb957298f --- /dev/null +++ b/third_party/ROCR-Runtime @@ -0,0 +1 @@ +Subproject commit cb957298f13a8a0c8bb30090713d5dc9121e2d52