From 0ed0037496627a6bf0b2be9ce20913a8509a398d Mon Sep 17 00:00:00 2001 From: "Agarwal, Udit" Date: Fri, 27 Jun 2025 03:37:52 +0200 Subject: [PATCH] [SYCL][NFC] Pass adapter by reference in backend[_impl] --- sycl/source/backend.cpp | 66 ++++++++++++++-------------- sycl/source/context.cpp | 2 +- sycl/source/detail/adapter_impl.hpp | 2 +- sycl/source/detail/allowlist.cpp | 2 +- sycl/source/detail/context_impl.cpp | 6 +-- sycl/source/detail/context_impl.hpp | 4 +- sycl/source/detail/device_impl.hpp | 2 +- sycl/source/detail/platform_impl.cpp | 16 +++---- sycl/source/detail/platform_impl.hpp | 21 ++++----- sycl/source/device.cpp | 2 +- sycl/source/platform.cpp | 2 +- 11 files changed, 60 insertions(+), 65 deletions(-) diff --git a/sycl/source/backend.cpp b/sycl/source/backend.cpp index 091d36344654f..c2c9ff3c107f7 100644 --- a/sycl/source/backend.cpp +++ b/sycl/source/backend.cpp @@ -30,16 +30,16 @@ namespace sycl { inline namespace _V1 { namespace detail { -static const AdapterPtr &getAdapter(backend Backend) { +static const adapter_impl &getAdapter(backend Backend) { switch (Backend) { case backend::opencl: - return ur::getAdapter(); + return *ur::getAdapter(); case backend::ext_oneapi_level_zero: - return ur::getAdapter(); + return *ur::getAdapter(); case backend::ext_oneapi_cuda: - return ur::getAdapter(); + return *ur::getAdapter(); case backend::ext_oneapi_hip: - return ur::getAdapter(); + return *ur::getAdapter(); default: throw sycl::exception( sycl::make_error_code(sycl::errc::runtime), @@ -71,12 +71,12 @@ backend convertUrBackend(ur_backend_t UrBackend) { } platform make_platform(ur_native_handle_t NativeHandle, backend Backend) { - const auto &Adapter = getAdapter(Backend); + const adapter_impl &Adapter = getAdapter(Backend); // Create UR platform first. ur_platform_handle_t UrPlatform = nullptr; - Adapter->call( - NativeHandle, Adapter->getUrAdapter(), nullptr, &UrPlatform); + Adapter.call( + NativeHandle, Adapter.getUrAdapter(), nullptr, &UrPlatform); return detail::createSyclObjFromImpl( platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter)); @@ -84,11 +84,11 @@ platform make_platform(ur_native_handle_t NativeHandle, backend Backend) { __SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle, backend Backend) { - const auto &Adapter = getAdapter(Backend); + const adapter_impl &Adapter = getAdapter(Backend); ur_device_handle_t UrDevice = nullptr; - Adapter->call( - NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice); + Adapter.call( + NativeHandle, Adapter.getUrAdapter(), nullptr, &UrDevice); // Construct the SYCL device from UR device. return detail::createSyclObjFromImpl( @@ -100,7 +100,7 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle, const async_handler &Handler, backend Backend, bool KeepOwnership, const std::vector &DeviceList) { - const auto &Adapter = getAdapter(Backend); + const adapter_impl &Adapter = getAdapter(Backend); ur_context_handle_t UrContext = nullptr; ur_context_native_properties_t Properties{}; @@ -110,8 +110,8 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle, for (const auto &Dev : DeviceList) { DeviceHandles.push_back(detail::getSyclObjImpl(Dev)->getHandleRef()); } - Adapter->call( - NativeHandle, Adapter->getUrAdapter(), DeviceHandles.size(), + Adapter.call( + NativeHandle, Adapter.getUrAdapter(), DeviceHandles.size(), DeviceHandles.data(), &Properties, &UrContext); // Construct the SYCL context from UR context. return detail::createSyclObjFromImpl(context_impl::create( @@ -125,7 +125,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle, const async_handler &Handler, backend Backend) { ur_device_handle_t UrDevice = Device ? getSyclObjImpl(*Device)->getHandleRef() : nullptr; - const auto &Adapter = getAdapter(Backend); + const adapter_impl &Adapter = getAdapter(Backend); context_impl &ContextImpl = *getSyclObjImpl(Context); if (PropList.has_property()) { @@ -155,7 +155,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle, // Create UR queue first. ur_queue_handle_t UrQueue = nullptr; - Adapter->call( + Adapter.call( NativeHandle, ContextImpl.getHandleRef(), UrDevice, &NativeProperties, &UrQueue); // Construct the SYCL queue from UR queue. @@ -171,7 +171,7 @@ __SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle, __SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle, const context &Context, bool KeepOwnership, backend Backend) { - const auto &Adapter = getAdapter(Backend); + const adapter_impl &Adapter = getAdapter(Backend); const auto &ContextImpl = getSyclObjImpl(Context); ur_event_handle_t UrEvent = nullptr; @@ -179,7 +179,7 @@ __SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle, Properties.stype = UR_STRUCTURE_TYPE_EVENT_NATIVE_PROPERTIES; Properties.isNativeHandleOwned = !KeepOwnership; - Adapter->call( + Adapter.call( NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrEvent); event Event = detail::createSyclObjFromImpl( event_impl::create_from_handle(UrEvent, Context)); @@ -193,7 +193,7 @@ std::shared_ptr make_kernel_bundle(ur_native_handle_t NativeHandle, const context &TargetContext, bool KeepOwnership, bundle_state State, backend Backend) { - const auto &Adapter = getAdapter(Backend); + const adapter_impl &Adapter = getAdapter(Backend); const auto &ContextImpl = getSyclObjImpl(TargetContext); ur_program_handle_t UrProgram = nullptr; @@ -201,7 +201,7 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, Properties.stype = UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES; Properties.isNativeHandleOwned = !KeepOwnership; - Adapter->call( + Adapter.call( NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrProgram); if (UrProgram == nullptr) throw sycl::exception( @@ -214,39 +214,39 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, std::vector ProgramDevices; uint32_t NumDevices = 0; - Adapter->call( + Adapter.call( UrProgram, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(NumDevices), &NumDevices, nullptr); ProgramDevices.resize(NumDevices); - Adapter->call( + Adapter.call( UrProgram, UR_PROGRAM_INFO_DEVICES, sizeof(ur_device_handle_t) * NumDevices, ProgramDevices.data(), nullptr); for (auto &Dev : ProgramDevices) { ur_program_binary_type_t BinaryType; - Adapter->call( + Adapter.call( UrProgram, Dev, UR_PROGRAM_BUILD_INFO_BINARY_TYPE, sizeof(ur_program_binary_type_t), &BinaryType, nullptr); switch (BinaryType) { case (UR_PROGRAM_BINARY_TYPE_NONE): if (State == bundle_state::object) { - auto Res = Adapter->call_nocheck( + auto Res = Adapter.call_nocheck( UrProgram, 1, &Dev, nullptr); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { - Res = Adapter->call_nocheck( + Res = Adapter.call_nocheck( ContextImpl->getHandleRef(), UrProgram, nullptr); } - Adapter->checkUrResult(Res); + Adapter.checkUrResult(Res); } else if (State == bundle_state::executable) { - auto Res = Adapter->call_nocheck( + auto Res = Adapter.call_nocheck( UrProgram, 1, &Dev, nullptr); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { - Res = Adapter->call_nocheck( + Res = Adapter.call_nocheck( ContextImpl->getHandleRef(), UrProgram, nullptr); } - Adapter->checkUrResult(Res); + Adapter.checkUrResult(Res); } break; @@ -259,15 +259,15 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, detail::codeToString(UR_RESULT_ERROR_INVALID_VALUE)); if (State == bundle_state::executable) { ur_program_handle_t UrLinkedProgram = nullptr; - auto Res = Adapter->call_nocheck( + auto Res = Adapter.call_nocheck( ContextImpl->getHandleRef(), 1, &Dev, 1, &UrProgram, nullptr, &UrLinkedProgram); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { - Res = Adapter->call_nocheck( + Res = Adapter.call_nocheck( ContextImpl->getHandleRef(), 1, &UrProgram, nullptr, &UrLinkedProgram); } - Adapter->checkUrResult(Res); + Adapter.checkUrResult(Res); if (UrLinkedProgram != nullptr) { UrProgram = UrLinkedProgram; } @@ -351,7 +351,7 @@ kernel make_kernel(const context &TargetContext, ur_kernel_native_properties_t Properties{}; Properties.stype = UR_STRUCTURE_TYPE_KERNEL_NATIVE_PROPERTIES; Properties.isNativeHandleOwned = !KeepOwnership; - Adapter->call( + Adapter.call( NativeHandle, ContextImpl->getHandleRef(), UrProgram, &Properties, &UrKernel); diff --git a/sycl/source/context.cpp b/sycl/source/context.cpp index abb7760903316..67013cdcd8094 100644 --- a/sycl/source/context.cpp +++ b/sycl/source/context.cpp @@ -80,7 +80,7 @@ context::context(cl_context ClContext, async_handler AsyncHandler) { Adapter->call( nativeHandle, Adapter->getUrAdapter(), 0, nullptr, nullptr, &hContext); - impl = detail::context_impl::create(hContext, AsyncHandler, Adapter); + impl = detail::context_impl::create(hContext, AsyncHandler, *Adapter); } template diff --git a/sycl/source/detail/adapter_impl.hpp b/sycl/source/detail/adapter_impl.hpp index a1c16e148a13b..51fb2601d42db 100644 --- a/sycl/source/detail/adapter_impl.hpp +++ b/sycl/source/detail/adapter_impl.hpp @@ -107,7 +107,7 @@ class adapter_impl { return UrPlatforms; } - ur_adapter_handle_t getUrAdapter() { return MAdapter; } + ur_adapter_handle_t getUrAdapter() const { return MAdapter; } /// Calls the UR Api, traces the call, and returns the result. /// diff --git a/sycl/source/detail/allowlist.cpp b/sycl/source/detail/allowlist.cpp index 1dbb0a7d6b889..f5d6c86cec268 100644 --- a/sycl/source/detail/allowlist.cpp +++ b/sycl/source/detail/allowlist.cpp @@ -375,7 +375,7 @@ void applyAllowList(std::vector &UrDevices, // Get platform's backend and put it to DeviceDesc DeviceDescT DeviceDesc; platform_impl &PlatformImpl = - platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter); + platform_impl::getOrMakePlatformImpl(UrPlatform, *Adapter); backend Backend = PlatformImpl.getBackend(); for (const auto &SyclBe : getSyclBeMap()) { diff --git a/sycl/source/detail/context_impl.cpp b/sycl/source/detail/context_impl.cpp index 0bf73f191a8d0..cd29257f93519 100644 --- a/sycl/source/detail/context_impl.cpp +++ b/sycl/source/detail/context_impl.cpp @@ -62,7 +62,7 @@ context_impl::context_impl(const std::vector Devices, context_impl::context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, - const AdapterPtr &Adapter, + const adapter_impl &Adapter, const std::vector &DeviceList, bool OwnedByRuntime, private_tag) : MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(AsyncHandler), @@ -74,12 +74,12 @@ context_impl::context_impl(ur_context_handle_t UrContext, std::vector DeviceIds; uint32_t DevicesNum = 0; // TODO catch an exception and put it to list of asynchronous exceptions - Adapter->call( + Adapter.call( MContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum, nullptr); DeviceIds.resize(DevicesNum); // TODO catch an exception and put it to list of asynchronous exceptions - Adapter->call( + Adapter.call( MContext, UR_CONTEXT_INFO_DEVICES, sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr); diff --git a/sycl/source/detail/context_impl.hpp b/sycl/source/detail/context_impl.hpp index 24a19f0a9c674..733fa39d8f9f9 100644 --- a/sycl/source/detail/context_impl.hpp +++ b/sycl/source/detail/context_impl.hpp @@ -62,12 +62,12 @@ class context_impl : public std::enable_shared_from_this { /// \param OwnedByRuntime is the flag if ownership is kept by user or /// transferred to runtime context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, - const AdapterPtr &Adapter, + const adapter_impl &Adapter, const std::vector &DeviceList, bool OwnedByRuntime, private_tag); context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, - const AdapterPtr &Adapter, private_tag tag) + const adapter_impl &Adapter, private_tag tag) : context_impl(UrContext, AsyncHandler, Adapter, std::vector{}, /*OwnedByRuntime*/ true, tag) {} diff --git a/sycl/source/detail/device_impl.hpp b/sycl/source/detail/device_impl.hpp index 391bce575d8a7..4cd5fc622082a 100644 --- a/sycl/source/detail/device_impl.hpp +++ b/sycl/source/detail/device_impl.hpp @@ -724,7 +724,7 @@ class device_impl : public std::enable_shared_from_this { CASE(info::device::platform) { return createSyclObjFromImpl( platform_impl::getOrMakePlatformImpl( - get_info_impl(), getAdapter())); + get_info_impl(), *getAdapter())); } CASE(info::device::profile) { diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index 3e71cd500739c..1da5b5a96bf97 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -32,7 +32,7 @@ namespace detail { platform_impl & platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, - const AdapterPtr &Adapter) { + const adapter_impl &Adapter) { std::shared_ptr Result; { const std::lock_guard Guard( @@ -50,8 +50,8 @@ platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, // Otherwise make the impl. Our ctor/dtor are private, so std::make_shared // needs a bit of help... struct creator : platform_impl { - creator(ur_platform_handle_t APlatform, const AdapterPtr &AAdapter) - : platform_impl(APlatform, AAdapter) {} + creator(ur_platform_handle_t APlatform, const adapter_impl &AAdapter) + : platform_impl(APlatform, &AAdapter) {} }; Result = std::make_shared(UrPlatform, Adapter); PlatformCache.emplace_back(Result); @@ -62,12 +62,12 @@ platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, platform_impl & platform_impl::getPlatformFromUrDevice(ur_device_handle_t UrDevice, - const AdapterPtr &Adapter) { + const adapter_impl &Adapter) { ur_platform_handle_t Plt = nullptr; // TODO catch an exception and put it to list // of asynchronous exceptions - Adapter->call(UrDevice, UR_DEVICE_INFO_PLATFORM, - sizeof(Plt), &Plt, nullptr); + Adapter.call(UrDevice, UR_DEVICE_INFO_PLATFORM, + sizeof(Plt), &Plt, nullptr); return getOrMakePlatformImpl(Plt, Adapter); } @@ -131,7 +131,7 @@ std::vector platform_impl::getAdapterPlatforms(AdapterPtr &Adapter, for (const auto &UrPlatform : UrPlatforms) { platform Platform = detail::createSyclObjFromImpl( - getOrMakePlatformImpl(UrPlatform, Adapter)); + getOrMakePlatformImpl(UrPlatform, *Adapter)); const bool IsBanned = IsBannedPlatform(Platform); bool HasAnyDevices = false; @@ -543,7 +543,7 @@ platform_impl::get_devices(info::device_type DeviceType) const { // The next step is to inflate the filtered UrDevices into SYCL Device // objects. - platform_impl &PlatformImpl = getOrMakePlatformImpl(MPlatform, MAdapter); + platform_impl &PlatformImpl = getOrMakePlatformImpl(MPlatform, *MAdapter); std::transform(UrDevices.begin(), UrDevices.end(), std::back_inserter(Res), [&PlatformImpl](const ur_device_handle_t UrDevice) -> device { return detail::createSyclObjFromImpl( diff --git a/sycl/source/detail/platform_impl.hpp b/sycl/source/detail/platform_impl.hpp index 70e619b8ec9b7..dff52d08a19a0 100644 --- a/sycl/source/detail/platform_impl.hpp +++ b/sycl/source/detail/platform_impl.hpp @@ -39,8 +39,12 @@ class platform_impl : public std::enable_shared_from_this { // // Platforms can only be created under `GlobalHandler`'s ownership via // `platform_impl::getOrMakePlatformImpl` method. - explicit platform_impl(ur_platform_handle_t APlatform, adapter_impl *AAdapter) - : MPlatform(APlatform), MAdapter(AAdapter) { + explicit platform_impl(ur_platform_handle_t APlatform, + const adapter_impl *AAdapter) + : MPlatform(APlatform) { + + MAdapter = const_cast(AAdapter); + // Find out backend of the platform ur_backend_t UrBackend = UR_BACKEND_UNKNOWN; AAdapter->call_nocheck( @@ -137,15 +141,6 @@ class platform_impl : public std::enable_shared_from_this { // \return the Adapter associated with this platform. const AdapterPtr &getAdapter() const { return MAdapter; } - /// Sets the platform implementation to use another adapter. - /// - /// \param AdapterPtr is a pointer to a adapter instance - /// \param Backend is the backend that we want this platform to use - void setAdapter(AdapterPtr &AdapterPtr, backend Backend) { - MAdapter = AdapterPtr; - MBackend = Backend; - } - /// Gets the native handle of the SYCL platform. /// /// \return a native handle. @@ -188,7 +183,7 @@ class platform_impl : public std::enable_shared_from_this { /// \param Adapter is the UR adapter providing the backend for the platform /// \return the platform_impl representing the UR platform static platform_impl &getOrMakePlatformImpl(ur_platform_handle_t UrPlatform, - const AdapterPtr &Adapter); + const adapter_impl &Adapter); /// Queries the cache for the specified platform based on an input device. /// If found, returns the the cached platform_impl, otherwise creates a new @@ -200,7 +195,7 @@ class platform_impl : public std::enable_shared_from_this { /// platform /// \return the platform_impl that contains the input device static platform_impl &getPlatformFromUrDevice(ur_device_handle_t UrDevice, - const AdapterPtr &Adapter); + const adapter_impl &Adapter); context_impl &khr_get_default_context(); diff --git a/sycl/source/device.cpp b/sycl/source/device.cpp index 255bde38ad75e..415ec787ecb60 100644 --- a/sycl/source/device.cpp +++ b/sycl/source/device.cpp @@ -40,7 +40,7 @@ device::device(cl_device_id DeviceId) { Adapter->call( detail::ur::cast(DeviceId), Adapter->getUrAdapter(), nullptr, &Device); - impl = detail::platform_impl::getPlatformFromUrDevice(Device, Adapter) + impl = detail::platform_impl::getPlatformFromUrDevice(Device, *Adapter) .getOrMakeDeviceImpl(Device) .shared_from_this(); __SYCL_OCL_CALL(clRetainDevice, DeviceId); diff --git a/sycl/source/platform.cpp b/sycl/source/platform.cpp index b5942e20e056b..8979cd071f38d 100644 --- a/sycl/source/platform.cpp +++ b/sycl/source/platform.cpp @@ -30,7 +30,7 @@ platform::platform(cl_platform_id PlatformId) { Adapter->call( detail::ur::cast(PlatformId), Adapter->getUrAdapter(), /* pProperties = */ nullptr, &UrPlatform); - impl = detail::platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter) + impl = detail::platform_impl::getOrMakePlatformImpl(UrPlatform, *Adapter) .shared_from_this(); }