Skip to content

Commit db77438

Browse files
joyalbinamitchawla1
authored andcommitted
Adding UT for more dtype combinations in epilogue
1. ElementC is 'void', ElementCompute and ElementOutput different in LinearCombination 2. ElementAccumulator and ElementC have different types D=Ax B + C; => BF16=BF16xBF16+BF16 <=>BF16=FP32+BF16
1 parent 0a301f8 commit db77438

File tree

4 files changed

+284
-0
lines changed

4 files changed

+284
-0
lines changed

test/unit/gemm/device/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ if(CUTLASS_ENABLE_SYCL)
3434
xe_gemm_bf16_bf16_bf16_tensor_op_bf16.cpp
3535
xe_gemm_fp16_fp16_fp16_tensor_op_fp16.cpp
3636
xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp
37+
xe_gemm_bf16_bf16_fp32_tensor_op_bf16.cpp
3738
xe_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp
3839
xe_gemm_fp16_fp16_fp16_tensor_op_fp32.cpp
3940
xe_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp

test/unit/gemm/device/default_gemm_configuration.hpp

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ struct DefaultGemmConfigurationToCutlass3Types {
6262
static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists.");
6363
};
6464

65+
// This type is only intended to demonstrate porting 2.x kernels to 3.0
66+
template<
67+
class OperatorClass, class ArchTag,
68+
class ElementA, class LayoutA,
69+
class ElementB, class LayoutB,
70+
class ElementC, class LayoutC,
71+
class ElementAccumulator,
72+
class ElementOutput>
73+
struct XeDefaultGemmConfigurationToCutlass3Types {
74+
static_assert(sizeof(ElementA) == 0, "No valid XeDefaultGemmConfigurationToCutlass3Types configuration exists.");
75+
};
76+
6577
///////////////////////////////////////////////////////////////////////////////
6678

6779
namespace detail {
@@ -1486,6 +1498,141 @@ struct DefaultGemmConfigurationToCutlass3Types<
14861498
>::CollectiveOp;
14871499
};
14881500

1501+
///////////////////////////////////////////////////////////////////////////////
1502+
1503+
// Intel XE MMA F32BF16
1504+
// ElementC - > void
1505+
// ElementCompute and ElementOutput different in LinearCombination
1506+
template <typename LayoutA, typename LayoutB, typename LayoutC, typename ElementOutput>
1507+
struct DefaultGemmConfigurationToCutlass3Types<
1508+
arch::OpClassTensorOp, arch::IntelXe,
1509+
bfloat16_t, LayoutA,
1510+
bfloat16_t, LayoutB,
1511+
void, LayoutC,
1512+
ElementOutput>
1513+
{
1514+
using TileShape = Shape<_256, _256, _32>;
1515+
1516+
using TiledMma =
1517+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
1518+
Layout<TileShape>,
1519+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
1520+
1521+
// A
1522+
static constexpr int kAlignmentA = 32;
1523+
using DefaultOperandA = detail::DefaultGemm_TensorOpXe_OperandA<
1524+
bfloat16_t, LayoutA, kAlignmentA, 32>;
1525+
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
1526+
1527+
// B
1528+
static constexpr int kAlignmentB = 32;
1529+
using DefaultOperandB = detail::DefaultGemm_TensorOpXe_OperandB<
1530+
bfloat16_t, LayoutB, kAlignmentB, 32>;
1531+
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
1532+
1533+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
1534+
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1535+
cute::bfloat16_t, LayoutA, 1,
1536+
cute::bfloat16_t, LayoutB, 1,
1537+
float,
1538+
TileShape, Shape<_1, _1, _1>,
1539+
cutlass::gemm::collective::StageCountAuto,
1540+
cutlass::gemm::collective::KernelScheduleAuto
1541+
>::CollectiveOp;
1542+
1543+
//using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;
1544+
using EpilogueOp = epilogue::fusion::LinearCombination<cute::bfloat16_t, float>;
1545+
1546+
1547+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
1548+
epilogue::IntelXeXMX16,
1549+
EpilogueOp,
1550+
TileShape,
1551+
decltype(tile_shape(TiledMma()))
1552+
>;
1553+
1554+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
1555+
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1556+
TileShape, Shape<_1, _1, _1>,
1557+
cutlass::epilogue::collective::EpilogueTileAuto,
1558+
float, float,
1559+
void, LayoutC, 1,
1560+
cute::bfloat16_t, LayoutC, 1,
1561+
cutlass::epilogue::collective::EpilogueScheduleAuto,
1562+
EpilogueOp
1563+
>::CollectiveOp;
1564+
};
1565+
1566+
///////////////////////////////////////////////////////////////////////////////
1567+
1568+
// Intel XE MMA F32BF16
1569+
// D=Ax B + C; => BF16=BF16xBF16+BF16 <=>BF16=FP32+BF16
1570+
// ElementAccumulator and ElementC are different types.
1571+
template <
1572+
typename LayoutA,
1573+
typename LayoutB,
1574+
typename LayoutC,
1575+
typename ElementAccumulator,
1576+
typename ElementOutput>
1577+
struct XeDefaultGemmConfigurationToCutlass3Types<
1578+
arch::OpClassTensorOp, arch::IntelXe,
1579+
bfloat16_t, LayoutA,
1580+
bfloat16_t, LayoutB,
1581+
bfloat16_t, LayoutC,
1582+
ElementAccumulator,
1583+
ElementOutput>
1584+
{
1585+
using TileShape = Shape<_256, _256, _32>;
1586+
1587+
using TiledMma =
1588+
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
1589+
Layout<TileShape>,
1590+
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
1591+
1592+
// A
1593+
static constexpr int kAlignmentA = 32;
1594+
using DefaultOperandA = detail::DefaultGemm_TensorOpXe_OperandA<
1595+
bfloat16_t, LayoutA, kAlignmentA, 32>;
1596+
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
1597+
1598+
// B
1599+
static constexpr int kAlignmentB = 32;
1600+
using DefaultOperandB = detail::DefaultGemm_TensorOpXe_OperandB<
1601+
bfloat16_t, LayoutB, kAlignmentB, 32>;
1602+
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
1603+
1604+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
1605+
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1606+
cute::bfloat16_t, LayoutA, 1,
1607+
cute::bfloat16_t, LayoutB, 1,
1608+
ElementAccumulator,
1609+
TileShape, Shape<_1, _1, _1>,
1610+
cutlass::gemm::collective::StageCountAuto,
1611+
cutlass::gemm::collective::KernelScheduleAuto
1612+
>::CollectiveOp;
1613+
1614+
using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;
1615+
1616+
using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
1617+
epilogue::IntelXeXMX16,
1618+
EpilogueOp,
1619+
TileShape,
1620+
decltype(tile_shape(TiledMma()))
1621+
>;
1622+
1623+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
1624+
cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1625+
TileShape, Shape<_1, _1, _1>,
1626+
cutlass::epilogue::collective::EpilogueTileAuto,
1627+
ElementAccumulator, float,
1628+
bfloat16_t, LayoutC, 1,
1629+
ElementOutput, LayoutC, 1,
1630+
cutlass::epilogue::collective::EpilogueScheduleAuto,
1631+
EpilogueOp
1632+
>::CollectiveOp;
1633+
};
1634+
1635+
14891636
///////////////////////////////////////////////////////////////////////////////
14901637

14911638
namespace detail {

test/unit/gemm/device/xe_gemm_bf16_bf16_bf16_tensor_op_fp32.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,51 @@ TEST(XE_Device_Gemm_bf16n_bf16n_bf16t_tensor_op_f32, 256x256x32) {
8585
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
8686
}
8787

88+
89+
// ElementC ---> void
90+
// ElementOutput != ElementCompute in LinearCombination
91+
92+
template <typename LayoutA, typename LayoutB>
93+
struct XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void {
94+
using Config =
95+
gemm::device::DefaultGemmConfigurationToCutlass3Types<
96+
arch::OpClassTensorOp, arch::IntelXe,
97+
cute::bfloat16_t, LayoutA,
98+
cute::bfloat16_t, LayoutB,
99+
void, layout::RowMajor,
100+
cute::bfloat16_t>;
101+
102+
using Gemm = gemm::device::GemmUniversalAdapter<
103+
gemm::kernel::GemmUniversal<
104+
cute::Shape<int,int,int,int>,
105+
typename Config::CollectiveMainloop,
106+
typename Config::CollectiveEpilogue>>;
107+
};
108+
109+
TEST(XE_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_void, 256x256x32) {
110+
using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void<
111+
layout::RowMajor, layout::RowMajor>::Gemm;
112+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
113+
}
114+
115+
TEST(XE_Device_Gemm_bf16n_bf16t_bf16t_tensor_op_f32_void, 256x256x32) {
116+
using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void<
117+
layout::ColumnMajor, layout::RowMajor>::Gemm;
118+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
119+
}
120+
121+
TEST(XE_Device_Gemm_bf16t_bf16n_bf16t_tensor_op_f32_void, 256x256x32) {
122+
using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void<
123+
layout::RowMajor, layout::ColumnMajor>::Gemm;
124+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
125+
}
126+
127+
TEST(XE_Device_Gemm_bf16n_bf16n_bf16t_tensor_op_f32_void, 256x256x32) {
128+
using Gemm = XE_Device_Gemm_bf16_bf16_bf16_tensor_op_f32_void<
129+
layout::ColumnMajor, layout::ColumnMajor>::Gemm;
130+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
131+
}
132+
133+
88134
}
89135
} // namespace cutlass
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
32+
/*! \file
33+
\brief Tests for Xe bf16_bf16_fp32 and C is bf16
34+
*/
35+
36+
37+
#include "cutlass/cutlass.h"
38+
39+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
40+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
41+
#include "default_gemm_configuration.hpp"
42+
43+
#include "gemm_testbed_3x.hpp"
44+
45+
namespace cutlass {
46+
namespace {
47+
template <typename LayoutA, typename LayoutB>
48+
struct XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16 {
49+
using Config =
50+
gemm::device::XeDefaultGemmConfigurationToCutlass3Types<
51+
arch::OpClassTensorOp, arch::IntelXe,
52+
cute::bfloat16_t, LayoutA,
53+
cute::bfloat16_t, LayoutB,
54+
cute::bfloat16_t, layout::RowMajor,
55+
float,
56+
cute::bfloat16_t>;
57+
58+
using Gemm = gemm::device::GemmUniversalAdapter<
59+
gemm::kernel::GemmUniversal<
60+
cute::Shape<int,int,int,int>,
61+
typename Config::CollectiveMainloop,
62+
typename Config::CollectiveEpilogue>>;
63+
};
64+
65+
TEST(XE_Device_Gemm_bf16t_bf16t_f32t_tensor_op_bf16, 256x256x32) {
66+
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16<
67+
layout::RowMajor, layout::RowMajor>::Gemm;
68+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
69+
}
70+
71+
TEST(XE_Device_Gemm_bf16n_bf16t_f32t_tensor_op_bf16, 256x256x32) {
72+
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16<
73+
layout::ColumnMajor, layout::RowMajor>::Gemm;
74+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
75+
}
76+
77+
TEST(XE_Device_Gemm_bf16t_bf16n_f32t_tensor_op_bf16, 256x256x32) {
78+
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16<
79+
layout::RowMajor, layout::ColumnMajor>::Gemm;
80+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
81+
}
82+
83+
TEST(XE_Device_Gemm_bf16n_bf16n_f32t_tensor_op_bf16, 256x256x32) {
84+
using Gemm = XE_Device_Gemm_bf16_bf16_f32_tensor_op_bf16<
85+
layout::ColumnMajor, layout::ColumnMajor>::Gemm;
86+
EXPECT_TRUE(test::gemm::device::TestXe<Gemm>());
87+
}
88+
89+
}
90+
} // namespace cutlass

0 commit comments

Comments
 (0)