Skip to content

Commit

Permalink
Fix B matrix layout (#85)
Browse files Browse the repository at this point in the history
* Use N major for B

* Use N major for B coordinates

* Update include/cute/atom/copy_traits_xe.hpp

Co-authored-by: Mehdi Goli <[email protected]>

* Update include/cute/atom/copy_traits_xe.hpp

Co-authored-by: Mehdi Goli <[email protected]>

* Update include/cute/atom/copy_traits_xe.hpp

Co-authored-by: Mehdi Goli <[email protected]>

* Update include/cute/atom/copy_traits_xe.hpp

Co-authored-by: Mehdi Goli <[email protected]>

* Fix A copy trait layout

* Update include/cute/atom/copy_traits_xe.hpp

Co-authored-by: Mehdi Goli <[email protected]>

---------

Co-authored-by: Mehdi Goli <[email protected]>
  • Loading branch information
aacostadiaz and mehdi-goli authored Jun 19, 2024
1 parent c584a18 commit ef0284f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 30 deletions.
60 changes: 44 additions & 16 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,55 @@

namespace cute
{

template <class IntT>
CUTE_HOST_DEVICE constexpr
auto get_shape_WHD(cute::Stride<Int<1>, IntT, IntT> , cute::Shape<int,int,int> shape_MKL) {
return shape_MKL;
}

template <class IntT>
CUTE_HOST_DEVICE constexpr
auto get_shape_WHD(cute::Stride<IntT, Int<1>, IntT> , cute::Shape<int,int,int> shape_MKL) {
return Shape<int, int, int>(get<1>(shape_MKL), get<0>(shape_MKL), get<2>(shape_MKL));
}

template <class IntT, class TS, class SLayout>
CUTE_HOST_DEVICE constexpr
auto get_coordinates(cute::Stride<Int<1>, IntT, IntT> ,
Tensor<ViewEngine<ArithmeticTupleIterator<TS>>, SLayout> const &src) {
auto [x, y, z] = src.data().coord_;
return make_coord(x, y, z);
}

template <class IntT, class TS, class SLayout>
CUTE_HOST_DEVICE constexpr
auto get_coordinates(cute::Stride<IntT, Int<1>, IntT> ,
Tensor<ViewEngine<ArithmeticTupleIterator<TS>>, SLayout> const &src) {
auto [x, y, z] = src.data().coord_;
return make_coord(y, x, z);
}

template <class CopyOp, class GTensor>
struct XE_2D_LD_Unpack
{
GTensor tensor;

using Copy_Traits = Copy_Traits<CopyOp, GTensor>;

template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const &traits,
Tensor<ViewEngine<ArithmeticTupleIterator<TS>>, SLayout> const &src,
Tensor<TD, DLayout> &dst)
{
static_assert(is_rmem<TD>::value);
int H = size<0>(traits.tensor);
int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType);
auto [y, x, z] = src.data().coord_;
CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data());
static_assert(is_rmem<TD>::value);
auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape());
int W = size<0>(shape_whd) * sizeof(typename Copy_Traits::CopyInternalType);
int H = size<1>(shape_whd);
auto [x, y, z] = get_coordinates(traits.tensor.stride(), src);
CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data());
}

template <class GCoord, class GShape, class GStride>
Expand Down Expand Up @@ -105,15 +136,13 @@ struct Copy_Traits<XE_2D_U16x8x16x4x2_LD_N, GTensor>
: XE_2D_LD_Unpack<XE_2D_U16x8x16x4x2_LD_N, GTensor>
{
// Logical thread id to thread idx
using ThrID = Layout<_16>;
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_16, _64>, Stride<_0, _1>>;
using SrcLayout = Layout<Shape<_1, Shape<_1, _1>>>; // one coordinate
// Map from (dst-thr,dst-val) to bit
using DstLayout =
Layout<Shape<_16, Shape<Shape<_8, _4>, Shape<_16, _2>>>,
Stride<_16, Stride<Stride<_512, _4096>, Stride<_1, _256>>>>;
using DstLayout = Layout<Shape<_1, Shape<_64, _1>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
using RefLayout = SrcLayout;
using CopyInternalType = ushort;
};

Expand Down Expand Up @@ -188,14 +217,13 @@ struct Copy_Traits<XE_2D_U16x16x16x2x1_LD_N, GTensor>
: XE_2D_LD_Unpack<XE_2D_U16x16x16x2x1_LD_N, GTensor>
{
// Logical thread id to thread idx
using ThrID = Layout<_16>;
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_16, _64>, Stride<_0, _1>>;
using SrcLayout = Layout<Shape<_1, Shape<_1, _4>>>; // expected 4 coordinates
// Map from (dst-thr,dst-val) to bit
using DstLayout =
Layout<Shape<_16, Shape<_16, _32>>, Stride<_32, Stride<_512, _1>>>;
using DstLayout = Layout<Shape<_1, Shape<_32, _4>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
using RefLayout = SrcLayout;
// 32 bits register file
using CopyInternalType = uint;
};
Expand Down
8 changes: 6 additions & 2 deletions include/cute/util/debug.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ bool
block(int bid)
{
#if defined(CUTLASS_ENABLE_SYCL)
return (syclcompat::get_nd_item<3>().get_group_linear_id()==bid);
using namespace syclcompat;
return (work_group_id::x() + work_group_id::y() * work_group_range::x() +
work_group_id::z() * work_group_range::y() * work_group_range::x() == bid);
#elif defined(__CUDA_ARCH__)
return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid;
#else
Expand All @@ -142,7 +144,9 @@ bool
thread(int tid, int bid)
{
#if defined(CUTLASS_ENABLE_SYCL)
return (syclcompat::get_nd_item<3>().get_global_linear_id()==bid);
using namespace syclcompat;
return (local_id::x() + local_id::y() * local_range::x() +
local_id::z() * local_range::x() * local_range::y() == tid) && block(bid);
#elif defined(__CUDA_ARCH__)
return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid);
#else
Expand Down
16 changes: 7 additions & 9 deletions include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ struct CollectiveMma<
auto [M,N,K,L] = problem_shape_MNKL;

Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA));
Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K,N,L), args.dB));
Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB));

typename Params::XE_Copy_A copyA = make_xe_2d_copy<GmemTiledCopyA>(tensorA);
typename Params::XE_Copy_B copyB = make_xe_2d_copy<GmemTiledCopyB>(tensorB);
Expand Down Expand Up @@ -187,14 +187,14 @@ struct CollectiveMma<
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

// Tensor to hold input data
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(SubgroupTileShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(
Shape<Int<FragsK * get<1>(SubgroupTileShape{}) / FragsN>, Int<FragsN>>{});
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(SubgroupTileShape{}) * FragsK>, _1>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(Shape<Int<get<1>(SubgroupTileShape{}) / 2>, Int<FragsN>>{});

Tensor tAr_view = make_tensor(static_cast<decltype(tAr) &&>(tAr).data(),
Shape<Int<VecA>, Int<FragsM>, Int<FragsK>>{});
Tensor tBr_view = make_tensor(static_cast<decltype(tBr) &&>(tBr).data(),
Shape<Int<VecB>, Int<FragsK>, Int<FragsN>>{});
Shape<Int<VecB>, Int<FragsN>, Int<FragsK>>{},
Stride<_1, Int<get<1>(SubgroupTileShape{}) / 2>, Int<VecB>>{});

// Instantiate the M MA object
TiledMma tiled_mma;
Expand All @@ -206,11 +206,9 @@ struct CollectiveMma<
{
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr);
copy(mainloop.gmem_tiled_copy_b, gB(_,k/2,_), tBr);
copy(mainloop.gmem_tiled_copy_b, gB(_,_,k/2), tBr);

for (int kl = 0; kl < FragsK; kl++) {
cute::gemm(tiled_mma, accum, tAr_view(_, _, kl), tBr_view(_, kl, _), src_accum);
}
cute::gemm(tiled_mma, accum, tAr_view, tBr_view, src_accum);
}
}
};
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ class GemmUniversal<
make_stride(Int<FragsM>{} * get<0>(MmaAtomShape()),_1{}));
Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(
make_coord(0, n_coord, 0),
make_shape(K, Int<FragsN>{}, L),
make_stride(_1{}, get<1>(MmaAtomShape())));
make_coord(n_coord, 0, 0),
make_shape(Int<FragsN>{}, K / 2, L),
make_stride(get<1>(MmaAtomShape()), _1{}));
// Compute tile residues for predication
auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord
Expand Down

0 comments on commit ef0284f

Please sign in to comment.