Skip to content

Commit

Permalink
Fixed block size calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Apr 26, 2024
1 parent 8f00614 commit 38776ee
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
13 changes: 5 additions & 8 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ namespace cute
template <class GTensor>
struct Copy_Traits<XE_2D_LOAD, GTensor>
{
// using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails
using ThrID = Layout<_1>;
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>; // hacky: does vec of 8
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBits>>; // TODO: is _1 correct?
using SrcLayout = Layout<Shape<_1, NumBits>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBits>>;
// Reference map from (thr,val) to bit
Expand All @@ -61,8 +60,7 @@ namespace cute
{
static_assert(is_rmem<TD>::value);
int H = size<0>(traits.tensor);
// int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type);
int W = size<1>(traits.tensor) * sizeof(typename TD::value_type); //TODO: inconsistent to give the size in elements but use vector for copy
int W = size<1>(traits.tensor) * sizeof(typename TD::value_type);
auto [y, x, z] = src.data().coord_;
XE_2D_LOAD::copy(traits.tensor.data() + z, W, H, W, int2_{static_cast<int>(x), static_cast<int>(y)}, &*dst.data());
}
Expand All @@ -71,11 +69,10 @@ namespace cute
template <class GTensor>
struct Copy_Traits<XE_2D_SAVE, GTensor>
{
// using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails
using ThrID = Layout<_1>;
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>; // hacky: does vec of 8
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBits>>; // TODO: is _1 correct?
using SrcLayout = Layout<Shape<_1, NumBits>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBits>>;
// Reference map from (thr,val) to bit
Expand Down
3 changes: 3 additions & 0 deletions include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ struct CollectiveMma<
static constexpr int tN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B
static constexpr int tK = get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per sub_group for Matrix A

static constexpr uint32_t MaxThreadsPerBlock = tM * tN;
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;

static constexpr int MM = get<0>(TileShape{}) / tM; // A frags per sub_group
static constexpr int NN = get<1>(TileShape{}) / tN; // B frags per sub_group

Expand Down
7 changes: 3 additions & 4 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,10 @@ class GemmUniversal<
// MSVC requires the cast to fix a warning-as-error.
static constexpr int SharedStorageSize = 0;
static constexpr uint32_t MaxThreadsPerBlock = 64;
// static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::get<0>(TiledMma{}) * cute::get<1>(TiledMma{}));
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr int SG_SZ = CollectiveMainloop::SG_SZ; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor;
static constexpr int num_sg = MaxThreadsPerBlock / SG_SZ; // number of sub_groups per work group
static constexpr int tM = CollectiveMainloop::tM;
Expand Down

0 comments on commit 38776ee

Please sign in to comment.