Skip to content

Commit

Permalink
Update to latest offload API
Browse files Browse the repository at this point in the history
  • Loading branch information
callumfare committed Feb 6, 2025
1 parent 2e13226 commit 2afcfa1
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 0 deletions.
28 changes: 28 additions & 0 deletions source/adapters/offload/context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "context.hpp"
#include <ur_api.h>

UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
const ur_context_properties_t *, ur_context_handle_t *phContext) {
if (DeviceCount > 1) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

auto Ctx = new ur_context_handle_t_(*phDevices);
*phContext = Ctx;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urContextRetain(ur_context_handle_t hContext) {
hContext->RefCount++;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {
if (--hContext->RefCount == 0) {
delete hContext;
}
return UR_RESULT_SUCCESS;
}
19 changes: 19 additions & 0 deletions source/adapters/offload/context.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <atomic>
#include <unordered_map>
#include <ur_api.h>
#include <OffloadAPI.h>

struct ur_context_handle_t_ {
ur_context_handle_t_(ur_device_handle_t hDevice) : Device{hDevice} {
urDeviceRetain(Device);
}
~ur_context_handle_t_() {
urDeviceRelease(Device);
}

ur_device_handle_t Device;
std::atomic_uint32_t RefCount;
std::unordered_map<void*, ol_alloc_type_t> AllocTypeMap;
};
42 changes: 42 additions & 0 deletions source/adapters/offload/enqueue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <assert.h>
#include <OffloadAPI.h>
#include <ur_api.h>

#include "ur2offload.hpp"

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
// Ignore wait list for now
(void)numEventsInWaitList;
(void)phEventWaitList;
//

assert(workDim == 1);

ol_kernel_launch_size_args_t LaunchArgs;
LaunchArgs.Dimensions = workDim;
LaunchArgs.NumGroupsX = pGlobalWorkSize[0];
LaunchArgs.NumGroupsY = 1;
LaunchArgs.NumGroupsZ = 1;
LaunchArgs.GroupSizeX = 1;
LaunchArgs.GroupSizeY = 1;
LaunchArgs.GroupSizeZ = 1;

ol_event_handle_t EventOut;
auto Ret =
olEnqueueKernelLaunch(reinterpret_cast<ol_queue_handle_t>(hQueue),
reinterpret_cast<ol_kernel_handle_t>(hKernel),
&LaunchArgs, &EventOut);

if (Ret != OL_SUCCESS) {
return offloadResultToUR(Ret);
}

if (phEvent) {
*phEvent = reinterpret_cast<ur_event_handle_t>(EventOut);
}
return UR_RESULT_SUCCESS;
}
23 changes: 23 additions & 0 deletions source/adapters/offload/event.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <OffloadAPI.h>
#include <ur_api.h>

#include "ur2offload.hpp"

UR_APIEXPORT ur_result_t UR_APICALL
urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
// TODO: Check for errors
for (uint32_t i = 0; i < numEvents; i++) {
olWaitEvent(reinterpret_cast<ol_event_handle_t>(phEventWaitList[i]));
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
auto OffloadEvent = reinterpret_cast<ol_event_handle_t>(hEvent);
return offloadResultToUR(olRetainEvent(OffloadEvent));
}

UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
auto OffloadEvent = reinterpret_cast<ol_event_handle_t>(hEvent);
return offloadResultToUR(olReleaseEvent(OffloadEvent));
}
69 changes: 69 additions & 0 deletions source/adapters/offload/kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "kernel.hpp"
#include "ur2offload.hpp"
#include <OffloadAPI.h>
#include <ur/ur.hpp>
#include <ur_api.h>

UR_APIEXPORT ur_result_t UR_APICALL
urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
ur_kernel_handle_t *phKernel) {
ol_kernel_handle_t OffloadKernel;

auto Res = olCreateKernel(reinterpret_cast<ol_program_handle_t>(hProgram),
pKernelName, &OffloadKernel);

if (Res != OL_SUCCESS) {
return offloadResultToUR(Res);
}

*phKernel = reinterpret_cast<ur_kernel_handle_t>(OffloadKernel);

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) {
return offloadResultToUR(
olRetainKernel(reinterpret_cast<ol_kernel_handle_t>(hKernel)));
}

UR_APIEXPORT ur_result_t UR_APICALL
urKernelRelease(ur_kernel_handle_t hKernel) {
return offloadResultToUR(
olReleaseKernel(reinterpret_cast<ol_kernel_handle_t>(hKernel)));
}

UR_APIEXPORT ur_result_t UR_APICALL
urKernelSetExecInfo(ur_kernel_handle_t, ur_kernel_exec_info_t, size_t,
const ur_kernel_exec_info_properties_t *, const void *) {
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer(
ur_kernel_handle_t hKernel, uint32_t argIndex,
const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) {
// setKernelArg is expecting a pointer to our argument
return offloadResultToUR(
olSetKernelArgValue(reinterpret_cast<ol_kernel_handle_t>(hKernel),
argIndex, sizeof(pArgValue), &pArgValue));
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
const ur_kernel_arg_value_properties_t *, const void *pArgValue) {
return offloadResultToUR(
olSetKernelArgValue(reinterpret_cast<ol_kernel_handle_t>(hKernel),
argIndex, argSize, (void *)pArgValue));
}

UR_APIEXPORT ur_result_t UR_APICALL
urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
ur_kernel_group_info_t propName, size_t propSize,
void *pPropValue, size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

if (propName == UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE) {
size_t GroupSize[3] = {0, 0, 0};
return ReturnValue(GroupSize, 3);
}
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
}
6 changes: 6 additions & 0 deletions source/adapters/offload/kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <ur_api.h>
#include <OffloadAPI.h>

struct ur_kernel_handle_t_ {
ol_kernel_handle_t OffloadKernel;
};
94 changes: 94 additions & 0 deletions source/adapters/offload/program.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include <OffloadAPI.h>
#include <ur/ur.hpp>
#include <ur_api.h>
#include <cuda.h>

#include "context.hpp"
#include "ur2offload.hpp"

UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
ur_context_handle_t hContext, uint32_t numDevices,
ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
const ur_program_properties_t *, ur_program_handle_t *phProgram) {
if (numDevices > 1) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}


// Workaround for Offload not supporting PTX binaries. Force CUDA programs
// to be linked so they end up as CUBIN.
uint8_t *RealBinary;
size_t RealLength;
ur_platform_handle_t DevicePlatform;
bool DidLink = false;
CUlinkState State;
urDeviceGetInfo(phDevices[0], UR_DEVICE_INFO_PLATFORM,
sizeof(ur_platform_handle_t), &DevicePlatform, nullptr);
ur_platform_backend_t PlatformBackend;
urPlatformGetInfo(DevicePlatform, UR_PLATFORM_INFO_BACKEND,
sizeof(ur_platform_backend_t), &PlatformBackend, nullptr);
if (PlatformBackend == UR_PLATFORM_BACKEND_CUDA) {
cuLinkCreate(0, nullptr, nullptr, &State);

cuLinkAddData(State, CU_JIT_INPUT_PTX, (char *)(ppBinaries[0]), pLengths[0],
nullptr, 0, nullptr, nullptr);

void *CuBin = nullptr;
size_t CuBinSize = 0;
cuLinkComplete(State, &CuBin, &CuBinSize);
RealBinary = (uint8_t*) CuBin;
RealLength = CuBinSize;
DidLink = true;
fprintf(stderr, "Performed CUDA bin workaround (size = %lu)\n", RealLength);

} else {
RealBinary = const_cast<uint8_t *>(ppBinaries[0]);
RealLength = pLengths[0];
}

ol_program_handle_t OffloadProgram;
auto Res =
olCreateProgram(reinterpret_cast<ol_device_handle_t>(hContext->Device),
RealBinary, RealLength, &OffloadProgram);

// Program owns the linked module now
if (DidLink) {
cuLinkDestroy(State);
}

if (Res != OL_SUCCESS) {
return offloadResultToUR(Res);
}

*phProgram = reinterpret_cast<ur_program_handle_t>(OffloadProgram);

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t,
ur_program_handle_t,
const char *) {
// Do nothing, program is built upon creation
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
// Do nothing, program is built upon creation
return UR_RESULT_SUCCESS;
}


UR_APIEXPORT ur_result_t UR_APICALL
urProgramRetain(ur_program_handle_t hProgram) {
auto OffloadProgram = reinterpret_cast<ol_program_handle_t>(hProgram);
return offloadResultToUR(olRetainProgram(OffloadProgram));
}

UR_APIEXPORT ur_result_t UR_APICALL
urProgramRelease(ur_program_handle_t hProgram) {
auto OffloadProgram = reinterpret_cast<ol_program_handle_t>(hProgram);
return offloadResultToUR(olReleaseProgram(OffloadProgram));
}
39 changes: 39 additions & 0 deletions source/adapters/offload/queue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <OffloadAPI.h>
#include <ur/ur.hpp>
#include <ur_api.h>

#include "context.hpp"
#include "ur2offload.hpp"

UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate(
[[maybe_unused]] ur_context_handle_t hContext, ur_device_handle_t hDevice,
const ur_queue_properties_t *, ur_queue_handle_t *phQueue) {

assert(hContext->Device == hDevice);

ol_queue_handle_t OffloadQueue;
auto Res = olCreateQueue(reinterpret_cast<ol_device_handle_t>(hDevice),
&OffloadQueue);
if (Res != OL_SUCCESS) {
return offloadResultToUR(Res);
}

*phQueue = reinterpret_cast<ur_queue_handle_t>(OffloadQueue);

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) {
auto OffloadQueue = reinterpret_cast<ol_queue_handle_t>(hQueue);
return offloadResultToUR(olRetainQueue(OffloadQueue));
}

UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) {
auto OffloadQueue = reinterpret_cast<ol_queue_handle_t>(hQueue);
return offloadResultToUR(olReleaseQueue(OffloadQueue));
}

UR_APIEXPORT ur_result_t UR_APICALL urQueueFinish(ur_queue_handle_t hQueue) {
auto OffloadQueue = reinterpret_cast<ol_queue_handle_t>(hQueue);
return offloadResultToUR(olFinishQueue(OffloadQueue));
}
61 changes: 61 additions & 0 deletions source/adapters/offload/usm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include <OffloadAPI.h>
#include <ur/ur.hpp>
#include <ur_api.h>

#include "context.hpp"
#include "ur2offload.hpp"

UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(ur_context_handle_t hContext,
const ur_usm_desc_t *,
ur_usm_pool_handle_t,
size_t size, void **ppMem) {
auto Res = olMemAlloc(reinterpret_cast<ol_device_handle_t>(hContext->Device),
OL_ALLOC_TYPE_HOST, size, ppMem);

if (Res != OL_SUCCESS) {
return offloadResultToUR(Res);
}

hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_HOST);
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *,
ur_usm_pool_handle_t, size_t size, void **ppMem) {
auto Res = olMemAlloc(reinterpret_cast<ol_device_handle_t>(hContext->Device),
OL_ALLOC_TYPE_DEVICE, size, ppMem);

if (Res != OL_SUCCESS) {
return offloadResultToUR(Res);
}

hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_DEVICE);
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
ur_context_handle_t hContext, ur_device_handle_t, const ur_usm_desc_t *,
ur_usm_pool_handle_t, size_t size, void **ppMem) {
auto Res = olMemAlloc(reinterpret_cast<ol_device_handle_t>(hContext->Device),
OL_ALLOC_TYPE_SHARED, size, ppMem);

if (Res != OL_SUCCESS) {
return offloadResultToUR(Res);
}

hContext->AllocTypeMap.insert_or_assign(*ppMem, OL_ALLOC_TYPE_SHARED);
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
void *pMem) {
auto AllocType = hContext->AllocTypeMap.find(pMem);
if (AllocType == hContext->AllocTypeMap.end()) {
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
}

return offloadResultToUR(
olMemFree(reinterpret_cast<ol_device_handle_t>(hContext->Device),
AllocType->second, pMem));
}

0 comments on commit 2afcfa1

Please sign in to comment.