Skip to content

Commit 21d2651

Browse files
committed
Fix for void ElementC in epilogue.
Todo: Handle void ElementD Check if we need to handle assumption that accumulator is same as compute. Signed-off-by: Chawla, Amit K <[email protected]>
1 parent 9acfcd5 commit 21d2651

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

include/cutlass/epilogue/collective/builders/xe_builder.inl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ template <
195195
//TODO(Codeplay): Should FusionCallbacks use DispatchPolicy IntelXeGroupEpilogue for group gemm? That does not work.
196196
using FusionCallbacks = typename detail::FusionOpInfo<FusionOpOrCallbacks>::template FusionCallbacks<
197197
IntelXeXMX16, TileShape_MNK, TileShape_MNK, CopyOpG2R>;
198-
198+
static_assert(cute::is_same_v<ElementAccumulator, ElementCompute>, "Assumption is that in Epilogue, ElementAccumulator and ElementCompute are same");
199199
using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue<
200200
DispatchPolicy,
201201
TileShape_MNK,

include/cutlass/epilogue/collective/xe_array_epilogue.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class CollectiveEpilogue<
9090
using DispatchPolicy = IntelXeXMX16Group;
9191
using CtaTileMNK = CtaTileMNK_;
9292
using FusionCallbacks = FusionCallbacks_;
93-
using ElementC = ElementC_;
9493
using StrideC = StrideC_;
9594
using InternalStrideC = cute::remove_pointer_t<StrideC>;
9695
using ElementD = ElementD_;
@@ -110,6 +109,7 @@ class CollectiveEpilogue<
110109
using ElementOutput = ElementD;
111110
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
112111
using ElementAccumulator = ElementCompute;
112+
using ElementC = conditional_t<cute::is_void_v<CopyOpG2R>, ElementAccumulator, ElementC_>;
113113
using ElementSource = typename FusionCallbacks::ElementSource;
114114
using ElementScalar = typename FusionCallbacks::ElementScalar;
115115
static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest;
@@ -139,7 +139,7 @@ class CollectiveEpilogue<
139139
Layout<CopyThreadShape>{},
140140
make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))));
141141
private:
142-
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
142+
constexpr static bool is_source_supported = not cute::is_void_v<ElementC_>;
143143
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;
144144

145145
public:

include/cutlass/epilogue/collective/xe_epilogue.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ class CollectiveEpilogue<
8989
using DispatchPolicy = IntelXeXMX16;
9090
using CtaTileMNK = CtaTileMNK_;
9191
using FusionCallbacks = FusionCallbacks_;
92-
using ElementC = ElementC_;
9392
using StrideC = StrideC_;
9493
using ElementD = ElementD_;
9594
using StrideD = StrideD_;
@@ -107,6 +106,7 @@ class CollectiveEpilogue<
107106
using ElementOutput = ElementD;
108107
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
109108
using ElementAccumulator = ElementCompute;
109+
using ElementC = conditional_t<cute::is_void_v<CopyOpG2R>, ElementAccumulator, ElementC_>;
110110
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
111111

112112
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
@@ -129,7 +129,7 @@ class CollectiveEpilogue<
129129
using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom<Trait_D, ElementD>{}, Layout<CopyThreadShape>{}, val_layout_store_D{}));
130130

131131
private:
132-
constexpr static bool is_source_supported = not cute::is_void_v<ElementC> && not cute::is_void_v<CopyOpG2R>;
132+
constexpr static bool is_source_supported = not cute::is_void_v<ElementC_> && not cute::is_void_v<CopyOpG2R>;
133133
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD> && not cute::is_void_v<CopyOpR2G>;
134134

135135
constexpr static bool is_m_major_C = detail::is_m_major<StrideC>();

0 commit comments

Comments
 (0)