Skip to content

Commit

Permalink
Simplify ConsumerStoreCallbacks creation
Browse files Browse the repository at this point in the history
  • Loading branch information
joeatodd committed Dec 4, 2024
1 parent b787872 commit b5db222
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,31 +308,22 @@ class CollectiveEpilogue<
make_coord(m_offset, n_offset, 0),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(Int<get<0>(MmaAtomShape{})>{}, Int<get<1>(MmaAtomShape{})>{}, _1{}));
// TODO(joe): I replaced tiled_r2s from the sm90 version with tiled_g2r on the assumption
// that this would give me the equivalent thread layout that I need for XE epilogue (which
// doesn't use shared memory. Is this the correct/best approach?

ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx);
// Because Nvidia uses shared memory, they are not tied to using the same accumulator values
// for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be
// sure that we are operating on the same values.
ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx);

// OOB predication for tile quantization "residue"
// Absolute coordinate tensors (dynamic)
Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N)
// Is this likely to be wrong, since we have redefined m_coord (?) for subgroup scale logic
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)

Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N)
Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N)

Tensor rw_coord = tOuti(_,_,_,l_coord);

// Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate
auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n)
auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n)

// Get the fusion callbacks
constexpr bool RefSrc = true;
auto residue_mn = make_coord(M, N);
Expand All @@ -343,10 +334,10 @@ class CollectiveEpilogue<
tiled_mma,
EpilogueTile{},
params.xe_store_d,
cD,
residue_cD,
rw_coord,
residue_mn,
tRS_cD,
residue_tRS_cD,
residue_mn,
trC,
thread_idx,
};
Expand Down

0 comments on commit b5db222

Please sign in to comment.