Skip to content

Commit

Permalink
Merge pull request #12 from taozha2/zt/debug
Browse files Browse the repository at this point in the history
refine get_pvc_tensor
  • Loading branch information
taozha2 authored Dec 19, 2024
2 parents b7f4d66 + 4503974 commit 513b72d
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 146 deletions.
33 changes: 33 additions & 0 deletions benchmarks/pvc/benchmarks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,49 @@ using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration<
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>,
XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_6 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
float, cutlass::layout::RowMajor,
float, Shape<_8, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>,
XE_2D_U16x8x32_LD_N, XE_2D_U16x16x16_LD_T>;

using PvcGemmBF16BF16FP32_RRR_7 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_8, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>,
XE_2D_U16x16x16_LD_T, XE_2D_U16x32x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_8 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
cutlass::bfloat16_t, cutlass::layout::ColumnMajor,
float, cutlass::layout::RowMajor,
float, Shape<_8, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>,
XE_2D_U16x16x16_LD_T, XE_2D_U16x16x16_LD_T>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_6);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_7);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_8);

static void register_benchmarks() {
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_6);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_7);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_8);
}
6 changes: 6 additions & 0 deletions benchmarks/pvc/input.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_6 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_RRR_6 --bm_name=bf16_bf16_fp32 --l=32 --m=256 --k=2048 --n=16384
PvcGemmBF16BF16FP32_RRR_7 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_RRR_7 --bm_name=bf16_bf16_fp32 --l=32 --m=128 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR_8 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_RRR_8 --bm_name=bf16_bf16_fp32 --l=32 --m=16384 --k=4096 --n=1024
74 changes: 28 additions & 46 deletions examples/sycl/pvc/pvc_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,34 @@ struct ExampleRunner {
float cute_time = timer.seconds() / options.iterations;
double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
printf("Cutlass GEMM (A %s, B %s) Performance: [%4.3f]TFlop/s (%6.4f)ms\n\n",
std::is_same_v<LayoutA, cutlass::layout::RowMajor> ? "RowMajor" : "ColumnMajor",
std::is_same_v<LayoutB, cutlass::layout::RowMajor> ? "RowMajor" : "ColumnMajor",
tflops / cute_time, cute_time*1000);
printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000);
}

return;
}

};

template<bool a_row_major, bool b_row_major, class a_type, class b_type, class c_type>
static constexpr auto gemm_run(Options const& options) {
int main(int argc, const char** argv)
{
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}

//
// Run examples
//
Expand All @@ -285,17 +300,17 @@ static constexpr auto gemm_run(Options const& options) {
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputA = a_type; // <- data type of elements in input matrix A
using ElementInputB = b_type; // <- data type of elements in input matrix B
using ElementOutput = c_type; // <- data type of elements in output matrix D
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D

using LayoutA = std::conditional_t<a_row_major, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>;
using LayoutB = std::conditional_t<b_row_major, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = std::conditional_t<a_row_major, XE_2D_U16x32x32_LD_N, XE_2D_U16x16x16_LD_T>;
using GmemTiledCopyB = std::conditional_t<b_row_major, XE_2D_U16x32x32_LD_V, XE_2D_U16x16x16_LD_T>;
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;
Expand Down Expand Up @@ -350,39 +365,6 @@ static constexpr auto gemm_run(Options const& options) {
ExampleRunner<Gemm> runner;

runner.run(options, hw_info);
}

int main(int argc, const char** argv)
{
//
// Parse options
//

Options options;

options.parse(argc, argv);

if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}

if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return -1;
}

// row major A, row major B
gemm_run<true, true, bfloat16_t, bfloat16_t, float>(options);

// row major A, column major B
gemm_run<true, false, bfloat16_t, bfloat16_t, float>(options);

// column major A, row major B
gemm_run<false, true, bfloat16_t, bfloat16_t, float>(options);

// column major A, column major B
gemm_run<false, false, bfloat16_t, bfloat16_t, float>(options);

return 0;
}
181 changes: 130 additions & 51 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,56 +40,89 @@ namespace cute {

namespace detail
{
template<class CopyOp>
struct is_transpose : bool_constant<false> {};

template<>
struct is_transpose<XE_2D_U16x16x8_LD_T> : bool_constant<true>{};

template<>
struct is_transpose<XE_2D_U16x16x16_LD_T> : bool_constant<true>{};

template<>
struct is_transpose<XE_2D_U32x16x2_LD_T> : bool_constant<true>{};

template<>
struct is_transpose<XE_2D_U32x16x4_LD_T> : bool_constant<true>{};

template<>
struct is_transpose<XE_2D_U32x16x8_LD_T> : bool_constant<true>{};

template<>
struct is_transpose<XE_2D_U64x8x1_LD_T> : bool_constant<true>{};

template<>
struct is_transpose<XE_2D_U64x8x2_LD_T> : bool_constant<true>{};

template<>
struct is_transpose<XE_2D_U64x8x4_LD_T> : bool_constant<true>{};
struct MKL_Indicator {};
struct NKL_Indicator {};

template <class MNKL_Indicator, class Enable = void>
struct is_MKL_layout {
static constexpr bool value = false;
};

template <class MNKL_Indicator>
struct is_MKL_layout<MNKL_Indicator, std::enable_if_t<std::is_same_v<MNKL_Indicator, MKL_Indicator>>> {
static constexpr bool value = true;
};

template <class MNKL_Indicator, class Enable = void>
struct is_NKL_layout {
static constexpr bool value = false;
};

template <class MNKL_Indicator>
struct is_NKL_layout<MNKL_Indicator, std::enable_if_t<std::is_same_v<MNKL_Indicator, NKL_Indicator>>> {
static constexpr bool value = true;
};

template <class MNKL_Indicator, class Stride>
struct is_transpose_load{
static constexpr bool value = (is_MKL_layout<MNKL_Indicator>::value
&& std::is_same_v<cutlass::detail::StrideToLayoutTagA_t<Stride>, cutlass::layout::ColumnMajor>)
|| (is_NKL_layout<MNKL_Indicator>::value
&& std::is_same_v<cutlass::detail::StrideToLayoutTagB_t<Stride>, cutlass::layout::ColumnMajor>);
};

template <class, class Enable = void> constexpr bool has_inst_dtype = false;

template <class T>
constexpr bool has_inst_dtype<T, cute::void_t<typename T::inst_dtype>> = true;

template <class T, class dtype, class Enable = void>
struct size_of_inst {
static constexpr auto value = sizeof(dtype);
};

template <class T, class dtype>
struct size_of_inst<T, dtype, std::enable_if_t<has_inst_dtype<T>>> {
static constexpr auto value = sizeof(typename T::inst_dtype);
};

} // namespace detail end

template <class CopyOp, class... ArgTs> struct XE_2D_LD_Unpack {
template <class CopyOp,
class Indicator_MNK = detail::MKL_Indicator,
class GStride = cute::Stride<int64_t, cute::Int<1>, int64_t>>
struct XE_2D_LD_Unpack {
const void *base_ptr;
uint32_t width;
uint32_t height;
uint32_t pitch;

XE_2D_LD_Unpack(const void *ptr, uint32_t const &w,
uint32_t const &h, uint32_t const &p)
: base_ptr(ptr), width(w), height(h), pitch(p) {}

static constexpr bool is_mkl = detail::is_MKL_layout<Indicator_MNK>::value;
static constexpr bool is_nkl = detail::is_NKL_layout<Indicator_MNK>::value;
static constexpr bool is_transpose = detail::is_transpose_load<Indicator_MNK, GStride>::value;

static_assert(is_mkl != is_nkl);

XE_2D_LD_Unpack(const void *ptr, uint32_t const &y,
uint32_t const &x, uint32_t const &p = 0) : base_ptr(ptr) {
if (is_nkl) {
width = is_transpose ? x : y;
height = is_transpose ? y : x;
pitch = (p == 0 ? width : p);
} else {
width = is_transpose ? y : x;
height = is_transpose ? x : y;
pitch = (p == 0 ? width : p);
}
}

template <class TraitsArgs>
XE_2D_LD_Unpack(TraitsArgs const &traits) : base_ptr(traits.base_ptr),
width(traits.width), height(traits.height), pitch(traits.pitch) {}

XE_2D_LD_Unpack() {}

using Traits_LD_t = Copy_Traits<CopyOp, ArgTs...>;
using Traits_LD_t = Copy_Traits<CopyOp, Indicator_MNK, GStride>;

template <class TS, class SLayout, class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
Expand All @@ -100,19 +133,23 @@ template <class CopyOp, class... ArgTs> struct XE_2D_LD_Unpack {
using dtype = typename Tensor<TD, DLayout>::value_type;

dtype *base_addr = (dtype *)traits.base_ptr;

auto [m, n, l] = src.data().coord_;

auto inst_size = sizeof(dtype);

if constexpr (detail::has_inst_dtype<CopyOp>) {
inst_size = sizeof(typename CopyOp::inst_dtype);

int x, y;
auto [coord_0, coord_1, z] = src.data().coord_;
if constexpr (is_mkl ^ is_transpose) {
x = coord_1;
y = coord_0;
} else {
x = coord_0;
y = coord_1;
}

CopyOp::copy(base_addr + l * traits.width * traits.height,
static constexpr auto inst_size = detail::size_of_inst<CopyOp, dtype>::value;

CopyOp::copy(base_addr + z * traits.width * traits.height,
traits.width * sizeof(dtype), traits.height,
traits.pitch * sizeof(dtype),
intel::coord_t{(int)(n * sizeof(dtype) / inst_size), (int)(m)},
intel::coord_t{(int)(x * sizeof(dtype) / inst_size), y},
&*dst.data());
}

Expand All @@ -134,20 +171,62 @@ template <class CopyOp, class... ArgTs> struct XE_2D_LD_Unpack {
intel::coord_t{(int)n, (int)m});
}

template <class GCoord, class GShape, class GStride, class Basis = decltype(make_seq<rank(GStride{})>{})>
CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord,
template <class GShape>
CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(int m_coord, int n_coord, int l_coord,
GShape const &shape) const {

auto R = rank(GShape{});
static_assert(R == 3, "mismatch rank");

auto t_shape = cute::tuple_cat(make_shape(_1{}), take<1, R>(shape));

auto basis = make_seq<rank(typename CopyOp::Shape_MN{})>{};

if constexpr (is_mkl) {
if constexpr (!is_transpose) {
auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(basis, typename CopyOp::Shape_MN{},
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
make_layout(t_shape, t_stride));
} else {
auto t_stride = cute::tuple_cat(make_stride(_1{}), transform((basis), typename CopyOp::Shape_MN{},
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
make_layout(t_shape, t_stride));
}
} else if constexpr (is_nkl) {
if constexpr (!is_transpose) {
auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(reverse(basis), typename CopyOp::Shape_MN{},
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
make_layout(t_shape, t_stride));
} else {
auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(reverse(basis), typename CopyOp::Shape_MN{},
[&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(make_coord(m_coord, n_coord, l_coord)),
make_layout(t_shape, t_stride));
}
}
}

template <class GShape, class Direction>
CUTE_HOST_DEVICE constexpr auto get_pvc_tensor_B(int m_coord, int n_coord, int l,
GShape const &shape,
GStride const &stride,
Basis const & basis = {}) const {
Direction const& direction) const {

auto R = rank(GShape{});
static_assert(R == 3 || R == 4, "mismatch rank");
static_assert(R == 3, "mismatch rank");
auto t_shape = cute::tuple_cat(make_shape(_1{}), take<1, R>(shape));
auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(basis, stride, [&](auto i, auto s){
return E<i>{} * s;
}));
return make_tensor(make_inttuple_iter(coord),
make_layout(t_shape, t_stride));


}

template <class T1, class T2, class... TraitsArgs>
Expand Down
Loading

0 comments on commit 513b72d

Please sign in to comment.