Skip to content

Commit

Permalink
Generalise args to ConsumerStoreArgs
Browse files Browse the repository at this point in the history
Not totally sure these are correct yet, but it's compiling...
  • Loading branch information
joeatodd committed Nov 19, 2024
1 parent 2257ac1 commit 4720c36
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ class CollectiveEpilogue<
auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
auto l_offset = l_coord;

// TODO(joe): Need to make this a template argument, and/or work out what it needs to be
// for consistency with the MMA configuration. In other words: Each thread in our work-group
// has accumulators for certain elements of C. We need to be applying the epilogue for *those*
// elements specifically. For some epilogues this probably makes no difference, but e.g.
// PerRowBias, it's presumably important.
using EpilogueTile = Shape<_64, _16>;
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Expand All @@ -299,10 +305,37 @@ class CollectiveEpilogue<
make_coord(m_offset, n_offset, 0),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(Int<get<0>(MmaAtomShape{})>{}, Int<get<1>(MmaAtomShape{})>{}, _1{}));
// TODO(joe): I replaced tiled_r2s from the sm90 version with tiled_r2g on the assumption
// that this would give me the equivalent thread layout that I need for XE epilogue (which
// doesn't use shared memory. Is this the correct/best approach?
// Furthermore, I am using `Copy_Atom<CopyOpR2G,Element>` instead of `CopyAtomC` which isn't
// defined here.
// Also replaced make_tiled_copy_C_atom with make_tiled_copy_C here because that's for pipelined stuff
// Need to understand that a bit better
TiledCopy tiled_copy_C_atom = make_tiled_copy_C(Copy_Atom<CopyOpR2G,ElementD>{}, tiled_mma);
// (t)hread-partition for (r)egister to (s)mem copy (tRS_)
TiledCopy tiled_r2g = make_tiled_copy_S(Copy_Atom<CopyOpR2G,ElementD>{}, tiled_copy_C_atom);
ThrCopy thread_r2g = tiled_r2g.get_slice(thread_idx);
// Because Nvidia uses shared memory, they are not tied to using the same accumulator values
// for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be
// sure that we are operating on the same values.

// OOB predication for tile quantization "residue"
// Absolute coordinate tensors (dynamic)
Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N)
// Is this likely to be wrong, since we have redefined m_coord (?) for subgroup scale logic
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
Tensor tRS_cD_mn = thread_r2g.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)

Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N)
Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N)

Tensor rw_coord = tOuti(_,_,_,l_coord);
Tensor mD_crd = make_identity_tensor(make_shape(M,N));
Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(m_coord, n_coord));

// Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate
auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n)
auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n)

// Get the fusion callbacks
constexpr bool RefSrc = true;
auto residue_mn = make_coord(M, N);
Expand All @@ -311,12 +344,12 @@ class CollectiveEpilogue<
TileShapeMNK{},
tile_coord_mnkl,
tiled_mma,
SubgroupTileShape{}, // Epilogue tile
EpilogueTile{},
params.xe_load_c,
cD,
residue_mn,
cD,
residue_mn,
residue_cD,
tRS_cD,
residue_tRS_cD,
trC,
thread_idx,
};
Expand Down

0 comments on commit 4720c36

Please sign in to comment.