Skip to content

Commit

Permalink
Introduce index reduce kernels for small sizes (#1345)
Browse files Browse the repository at this point in the history
Introduce index reduce kernels for small sizes.

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
yucai-intel and xytintel authored Feb 19, 2025
1 parent a14d1ea commit e4ce4df
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 27 deletions.
248 changes: 224 additions & 24 deletions src/ATen/native/xpu/sycl/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,23 @@ static ptrdiff_t getSliceSize(
return dstSliceSize;
}

template <typename scalar_t>
bool indexShouldBeMajor(
TensorInfo<scalar_t, unsigned int>& info,
int sliceDim) {
// The stride between adjacent slices (e.g., between element #0 of slice #100
// and element #0 of slice #101).
unsigned int sliceStride = info.strides[sliceDim];

for (const auto i : c10::irange(info.dims)) {
if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) {
return true;
}
}

return false;
}

template <typename ValType>
struct IndexAddScalarFunctor {
void operator()(
Expand Down Expand Up @@ -1152,6 +1169,76 @@ void put_kernel(
});
}

template <typename T, typename IndicesType, typename IndexType, typename func_t>
struct IndexFuncSmallIndexFunctor {
void operator()(sycl::nd_item<1> item) const {
// In order to avoid reloading the index that we are copying, load
// it once to handle all of the points that are being selected, so
// it can be reused as much as possible. This kernel is chosen when
// this is a good choice (small number of chosen indices), since
// re-accessing indices in addition to src elements can be slow.
for (IndexType srcIndex = 0; srcIndex < indices_.sizes[0]; ++srcIndex) {
// Lua indices begin at 1
IndexType dstIndex =
indices_.data[IndexToOffset<const IndicesType, IndexType>::get(
srcIndex, indices_)];
SYCL_KERNEL_ASSERT(dstIndex < dstAddDimSize_);

// We stride over the output ignoring the indexed dimension
// (innerSize), whose offset calculation is handled differently
for (IndexType linearIndex = item.get_group(0) * item.get_local_range(0) +
item.get_local_id(0);
linearIndex < innerSize_;
linearIndex += item.get_group_range(0) * item.get_local_range(0)) {
IndexType dstOffset =
IndexToOffset<T, IndexType>::get(linearIndex, dst_);
dstOffset += dstIndex * dst_.strides[dstAddDim_];

IndexType srcOffset =
IndexToOffset<const T, IndexType>::get(linearIndex, src_);
srcOffset += srcIndex * src_.strides[srcAddDim_];

T val = src_.data[srcOffset] * alpha_;
op_(dst_.data, dstOffset, dstNumel_, &val);
}
}
}

IndexFuncSmallIndexFunctor(
TensorInfo<T, IndexType> dst,
TensorInfo<const T, IndexType> src,
TensorInfo<const IndicesType, IndexType> indices,
int dstAddDim,
int srcAddDim,
IndexType innerSize,
int64_t dstAddDimSize,
int64_t dstNumel,
func_t op,
T alpha)
: dst_(dst),
src_(src),
indices_(indices),
dstAddDim_(dstAddDim),
srcAddDim_(srcAddDim),
innerSize_(innerSize),
dstAddDimSize_(dstAddDimSize),
dstNumel_(dstNumel),
op_(op),
alpha_(alpha) {}

private:
TensorInfo<T, IndexType> dst_;
TensorInfo<const T, IndexType> src_;
TensorInfo<const IndicesType, IndexType> indices_;
int dstAddDim_;
int srcAddDim_;
IndexType innerSize_;
int64_t dstAddDimSize_;
int64_t dstNumel_;
func_t op_;
T alpha_;
};

template <
typename T,
typename IndicesType,
Expand Down Expand Up @@ -1180,7 +1267,7 @@ struct IndexFuncLargeIndexFunctor {
IndexType dstIndex =
indices_.data[IndexToOffset<const IndicesType, IndexType>::get(
srcIndex, indices_)];
CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize_);
SYCL_KERNEL_ASSERT(dstIndex < dstAddDimSize_);

IndexType dstOffset =
IndexToOffset<T, IndexType>::get(elementInSlice, dst_);
Expand Down Expand Up @@ -1301,13 +1388,140 @@ void index_reduce_func_xpu_template(
uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
uint64_t sourceTotalSize = source.numel();
uint64_t selfReduceDimSize = self_.size(dim);
// uint64_t numIndex = index.numel();
uint64_t numIndex = index.numel();
uint64_t selfNumel = self_.numel();
if (sliceSize == 0) {
return;
}
bool indContig = index.is_contiguous();

#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, FUNC_T) \
IndexFuncSmallIndexFunctor<TENSOR_TYPE, INDICES_TYPE, TYPE, FUNC_T>( \
selfInfo, \
sourceInfo, \
indexInfo, \
selfReduceDim, \
sourceReduceDim, \
sliceSize, \
selfReduceDimSize, \
selfNumel, \
reduce_func, \
alpha_value);

#define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, IDX_IS_MAJOR, FUNC_T) \
IndexFuncLargeIndexFunctor< \
TENSOR_TYPE, \
INDICES_TYPE, \
TYPE, \
IDX_IS_MAJOR, \
FUNC_T>( \
selfInfo, \
sourceInfo, \
indexInfo, \
selfReduceDim, \
sourceReduceDim, \
sourceTotalSize, \
(IDX_IS_MAJOR) ? sliceSize : numIndex, \
selfReduceDimSize, \
selfNumel, \
reduce_func, \
alpha_value);

int ssc = syclMaxDSSNum();

if (canUse32BitIndexMath(result) && canUse32BitIndexMath(source) &&
canUse32BitIndexMath(index)) {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
result.scalar_type(),
"index_reduce",
[&] {
TensorInfo<scalar_t, unsigned int> selfInfo =
getTensorInfo<scalar_t, unsigned int>(self_);
int selfReduceDim = selfInfo.collapseDims(dim);
selfInfo.reduceDim(selfReduceDim);
auto alpha_value = (scalar_t)1;
AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "index_reduce_xpu", [&]() {
auto sourceInfo =
getTensorInfo<const scalar_t, unsigned int>(source_);
int sourceReduceDim = sourceInfo.collapseDims(dim);
sourceInfo.reduceDim(sourceReduceDim);

auto indexInfo =
getTensorInfo<const index_t, unsigned int>(index);
indexInfo.collapseDims();

{
// A reasonable choice for when to have each thread iterate over
// index to choose
if (numIndex <= 16) {
auto caller =
SMALL_INDEX(scalar_t, index_t, unsigned int, func_t);
int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller);
size_t num_wg = std::min(
ceil_div(sliceSize, (uint64_t)128), (uint64_t)(ssc * 8));
size_t wg_size = std::min(sliceSize, (uint64_t)128);
sycl_kernel_submit(
num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller);
} else {
bool indexIsMajor =
indexShouldBeMajor(selfInfo, selfReduceDim);

if (indContig) {
if (indexIsMajor) {
auto caller = LARGE_INDEX(
scalar_t, index_t, unsigned int, true, func_t);
int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller);
size_t num_wg = std::min(
ceil_div(sourceTotalSize, (uint64_t)128),
(uint64_t)(ssc * 8));
size_t wg_size =
(sourceTotalSize < defaultMaxGroupThreads)
? sourceTotalSize
: defaultMaxGroupThreads;
sycl_kernel_submit(
num_wg * wg_size,
wg_size,
getCurrentSYCLQueue(),
caller);
} else {
auto caller = LARGE_INDEX(
scalar_t, index_t, unsigned int, false, func_t);
int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller);
size_t num_wg = std::min(
ceil_div(sourceTotalSize, (uint64_t)128),
(uint64_t)(ssc * 8));
size_t wg_size =
(sourceTotalSize < defaultMaxGroupThreads)
? sourceTotalSize
: defaultMaxGroupThreads;
sycl_kernel_submit(
num_wg * wg_size,
wg_size,
getCurrentSYCLQueue(),
caller);
}
} else {
auto caller = LARGE_INDEX(
scalar_t, index_t, unsigned int, true, func_t);
int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller);
size_t num_wg = std::min(
ceil_div(sourceTotalSize, (uint64_t)128),
(uint64_t)(ssc * 8));
size_t wg_size = (sourceTotalSize < defaultMaxGroupThreads)
? sourceTotalSize
: defaultMaxGroupThreads;
sycl_kernel_submit(
num_wg * wg_size,
wg_size,
getCurrentSYCLQueue(),
caller);
}
}
});
});
} else {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand All @@ -1330,28 +1544,12 @@ void index_reduce_func_xpu_template(
TensorInfo<const index_t, uint64_t> indexInfo =
getTensorInfo<const index_t, uint64_t>(index);
indexInfo.collapseDims();
auto caller = IndexFuncLargeIndexFunctor<
scalar_t,
index_t,
uint64_t,
true,
func_t>(
selfInfo,
sourceInfo,
indexInfo,
selfReduceDim,
sourceReduceDim,
sourceTotalSize,
sliceSize,
selfReduceDimSize,
selfNumel,
reduce_func,
alpha_value);
auto caller =
LARGE_INDEX(scalar_t, index_t, uint64_t, true, func_t);
int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller);
int sgc = syclMaxNumSubGroups();
size_t num_wg = std::min(
ceil_div(sourceTotalSize, (uint64_t)128),
(uint64_t)(sgc * 8));
(uint64_t)(ssc * 8));
size_t wg_size = (sourceTotalSize < defaultMaxGroupThreads)
? sourceTotalSize
: defaultMaxGroupThreads;
Expand All @@ -1360,6 +1558,9 @@ void index_reduce_func_xpu_template(
});
});
}

#undef SMALL_INDEX
#undef LARGE_INDEX
}

struct IndexReduceMultiplyFunctor {
Expand Down Expand Up @@ -1550,8 +1751,7 @@ struct IndexSelectSparse3Functor {
index_t count,
index_t offset,
index_t first_match) const {
index_t* RESTRICT ptr_res_dim_indices_out =
ptr_res_dim_indices_ + offset;
index_t* RESTRICT ptr_res_dim_indices_out = ptr_res_dim_indices_ + offset;
const index_t* RESTRICT ptr_argsort_dim_indices_in =
ptr_argsort_dim_indices_ + first_match;
index_t* RESTRICT ptr_selected_dim_indices_out =
Expand Down
5 changes: 2 additions & 3 deletions src/comm/DeviceProperties.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ static inline int64_t syclMaxNumSubGroups(

static inline int64_t syclMaxDSSNum(
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
// TODO: We need to got this info from DPC++ Runtime
// Hardcode to 32 for ATS
int64_t dss_num = 32;
int64_t dss_num =
syclMaxComputeUnitSize(dev_id) / syclGpuEUCountPerSubslice(dev_id);
return dss_num;
}

Expand Down

0 comments on commit e4ce4df

Please sign in to comment.