Skip to content

Commit

Permalink
Intel gpu backend gemm pipeline (#89)
Browse files Browse the repository at this point in the history
Enable the prefetch by copy atom.

---------

Co-authored-by: Mehdi Goli <[email protected]>
  • Loading branch information
Jiaxingla and mehdi-goli authored Aug 2, 2024
1 parent 8a65ebc commit 0b5c911
Show file tree
Hide file tree
Showing 10 changed files with 634 additions and 365 deletions.
36 changes: 2 additions & 34 deletions benchmarks/common/benchmark_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,42 +320,10 @@ struct PvcBenchmarkRunner : BenchmarkRunner<Gemm> {

using ProblemShapeType = typename Base::ProblemShapeType;

cutlass::DeviceAllocation<ElementB> block_B_vnni;

template <typename T>
void vnni_matrix(
T* dst, const T* src,
int batch, int numRows, int numCols, int factor)
{
for (int b = 0; b < batch; b++) {
for (int r = 0; r < numRows / factor; r++) {
for (int c = 0; c < numCols; c++) {
for (int k = 0; k < factor; k++) {
dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] =
src[((b * (numRows / factor) + r) * factor + k) * numCols + c];
}
}
}
}
}

void initialize(const ProblemShapeType& problem_size) override {
Base::initialize(problem_size);

auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;

block_B_vnni.reset(Base::block_B.size());

std::vector<ElementB> b(K * N * L);
std::vector<ElementB> b_vnni(b.size());

Base::block_B.copy_to_host(b.data());
vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2);

block_B_vnni.copy_from_host(b_vnni.data());
}

void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) override {
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};

Expand All @@ -364,7 +332,7 @@ struct PvcBenchmarkRunner : BenchmarkRunner<Gemm> {
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{Base::block_A.get(), Base::stride_A, block_B_vnni.get(), Base::stride_B},
{Base::block_A.get(), Base::stride_A, Base::block_B.get(), Base::stride_B},
{
{options.alpha, options.beta},
Base::block_C.get(), Base::stride_C, Base::block_D.get(), Base::stride_D
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ int main(int argc, const char** argv)
using LayoutD = cutlass::layout::RowMajor;

// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;
using TileShape = Shape<_256, _256, _32>;

using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V;

using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
Expand Down
36 changes: 7 additions & 29 deletions examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,6 @@ static void fill_matrix(std::vector<T> &vector)
return static_cast<T>( (rand() / double(RAND_MAX)) );
});
}

template <typename T>
static void vnni_matrix(
T* dst, const T* src,
int batch, int numRows, int numCols, int factor)
{
for (int b = 0; b < batch; b++) {
for (int r = 0; r < numRows / factor; r++) {
for (int c = 0; c < numCols; c++) {
for (int k = 0; k < factor; k++) {
dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] =
src[((b * (numRows / factor) + r) * factor + k) * numCols + c];
}
}
}
}
}

using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -89,7 +71,7 @@ struct Options {
Options():
help(false),
error(false),
m(4096), n(4096), k(4096), l(1), iterations(100),
m(4096), n(4096), k(4096), l(1), iterations(20),
alpha(1.f), beta(0.f)
{ }

Expand All @@ -108,7 +90,7 @@ struct Options {
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations, 100);
cmd.get_cmd_line_argument("iterations", iterations, 20);
}

/// Prints the usage statement.
Expand Down Expand Up @@ -170,7 +152,6 @@ struct ExampleRunner {

cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementB> block_B_vnni;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementOutput> block_D;
cutlass::DeviceAllocation<ElementOutput> block_ref_D;
Expand Down Expand Up @@ -231,7 +212,6 @@ struct ExampleRunner {

block_A.reset(M * K * L);
block_B.reset(K * N * L);
block_B_vnni.reset(K * N * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
Expand All @@ -247,11 +227,9 @@ struct ExampleRunner {
fill_matrix(a);
fill_matrix(b);
fill_matrix(c);
vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2);

syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA));
syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC));
syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC));
}
Expand All @@ -272,7 +250,7 @@ struct ExampleRunner {
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{block_A.get(), stride_A, block_B_vnni.get(), stride_B},
{block_A.get(), stride_A, block_B.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
Expand Down Expand Up @@ -362,14 +340,14 @@ int main(int argc, const char** argv)
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V;

// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;
using TileShape = Shape<_256, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile
Tile<_32,_64,_32>>; // Subgroup level-tile

using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
Expand Down
30 changes: 4 additions & 26 deletions examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,6 @@ static void fill_matrix(std::vector<T> &vector)
});
}

template <typename T>
static void vnni_matrix(
T* dst, const T* src,
int batch, int numRows, int numCols, int factor)
{
for (int b = 0; b < batch; b++) {
for (int r = 0; r < numRows / factor; r++) {
for (int c = 0; c < numCols; c++) {
for (int k = 0; k < factor; k++) {
dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] =
src[((b * (numRows / factor) + r) * factor + k) * numCols + c];
}
}
}
}
}

using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -171,7 +154,6 @@ struct ExampleRunner {

cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementB> block_B_vnni;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementOutput> block_D;
cutlass::DeviceAllocation<ElementOutput> block_ref_D;
Expand Down Expand Up @@ -238,7 +220,6 @@ struct ExampleRunner {

block_A.reset(M * K * L);
block_B.reset(K * N * L);
block_B_vnni.reset(K * N * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
Expand All @@ -247,18 +228,15 @@ struct ExampleRunner {
// available through SYCL.
std::vector<ElementA> a(K * M * L);
std::vector<ElementB> b(K * N * L);
std::vector<ElementB> b_vnni(b.size());
std::vector<ElementC> c(M * N * L);
std::vector<ElementC> d(M * N * L, ElementC{0});

fill_matrix(a);
fill_matrix(b);
fill_matrix(c);
vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2);

syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA));
syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC));
syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC));
}
Expand All @@ -271,7 +249,7 @@ struct ExampleRunner {
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{block_A.get(), stride_A, block_B_vnni.get(), stride_B},
{block_A.get(), stride_A, block_B.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
Expand Down Expand Up @@ -361,12 +339,12 @@ int main(int argc, const char** argv)
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V;

// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;
using TileShape = Shape<_256, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile

Expand Down
Loading

0 comments on commit 0b5c911

Please sign in to comment.