Skip to content

Commit

Permalink
fix example compiler errors
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Dec 3, 2024
1 parent 52a7668 commit 98bd9f2
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 9 deletions.
8 changes: 4 additions & 4 deletions examples/sycl/pvc/pvc_gemm_with_epilogue_lincombdeeltact.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ struct ExampleRunner {
block_ref_D.reset(M * N * L);
block_Aux.reset(M * N * L);

initialize_block(block_A, 2023);
initialize_block(block_B, 2022);
initialize_block(block_C, 2021);
cutlass::reference::device::BlockFillScaler(block_Aux.get(), block_Aux.size(), 8.f);
initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
initialize_block(block_C, seed + 2021);
initialize_block(block_Aux, seed + 2020);
}

void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class CollectiveEpilogue<
TileShapeMNK{},
tile_coord_mnkl,
tiled_mma,
SubgroupTileShape{}, // Epilogue Tile
SubgroupTileShape{}, // Epilogue tile
params.xe_load_c,
rw_coord,
residue_mn,
Expand Down
11 changes: 7 additions & 4 deletions include/cutlass/epilogue/fusion/xe_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ using XeLinCombDeEltAct =
template <
class GmemLayoutTagAux,
template <class> class ActivationFn,
class ElementOutput,
class ElementCompute,
class ElementOutput_,
class ElementCompute_,
class ElementAux,
class ElementSource,
class ElementScalar,
Expand All @@ -200,17 +200,20 @@ template <
struct FusionCallbacks<
epilogue::IntelPVCEpilogue,
fusion::LinCombDeEltAct<
GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute,
GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_,
ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
>,
CtaTileShapeMNK,
EpilogueTile,
CopyOpG2R
> : XeLinCombDeEltAct<
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
ElementOutput_, ElementCompute_, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle
> {

using ElementOutput = ElementOutput_;
using ElementCompute = ElementCompute_;

using Impl =
XeLinCombDeEltAct<
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
Expand Down
72 changes: 72 additions & 0 deletions tools/util/include/cutlass/util/reference/device/tensor_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,30 @@ __global__ void
}
}

template <template <class> class BinaryOp, typename Element>
#if defined (CUTLASS_ENABLE_SYCL)
void
#else
__global__ void
#endif
BlockElementwiseOp(
Element *ptr_dst,
Element const *ptr_A,
Element const *ptr_B,
size_t capacity) {

size_t idx = ThreadIdxX() + BlockDimX() * BlockIdxX();
BinaryOp<Element> bin_op{};

for (; idx < capacity; idx += GridDimX() * BlockDimX()) {

Element a = cutlass::ReferenceFactory<Element>::get(ptr_A, idx);
Element b = cutlass::ReferenceFactory<Element>::get(ptr_B, idx);

cutlass::ReferenceFactory<Element>::get(ptr_dst, idx) = bin_op(a, b);
}
}

} // namespace kernel


Expand Down Expand Up @@ -299,6 +323,54 @@ bool BlockCompareRelativelyEqual(

///////////////////////////////////////////////////////////////////////////////////////////////////

/// Performs an elementwise function of two blocks
template <template <class> class BinaryOp, typename Element>
void BlockElementwiseOp(
Element *ptr_dst,
Element const *ptr_A,
Element const *ptr_B,
size_t capacity,
int grid_size = 0,
int block_size = 0) {


if (!grid_size || !block_size) {
#if defined (CUTLASS_ENABLE_SYCL)
block_size = 128;
grid_size = (capacity + block_size - 1) / block_size;
grid_size = (grid_size < 64 ? grid_size : 64); // limit grid size to avoid out_of_resources runtime error.
#else
// if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
&grid_size,
&block_size,
reinterpret_cast<void const *>(kernel::BlockCompareEqual<Element>));

if (result != cudaSuccess) {
throw std::runtime_error("Failed to query occupancy.");
}

// Limit block size. This has the effect of increasing the number of items processed by a
// single thread and reduces the impact of initialization overhead.
block_size = (block_size < 128 ? block_size : 128);
#endif
}

#if defined(CUTLASS_ENABLE_SYCL)
const auto sycl_block = syclcompat::dim3(block_size, 1, 1);
const auto sycl_grid = syclcompat::dim3(grid_size, 1, 1);
syclcompat::launch<kernel::BlockElementwiseOp<BinaryOp, Element>>(
sycl_grid, sycl_block, ptr_dst, ptr_A, ptr_B, capacity);
syclcompat::wait(); // is this needed?
#else
dim3 grid(grid_size, 1, 1);
dim3 block(block_size, 1, 1);
kernel::BlockCompareEqual<Element><<< grid, block >>>(device_equal_flag, ptr_A, ptr_B, capacity);
#endif
}

///////////////////////////////////////////////////////////////////////////////////////////////////

} // device
} // reference
} // cutlass

0 comments on commit 98bd9f2

Please sign in to comment.