From bdcb4bb5ccd3539a6844b018ea2ff18cdc04f49a Mon Sep 17 00:00:00 2001 From: Joe Todd Date: Fri, 8 Nov 2024 15:47:52 +0000 Subject: [PATCH] Generalise args to ConsumerStoreArgs Not totally sure these are correct yet, but it's compiling... --- .../epilogue/collective/xe_epilogue.hpp | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index a5d00a0c0..8142cc64b 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -291,6 +291,12 @@ class CollectiveEpilogue< auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; auto l_offset = l_coord; + // TODO(joe): Need to make this a template argument, and/or work out what it needs to be + // for consistency with the MMA configuration. In other words: Each thread in our work-group + // has accumulators for certain elements of C. We need to be applying the epilogue for *those* + // elements specifically. For some epilogues this probably makes no difference, but e.g. + // PerRowBias, it's presumably important. + using EpilogueTile = Shape<_64, _16>; bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); Tensor trC = make_tensor(Shape>{}); @@ -299,10 +305,37 @@ class CollectiveEpilogue< make_coord(m_offset, n_offset, 0), make_shape(_, Int{}, Int{}, L), make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); + // TODO(joe): I replaced tiled_r2s from the sm90 version with tiled_r2g on the assumption + // that this would give me the equivalent thread layout that I need for XE epilogue (which + // doesn't use shared memory. Is this the correct/best approach? + // Furthermore, I am using `Copy_Atom` instead of `CopyAtomC` which isn't + // defined here. + // Also replaced make_tiled_copy_C_atom with make_tiled_copy_C here because that's for pipelined stuff + // Need to understand that a bit better + TiledCopy tiled_copy_C_atom = make_tiled_copy_C(Copy_Atom{}, tiled_mma); + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2g = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_r2g = tiled_r2g.get_slice(thread_idx); + // Because Nvidia uses shared memory, they are not tied to using the same accumulator values + // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be + // sure that we are operating on the same values. + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + // Is this likely to be wrong, since we have redefined m_coord (?) for subgroup scale logic + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_r2g.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) + + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2G,R2G_M,R2G_N,EPI_M,EPI_N) Tensor rw_coord = tOuti(_,_,_,l_coord); - Tensor mD_crd = make_identity_tensor(make_shape(M,N)); - Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(m_coord, n_coord)); + + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) + // Get the fusion callbacks constexpr bool RefSrc = true; auto residue_mn = make_coord(M, N); @@ -311,12 +344,12 @@ class CollectiveEpilogue< TileShapeMNK{}, tile_coord_mnkl, tiled_mma, - SubgroupTileShape{}, // Epilogue tile + EpilogueTile{}, params.xe_load_c, cD, - residue_mn, - cD, - residue_mn, + residue_cD, + tRS_cD, + residue_tRS_cD, trC, thread_idx, };