diff --git a/include/ur_api.h b/include/ur_api.h index eb8b07221c..3dd476328f 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -1087,6 +1087,8 @@ typedef enum ur_platform_info_t { ///< info needs to be dynamically queried. UR_PLATFORM_INFO_BACKEND = 6, ///< [::ur_platform_backend_t] The backend of the platform. Identifies the ///< native backend adapter implementing this platform. + UR_PLATFORM_INFO_ADAPTER = 7, ///< [::ur_adapter_handle_t] The adapter handle associated with the + ///< platform. /// @cond UR_PLATFORM_INFO_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -1112,7 +1114,7 @@ typedef enum ur_platform_info_t { /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hPlatform` /// - ::UR_RESULT_ERROR_INVALID_ENUMERATION -/// + `::UR_PLATFORM_INFO_BACKEND < propName` +/// + `::UR_PLATFORM_INFO_ADAPTER < propName` /// - ::UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION /// + If `propName` is not supported by the adapter. /// - ::UR_RESULT_ERROR_INVALID_SIZE diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 8888a74f91..40e0c1793d 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -2024,6 +2024,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_platform_info_t value) case UR_PLATFORM_INFO_BACKEND: os << "UR_PLATFORM_INFO_BACKEND"; break; + case UR_PLATFORM_INFO_ADAPTER: + os << "UR_PLATFORM_INFO_ADAPTER"; + break; default: os << "unknown enumerator"; break; @@ -2077,6 +2080,19 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_platform_in os << ")"; } break; + case UR_PLATFORM_INFO_ADAPTER: { + const ur_adapter_handle_t *tptr = (const ur_adapter_handle_t *)ptr; + if (sizeof(ur_adapter_handle_t) > size) { + os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_adapter_handle_t) << ")"; + return UR_RESULT_ERROR_INVALID_SIZE; + } + os << (const void *)(tptr) << " ("; + + ur::details::printPtr(os, + *tptr); + + os << ")"; + } break; default: os << "unknown enumerator"; return UR_RESULT_ERROR_INVALID_ENUMERATION; diff --git a/scripts/core/platform.yml b/scripts/core/platform.yml index 997f4918ee..d4d7ef6a80 100644 --- a/scripts/core/platform.yml +++ b/scripts/core/platform.yml @@ -77,7 +77,9 @@ etors: - name: BACKEND value: "6" desc: "[$x_platform_backend_t] The backend of the platform. Identifies the native backend adapter implementing this platform." - + - name: ADAPTER + value: "7" + desc: "[$x_adapter_handle_t] The adapter handle associated with the platform." --- #-------------------------------------------------------------------------- type: function desc: "Retrieves various information about platform" diff --git a/source/adapters/cuda/platform.cpp b/source/adapters/cuda/platform.cpp index 7ce0bba9e7..20518494f7 100644 --- a/source/adapters/cuda/platform.cpp +++ b/source/adapters/cuda/platform.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "platform.hpp" +#include "adapter.hpp" #include "common.hpp" #include "context.hpp" #include "device.hpp" @@ -41,6 +42,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( case UR_PLATFORM_INFO_BACKEND: { return ReturnValue(UR_PLATFORM_BACKEND_CUDA); } + case UR_PLATFORM_INFO_ADAPTER: { + return ReturnValue(&adapter); + } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; } diff --git a/source/adapters/hip/platform.cpp b/source/adapters/hip/platform.cpp index 007889f138..fa0b07cc82 100644 --- a/source/adapters/hip/platform.cpp +++ b/source/adapters/hip/platform.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "platform.hpp" +#include "adapter.hpp" #include "context.hpp" UR_APIEXPORT ur_result_t UR_APICALL @@ -34,6 +35,9 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName, case UR_PLATFORM_INFO_EXTENSIONS: { return ReturnValue(""); } + case UR_PLATFORM_INFO_ADAPTER: { + return ReturnValue(&adapter); + } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; } diff --git a/source/adapters/level_zero/platform.cpp b/source/adapters/level_zero/platform.cpp index 0237b62863..18a417ff1b 100644 --- a/source/adapters/level_zero/platform.cpp +++ b/source/adapters/level_zero/platform.cpp @@ -95,6 +95,8 @@ ur_result_t urPlatformGetInfo( return ReturnValue(Platform->ZeDriverApiVersion.c_str()); case UR_PLATFORM_INFO_BACKEND: return ReturnValue(UR_PLATFORM_BACKEND_LEVEL_ZERO); + case UR_PLATFORM_INFO_ADAPTER: + return ReturnValue(GlobalAdapter); default: logger::debug("urPlatformGetInfo: unrecognized ParamName"); return UR_RESULT_ERROR_INVALID_VALUE; diff --git a/source/adapters/native_cpu/CMakeLists.txt b/source/adapters/native_cpu/CMakeLists.txt index 56cfc577d8..17467bfdef 100644 --- a/source/adapters/native_cpu/CMakeLists.txt +++ b/source/adapters/native_cpu/CMakeLists.txt @@ -9,6 +9,7 @@ set(TARGET_NAME ur_adapter_native_cpu) add_ur_adapter(${TARGET_NAME} SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/adapter.hpp ${CMAKE_CURRENT_SOURCE_DIR}/adapter.cpp ${CMAKE_CURRENT_SOURCE_DIR}/command_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp diff --git a/source/adapters/native_cpu/adapter.cpp b/source/adapters/native_cpu/adapter.cpp index 2b5b95ccd0..01fffeb01e 100644 --- a/source/adapters/native_cpu/adapter.cpp +++ b/source/adapters/native_cpu/adapter.cpp @@ -8,6 +8,7 @@ // //===----------------------------------------------------------------------===// +#include "adapter.hpp" #include "common.hpp" #include "ur_api.h" diff --git a/source/adapters/native_cpu/adapter.hpp b/source/adapters/native_cpu/adapter.hpp new file mode 100644 index 0000000000..2607aeb542 --- /dev/null +++ b/source/adapters/native_cpu/adapter.hpp @@ -0,0 +1,13 @@ +//===---------------- adapter.hpp - Native CPU Adapter --------------------===// +// +// Copyright (C) 2024 Intel Corporation +// +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM +// Exceptions. See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +struct ur_adapter_handle_t_; + +extern ur_adapter_handle_t_ Adapter; diff --git a/source/adapters/native_cpu/platform.cpp b/source/adapters/native_cpu/platform.cpp index 840f18f8b3..8e55037079 100644 --- a/source/adapters/native_cpu/platform.cpp +++ b/source/adapters/native_cpu/platform.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "platform.hpp" +#include "adapter.hpp" #include "common.hpp" #include "ur/ur.hpp" @@ -75,9 +76,9 @@ urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName, return ReturnValue(""); case UR_PLATFORM_INFO_BACKEND: - // TODO(alcpz): PR with this enum value at - // https://github.com/oneapi-src/unified-runtime return ReturnValue(UR_PLATFORM_BACKEND_NATIVE_CPU); + case UR_PLATFORM_INFO_ADAPTER: + return ReturnValue(&Adapter); default: DIE_NO_IMPLEMENTATION; } diff --git a/source/adapters/opencl/platform.cpp b/source/adapters/opencl/platform.cpp index b6d3a77cee..1400f27cf4 100644 --- a/source/adapters/opencl/platform.cpp +++ b/source/adapters/opencl/platform.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "platform.hpp" +#include "adapter.hpp" ur_result_t cl_adapter::getPlatformVersion(cl_platform_id Plat, oclv::OpenCLVersion &Version) { @@ -57,6 +58,8 @@ urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName, switch (static_cast(propName)) { case UR_PLATFORM_INFO_BACKEND: return ReturnValue(UR_PLATFORM_BACKEND_OPENCL); + case UR_PLATFORM_INFO_ADAPTER: + return ReturnValue(ur::cl::getAdapter()); case UR_PLATFORM_INFO_NAME: case UR_PLATFORM_INFO_VENDOR_NAME: case UR_PLATFORM_INFO_VERSION: diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index b3969de10f..6e96efe6bf 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -280,7 +280,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo( return UR_RESULT_ERROR_INVALID_NULL_POINTER; } - if (UR_PLATFORM_INFO_BACKEND < propName) { + if (UR_PLATFORM_INFO_ADAPTER < propName) { return UR_RESULT_ERROR_INVALID_ENUMERATION; } diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 86a6ad95a0..f482eb3560 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -289,10 +289,43 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo( // convert loader handle to platform handle hPlatform = reinterpret_cast(hPlatform)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hPlatform, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_PLATFORM_INFO_ADAPTER: { + ur_adapter_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_adapter_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + context->factories.ur_adapter_factory.getInstance( + handles[i], dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 3340363737..3c6822d613 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -557,7 +557,7 @@ ur_result_t UR_APICALL urPlatformGet( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hPlatform` /// - ::UR_RESULT_ERROR_INVALID_ENUMERATION -/// + `::UR_PLATFORM_INFO_BACKEND < propName` +/// + `::UR_PLATFORM_INFO_ADAPTER < propName` /// - ::UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION /// + If `propName` is not supported by the adapter. /// - ::UR_RESULT_ERROR_INVALID_SIZE diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 853d61472e..7f7eb65d40 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -501,7 +501,7 @@ ur_result_t UR_APICALL urPlatformGet( /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hPlatform` /// - ::UR_RESULT_ERROR_INVALID_ENUMERATION -/// + `::UR_PLATFORM_INFO_BACKEND < propName` +/// + `::UR_PLATFORM_INFO_ADAPTER < propName` /// - ::UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION /// + If `propName` is not supported by the adapter. /// - ::UR_RESULT_ERROR_INVALID_SIZE diff --git a/test/conformance/platform/urPlatformGetInfo.cpp b/test/conformance/platform/urPlatformGetInfo.cpp index 1dc92b26d7..3973b8ee6b 100644 --- a/test/conformance/platform/urPlatformGetInfo.cpp +++ b/test/conformance/platform/urPlatformGetInfo.cpp @@ -19,7 +19,8 @@ INSTANTIATE_TEST_SUITE_P( urPlatformGetInfo, urPlatformGetInfoTest, ::testing::Values(UR_PLATFORM_INFO_NAME, UR_PLATFORM_INFO_VENDOR_NAME, UR_PLATFORM_INFO_VERSION, UR_PLATFORM_INFO_EXTENSIONS, - UR_PLATFORM_INFO_PROFILE, UR_PLATFORM_INFO_BACKEND), + UR_PLATFORM_INFO_PROFILE, UR_PLATFORM_INFO_BACKEND, + UR_PLATFORM_INFO_ADAPTER), [](const ::testing::TestParamInfo &info) { std::stringstream ss; ss << info.param; @@ -38,8 +39,29 @@ TEST_P(urPlatformGetInfoTest, Success) { std::vector name(size); ASSERT_SUCCESS( urPlatformGetInfo(platform, info_type, size, name.data(), nullptr)); - if (info_type != UR_PLATFORM_INFO_BACKEND) { + switch (info_type) { + case UR_PLATFORM_INFO_NAME: + case UR_PLATFORM_INFO_VENDOR_NAME: + case UR_PLATFORM_INFO_VERSION: + case UR_PLATFORM_INFO_EXTENSIONS: + case UR_PLATFORM_INFO_PROFILE: { ASSERT_EQ(size, std::strlen(name.data()) + 1); + break; + } + case UR_PLATFORM_INFO_BACKEND: { + ASSERT_EQ(size, sizeof(ur_platform_backend_t)); + break; + } + case UR_PLATFORM_INFO_ADAPTER: { + auto queried_adapter = + *reinterpret_cast(name.data()); + auto adapter_found = + std::find(adapters.begin(), adapters.end(), queried_adapter); + ASSERT_NE(adapter_found, adapters.end()); + break; + } + default: + break; } } diff --git a/tools/urinfo/urinfo.hpp b/tools/urinfo/urinfo.hpp index 37c7a80328..b1dcb9e57e 100644 --- a/tools/urinfo/urinfo.hpp +++ b/tools/urinfo/urinfo.hpp @@ -45,6 +45,8 @@ inline void printPlatformInfos(ur_platform_handle_t hPlatform, std::cout << prefix; printPlatformInfo(hPlatform, UR_PLATFORM_INFO_BACKEND); + std::cout << prefix; + printPlatformInfo(hPlatform, UR_PLATFORM_INFO_ADAPTER); } inline void printDeviceInfos(ur_device_handle_t hDevice,