Skip to content

Commit

Permalink
Merge pull request #180 from FMarno/revert_rw_coord
Browse files Browse the repository at this point in the history
Reverted the change of cD to rw_coord in consumer store args
  • Loading branch information
FMarno authored Jan 10, 2025
2 parents 096165f + 3e92879 commit fa0fcad
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
4 changes: 2 additions & 2 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ class CollectiveEpilogue<
tile_coord_mnkl,
tiled_mma,
SubgroupTileShape{}, // Epilogue tile
params.xe_load_c,
rw_coord,
params.xe_store_d,
cD,
residue_mn,
cD,
residue_mn,
Expand Down
11 changes: 6 additions & 5 deletions include/cutlass/epilogue/fusion/xe_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ struct FusionCallbacks<
/////////////////////////////////////////////////////////////////////////////////////////////////

template<
class CtaTileShapeMNK,
class StrideAux,
class CopyOpG2R,
template <class> class ActivationFn,
Expand All @@ -184,7 +185,7 @@ template<
using XeLinCombDeEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
XeAuxLoad<ElementAux, StrideAux, CopyOpG2R> // aux
XeAuxLoad<CtaTileShapeMNK, ElementAux, StrideAux, CopyOpG2R> // aux
>;

// Z = Aux
Expand Down Expand Up @@ -215,17 +216,17 @@ struct FusionCallbacks<
EpilogueTile,
CopyOpG2R
> : XeLinCombDeEltAct<
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput_,
ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
ElementOutput_, ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle
> {

using ElementOutput = ElementOutput_;
using ElementCompute = ElementCompute_;

using Impl =
XeLinCombDeEltAct<
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput,
ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle
>;
using Operation =
fusion::LinCombDeEltAct<
Expand Down
27 changes: 26 additions & 1 deletion include/cutlass/epilogue/fusion/xe_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ using namespace cutlass::epilogue::fusion;
/////////////////////////////////////////////////////////////////////////////////////////////////

template <
class CtaTileShapeMNK,
class Element,
class StrideMNL,
class CopyOpG2R,
Expand Down Expand Up @@ -190,9 +191,33 @@ struct XeAuxLoad {
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto xe_copy_aux = params_ptr->xe_load_aux;
Tensor rw_coord = args.cD;
Tensor trAux = make_tensor_like<Element>(args.tCrC);

using TiledMma = decltype(args.tiled_mma);
using MmaAtomShape = typename TiledMma::AtomShape_MNK;

static constexpr auto BLK_M = get<0>(CtaTileShapeMNK{});
static constexpr auto BLK_N = get<1>(CtaTileShapeMNK{});

static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());

static constexpr auto SG_M = BLK_M / ATOM_M;
static constexpr auto SG_N = BLK_N / ATOM_N;

static constexpr int FragsM = SG_M / get<0>(MmaAtomShape()); // A frags per sub_group
static constexpr int FragsN = SG_N / get<1>(MmaAtomShape()); // B frags per sub_group

auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl;
auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
Tensor tOuti = args.tiled_copy.get_pvc_tensor(
make_coord(m_offset, n_offset, 0),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(Int<get<0>(MmaAtomShape{})>{}, Int<get<1>(MmaAtomShape{})>{}, _1{}));
Tensor rw_coord = tOuti(_,_,_,l_coord);

return ConsumerStoreCallbacks(
rw_coord, xe_copy_aux, cute::move(trAux), params_ptr
);
Expand Down

0 comments on commit fa0fcad

Please sign in to comment.