diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 9cb95248b..d1380c265 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -308,31 +308,22 @@ class CollectiveEpilogue< make_coord(m_offset, n_offset, 0), make_shape(_, Int{}, Int{}, L), make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); - // TODO(joe): I replaced tiled_r2s from the sm90 version with tiled_g2r on the assumption - // that this would give me the equivalent thread layout that I need for XE epilogue (which - // doesn't use shared memory. Is this the correct/best approach? - ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); // Because Nvidia uses shared memory, they are not tied to using the same accumulator values // for MMA and Epilogue. But because we are operating directly in the accumulators, we need to be // sure that we are operating on the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); // OOB predication for tile quantization "residue" // Absolute coordinate tensors (dynamic) Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) - // Is this likely to be wrong, since we have redefined m_coord (?) for subgroup scale logic Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) Tensor tRS_cD_mn = thread_g2r.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) - Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) Tensor rw_coord = tOuti(_,_,_,l_coord); - // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate - auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) - auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) - // Get the fusion callbacks constexpr bool RefSrc = true; auto residue_mn = make_coord(M, N); @@ -343,10 +334,10 @@ class CollectiveEpilogue< tiled_mma, EpilogueTile{}, params.xe_store_d, - cD, - residue_cD, + rw_coord, + residue_mn, tRS_cD, - residue_tRS_cD, + residue_mn, trC, thread_idx, };