From 2afcfa13ca0b7985bbf58c7b0d3adb7af863d720 Mon Sep 17 00:00:00 2001 From: Callum Fare Date: Thu, 6 Feb 2025 17:23:14 +0000 Subject: [PATCH] Update to latest offload API --- source/adapters/offload/context.cpp | 28 +++++++++ source/adapters/offload/context.hpp | 19 ++++++ source/adapters/offload/enqueue.cpp | 42 +++++++++++++ source/adapters/offload/event.cpp | 23 +++++++ source/adapters/offload/kernel.cpp | 69 +++++++++++++++++++++ source/adapters/offload/kernel.hpp | 6 ++ source/adapters/offload/program.cpp | 94 +++++++++++++++++++++++++++++ source/adapters/offload/queue.cpp | 39 ++++++++++++ source/adapters/offload/usm.cpp | 61 +++++++++++++++++++ 9 files changed, 381 insertions(+) create mode 100644 source/adapters/offload/context.cpp create mode 100644 source/adapters/offload/context.hpp create mode 100644 source/adapters/offload/enqueue.cpp create mode 100644 source/adapters/offload/event.cpp create mode 100644 source/adapters/offload/kernel.cpp create mode 100644 source/adapters/offload/kernel.hpp create mode 100644 source/adapters/offload/program.cpp create mode 100644 source/adapters/offload/queue.cpp create mode 100644 source/adapters/offload/usm.cpp diff --git a/source/adapters/offload/context.cpp b/source/adapters/offload/context.cpp new file mode 100644 index 0000000000..01d015038c --- /dev/null +++ b/source/adapters/offload/context.cpp @@ -0,0 +1,28 @@ +#include "context.hpp" +#include + +UR_APIEXPORT ur_result_t UR_APICALL urContextCreate( + uint32_t DeviceCount, const ur_device_handle_t *phDevices, + const ur_context_properties_t *, ur_context_handle_t *phContext) { + if (DeviceCount > 1) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + auto Ctx = new ur_context_handle_t_(*phDevices); + *phContext = Ctx; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL +urContextRetain(ur_context_handle_t hContext) { + hContext->RefCount++; + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL +urContextRelease(ur_context_handle_t hContext) { + if (--hContext->RefCount == 0) { + delete hContext; + } + return UR_RESULT_SUCCESS; +} diff --git a/source/adapters/offload/context.hpp b/source/adapters/offload/context.hpp new file mode 100644 index 0000000000..9483ec1b4a --- /dev/null +++ b/source/adapters/offload/context.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include +#include + +struct ur_context_handle_t_ { + ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} { + urDeviceRetain(Device); + } + ~ur_context_handle_t_() { + urDeviceRelease(Device); + } + + ur_device_handle_t Device; + std::atomic_uint32_t RefCount; + std::unordered_map AllocTypeMap; +}; diff --git a/source/adapters/offload/enqueue.cpp b/source/adapters/offload/enqueue.cpp new file mode 100644 index 0000000000..71775312bd --- /dev/null +++ b/source/adapters/offload/enqueue.cpp @@ -0,0 +1,42 @@ +#include +#include +#include + +#include "ur2offload.hpp" + +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( + ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim, + const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize, + const size_t *pLocalWorkSize, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + // Ignore wait list for now + (void)numEventsInWaitList; + (void)phEventWaitList; + // + + assert(workDim == 1); + + ol_kernel_launch_size_args_t LaunchArgs; + LaunchArgs.Dimensions = workDim; + LaunchArgs.NumGroupsX = pGlobalWorkSize[0]; + LaunchArgs.NumGroupsY = 1; + LaunchArgs.NumGroupsZ = 1; + LaunchArgs.GroupSizeX = 1; + LaunchArgs.GroupSizeY = 1; + LaunchArgs.GroupSizeZ = 1; + + ol_event_handle_t EventOut; + auto Ret = + olEnqueueKernelLaunch(reinterpret_cast(hQueue), + reinterpret_cast(hKernel), + &LaunchArgs, &EventOut); + + if (Ret != OL_SUCCESS) { + return offloadResultToUR(Ret); + } + + if (phEvent) { + *phEvent = reinterpret_cast(EventOut); + } + return UR_RESULT_SUCCESS; +} diff --git a/source/adapters/offload/event.cpp b/source/adapters/offload/event.cpp new file mode 100644 index 0000000000..5b719d1f50 --- /dev/null +++ b/source/adapters/offload/event.cpp @@ -0,0 +1,23 @@ +#include +#include + +#include "ur2offload.hpp" + +UR_APIEXPORT ur_result_t UR_APICALL +urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { + // TODO: Check for errors + for (uint32_t i = 0; i < numEvents; i++) { + olWaitEvent(reinterpret_cast(phEventWaitList[i])); + } + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { + auto OffloadEvent = reinterpret_cast(hEvent); + return offloadResultToUR(olRetainEvent(OffloadEvent)); +} + +UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { + auto OffloadEvent = reinterpret_cast(hEvent); + return offloadResultToUR(olReleaseEvent(OffloadEvent)); +} diff --git a/source/adapters/offload/kernel.cpp b/source/adapters/offload/kernel.cpp new file mode 100644 index 0000000000..16fa9ff81a --- /dev/null +++ b/source/adapters/offload/kernel.cpp @@ -0,0 +1,69 @@ +#include "kernel.hpp" +#include "ur2offload.hpp" +#include +#include +#include + +UR_APIEXPORT ur_result_t UR_APICALL +urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName, + ur_kernel_handle_t *phKernel) { + ol_kernel_handle_t OffloadKernel; + + auto Res = olCreateKernel(reinterpret_cast(hProgram), + pKernelName, &OffloadKernel); + + if (Res != OL_SUCCESS) { + return offloadResultToUR(Res); + } + + *phKernel = reinterpret_cast(OffloadKernel); + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { + return offloadResultToUR( + olRetainKernel(reinterpret_cast(hKernel))); +} + +UR_APIEXPORT ur_result_t UR_APICALL +urKernelRelease(ur_kernel_handle_t hKernel) { + return offloadResultToUR( + olReleaseKernel(reinterpret_cast(hKernel))); +} + +UR_APIEXPORT ur_result_t UR_APICALL +urKernelSetExecInfo(ur_kernel_handle_t, ur_kernel_exec_info_t, size_t, + const ur_kernel_exec_info_properties_t *, const void *) { + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer( + ur_kernel_handle_t hKernel, uint32_t argIndex, + const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) { + // setKernelArg is expecting a pointer to our argument + return offloadResultToUR( + olSetKernelArgValue(reinterpret_cast(hKernel), + argIndex, sizeof(pArgValue), &pArgValue)); +} + +UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue( + ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize, + const ur_kernel_arg_value_properties_t *, const void *pArgValue) { + return offloadResultToUR( + olSetKernelArgValue(reinterpret_cast(hKernel), + argIndex, argSize, (void *)pArgValue)); +} + +UR_APIEXPORT ur_result_t UR_APICALL +urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, + ur_kernel_group_info_t propName, size_t propSize, + void *pPropValue, size_t *pPropSizeRet) { + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + + if (propName == UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE) { + size_t GroupSize[3] = {0, 0, 0}; + return ReturnValue(GroupSize, 3); + } + return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; +} diff --git a/source/adapters/offload/kernel.hpp b/source/adapters/offload/kernel.hpp new file mode 100644 index 0000000000..c275acd570 --- /dev/null +++ b/source/adapters/offload/kernel.hpp @@ -0,0 +1,6 @@ +#include +#include + +struct ur_kernel_handle_t_ { + ol_kernel_handle_t OffloadKernel; +}; \ No newline at end of file diff --git a/source/adapters/offload/program.cpp b/source/adapters/offload/program.cpp new file mode 100644 index 0000000000..14d1e0cf7c --- /dev/null +++ b/source/adapters/offload/program.cpp @@ -0,0 +1,94 @@ +#include +#include +#include +#include + +#include "context.hpp" +#include "ur2offload.hpp" + +UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( + ur_context_handle_t hContext, uint32_t numDevices, + ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries, + const ur_program_properties_t *, ur_program_handle_t *phProgram) { + if (numDevices > 1) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + + // Workaround for Offload not supporting PTX binaries. Force CUDA programs + // to be linked so they end up as CUBIN. + uint8_t *RealBinary; + size_t RealLength; + ur_platform_handle_t DevicePlatform; + bool DidLink = false; + CUlinkState State; + urDeviceGetInfo(phDevices[0], UR_DEVICE_INFO_PLATFORM, + sizeof(ur_platform_handle_t), &DevicePlatform, nullptr); + ur_platform_backend_t PlatformBackend; + urPlatformGetInfo(DevicePlatform, UR_PLATFORM_INFO_BACKEND, + sizeof(ur_platform_backend_t), &PlatformBackend, nullptr); + if (PlatformBackend == UR_PLATFORM_BACKEND_CUDA) { + cuLinkCreate(0, nullptr, nullptr, &State); + + cuLinkAddData(State, CU_JIT_INPUT_PTX, (char *)(ppBinaries[0]), pLengths[0], + nullptr, 0, nullptr, nullptr); + + void *CuBin = nullptr; + size_t CuBinSize = 0; + cuLinkComplete(State, &CuBin, &CuBinSize); + RealBinary = (uint8_t*) CuBin; + RealLength = CuBinSize; + DidLink = true; + fprintf(stderr, "Performed CUDA bin workaround (size = %lu)\n", RealLength); + + } else { + RealBinary = const_cast(ppBinaries[0]); + RealLength = pLengths[0]; + } + + ol_program_handle_t OffloadProgram; + auto Res = + olCreateProgram(reinterpret_cast(hContext->Device), + RealBinary, RealLength, &OffloadProgram); + + // Program owns the linked module now + if (DidLink) { + cuLinkDestroy(State); + } + + if (Res != OL_SUCCESS) { + return offloadResultToUR(Res); + } + + *phProgram = reinterpret_cast(OffloadProgram); + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t, + ur_program_handle_t, + const char *) { + // Do nothing, program is built upon creation + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t, + uint32_t, + ur_device_handle_t *, + const char *) { + // Do nothing, program is built upon creation + return UR_RESULT_SUCCESS; +} + + +UR_APIEXPORT ur_result_t UR_APICALL +urProgramRetain(ur_program_handle_t hProgram) { + auto OffloadProgram = reinterpret_cast(hProgram); + return offloadResultToUR(olRetainProgram(OffloadProgram)); +} + +UR_APIEXPORT ur_result_t UR_APICALL +urProgramRelease(ur_program_handle_t hProgram) { + auto OffloadProgram = reinterpret_cast(hProgram); + return offloadResultToUR(olReleaseProgram(OffloadProgram)); +} diff --git a/source/adapters/offload/queue.cpp b/source/adapters/offload/queue.cpp new file mode 100644 index 0000000000..cc40a8ece9 --- /dev/null +++ b/source/adapters/offload/queue.cpp @@ -0,0 +1,39 @@ +#include +#include +#include + +#include "context.hpp" +#include "ur2offload.hpp" + +UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate( + [[maybe_unused]] ur_context_handle_t hContext, ur_device_handle_t hDevice, + const ur_queue_properties_t *, ur_queue_handle_t *phQueue) { + + assert(hContext->Device == hDevice); + + ol_queue_handle_t OffloadQueue; + auto Res = olCreateQueue(reinterpret_cast(hDevice), + &OffloadQueue); + if (Res != OL_SUCCESS) { + return offloadResultToUR(Res); + } + + *phQueue = reinterpret_cast(OffloadQueue); + + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { + auto OffloadQueue = reinterpret_cast(hQueue); + return offloadResultToUR(olRetainQueue(OffloadQueue)); +} + +UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { + auto OffloadQueue = reinterpret_cast(hQueue); + return offloadResultToUR(olReleaseQueue(OffloadQueue)); +} + +UR_APIEXPORT ur_result_t UR_APICALL urQueueFinish(ur_queue_handle_t hQueue) { + auto OffloadQueue = reinterpret_cast(hQueue); + return offloadResultToUR(olFinishQueue(OffloadQueue)); +} diff --git a/source/adapters/offload/usm.cpp b/source/adapters/offload/usm.cpp new file mode 100644 index 0000000000..d71e6e26e2 --- /dev/null +++ b/source/adapters/offload/usm.cpp @@ -0,0 +1,61 @@ +#include +#include +#include + +#include "context.hpp" +#include "ur2offload.hpp" + +UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext, + const ur_usm_desc_t *, + ur_usm_pool_handle_t, + size_t size, void **ppMem) { + auto Res = olMemAlloc(reinterpret_cast(hContext->Device), + OL_ALLOC_TYPE_HOST, size, ppMem); + + if (Res != OL_SUCCESS) { + return offloadResultToUR(Res); + } + + hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST); + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( + ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *, + ur_usm_pool_handle_t, size_t size, void **ppMem) { + auto Res = olMemAlloc(reinterpret_cast(hContext->Device), + OL_ALLOC_TYPE_DEVICE, size, ppMem); + + if (Res != OL_SUCCESS) { + return offloadResultToUR(Res); + } + + hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE); + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( + ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *, + ur_usm_pool_handle_t, size_t size, void **ppMem) { + auto Res = olMemAlloc(reinterpret_cast(hContext->Device), + OL_ALLOC_TYPE_SHARED, size, ppMem); + + if (Res != OL_SUCCESS) { + return offloadResultToUR(Res); + } + + hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_SHARED); + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, + void *pMem) { + auto AllocType = hContext->AllocTypeMap.find(pMem); + if (AllocType == hContext->AllocTypeMap.end()) { + return UR_RESULT_ERROR_INVALID_MEM_OBJECT; + } + + return offloadResultToUR( + olMemFree(reinterpret_cast(hContext->Device), + AllocType->second, pMem)); +}