@@ -62,6 +62,18 @@ struct DefaultGemmConfigurationToCutlass3Types {
6262 static_assert (sizeof (ElementA) == 0 , " No valid DefaultGemmConfigurationToCutlass3Types configuration exists." );
6363};
6464
65+ // This type is only intended to demonstrate porting 2.x kernels to 3.0
66+ template <
67+ class OperatorClass , class ArchTag ,
68+ class ElementA , class LayoutA ,
69+ class ElementB , class LayoutB ,
70+ class ElementC , class LayoutC ,
71+ class ElementAccumulator ,
72+ class ElementOutput >
73+ struct XeDefaultGemmConfigurationToCutlass3Types {
74+ static_assert (sizeof (ElementA) == 0 , " No valid XeDefaultGemmConfigurationToCutlass3Types configuration exists." );
75+ };
76+
6577// /////////////////////////////////////////////////////////////////////////////
6678
6779namespace detail {
@@ -1486,6 +1498,141 @@ struct DefaultGemmConfigurationToCutlass3Types<
14861498 >::CollectiveOp;
14871499};
14881500
1501+ // /////////////////////////////////////////////////////////////////////////////
1502+
1503+ // Intel XE MMA F32BF16
1504+ // ElementC - > void
1505+ // ElementCompute and ElementOutput different in LinearCombination
1506+ template <typename LayoutA, typename LayoutB, typename LayoutC, typename ElementOutput>
1507+ struct DefaultGemmConfigurationToCutlass3Types <
1508+ arch::OpClassTensorOp, arch::IntelXe,
1509+ bfloat16_t , LayoutA,
1510+ bfloat16_t , LayoutB,
1511+ void , LayoutC,
1512+ ElementOutput>
1513+ {
1514+ using TileShape = Shape<_256, _256, _32>;
1515+
1516+ using TiledMma =
1517+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
1518+ Layout<TileShape>,
1519+ Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
1520+
1521+ // A
1522+ static constexpr int kAlignmentA = 32 ;
1523+ using DefaultOperandA = detail::DefaultGemm_TensorOpXe_OperandA<
1524+ bfloat16_t , LayoutA, kAlignmentA , 32 >;
1525+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
1526+
1527+ // B
1528+ static constexpr int kAlignmentB = 32 ;
1529+ using DefaultOperandB = detail::DefaultGemm_TensorOpXe_OperandB<
1530+ bfloat16_t , LayoutB, kAlignmentB , 32 >;
1531+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
1532+
1533+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
1534+ cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1535+ cute::bfloat16_t , LayoutA, 1 ,
1536+ cute::bfloat16_t , LayoutB, 1 ,
1537+ float ,
1538+ TileShape, Shape<_1, _1, _1>,
1539+ cutlass::gemm::collective::StageCountAuto,
1540+ cutlass::gemm::collective::KernelScheduleAuto
1541+ >::CollectiveOp;
1542+
1543+ // using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float>;
1544+ using EpilogueOp = epilogue::fusion::LinearCombination<cute::bfloat16_t , float >;
1545+
1546+
1547+ using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
1548+ epilogue::IntelXeXMX16,
1549+ EpilogueOp,
1550+ TileShape,
1551+ decltype (tile_shape(TiledMma()))
1552+ >;
1553+
1554+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
1555+ cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1556+ TileShape, Shape<_1, _1, _1>,
1557+ cutlass::epilogue::collective::EpilogueTileAuto,
1558+ float , float ,
1559+ void , LayoutC, 1 ,
1560+ cute::bfloat16_t , LayoutC, 1 ,
1561+ cutlass::epilogue::collective::EpilogueScheduleAuto,
1562+ EpilogueOp
1563+ >::CollectiveOp;
1564+ };
1565+
1566+ // /////////////////////////////////////////////////////////////////////////////
1567+
1568+ // Intel XE MMA F32BF16
1569+ // D=Ax B + C; => BF16=BF16xBF16+BF16 <=>BF16=FP32+BF16
1570+ // ElementAccumulator and ElementC are different types.
1571+ template <
1572+ typename LayoutA,
1573+ typename LayoutB,
1574+ typename LayoutC,
1575+ typename ElementAccumulator,
1576+ typename ElementOutput>
1577+ struct XeDefaultGemmConfigurationToCutlass3Types <
1578+ arch::OpClassTensorOp, arch::IntelXe,
1579+ bfloat16_t , LayoutA,
1580+ bfloat16_t , LayoutB,
1581+ bfloat16_t , LayoutC,
1582+ ElementAccumulator,
1583+ ElementOutput>
1584+ {
1585+ using TileShape = Shape<_256, _256, _32>;
1586+
1587+ using TiledMma =
1588+ typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
1589+ Layout<TileShape>,
1590+ Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
1591+
1592+ // A
1593+ static constexpr int kAlignmentA = 32 ;
1594+ using DefaultOperandA = detail::DefaultGemm_TensorOpXe_OperandA<
1595+ bfloat16_t , LayoutA, kAlignmentA , 32 >;
1596+ using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
1597+
1598+ // B
1599+ static constexpr int kAlignmentB = 32 ;
1600+ using DefaultOperandB = detail::DefaultGemm_TensorOpXe_OperandB<
1601+ bfloat16_t , LayoutB, kAlignmentB , 32 >;
1602+ using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
1603+
1604+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
1605+ cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1606+ cute::bfloat16_t , LayoutA, 1 ,
1607+ cute::bfloat16_t , LayoutB, 1 ,
1608+ ElementAccumulator,
1609+ TileShape, Shape<_1, _1, _1>,
1610+ cutlass::gemm::collective::StageCountAuto,
1611+ cutlass::gemm::collective::KernelScheduleAuto
1612+ >::CollectiveOp;
1613+
1614+ using EpilogueOp = epilogue::fusion::LinearCombination<ElementOutput, float >;
1615+
1616+ using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks<
1617+ epilogue::IntelXeXMX16,
1618+ EpilogueOp,
1619+ TileShape,
1620+ decltype (tile_shape(TiledMma()))
1621+ >;
1622+
1623+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
1624+ cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp,
1625+ TileShape, Shape<_1, _1, _1>,
1626+ cutlass::epilogue::collective::EpilogueTileAuto,
1627+ ElementAccumulator, float ,
1628+ bfloat16_t , LayoutC, 1 ,
1629+ ElementOutput, LayoutC, 1 ,
1630+ cutlass::epilogue::collective::EpilogueScheduleAuto,
1631+ EpilogueOp
1632+ >::CollectiveOp;
1633+ };
1634+
1635+
14891636// /////////////////////////////////////////////////////////////////////////////
14901637
14911638namespace detail {
0 commit comments