diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 331acaea29..4d3df34f7f 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -40,11 +40,10 @@ namespace cute template struct Copy_Traits { - // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails using ThrID = Layout<_1>; - using NumBits = Int; // hacky: does vec of 8 + using NumBits = Int; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; // TODO: is _1 correct? + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>; // Reference map from (thr,val) to bit @@ -61,8 +60,7 @@ namespace cute { static_assert(is_rmem::value); int H = size<0>(traits.tensor); - // int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type); - int W = size<1>(traits.tensor) * sizeof(typename TD::value_type); //TODO: inconsistent to give the size in elements but use vector for copy + int W = size<1>(traits.tensor) * sizeof(typename TD::value_type); auto [y, x, z] = src.data().coord_; XE_2D_LOAD::copy(traits.tensor.data() + z, W, H, W, int2_{static_cast(x), static_cast(y)}, &*dst.data()); } @@ -71,11 +69,10 @@ namespace cute template struct Copy_Traits { - // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails using ThrID = Layout<_1>; - using NumBits = Int; // hacky: does vec of 8 + using NumBits = Int; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; // TODO: is _1 correct? + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>; // Reference map from (thr,val) to bit diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 706ef8df7a..5dbc725005 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -104,6 +104,9 @@ struct CollectiveMma< static constexpr int tN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B static constexpr int tK = get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per sub_group for Matrix A + static constexpr uint32_t MaxThreadsPerBlock = tM * tN; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr int MM = get<0>(TileShape{}) / tM; // A frags per sub_group static constexpr int NN = get<1>(TileShape{}) / tN; // B frags per sub_group diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index f4eb30cbb0..889140270c 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -99,11 +99,10 @@ class GemmUniversal< // MSVC requires the cast to fix a warning-as-error. static constexpr int SharedStorageSize = 0; - static constexpr uint32_t MaxThreadsPerBlock = 64; - // static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::get<0>(TiledMma{}) * cute::get<1>(TiledMma{})); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static constexpr int SG_SZ = CollectiveMainloop::SG_SZ; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor; + static constexpr int num_sg = MaxThreadsPerBlock / SG_SZ; // number of sub_groups per work group static constexpr int tM = CollectiveMainloop::tM;