Skip to content

Commit

Permalink
Add query to retrieve adapter handle from platform.
Browse files Browse the repository at this point in the history
  • Loading branch information
aarongreig committed Nov 21, 2024
1 parent 50f66ae commit 07d446f
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 17 deletions.
4 changes: 3 additions & 1 deletion include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,8 @@ typedef enum ur_platform_info_t {
///< info needs to be dynamically queried.
UR_PLATFORM_INFO_BACKEND = 6, ///< [::ur_platform_backend_t] The backend of the platform. Identifies the
///< native backend adapter implementing this platform.
UR_PLATFORM_INFO_ADAPTER = 7, ///< [::ur_adapter_handle_t] The adapter handle associated with the
///< platform.
/// @cond
UR_PLATFORM_INFO_FORCE_UINT32 = 0x7fffffff
/// @endcond
Expand All @@ -1112,7 +1114,7 @@ typedef enum ur_platform_info_t {
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `::UR_PLATFORM_INFO_BACKEND < propName`
/// + `::UR_PLATFORM_INFO_ADAPTER < propName`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION
/// + If `propName` is not supported by the adapter.
/// - ::UR_RESULT_ERROR_INVALID_SIZE
Expand Down
36 changes: 28 additions & 8 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,9 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_platform_info_t value)
case UR_PLATFORM_INFO_BACKEND:
os << "UR_PLATFORM_INFO_BACKEND";
break;
case UR_PLATFORM_INFO_ADAPTER:
os << "UR_PLATFORM_INFO_ADAPTER";
break;
default:
os << "unknown enumerator";
break;
Expand Down Expand Up @@ -2073,6 +2076,19 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_platform_in

os << ")";
} break;
case UR_PLATFORM_INFO_ADAPTER: {
const ur_adapter_handle_t *tptr = (const ur_adapter_handle_t *)ptr;
if (sizeof(ur_adapter_handle_t) > size) {
os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_adapter_handle_t) << ")";
return UR_RESULT_ERROR_INVALID_SIZE;
}
os << (const void *)(tptr) << " (";

ur::details::printPtr(os,
*tptr);

os << ")";
} break;
default:
os << "unknown enumerator";
return UR_RESULT_ERROR_INVALID_ENUMERATION;
Expand Down Expand Up @@ -15107,16 +15123,20 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
os << *(params->pnumEventsInWaitList);

os << ", ";
os << ".phEventWaitList = {";
for (size_t i = 0; *(params->pphEventWaitList) != NULL && i < *params->pnumEventsInWaitList; ++i) {
if (i != 0) {
os << ", ";
}
os << ".phEventWaitList = ";
ur::details::printPtr(os, reinterpret_cast<const void *>(*(params->pphEventWaitList)));
if (*(params->pphEventWaitList) != NULL) {
os << " {";
for (size_t i = 0; i < *params->pnumEventsInWaitList; ++i) {
if (i != 0) {
os << ", ";
}

ur::details::printPtr(os,
(*(params->pphEventWaitList))[i]);
ur::details::printPtr(os,
(*(params->pphEventWaitList))[i]);
}
os << "}";
}
os << "}";

os << ", ";
os << ".phEvent = ";
Expand Down
4 changes: 3 additions & 1 deletion scripts/core/platform.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ etors:
- name: BACKEND
value: "6"
desc: "[$x_platform_backend_t] The backend of the platform. Identifies the native backend adapter implementing this platform."

- name: ADAPTER
value: "7"
desc: "[$x_adapter_handle_t] The adapter handle associated with the platform."
--- #--------------------------------------------------------------------------
type: function
desc: "Retrieves various information about platform"
Expand Down
4 changes: 4 additions & 0 deletions source/adapters/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "platform.hpp"
#include "adapter.hpp"
#include "common.hpp"
#include "context.hpp"
#include "device.hpp"
Expand Down Expand Up @@ -41,6 +42,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
case UR_PLATFORM_INFO_BACKEND: {
return ReturnValue(UR_PLATFORM_BACKEND_CUDA);
}
case UR_PLATFORM_INFO_ADAPTER: {
return ReturnValue(&adapter);
}
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
Expand Down
4 changes: 4 additions & 0 deletions source/adapters/hip/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "platform.hpp"
#include "adapter.hpp"
#include "context.hpp"

UR_APIEXPORT ur_result_t UR_APICALL
Expand All @@ -34,6 +35,9 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName,
case UR_PLATFORM_INFO_EXTENSIONS: {
return ReturnValue("");
}
case UR_PLATFORM_INFO_ADAPTER: {
return ReturnValue(&adapter);
}
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ ur_result_t urPlatformGetInfo(
return ReturnValue(Platform->ZeDriverApiVersion.c_str());
case UR_PLATFORM_INFO_BACKEND:
return ReturnValue(UR_PLATFORM_BACKEND_LEVEL_ZERO);
case UR_PLATFORM_INFO_ADAPTER:
return ReturnValue(GlobalAdapter);
default:
logger::debug("urPlatformGetInfo: unrecognized ParamName");
return UR_RESULT_ERROR_INVALID_VALUE;
Expand Down
1 change: 1 addition & 0 deletions source/adapters/native_cpu/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//
//===----------------------------------------------------------------------===//

#include "adapter.hpp"
#include "common.hpp"
#include "ur_api.h"

Expand Down
13 changes: 13 additions & 0 deletions source/adapters/native_cpu/adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===---------------- adapter.hpp - Native CPU Adapter --------------------===//
//
// Copyright (C) 2024 Intel Corporation
//
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
// Exceptions. See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

struct ur_adapter_handle_t_;

extern ur_adapter_handle_t_ Adapter;
5 changes: 3 additions & 2 deletions source/adapters/native_cpu/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "platform.hpp"
#include "adapter.hpp"
#include "common.hpp"

#include "ur/ur.hpp"
Expand Down Expand Up @@ -75,9 +76,9 @@ urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName,
return ReturnValue("");

case UR_PLATFORM_INFO_BACKEND:
// TODO(alcpz): PR with this enum value at
// https://github.com/oneapi-src/unified-runtime
return ReturnValue(UR_PLATFORM_BACKEND_NATIVE_CPU);
case UR_PLATFORM_INFO_ADAPTER:
return ReturnValue(&Adapter);
default:
DIE_NO_IMPLEMENTATION;
}
Expand Down
3 changes: 3 additions & 0 deletions source/adapters/opencl/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "platform.hpp"
#include "adapter.hpp"

ur_result_t cl_adapter::getPlatformVersion(cl_platform_id Plat,
oclv::OpenCLVersion &Version) {
Expand Down Expand Up @@ -57,6 +58,8 @@ urPlatformGetInfo(ur_platform_handle_t hPlatform, ur_platform_info_t propName,
switch (static_cast<uint32_t>(propName)) {
case UR_PLATFORM_INFO_BACKEND:
return ReturnValue(UR_PLATFORM_BACKEND_OPENCL);
case UR_PLATFORM_INFO_ADAPTER:
return ReturnValue(ur::cl::getAdapter());
case UR_PLATFORM_INFO_NAME:
case UR_PLATFORM_INFO_VENDOR_NAME:
case UR_PLATFORM_INFO_VERSION:
Expand Down
2 changes: 1 addition & 1 deletion source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo(
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (UR_PLATFORM_INFO_BACKEND < propName) {
if (UR_PLATFORM_INFO_ADAPTER < propName) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

Expand Down
33 changes: 33 additions & 0 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,43 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetInfo(
// convert loader handle to platform handle
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;

// this value is needed for converting adapter handles to loader handles
size_t sizeret = 0;
if (pPropSizeRet == NULL) {
pPropSizeRet = &sizeret;
}

// forward to device-platform
result =
pfnGetInfo(hPlatform, propName, propSize, pPropValue, pPropSizeRet);

if (UR_RESULT_SUCCESS != result) {
return result;
}

try {
if (pPropValue != nullptr) {
switch (propName) {
case UR_PLATFORM_INFO_ADAPTER: {
ur_adapter_handle_t *handles =
reinterpret_cast<ur_adapter_handle_t *>(pPropValue);
size_t nelements = *pPropSizeRet / sizeof(ur_adapter_handle_t);
for (size_t i = 0; i < nelements; ++i) {
if (handles[i] != nullptr) {
handles[i] = reinterpret_cast<ur_adapter_handle_t>(
context->factories.ur_adapter_factory.getInstance(
handles[i], dditable));
}
}
} break;
default: {
} break;
}
}
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand Down
2 changes: 1 addition & 1 deletion source/loader/ur_libapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ ur_result_t UR_APICALL urPlatformGet(
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `::UR_PLATFORM_INFO_BACKEND < propName`
/// + `::UR_PLATFORM_INFO_ADAPTER < propName`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION
/// + If `propName` is not supported by the adapter.
/// - ::UR_RESULT_ERROR_INVALID_SIZE
Expand Down
2 changes: 1 addition & 1 deletion source/ur_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ ur_result_t UR_APICALL urPlatformGet(
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `::UR_PLATFORM_INFO_BACKEND < propName`
/// + `::UR_PLATFORM_INFO_ADAPTER < propName`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION
/// + If `propName` is not supported by the adapter.
/// - ::UR_RESULT_ERROR_INVALID_SIZE
Expand Down
26 changes: 24 additions & 2 deletions test/conformance/platform/urPlatformGetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ INSTANTIATE_TEST_SUITE_P(
urPlatformGetInfo, urPlatformGetInfoTest,
::testing::Values(UR_PLATFORM_INFO_NAME, UR_PLATFORM_INFO_VENDOR_NAME,
UR_PLATFORM_INFO_VERSION, UR_PLATFORM_INFO_EXTENSIONS,
UR_PLATFORM_INFO_PROFILE, UR_PLATFORM_INFO_BACKEND),
UR_PLATFORM_INFO_PROFILE, UR_PLATFORM_INFO_BACKEND,
UR_PLATFORM_INFO_ADAPTER),
[](const ::testing::TestParamInfo<ur_platform_info_t> &info) {
std::stringstream ss;
ss << info.param;
Expand All @@ -38,8 +39,29 @@ TEST_P(urPlatformGetInfoTest, Success) {
std::vector<char> name(size);
ASSERT_SUCCESS(
urPlatformGetInfo(platform, info_type, size, name.data(), nullptr));
if (info_type != UR_PLATFORM_INFO_BACKEND) {
switch (info_type) {
case UR_PLATFORM_INFO_NAME:
case UR_PLATFORM_INFO_VENDOR_NAME:
case UR_PLATFORM_INFO_VERSION:
case UR_PLATFORM_INFO_EXTENSIONS:
case UR_PLATFORM_INFO_PROFILE: {
ASSERT_EQ(size, std::strlen(name.data()) + 1);
break;
}
case UR_PLATFORM_INFO_BACKEND: {
ASSERT_EQ(size, sizeof(ur_platform_backend_t));
break;
}
case UR_PLATFORM_INFO_ADAPTER: {
auto queried_adapter =
*reinterpret_cast<ur_adapter_handle_t *>(name.data());
auto adapter_found =
std::find(adapters.begin(), adapters.end(), queried_adapter);
ASSERT_NE(adapter_found, adapters.end());
break;
}
default:
break;
}
}

Expand Down
2 changes: 2 additions & 0 deletions tools/urinfo/urinfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ inline void printPlatformInfos(ur_platform_handle_t hPlatform,
std::cout << prefix;
printPlatformInfo<ur_platform_backend_t>(hPlatform,
UR_PLATFORM_INFO_BACKEND);
std::cout << prefix;
printPlatformInfo<ur_adapter_handle_t>(hPlatform, UR_PLATFORM_INFO_ADAPTER);
}

inline void printDeviceInfos(ur_device_handle_t hDevice,
Expand Down

0 comments on commit 07d446f

Please sign in to comment.