Skip to content

Commit

Permalink
Partial revert of codeplaysoftware#180
Browse files Browse the repository at this point in the history
I have rewritten the construction of `rw_coord` in xe_visitor.hpp in
terms of subgroup tile shape & coords. This simplifies things & allows
to remove the added CtaTileShapeMNK template param.
  • Loading branch information
joeatodd committed Jan 13, 2025
1 parent 05d9b75 commit a1df53e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 18 deletions.
3 changes: 2 additions & 1 deletion include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ class CollectiveEpilogue<
// OOB predication for tile quantization "residue"
// Absolute coordinate tensors (dynamic)
Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N)
Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(sg_m_coord, sg_n_coord));
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)

Expand All @@ -343,7 +344,7 @@ class CollectiveEpilogue<
tiled_mma,
EpilogueTile{},
params.xe_store_d,
rw_coord,
cD,
residue_mn,
tRS_cD,
residue_mn,
Expand Down
11 changes: 5 additions & 6 deletions include/cutlass/epilogue/fusion/xe_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ struct FusionCallbacks<
/////////////////////////////////////////////////////////////////////////////////////////////////

template<
class CtaTileShapeMNK,
class StrideAux,
class CopyOpG2R,
template <class> class ActivationFn,
Expand All @@ -185,7 +184,7 @@ template<
using XeLinCombDeEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
XeAuxLoad<CtaTileShapeMNK, ElementAux, StrideAux, CopyOpG2R> // aux
XeAuxLoad<ElementAux, StrideAux, CopyOpG2R> // aux
>;

// Z = Aux
Expand Down Expand Up @@ -216,17 +215,17 @@ struct FusionCallbacks<
EpilogueTile,
CopyOpG2R
> : XeLinCombDeEltAct<
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
ElementOutput_, ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput_,
ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle
> {

using ElementOutput = ElementOutput_;
using ElementCompute = ElementCompute_;

using Impl =
XeLinCombDeEltAct<
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput,
ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle
>;
using Operation =
fusion::LinCombDeEltAct<
Expand Down
15 changes: 4 additions & 11 deletions include/cutlass/epilogue/fusion/xe_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ using namespace cutlass::epilogue::fusion;
/////////////////////////////////////////////////////////////////////////////////////////////////

template <
class CtaTileShapeMNK,
class Element,
class StrideMNL,
class CopyOpG2R,
Expand Down Expand Up @@ -196,22 +195,16 @@ struct XeAuxLoad {
using TiledMma = decltype(args.tiled_mma);
using MmaAtomShape = typename TiledMma::AtomShape_MNK;

static constexpr auto BLK_M = get<0>(CtaTileShapeMNK{});
static constexpr auto BLK_N = get<1>(CtaTileShapeMNK{});

static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());

static constexpr auto SG_M = BLK_M / ATOM_M;
static constexpr auto SG_N = BLK_N / ATOM_N;
auto SG_M = get<0>(args.tile_shape_mnk);
auto SG_N = get<1>(args.tile_shape_mnk);

static constexpr int FragsM = SG_M / get<0>(MmaAtomShape()); // A frags per sub_group
static constexpr int FragsN = SG_N / get<1>(MmaAtomShape()); // B frags per sub_group

auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl;
auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
auto m_offset = m_coord * SG_M;
auto n_offset = n_coord * SG_N;
Tensor tOuti = args.tiled_copy.get_pvc_tensor(
make_coord(m_offset, n_offset, 0),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}, L),
Expand Down

0 comments on commit a1df53e

Please sign in to comment.