Skip to content

[UR][NativeCPU] Refactor UR Native CPU reference counting #19200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: sycl
Choose a base branch
from
10 changes: 5 additions & 5 deletions unified-runtime/source/adapters/level_zero/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ Behavior Summary:
SysMan initialization is skipped.
*/
ur_adapter_handle_t_::ur_adapter_handle_t_()
: handle_base(), logger(logger::get_logger("level_zero")) {
: handle_base(), logger(logger::get_logger("level_zero")), RefCount(0) {
ZeInitDriversResult = ZE_RESULT_ERROR_UNINITIALIZED;
ZeInitResult = ZE_RESULT_ERROR_UNINITIALIZED;
ZesResult = ZE_RESULT_ERROR_UNINITIALIZED;
Expand Down Expand Up @@ -675,7 +675,7 @@ ur_result_t urAdapterGet(
}
*Adapters = GlobalAdapter;

if (GlobalAdapter->RefCount++ == 0) {
if (GlobalAdapter->getRefCount().retain() == 0) {
adapterStateInit();
}
}
Expand All @@ -692,7 +692,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) {

// NOTE: This does not require guarding with a mutex; the instant the ref
// count hits zero, both Get and Retain are UB.
if (--GlobalAdapter->RefCount == 0) {
if (GlobalAdapter->getRefCount().release()) {
auto result = adapterStateTeardown();
#ifdef UR_STATIC_LEVEL_ZERO
// Given static linking of the L0 Loader, we must delay the loader's
Expand All @@ -711,7 +711,7 @@ ur_result_t urAdapterRelease([[maybe_unused]] ur_adapter_handle_t Adapter) {

ur_result_t urAdapterRetain([[maybe_unused]] ur_adapter_handle_t Adapter) {
assert(GlobalAdapter && GlobalAdapter == Adapter);
GlobalAdapter->RefCount++;
GlobalAdapter->getRefCount().retain();

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -740,7 +740,7 @@ ur_result_t urAdapterGetInfo(ur_adapter_handle_t, ur_adapter_info_t PropName,
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_BACKEND_LEVEL_ZERO);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(GlobalAdapter->RefCount.load());
return ReturnValue(GlobalAdapter->getRefCount().getCount());
case UR_ADAPTER_INFO_VERSION: {
#ifdef UR_ADAPTER_LEVEL_ZERO_V2
uint32_t adapterVersion = 2;
Expand Down
8 changes: 6 additions & 2 deletions unified-runtime/source/adapters/level_zero/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
//===----------------------------------------------------------------------===//
#pragma once

#include "common/ur_ref_count.hpp"
#include "logger/ur_logger.hpp"
#include "ur_interface_loader.hpp"
#include <atomic>
#include <loader/ur_loader.hpp>
#include <loader/ze_loader.h>
#include <optional>
Expand All @@ -26,7 +26,6 @@ class ur_legacy_sink;

struct ur_adapter_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter> {
ur_adapter_handle_t_();
std::atomic<uint32_t> RefCount = 0;

zes_pfnDriverGetDeviceByUuidExp_t getDeviceByUUIdFunctionPtr = nullptr;
zes_pfnDriverGet_t getSysManDriversFunctionPtr = nullptr;
Expand All @@ -45,6 +44,11 @@ struct ur_adapter_handle_t_ : ur::handle_base<ur::level_zero::ddi_getter> {
ZeCache<Result<PlatformVec>> PlatformCache;
logger::Logger &logger;
HMODULE processHandle = nullptr;

ur::RefCount &getRefCount() noexcept { return RefCount; }

private:
ur::RefCount RefCount;
};

extern ur_adapter_handle_t_ *GlobalAdapter;
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ ur_result_t urEnqueueUSMFreeExp(
}

size_t size = umfPoolMallocUsableSize(hPool, Mem);
(*Event)->RefCount.increment();
(*Event)->getRefCount().retain();
usmPool->AsyncPool.insert(Mem, size, *Event, Queue);

// Signal that USM free event was finished
Expand Down
10 changes: 5 additions & 5 deletions unified-runtime/source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,13 +842,13 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,

ur_result_t
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t CommandBuffer) {
CommandBuffer->RefCount.increment();
CommandBuffer->getRefCount().retain();
return UR_RESULT_SUCCESS;
}

ur_result_t
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t CommandBuffer) {
if (!CommandBuffer->RefCount.decrementAndTest())
if (!CommandBuffer->getRefCount().release())
return UR_RESULT_SUCCESS;

UR_CALL(waitForOngoingExecution(CommandBuffer));
Expand Down Expand Up @@ -1643,7 +1643,7 @@ ur_result_t enqueueImmediateAppendPath(
if (CommandBuffer->CurrentSubmissionEvent) {
UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent));
}
(*Event)->RefCount.increment();
(*Event)->getRefCount().retain();
CommandBuffer->CurrentSubmissionEvent = *Event;

UR_CALL(Queue->executeCommandList(CommandListHelper, false, false));
Expand Down Expand Up @@ -1726,7 +1726,7 @@ ur_result_t enqueueWaitEventPath(ur_exp_command_buffer_handle_t CommandBuffer,
if (CommandBuffer->CurrentSubmissionEvent) {
UR_CALL(urEventReleaseInternal(CommandBuffer->CurrentSubmissionEvent));
}
(*Event)->RefCount.increment();
(*Event)->getRefCount().retain();
CommandBuffer->CurrentSubmissionEvent = *Event;

UR_CALL(Queue->executeCommandList(SignalCommandList, false /*IsBlocking*/,
Expand Down Expand Up @@ -1850,7 +1850,7 @@ urCommandBufferGetInfoExp(ur_exp_command_buffer_handle_t hCommandBuffer,

switch (propName) {
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
return ReturnValue(uint32_t{hCommandBuffer->RefCount.load()});
return ReturnValue(uint32_t{hCommandBuffer->getRefCount().getCount()});
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
ur_exp_command_buffer_desc_t Descriptor{};
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "common.hpp"

#include "common/ur_ref_count.hpp"
#include "context.hpp"
#include "kernel.hpp"
#include "queue.hpp"
Expand Down Expand Up @@ -149,4 +150,9 @@ struct ur_exp_command_buffer_handle_t_ : public ur_object {
// Track handle objects to free when command-buffer is destroyed.
std::vector<std::unique_ptr<ur_exp_command_buffer_command_handle_t_>>
CommandHandles;

ur::RefCount &getRefCount() noexcept { return RefCount; }

private:
ur::RefCount RefCount;
};
54 changes: 7 additions & 47 deletions unified-runtime/source/adapters/level_zero/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <level_zero/ze_intel_gpu.h>
#include <umf_pools/disjoint_pool_config_parser.hpp>

#include "common/ur_ref_count.hpp"
#include "logger/ur_logger.hpp"
#include "ur_interface_loader.hpp"

Expand Down Expand Up @@ -220,55 +221,9 @@ void zeParseError(ze_result_t ZeError, const char *&ErrorString);
#define ZE_CALL_NOCHECK_NAME(ZeName, ZeArgs, callName) \
ZeCall().doCall(ZeName ZeArgs, callName, #ZeArgs, false)

// This wrapper around std::atomic is created to limit operations with reference
// counter and to make allowed operations more transparent in terms of
// thread-safety in the plugin. increment() and load() operations do not need a
// mutex guard around them since the underlying data is already atomic.
// decrementAndTest() method is used to guard a code which needs to be
// executed when object's ref count becomes zero after release. This method also
// doesn't need a mutex guard because decrement operation is atomic and only one
// thread can reach ref count equal to zero, i.e. only a single thread can pass
// through this check.
struct ReferenceCounter {
ReferenceCounter() : RefCount{1} {}

// Reset the counter to the initial value.
void reset() { RefCount = 1; }

// Used when retaining an object.
void increment() { RefCount++; }

// Supposed to be used in ur*GetInfo* methods where ref count value is
// requested.
uint32_t load() { return RefCount.load(); }

// This method allows to guard a code which needs to be executed when object's
// ref count becomes zero after release. It is important to notice that only a
// single thread can pass through this check. This is true because of several
// reasons:
// 1. Decrement operation is executed atomically.
// 2. It is not allowed to retain an object after its refcount reaches zero.
// 3. It is not allowed to release an object more times than the value of
// the ref count.
// 2. and 3. basically means that we can't use an object at all as soon as its
// refcount reaches zero. Using this check guarantees that code for deleting
// an object and releasing its resources is executed once by a single thread
// and we don't need to use any mutexes to guard access to this object in the
// scope after this check. Of course if we access another objects in this code
// (not the one which is being deleted) then access to these objects must be
// guarded, for example with a mutex.
bool decrementAndTest() { return --RefCount == 0; }

private:
std::atomic<uint32_t> RefCount;
};

// Base class to store common data
struct ur_object : ur::handle_base<ur::level_zero::ddi_getter> {
ur_object() : handle_base(), RefCount{} {}

// Must be atomic to prevent data race when incrementing/decrementing.
ReferenceCounter RefCount;
ur_object() : handle_base() {}

// This mutex protects accesses to all the non-const member variables.
// Exclusive access is required to modify any of these members.
Expand Down Expand Up @@ -303,6 +258,11 @@ struct MemAllocRecord : ur_object {
// TODO: this should go away when memory isolation issue is fixed in the Level
// Zero runtime.
ur_context_handle_t Context;

ur::RefCount &getRefCount() noexcept { return RefCount; }

private:
ur::RefCount RefCount;
};

extern usm::DisjointPoolAllConfigs DisjointPoolConfigInstance;
Expand Down
6 changes: 3 additions & 3 deletions unified-runtime/source/adapters/level_zero/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ur_result_t urContextRetain(

/// [in] handle of the context to get a reference of.
ur_context_handle_t Context) {
Context->RefCount.increment();
Context->getRefCount().retain();
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -113,7 +113,7 @@ ur_result_t urContextGetInfo(
case UR_CONTEXT_INFO_NUM_DEVICES:
return ReturnValue(uint32_t(Context->Devices.size()));
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(uint32_t{Context->RefCount.load()});
return ReturnValue(uint32_t{Context->getRefCount().getCount()});
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
// 2D USM memcpy is supported.
return ReturnValue(uint8_t{UseMemcpy2DOperations});
Expand Down Expand Up @@ -251,7 +251,7 @@ ur_device_handle_t ur_context_handle_t_::getRootDevice() const {
// from the list of tracked contexts.
ur_result_t ContextReleaseHelper(ur_context_handle_t Context) {

if (!Context->RefCount.decrementAndTest())
if (!Context->getRefCount().release())
return UR_RESULT_SUCCESS;

if (IndirectAccessTrackingEnabled) {
Expand Down
5 changes: 5 additions & 0 deletions unified-runtime/source/adapters/level_zero/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "queue.hpp"
#include "usm.hpp"

#include "common/ur_ref_count.hpp"
#include <umf_helpers.hpp>

struct l0_command_list_cache_info {
Expand Down Expand Up @@ -358,6 +359,8 @@ struct ur_context_handle_t_ : ur_object {
// Get handle to the L0 context
ze_context_handle_t getZeHandle() const;

ur::RefCount &getRefCount() noexcept { return RefCount; }

private:
enum EventFlags {
EVENT_FLAG_HOST_VISIBLE = UR_BIT(0),
Expand Down Expand Up @@ -404,6 +407,8 @@ struct ur_context_handle_t_ : ur_object {

return &EventCaches[index];
}

ur::RefCount RefCount;
};

// Helper function to release the context, a caller must lock the platform-level
Expand Down
6 changes: 3 additions & 3 deletions unified-runtime/source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ ur_result_t urDeviceGetInfo(
return ReturnValue((uint32_t)Device->SubDevices.size());
}
case UR_DEVICE_INFO_REFERENCE_COUNT:
return ReturnValue(uint32_t{Device->RefCount.load()});
return ReturnValue(uint32_t{Device->getRefCount().getCount()});
case UR_DEVICE_INFO_SUPPORTED_PARTITIONS: {
// SYCL spec says: if this SYCL device cannot be partitioned into at least
// two sub devices then the returned vector must be empty.
Expand Down Expand Up @@ -1666,15 +1666,15 @@ ur_result_t urDeviceGetGlobalTimestamps(
ur_result_t urDeviceRetain(ur_device_handle_t Device) {
// The root-device ref-count remains unchanged (always 1).
if (Device->isSubDevice()) {
Device->RefCount.increment();
Device->getRefCount().retain();
}
return UR_RESULT_SUCCESS;
}

ur_result_t urDeviceRelease(ur_device_handle_t Device) {
// Root devices are destroyed during the piTearDown process.
if (Device->isSubDevice()) {
if (Device->RefCount.decrementAndTest()) {
if (Device->getRefCount().release()) {
delete Device;
}
}
Expand Down
6 changes: 6 additions & 0 deletions unified-runtime/source/adapters/level_zero/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "adapters/level_zero/platform.hpp"
#include "common.hpp"
#include "common/ur_ref_count.hpp"
#include <ur/ur.hpp>
#include <ur_ddi.h>
#include <ze_api.h>
Expand Down Expand Up @@ -242,6 +243,11 @@ struct ur_device_handle_t_ : ur_object {

// unique ephemeral identifer of the device in the adapter
std::optional<DeviceId> Id;

ur::RefCount &getRefCount() noexcept { return RefCount; }

private:
ur::RefCount RefCount;
};

inline std::vector<ur_device_handle_t>
Expand Down
14 changes: 7 additions & 7 deletions unified-runtime/source/adapters/level_zero/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ ur_result_t urEventGetInfo(
return ReturnValue(Result);
}
case UR_EVENT_INFO_REFERENCE_COUNT: {
return ReturnValue(Event->RefCount.load());
return ReturnValue(Event->getRefCount().getCount());
}
default:
UR_LOG(ERR, "Unsupported ParamName in urEventGetInfo: ParamName={}(0x{})",
Expand Down Expand Up @@ -874,7 +874,7 @@ ur_result_t
/// [in] handle of the event object
urEventRetain(/** [in] handle of the event object */ ur_event_handle_t Event) {
Event->RefCountExternal++;
Event->RefCount.increment();
Event->getRefCount().retain();

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -1088,7 +1088,7 @@ ur_event_handle_t_::~ur_event_handle_t_() {

ur_result_t urEventReleaseInternal(ur_event_handle_t Event,
bool *isEventDeleted) {
if (!Event->RefCount.decrementAndTest())
if (!Event->getRefCount().release())
return UR_RESULT_SUCCESS;

if (Event->OriginAllocEvent) {
Expand Down Expand Up @@ -1524,7 +1524,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
std::shared_lock<ur_shared_mutex> Lock(CurQueue->LastCommandEvent->Mutex);
this->ZeEventList[0] = CurQueue->LastCommandEvent->ZeEvent;
this->UrEventList[0] = CurQueue->LastCommandEvent;
this->UrEventList[0]->RefCount.increment();
this->UrEventList[0]->getRefCount().retain();
TmpListLength = 1;
} else if (EventListLength > 0) {
this->ZeEventList = new ze_event_handle_t[EventListLength];
Expand Down Expand Up @@ -1660,7 +1660,7 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(
IsInternal, IsMultiDevice));
MultiDeviceZeEvent = MultiDeviceEvent->ZeEvent;
const auto &ZeCommandList = CommandList->first;
EventList[I]->RefCount.increment();
EventList[I]->getRefCount().retain();

// Append a Barrier to wait on the original event while signalling the
// new multi device event.
Expand All @@ -1676,11 +1676,11 @@ ur_result_t ur_ze_event_list_t::createAndRetainUrZeEventList(

this->ZeEventList[TmpListLength] = MultiDeviceZeEvent;
this->UrEventList[TmpListLength] = MultiDeviceEvent;
this->UrEventList[TmpListLength]->RefCount.increment();
this->UrEventList[TmpListLength]->getRefCount().retain();
} else {
this->ZeEventList[TmpListLength] = EventList[I]->ZeEvent;
this->UrEventList[TmpListLength] = EventList[I];
this->UrEventList[TmpListLength]->RefCount.increment();
this->UrEventList[TmpListLength]->getRefCount().retain();
}

if (QueueLock.has_value()) {
Expand Down
Loading
Loading