Skip to content

Commit

Permalink
[SYCL][ABI-Break] Improve Queue fill (intel#13788)
Browse files Browse the repository at this point in the history
Changed the `queue.fill()` implementation to make use of the native
functions for a specific backend. Also, unified the implementation with
the one for memset, since it is just an 8-bit subset operation of fill.

In the CUDA case, both memset and fill are currently calling
`urEnqueueUSMFill` which depending on the size of the filling pattern
calls either `cuMemsetD8Async`, `cuMemsetD16Async`, `cuMemsetD32Async`
or `commonMemSetLargePattern`. Before this patch memset was using the
same thing, just beforehand setting patternSize always to 1 byte which
resulted in calling `cuMemsetD8Async`. In other backends, the behaviour
is analogous.

The fill method was just invoking a `parallel_for` to fill the memory
with the pattern which was making this operation quite slow.
  • Loading branch information
konradkusiak97 authored Jul 5, 2024
1 parent b026de4 commit 0ccb0b7
Show file tree
Hide file tree
Showing 44 changed files with 266 additions and 198 deletions.
2 changes: 2 additions & 0 deletions sycl/doc/design/CommandGraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ The types of commands which are unsupported, and lead to this exception are:
This corresponds to a memory buffer write command.
* `handler::copy(src, dest)` or `handler::memcpy(dest, src)` - Where both `src` and
`dest` are USM pointers. This corresponds to a USM copy command.
* `handler::fill(ptr, pattern, count)` - This corresponds to a USM memory
fill command.
* `handler::memset(ptr, value, numBytes)` - This corresponds to a USM memory
fill command.
* `handler::prefetch()`.
Expand Down
20 changes: 10 additions & 10 deletions sycl/include/sycl/detail/cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class CG {
getAuxiliaryResources() const {
return {};
}
virtual void clearAuxiliaryResources(){};
virtual void clearAuxiliaryResources() {};

virtual ~CG() = default;

Expand Down Expand Up @@ -247,11 +247,11 @@ class CGCopy : public CG {
/// "Fill memory" command group class.
class CGFill : public CG {
public:
std::vector<char> MPattern;
std::vector<unsigned char> MPattern;
AccessorImplHost *MPtr;

CGFill(std::vector<char> Pattern, void *Ptr, CG::StorageInitHelper CGData,
detail::code_location loc = {})
CGFill(std::vector<unsigned char> Pattern, void *Ptr,
CG::StorageInitHelper CGData, detail::code_location loc = {})
: CG(Fill, std::move(CGData), std::move(loc)),
MPattern(std::move(Pattern)), MPtr((AccessorImplHost *)Ptr) {}
AccessorImplHost *getReqToFill() { return MPtr; }
Expand Down Expand Up @@ -289,18 +289,18 @@ class CGCopyUSM : public CG {

/// "Fill USM" command group class.
class CGFillUSM : public CG {
std::vector<char> MPattern;
std::vector<unsigned char> MPattern;
void *MDst;
size_t MLength;

public:
CGFillUSM(std::vector<char> Pattern, void *DstPtr, size_t Length,
CGFillUSM(std::vector<unsigned char> Pattern, void *DstPtr, size_t Length,
CG::StorageInitHelper CGData, detail::code_location loc = {})
: CG(FillUSM, std::move(CGData), std::move(loc)),
MPattern(std::move(Pattern)), MDst(DstPtr), MLength(Length) {}
void *getDst() { return MDst; }
size_t getLength() { return MLength; }
int getFill() { return MPattern[0]; }
const std::vector<unsigned char> &getPattern() { return MPattern; }
};

/// "Prefetch USM" command group class.
Expand Down Expand Up @@ -378,14 +378,14 @@ class CGCopy2DUSM : public CG {

/// "Fill 2D USM" command group class.
class CGFill2DUSM : public CG {
std::vector<char> MPattern;
std::vector<unsigned char> MPattern;
void *MDst;
size_t MPitch;
size_t MWidth;
size_t MHeight;

public:
CGFill2DUSM(std::vector<char> Pattern, void *DstPtr, size_t Pitch,
CGFill2DUSM(std::vector<unsigned char> Pattern, void *DstPtr, size_t Pitch,
size_t Width, size_t Height, CG::StorageInitHelper CGData,
detail::code_location loc = {})
: CG(Fill2DUSM, std::move(CGData), std::move(loc)),
Expand All @@ -395,7 +395,7 @@ class CGFill2DUSM : public CG {
size_t getPitch() const { return MPitch; }
size_t getWidth() const { return MWidth; }
size_t getHeight() const { return MHeight; }
const std::vector<char> &getPattern() const { return MPattern; }
const std::vector<unsigned char> &getPattern() const { return MPattern; }
};

/// "Memset 2D USM" command group class.
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/pi.def
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ _PI_API(piextUSMHostAlloc)
_PI_API(piextUSMDeviceAlloc)
_PI_API(piextUSMSharedAlloc)
_PI_API(piextUSMFree)
_PI_API(piextUSMEnqueueMemset)
_PI_API(piextUSMEnqueueFill)
_PI_API(piextUSMEnqueueMemcpy)
_PI_API(piextUSMEnqueuePrefetch)
_PI_API(piextUSMEnqueueMemAdvise)
Expand Down
27 changes: 14 additions & 13 deletions sycl/include/sycl/detail/pi.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@
// _pi_virtual_mem_granularity_info enum, _pi_virtual_mem_info enum and
// pi_virtual_access_flags bit flags.
// 15.55 Added piextEnqueueNativeCommand as well as associated types and enums
// 16.56 Replaced piextUSMEnqueueMemset with piextUSMEnqueueFill

#define _PI_H_VERSION_MAJOR 15
#define _PI_H_VERSION_MINOR 55
#define _PI_H_VERSION_MAJOR 16
#define _PI_H_VERSION_MINOR 56

#define _PI_STRING_HELPER(a) #a
#define _PI_CONCAT(a, b) _PI_STRING_HELPER(a.b)
Expand Down Expand Up @@ -2174,22 +2175,22 @@ __SYCL_EXPORT pi_result piextUSMPitchedAlloc(
/// \param ptr is the memory to be freed
__SYCL_EXPORT pi_result piextUSMFree(pi_context context, void *ptr);

/// USM Memset API
/// USM Fill API
///
/// \param queue is the queue to submit to
/// \param ptr is the ptr to memset
/// \param value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// \param count is the size in bytes to memset
/// \param ptr is the ptr to fill
/// \param pattern is the ptr with the bytes of the pattern to set
/// \param patternSize is the size in bytes of the pattern to set
/// \param count is the size in bytes to fill
/// \param num_events_in_waitlist is the number of events to wait on
/// \param events_waitlist is an array of events to wait on
/// \param event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueMemset(pi_queue queue, void *ptr,
pi_int32 value, size_t count,
pi_uint32 num_events_in_waitlist,
const pi_event *events_waitlist,
pi_event *event);
__SYCL_EXPORT pi_result piextUSMEnqueueFill(pi_queue queue, void *ptr,
const void *pattern,
size_t patternSize, size_t count,
pi_uint32 num_events_in_waitlist,
const pi_event *events_waitlist,
pi_event *event);

/// USM Memcpy API
///
Expand Down
21 changes: 16 additions & 5 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2826,10 +2826,14 @@ class __SYCL_EXPORT handler {
setUserFacingNodeType(ext::oneapi::experimental::node_type::memfill);
static_assert(is_device_copyable<T>::value,
"Pattern must be device copyable");
parallel_for<__usmfill<T>>(range<1>(Count), [=](id<1> Index) {
T *CastedPtr = static_cast<T *>(Ptr);
CastedPtr[Index] = Pattern;
});
if (getDeviceBackend() == backend::ext_oneapi_level_zero) {
parallel_for<__usmfill<T>>(range<1>(Count), [=](id<1> Index) {
T *CastedPtr = static_cast<T *>(Ptr);
CastedPtr[Index] = Pattern;
});
} else {
this->fill_impl(Ptr, &Pattern, sizeof(T), Count);
}
}

/// Prevents any commands submitted afterward to this queue from executing
Expand Down Expand Up @@ -3297,7 +3301,7 @@ class __SYCL_EXPORT handler {
/// Length to copy or fill (for USM operations).
size_t MLength = 0;
/// Pattern that is used to fill memory object in case command type is fill.
std::vector<char> MPattern;
std::vector<unsigned char> MPattern;
/// Storage for a lambda or function object.
std::unique_ptr<detail::HostKernelBase> MHostKernel;
/// Storage for lambda/function when using HostTask
Expand Down Expand Up @@ -3442,6 +3446,10 @@ class __SYCL_EXPORT handler {
// Helper function for getting a loose bound on work-items.
id<2> computeFallbackKernelBounds(size_t Width, size_t Height);

// Function to get information about the backend for which the code is
// compiled for
backend getDeviceBackend() const;

// Common function for launching a 2D USM memcpy kernel to avoid redefinitions
// of the kernel from copy and memcpy.
template <typename T>
Expand Down Expand Up @@ -3553,6 +3561,9 @@ class __SYCL_EXPORT handler {
});
}

// Implementation of USM fill using command for native fill.
void fill_impl(void *Dest, const void *Value, size_t ValueSize, size_t Count);

// Implementation of ext_oneapi_memcpy2d using command for native 2D memcpy.
void ext_oneapi_memcpy2d_impl(void *Dest, size_t DestPitch, const void *Src,
size_t SrcPitch, size_t Width, size_t Height);
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/cuda/pi_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,12 +930,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/hip/pi_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,12 +933,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
23 changes: 11 additions & 12 deletions sycl/plugins/level_zero/pi_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,23 +957,22 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

/// USM Memset API
/// USM Fill API
///
/// @param Queue is the queue to submit to
/// @param Ptr is the ptr to memset
/// @param Value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// @param Count is the size in bytes to memset
/// @param Ptr is the ptr to fill
/// \param Pattern is the ptr with the bytes of the pattern to set
/// \param PatternSize is the size in bytes of the pattern to set
/// @param Count is the size in bytes to fill
/// @param NumEventsInWaitlist is the number of events to wait on
/// @param EventsWaitlist is an array of events to wait on
/// @param Event is the event that represents this operation
pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/native_cpu/pi_native_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,12 +933,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,12 +889,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
14 changes: 7 additions & 7 deletions sycl/plugins/unified_runtime/pi2ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3913,11 +3913,12 @@ inline pi_result piEnqueueMemBufferFill(pi_queue Queue, pi_mem Buffer,
return PI_SUCCESS;
}

inline pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,
pi_int32 Value, size_t Count,
pi_uint32 NumEventsInWaitList,
const pi_event *EventsWaitList,
pi_event *OutEvent) {
inline pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr,
const void *Pattern, size_t PatternSize,
size_t Count,
pi_uint32 NumEventsInWaitList,
const pi_event *EventsWaitList,
pi_event *OutEvent) {
PI_ASSERT(Queue, PI_ERROR_INVALID_QUEUE);
if (!Ptr) {
return PI_ERROR_INVALID_VALUE;
Expand All @@ -3929,8 +3930,7 @@ inline pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,

ur_event_handle_t *UREvent = reinterpret_cast<ur_event_handle_t *>(OutEvent);

size_t PatternSize = 1;
HANDLE_ERRORS(urEnqueueUSMFill(UrQueue, Ptr, PatternSize, &Value, Count,
HANDLE_ERRORS(urEnqueueUSMFill(UrQueue, Ptr, PatternSize, Pattern, Count,
NumEventsInWaitList, UrEventsWaitList,
UREvent));

Expand Down
36 changes: 18 additions & 18 deletions sycl/plugins/unified_runtime/pi_unified_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,24 +442,24 @@ __SYCL_EXPORT pi_result piQueueGetInfo(pi_queue Queue, pi_queue_info ParamName,
ParamValueSizeRet);
}

/// USM Memset API
/// USM Fill API
///
/// @param Queue is the queue to submit to
/// @param Ptr is the ptr to memset
/// @param Value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// @param Count is the size in bytes to memset
/// @param NumEventsInWaitlist is the number of events to wait on
/// @param EventsWaitlist is an array of events to wait on
/// @param Event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,
pi_int32 Value, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
/// \param queue is the queue to submit to
/// \param ptr is the ptr to fill
/// \param pattern is the ptr with the bytes of the pattern to set
/// \param patternSize is the size in bytes of the pattern to set
/// \param count is the size in bytes to fill
/// \param num_events_in_waitlist is the number of events to wait on
/// \param events_waitlist is an array of events to wait on
/// \param event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr,
const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

__SYCL_EXPORT pi_result piEnqueueMemBufferCopyRect(
Expand Down Expand Up @@ -1598,7 +1598,7 @@ __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
_PI_API(piEnqueueMemBufferMap)
_PI_API(piEnqueueMemUnmap)
_PI_API(piEnqueueMemBufferFill)
_PI_API(piextUSMEnqueueMemset)
_PI_API(piextUSMEnqueueFill)
_PI_API(piEnqueueMemBufferCopyRect)
_PI_API(piEnqueueMemBufferCopy)
_PI_API(piextUSMEnqueueMemcpy)
Expand Down
6 changes: 4 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,10 @@ class node_impl {
sycl::detail::CGFillUSM *FillUSM =
static_cast<sycl::detail::CGFillUSM *>(MCommandGroup.get());
Stream << "Dst: " << FillUSM->getDst()
<< " Length: " << FillUSM->getLength()
<< " Pattern: " << FillUSM->getFill() << "\\n";
<< " Length: " << FillUSM->getLength() << " Pattern: ";
for (auto byte : FillUSM->getPattern())
Stream << byte;
Stream << "\\n";
}
break;
case sycl::detail::CG::CGTYPE::PrefetchUSM:
Expand Down
Loading

0 comments on commit 0ccb0b7

Please sign in to comment.