Skip to content

Commit 19922a0

Browse files
committed
[UR][L0] Set pointer kernel arguments only for queue's associated device
Ensure that pointer kernel arguments are set only for the device associated with the queue being used for kernel launch. Previously, arguments were set for all devices in the kernel's device map, which was unnecessary and potentially incorrect when launching on a specific device.
1 parent 2af08ff commit 19922a0

File tree

4 files changed

+98
-18
lines changed

4 files changed

+98
-18
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
// UNSUPPORTED-TRACKER: CMPLRLLVM-67039
5+
// UNSUPPORTED: level_zero_v2_adapter
6+
7+
// Test that usm device pointer can be used in a kernel compiled for a context
8+
// with multiple devices.
9+
10+
#include <iostream>
11+
#include <sycl/detail/core.hpp>
12+
#include <sycl/kernel_bundle.hpp>
13+
#include <sycl/platform.hpp>
14+
#include <sycl/usm.hpp>
15+
#include <vector>
16+
17+
using namespace sycl;
18+
19+
class AddIdxKernel;
20+
21+
int main() {
22+
sycl::platform plt;
23+
std::vector<sycl::device> devices = plt.get_devices();
24+
if (devices.size() < 2) {
25+
std::cout << "Need at least 2 GPU devices for this test.\n";
26+
return 0;
27+
}
28+
29+
std::vector<sycl::device> ctx_devices{devices[0], devices[1]};
30+
sycl::context ctx(ctx_devices);
31+
32+
constexpr size_t N = 16;
33+
std::vector<std::vector<int>> results(ctx_devices.size(),
34+
std::vector<int>(N, 0));
35+
36+
// Create a kernel bundle compiled for both devices in the context
37+
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(ctx);
38+
39+
// For each device, create a queue and run a kernel using device USM
40+
for (size_t i = 0; i < ctx_devices.size(); ++i) {
41+
sycl::queue q(ctx, ctx_devices[i]);
42+
int *data = sycl::malloc_device<int>(N, q);
43+
q.fill(data, 1, N).wait();
44+
q.submit([&](sycl::handler &h) {
45+
h.use_kernel_bundle(kb);
46+
h.parallel_for<AddIdxKernel>(
47+
sycl::range<1>(N), [=](sycl::id<1> idx) { data[idx] += idx[0]; });
48+
}).wait();
49+
q.memcpy(results[i].data(), data, N * sizeof(int)).wait();
50+
sycl::free(data, q);
51+
}
52+
53+
for (size_t i = 0; i < ctx_devices.size(); ++i) {
54+
std::cout << "Device " << i << " results: ";
55+
for (size_t j = 0; j < N; ++j) {
56+
if (results[i][j] != 1 + static_cast<int>(j)) {
57+
return -1;
58+
}
59+
std::cout << results[i][j] << " ";
60+
}
61+
}
62+
return 0;
63+
}

unified-runtime/source/adapters/level_zero/command_buffer.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,12 +1004,16 @@ ur_result_t setKernelPendingArguments(
10041004
ze_kernel_handle_t ZeKernel) {
10051005
// If there are any pending arguments set them now.
10061006
for (auto &Arg : PendingArguments) {
1007-
// The ArgValue may be a NULL pointer in which case a NULL value is used for
1008-
// the kernel argument declared as a pointer to global or constant memory.
10091007
char **ZeHandlePtr = nullptr;
1010-
if (Arg.Value) {
1011-
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
1012-
nullptr, 0u));
1008+
if (auto MemObjPtr = std::get_if<ur_mem_handle_t>(&Arg.Value)) {
1009+
ur_mem_handle_t MemObj = *MemObjPtr;
1010+
if (MemObj) {
1011+
UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device,
1012+
nullptr, 0u));
1013+
}
1014+
} else {
1015+
auto Ptr = const_cast<void **>(&std::get<const void *>(Arg.Value));
1016+
ZeHandlePtr = reinterpret_cast<char **>(Ptr);
10131017
}
10141018
ZE2UR_CALL(zeKernelSetArgumentValue,
10151019
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));

unified-runtime/source/adapters/level_zero/kernel.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,20 @@ ur_result_t urEnqueueKernelLaunch(
125125

126126
// If there are any pending arguments set them now.
127127
for (auto &Arg : Kernel->PendingArguments) {
128-
// The ArgValue may be a NULL pointer in which case a NULL value is used for
129-
// the kernel argument declared as a pointer to global or constant memory.
128+
// The Arg.Value can be either a ur_mem_handle_t or a raw pointer
129+
// (const void*). Resolve per-device: for mem handles obtain the device
130+
// specific handle, otherwise pass the raw pointer value.
130131
char **ZeHandlePtr = nullptr;
131-
if (Arg.Value) {
132-
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
133-
Queue->Device, EventWaitList,
134-
NumEventsInWaitList));
132+
if (auto MemObjPtr = std::get_if<ur_mem_handle_t>(&Arg.Value)) {
133+
ur_mem_handle_t MemObj = *MemObjPtr;
134+
if (MemObj) {
135+
UR_CALL(MemObj->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
136+
Queue->Device, EventWaitList,
137+
NumEventsInWaitList));
138+
}
139+
} else {
140+
auto Ptr = const_cast<void **>(&std::get<const void *>(Arg.Value));
141+
ZeHandlePtr = reinterpret_cast<char **>(Ptr);
135142
}
136143
ZE2UR_CALL(zeKernelSetArgumentValue,
137144
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
@@ -733,9 +740,13 @@ ur_result_t urKernelSetArgPointer(
733740
/// value. If null then argument value is considered null.
734741
const void *ArgValue) {
735742

736-
// KernelSetArgValue is expecting a pointer to the argument
737-
UR_CALL(ur::level_zero::urKernelSetArgValue(
738-
Kernel, ArgIndex, sizeof(const void *), nullptr, &ArgValue));
743+
// Instead of setting pointer arguments immediately across all device
744+
// kernels, store them as pending so they can be resolved per-device at
745+
// enqueue time. This ensures the correct handle is used for the device of the
746+
// queue.
747+
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
748+
Kernel->PendingArguments.push_back(
749+
{ArgIndex, sizeof(const void *), ArgValue, ur_mem_handle_t_::unknown});
739750
return UR_RESULT_SUCCESS;
740751
}
741752

@@ -842,9 +853,8 @@ ur_result_t urKernelSetArgMemObj(
842853
return UR_RESULT_ERROR_INVALID_ARGUMENT;
843854
}
844855
}
845-
auto Arg = UrMem ? UrMem : nullptr;
846856
Kernel->PendingArguments.push_back(
847-
{ArgIndex, sizeof(void *), Arg, UrAccessMode});
857+
{ArgIndex, sizeof(const void *), UrMem, UrAccessMode});
848858

849859
return UR_RESULT_SUCCESS;
850860
}

unified-runtime/source/adapters/level_zero/kernel.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#pragma once
1111

1212
#include <unordered_set>
13+
#include <variant>
1314

1415
#include "common.hpp"
1516
#include "common/ur_ref_count.hpp"
@@ -97,8 +98,10 @@ struct ur_kernel_handle_t_ : ur_object {
9798
struct ArgumentInfo {
9899
uint32_t Index;
99100
size_t Size;
100-
// const ur_mem_handle_t_ *Value;
101-
ur_mem_handle_t_ *Value;
101+
// Value may be either a memory object or a raw pointer value (for pointer
102+
// arguments). Resolve at enqueue time per-device to ensure correct handle
103+
// is used for that device.
104+
std::variant<ur_mem_handle_t, const void *> Value;
102105
ur_mem_handle_t_::access_mode_t AccessMode{ur_mem_handle_t_::unknown};
103106
};
104107
// Arguments that still need to be set (with zeKernelSetArgumentValue)

0 commit comments

Comments
 (0)