Skip to content

Commit

Permalink
Rename tile shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed Jun 13, 2024
1 parent f5e0a17 commit 8c72fd5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
18 changes: 10 additions & 8 deletions include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct CollectiveMma<
// Type Aliases
//
using DispatchPolicy = MainloopIntelPVCUnpredicated;
using TileShape = TileShape_;
using WorkgroupTileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
Expand All @@ -101,13 +101,14 @@ struct CollectiveMma<
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

using DpasShape = typename TiledMma::Shape_MNK;
using TileDpasShape = decltype(tile_shape(TiledMma()));
using SubgroupTileShape = decltype(tile_shape(TiledMma()));

static constexpr uint32_t MaxThreadsPerBlock = cute::size(TileShape{}) / cute::size(TileDpasShape{}) * SubgroupSize;
static constexpr uint32_t MaxThreadsPerBlock =
cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize;

static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group
static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape());
static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(DpasShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(DpasShape()); // B frags per sub_group
static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(DpasShape());

// Calculate the vector width based on the amount of registers
// required per work item by dividing the total fragment size by
Expand Down Expand Up @@ -186,8 +187,9 @@ struct CollectiveMma<
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

// Tensor to hold input data
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(TileDpasShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(Shape<Int<FragsK * get<1>(TileDpasShape{}) / FragsN>, Int<FragsN>>{});
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(SubgroupTileShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(
Shape<Int<FragsK * get<1>(SubgroupTileShape{}) / FragsN>, Int<FragsN>>{});

Tensor tAr_view = make_tensor(static_cast<decltype(tAr) &&>(tAr).data(),
Shape<Int<VecA>, Int<FragsM>, Int<FragsK>>{});
Expand Down
21 changes: 11 additions & 10 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class GemmUniversal<
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TileShape = typename CollectiveMainloop::WorkgroupTileShape;
using WorkgroupTileShape = TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
Expand All @@ -81,7 +82,7 @@ class GemmUniversal<
"Intel PVC does not support specializing the tile scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<
TileScheduler_, ArchTag, TileShape,
TileScheduler_, ArchTag, WorkgroupTileShape,
cute::Shape<cute::Int<1>, cute::Int<1>, cute::Int<1>>>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
Expand All @@ -103,7 +104,7 @@ class GemmUniversal<
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
using DpasShape = typename CollectiveMainloop::DpasShape;
using TileDpasShape = typename CollectiveMainloop::TileDpasShape;
using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape;
static constexpr int FragsM = CollectiveMainloop::FragsM;
static constexpr int FragsN = CollectiveMainloop::FragsN;
Expand Down Expand Up @@ -174,12 +175,12 @@ class GemmUniversal<
auto M = get<0>(params.problem_shape);
auto N = get<1>(params.problem_shape);
const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments
const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments
const int sg_m = (M - 1) / get<0>(SubgroupTileShape{}) + 1; // sub_groups required to process A fragments
const int sg_n = (N - 1) / get<1>(SubgroupTileShape{}) + 1; // sub_groups required to process B fragments
return dim3(
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))),
batch_count
);
}
Expand All @@ -196,7 +197,7 @@ class GemmUniversal<
(void)smem_buf;
// Preconditions
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
CUTE_STATIC_ASSERT(is_static<WorkgroupTileShape>::value);
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
Expand All @@ -214,8 +215,8 @@ class GemmUniversal<
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
constexpr auto workgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K)
const int m_coord = BlockIdxX() * get<0>(subgroup_shape);
const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape);
const int l_coord = BlockIdxZ();
Expand Down

0 comments on commit 8c72fd5

Please sign in to comment.