Skip to content

Commit

Permalink
Pass subgroup args to ConsumerStoreArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
joeatodd committed Dec 3, 2024
1 parent f12b8b0 commit e4cd6fb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
17 changes: 12 additions & 5 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,15 @@ class CollectiveEpilogue<
auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
auto l_offset = l_coord;

// TODO(joe): We likely need to set EpilogueTile equal to the MN shape of the TiledMMA.
using EpilogueTile = Shape<_64, _64>;
using EpilogueTile = decltype(get<0>(params.xe_store_d.get_layoutS_MN()).shape());

auto sg_local_m_coord = get_sub_group_id() / ATOM_N;
auto sg_local_n_coord = get_sub_group_id() % ATOM_N;

auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord;
auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord;
auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord);

bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();

Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Expand Down Expand Up @@ -331,11 +338,11 @@ class CollectiveEpilogue<
auto residue_mn = make_coord(M, N);
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
problem_shape_mnkl,
TileShapeMNK{},
tile_coord_mnkl,
SubgroupTileShape{},
sg_coord,
tiled_mma,
EpilogueTile{},
params.xe_load_c,
params.xe_store_d,
cD,
residue_cD,
tRS_cD,
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/kernel/xe_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class GemmUniversal<
CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue};
epilogue(
problem_shape_MNKL,
subgroup_shape,
subgroup_shape, // TODO(joe): Inconsistency here w/ blk_coord_mnkl
blk_coord_mnkl,
accumulators,
tiled_mma,
Expand Down

0 comments on commit e4cd6fb

Please sign in to comment.