Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
taozha2 committed Dec 20, 2024
1 parent e0f92f9 commit 5bb970d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# PyCache files
__pycache__/
cutlass_library.egg-info/
cutlass_library.egg-info/
build*
build/*
.*
2 changes: 1 addition & 1 deletion include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct XE_2D_LD_Unpack {

XE_2D_LD_Unpack(const void *ptr, uint32_t const &y,
uint32_t const &x, uint32_t const &p = 0) : base_ptr(ptr) {
if (is_nkl) {
if constexpr (is_nkl) {
width = is_transpose ? x : y;
height = is_transpose ? y : x;
pitch = (p == 0 ? width : p);
Expand Down
14 changes: 7 additions & 7 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ struct CollectiveMma<
return Params{copyA, copyB, prefetchA, prefetchB};
}

template <class Tensor_t, class Layout_t>
static constexpr auto append_pvc_tensor_with_layout(Tensor_t const &t0, Layout_t const & layout) {
return make_tensor(make_inttuple_iter(t0.data()), append(t0.layout(), layout));
template <class Tensor_t>
static constexpr auto append_pvc_tensor_with_k(Tensor_t const &t0, uint32_t k_shape, uint32_t k_stride) {
return make_tensor(make_inttuple_iter(t0.data()), append(t0.layout(), make_layout(k_shape, E<1>{} * k_stride)));
}

/// Perform a subgroup-scoped matrix multiply-accumulate
Expand Down Expand Up @@ -278,10 +278,10 @@ struct CollectiveMma<
const int l_coord = l_idx;

Tensor block2d_copy_iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor(m_coord, 0, l_coord, tCrA_copy_view.shape());
auto copy_iter_a = append_pvc_tensor_with_layout(block2d_copy_iter_a, make_layout(make_shape(k_tile_count), make_stride(E<1>{} *BLK_K)));
auto copy_iter_a = append_pvc_tensor_with_k(block2d_copy_iter_a, k_tile_count, BLK_K);

Tensor block2d_copy_iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor(n_coord, 0, l_coord, tCrB_copy_view.shape());
auto copy_iter_b = append_pvc_tensor_with_layout(block2d_copy_iter_b, make_layout(make_shape(k_tile_count), make_stride(E<1>{} *BLK_K)));
auto copy_iter_b = append_pvc_tensor_with_k(block2d_copy_iter_b, k_tile_count, BLK_K);

const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K));
int prefetch_k = 0;
Expand All @@ -291,14 +291,14 @@ struct CollectiveMma<
(k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA,
l_coord,
make_shape(_1{}, _1{}, _1{}));
auto prefetch_iter_a = append_pvc_tensor_with_layout(blocked_prefetch_iter_a, make_layout(make_shape(k_tile_count), make_stride(E<1>{} *BLK_K)));
auto prefetch_iter_a = append_pvc_tensor_with_k(blocked_prefetch_iter_a, k_tile_count, BLK_K);

Tensor blocked_prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor(
(get_sub_group_id() / ATOM_N / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB,
n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}),
l_coord,
make_shape(_1{}, _1{}, _1{}));
auto prefetch_iter_b = append_pvc_tensor_with_layout(blocked_prefetch_iter_b, make_layout(make_shape(k_tile_count), make_stride(E<0>{} *BLK_K)));
auto prefetch_iter_b = append_pvc_tensor_with_k(blocked_prefetch_iter_b, k_tile_count, BLK_K);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) {
Expand Down

0 comments on commit 5bb970d

Please sign in to comment.