diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index b31eed044..0d8ed6a04 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -188,27 +188,18 @@ struct XE_2D_LD_Unpack { static_assert(R == 3, "mismatch rank"); using basis_t = make_seq; - using rvs_basis_t = decltype(reverse(basis_t{})); - using use_basis_t = std::conditional_t; - - if constexpr (is_mkl) { - auto new_shape = cute::tuple_cat(make_shape(_1{}), take(shape)); - auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, typename CopyOp::Shape_MN{}, - [&](auto i, auto s){ - return E{} * s; - })); - return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)), - make_layout(new_shape, new_stride)); - } else { - auto new_shape = cute::tuple_cat(make_shape(_1{}), (take(shape))); - auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, reverse(typename CopyOp::Shape_MN{}), - [&](auto i, auto s){ - return E{} * s; - })); - return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)), - make_layout(new_shape, new_stride)); - } + using shape_mn = CopyOp::Shape_MN; + using rvs_shape_mn = decltype(reverse(shape_mn{})); + using use_shape_mn = std::conditional_t; + + auto new_shape = cute::tuple_cat(make_shape(_1{}), take(shape)); + auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, use_shape_mn{}, + [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)), + make_layout(new_shape, new_stride)); } template