Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
taozha2 committed Dec 23, 2024
1 parent 4951d1b commit eadeaf5
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,27 +188,18 @@ struct XE_2D_LD_Unpack {
static_assert(R == 3, "mismatch rank");

using basis_t = make_seq<rank(typename CopyOp::Shape_MN{})>;
using rvs_basis_t = decltype(reverse(basis_t{}));
using use_basis_t = std::conditional_t<is_mkl, basis_t, rvs_basis_t>;

if constexpr (is_mkl) {
auto new_shape = cute::tuple_cat(make_shape(_1{}), take<R - 2, R>(shape));
auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, typename CopyOp::Shape_MN{},
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
make_layout(new_shape, new_stride));

} else {
auto new_shape = cute::tuple_cat(make_shape(_1{}), (take<R - 2, R>(shape)));
auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, reverse(typename CopyOp::Shape_MN{}),
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
make_layout(new_shape, new_stride));
}
using shape_mn = CopyOp::Shape_MN;
using rvs_shape_mn = decltype(reverse(shape_mn{}));
using use_shape_mn = std::conditional_t<is_mkl, shape_mn, rvs_shape_mn>;

auto new_shape = cute::tuple_cat(make_shape(_1{}), take<R - 2, R>(shape));
auto new_stride = cute::tuple_cat(make_stride(_1{}), transform(basis_t{}, use_shape_mn{},
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
make_layout(new_shape, new_stride));
}

template <class T1, class T2, class... TraitsArgs>
Expand Down

0 comments on commit eadeaf5

Please sign in to comment.