Skip to content

Commit

Permalink
use make_coord
Browse files Browse the repository at this point in the history
  • Loading branch information
taozha2 committed Jan 6, 2025
1 parent 7ec0bd3 commit cdbd09c
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 32 deletions.
12 changes: 6 additions & 6 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ struct XE_2D_LD_Unpack {
intel::coord_t{(int)n, (int)m});
}

template <class GShape>
CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(int m_coord, int n_coord, int l_coord,
template <class Coord, class GShape>
CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(Coord const &coord,
GShape const &shape) const {

auto R = rank(GShape{});
Expand All @@ -216,7 +216,7 @@ struct XE_2D_LD_Unpack {
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
return make_tensor(make_inttuple_iter(coord),
make_layout(new_shape, new_stride));
}

Expand Down Expand Up @@ -266,8 +266,8 @@ template <class CopyOp, class... ArgTs> struct XE_2D_ST_Unpack {
intel::coord_t{(int)n, (int)m}, &*src.data());
}

template <class GShape>
CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(int m_coord, int n_coord, int l_coord,
template <class Coord, class GShape>
CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(Coord const &coord,
GShape const &shape) const {

auto R = rank(GShape{});
Expand All @@ -282,7 +282,7 @@ template <class CopyOp, class... ArgTs> struct XE_2D_ST_Unpack {
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
return make_tensor(make_inttuple_iter(coord),
make_layout(new_shape, new_stride));
}

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class CollectiveEpilogue<
Tensor trC = make_tensor<typename TiledMma::ValTypeC>(Shape<Int<FragmentSize>>{});
Tensor trD = make_tensor<typename TiledMma::ValTypeD>(Shape<Int<FragmentSize>>{});
Tensor rw_coord = params.xe_store_d.get_pvc_tensor(
m_offset, n_offset, l_offset,
make_coord(m_offset, n_offset, l_offset),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}));

Tensor mD_crd = make_identity_tensor(make_shape(M,N));
Expand Down
16 changes: 8 additions & 8 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,26 +262,26 @@ struct CollectiveMma<
#endif
const int l_coord = l_idx;

Tensor block2d_copy_iter_a = gmem_tiled_copy_a.get_pvc_tensor(m_coord, 0, l_coord, tCrA_copy_view.shape());
Tensor block2d_copy_iter_a = gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, l_coord), tCrA_copy_view.shape());
auto copy_iter_a = append_pvc_tensor<1>(block2d_copy_iter_a, k_tile_count, BLK_K);

Tensor block2d_copy_iter_b = gmem_tiled_copy_b.get_pvc_tensor(n_coord, 0, l_coord, tCrB_copy_view.shape());
Tensor block2d_copy_iter_b = gmem_tiled_copy_b.get_pvc_tensor(make_coord(n_coord, 0, l_coord), tCrB_copy_view.shape());
auto copy_iter_b = append_pvc_tensor<1>(block2d_copy_iter_b, k_tile_count, BLK_K);

const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start));
int prefetch_k = 0;

Tensor block2d_prefetch_iter_a = XE_Prefetch_A{}.get_pvc_tensor(
m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}),
(k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA,
l_coord,
make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}),
(k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA,
l_coord),
make_shape(_1{}, _1{}, _1{}));
auto prefetch_iter_a = append_pvc_tensor<1>(block2d_prefetch_iter_a, k_tile_count, BLK_K);

Tensor block2d_prefetch_iter_b = XE_Prefetch_B{}.get_pvc_tensor(
(get_sub_group_id() / ATOM_N / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB,
n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}),
l_coord,
make_coord((get_sub_group_id() / ATOM_N / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB,
n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}),
l_coord),
make_shape(_1{}, _1{}, _1{}));
auto prefetch_iter_b = append_pvc_tensor<0>(block2d_prefetch_iter_b, k_tile_count, BLK_K);

Expand Down
4 changes: 2 additions & 2 deletions test/unit/cute/intel_xe/copy_block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void copy_kernel_vectorized(TensorS S, TensorD D, TiledLoad load,
auto thr_tile_load_D = thr_copy_load.partition_D(S);
auto fragment = make_fragment_like(thr_tile_load_D);
auto ld_tensor =
load.get_pvc_tensor(m_coord, n_coord, l_coord, fragment.shape());
load.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), fragment.shape());
if constexpr (cute::detail::has_prefetch<CopyOp>)
prefetch(load, ld_tensor);
copy(load, ld_tensor, fragment);
Expand All @@ -67,7 +67,7 @@ void copy_kernel_vectorized(TensorS S, TensorD D, TiledLoad load,
make_tensor(static_cast<decltype(fragment) &&>(fragment).data(),
thr_copy_store.partition_S(D).shape());
auto st_tensor =
store.get_pvc_tensor(m_coord, n_coord, l_coord, frag_view.shape());
store.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), frag_view.shape());
copy(store, frag_view, st_tensor);

#if 0
Expand Down
4 changes: 2 additions & 2 deletions test/unit/cute/intel_xe/copy_subgroup_block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) {
const int l_coord = BlockIdxZ();

// Copy from GMEM to RMEM and from RMEM to GMEM
auto blk_load_S = tiled_copy_load.get_pvc_tensor(m_coord, n_coord, l_coord,
auto blk_load_S = tiled_copy_load.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord),
fragment.shape());
copy(tiled_copy_load, blk_load_S, fragment);

Expand All @@ -146,7 +146,7 @@ void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) {
}
#endif

auto blk_store_D = tiled_copy_store.get_pvc_tensor(m_coord, n_coord, l_coord,
auto blk_store_D = tiled_copy_store.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord),
fragment.shape());

// onlt run first subgroup
Expand Down
10 changes: 5 additions & 5 deletions test/unit/cute/intel_xe/gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ struct gemm_device_partition_fragment_abc {

auto k_tile_max = size<2>(gA);
for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) {
Tensor blk_tgA = copy_a.get_pvc_tensor(m_coord, k_tile * sg_tile_k,
l_coord, copy_view_A.shape());
Tensor blk_tgB = copy_b.get_pvc_tensor(n_coord, k_tile * sg_tile_k,
l_coord, copy_view_B.shape());
Tensor blk_tgA = copy_a.get_pvc_tensor(make_coord(m_coord, k_tile * sg_tile_k, l_coord),
copy_view_A.shape());
Tensor blk_tgB = copy_b.get_pvc_tensor(make_coord(n_coord, k_tile * sg_tile_k, l_coord),
copy_view_B.shape());

#if CUTLASS_ENABLE_DEBUG_PRINTS
if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) {
Expand All @@ -204,7 +204,7 @@ struct gemm_device_partition_fragment_abc {
}

Tensor blk_tgC =
copy_c.get_pvc_tensor(m_coord, n_coord, l_coord, fragment_C.shape());
copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), fragment_C.shape());

copy(copy_c, fragment_C, blk_tgC);
}
Expand Down
10 changes: 5 additions & 5 deletions test/unit/cute/intel_xe/gemm_partition_src_dst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ struct gemm_device_partition_sd {
auto k_tile_max = size<3>(tgA);
for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) {

Tensor blk_tgA = copy_a.get_pvc_tensor(m_coord, k_tile * sg_tile_k,
l_coord, fragment_A.shape());
Tensor blk_tgB = copy_b.get_pvc_tensor(n_coord, k_tile * sg_tile_k,
l_coord, fragment_B.shape());
Tensor blk_tgA = copy_a.get_pvc_tensor(make_coord(m_coord, k_tile * sg_tile_k, l_coord),
fragment_A.shape());
Tensor blk_tgB = copy_b.get_pvc_tensor(make_coord(n_coord, k_tile * sg_tile_k, l_coord),
fragment_B.shape());

#if CUTLASS_ENABLE_DEBUG_PRINTS
if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) {
Expand All @@ -205,7 +205,7 @@ struct gemm_device_partition_sd {
}

Tensor blk_tgC =
copy_c.get_pvc_tensor(m_coord, n_coord, l_coord, fragment_C.shape());
copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), fragment_C.shape());

copy(copy_c, fragment_C, blk_tgC);
}
Expand Down
6 changes: 3 additions & 3 deletions test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ struct gemm_device_tiled_copy_abc {

for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) {
Tensor blk_tgA = tiled_copy_A.get_pvc_tensor(
m_coord, k_tile * sg_tile_k, l_coord, tCrA_copy_view.shape());
make_coord(m_coord, k_tile * sg_tile_k, l_coord), tCrA_copy_view.shape());
Tensor blk_tgB = tiled_copy_B.get_pvc_tensor(
n_coord, k_tile * sg_tile_k, l_coord, tCrB_copy_view.shape());
make_coord(n_coord, k_tile * sg_tile_k, l_coord), tCrB_copy_view.shape());

copy(tiled_copy_A, blk_tgA, tCrA_copy_view);
copy(tiled_copy_B, blk_tgB, tCrB_copy_view);
Expand All @@ -185,7 +185,7 @@ struct gemm_device_tiled_copy_abc {
cute::gemm(mma, tiled_copy_A, tiled_copy_B, tCrA, tCrB, tCrC);
}

Tensor blk_tgC = tiled_copy_C.get_pvc_tensor(m_coord, n_coord, l_coord,
Tensor blk_tgC = tiled_copy_C.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord),
tCrC_copy_view.shape());
copy(copy_c, tCrC_copy_view, blk_tgC);
}
Expand Down

0 comments on commit cdbd09c

Please sign in to comment.