diff --git a/include/imex/Conversion/ConvertToSPIRV/ConvertToSPIRV.h b/include/imex/Conversion/ConvertToSPIRV/ConvertToSPIRV.h new file mode 100644 index 000000000..770b654ce --- /dev/null +++ b/include/imex/Conversion/ConvertToSPIRV/ConvertToSPIRV.h @@ -0,0 +1,36 @@ +//===- ConvertToSPIRV.h - Converts everything to SPIR-V dialect - *-C++ -*-===// +// +// Copyright 2025 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef IMEX_CONVERSION_CONVERTTOSPIRV_H +#define IMEX_CONVERSION_CONVERTTOSPIRV_H + +#include "imex/Utils/XeCommon.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class ConversionTarget; +class SPIRVTypeConverter; +class Pass; +class Operation; +class RewritePatternSet; +template class OperationPass; +} // namespace mlir + +namespace imex { +#define GEN_PASS_DECL_CONVERTTOSPIRV +#include "imex/Conversion/Passes.h.inc" + +std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>> +createConvertToSPIRVPass(); + +} // namespace imex +#endif diff --git a/include/imex/Conversion/Passes.h b/include/imex/Conversion/Passes.h index 1efd47ab0..044923b35 100644 --- a/include/imex/Conversion/Passes.h +++ b/include/imex/Conversion/Passes.h @@ -18,6 +18,7 @@ #include "mlir/Pass/Pass.h" #include +#include #include #include #include diff --git a/include/imex/Conversion/Passes.td b/include/imex/Conversion/Passes.td index e4c0cecca..297a2f659 100644 --- a/include/imex/Conversion/Passes.td +++ b/include/imex/Conversion/Passes.td @@ -258,6 +258,43 @@ memref, arith and math. let dependentDialects = ["::mlir::spirv::SPIRVDialect"]; } + +//===----------------------------------------------------------------------===// +// ConvertToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertToSPIRV : Pass<"imex-convert-to-spirv", "::mlir::ModuleOp"> { + let summary = "Convert to SPIR-V dialect by using all the 'to SPIR-V' conversion patterns from all the dialect conversions"; + let description = [{ + This is a one-shot pass to convert to SPIR-V dialect by using all the 'to SPIR-V' conversion patterns from all the dialect conversions. + It includes the GPU to SPIR-V conversion as well as other dialects like SCF, Math, Arith etc. + + }]; + let dependentDialects = ["::mlir::spirv::SPIRVDialect"]; + let constructor = "imex::createConvertToSPIRVPass()"; + let options = [ + // arith, cf, func, tensor to SPIR-V options + Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", + "bool", /*default=*/"true", + "Emulate narrower scalar types with 32-bit ones if not supported by " + "the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width">, + // gpu, Index to SPIR-V options + Option<"use64bitIndex", "use-64bit-index", + "bool", /*default=*/"false", + "Use 64-bit integers to convert index types">, + // memref to SPIR-V options + Option<"boolNumBits", "bool-num-bits", + "int", /*default=*/"8", + "The number of bits to store a boolean value">, + + ]; +} + + //===----------------------------------------------------------------------===// // GPUToGPUX //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index e02db3694..8182a7417 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(ArithToVC) +add_subdirectory(ConvertToSPIRV) add_subdirectory(NDArrayToLinalg) add_subdirectory(DropRegions) add_subdirectory(RegionParallelLoopToGpu) diff --git a/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/lib/Conversion/ConvertToSPIRV/CMakeLists.txt new file mode 100644 index 000000000..cf841cc7b --- /dev/null +++ b/lib/Conversion/ConvertToSPIRV/CMakeLists.txt @@ -0,0 +1,24 @@ +add_imex_conversion_library(IMEXConvertToSPIRV + ConvertToSPIRVPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/imex/Conversion/ConvertToSPIRV + + DEPENDS + IMEXConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithToSPIRV + MLIRControlFlowToSPIRV + MLIRFuncToSPIRV + MLIRGPUDialect + MLIRGPUToSPIRV + MLIRIR + MLIRMathToSPIRV + MLIRPass + MLIRSCFToSPIRV + MLIRSPIRVDialect + MLIRSPIRVConversion + MLIRSupport + MLIRTransforms + ) diff --git a/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp new file mode 100644 index 000000000..df2d7538b --- /dev/null +++ b/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp @@ -0,0 +1,354 @@ +//===-------- ConvertToSPIRV.cpp - one shot convert-to-spirv pass --------===// +// +// Copyright 2025 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Generate a convert-to-spirv pass. Add all `to-spirv` patterns from all the +// dialects. + +#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" +#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h" +#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" +#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" +#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" +#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" +#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" +#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" +#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" +#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" +#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/DebugLog.h" +#include + +#include "imex/Conversion/ConvertToSPIRV/ConvertToSPIRV.h" +#include "imex/Conversion/Passes.h" + +namespace imex { +#define GEN_PASS_DEF_CONVERTTOSPIRV +#include "imex/Conversion/Passes.h.inc" +} // namespace imex + +using namespace mlir; +using namespace imex; +namespace imex { +// This op: +// vector.create_mask %maskVal : vector +// is lowered to: +// if maskVal < 0 +// mask = 0 +// else if maskVal < vWidth +// mask = (1 << maskVal) - 1 +// else +// mask = all ones +class VectorMaskConversionPattern final + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::vector::CreateMaskOp vMaskOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::VectorType vTy = vMaskOp.getVectorType(); + if (vTy.getRank() != 1) + return mlir::failure(); + + auto vWidth = vTy.getNumElements(); + assert(vWidth <= 64 && "vector.create_mask supports vector widths <= 64"); + auto vWidthConst = rewriter.create( + vMaskOp.getLoc(), rewriter.getI64IntegerAttr(vWidth)); + auto maskVal = adaptor.getOperands()[0]; + maskVal = rewriter.create( + vMaskOp.getLoc(), rewriter.getI64Type(), maskVal); + + // maskVal < vWidth + auto cmp = rewriter.create( + vMaskOp.getLoc(), mlir::arith::CmpIPredicate::slt, maskVal, + vWidthConst); + auto one = rewriter.create( + vMaskOp.getLoc(), rewriter.getI64IntegerAttr(1)); + auto shift = rewriter.create( + vMaskOp.getLoc(), one, maskVal); + auto mask1 = + rewriter.create(vMaskOp.getLoc(), shift, one); + auto mask2 = rewriter.create( + vMaskOp.getLoc(), rewriter.getI64IntegerAttr(-1)); // all ones + mlir::Value sel = rewriter.create(vMaskOp.getLoc(), + cmp, mask1, mask2); + + // maskVal < 0 + auto zero = rewriter.create( + vMaskOp.getLoc(), rewriter.getI64IntegerAttr(0)); + auto cmp2 = rewriter.create( + vMaskOp.getLoc(), mlir::arith::CmpIPredicate::slt, maskVal, zero); + sel = rewriter.create(vMaskOp.getLoc(), cmp2, zero, + sel); + + sel = rewriter.create( + vMaskOp.getLoc(), rewriter.getIntegerType(vWidth), sel); + auto res = rewriter.create( + vMaskOp.getLoc(), mlir::VectorType::get({vWidth}, rewriter.getI1Type()), + sel); + vMaskOp->replaceAllUsesWith(res); + rewriter.eraseOp(vMaskOp); + return mlir::success(); + } +}; + +// This pattern converts vector.from_elements op to SPIR-V CompositeInsertOp +class VectorFromElementsConversionPattern final + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::vector::FromElementsOp fromElementsOp, + OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::VectorType vecTy = fromElementsOp.getType(); + if (vecTy.getRank() > 1) + return rewriter.notifyMatchFailure(fromElementsOp, + "rank > 1 vectors are not supported"); + + mlir::Type spirvVecTy = getTypeConverter()->convertType(vecTy); + if (!spirvVecTy) + return mlir::failure(); + + // if the vector is just constructed from one element + if (mlir::isa(spirvVecTy)) { + rewriter.replaceOp(fromElementsOp, adaptor.getElements()[0]); + return mlir::success(); + } + + auto loc = fromElementsOp.getLoc(); + mlir::Value result = rewriter.create(loc, spirvVecTy); + for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) { + result = rewriter.create(loc, val, result, + idx); + } + rewriter.replaceOp(fromElementsOp, result); + return mlir::success(); + } +}; + +void populateIMEXVectorToSPIRVPatterns(mlir::SPIRVTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns) { + patterns + .add( + typeConverter, patterns.getContext()); +} +} // namespace imex + +namespace { + +// Populate upstream conversion patterns for each dialect. +void populateUpstreamConvertToSPIRVPatterns( + const SPIRVTypeConverter &typeConverter, + ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns) { + arith::populateCeilFloorDivExpandOpsPatterns(patterns); + arith::populateArithToSPIRVPatterns(typeConverter, patterns); + populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); + cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns); + populateComplexToSPIRVPatterns(typeConverter, patterns); + populateFuncToSPIRVPatterns(typeConverter, patterns); + populateGPUToSPIRVPatterns(typeConverter, patterns); + index::populateIndexToSPIRVPatterns(typeConverter, patterns); + populateMathToSPIRVPatterns(typeConverter, patterns); + populateMemRefToSPIRVPatterns(typeConverter, patterns); + populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); + populateTensorToSPIRVPatterns(typeConverter, + /*byteCountThreshold=*/64, patterns); + ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); + mlir::populateVectorToSPIRVPatterns(typeConverter, patterns); +} + +struct ConvertToSPIRVPass + : public imex::impl::ConvertToSPIRVBase { + void runOnOperation() override; + +private: + // Queries the target environment from 'targets' attribute of the given + // `moduleOp`. + spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp); + + // Queries the target environment from 'targets' attribute of the given + // `moduleOp` or returns target environment as returned by + // `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'. + spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp); + // Map memRef memory space to SPIR-V storage class. + void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr); + // bool mapMemorySpace; +}; + +spirv::TargetEnvAttr +ConvertToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) { + if (ArrayAttr targets = moduleOp.getTargetsAttr()) { + for (Attribute targetAttr : targets) + if (auto spirvTargetEnvAttr = dyn_cast(targetAttr)) + return spirvTargetEnvAttr; + } + return {}; +} + +spirv::TargetEnvAttr +ConvertToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) { + if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp)) + return targetEnvAttr; + return spirv::lookupTargetEnvOrDefault(moduleOp); +} + +// Map memRef memory space to SPIR-V storage class. +void ConvertToSPIRVPass::mapToMemRef(Operation *op, + spirv::TargetEnvAttr &targetAttr) { + spirv::TargetEnv targetEnv(targetAttr); + bool targetEnvSupportsKernelCapability = + targetEnv.allows(spirv::Capability::Kernel); + spirv::MemorySpaceToStorageClassMap memorySpaceMap = + targetEnvSupportsKernelCapability + ? spirv::mapMemorySpaceToOpenCLStorageClass + : spirv::mapMemorySpaceToVulkanStorageClass; + spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); + spirv::convertMemRefTypesAndAttrs(op, converter); + + // Check if there are any illegal ops remaining. + std::unique_ptr target = + spirv::getMemorySpaceToStorageClassTarget(*op->getContext()); + + op->walk([&target, this](Operation *childOp) { + if (target->isIllegal(childOp)) { + childOp->emitOpError("failed to legalize memory space"); + signalPassFailure(); // Now this works because it's a member function + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); +} + +void ConvertToSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + Operation *op = getOperation(); + + SmallVector gpuModules; + OpBuilder builder(context); + + auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) { + auto targetAttr = lookupTargetEnvOrDefault(moduleOp); + spirv::TargetEnv targetEnv(targetAttr); + return targetEnv.allows(spirv::Capability::Kernel); + }; + + op->walk([&](gpu::GPUModuleOp moduleOp) { + // Clone each GPU kernel module for conversion, given that the GPU + // launch op still needs the original GPU kernel module. + // For Vulkan Shader capabilities, we insert the newly converted SPIR-V + // module right after the original GPU module, as that's the expectation + // of the in-tree SPIR-V CPU runner (the Vulkan runner does not use this + // pass). For OpenCL Kernel capabilities, we insert the newly converted + // SPIR-V module inside the original GPU module, as that's the expectaion + // of the normal GPU compilation pipeline. + if (targetEnvSupportsKernelCapability(moduleOp)) { + builder.setInsertionPointToStart(moduleOp.getBody()); + } else { + builder.setInsertionPoint(moduleOp.getOperation()); + } + gpuModules.push_back(builder.clone(*moduleOp.getOperation())); + }); + + // Run conversion for each gpu module independently as they can have + // different TargetEnv attributes. + for (Operation *gpuModule : gpuModules) { + // Configure conversion target + auto castedGPUModule = mlir::dyn_cast(*gpuModule); + spirv::TargetEnvAttr targetAttr = lookupTargetEnvOrDefault(castedGPUModule); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + // Set up type converter with SPIR-V type conversion and pass options + SPIRVConversionOptions options; + options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; + options.use64bitIndex = this->use64bitIndex; + options.boolNumBits = this->boolNumBits; + SPIRVTypeConverter typeConverter(targetAttr, options); + + // Upstream SPIRVTypeConverter does not add conversion for + // UnrankedMemRefType. + // Conversion logic is the same as ranked dynamic memref type for OpenCL + // Kernel. unranked memref type is converted to a spirv pointer type + // with converted spirv scalar element type and spirv storage class. + // Only scalar element type is currently supported. + // Also vulkan should be handled differently but out of scope since this + // conversion pass is for lowering to OpenCL spirv kernel only. + typeConverter.addConversion( + [&](mlir::UnrankedMemRefType type) -> std::optional { + auto attr = mlir::dyn_cast_or_null( + type.getMemorySpace()); + if (!attr) + return nullptr; + mlir::spirv::StorageClass storageClass = attr.getValue(); + + mlir::Type elementType = type.getElementType(); + auto scalarType = + mlir::dyn_cast(elementType); + if (!scalarType) + return nullptr; + mlir::Type arrayElemType = typeConverter.convertType(scalarType); + return mlir::spirv::PointerType::get(arrayElemType, storageClass); + }); + + // Add all to-SPIRV conversion patterns + RewritePatternSet patterns(context); + // Upstream patterns + ScfToSPIRVContext scfToSPIRVContext; + mapToMemRef(gpuModule, targetAttr); + populateUpstreamConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext, + patterns); + // IMEX patterns + imex::populateIMEXVectorToSPIRVPatterns(typeConverter, patterns); + // Apply conversion + if (failed(applyFullConversion(gpuModule, *target, std::move(patterns)))) { + signalPassFailure(); + } + } + // For OpenCL, the gpu.func op in the original gpu.module op needs to be + // replaced with an empty func.func op with the same arguments as the + // gpu.func op. The func.func op needs gpu.kernel attribute set. + op->walk([&](gpu::GPUModuleOp moduleOp) { + if (targetEnvSupportsKernelCapability(moduleOp)) { + moduleOp.walk([&](gpu::GPUFuncOp funcOp) { + builder.setInsertionPoint(funcOp); + auto newFuncOp = + func::FuncOp::create(builder, funcOp.getLoc(), funcOp.getName(), + funcOp.getFunctionType()); + auto entryBlock = newFuncOp.addEntryBlock(); + builder.setInsertionPointToEnd(entryBlock); + func::ReturnOp::create(builder, funcOp.getLoc()); + newFuncOp->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + funcOp.erase(); + }); + } + }); +} + +} // namespace + +std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>> +imex::createConvertToSPIRVPass() { + return std::make_unique(); +} diff --git a/test/Integration/Dialect/Gpu/AsyncTests.mlir b/test/Integration/Dialect/Gpu/AsyncTests.mlir index 059059628..6dbc4fe5f 100644 --- a/test/Integration/Dialect/Gpu/AsyncTests.mlir +++ b/test/Integration/Dialect/Gpu/AsyncTests.mlir @@ -1,40 +1,29 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @eltwise_add attributes {gpu.container_module} { - func.func @fillRandom(%arg0: memref<4194304xf32>, %arg0_gpu: memref<4194304xf32>) -> () { - %S0L = arith.constant 10.0 : f32 - %S0H = arith.constant 50.0 : f32 - %false = arith.constant 0 : i1 - - %arg0_random = memref.cast %arg0 : memref<4194304xf32> to memref<*xf32> - call @fillResource1DRandomF32(%arg0_random, %S0L, %S0H, %false) : (memref<*xf32>, f32, f32, i1) -> () - - memref.copy %arg0, %arg0_gpu : memref<4194304xf32> to memref<4194304xf32> +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +module @eltwise_add attributes {gpu.container_module} { + func.func @fillRandom(%arg0: memref<4194304xf32>, %arg1: memref<4194304xf32>) { + %cst = arith.constant 1.000000e+01 : f32 + %cst_0 = arith.constant 5.000000e+01 : f32 + %false = arith.constant false + %cast = memref.cast %arg0 : memref<4194304xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + memref.copy %arg0, %arg1 : memref<4194304xf32> to memref<4194304xf32> return } - - func.func @fillZeros(%res: memref<4194304xf32>, %res_gpu: memref<4194304xf32>) -> () { - %c0 = arith.constant 0.0 : f32 - - %res_zeros = memref.cast %res : memref<4194304xf32> to memref<*xf32> - call @fillResource1DF32(%res_zeros, %c0) : (memref<*xf32>, f32) -> () - - memref.copy %res, %res_gpu : memref<4194304xf32> to memref<4194304xf32> - + func.func @fillZeros(%arg0: memref<4194304xf32>, %arg1: memref<4194304xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %cast = memref.cast %arg0 : memref<4194304xf32> to memref<*xf32> + call @fillResource1DF32(%cast, %cst) : (memref<*xf32>, f32) -> () + memref.copy %arg0, %arg1 : memref<4194304xf32> to memref<4194304xf32> return } - - gpu.module @eltwiseAdd_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @eltwiseAdd_kernel { gpu.func @eltwiseAdd_kernel(%arg0: memref<4194304xf32>, %arg1: memref<4194304xf32>, %arg2: memref<4194304xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %global_id_x = gpu.global_id x - %cst = arith.constant 0.5 : f32 + %global_id_x = gpu.global_id x + %cst = arith.constant 5.000000e-01 : f32 %0 = memref.load %arg0[%global_id_x] : memref<4194304xf32> %1 = memref.load %arg1[%global_id_x] : memref<4194304xf32> %2 = arith.addf %0, %1 : f32 @@ -43,144 +32,113 @@ module @eltwise_add attributes {gpu.container_module} { gpu.return } } - // compute CPU reference (takes minutes) func.func @cpu_reference(%arg0: memref<4194304xf32>, %arg1: memref<4194304xf32>, %arg2: memref<4194304xf32>) { %c4194304 = arith.constant 4194304 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %cst = arith.constant 5.000000e-01 : f32 - scf.for %i = %c0 to %c4194304 step %c1 { - %0 = memref.load %arg0[%i] : memref<4194304xf32> - %1 = memref.load %arg1[%i] : memref<4194304xf32> + scf.for %arg3 = %c0 to %c4194304 step %c1 { + %0 = memref.load %arg0[%arg3] : memref<4194304xf32> + %1 = memref.load %arg1[%arg3] : memref<4194304xf32> %2 = arith.addf %0, %1 : f32 %3 = arith.addf %2, %cst : f32 - memref.store %3, %arg2[%i] : memref<4194304xf32> + memref.store %3, %arg2[%arg3] : memref<4194304xf32> } return } - func.func @main() { %c1 = arith.constant 1 : index %c512 = arith.constant 512 : index %c8192 = arith.constant 8192 : index - - %arg0 = memref.alloc() : memref<4194304xf32> - %arg0_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillRandom(%arg0, %arg0_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - - %arg1 = memref.alloc() : memref<4194304xf32> - %arg1_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillRandom(%arg1, %arg1_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - - %arg2 = memref.alloc() : memref<4194304xf32> - %arg2_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillRandom(%arg2, %arg2_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - - %arg3 = memref.alloc() : memref<4194304xf32> - %arg3_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillRandom(%arg3, %arg3_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - - %res0 = memref.alloc() : memref<4194304xf32> - %res0_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillZeros(%res0, %res0_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - - %res1 = memref.alloc() : memref<4194304xf32> - %res1_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillZeros(%res1, %res1_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - - %res2 = memref.alloc() : memref<4194304xf32> - %res2_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillZeros(%res2, %res2_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - - %res = memref.alloc() : memref<4194304xf32> - %res_gpu = gpu.alloc host_shared () : memref<4194304xf32> - call @fillZeros(%res, %res_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - // Test1: Two async launches followed by sync launch that // waits for events returned by the two async launches - - %e1 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg0_gpu: memref<4194304xf32>, %arg1_gpu: memref<4194304xf32>, %res0_gpu: memref<4194304xf32>) - %e2 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg2_gpu: memref<4194304xf32>, %arg3_gpu: memref<4194304xf32>, %res1_gpu: memref<4194304xf32>) - gpu.launch_func [%e1, %e2] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%res0_gpu: memref<4194304xf32>, %res1_gpu: memref<4194304xf32>, %res_gpu: memref<4194304xf32>) - - call @cpu_reference(%arg0, %arg1, %res0) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - call @cpu_reference(%arg2, %arg3, %res1) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - call @cpu_reference(%res0, %res1, %res) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - - %cast_res = memref.cast %res : memref<4194304xf32> to memref<*xf32> - %cast_res_gpu = memref.cast %res_gpu : memref<4194304xf32> to memref<*xf32> - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_res, %cast_res_gpu) : (memref<*xf32>, memref<*xf32>) -> () - - call @fillZeros(%res0, %res0_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - call @fillZeros(%res1, %res1_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - call @fillZeros(%res, %res_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - // Test2: An async launch followed by another async launch and // finally a sync launch. Each launch waits on the event // from the preceeding launch. - - %e3 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg0_gpu: memref<4194304xf32>, %arg1_gpu: memref<4194304xf32>, %res0_gpu: memref<4194304xf32>) - %e4 = gpu.launch_func async [%e3] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg2_gpu: memref<4194304xf32>, %res0_gpu: memref<4194304xf32>, %res1_gpu: memref<4194304xf32>) - gpu.launch_func [%e4] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg3_gpu: memref<4194304xf32>, %res1_gpu: memref<4194304xf32>, %res_gpu: memref<4194304xf32>) - - call @cpu_reference(%arg0, %arg1, %res0) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - call @cpu_reference(%arg2, %res0, %res1) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - call @cpu_reference(%arg3, %res1, %res) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - - %cast_res_0 = memref.cast %res : memref<4194304xf32> to memref<*xf32> - %cast_res_gpu_0 = memref.cast %res_gpu : memref<4194304xf32> to memref<*xf32> - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_res_0, %cast_res_gpu_0) : (memref<*xf32>, memref<*xf32>) -> () - - call @fillZeros(%res0, %res0_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - call @fillZeros(%res1, %res1_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - call @fillZeros(%res, %res_gpu) : (memref<4194304xf32>, memref<4194304xf32>) -> () - // Test3: An async launch followed by two async launches and // finally a sync launch. The event from the first async launch // is passed to the subsequent two async launches which wait on // the same event. The last sync launch waits from two events // from the preceeding two async launches. - - %e5 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg0_gpu: memref<4194304xf32>, %arg1_gpu: memref<4194304xf32>, %res0_gpu: memref<4194304xf32>) - %e6 = gpu.launch_func async [%e5] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg2_gpu: memref<4194304xf32>, %res0_gpu: memref<4194304xf32>, %res1_gpu: memref<4194304xf32>) - %e7 = gpu.launch_func async [%e5] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%arg3_gpu: memref<4194304xf32>, %res0_gpu: memref<4194304xf32>, %res2_gpu: memref<4194304xf32>) - gpu.launch_func [%e6, %e7] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%res1_gpu: memref<4194304xf32>, %res2_gpu: memref<4194304xf32>, %res_gpu: memref<4194304xf32>) - - call @cpu_reference(%arg0, %arg1, %res0) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - call @cpu_reference(%arg2, %res0, %res1) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - call @cpu_reference(%arg3, %res0, %res2) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - call @cpu_reference(%res1, %res2, %res) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () - - %cast_res_1 = memref.cast %res : memref<4194304xf32> to memref<*xf32> - %cast_res_gpu_1 = memref.cast %res_gpu : memref<4194304xf32> to memref<*xf32> - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_res_1, %cast_res_gpu_1) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %arg0 : memref<4194304xf32> - memref.dealloc %arg1 : memref<4194304xf32> - memref.dealloc %arg2 : memref<4194304xf32> - memref.dealloc %arg3 : memref<4194304xf32> - memref.dealloc %res0 : memref<4194304xf32> - memref.dealloc %res1 : memref<4194304xf32> - memref.dealloc %res : memref<4194304xf32> - - gpu.dealloc %arg0_gpu : memref<4194304xf32> - gpu.dealloc %arg1_gpu : memref<4194304xf32> - gpu.dealloc %arg2_gpu : memref<4194304xf32> - gpu.dealloc %arg3_gpu : memref<4194304xf32> - gpu.dealloc %res0_gpu : memref<4194304xf32> - gpu.dealloc %res1_gpu : memref<4194304xf32> - gpu.dealloc %res_gpu : memref<4194304xf32> - + %alloc = memref.alloc() : memref<4194304xf32> + %memref = gpu.alloc () : memref<4194304xf32> + call @fillRandom(%alloc, %memref) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %alloc_0 = memref.alloc() : memref<4194304xf32> + %memref_1 = gpu.alloc () : memref<4194304xf32> + call @fillRandom(%alloc_0, %memref_1) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %alloc_2 = memref.alloc() : memref<4194304xf32> + %memref_3 = gpu.alloc () : memref<4194304xf32> + call @fillRandom(%alloc_2, %memref_3) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %alloc_4 = memref.alloc() : memref<4194304xf32> + %memref_5 = gpu.alloc () : memref<4194304xf32> + call @fillRandom(%alloc_4, %memref_5) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %alloc_6 = memref.alloc() : memref<4194304xf32> + %memref_7 = gpu.alloc () : memref<4194304xf32> + call @fillZeros(%alloc_6, %memref_7) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %alloc_8 = memref.alloc() : memref<4194304xf32> + %memref_9 = gpu.alloc () : memref<4194304xf32> + call @fillZeros(%alloc_8, %memref_9) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %alloc_10 = memref.alloc() : memref<4194304xf32> + %memref_11 = gpu.alloc () : memref<4194304xf32> + call @fillZeros(%alloc_10, %memref_11) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %alloc_12 = memref.alloc() : memref<4194304xf32> + %memref_13 = gpu.alloc () : memref<4194304xf32> + call @fillZeros(%alloc_12, %memref_13) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %0 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref : memref<4194304xf32>, %memref_1 : memref<4194304xf32>, %memref_7 : memref<4194304xf32>) + %1 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref_3 : memref<4194304xf32>, %memref_5 : memref<4194304xf32>, %memref_9 : memref<4194304xf32>) + gpu.launch_func [%0, %1] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref_7 : memref<4194304xf32>, %memref_9 : memref<4194304xf32>, %memref_13 : memref<4194304xf32>) + call @cpu_reference(%alloc, %alloc_0, %alloc_6) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + call @cpu_reference(%alloc_2, %alloc_4, %alloc_8) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + call @cpu_reference(%alloc_6, %alloc_8, %alloc_12) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + %cast = memref.cast %alloc_12 : memref<4194304xf32> to memref<*xf32> + %cast_14 = memref.cast %memref_13 : memref<4194304xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_14) : (memref<*xf32>, memref<*xf32>) -> () + call @fillZeros(%alloc_6, %memref_7) : (memref<4194304xf32>, memref<4194304xf32>) -> () + call @fillZeros(%alloc_8, %memref_9) : (memref<4194304xf32>, memref<4194304xf32>) -> () + call @fillZeros(%alloc_12, %memref_13) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %2 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref : memref<4194304xf32>, %memref_1 : memref<4194304xf32>, %memref_7 : memref<4194304xf32>) + %3 = gpu.launch_func async [%2] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref_3 : memref<4194304xf32>, %memref_7 : memref<4194304xf32>, %memref_9 : memref<4194304xf32>) + gpu.launch_func [%3] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref_5 : memref<4194304xf32>, %memref_9 : memref<4194304xf32>, %memref_13 : memref<4194304xf32>) + call @cpu_reference(%alloc, %alloc_0, %alloc_6) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + call @cpu_reference(%alloc_2, %alloc_6, %alloc_8) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + call @cpu_reference(%alloc_4, %alloc_8, %alloc_12) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + %cast_15 = memref.cast %alloc_12 : memref<4194304xf32> to memref<*xf32> + %cast_16 = memref.cast %memref_13 : memref<4194304xf32> to memref<*xf32> + call @printAllcloseF32(%cast_15, %cast_16) : (memref<*xf32>, memref<*xf32>) -> () + call @fillZeros(%alloc_6, %memref_7) : (memref<4194304xf32>, memref<4194304xf32>) -> () + call @fillZeros(%alloc_8, %memref_9) : (memref<4194304xf32>, memref<4194304xf32>) -> () + call @fillZeros(%alloc_12, %memref_13) : (memref<4194304xf32>, memref<4194304xf32>) -> () + %4 = gpu.launch_func async @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref : memref<4194304xf32>, %memref_1 : memref<4194304xf32>, %memref_7 : memref<4194304xf32>) + %5 = gpu.launch_func async [%4] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref_3 : memref<4194304xf32>, %memref_7 : memref<4194304xf32>, %memref_9 : memref<4194304xf32>) + %6 = gpu.launch_func async [%4] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref_5 : memref<4194304xf32>, %memref_7 : memref<4194304xf32>, %memref_11 : memref<4194304xf32>) + gpu.launch_func [%5, %6] @eltwiseAdd_kernel::@eltwiseAdd_kernel blocks in (%c8192, %c1, %c1) threads in (%c512, %c1, %c1) args(%memref_9 : memref<4194304xf32>, %memref_11 : memref<4194304xf32>, %memref_13 : memref<4194304xf32>) + call @cpu_reference(%alloc, %alloc_0, %alloc_6) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + call @cpu_reference(%alloc_2, %alloc_6, %alloc_8) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + call @cpu_reference(%alloc_4, %alloc_6, %alloc_10) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + call @cpu_reference(%alloc_8, %alloc_10, %alloc_12) : (memref<4194304xf32>, memref<4194304xf32>, memref<4194304xf32>) -> () + %cast_17 = memref.cast %alloc_12 : memref<4194304xf32> to memref<*xf32> + %cast_18 = memref.cast %memref_13 : memref<4194304xf32> to memref<*xf32> + call @printAllcloseF32(%cast_17, %cast_18) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4194304xf32> + memref.dealloc %alloc_0 : memref<4194304xf32> + memref.dealloc %alloc_2 : memref<4194304xf32> + memref.dealloc %alloc_4 : memref<4194304xf32> + memref.dealloc %alloc_6 : memref<4194304xf32> + memref.dealloc %alloc_8 : memref<4194304xf32> + memref.dealloc %alloc_12 : memref<4194304xf32> + gpu.dealloc %memref : memref<4194304xf32> + gpu.dealloc %memref_1 : memref<4194304xf32> + gpu.dealloc %memref_3 : memref<4194304xf32> + gpu.dealloc %memref_5 : memref<4194304xf32> + gpu.dealloc %memref_7 : memref<4194304xf32> + gpu.dealloc %memref_9 : memref<4194304xf32> + gpu.dealloc %memref_13 : memref<4194304xf32> return } - func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface} func.func private @fillResource1DF32(memref<*xf32>, f32) attributes {llvm.emit_c_interface} func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir b/test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir index 3318148f6..6291f56a8 100644 --- a/test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir +++ b/test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir @@ -1,11 +1,7 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ +// RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @eltwise_add attributes {gpu.container_module} { memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01> @@ -13,24 +9,24 @@ module @eltwise_add attributes {gpu.container_module} { %c20 = arith.constant 20 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<10x20xbf16> - memref.copy %arg1, %memref : memref<10x20xbf16> to memref<10x20xbf16> - %memref_0 = gpu.alloc host_shared () : memref<10x20xbf16> - memref.copy %arg0, %memref_0 : memref<10x20xbf16> to memref<10x20xbf16> - %memref_1 = gpu.alloc host_shared () : memref<10x20xbf16> + %memref = gpu.alloc () : memref<10x20xbf16> + gpu.memcpy %memref, %arg1 : memref<10x20xbf16>, memref<10x20xbf16> + %memref_0 = gpu.alloc () : memref<10x20xbf16> + gpu.memcpy %memref_0, %arg0 : memref<10x20xbf16>, memref<10x20xbf16> + %memref_1 = gpu.alloc () : memref<10x20xbf16> gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<10x20xbf16>, %memref : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>) %alloc = memref.alloc() : memref<10x20xbf16> - memref.copy %memref_1, %alloc : memref<10x20xbf16> to memref<10x20xbf16> + gpu.memcpy %alloc, %memref_1 : memref<10x20xbf16>, memref<10x20xbf16> gpu.dealloc %memref_1 : memref<10x20xbf16> gpu.dealloc %memref_0 : memref<10x20xbf16> gpu.dealloc %memref : memref<10x20xbf16> return %alloc : memref<10x20xbf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %block_id_x = gpu.block_id x %block_id_y = gpu.block_id y - %cst = arith.constant 0.5 : bf16 + %cst = arith.constant 5.000000e-01 : bf16 %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16> %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16> %2 = arith.addf %0, %1 : bf16 @@ -49,5 +45,5 @@ module @eltwise_add attributes {gpu.container_module} { call @printMemrefBF16(%cast) : (memref<*xbf16>) -> () return } - func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/Gpu/EltwiseAdd_BF16_single_elem_vector.mlir b/test/Integration/Dialect/Gpu/EltwiseAdd_BF16_single_elem_vector.mlir index bf5c6b026..95d7c9a18 100644 --- a/test/Integration/Dialect/Gpu/EltwiseAdd_BF16_single_elem_vector.mlir +++ b/test/Integration/Dialect/Gpu/EltwiseAdd_BF16_single_elem_vector.mlir @@ -1,11 +1,7 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ +// RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @eltwise_add attributes {gpu.container_module} { memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01> @@ -13,31 +9,31 @@ module @eltwise_add attributes {gpu.container_module} { %c20 = arith.constant 20 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<10x20xbf16> - memref.copy %arg1, %memref : memref<10x20xbf16> to memref<10x20xbf16> - %memref_0 = gpu.alloc host_shared () : memref<10x20xbf16> - memref.copy %arg0, %memref_0 : memref<10x20xbf16> to memref<10x20xbf16> - %memref_1 = gpu.alloc host_shared () : memref<10x20xbf16> + %memref = gpu.alloc () : memref<10x20xbf16> + gpu.memcpy %memref, %arg1 : memref<10x20xbf16>, memref<10x20xbf16> + %memref_0 = gpu.alloc () : memref<10x20xbf16> + gpu.memcpy %memref_0, %arg0 : memref<10x20xbf16>, memref<10x20xbf16> + %memref_1 = gpu.alloc () : memref<10x20xbf16> gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<10x20xbf16>, %memref : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>) %alloc = memref.alloc() : memref<10x20xbf16> - memref.copy %memref_1, %alloc : memref<10x20xbf16> to memref<10x20xbf16> + gpu.memcpy %alloc, %memref_1 : memref<10x20xbf16>, memref<10x20xbf16> gpu.dealloc %memref_1 : memref<10x20xbf16> gpu.dealloc %memref_0 : memref<10x20xbf16> gpu.dealloc %memref : memref<10x20xbf16> return %alloc : memref<10x20xbf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %block_id_x = gpu.block_id x %block_id_y = gpu.block_id y - %cst = arith.constant dense<0.5> : vector<1xbf16> + %cst = arith.constant dense<5.000000e-01> : vector<1xbf16> %0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16> %1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16> - %vec_0 = vector.from_elements %0 : vector<1xbf16> - %vec_1 = vector.from_elements %1 : vector<1xbf16> - %2 = arith.addf %vec_0, %vec_1 : vector<1xbf16> - %3 = arith.addf %2, %cst : vector<1xbf16> - vector.store %3, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>, vector<1xbf16> + %2 = vector.from_elements %0 : vector<1xbf16> + %3 = vector.from_elements %1 : vector<1xbf16> + %4 = arith.addf %2, %3 : vector<1xbf16> + %5 = arith.addf %4, %cst : vector<1xbf16> + vector.store %5, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>, vector<1xbf16> gpu.return } } @@ -51,5 +47,5 @@ module @eltwise_add attributes {gpu.container_module} { call @printMemrefBF16(%cast) : (memref<*xbf16>) -> () return } - func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/Gpu/ceil_floor_BF16.mlir b/test/Integration/Dialect/Gpu/ceil_floor_BF16.mlir index f50f123b9..7188294a0 100644 --- a/test/Integration/Dialect/Gpu/ceil_floor_BF16.mlir +++ b/test/Integration/Dialect/Gpu/ceil_floor_BF16.mlir @@ -1,11 +1,7 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ +// RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @eltwise_add attributes {gpu.container_module} { memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01> @@ -13,20 +9,20 @@ module @eltwise_add attributes {gpu.container_module} { %c20 = arith.constant 20 : index %c10 = arith.constant 10 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<10x20xbf16> - memref.copy %arg1, %memref : memref<10x20xbf16> to memref<10x20xbf16> - %memref_0 = gpu.alloc host_shared () : memref<10x20xbf16> - memref.copy %arg0, %memref_0 : memref<10x20xbf16> to memref<10x20xbf16> - %memref_1 = gpu.alloc host_shared () : memref<10x20xbf16> + %memref = gpu.alloc () : memref<10x20xbf16> + gpu.memcpy %memref, %arg1 : memref<10x20xbf16>, memref<10x20xbf16> + %memref_0 = gpu.alloc () : memref<10x20xbf16> + gpu.memcpy %memref_0, %arg0 : memref<10x20xbf16>, memref<10x20xbf16> + %memref_1 = gpu.alloc () : memref<10x20xbf16> gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<10x20xbf16>, %memref : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>) %alloc = memref.alloc() : memref<10x20xbf16> - memref.copy %memref_1, %alloc : memref<10x20xbf16> to memref<10x20xbf16> + gpu.memcpy %alloc, %memref_1 : memref<10x20xbf16>, memref<10x20xbf16> gpu.dealloc %memref_1 : memref<10x20xbf16> gpu.dealloc %memref_0 : memref<10x20xbf16> gpu.dealloc %memref : memref<10x20xbf16> return %alloc : memref<10x20xbf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %block_id_x = gpu.block_id x %block_id_y = gpu.block_id y @@ -49,5 +45,5 @@ module @eltwise_add attributes {gpu.container_module} { call @printMemrefBF16(%cast) : (memref<*xbf16>) -> () return } - func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/Gpu/gpu-to-llvm.pp b/test/Integration/Dialect/Gpu/gpu-to-llvm.pp index 11f17cbbf..db39a1b24 100644 --- a/test/Integration/Dialect/Gpu/gpu-to-llvm.pp +++ b/test/Integration/Dialect/Gpu/gpu-to-llvm.pp @@ -1,31 +1,26 @@ -// gpu dialect with intel intrinsic functions (func dialect) to +// gpu dialect to // llvm dialect (for host code) and // spirv dialect (for device code) lowering pipeline. // Ready for imex runner starting from GPU dialect. builtin.module( xegpu-vector-linearize + canonicalize cse - gpu.module(convert-math-to-vc{enable-high-precision-interim-calculation=true}) reconcile-unrealized-casts - bf16-to-gpu - imex-convert-gpu-to-spirv - spirv.module(spirv-lower-abi-attrs - spirv-update-vce) + gpu.module(math-extend-to-supported-types{target-type=f32}) + gpu.module(arith-emulate-unsupported-floats{source-types=bf16 target-type=f32}) + spirv-attach-target{ver=v1.0 caps=Addresses,BFloat16TypeKHR,Float16Buffer,Int64,Int16,Int8,Kernel,Linkage,Vector16,GenericPointer,Groups,Float16,Float64,AtomicFloat32AddEXT,ExpectAssumeKHR,SubgroupDispatch,VectorComputeINTEL,VectorAnyINTEL,Bfloat16ConversionINTEL exts=SPV_EXT_shader_atomic_float_add,SPV_KHR_bfloat16,SPV_KHR_expect_assume,SPV_INTEL_vector_compute,SPV_INTEL_bfloat16_conversion} + imex-convert-to-spirv{use-64bit-index=true} + gpu.module(spirv.module(spirv-lower-abi-attrs, spirv-update-vce)) func.func(llvm-request-c-wrappers) - serialize-spirv convert-vector-to-scf - convert-gpu-to-gpux convert-scf-to-cf + func.func(gpu-async-region) expand-strided-metadata + gpu-to-llvm{use-bare-pointers-for-kernels=true} finalize-memref-to-llvm - convert-cf-to-llvm - convert-vector-to-llvm - convert-index-to-llvm - convert-arith-to-llvm - convert-func-to-llvm - convert-math-to-llvm - convert-gpux-to-llvm - convert-index-to-llvm + convert-to-llvm + gpu-module-to-binary lower-affine reconcile-unrealized-casts) // End diff --git a/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir b/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir index 0ea9f5c37..95df1ba81 100644 --- a/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeGPU/SG/gemm_4kx4kx4k_f16_f16_f16.mlir @@ -29,8 +29,8 @@ module @gemm attributes {gpu.container_module} { return %C : memref<4096x4096xf16> } - gpu.module @test_kernel { - gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) kernel { + gpu.module @test_kernel { + gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) kernel { // constants %c256 = arith.constant 256 : index %c512 = arith.constant 512 : index diff --git a/test/Integration/Dialect/XeGPU/SG/xegpu-to-llvm.pp b/test/Integration/Dialect/XeGPU/SG/xegpu-to-llvm.pp index 91d2c71bb..795bdf430 100644 --- a/test/Integration/Dialect/XeGPU/SG/xegpu-to-llvm.pp +++ b/test/Integration/Dialect/XeGPU/SG/xegpu-to-llvm.pp @@ -9,6 +9,8 @@ loop-invariant-code-motion cse xegpu-vector-linearize + canonicalize + cse convert-xegpu-to-xevm convert-gpu-to-llvm-spv{use-64bit-index=true} convert-xevm-to-llvm diff --git a/test/Integration/Dialect/XeGPU/VC/atomic_rmw.mlir b/test/Integration/Dialect/XeGPU/VC/atomic_rmw.mlir index 819a19510..1a1b2fc70 100644 --- a/test/Integration/Dialect/XeGPU/VC/atomic_rmw.mlir +++ b/test/Integration/Dialect/XeGPU/VC/atomic_rmw.mlir @@ -1,51 +1,46 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#scatter = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%a: memref<16xf32>) -> memref<16xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<16xf32>) -> memref<16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %a_gpu = gpu.alloc host_shared () : memref<16xf32> - memref.copy %a, %a_gpu : memref<16xf32> to memref<16xf32> - %out = gpu.alloc host_shared () : memref<16xf32> - gpu.launch_func @test_kernel::@test_atomic_rmw blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%a_gpu: memref<16xf32>, %out : memref<16xf32>) - return %a_gpu : memref<16xf32> + %memref = gpu.alloc () : memref<16xf32> + gpu.memcpy %memref, %arg0 : memref<16xf32>, memref<16xf32> + %memref_0 = gpu.alloc () : memref<16xf32> + gpu.launch_func @test_kernel::@test_atomic_rmw blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16xf32>, %memref_0 : memref<16xf32>) + %alloc = memref.alloc() : memref<16xf32> + gpu.memcpy %alloc, %memref : memref<16xf32>, memref<16xf32> + gpu.dealloc %memref : memref<16xf32> + return %alloc : memref<16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_atomic_rmw(%input: memref<16xf32>, %mem: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf32> - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> - %in_tdesc = xegpu.create_tdesc %input, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #scatter> - %atomic_rmw = xegpu.atomic_rmw addf %in_tdesc, %mask, %cst : !xegpu.tensor_desc<16xf32, #scatter>, vector<16xi1>, vector<16xf32> -> vector<16xf32> - %out_tdesc = xegpu.create_tdesc %mem, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #scatter> - xegpu.store %atomic_rmw, %out_tdesc, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #scatter>, vector<16xi1> + gpu.module @test_kernel { + gpu.func @test_atomic_rmw(%arg0: memref<16xf32>, %arg1: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01, 1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01]> : vector<16xf32> + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.atomic_rmw addf %0, %cst_0, %cst : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>, vector<16xf32> -> vector<16xf32> + %2 = xegpu.create_tdesc %arg1, %cst_1 : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + xegpu.store %1, %2, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %a = memref.alloc() : memref<16xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index + %cst = arith.constant 1.000000e+00 : f32 %c16 = arith.constant 16 : index - %c1_f32 = arith.constant 1.0 : f32 - scf.for %i = %c0 to %c16 step %c1 { - memref.store %c1_f32, %a[%i] : memref<16xf32> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<16xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + memref.store %cst, %alloc[%arg0] : memref<16xf32> } - - %B = call @test(%a) : (memref<16xf32>) -> memref<16xf32> - %cast = memref.cast %B : memref<16xf32> to memref<*xf32> //CHECK: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + %0 = call @test(%alloc) : (memref<16xf32>) -> memref<16xf32> + %cast = memref.cast %0 : memref<16xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_0_fp32.mlir b/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_0_fp32.mlir index 2e42128e0..3328e41a0 100644 --- a/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_0_fp32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_0_fp32.mlir @@ -1,91 +1,76 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @softmax attributes {gpu.container_module} { - func.func @reduce_test(%a: memref<16x512xf32>) -> memref<512xf32> attributes {llvm.emit_c_interface} { + func.func @reduce_test(%arg0: memref<16x512xf32>) -> memref<512xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - - %a_gpu = gpu.alloc host_shared () : memref<16x512xf32> - memref.copy %a, %a_gpu : memref<16x512xf32> to memref<16x512xf32> - %b_gpu = gpu.alloc host_shared () : memref<512xf32> - - gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c1, %c32, %c1) threads in (%c1, %c1, %c1) args(%a_gpu : memref<16x512xf32>, %b_gpu : memref<512xf32>) - - gpu.dealloc %a_gpu : memref<16x512xf32> - return %b_gpu : memref<512xf32> + %memref = gpu.alloc () : memref<16x512xf32> + gpu.memcpy %memref, %arg0 : memref<16x512xf32>, memref<16x512xf32> + %memref_0 = gpu.alloc () : memref<512xf32> + gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c1, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x512xf32>, %memref_0 : memref<512xf32>) + gpu.dealloc %memref : memref<16x512xf32> + %alloc = memref.alloc() : memref<512xf32> + gpu.memcpy %alloc, %memref_0 : memref<512xf32>, memref<512xf32> + gpu.dealloc %memref_0 : memref<512xf32> + return %alloc : memref<512xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block reduction. each thread is assigned with a 16x32 block, and do reduction along dim-0 independently. - gpu.func @reduce_dim_1(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @kernel { + gpu.func @reduce_dim_1(%arg0: memref<16x512xf32>, %arg1: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - %acc = arith.constant dense<0.0> : vector<16xf32> - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c16 : index - %0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<16x512xf32> -> !xegpu.tensor_desc<16x16xf32> - %1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> - - %2 = vector.multi_reduction , %1, %acc [0]: vector<16x16xf32> to vector<16xf32> - - %3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> - xegpu.store_nd %2, %3: vector<16xf32>, !xegpu.tensor_desc<16xf32> + %cst = arith.constant dense<0.000000e+00> : vector<16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg0[%0, %1] : memref<16x512xf32> -> !xegpu.tensor_desc<16x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> + %4 = vector.multi_reduction , %3, %cst [0] : vector<16x16xf32> to vector<16xf32> + %5 = xegpu.create_nd_tdesc %arg1[%1] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> + xegpu.store_nd %4, %5 : vector<16xf32>, !xegpu.tensor_desc<16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index %c512 = arith.constant 512 : index - %c0_f32 = arith.constant 0.0 : f32 - %c32_f32 = arith.constant 32.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<16x512xf32> - %b_ref = memref.alloc() : memref<512xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c512 step %c1 { - %t = index.castu %j : index to i16 - %u = arith.uitofp %t : i16 to f32 - %v = arith.divf %u, %c100_f32 : f32 - memref.store %v, %a[%i, %j] : memref<16x512xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+02 : f32 + %alloc = memref.alloc() : memref<16x512xf32> + %alloc_1 = memref.alloc() : memref<512xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c512 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f32 + %3 = arith.divf %2, %cst_0 : f32 + memref.store %3, %alloc[%arg0, %arg1] : memref<16x512xf32> } } - - scf.for %j = %c0 to %c512 step %c1 { - %sum = scf.for %i = %c0 to %c16 step %c1 iter_args(%arg = %c0_f32) -> (f32) { - %val = memref.load %a[%i, %j] : memref<16x512xf32> - %2 = arith.addf %arg, %val : f32 - scf.yield %2 : f32 + scf.for %arg0 = %c0 to %c512 step %c1 { + %1 = scf.for %arg1 = %c0 to %c16 step %c1 iter_args(%arg2 = %cst) -> (f32) { + %2 = memref.load %alloc[%arg1, %arg0] : memref<16x512xf32> + %3 = arith.addf %arg2, %2 : f32 + scf.yield %3 : f32 } - memref.store %sum, %b_ref[%j] : memref<512xf32> + memref.store %1, %alloc_1[%arg0] : memref<512xf32> } - - %b = call @reduce_test(%a) : (memref<16x512xf32>) -> memref<512xf32> - %cast_b = memref.cast %b : memref<512xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<512xf32> to memref<*xf32> // call @printMemrefF32(%cast_b): (memref<*xf32>) -> () // call @printMemrefF32(%cast_b_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<16x512xf32> - memref.dealloc %b_ref : memref<512xf32> + %0 = call @reduce_test(%alloc) : (memref<16x512xf32>) -> memref<512xf32> + %cast = memref.cast %0 : memref<512xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<512xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_2) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<16x512xf32> + memref.dealloc %alloc_1 : memref<512xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_1_fp32.mlir b/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_1_fp32.mlir index 380950565..d7ddf56ba 100644 --- a/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_1_fp32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/block_reduce_dim_1_fp32.mlir @@ -1,95 +1,80 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @softmax attributes {gpu.container_module} { - func.func @reduce_test(%a: memref<16x512xf32>) -> memref<512xf32> attributes {llvm.emit_c_interface} { + func.func @reduce_test(%arg0: memref<16x512xf32>) -> memref<512xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - - %a_gpu = gpu.alloc host_shared () : memref<16x512xf32> - memref.copy %a, %a_gpu : memref<16x512xf32> to memref<16x512xf32> - %b_gpu = gpu.alloc host_shared () : memref<512xf32> - - gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c1, %c32, %c1) threads in (%c1, %c1, %c1) args(%a_gpu : memref<16x512xf32>, %b_gpu : memref<512xf32>) - - gpu.dealloc %a_gpu : memref<16x512xf32> - return %b_gpu : memref<512xf32> + %memref = gpu.alloc () : memref<16x512xf32> + gpu.memcpy %memref, %arg0 : memref<16x512xf32>, memref<16x512xf32> + %memref_0 = gpu.alloc () : memref<512xf32> + gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c1, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x512xf32>, %memref_0 : memref<512xf32>) + gpu.dealloc %memref : memref<16x512xf32> + %alloc = memref.alloc() : memref<512xf32> + gpu.memcpy %alloc, %memref_0 : memref<512xf32>, memref<512xf32> + gpu.dealloc %memref_0 : memref<512xf32> + return %alloc : memref<512xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block reduction. each thread is assigned with a 16x32 block, and do reduction along dim-0 independently. - gpu.func @reduce_dim_1(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @kernel { + gpu.func @reduce_dim_1(%arg0: memref<16x512xf32>, %arg1: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - %acc = arith.constant dense<0.0> : vector<16xf32> - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c16 : index - %0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<16x512xf32> -> !xegpu.tensor_desc<16x16xf32> - %1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> - %2 = vector.multi_reduction , %1, %acc [1]: vector<16x16xf32> to vector<16xf32> - %3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> - xegpu.store_nd %2, %3: vector<16xf32>, !xegpu.tensor_desc<16xf32> + %cst = arith.constant dense<0.000000e+00> : vector<16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg0[%0, %1] : memref<16x512xf32> -> !xegpu.tensor_desc<16x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> + %4 = vector.multi_reduction , %3, %cst [1] : vector<16x16xf32> to vector<16xf32> + %5 = xegpu.create_nd_tdesc %arg1[%1] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> + xegpu.store_nd %4, %5 : vector<16xf32>, !xegpu.tensor_desc<16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index %c512 = arith.constant 512 : index - %c0_f32 = arith.constant 0.0 : f32 - %c32_f32 = arith.constant 32.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<16x512xf32> - %b_ref = memref.alloc() : memref<512xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c512 step %c1 { - %t = index.castu %j : index to i16 - %u = arith.uitofp %t : i16 to f32 - %v = arith.divf %u, %c100_f32 : f32 - memref.store %v, %a[%i, %j] : memref<16x512xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+02 : f32 + %alloc = memref.alloc() : memref<16x512xf32> + %alloc_1 = memref.alloc() : memref<512xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c512 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f32 + %3 = arith.divf %2, %cst_0 : f32 + memref.store %3, %alloc[%arg0, %arg1] : memref<16x512xf32> } } - - scf.for %j = %c0 to %c512 step %c16 { - scf.for %i = %c0 to %c16 step %c1 { - %sum = scf.for %k = %c0 to %c16 step %c1 iter_args(%arg = %c0_f32) -> (f32) { - %idx = arith.addi %j, %k : index - %val = memref.load %a[%i, %idx] : memref<16x512xf32> - %2 = arith.addf %arg, %val : f32 - scf.yield %2 : f32 + scf.for %arg0 = %c0 to %c512 step %c16 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %cst) -> (f32) { + %3 = arith.addi %arg0, %arg2 : index + %4 = memref.load %alloc[%arg1, %3] : memref<16x512xf32> + %5 = arith.addf %arg3, %4 : f32 + scf.yield %5 : f32 } - %m = arith.addi %j, %i : index - memref.store %sum, %b_ref[%m] : memref<512xf32> + %2 = arith.addi %arg0, %arg1 : index + memref.store %1, %alloc_1[%2] : memref<512xf32> } } - - - - %b = call @reduce_test(%a) : (memref<16x512xf32>) -> memref<512xf32> - %cast_b = memref.cast %b : memref<512xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<512xf32> to memref<*xf32> // call @printMemrefF32(%cast_b): (memref<*xf32>) -> () // call @printMemrefF32(%cast_b_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<16x512xf32> - memref.dealloc %b_ref : memref<512xf32> + %0 = call @reduce_test(%alloc) : (memref<16x512xf32>) -> memref<512xf32> + %cast = memref.cast %0 : memref<512xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<512xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_2) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<16x512xf32> + memref.dealloc %alloc_1 : memref<512xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/ceil_floor_f32.mlir b/test/Integration/Dialect/XeGPU/VC/ceil_floor_f32.mlir index bc5b3428a..183eb70a0 100644 --- a/test/Integration/Dialect/XeGPU/VC/ceil_floor_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/ceil_floor_f32.mlir @@ -1,32 +1,31 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<0.0> + memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index - - %memref = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %arg0, %memref : memref<8x16xf32> to memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %arg1, %memref_1 : memref<8x16xf32> to memref<8x16xf32> - %memref_2 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_1 : memref<8x16xf32>, %memref_2 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref_0, %arg1 : memref<8x16xf32>, memref<8x16xf32> + %memref_1 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>, %memref_1 : memref<8x16xf32>) gpu.dealloc %memref : memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_1 : memref<8x16xf32>, memref<8x16xf32> gpu.dealloc %memref_1 : memref<8x16xf32> - return %memref_2 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %thread_id_x = gpu.thread_id x + %thread_id_x = gpu.thread_id x cf.br ^bb1 - ^bb1: + ^bb1: // pred: ^bb0 %0 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<16xf32> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> %2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<16xf32> @@ -40,42 +39,37 @@ module @gemm attributes {gpu.container_module} { } } func.func @main() attributes {llvm.emit_c_interface} { - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - %A = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - - %B = memref.alloc() : memref<8x16xf32> - %B_random = memref.cast %B : memref<8x16xf32> to memref<*xf32> - call @fillResource1DRandomF32(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - // calculate the result C matrix - %c16 = arith.constant 16 : index - %c8 = arith.constant 8 : index - %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %ref = memref.alloc() : memref<8x16xf32> - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %a = memref.load %A[%i, %j] : memref<8x16xf32> - %b = memref.load %B[%i, %j] : memref<8x16xf32> - %a_ceiled = math.ceil %a : f32 - %b_floored = math.floor %b : f32 - %c = arith.addf %a_ceiled, %b_floored : f32 - memref.store %c, %ref[%i, %j] : memref<8x16xf32> + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %false = arith.constant false + %cst = arith.constant -5.000000e-01 : f32 + %cst_0 = arith.constant 5.000000e-01 : f32 + %alloc = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %alloc_1 = memref.alloc() : memref<8x16xf32> + %cast_2 = memref.cast %alloc_1 : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast_2, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %alloc_3 = memref.alloc() : memref<8x16xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<8x16xf32> + %2 = memref.load %alloc_1[%arg0, %arg1] : memref<8x16xf32> + %3 = math.ceil %1 : f32 + %4 = math.floor %2 : f32 + %5 = arith.addf %3, %4 : f32 + memref.store %5, %alloc_3[%arg0, %arg1] : memref<8x16xf32> } } - - %C = call @test(%A, %B) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> - - %C_cast = memref.cast %C : memref<8x16xf32> to memref<*xf32> - %ref_cast = memref.cast %ref : memref<8x16xf32> to memref<*xf32> // call @printMemrefF32(%C_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%ref_cast, %C_cast) : (memref<*xf32>, memref<*xf32>) -> () + %0 = call @test(%alloc, %alloc_1) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> + %cast_4 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_3 : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast_5, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/dynamic_memref.vc.mlir b/test/Integration/Dialect/XeGPU/VC/dynamic_memref.vc.mlir index 1e7ff9e1d..e28d544aa 100644 --- a/test/Integration/Dialect/XeGPU/VC/dynamic_memref.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/dynamic_memref.vc.mlir @@ -1,52 +1,49 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref_0 = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %A, %memref_0 : memref<8x16xf32> to memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - %memref_0_cast = memref.cast %memref_0 : memref<8x16xf32> to memref - %memref_1_cast = memref.cast %memref_1 : memref<8x16xf32> to memref - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0_cast : memref, %memref_1_cast : memref) - gpu.dealloc %memref_0 : memref<8x16xf32> - return %memref_1 : memref<8x16xf32> + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + %cast = memref.cast %memref : memref<8x16xf32> to memref + %cast_1 = memref.cast %memref_0 : memref<8x16xf32> to memref + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%cast : memref, %cast_1 : memref) + gpu.dealloc %memref : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0 : memref, %arg1: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref, %arg1: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %1 = xegpu.create_nd_tdesc %arg0[0, 0], shape: [%c8, %c16], strides: [%c16, %c1] : memref -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - %6 = xegpu.create_nd_tdesc %arg1[0, 0], shape: [%c8, %c16], strides: [%c16, %c1] : memref -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %2, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[0, 0], shape : [%c8, %c16], strides : [%c16, %c1] : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.load_nd %0 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg1[0, 0], shape : [%c8, %c16], strides : [%c16, %c1] : memref -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %1, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - - %B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32> - %B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32> - %A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32> // call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<8x16xf32> + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst_0, %cst, %false) : (memref<*xf32>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<8x16xf32>) -> memref<8x16xf32> + %cast_1 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast_2, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_1d.mlir b/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_1d.mlir index 362794b09..da5574d02 100644 --- a/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_1d.mlir +++ b/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_1d.mlir @@ -1,68 +1,62 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__Aconstant_8x32xf32 : memref<8x32xf32> = dense<1.0> - memref.global "private" @__Bconstant_8x32xf32 : memref<8x32xf32> = dense<2.0> + memref.global "private" @__Aconstant_8x32xf32 : memref<8x32xf32> = dense<1.000000e+00> + memref.global "private" @__Bconstant_8x32xf32 : memref<8x32xf32> = dense<2.000000e+00> func.func @test(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf32>) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - %c0_f32 = arith.constant 0.0 : f32 - - %A = gpu.alloc host_shared () : memref<8x32xf32> - memref.copy %arg0, %A : memref<8x32xf32> to memref<8x32xf32> - %B = gpu.alloc host_shared () : memref<8x32xf32> - memref.copy %arg1, %B : memref<8x32xf32> to memref<8x32xf32> - - %C = gpu.alloc host_shared () : memref<8x32xf32> - %C_unranked = memref.cast %C : memref<8x32xf32> to memref<*xf32> - call @fillResource1DF32(%C_unranked, %c0_f32) : (memref<*xf32>, f32) -> () - - %A_strided_dynamic = memref.subview %A[%c0, %c0][%c8, %c16][%c1, %c1] : memref<8x32xf32> to memref> - %B_strided_dynamic = memref.subview %B[%c0, %c0][%c8, %c16][%c1, %c1] : memref<8x32xf32> to memref> - %C_strided_dynamic = memref.subview %C[%c0, %c0][%c8, %c16][%c1, %c1] : memref<8x32xf32> to memref> - + %cst = arith.constant 0.000000e+00 : f32 + %memref = gpu.alloc () : memref<8x32xf32> + gpu.memcpy %memref, %arg0 : memref<8x32xf32>, memref<8x32xf32> + %memref_0 = gpu.alloc () : memref<8x32xf32> + gpu.memcpy %memref_0, %arg1 : memref<8x32xf32>, memref<8x32xf32> + %meref_host = memref.alloc() : memref<8x32xf32> + %cast_host = memref.cast %meref_host : memref<8x32xf32> to memref<*xf32> + call @fillResource1DF32(%cast_host, %cst) : (memref<*xf32>, f32) -> () + %memref_1 = gpu.alloc () : memref<8x32xf32> + gpu.memcpy %memref_1, %meref_host : memref<8x32xf32>, memref<8x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%A_strided_dynamic : memref>, %B_strided_dynamic : memref>, %C_strided_dynamic : memref>, %c8 : index, %c16 : index, %c32 : index, %c1 : index) - gpu.dealloc %A : memref<8x32xf32> - gpu.dealloc %B : memref<8x32xf32> - return %C : memref<8x32xf32> + %subview = memref.subview %memref[%c0, %c0] [%c8, %c16] [%c1, %c1] : memref<8x32xf32> to memref> + %subview_2 = memref.subview %memref_0[%c0, %c0] [%c8, %c16] [%c1, %c1] : memref<8x32xf32> to memref> + %subview_3 = memref.subview %memref_1[%c0, %c0] [%c8, %c16] [%c1, %c1] : memref<8x32xf32> to memref> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%subview : memref>, %subview_2 : memref>, %subview_3 : memref>, %c8 : index, %c16 : index, %c32 : index, %c1 : index) + memref.dealloc %meref_host : memref<8x32xf32> + gpu.dealloc %memref : memref<8x32xf32> + gpu.dealloc %memref_0 : memref<8x32xf32> + %alloc = memref.alloc() : memref<8x32xf32> + gpu.memcpy %alloc, %memref_1 : memref<8x32xf32>, memref<8x32xf32> + gpu.dealloc %memref_1 : memref<8x32xf32> + return %alloc : memref<8x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref>, %arg1: memref>, %arg2: memref>, %shape_x : index, %shape_y : index, %stride_x : index, %stride_y : index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %thread_id_x = gpu.thread_id x - - %0 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0], shape: [%shape_x, %shape_y], strides: [%stride_x, %stride_y] : memref> -> !xegpu.tensor_desc<16xf32> + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref>, %arg1: memref>, %arg2: memref>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %thread_id_x = gpu.thread_id x + %0 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0], shape : [%arg3, %arg4], strides : [%arg5, %arg6] : memref> -> !xegpu.tensor_desc<16xf32> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> - %2 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0], shape: [%shape_x, %shape_y], strides: [%stride_x, %stride_y] : memref> -> !xegpu.tensor_desc<16xf32> + %2 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0], shape : [%arg3, %arg4], strides : [%arg5, %arg6] : memref> -> !xegpu.tensor_desc<16xf32> %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> %4 = arith.addf %3, %1 : vector<16xf32> - %5 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0], shape: [%shape_x, %shape_y], strides: [%stride_x, %stride_y] : memref> -> !xegpu.tensor_desc<16xf32> + %5 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0], shape : [%arg3, %arg4], strides : [%arg5, %arg6] : memref> -> !xegpu.tensor_desc<16xf32> xegpu.store_nd %4, %5 : vector<16xf32>, !xegpu.tensor_desc<16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - // Allocate/get regular row major memrefs - %A = memref.get_global @__Aconstant_8x32xf32 : memref<8x32xf32> - %B = memref.get_global @__Bconstant_8x32xf32 : memref<8x32xf32> - - %result = call @test(%A, %B) : (memref<8x32xf32>, memref<8x32xf32>) -> memref<8x32xf32> - - %result_cast = memref.cast %result : memref<8x32xf32> to memref<*xf32> - call @printMemrefF32(%result_cast) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} // CHECK-NEXT:[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - + %0 = memref.get_global @__Aconstant_8x32xf32 : memref<8x32xf32> + %1 = memref.get_global @__Bconstant_8x32xf32 : memref<8x32xf32> + %2 = call @test(%0, %1) : (memref<8x32xf32>, memref<8x32xf32>) -> memref<8x32xf32> + %cast = memref.cast %2 : memref<8x32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } func.func private @fillResource1DF32(memref<*xf32>, f32) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_2d.mlir b/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_2d.mlir index a80dd4a50..1d56aa769 100644 --- a/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_2d.mlir +++ b/test/Integration/Dialect/XeGPU/VC/dynamic_strided_memref_2d.mlir @@ -1,14 +1,11 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__Aconstant_32x64xf16 : memref<32x64xf16> = dense<1.0> - memref.global "private" @__Bconstant_32x64xf16 : memref<32x64xf16> = dense<2.0> + memref.global "private" @__Aconstant_32x64xf16 : memref<32x64xf16> = dense<1.000000e+00> + memref.global "private" @__Bconstant_32x64xf16 : memref<32x64xf16> = dense<2.000000e+00> func.func @test(%arg0: memref<32x64xf16>, %arg1: memref<32x64xf16>) -> memref<32x64xf32> attributes {llvm.emit_c_interface} { %c64 = arith.constant 64 : index %c32 = arith.constant 32 : index @@ -16,72 +13,63 @@ module @gemm attributes {gpu.container_module} { %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %c0_f32 = arith.constant 0.0 : f32 - - - %A = gpu.alloc host_shared () : memref<32x64xf16> - memref.copy %arg0, %A : memref<32x64xf16> to memref<32x64xf16> - %B = gpu.alloc host_shared () : memref<32x64xf16> - memref.copy %arg1, %B : memref<32x64xf16> to memref<32x64xf16> - - %C = gpu.alloc host_shared () : memref<32x64xf32> - %C_unranked = memref.cast %C : memref<32x64xf32> to memref<*xf32> - call @fillResource1DF32(%C_unranked, %c0_f32) : (memref<*xf32>, f32) -> () - - %A_strided_dynamic = memref.subview %A[%c0, %c0][%c32, %c32][%c1, %c1] : memref<32x64xf16> to memref> - %B_strided_dynamic = memref.subview %B[%c0, %c0][%c32, %c32][%c1, %c1] : memref<32x64xf16> to memref> - %C_strided_dynamic = memref.subview %C[%c0, %c0][%c32, %c32][%c1, %c1] : memref<32x64xf32> to memref> - - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%A_strided_dynamic : memref>, %B_strided_dynamic : memref>, %C_strided_dynamic : memref>, %c32 : index, %c32 : index, %c64 : index, %c1 : index) - gpu.dealloc %A : memref<32x64xf16> - gpu.dealloc %B : memref<32x64xf16> - return %C : memref<32x64xf32> + %cst = arith.constant 0.000000e+00 : f32 + %memref = gpu.alloc () : memref<32x64xf16> + gpu.memcpy %memref, %arg0 : memref<32x64xf16>, memref<32x64xf16> + %memref_0 = gpu.alloc () : memref<32x64xf16> + gpu.memcpy %memref_0, %arg1 : memref<32x64xf16>, memref<32x64xf16> + %meref_host = memref.alloc() : memref<32x64xf32> + %cast_host = memref.cast %meref_host : memref<32x64xf32> to memref<*xf32> + call @fillResource1DF32(%cast_host, %cst) : (memref<*xf32>, f32) -> () + %memref_1 = gpu.alloc () : memref<32x64xf32> + gpu.memcpy %memref_1, %meref_host : memref<32x64xf32>, memref<32x64xf32> + %subview = memref.subview %memref[%c0, %c0] [%c32, %c32] [%c1, %c1] : memref<32x64xf16> to memref> + %subview_2 = memref.subview %memref_0[%c0, %c0] [%c32, %c32] [%c1, %c1] : memref<32x64xf16> to memref> + %subview_3 = memref.subview %memref_1[%c0, %c0] [%c32, %c32] [%c1, %c1] : memref<32x64xf32> to memref> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%subview : memref>, %subview_2 : memref>, %subview_3 : memref>, %c32 : index, %c32 : index, %c64 : index, %c1 : index) + memref.dealloc %meref_host : memref<32x64xf32> + gpu.dealloc %memref : memref<32x64xf16> + gpu.dealloc %memref_0 : memref<32x64xf16> + %alloc = memref.alloc() : memref<32x64xf32> + gpu.memcpy %alloc, %memref_1 : memref<32x64xf32>, memref<32x64xf32> + gpu.dealloc %memref_1 : memref<32x64xf32> + return %alloc : memref<32x64xf32> } - -gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref>, %B: memref>, %C: memref>, %shape_x : index, %shape_y : index, %stride_x : index, %stride_y : index) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref>, %arg1: memref>, %arg2: memref>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c8 = arith.constant 8 : index - %cst = arith.constant dense<1.0> : vector<8x16xf16> - %0 = gpu.block_id x - %1 = gpu.block_id y - - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - - %4 = xegpu.create_nd_tdesc %C[%2, %3], shape: [%shape_x, %shape_y], strides: [%stride_x, %stride_y] : memref> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - - %6 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %A0 = xegpu.create_nd_tdesc %A[%2, %arg3], shape: [%shape_x, %shape_y], strides: [%stride_x, %stride_y] : memref> -> !xegpu.tensor_desc<8x16xf16> - %A0_val = xegpu.load_nd %A0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - - %B0 = xegpu.create_nd_tdesc %B[%arg3, %3], shape: [%shape_x, %shape_y], strides: [%stride_x, %stride_y] : memref> -> !xegpu.tensor_desc<16x16xf16> - %B0_val = xegpu.load_nd %B0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - - %A0_preop = arith.addf %A0_val, %cst : vector<8x16xf16> - - %dpas0 = xegpu.dpas %A0_preop, %B0_val , %arg4: vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - scf.yield %dpas0 : vector<8x16xf32> + %cst = arith.constant dense<1.000000e+00> : vector<8x16xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1], shape : [%arg3, %arg4], strides : [%arg5, %arg6] : memref> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg7 = %c0 to %c32 step %c16 iter_args(%arg8 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg7], shape : [%arg3, %arg4], strides : [%arg5, %arg6] : memref> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %7 = xegpu.create_nd_tdesc %arg1[%arg7, %1], shape : [%arg3, %arg4], strides : [%arg5, %arg6] : memref> -> !xegpu.tensor_desc<16x16xf16> + %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = arith.addf %6, %cst : vector<8x16xf16> + %10 = xegpu.dpas %9, %8, %arg8 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %10 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { // Allocate/get regular row major memrefs - %A = memref.get_global @__Aconstant_32x64xf16 : memref<32x64xf16> - %B = memref.get_global @__Bconstant_32x64xf16 : memref<32x64xf16> - - %result = call @test(%A, %B) : (memref<32x64xf16>, memref<32x64xf16>) -> memref<32x64xf32> - %result_cast = memref.cast %result : memref<32x64xf32> to memref<*xf32> - call @printMemrefF32(%result_cast) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} // CHECK-NEXT:[128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + %0 = memref.get_global @__Aconstant_32x64xf16 : memref<32x64xf16> + %1 = memref.get_global @__Bconstant_32x64xf16 : memref<32x64xf16> + %2 = call @test(%0, %1) : (memref<32x64xf16>, memref<32x64xf16>) -> memref<32x64xf32> + %cast = memref.cast %2 : memref<32x64xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } func.func private @fillResource1DF32(memref<*xf32>, f32) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/eltwise_add_1_d.mlir b/test/Integration/Dialect/XeGPU/VC/eltwise_add_1_d.mlir index 7be97369b..f40a7cbcb 100644 --- a/test/Integration/Dialect/XeGPU/VC/eltwise_add_1_d.mlir +++ b/test/Integration/Dialect/XeGPU/VC/eltwise_add_1_d.mlir @@ -1,77 +1,71 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_512xf32 : memref<512xf32> = dense<0.0> + memref.global "private" constant @__constant_512xf32 : memref<512xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<512xf32>, %arg1: memref<512xf32>) -> memref<512xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - - %memref = gpu.alloc host_shared () : memref<512xf32> - memref.copy %arg0, %memref : memref<512xf32> to memref<512xf32> - %memref_1 = gpu.alloc host_shared () : memref<512xf32> - memref.copy %arg1, %memref_1 : memref<512xf32> to memref<512xf32> - %memref_2 = gpu.alloc host_shared () : memref<512xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c32, %c1, %c1) args(%memref : memref<512xf32>, %memref_1 : memref<512xf32>, %memref_2 : memref<512xf32>) + %memref = gpu.alloc () : memref<512xf32> + gpu.memcpy %memref, %arg0 : memref<512xf32>, memref<512xf32> + %memref_0 = gpu.alloc () : memref<512xf32> + gpu.memcpy %memref_0, %arg1 : memref<512xf32>, memref<512xf32> + %memref_1 = gpu.alloc () : memref<512xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c32, %c1, %c1) args(%memref : memref<512xf32>, %memref_0 : memref<512xf32>, %memref_1 : memref<512xf32>) gpu.dealloc %memref : memref<512xf32> + gpu.dealloc %memref_0 : memref<512xf32> + %alloc = memref.alloc() : memref<512xf32> + gpu.memcpy %alloc, %memref_1 : memref<512xf32>, memref<512xf32> gpu.dealloc %memref_1 : memref<512xf32> - return %memref_2 : memref<512xf32> + return %alloc : memref<512xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<512xf32>, %arg1: memref<512xf32>, %arg2: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %thread_id_x = gpu.thread_id x + %thread_id_x = gpu.thread_id x %c16 = arith.constant 16 : index cf.br ^bb1 - ^bb1: - %t = arith.muli %thread_id_x, %c16 : index - %0 = xegpu.create_nd_tdesc %arg1[%t]: memref<512xf32> -> !xegpu.tensor_desc<16xf32> - %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> - %2 = xegpu.create_nd_tdesc %arg0[%t] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> - %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> - %4 = arith.addf %3, %1 : vector<16xf32> - %5 = xegpu.create_nd_tdesc %arg2[%t] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> - xegpu.store_nd %4, %5 : vector<16xf32>, !xegpu.tensor_desc<16xf32> + ^bb1: // pred: ^bb0 + %0 = arith.muli %thread_id_x, %c16 : index + %1 = xegpu.create_nd_tdesc %arg1[%0] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> + %2 = xegpu.load_nd %1 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> + %3 = xegpu.create_nd_tdesc %arg0[%0] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> + %4 = xegpu.load_nd %3 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> + %5 = arith.addf %4, %2 : vector<16xf32> + %6 = xegpu.create_nd_tdesc %arg2[%0] : memref<512xf32> -> !xegpu.tensor_desc<16xf32> + xegpu.store_nd %5, %6 : vector<16xf32>, !xegpu.tensor_desc<16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -1. : f32 - %cf_upper = arith.constant 1. : f32 - - %A = memref.alloc() : memref<512xf32> - %A_random = memref.cast %A : memref<512xf32> to memref<*xf32> - call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - - %B = memref.alloc() : memref<512xf32> - %B_random = memref.cast %B : memref<512xf32> to memref<*xf32> - call @fillResource1DRandomF32(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - // calculate the result of C vector - %c512 = arith.constant 512 : index - %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %ref = memref.alloc() : memref<512xf32> - scf.for %i = %c0 to %c512 step %c1 { - %a = memref.load %A[%i] : memref<512xf32> - %b = memref.load %B[%i] : memref<512xf32> - %c = arith.addf %a, %b : f32 - memref.store %c, %ref[%i] : memref<512xf32> + %c1 = arith.constant 1 : index + %c512 = arith.constant 512 : index + %false = arith.constant false + %cst = arith.constant -1.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %alloc = memref.alloc() : memref<512xf32> + %cast = memref.cast %alloc : memref<512xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %alloc_1 = memref.alloc() : memref<512xf32> + %cast_2 = memref.cast %alloc_1 : memref<512xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast_2, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %alloc_3 = memref.alloc() : memref<512xf32> + scf.for %arg0 = %c0 to %c512 step %c1 { + %1 = memref.load %alloc[%arg0] : memref<512xf32> + %2 = memref.load %alloc_1[%arg0] : memref<512xf32> + %3 = arith.addf %1, %2 : f32 + memref.store %3, %alloc_3[%arg0] : memref<512xf32> } - - %C = call @test(%A, %B) : (memref<512xf32>, memref<512xf32>) -> memref<512xf32> - - %C_cast = memref.cast %C : memref<512xf32> to memref<*xf32> - %ref_cast = memref.cast %ref : memref<512xf32> to memref<*xf32> - call @printMemrefF32(%ref_cast) : (memref<*xf32>) -> () - call @printMemrefF32(%C_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%ref_cast, %C_cast) : (memref<*xf32>, memref<*xf32>) -> () + %0 = call @test(%alloc, %alloc_1) : (memref<512xf32>, memref<512xf32>) -> memref<512xf32> + %cast_4 = memref.cast %0 : memref<512xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_3 : memref<512xf32> to memref<*xf32> + call @printMemrefF32(%cast_5) : (memref<*xf32>) -> () + call @printMemrefF32(%cast_4) : (memref<*xf32>) -> () + call @printAllcloseF32(%cast_5, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/exp_f32.vc.mlir b/test/Integration/Dialect/XeGPU/VC/exp_f32.vc.mlir index 1c0900df3..5bdc6d205 100644 --- a/test/Integration/Dialect/XeGPU/VC/exp_f32.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/exp_f32.vc.mlir @@ -1,85 +1,85 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<8x16xf32>) -> (memref<8x16xf32>, memref<8x16xf32>) attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<8x16xf32>) -> (memref<8x16xf32>, memref<8x16xf32>) attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %A, %memref : memref<8x16xf32> to memref<8x16xf32> - - %memref_2 = gpu.alloc host_shared () : memref<8x16xf32> - %memref_3 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @module0::@test_exp_larger_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_2 : memref<8x16xf32>) - gpu.launch_func @module1::@test_exp_generic_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_3 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + %memref_1 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @module0::@test_exp_larger_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>) + gpu.launch_func @module1::@test_exp_generic_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_1 : memref<8x16xf32>) + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + %alloc_2 = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc_2, %memref_1 : memref<8x16xf32>, memref<8x16xf32> gpu.dealloc %memref : memref<8x16xf32> - return %memref_2, %memref_3 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + gpu.dealloc %memref_1 : memref<8x16xf32> + return %alloc, %alloc_2 : memref<8x16xf32>, memref<8x16xf32> } - - gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_exp_larger_vec(%A: memref<8x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module0 { + gpu.func @test_exp_larger_vec(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index // load A tile - %a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // take exp - %t6 = math.exp %val0 : vector<8x16xf32> // store - %out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %t6, %out_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = math.exp %1 : vector<8x16xf32> + %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %2, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } - gpu.module @module1 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_exp_generic_vec(%A: memref<8x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module1 { + gpu.func @test_exp_generic_vec(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index // load A tile - %a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - // extract the loaded vector into 16xf32 vectors - %v0 = vector.extract %val0[0] : vector<16xf32> from vector<8x16xf32> - %v1 = vector.extract %val0[1] : vector<16xf32> from vector<8x16xf32> - %v2 = vector.extract %val0[2] : vector<16xf32> from vector<8x16xf32> - %v3 = vector.extract %val0[3] : vector<16xf32> from vector<8x16xf32> - %v4 = vector.extract %val0[4] : vector<16xf32> from vector<8x16xf32> - %v5 = vector.extract %val0[5] : vector<16xf32> from vector<8x16xf32> - %v6 = vector.extract %val0[6] : vector<16xf32> from vector<8x16xf32> - %v7 = vector.extract %val0[7] : vector<16xf32> from vector<8x16xf32> // do generic size exp - %v0_exp = math.exp %v0 : vector<16xf32> - %v1_exp = math.exp %v1 : vector<16xf32> - %v2_exp = math.exp %v2 : vector<16xf32> - %v3_exp = math.exp %v3 : vector<16xf32> - %v4_exp = math.exp %v4 : vector<16xf32> - %v5_exp = math.exp %v5 : vector<16xf32> - %v6_exp = math.exp %v6 : vector<16xf32> - %v7_exp = math.exp %v7 : vector<16xf32> - %v0_exp_cast = vector.shape_cast %v0_exp : vector<16xf32> to vector<1x16xf32> - %v1_exp_cast = vector.shape_cast %v1_exp : vector<16xf32> to vector<1x16xf32> - %v2_exp_cast = vector.shape_cast %v2_exp : vector<16xf32> to vector<1x16xf32> - %v3_exp_cast = vector.shape_cast %v3_exp : vector<16xf32> to vector<1x16xf32> - %v4_exp_cast = vector.shape_cast %v4_exp : vector<16xf32> to vector<1x16xf32> - %v5_exp_cast = vector.shape_cast %v5_exp : vector<16xf32> to vector<1x16xf32> - %v6_exp_cast = vector.shape_cast %v6_exp : vector<16xf32> to vector<1x16xf32> - %v7_exp_cast = vector.shape_cast %v7_exp : vector<16xf32> to vector<1x16xf32> // construct 4x16xf32 vector from the smaller ones - %t0 = vector.shuffle %v0_exp_cast, %v1_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32> - %t1 = vector.shuffle %v2_exp_cast, %v3_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32> - %t2 = vector.shuffle %v4_exp_cast, %v5_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32> - %t3 = vector.shuffle %v6_exp_cast, %v7_exp_cast [0, 1] : vector<1x16xf32>, vector<1x16xf32> - %t4 = vector.shuffle %t0, %t1 [0, 1, 2, 3] : vector<2x16xf32>, vector<2x16xf32> - %t5 = vector.shuffle %t2, %t3 [0, 1, 2, 3] : vector<2x16xf32>, vector<2x16xf32> - %t6 = vector.shuffle %t4, %t5 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x16xf32>, vector<4x16xf32> // store - %out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %t6, %out_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = vector.extract %1[0] : vector<16xf32> from vector<8x16xf32> + %3 = vector.extract %1[1] : vector<16xf32> from vector<8x16xf32> + %4 = vector.extract %1[2] : vector<16xf32> from vector<8x16xf32> + %5 = vector.extract %1[3] : vector<16xf32> from vector<8x16xf32> + %6 = vector.extract %1[4] : vector<16xf32> from vector<8x16xf32> + %7 = vector.extract %1[5] : vector<16xf32> from vector<8x16xf32> + %8 = vector.extract %1[6] : vector<16xf32> from vector<8x16xf32> + %9 = vector.extract %1[7] : vector<16xf32> from vector<8x16xf32> + %10 = math.exp %2 : vector<16xf32> + %11 = math.exp %3 : vector<16xf32> + %12 = math.exp %4 : vector<16xf32> + %13 = math.exp %5 : vector<16xf32> + %14 = math.exp %6 : vector<16xf32> + %15 = math.exp %7 : vector<16xf32> + %16 = math.exp %8 : vector<16xf32> + %17 = math.exp %9 : vector<16xf32> + %18 = vector.shape_cast %10 : vector<16xf32> to vector<1x16xf32> + %19 = vector.shape_cast %11 : vector<16xf32> to vector<1x16xf32> + %20 = vector.shape_cast %12 : vector<16xf32> to vector<1x16xf32> + %21 = vector.shape_cast %13 : vector<16xf32> to vector<1x16xf32> + %22 = vector.shape_cast %14 : vector<16xf32> to vector<1x16xf32> + %23 = vector.shape_cast %15 : vector<16xf32> to vector<1x16xf32> + %24 = vector.shape_cast %16 : vector<16xf32> to vector<1x16xf32> + %25 = vector.shape_cast %17 : vector<16xf32> to vector<1x16xf32> + %26 = vector.shuffle %18, %19 [0, 1] : vector<1x16xf32>, vector<1x16xf32> + %27 = vector.shuffle %20, %21 [0, 1] : vector<1x16xf32>, vector<1x16xf32> + %28 = vector.shuffle %22, %23 [0, 1] : vector<1x16xf32>, vector<1x16xf32> + %29 = vector.shuffle %24, %25 [0, 1] : vector<1x16xf32>, vector<1x16xf32> + %30 = vector.shuffle %26, %27 [0, 1, 2, 3] : vector<2x16xf32>, vector<2x16xf32> + %31 = vector.shuffle %28, %29 [0, 1, 2, 3] : vector<2x16xf32>, vector<2x16xf32> + %32 = vector.shuffle %30, %31 [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x16xf32>, vector<4x16xf32> + %33 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %32, %33 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } @@ -89,39 +89,39 @@ module @gemm attributes {gpu.container_module} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %rand_lower = arith.constant -1.0 : f32 - %rand_upper = arith.constant 1.0 : f32 - %gen_int = arith.constant 0 : i1 - %A = memref.alloc() : memref<8x16xf32> - %Out_cpu = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - call @fillResource1DRandomF32(%A_random, %rand_lower, %rand_upper, %gen_int) : (memref<*xf32>, f32, f32, i1) -> () // run GPU version - %Out_gpu_large, %Out_gpu_generic = call @test(%A) : (memref<8x16xf32>) -> (memref<8x16xf32>, memref<8x16xf32>) - %Out_gpu_generic_cast = memref.cast %Out_gpu_generic : memref<8x16xf32> to memref<*xf32> - %Out_gpu_large_cast = memref.cast %Out_gpu_large : memref<8x16xf32> to memref<*xf32> // run CPU version - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %a0 = memref.load %A[%i, %j] : memref<8x16xf32> - %vexp = math.exp %a0: f32 - memref.store %vexp, %Out_cpu[%i, %j] : memref<8x16xf32> + %cst = arith.constant -1.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x16xf32> + %alloc_1 = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %0:2 = call @test(%alloc) : (memref<8x16xf32>) -> (memref<8x16xf32>, memref<8x16xf32>) + %cast_2 = memref.cast %0#1 : memref<8x16xf32> to memref<*xf32> + %cast_3 = memref.cast %0#0 : memref<8x16xf32> to memref<*xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<8x16xf32> + %2 = math.exp %1 : f32 + memref.store %2, %alloc_1[%arg0, %arg1] : memref<8x16xf32> } } - %Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32> // print GPU and CPU outs // call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_generic_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () - call @printAllcloseF32(%Out_gpu_large_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () // dealloc - memref.dealloc %A : memref<8x16xf32> - memref.dealloc %Out_cpu : memref<8x16xf32> // gpu dealloc - gpu.dealloc %Out_gpu_generic : memref<8x16xf32> - gpu.dealloc %Out_gpu_large : memref<8x16xf32> + %cast_4 = memref.cast %alloc_1 : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast_2, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_3, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x16xf32> + memref.dealloc %alloc_1 : memref<8x16xf32> + memref.dealloc %0#1 : memref<8x16xf32> + memref.dealloc %0#0 : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/flash_attention_fwd.mlir b/test/Integration/Dialect/XeGPU/VC/flash_attention_fwd.mlir index 476b4f4f8..0c8e1af43 100644 --- a/test/Integration/Dialect/XeGPU/VC/flash_attention_fwd.mlir +++ b/test/Integration/Dialect/XeGPU/VC/flash_attention_fwd.mlir @@ -1,1052 +1,762 @@ -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @flash_attention attributes {gpu.container_module} { - gpu.module @flash_attention_fwd attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @flash_attention_fwd( - %Q : memref, - %K : memref, - %V : memref, - %Out : memref, - %sm_scale : f32, - %stride_qz : index, %stride_qh : index, %stride_qm : index, %stride_qk : index, - %stride_kz : index, %stride_kh : index, %stride_kn : index, %stride_kk : index, - %stride_vz : index, %stride_vh : index, %stride_vk : index, %stride_vn : index, - %stride_oz : index, %stride_oh : index, %stride_om : index, %stride_on : index, - %Z : index, %H : index, - %N_CTX : index, - %BLOCK_M : index, - %BLOCK_DMODEL : index, - %BLOCK_N : index - ) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +module @flash_attention attributes {gpu.container_module} { + gpu.module @flash_attention_fwd { + gpu.func @flash_attention_fwd(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: f32, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: index, %arg11: index, %arg12: index, %arg13: index, %arg14: index, %arg15: index, %arg16: index, %arg17: index, %arg18: index, %arg19: index, %arg20: index, %arg21: index, %arg22: index, %arg23: index, %arg24: index, %arg25: index, %arg26: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %start_m = gpu.block_id x - %off_hz = gpu.block_id y - %sg_id = gpu.subgroup_id : index - // memref sizes in x dim - %size_x_t0 = arith.muli %Z, %H : index - %size_x = arith.muli %size_x_t0, %N_CTX : index - // calculate the WG x offset of the q tile. This is equal to off_hz * N_CTX + start_m * BLOCK_M - %wg_x_offset = arith.muli %off_hz, %N_CTX : index - %offset_m = arith.muli %start_m, %BLOCK_M : index - %wg_q_x_offset = arith.addi %wg_x_offset, %offset_m : index - // for k and v offsets are off_zh * N_CTX because inside the K loop we will consume N_CTX length // this is eqaul to wg_x_offset - // compute the SG x offset for the q tile. // wg_q_offset + sg_x_slice_size * sg_id - %sg_x_slice_size = arith.divui %BLOCK_M, %c8 : index - %sg_q_x_offset_t0 = arith.muli %sg_id, %sg_x_slice_size : index - %sg_q_x_offset = arith.addi %wg_q_x_offset, %sg_q_x_offset_t0 : index - // init tile for 16x64 Q tiles - %q_tile_init_0 = xegpu.create_nd_tdesc %Q[%sg_q_x_offset, %c0], shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> - %q_tile_init_1 = xegpu.update_nd_offset %q_tile_init_0, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %q_tile_init_2 = xegpu.update_nd_offset %q_tile_init_1, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %q_tile_init_3 = xegpu.update_nd_offset %q_tile_init_2, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - // init tile for 64x64 K tiles. We do this in 4 stages of 16x64 tiles to reduce register pressure. // k is reused by all SGs - %k_tile_slice_0_0_init = xegpu.create_nd_tdesc %K [%wg_x_offset, %c0], shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_0_1_init = xegpu.update_nd_offset %k_tile_slice_0_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_0_2_init = xegpu.update_nd_offset %k_tile_slice_0_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_0_3_init = xegpu.update_nd_offset %k_tile_slice_0_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - - %k_tile_slice_1_0_init = xegpu.update_nd_offset %k_tile_slice_0_0_init, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_1_1_init = xegpu.update_nd_offset %k_tile_slice_1_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_1_2_init = xegpu.update_nd_offset %k_tile_slice_1_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_1_3_init = xegpu.update_nd_offset %k_tile_slice_1_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - - %k_tile_slice_2_0_init = xegpu.update_nd_offset %k_tile_slice_1_0_init, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_2_1_init = xegpu.update_nd_offset %k_tile_slice_2_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_2_2_init = xegpu.update_nd_offset %k_tile_slice_2_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_2_3_init = xegpu.update_nd_offset %k_tile_slice_2_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - - %k_tile_slice_3_0_init = xegpu.update_nd_offset %k_tile_slice_2_0_init, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_3_1_init = xegpu.update_nd_offset %k_tile_slice_3_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_3_2_init = xegpu.update_nd_offset %k_tile_slice_3_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_3_3_init = xegpu.update_nd_offset %k_tile_slice_3_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - // same for V tiles - %v_tile_slice_0_0_init = xegpu.create_nd_tdesc %V [%wg_x_offset, %c0], shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_0_1_init = xegpu.update_nd_offset %v_tile_slice_0_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_0_2_init = xegpu.update_nd_offset %v_tile_slice_0_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_0_3_init = xegpu.update_nd_offset %v_tile_slice_0_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - - %v_tile_slice_1_0_init = xegpu.update_nd_offset %v_tile_slice_0_0_init, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_1_1_init = xegpu.update_nd_offset %v_tile_slice_1_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_1_2_init = xegpu.update_nd_offset %v_tile_slice_1_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_1_3_init = xegpu.update_nd_offset %v_tile_slice_1_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - - %v_tile_slice_2_0_init = xegpu.update_nd_offset %v_tile_slice_1_0_init, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_2_1_init = xegpu.update_nd_offset %v_tile_slice_2_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_2_2_init = xegpu.update_nd_offset %v_tile_slice_2_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_2_3_init = xegpu.update_nd_offset %v_tile_slice_2_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - - %v_tile_slice_3_0_init = xegpu.update_nd_offset %v_tile_slice_2_0_init, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_3_1_init = xegpu.update_nd_offset %v_tile_slice_3_0_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_3_2_init = xegpu.update_nd_offset %v_tile_slice_3_1_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_3_3_init = xegpu.update_nd_offset %v_tile_slice_3_2_init, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - // k preftech // prefetch 16x32 tiles in 4x2 layout to cover 64x64 // x offset for prefetch is same as for q tiles. This means that WGs assigned to same bacth also colloborate on prefetching // the K, V tiles. // We also tried WGs prefetching from the begining of the K, V tiles but that did not work well because multiple // WGs compete to prefetch the same data. + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = gpu.subgroup_id : index + %1 = arith.muli %arg21, %arg22 : index + %2 = arith.muli %1, %arg23 : index + %3 = arith.muli %block_id_y, %arg23 : index + %4 = arith.muli %block_id_x, %arg24 : index + %5 = arith.addi %3, %4 : index + %6 = arith.divui %arg24, %c8 : index + %7 = arith.muli %0, %6 : index + %8 = arith.addi %5, %7 : index + %9 = xegpu.create_nd_tdesc %arg0[%8, %c0], shape : [%2, %arg25], strides : [%arg25, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> + %10 = xegpu.update_nd_offset %9, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %11 = xegpu.update_nd_offset %10, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %12 = xegpu.update_nd_offset %11, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %13 = xegpu.create_nd_tdesc %arg1[%3, %c0], shape : [%2, %arg25], strides : [%arg25, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> + %14 = xegpu.update_nd_offset %13, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %15 = xegpu.update_nd_offset %14, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %16 = xegpu.update_nd_offset %15, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %17 = xegpu.update_nd_offset %13, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + %18 = xegpu.update_nd_offset %17, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %19 = xegpu.update_nd_offset %18, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %20 = xegpu.update_nd_offset %19, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %21 = xegpu.update_nd_offset %17, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + %22 = xegpu.update_nd_offset %21, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %23 = xegpu.update_nd_offset %22, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %24 = xegpu.update_nd_offset %23, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %25 = xegpu.update_nd_offset %21, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + %26 = xegpu.update_nd_offset %25, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %27 = xegpu.update_nd_offset %26, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %28 = xegpu.update_nd_offset %27, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %29 = xegpu.create_nd_tdesc %arg2[%3, %c0], shape : [%2, %arg25], strides : [%arg25, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> + %30 = xegpu.update_nd_offset %29, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %31 = xegpu.update_nd_offset %30, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %32 = xegpu.update_nd_offset %31, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %33 = xegpu.update_nd_offset %29, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + %34 = xegpu.update_nd_offset %33, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %35 = xegpu.update_nd_offset %34, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %36 = xegpu.update_nd_offset %35, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %37 = xegpu.update_nd_offset %33, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + %38 = xegpu.update_nd_offset %37, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %39 = xegpu.update_nd_offset %38, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %40 = xegpu.update_nd_offset %39, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %41 = xegpu.update_nd_offset %37, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + %42 = xegpu.update_nd_offset %41, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %43 = xegpu.update_nd_offset %42, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %44 = xegpu.update_nd_offset %43, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> %c2 = arith.constant 2 : index - %sg_layout_x = arith.divui %sg_id, %c2 : index - %sg_layout_y = arith.remui %sg_id, %c2 : index - - %prefetch_offset_x_t0 = arith.muli %sg_layout_x, %c16 : index - %prefetch_offset_x = arith.addi %wg_q_x_offset, %prefetch_offset_x_t0 : index - %prefetch_offset_y = arith.muli %sg_layout_y, %c32 : index - - %k_prefetch_iter0 = xegpu.create_nd_tdesc %K [%prefetch_offset_x, %prefetch_offset_y], shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x32xf16> - xegpu.prefetch_nd %k_prefetch_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x32xf16> - %k_prefetch_iter1 = xegpu.update_nd_offset %k_prefetch_iter0, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - xegpu.prefetch_nd %k_prefetch_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x32xf16> - %k_prefetch_iter2 = xegpu.update_nd_offset %k_prefetch_iter1, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - xegpu.prefetch_nd %k_prefetch_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x32xf16> - %k_prefetch_iter3 = xegpu.update_nd_offset %k_prefetch_iter2, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - // V prefetch is similar to K - %v_prefetch_iter0 = xegpu.create_nd_tdesc %V [%prefetch_offset_x, %prefetch_offset_y], shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x32xf16> - xegpu.prefetch_nd %v_prefetch_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x32xf16> - %v_prefetch_iter1 = xegpu.update_nd_offset %v_prefetch_iter0, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - xegpu.prefetch_nd %v_prefetch_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x32xf16> - %v_prefetch_iter2 = xegpu.update_nd_offset %v_prefetch_iter1, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - xegpu.prefetch_nd %v_prefetch_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x32xf16> - %v_prefetch_iter3 = xegpu.update_nd_offset %v_prefetch_iter2, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - - // initialize m, l and acc - %m_i_row_0_in_flat = arith.constant dense<0xFF800000> : vector<8xf32> // -inf - %m_i_row_1_in_flat = arith.constant dense<0xFF800000> : vector<8xf32> // -inf - %l_i_row_0_in_flat = arith.constant dense<1.0> : vector<8xf32> // 1.0 - %l_i_row_1_in_flat = arith.constant dense<1.0> : vector<8xf32> // 1.0 - %m_i_row_0_in = vector.shape_cast %m_i_row_0_in_flat : vector<8xf32> to vector<8x1xf32> - %m_i_row_1_in = vector.shape_cast %m_i_row_1_in_flat : vector<8xf32> to vector<8x1xf32> - %l_i_row_0_in = vector.shape_cast %l_i_row_0_in_flat : vector<8xf32> to vector<8x1xf32> - %l_i_row_1_in = vector.shape_cast %l_i_row_1_in_flat : vector<8xf32> to vector<8x1xf32> - %zero = arith.constant dense<0.0> : vector<128xf32> - %zero_dpas = vector.shape_cast %zero : vector<128xf32> to vector<8x16xf32> - // softmax scaling // %qk_scale_8 = spirv.CompositeConstruct %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale : (f32, f32, f32, f32, f32, f32, f32, f32) -> vector<8xf32> // %qk_scale_16 = spirv.CompositeConstruct %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale,%sm_scale, %sm_scale, %sm_scale, %sm_scale,%sm_scale, %sm_scale, %sm_scale, %sm_scale : (f32, f32, f32, f32,f32, f32, f32, f32,f32, f32, f32, f32,f32, f32, f32, f32 ) -> vector<16xf32> // FIXME: value 0.5 is hard coded. need to take it from %sm_scale - %qk_scale_8 = arith.constant dense<0.5> : vector<8xf32> - %qk_scale_16 = arith.constant dense<0.5> : vector<16xf32> - %qk_scale_8x1 = vector.shape_cast %qk_scale_8 : vector<8xf32> to vector<8x1xf32> - %qk_scale_1x16 = vector.shape_cast %qk_scale_16 : vector<16xf32> to vector<1x16xf32> - %qk_scale_8x16 = vector.shuffle %qk_scale_1x16, %qk_scale_1x16 [0, 0, 0, 0, 0, 0, 0, 0] : vector<1x16xf32>, vector<1x16xf32> - - // load Q tiles - %q_block_value_0 = xegpu.load_nd %q_tile_init_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> - %q_block_value_1 = xegpu.load_nd %q_tile_init_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> - %q_block_value_2 = xegpu.load_nd %q_tile_init_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> - %q_block_value_3 = xegpu.load_nd %q_tile_init_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> - - %q_block_value_0_flat = vector.shape_cast %q_block_value_0 : vector<16x16xf16> to vector<256xf16> - %q_block_value_1_flat = vector.shape_cast %q_block_value_1 : vector<16x16xf16> to vector<256xf16> - %q_block_value_2_flat = vector.shape_cast %q_block_value_2 : vector<16x16xf16> to vector<256xf16> - %q_block_value_3_flat = vector.shape_cast %q_block_value_3 : vector<16x16xf16> to vector<256xf16> - - %q_block_value_0_0_t0 = vector.extract_strided_slice %q_block_value_0_flat { offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_0_0 = vector.shape_cast %q_block_value_0_0_t0 : vector<128xf16> to vector<8x16xf16> - - %q_block_value_1_0_t0 = vector.extract_strided_slice %q_block_value_0_flat { offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_1_0 = vector.shape_cast %q_block_value_1_0_t0 : vector<128xf16> to vector<8x16xf16> - - %q_block_value_0_1_t0 = vector.extract_strided_slice %q_block_value_1_flat { offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_0_1 = vector.shape_cast %q_block_value_0_1_t0 : vector<128xf16> to vector<8x16xf16> - - %q_block_value_1_1_t0 = vector.extract_strided_slice %q_block_value_1_flat { offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_1_1 = vector.shape_cast %q_block_value_1_1_t0 : vector<128xf16> to vector<8x16xf16> - // ---- - %q_block_value_0_2_t0 = vector.extract_strided_slice %q_block_value_2_flat { offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_0_2 = vector.shape_cast %q_block_value_0_2_t0 : vector<128xf16> to vector<8x16xf16> - - %q_block_value_1_2_t0 = vector.extract_strided_slice %q_block_value_2_flat { offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_1_2 = vector.shape_cast %q_block_value_1_2_t0 : vector<128xf16> to vector<8x16xf16> - - %q_block_value_0_3_t0 = vector.extract_strided_slice %q_block_value_3_flat { offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_0_3 = vector.shape_cast %q_block_value_0_3_t0 : vector<128xf16> to vector<8x16xf16> - - %q_block_value_1_3_t0 = vector.extract_strided_slice %q_block_value_3_flat { offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> - %q_block_value_1_3 = vector.shape_cast %q_block_value_1_3_t0 : vector<128xf16> to vector<8x16xf16> - + %45 = arith.divui %0, %c2 : index + %46 = arith.remui %0, %c2 : index + %47 = arith.muli %45, %c16 : index + %48 = arith.addi %5, %47 : index + %49 = arith.muli %46, %c32 : index + %50 = xegpu.create_nd_tdesc %arg1[%48, %49], shape : [%2, %arg25], strides : [%arg25, %c1] : memref -> !xegpu.tensor_desc<16x32xf16> + xegpu.prefetch_nd %50 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x32xf16> + %51 = xegpu.update_nd_offset %50, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + xegpu.prefetch_nd %51 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x32xf16> + %52 = xegpu.update_nd_offset %51, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + xegpu.prefetch_nd %52 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x32xf16> + %53 = xegpu.update_nd_offset %52, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + %54 = xegpu.create_nd_tdesc %arg2[%48, %49], shape : [%2, %arg25], strides : [%arg25, %c1] : memref -> !xegpu.tensor_desc<16x32xf16> + xegpu.prefetch_nd %54 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x32xf16> + %55 = xegpu.update_nd_offset %54, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + xegpu.prefetch_nd %55 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x32xf16> + %56 = xegpu.update_nd_offset %55, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + xegpu.prefetch_nd %56 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x32xf16> + %57 = xegpu.update_nd_offset %56, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + %cst = arith.constant dense<0xFF800000> : vector<8xf32> + %cst_0 = arith.constant dense<0xFF800000> : vector<8xf32> + %cst_1 = arith.constant dense<1.000000e+00> : vector<8xf32> + %cst_2 = arith.constant dense<1.000000e+00> : vector<8xf32> + %58 = vector.shape_cast %cst : vector<8xf32> to vector<8x1xf32> + %59 = vector.shape_cast %cst_0 : vector<8xf32> to vector<8x1xf32> + %60 = vector.shape_cast %cst_1 : vector<8xf32> to vector<8x1xf32> + %61 = vector.shape_cast %cst_2 : vector<8xf32> to vector<8x1xf32> + %cst_3 = arith.constant dense<0.000000e+00> : vector<128xf32> + %62 = vector.shape_cast %cst_3 : vector<128xf32> to vector<8x16xf32> + %cst_4 = arith.constant dense<5.000000e-01> : vector<8xf32> + %cst_5 = arith.constant dense<5.000000e-01> : vector<16xf32> + %63 = vector.shape_cast %cst_4 : vector<8xf32> to vector<8x1xf32> + %64 = vector.shape_cast %cst_5 : vector<16xf32> to vector<1x16xf32> + %65 = vector.shuffle %64, %64 [0, 0, 0, 0, 0, 0, 0, 0] : vector<1x16xf32>, vector<1x16xf32> + %66 = xegpu.load_nd %9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %67 = xegpu.load_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %68 = xegpu.load_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %69 = xegpu.load_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %70 = vector.shape_cast %66 : vector<16x16xf16> to vector<256xf16> + %71 = vector.shape_cast %67 : vector<16x16xf16> to vector<256xf16> + %72 = vector.shape_cast %68 : vector<16x16xf16> to vector<256xf16> + %73 = vector.shape_cast %69 : vector<16x16xf16> to vector<256xf16> + %74 = vector.extract_strided_slice %70 {offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %75 = vector.shape_cast %74 : vector<128xf16> to vector<8x16xf16> + %76 = vector.extract_strided_slice %70 {offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %77 = vector.shape_cast %76 : vector<128xf16> to vector<8x16xf16> + %78 = vector.extract_strided_slice %71 {offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %79 = vector.shape_cast %78 : vector<128xf16> to vector<8x16xf16> + %80 = vector.extract_strided_slice %71 {offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %81 = vector.shape_cast %80 : vector<128xf16> to vector<8x16xf16> + %82 = vector.extract_strided_slice %72 {offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %83 = vector.shape_cast %82 : vector<128xf16> to vector<8x16xf16> + %84 = vector.extract_strided_slice %72 {offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %85 = vector.shape_cast %84 : vector<128xf16> to vector<8x16xf16> + %86 = vector.extract_strided_slice %73 {offsets = [0], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %87 = vector.shape_cast %86 : vector<128xf16> to vector<8x16xf16> + %88 = vector.extract_strided_slice %73 {offsets = [128], sizes = [128], strides = [1]} : vector<256xf16> to vector<128xf16> + %89 = vector.shape_cast %88 : vector<128xf16> to vector<8x16xf16> xegpu.alloc_nbarrier 16 - %nbarrier_id = arith.constant 1 : i8 - %num_threads = arith.constant 8 : i8 - %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier - - // inner loop. This loop iterate over K and V tiles and update the accumulator by computing softmax(q*k^T)*v - %result:46 = scf.for %k = %c0 to %N_CTX step %BLOCK_N iter_args - ( - %acc_in_0_0 = %zero_dpas, - %acc_in_0_1 = %zero_dpas, - %acc_in_0_2 = %zero_dpas, - %acc_in_0_3 = %zero_dpas, - %acc_in_1_0 = %zero_dpas, - %acc_in_1_1 = %zero_dpas, - %acc_in_1_2 = %zero_dpas, - %acc_in_1_3 = %zero_dpas, - - %k_tile_slice_0_0 = %k_tile_slice_0_0_init, - %k_tile_slice_0_1 = %k_tile_slice_0_1_init, - %k_tile_slice_0_2 = %k_tile_slice_0_2_init, - %k_tile_slice_0_3 = %k_tile_slice_0_3_init, - %k_tile_slice_1_0 = %k_tile_slice_1_0_init, - %k_tile_slice_1_1 = %k_tile_slice_1_1_init, - %k_tile_slice_1_2 = %k_tile_slice_1_2_init, - %k_tile_slice_1_3 = %k_tile_slice_1_3_init, - %k_tile_slice_2_0 = %k_tile_slice_2_0_init, - %k_tile_slice_2_1 = %k_tile_slice_2_1_init, - %k_tile_slice_2_2 = %k_tile_slice_2_2_init, - %k_tile_slice_2_3 = %k_tile_slice_2_3_init, - %k_tile_slice_3_0 = %k_tile_slice_3_0_init, - %k_tile_slice_3_1 = %k_tile_slice_3_1_init, - %k_tile_slice_3_2 = %k_tile_slice_3_2_init, - %k_tile_slice_3_3 = %k_tile_slice_3_3_init, - - %v_tile_slice_0_0 = %v_tile_slice_0_0_init, - %v_tile_slice_0_1 = %v_tile_slice_0_1_init, - %v_tile_slice_0_2 = %v_tile_slice_0_2_init, - %v_tile_slice_0_3 = %v_tile_slice_0_3_init, - %v_tile_slice_1_0 = %v_tile_slice_1_0_init, - %v_tile_slice_1_1 = %v_tile_slice_1_1_init, - %v_tile_slice_1_2 = %v_tile_slice_1_2_init, - %v_tile_slice_1_3 = %v_tile_slice_1_3_init, - %v_tile_slice_2_0 = %v_tile_slice_2_0_init, - %v_tile_slice_2_1 = %v_tile_slice_2_1_init, - %v_tile_slice_2_2 = %v_tile_slice_2_2_init, - %v_tile_slice_2_3 = %v_tile_slice_2_3_init, - %v_tile_slice_3_0 = %v_tile_slice_3_0_init, - %v_tile_slice_3_1 = %v_tile_slice_3_1_init, - %v_tile_slice_3_2 = %v_tile_slice_3_2_init, - %v_tile_slice_3_3 = %v_tile_slice_3_3_init, - /// prefetch - %k_prefetch_tile = %k_prefetch_iter3, - %v_prefetch_tile = %v_prefetch_iter3, - - %m_i_row_0 = %m_i_row_0_in, - %m_i_row_1 = %m_i_row_1_in, - %l_i_row_0 = %l_i_row_0_in, - %l_i_row_1 = %l_i_row_1_in - ) - -> ( - vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, - !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, - - !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, - - !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, - vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32> - - ) { - xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier - // k prefetch - xegpu.prefetch_nd %k_prefetch_tile : !xegpu.tensor_desc<16x32xf16> - %k_prefetch_tile_new = xegpu.update_nd_offset %k_prefetch_tile, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - xegpu.compile_hint // V prefetch - xegpu.prefetch_nd %v_prefetch_tile : !xegpu.tensor_desc<16x32xf16> - %v_prefetch_tile_new = xegpu.update_nd_offset %v_prefetch_tile, [%BLOCK_N, %c0] : !xegpu.tensor_desc<16x32xf16> - - xegpu.compile_hint - // load first 16x64 K slice - %k_value_slice_0_0 = xegpu.load_nd %k_tile_slice_0_0 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_0_1 = xegpu.load_nd %k_tile_slice_0_1 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_0_2 = xegpu.load_nd %k_tile_slice_0_2 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_0_3 = xegpu.load_nd %k_tile_slice_0_3 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %k_tile_slice_0_0_new = xegpu.update_nd_offset %k_tile_slice_0_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_0_1_new = xegpu.update_nd_offset %k_tile_slice_0_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_0_2_new = xegpu.update_nd_offset %k_tile_slice_0_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_0_3_new = xegpu.update_nd_offset %k_tile_slice_0_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint - - - // compute first 16x16 of Q * K^T using DPAS - %qk_out_0_0_t0 = xegpu.dpas %q_block_value_0_0, %k_value_slice_0_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_0_t0 = xegpu.dpas %q_block_value_1_0, %k_value_slice_0_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_0_t1 = xegpu.dpas %q_block_value_0_1, %k_value_slice_0_1, %qk_out_0_0_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_0_t1 = xegpu.dpas %q_block_value_1_1, %k_value_slice_0_1, %qk_out_1_0_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_0_t2 = xegpu.dpas %q_block_value_0_2, %k_value_slice_0_2, %qk_out_0_0_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_0_t2 = xegpu.dpas %q_block_value_1_2, %k_value_slice_0_2, %qk_out_1_0_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_0 = xegpu.dpas %q_block_value_0_3, %k_value_slice_0_3, %qk_out_0_0_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_0 = xegpu.dpas %q_block_value_1_3, %k_value_slice_0_3, %qk_out_1_0_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - xegpu.compile_hint - // load second 16x64 K slice - %k_value_slice_1_0 = xegpu.load_nd %k_tile_slice_1_0 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_1_1 = xegpu.load_nd %k_tile_slice_1_1 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_1_2 = xegpu.load_nd %k_tile_slice_1_2 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_1_3 = xegpu.load_nd %k_tile_slice_1_3 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %k_tile_slice_1_0_new = xegpu.update_nd_offset %k_tile_slice_1_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_1_1_new = xegpu.update_nd_offset %k_tile_slice_1_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_1_2_new = xegpu.update_nd_offset %k_tile_slice_1_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_1_3_new = xegpu.update_nd_offset %k_tile_slice_1_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint - // compute second 16x16 of Q * K^T using DPAS - %qk_out_0_1_t0 = xegpu.dpas %q_block_value_0_0, %k_value_slice_1_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_1_t1 = xegpu.dpas %q_block_value_0_1, %k_value_slice_1_1, %qk_out_0_1_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_1_t2 = xegpu.dpas %q_block_value_0_2, %k_value_slice_1_2, %qk_out_0_1_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_1 = xegpu.dpas %q_block_value_0_3, %k_value_slice_1_3, %qk_out_0_1_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %qk_out_1_1_t0 = xegpu.dpas %q_block_value_1_0, %k_value_slice_1_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_1_t1 = xegpu.dpas %q_block_value_1_1, %k_value_slice_1_1, %qk_out_1_1_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_1_t2 = xegpu.dpas %q_block_value_1_2, %k_value_slice_1_2, %qk_out_1_1_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_1 = xegpu.dpas %q_block_value_1_3, %k_value_slice_1_3, %qk_out_1_1_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - xegpu.compile_hint - // load third 16x64 K slice - %k_value_slice_2_0 = xegpu.load_nd %k_tile_slice_2_0 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_2_1 = xegpu.load_nd %k_tile_slice_2_1 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_2_2 = xegpu.load_nd %k_tile_slice_2_2 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_2_3 = xegpu.load_nd %k_tile_slice_2_3 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %k_tile_slice_2_0_new = xegpu.update_nd_offset %k_tile_slice_2_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_2_1_new = xegpu.update_nd_offset %k_tile_slice_2_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_2_2_new = xegpu.update_nd_offset %k_tile_slice_2_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_2_3_new = xegpu.update_nd_offset %k_tile_slice_2_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint - // compute third 16x16 of Q * K^T using DPAS - %qk_out_0_2_t0 = xegpu.dpas %q_block_value_0_0, %k_value_slice_2_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_2_t1 = xegpu.dpas %q_block_value_0_1, %k_value_slice_2_1, %qk_out_0_2_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_2_t2 = xegpu.dpas %q_block_value_0_2, %k_value_slice_2_2, %qk_out_0_2_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_2 = xegpu.dpas %q_block_value_0_3, %k_value_slice_2_3, %qk_out_0_2_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %qk_out_1_2_t0 = xegpu.dpas %q_block_value_1_0, %k_value_slice_2_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_2_t1 = xegpu.dpas %q_block_value_1_1, %k_value_slice_2_1, %qk_out_1_2_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_2_t2 = xegpu.dpas %q_block_value_1_2, %k_value_slice_2_2, %qk_out_1_2_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_2 = xegpu.dpas %q_block_value_1_3, %k_value_slice_2_3, %qk_out_1_2_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - xegpu.compile_hint - // load forth 16x64 K slice - %k_value_slice_3_0 = xegpu.load_nd %k_tile_slice_3_0 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_3_1 = xegpu.load_nd %k_tile_slice_3_1 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_3_2 = xegpu.load_nd %k_tile_slice_3_2 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %k_value_slice_3_3 = xegpu.load_nd %k_tile_slice_3_3 {transpose_bit_width = 32 : i32, transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %k_tile_slice_3_0_new = xegpu.update_nd_offset %k_tile_slice_3_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_3_1_new = xegpu.update_nd_offset %k_tile_slice_3_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_3_2_new = xegpu.update_nd_offset %k_tile_slice_3_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %k_tile_slice_3_3_new = xegpu.update_nd_offset %k_tile_slice_3_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint - // compute forth 16x16 of Q * K^T using DPAS - %qk_out_0_3_t0 = xegpu.dpas %q_block_value_0_0, %k_value_slice_3_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_3_t1 = xegpu.dpas %q_block_value_0_1, %k_value_slice_3_1, %qk_out_0_3_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_3_t2 = xegpu.dpas %q_block_value_0_2, %k_value_slice_3_2, %qk_out_0_3_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_0_3 = xegpu.dpas %q_block_value_0_3, %k_value_slice_3_3, %qk_out_0_3_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %qk_out_1_3_t0 = xegpu.dpas %q_block_value_1_0, %k_value_slice_3_0, %zero_dpas : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_3_t1 = xegpu.dpas %q_block_value_1_1, %k_value_slice_3_1, %qk_out_1_3_t0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_3_t2 = xegpu.dpas %q_block_value_1_2, %k_value_slice_3_2, %qk_out_1_3_t1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %qk_out_1_3 = xegpu.dpas %q_block_value_1_3, %k_value_slice_3_3, %qk_out_1_3_t2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - xegpu.compile_hint - // process row 0 of QK_out // do max reduction on qk_out row 0 - %qk_out_max_0_t0 = arith.maximumf %qk_out_0_0, %qk_out_0_1 fastmath : vector<8x16xf32> - %qk_out_max_0_t1 = arith.maximumf %qk_out_0_2, %qk_out_0_3 fastmath : vector<8x16xf32> - %qk_out_max_0_t2 = arith.maximumf %qk_out_max_0_t0, %qk_out_max_0_t1 fastmath : vector<8x16xf32> - %qk_out_max_0_t4 = vector.extract_strided_slice %qk_out_max_0_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 0]} : vector<8x16xf32> to vector<8x8xf32> - %qk_out_max_0_t5 = vector.extract_strided_slice %qk_out_max_0_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} : vector<8x16xf32> to vector<8x8xf32> - %qk_out_max_0_t6 = arith.maximumf %qk_out_max_0_t4, %qk_out_max_0_t5 fastmath : vector<8x8xf32> - %qk_out_max_0_t7 = vector.extract_strided_slice %qk_out_max_0_t6 {sizes = [8, 4], strides = [1, 1], offsets = [0, 0]} : vector<8x8xf32> to vector<8x4xf32> - %qk_out_max_0_t8 = vector.extract_strided_slice %qk_out_max_0_t6 {sizes = [8, 4], strides = [1, 1], offsets = [0, 4]} : vector<8x8xf32> to vector<8x4xf32> - %qk_out_max_0_t9 = arith.maximumf %qk_out_max_0_t7, %qk_out_max_0_t8 fastmath : vector<8x4xf32> - %qk_out_max_0_t10 = vector.extract_strided_slice %qk_out_max_0_t9 {sizes = [8, 2], strides = [1, 1], offsets = [0, 0]} : vector<8x4xf32> to vector<8x2xf32> - %qk_out_max_0_t11 = vector.extract_strided_slice %qk_out_max_0_t9 {sizes = [8, 2], strides = [1, 1], offsets = [0, 2]} : vector<8x4xf32> to vector<8x2xf32> - %qk_out_max_0_t12 = arith.maximumf %qk_out_max_0_t10, %qk_out_max_0_t11 fastmath : vector<8x2xf32> - %qk_out_max_0_t13 = vector.extract_strided_slice %qk_out_max_0_t12 {sizes = [8, 1], strides = [1, 1], offsets = [0, 0]} : vector<8x2xf32> to vector<8x1xf32> - %qk_out_max_0_t14 = vector.extract_strided_slice %qk_out_max_0_t12 {sizes = [8, 1], strides = [1, 1], offsets = [0, 1]} : vector<8x2xf32> to vector<8x1xf32> - %qk_out_max_0 = arith.maximumf %qk_out_max_0_t13, %qk_out_max_0_t14 fastmath : vector<8x1xf32> // scale - %qk_out_max_0_scaled = arith.mulf %qk_out_max_0, %qk_scale_8x1 : vector<8x1xf32> // find m_ij_row_0 - %m_ij_row_0 = arith.maximumf %qk_out_max_0_scaled, %m_i_row_0 fastmath : vector<8x1xf32> // scale qk row 0 by qk_scale - %qk_out_0_0_scaled = arith.mulf %qk_out_0_0, %qk_scale_8x16 : vector<8x16xf32> - %qk_out_0_1_scaled = arith.mulf %qk_out_0_1, %qk_scale_8x16 : vector<8x16xf32> - %qk_out_0_2_scaled = arith.mulf %qk_out_0_2, %qk_scale_8x16 : vector<8x16xf32> - %qk_out_0_3_scaled = arith.mulf %qk_out_0_3, %qk_scale_8x16 : vector<8x16xf32> // broadcast m_ij_row_0 to 8x16 - %m_ij_row_0_broadcasted_t1 = vector.shape_cast %m_ij_row_0 : vector<8x1xf32> to vector<8xf32> - %m_ij_row_0_broadcasted_t2 = vector.shuffle %m_ij_row_0_broadcasted_t1, %m_ij_row_0_broadcasted_t1 - [ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, - 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, - 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4, - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5, - 6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7] : vector<8xf32>, vector<8xf32> - %m_ij_row_0_broadcasted = vector.shape_cast %m_ij_row_0_broadcasted_t2 : vector<128xf32> to vector<8x16xf32> // center qk_out by m_ij_row_0 - %qk_out_0_0_centered = arith.subf %qk_out_0_0_scaled, %m_ij_row_0_broadcasted : vector<8x16xf32> - %qk_out_0_1_centered = arith.subf %qk_out_0_1_scaled, %m_ij_row_0_broadcasted : vector<8x16xf32> - %qk_out_0_2_centered = arith.subf %qk_out_0_2_scaled, %m_ij_row_0_broadcasted : vector<8x16xf32> - %qk_out_0_3_centered = arith.subf %qk_out_0_3_scaled, %m_ij_row_0_broadcasted : vector<8x16xf32> // take exp - %qk_out_0_0_exp = math.exp %qk_out_0_0_centered : vector<8x16xf32> - %qk_out_0_1_exp = math.exp %qk_out_0_1_centered : vector<8x16xf32> - %qk_out_0_2_exp = math.exp %qk_out_0_2_centered : vector<8x16xf32> - %qk_out_0_3_exp = math.exp %qk_out_0_3_centered : vector<8x16xf32> // do a sum reduction on exp output - %l_ij_row_0_t0 = arith.addf %qk_out_0_0_exp, %qk_out_0_1_exp : vector<8x16xf32> - %l_ij_row_0_t1 = arith.addf %qk_out_0_2_exp, %qk_out_0_3_exp : vector<8x16xf32> - %l_ij_row_0_t2 = arith.addf %l_ij_row_0_t0, %l_ij_row_0_t1 : vector<8x16xf32> - %l_ij_row_0_t3 = vector.extract_strided_slice %l_ij_row_0_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 0]} : vector<8x16xf32> to vector<8x8xf32> - %l_ij_row_0_t4 = vector.extract_strided_slice %l_ij_row_0_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} : vector<8x16xf32> to vector<8x8xf32> - %l_ij_row_0_t5 = arith.addf %l_ij_row_0_t3, %l_ij_row_0_t4 : vector<8x8xf32> - %l_ij_row_0_t6 = vector.extract_strided_slice %l_ij_row_0_t5 {sizes = [8, 4], strides = [1, 1], offsets = [0, 0]} : vector<8x8xf32> to vector<8x4xf32> - %l_ij_row_0_t7 = vector.extract_strided_slice %l_ij_row_0_t5 {sizes = [8, 4], strides = [1, 1], offsets = [0, 4]} : vector<8x8xf32> to vector<8x4xf32> - %l_ij_row_0_t8 = arith.addf %l_ij_row_0_t6, %l_ij_row_0_t7 : vector<8x4xf32> - %l_ij_row_0_t9 = vector.extract_strided_slice %l_ij_row_0_t8 {sizes = [8, 2], strides = [1, 1], offsets = [0, 0]} : vector<8x4xf32> to vector<8x2xf32> - %l_ij_row_0_t10 = vector.extract_strided_slice %l_ij_row_0_t8 {sizes = [8, 2], strides = [1, 1], offsets = [0, 2]} : vector<8x4xf32> to vector<8x2xf32> - %l_ij_row_0_t11 = arith.addf %l_ij_row_0_t9, %l_ij_row_0_t10 : vector<8x2xf32> - %l_ij_row_0_t12 = vector.extract_strided_slice %l_ij_row_0_t11 {sizes = [8, 1], strides = [1, 1], offsets = [0, 0]} : vector<8x2xf32> to vector<8x1xf32> - %l_ij_row_0_t13 = vector.extract_strided_slice %l_ij_row_0_t11 {sizes = [8, 1], strides = [1, 1], offsets = [0, 1]} : vector<8x2xf32> to vector<8x1xf32> - %l_ij_row_0 = arith.addf %l_ij_row_0_t12, %l_ij_row_0_t13 : vector<8x1xf32> // compute alpha - %alpha_row_0_t1 = arith.subf %m_i_row_0, %m_ij_row_0 : vector<8x1xf32> - %alpha_row_0 = math.exp %alpha_row_0_t1 : vector<8x1xf32> // update l_i - %l_i_row_0_new_t1 = arith.mulf %l_i_row_0, %alpha_row_0 : vector<8x1xf32> - %l_i_row_0_new = arith.addf %l_i_row_0_new_t1, %l_ij_row_0 : vector<8x1xf32> // update acc - %alpha_row_0_broadcasted_t1 = vector.shape_cast %alpha_row_0 : vector<8x1xf32> to vector<8xf32> - %alpha_row_0_broadcasted_t2 = vector.shuffle %alpha_row_0_broadcasted_t1, %alpha_row_0_broadcasted_t1 - [ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, - 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, - 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4, - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5, - 6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7] : vector<8xf32>, vector<8xf32> - %alpha_row_0_broadcasted = vector.shape_cast %alpha_row_0_broadcasted_t2 : vector<128xf32> to vector<8x16xf32> - %acc_in_0_0_updated = arith.mulf %acc_in_0_0, %alpha_row_0_broadcasted : vector<8x16xf32> - %acc_in_0_1_updated = arith.mulf %acc_in_0_1, %alpha_row_0_broadcasted : vector<8x16xf32> - %acc_in_0_2_updated = arith.mulf %acc_in_0_2, %alpha_row_0_broadcasted : vector<8x16xf32> - %acc_in_0_3_updated = arith.mulf %acc_in_0_3, %alpha_row_0_broadcasted : vector<8x16xf32> - - xegpu.compile_hint - // process row 1 of QK_out // do max reduction on qk_out row 1 - %qk_out_max_1_t0 = arith.maximumf %qk_out_1_0, %qk_out_1_1 fastmath : vector<8x16xf32> - %qk_out_max_1_t1 = arith.maximumf %qk_out_1_2, %qk_out_1_3 fastmath : vector<8x16xf32> - %qk_out_max_1_t2 = arith.maximumf %qk_out_max_1_t0, %qk_out_max_1_t1 fastmath : vector<8x16xf32> - %qk_out_max_1_t4 = vector.extract_strided_slice %qk_out_max_1_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 0]} : vector<8x16xf32> to vector<8x8xf32> - %qk_out_max_1_t5 = vector.extract_strided_slice %qk_out_max_1_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} : vector<8x16xf32> to vector<8x8xf32> - %qk_out_max_1_t6 = arith.maximumf %qk_out_max_1_t4, %qk_out_max_1_t5 fastmath : vector<8x8xf32> - %qk_out_max_1_t7 = vector.extract_strided_slice %qk_out_max_1_t6 {sizes = [8, 4], strides = [1, 1], offsets = [0, 0]} : vector<8x8xf32> to vector<8x4xf32> - %qk_out_max_1_t8 = vector.extract_strided_slice %qk_out_max_1_t6 {sizes = [8, 4], strides = [1, 1], offsets = [0, 4]} : vector<8x8xf32> to vector<8x4xf32> - %qk_out_max_1_t9 = arith.maximumf %qk_out_max_1_t7, %qk_out_max_1_t8 fastmath : vector<8x4xf32> - %qk_out_max_1_t10 = vector.extract_strided_slice %qk_out_max_1_t9 {sizes = [8, 2], strides = [1, 1], offsets = [0, 0]} : vector<8x4xf32> to vector<8x2xf32> - %qk_out_max_1_t11 = vector.extract_strided_slice %qk_out_max_1_t9 {sizes = [8, 2], strides = [1, 1], offsets = [0, 2]} : vector<8x4xf32> to vector<8x2xf32> - %qk_out_max_1_t12 = arith.maximumf %qk_out_max_1_t10, %qk_out_max_1_t11 fastmath : vector<8x2xf32> - %qk_out_max_1_t13 = vector.extract_strided_slice %qk_out_max_1_t12 {sizes = [8, 1], strides = [1, 1], offsets = [0, 0]} : vector<8x2xf32> to vector<8x1xf32> - %qk_out_max_1_t14 = vector.extract_strided_slice %qk_out_max_1_t12 {sizes = [8, 1], strides = [1, 1], offsets = [0, 1]} : vector<8x2xf32> to vector<8x1xf32> - %qk_out_max_1 = arith.maximumf %qk_out_max_1_t13, %qk_out_max_1_t14 fastmath : vector<8x1xf32> // scale - %qk_out_max_1_scaled = arith.mulf %qk_out_max_1, %qk_scale_8x1 : vector<8x1xf32> // find m_ij_row_0 - %m_ij_row_1 = arith.maximumf %qk_out_max_1_scaled, %m_i_row_1 fastmath : vector<8x1xf32> // scale qk row 0 by qk_scale - %qk_out_1_0_scaled = arith.mulf %qk_out_1_0, %qk_scale_8x16 : vector<8x16xf32> - %qk_out_1_1_scaled = arith.mulf %qk_out_1_1, %qk_scale_8x16 : vector<8x16xf32> - %qk_out_1_2_scaled = arith.mulf %qk_out_1_2, %qk_scale_8x16 : vector<8x16xf32> - %qk_out_1_3_scaled = arith.mulf %qk_out_1_3, %qk_scale_8x16 : vector<8x16xf32> // broadcast m_ij_row_0 to 8x16 - %m_ij_row_1_broadcasted_t1 = vector.shape_cast %m_ij_row_1 : vector<8x1xf32> to vector<8xf32> - %m_ij_row_1_broadcasted_t2 = vector.shuffle %m_ij_row_1_broadcasted_t1, %m_ij_row_1_broadcasted_t1 - [ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, - 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, - 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4, - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5, - 6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7] : vector<8xf32>, vector<8xf32> - %m_ij_row_1_broadcasted = vector.shape_cast %m_ij_row_1_broadcasted_t2 : vector<128xf32> to vector<8x16xf32> // center qk_out by m_ij_row_0 - %qk_out_1_0_centered = arith.subf %qk_out_1_0_scaled, %m_ij_row_1_broadcasted : vector<8x16xf32> - %qk_out_1_1_centered = arith.subf %qk_out_1_1_scaled, %m_ij_row_1_broadcasted : vector<8x16xf32> - %qk_out_1_2_centered = arith.subf %qk_out_1_2_scaled, %m_ij_row_1_broadcasted : vector<8x16xf32> - %qk_out_1_3_centered = arith.subf %qk_out_1_3_scaled, %m_ij_row_1_broadcasted : vector<8x16xf32> // take exp - %qk_out_1_0_exp = math.exp %qk_out_1_0_centered : vector<8x16xf32> - %qk_out_1_1_exp = math.exp %qk_out_1_1_centered : vector<8x16xf32> - %qk_out_1_2_exp = math.exp %qk_out_1_2_centered : vector<8x16xf32> - %qk_out_1_3_exp = math.exp %qk_out_1_3_centered : vector<8x16xf32> // do a sum reduction on exp output - %l_ij_row_1_t0 = arith.addf %qk_out_1_0_exp, %qk_out_1_1_exp : vector<8x16xf32> - %l_ij_row_1_t1 = arith.addf %qk_out_1_2_exp, %qk_out_1_3_exp : vector<8x16xf32> - %l_ij_row_1_t2 = arith.addf %l_ij_row_1_t0, %l_ij_row_1_t1 : vector<8x16xf32> - %l_ij_row_1_t3 = vector.extract_strided_slice %l_ij_row_1_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 0]} : vector<8x16xf32> to vector<8x8xf32> - %l_ij_row_1_t4 = vector.extract_strided_slice %l_ij_row_1_t2 {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} : vector<8x16xf32> to vector<8x8xf32> - %l_ij_row_1_t5 = arith.addf %l_ij_row_1_t3, %l_ij_row_1_t4 : vector<8x8xf32> - %l_ij_row_1_t6 = vector.extract_strided_slice %l_ij_row_1_t5 {sizes = [8, 4], strides = [1, 1], offsets = [0, 0]} : vector<8x8xf32> to vector<8x4xf32> - %l_ij_row_1_t7 = vector.extract_strided_slice %l_ij_row_1_t5 {sizes = [8, 4], strides = [1, 1], offsets = [0, 4]} : vector<8x8xf32> to vector<8x4xf32> - %l_ij_row_1_t8 = arith.addf %l_ij_row_1_t6, %l_ij_row_1_t7 : vector<8x4xf32> - %l_ij_row_1_t9 = vector.extract_strided_slice %l_ij_row_1_t8 {sizes = [8, 2], strides = [1, 1], offsets = [0, 0]} : vector<8x4xf32> to vector<8x2xf32> - %l_ij_row_1_t10 = vector.extract_strided_slice %l_ij_row_1_t8 {sizes = [8, 2], strides = [1, 1], offsets = [0, 2]} : vector<8x4xf32> to vector<8x2xf32> - %l_ij_row_1_t11 = arith.addf %l_ij_row_1_t9, %l_ij_row_1_t10 : vector<8x2xf32> - %l_ij_row_1_t12 = vector.extract_strided_slice %l_ij_row_1_t11 {sizes = [8, 1], strides = [1, 1], offsets = [0, 0]} : vector<8x2xf32> to vector<8x1xf32> - %l_ij_row_1_t13 = vector.extract_strided_slice %l_ij_row_1_t11 {sizes = [8, 1], strides = [1, 1], offsets = [0, 1]} : vector<8x2xf32> to vector<8x1xf32> - %l_ij_row_1 = arith.addf %l_ij_row_1_t12, %l_ij_row_1_t13 : vector<8x1xf32> // compute alpha - %alpha_row_1_t1 = arith.subf %m_i_row_1, %m_ij_row_1 : vector<8x1xf32> - %alpha_row_1 = math.exp %alpha_row_1_t1 : vector<8x1xf32> // update l_i - %l_i_row_1_new_t1 = arith.mulf %l_i_row_1, %alpha_row_1 : vector<8x1xf32> - %l_i_row_1_new = arith.addf %l_i_row_1_new_t1, %l_ij_row_1 : vector<8x1xf32> // update acc - %alpha_row_1_broadcasted_t1 = vector.shape_cast %alpha_row_1 : vector<8x1xf32> to vector<8xf32> - %alpha_row_1_broadcasted_t2 = vector.shuffle %alpha_row_1_broadcasted_t1, %alpha_row_1_broadcasted_t1 - [ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, - 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, - 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4, - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5, - 6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7] : vector<8xf32>, vector<8xf32> - %alpha_row_1_broadcasted = vector.shape_cast %alpha_row_1_broadcasted_t2 : vector<128xf32> to vector<8x16xf32> - %acc_in_1_0_updated = arith.mulf %acc_in_1_0, %alpha_row_1_broadcasted : vector<8x16xf32> - %acc_in_1_1_updated = arith.mulf %acc_in_1_1, %alpha_row_1_broadcasted : vector<8x16xf32> - %acc_in_1_2_updated = arith.mulf %acc_in_1_2, %alpha_row_1_broadcasted : vector<8x16xf32> - %acc_in_1_3_updated = arith.mulf %acc_in_1_3, %alpha_row_1_broadcasted : vector<8x16xf32> - // convert qk_out_tile to A format for DPAS for p * v computation - %qk_out_0_0_f16 = arith.truncf %qk_out_0_0_exp : vector<8x16xf32> to vector<8x16xf16> - %qk_out_0_1_f16 = arith.truncf %qk_out_0_1_exp : vector<8x16xf32> to vector<8x16xf16> - %qk_out_0_2_f16 = arith.truncf %qk_out_0_2_exp : vector<8x16xf32> to vector<8x16xf16> - %qk_out_0_3_f16 = arith.truncf %qk_out_0_3_exp : vector<8x16xf32> to vector<8x16xf16> - %qk_out_1_0_f16 = arith.truncf %qk_out_1_0_exp : vector<8x16xf32> to vector<8x16xf16> - %qk_out_1_1_f16 = arith.truncf %qk_out_1_1_exp : vector<8x16xf32> to vector<8x16xf16> - %qk_out_1_2_f16 = arith.truncf %qk_out_1_2_exp : vector<8x16xf32> to vector<8x16xf16> - %qk_out_1_3_f16 = arith.truncf %qk_out_1_3_exp : vector<8x16xf32> to vector<8x16xf16> - - xegpu.compile_hint // load first 16x64 V slices - %v_val_slice_0_0 = xegpu.load_nd %v_tile_slice_0_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_0_1 = xegpu.load_nd %v_tile_slice_0_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_0_2 = xegpu.load_nd %v_tile_slice_0_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_0_3 = xegpu.load_nd %v_tile_slice_0_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %v_tile_slice_0_0_new = xegpu.update_nd_offset %v_tile_slice_0_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_0_1_new = xegpu.update_nd_offset %v_tile_slice_0_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_0_2_new = xegpu.update_nd_offset %v_tile_slice_0_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_0_3_new = xegpu.update_nd_offset %v_tile_slice_0_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint - - - xegpu.compile_hint // compute first iteration update of 16x64 of P * V - %pv_out_0_0_iter0 = xegpu.dpas %qk_out_0_0_f16, %v_val_slice_0_0, %acc_in_0_0_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_0_iter0 = xegpu.dpas %qk_out_1_0_f16, %v_val_slice_0_0, %acc_in_1_0_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_1_iter0 = xegpu.dpas %qk_out_0_0_f16, %v_val_slice_0_1, %acc_in_0_1_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_1_iter0 = xegpu.dpas %qk_out_1_0_f16, %v_val_slice_0_1, %acc_in_1_1_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_2_iter0 = xegpu.dpas %qk_out_0_0_f16, %v_val_slice_0_2, %acc_in_0_2_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_2_iter0 = xegpu.dpas %qk_out_1_0_f16, %v_val_slice_0_2, %acc_in_1_2_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_3_iter0 = xegpu.dpas %qk_out_0_0_f16, %v_val_slice_0_3, %acc_in_0_3_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_3_iter0 = xegpu.dpas %qk_out_1_0_f16, %v_val_slice_0_3, %acc_in_1_3_updated : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - xegpu.compile_hint - // load second 16x64 V slices - %v_val_slice_1_0 = xegpu.load_nd %v_tile_slice_1_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_1_1 = xegpu.load_nd %v_tile_slice_1_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_1_2 = xegpu.load_nd %v_tile_slice_1_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_1_3 = xegpu.load_nd %v_tile_slice_1_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %v_tile_slice_1_0_new = xegpu.update_nd_offset %v_tile_slice_1_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_1_1_new = xegpu.update_nd_offset %v_tile_slice_1_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_1_2_new = xegpu.update_nd_offset %v_tile_slice_1_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_1_3_new = xegpu.update_nd_offset %v_tile_slice_1_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint // compute second iteration update of 16x64 of P * V - %pv_out_0_0_iter1 = xegpu.dpas %qk_out_0_1_f16, %v_val_slice_1_0, %pv_out_0_0_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_0_iter1 = xegpu.dpas %qk_out_1_1_f16, %v_val_slice_1_0, %pv_out_1_0_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_1_iter1 = xegpu.dpas %qk_out_0_1_f16, %v_val_slice_1_1, %pv_out_0_1_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_1_iter1 = xegpu.dpas %qk_out_1_1_f16, %v_val_slice_1_1, %pv_out_1_1_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_2_iter1 = xegpu.dpas %qk_out_0_1_f16, %v_val_slice_1_2, %pv_out_0_2_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_2_iter1 = xegpu.dpas %qk_out_1_1_f16, %v_val_slice_1_2, %pv_out_1_2_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_3_iter1 = xegpu.dpas %qk_out_0_1_f16, %v_val_slice_1_3, %pv_out_0_3_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_3_iter1 = xegpu.dpas %qk_out_1_1_f16, %v_val_slice_1_3, %pv_out_1_3_iter0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - xegpu.compile_hint - // load third 16x64 V slices - %v_val_slice_2_0 = xegpu.load_nd %v_tile_slice_2_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_2_1 = xegpu.load_nd %v_tile_slice_2_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_2_2 = xegpu.load_nd %v_tile_slice_2_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_2_3 = xegpu.load_nd %v_tile_slice_2_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %v_tile_slice_2_0_new = xegpu.update_nd_offset %v_tile_slice_2_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_2_1_new = xegpu.update_nd_offset %v_tile_slice_2_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_2_2_new = xegpu.update_nd_offset %v_tile_slice_2_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_2_3_new = xegpu.update_nd_offset %v_tile_slice_2_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint // compute third iteration update of 16x64 of P * V - %pv_out_0_0_iter2 = xegpu.dpas %qk_out_0_2_f16, %v_val_slice_2_0, %pv_out_0_0_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_0_iter2 = xegpu.dpas %qk_out_1_2_f16, %v_val_slice_2_0, %pv_out_1_0_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_1_iter2 = xegpu.dpas %qk_out_0_2_f16, %v_val_slice_2_1, %pv_out_0_1_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_1_iter2 = xegpu.dpas %qk_out_1_2_f16, %v_val_slice_2_1, %pv_out_1_1_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_2_iter2 = xegpu.dpas %qk_out_0_2_f16, %v_val_slice_2_2, %pv_out_0_2_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_2_iter2 = xegpu.dpas %qk_out_1_2_f16, %v_val_slice_2_2, %pv_out_1_2_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_3_iter2 = xegpu.dpas %qk_out_0_2_f16, %v_val_slice_2_3, %pv_out_0_3_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_3_iter2 = xegpu.dpas %qk_out_1_2_f16, %v_val_slice_2_3, %pv_out_1_3_iter1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - xegpu.compile_hint - // load forth 16x64 V slices - %v_val_slice_3_0 = xegpu.load_nd %v_tile_slice_3_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_3_1 = xegpu.load_nd %v_tile_slice_3_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_3_2 = xegpu.load_nd %v_tile_slice_3_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %v_val_slice_3_3 = xegpu.load_nd %v_tile_slice_3_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - xegpu.compile_hint // update offsets - %v_tile_slice_3_0_new = xegpu.update_nd_offset %v_tile_slice_3_0, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_3_1_new = xegpu.update_nd_offset %v_tile_slice_3_1, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_3_2_new = xegpu.update_nd_offset %v_tile_slice_3_2, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - %v_tile_slice_3_3_new = xegpu.update_nd_offset %v_tile_slice_3_3, [%BLOCK_N , %c0] : !xegpu.tensor_desc<16x16xf16> - xegpu.compile_hint // compute third iteration update of 16x64 of P * V - %pv_out_0_0_iter3 = xegpu.dpas %qk_out_0_3_f16, %v_val_slice_3_0, %pv_out_0_0_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_0_iter3 = xegpu.dpas %qk_out_1_3_f16, %v_val_slice_3_0, %pv_out_1_0_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_1_iter3 = xegpu.dpas %qk_out_0_3_f16, %v_val_slice_3_1, %pv_out_0_1_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_1_iter3 = xegpu.dpas %qk_out_1_3_f16, %v_val_slice_3_1, %pv_out_1_1_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_2_iter3 = xegpu.dpas %qk_out_0_3_f16, %v_val_slice_3_2, %pv_out_0_2_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_2_iter3 = xegpu.dpas %qk_out_1_3_f16, %v_val_slice_3_2, %pv_out_1_2_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_0_3_iter3 = xegpu.dpas %qk_out_0_3_f16, %v_val_slice_3_3, %pv_out_0_3_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %pv_out_1_3_iter3 = xegpu.dpas %qk_out_1_3_f16, %v_val_slice_3_3, %pv_out_1_3_iter2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - xegpu.compile_hint - - xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier - - scf.yield - %pv_out_0_0_iter3, %pv_out_0_1_iter3, %pv_out_0_2_iter3, %pv_out_0_3_iter3, - %pv_out_1_0_iter3, %pv_out_1_1_iter3, %pv_out_1_2_iter3, %pv_out_1_3_iter3, - %k_tile_slice_0_0_new, %k_tile_slice_0_1_new, %k_tile_slice_0_2_new, %k_tile_slice_0_3_new, - %k_tile_slice_1_0_new, %k_tile_slice_1_1_new, %k_tile_slice_1_2_new, %k_tile_slice_1_3_new, - %k_tile_slice_2_0_new, %k_tile_slice_2_1_new, %k_tile_slice_2_2_new, %k_tile_slice_2_3_new, - %k_tile_slice_3_0_new, %k_tile_slice_3_1_new, %k_tile_slice_3_2_new, %k_tile_slice_3_3_new, - - %v_tile_slice_0_0_new, %v_tile_slice_0_1_new, %v_tile_slice_0_2_new, %v_tile_slice_0_3_new, - %v_tile_slice_1_0_new, %v_tile_slice_1_1_new, %v_tile_slice_1_2_new, %v_tile_slice_1_3_new, - %v_tile_slice_2_0_new, %v_tile_slice_2_1_new, %v_tile_slice_2_2_new, %v_tile_slice_2_3_new, - %v_tile_slice_3_0_new, %v_tile_slice_3_1_new, %v_tile_slice_3_2_new, %v_tile_slice_3_3_new, - - %k_prefetch_tile_new, %v_prefetch_tile_new, - %m_ij_row_0, %m_ij_row_1, %l_i_row_0_new, %l_i_row_1_new : - vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, - - !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, - !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, - !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, - vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32> - + %c1_i8 = arith.constant 1 : i8 + %c8_i8 = arith.constant 8 : i8 + %90 = xegpu.init_nbarrier %c1_i8, %c8_i8 : i8, i8 -> !xegpu.nbarrier + %91:46 = scf.for %arg27 = %c0 to %arg23 step %arg26 iter_args(%arg28 = %62, %arg29 = %62, %arg30 = %62, %arg31 = %62, %arg32 = %62, %arg33 = %62, %arg34 = %62, %arg35 = %62, %arg36 = %13, %arg37 = %14, %arg38 = %15, %arg39 = %16, %arg40 = %17, %arg41 = %18, %arg42 = %19, %arg43 = %20, %arg44 = %21, %arg45 = %22, %arg46 = %23, %arg47 = %24, %arg48 = %25, %arg49 = %26, %arg50 = %27, %arg51 = %28, %arg52 = %29, %arg53 = %30, %arg54 = %31, %arg55 = %32, %arg56 = %33, %arg57 = %34, %arg58 = %35, %arg59 = %36, %arg60 = %37, %arg61 = %38, %arg62 = %39, %arg63 = %40, %arg64 = %41, %arg65 = %42, %arg66 = %43, %arg67 = %44, %arg68 = %53, %arg69 = %57, %arg70 = %58, %arg71 = %59, %arg72 = %60, %arg73 = %61) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32>) { + xegpu.nbarrier_arrive %90 : !xegpu.nbarrier + xegpu.prefetch_nd %arg68 : !xegpu.tensor_desc<16x32xf16> + %130 = xegpu.update_nd_offset %arg68, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + xegpu.compile_hint + xegpu.prefetch_nd %arg69 : !xegpu.tensor_desc<16x32xf16> + %131 = xegpu.update_nd_offset %arg69, [%arg26, %c0] : !xegpu.tensor_desc<16x32xf16> + xegpu.compile_hint + %132 = xegpu.load_nd %arg36 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %133 = xegpu.load_nd %arg37 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %134 = xegpu.load_nd %arg38 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %135 = xegpu.load_nd %arg39 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %136 = xegpu.update_nd_offset %arg36, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %137 = xegpu.update_nd_offset %arg37, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %138 = xegpu.update_nd_offset %arg38, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %139 = xegpu.update_nd_offset %arg39, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + %140 = xegpu.dpas %75, %132, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %141 = xegpu.dpas %77, %132, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %142 = xegpu.dpas %79, %133, %140 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %143 = xegpu.dpas %81, %133, %141 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %144 = xegpu.dpas %83, %134, %142 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %145 = xegpu.dpas %85, %134, %143 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %146 = xegpu.dpas %87, %135, %144 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %147 = xegpu.dpas %89, %135, %145 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %148 = xegpu.load_nd %arg40 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %149 = xegpu.load_nd %arg41 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %150 = xegpu.load_nd %arg42 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %151 = xegpu.load_nd %arg43 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %152 = xegpu.update_nd_offset %arg40, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %153 = xegpu.update_nd_offset %arg41, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %154 = xegpu.update_nd_offset %arg42, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %155 = xegpu.update_nd_offset %arg43, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + %156 = xegpu.dpas %75, %148, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %157 = xegpu.dpas %79, %149, %156 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %158 = xegpu.dpas %83, %150, %157 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %159 = xegpu.dpas %87, %151, %158 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %160 = xegpu.dpas %77, %148, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %161 = xegpu.dpas %81, %149, %160 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %162 = xegpu.dpas %85, %150, %161 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %163 = xegpu.dpas %89, %151, %162 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %164 = xegpu.load_nd %arg44 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %165 = xegpu.load_nd %arg45 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %166 = xegpu.load_nd %arg46 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %167 = xegpu.load_nd %arg47 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %168 = xegpu.update_nd_offset %arg44, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %169 = xegpu.update_nd_offset %arg45, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %170 = xegpu.update_nd_offset %arg46, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %171 = xegpu.update_nd_offset %arg47, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + %172 = xegpu.dpas %75, %164, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %173 = xegpu.dpas %79, %165, %172 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %174 = xegpu.dpas %83, %166, %173 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %175 = xegpu.dpas %87, %167, %174 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %176 = xegpu.dpas %77, %164, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %177 = xegpu.dpas %81, %165, %176 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %178 = xegpu.dpas %85, %166, %177 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %179 = xegpu.dpas %89, %167, %178 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %180 = xegpu.load_nd %arg48 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %181 = xegpu.load_nd %arg49 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %182 = xegpu.load_nd %arg50 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %183 = xegpu.load_nd %arg51 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %184 = xegpu.update_nd_offset %arg48, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %185 = xegpu.update_nd_offset %arg49, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %186 = xegpu.update_nd_offset %arg50, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %187 = xegpu.update_nd_offset %arg51, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + %188 = xegpu.dpas %75, %180, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %189 = xegpu.dpas %79, %181, %188 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %190 = xegpu.dpas %83, %182, %189 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %191 = xegpu.dpas %87, %183, %190 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %192 = xegpu.dpas %77, %180, %62 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %193 = xegpu.dpas %81, %181, %192 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %194 = xegpu.dpas %85, %182, %193 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %195 = xegpu.dpas %89, %183, %194 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %196 = arith.maximumf %146, %159 fastmath : vector<8x16xf32> + %197 = arith.maximumf %175, %191 fastmath : vector<8x16xf32> + %198 = arith.maximumf %196, %197 fastmath : vector<8x16xf32> + %199 = vector.extract_strided_slice %198 {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %200 = vector.extract_strided_slice %198 {offsets = [0, 8], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %201 = arith.maximumf %199, %200 fastmath : vector<8x8xf32> + %202 = vector.extract_strided_slice %201 {offsets = [0, 0], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %203 = vector.extract_strided_slice %201 {offsets = [0, 4], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %204 = arith.maximumf %202, %203 fastmath : vector<8x4xf32> + %205 = vector.extract_strided_slice %204 {offsets = [0, 0], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %206 = vector.extract_strided_slice %204 {offsets = [0, 2], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %207 = arith.maximumf %205, %206 fastmath : vector<8x2xf32> + %208 = vector.extract_strided_slice %207 {offsets = [0, 0], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %209 = vector.extract_strided_slice %207 {offsets = [0, 1], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %210 = arith.maximumf %208, %209 fastmath : vector<8x1xf32> + %211 = arith.mulf %210, %63 : vector<8x1xf32> + %212 = arith.maximumf %211, %arg70 fastmath : vector<8x1xf32> + %213 = arith.mulf %146, %65 : vector<8x16xf32> + %214 = arith.mulf %159, %65 : vector<8x16xf32> + %215 = arith.mulf %175, %65 : vector<8x16xf32> + %216 = arith.mulf %191, %65 : vector<8x16xf32> + %217 = vector.shape_cast %212 : vector<8x1xf32> to vector<8xf32> + %218 = vector.shuffle %217, %217 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] : vector<8xf32>, vector<8xf32> + %219 = vector.shape_cast %218 : vector<128xf32> to vector<8x16xf32> + %220 = arith.subf %213, %219 : vector<8x16xf32> + %221 = arith.subf %214, %219 : vector<8x16xf32> + %222 = arith.subf %215, %219 : vector<8x16xf32> + %223 = arith.subf %216, %219 : vector<8x16xf32> + %224 = math.exp %220 : vector<8x16xf32> + %225 = math.exp %221 : vector<8x16xf32> + %226 = math.exp %222 : vector<8x16xf32> + %227 = math.exp %223 : vector<8x16xf32> + %228 = arith.addf %224, %225 : vector<8x16xf32> + %229 = arith.addf %226, %227 : vector<8x16xf32> + %230 = arith.addf %228, %229 : vector<8x16xf32> + %231 = vector.extract_strided_slice %230 {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %232 = vector.extract_strided_slice %230 {offsets = [0, 8], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %233 = arith.addf %231, %232 : vector<8x8xf32> + %234 = vector.extract_strided_slice %233 {offsets = [0, 0], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %235 = vector.extract_strided_slice %233 {offsets = [0, 4], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %236 = arith.addf %234, %235 : vector<8x4xf32> + %237 = vector.extract_strided_slice %236 {offsets = [0, 0], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %238 = vector.extract_strided_slice %236 {offsets = [0, 2], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %239 = arith.addf %237, %238 : vector<8x2xf32> + %240 = vector.extract_strided_slice %239 {offsets = [0, 0], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %241 = vector.extract_strided_slice %239 {offsets = [0, 1], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %242 = arith.addf %240, %241 : vector<8x1xf32> + %243 = arith.subf %arg70, %212 : vector<8x1xf32> + %244 = math.exp %243 : vector<8x1xf32> + %245 = arith.mulf %arg72, %244 : vector<8x1xf32> + %246 = arith.addf %245, %242 : vector<8x1xf32> + %247 = vector.shape_cast %244 : vector<8x1xf32> to vector<8xf32> + %248 = vector.shuffle %247, %247 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] : vector<8xf32>, vector<8xf32> + %249 = vector.shape_cast %248 : vector<128xf32> to vector<8x16xf32> + %250 = arith.mulf %arg28, %249 : vector<8x16xf32> + %251 = arith.mulf %arg29, %249 : vector<8x16xf32> + %252 = arith.mulf %arg30, %249 : vector<8x16xf32> + %253 = arith.mulf %arg31, %249 : vector<8x16xf32> + xegpu.compile_hint + %254 = arith.maximumf %147, %163 fastmath : vector<8x16xf32> + %255 = arith.maximumf %179, %195 fastmath : vector<8x16xf32> + %256 = arith.maximumf %254, %255 fastmath : vector<8x16xf32> + %257 = vector.extract_strided_slice %256 {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %258 = vector.extract_strided_slice %256 {offsets = [0, 8], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %259 = arith.maximumf %257, %258 fastmath : vector<8x8xf32> + %260 = vector.extract_strided_slice %259 {offsets = [0, 0], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %261 = vector.extract_strided_slice %259 {offsets = [0, 4], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %262 = arith.maximumf %260, %261 fastmath : vector<8x4xf32> + %263 = vector.extract_strided_slice %262 {offsets = [0, 0], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %264 = vector.extract_strided_slice %262 {offsets = [0, 2], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %265 = arith.maximumf %263, %264 fastmath : vector<8x2xf32> + %266 = vector.extract_strided_slice %265 {offsets = [0, 0], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %267 = vector.extract_strided_slice %265 {offsets = [0, 1], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %268 = arith.maximumf %266, %267 fastmath : vector<8x1xf32> + %269 = arith.mulf %268, %63 : vector<8x1xf32> + %270 = arith.maximumf %269, %arg71 fastmath : vector<8x1xf32> + %271 = arith.mulf %147, %65 : vector<8x16xf32> + %272 = arith.mulf %163, %65 : vector<8x16xf32> + %273 = arith.mulf %179, %65 : vector<8x16xf32> + %274 = arith.mulf %195, %65 : vector<8x16xf32> + %275 = vector.shape_cast %270 : vector<8x1xf32> to vector<8xf32> + %276 = vector.shuffle %275, %275 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] : vector<8xf32>, vector<8xf32> + %277 = vector.shape_cast %276 : vector<128xf32> to vector<8x16xf32> + %278 = arith.subf %271, %277 : vector<8x16xf32> + %279 = arith.subf %272, %277 : vector<8x16xf32> + %280 = arith.subf %273, %277 : vector<8x16xf32> + %281 = arith.subf %274, %277 : vector<8x16xf32> + %282 = math.exp %278 : vector<8x16xf32> + %283 = math.exp %279 : vector<8x16xf32> + %284 = math.exp %280 : vector<8x16xf32> + %285 = math.exp %281 : vector<8x16xf32> + %286 = arith.addf %282, %283 : vector<8x16xf32> + %287 = arith.addf %284, %285 : vector<8x16xf32> + %288 = arith.addf %286, %287 : vector<8x16xf32> + %289 = vector.extract_strided_slice %288 {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %290 = vector.extract_strided_slice %288 {offsets = [0, 8], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %291 = arith.addf %289, %290 : vector<8x8xf32> + %292 = vector.extract_strided_slice %291 {offsets = [0, 0], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %293 = vector.extract_strided_slice %291 {offsets = [0, 4], sizes = [8, 4], strides = [1, 1]} : vector<8x8xf32> to vector<8x4xf32> + %294 = arith.addf %292, %293 : vector<8x4xf32> + %295 = vector.extract_strided_slice %294 {offsets = [0, 0], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %296 = vector.extract_strided_slice %294 {offsets = [0, 2], sizes = [8, 2], strides = [1, 1]} : vector<8x4xf32> to vector<8x2xf32> + %297 = arith.addf %295, %296 : vector<8x2xf32> + %298 = vector.extract_strided_slice %297 {offsets = [0, 0], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %299 = vector.extract_strided_slice %297 {offsets = [0, 1], sizes = [8, 1], strides = [1, 1]} : vector<8x2xf32> to vector<8x1xf32> + %300 = arith.addf %298, %299 : vector<8x1xf32> + %301 = arith.subf %arg71, %270 : vector<8x1xf32> + %302 = math.exp %301 : vector<8x1xf32> + %303 = arith.mulf %arg73, %302 : vector<8x1xf32> + %304 = arith.addf %303, %300 : vector<8x1xf32> + %305 = vector.shape_cast %302 : vector<8x1xf32> to vector<8xf32> + %306 = vector.shuffle %305, %305 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] : vector<8xf32>, vector<8xf32> + %307 = vector.shape_cast %306 : vector<128xf32> to vector<8x16xf32> + %308 = arith.mulf %arg32, %307 : vector<8x16xf32> + %309 = arith.mulf %arg33, %307 : vector<8x16xf32> + %310 = arith.mulf %arg34, %307 : vector<8x16xf32> + %311 = arith.mulf %arg35, %307 : vector<8x16xf32> + %312 = arith.truncf %224 : vector<8x16xf32> to vector<8x16xf16> + %313 = arith.truncf %225 : vector<8x16xf32> to vector<8x16xf16> + %314 = arith.truncf %226 : vector<8x16xf32> to vector<8x16xf16> + %315 = arith.truncf %227 : vector<8x16xf32> to vector<8x16xf16> + %316 = arith.truncf %282 : vector<8x16xf32> to vector<8x16xf16> + %317 = arith.truncf %283 : vector<8x16xf32> to vector<8x16xf16> + %318 = arith.truncf %284 : vector<8x16xf32> to vector<8x16xf16> + %319 = arith.truncf %285 : vector<8x16xf32> to vector<8x16xf16> + xegpu.compile_hint + %320 = xegpu.load_nd %arg52 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %321 = xegpu.load_nd %arg53 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %322 = xegpu.load_nd %arg54 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %323 = xegpu.load_nd %arg55 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %324 = xegpu.update_nd_offset %arg52, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %325 = xegpu.update_nd_offset %arg53, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %326 = xegpu.update_nd_offset %arg54, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %327 = xegpu.update_nd_offset %arg55, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + xegpu.compile_hint + %328 = xegpu.dpas %312, %320, %250 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %329 = xegpu.dpas %316, %320, %308 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %330 = xegpu.dpas %312, %321, %251 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %331 = xegpu.dpas %316, %321, %309 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %332 = xegpu.dpas %312, %322, %252 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %333 = xegpu.dpas %316, %322, %310 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %334 = xegpu.dpas %312, %323, %253 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %335 = xegpu.dpas %316, %323, %311 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %336 = xegpu.load_nd %arg56 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %337 = xegpu.load_nd %arg57 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %338 = xegpu.load_nd %arg58 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %339 = xegpu.load_nd %arg59 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %340 = xegpu.update_nd_offset %arg56, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %341 = xegpu.update_nd_offset %arg57, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %342 = xegpu.update_nd_offset %arg58, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %343 = xegpu.update_nd_offset %arg59, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + %344 = xegpu.dpas %313, %336, %328 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %345 = xegpu.dpas %317, %336, %329 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %346 = xegpu.dpas %313, %337, %330 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %347 = xegpu.dpas %317, %337, %331 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %348 = xegpu.dpas %313, %338, %332 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %349 = xegpu.dpas %317, %338, %333 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %350 = xegpu.dpas %313, %339, %334 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %351 = xegpu.dpas %317, %339, %335 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %352 = xegpu.load_nd %arg60 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %353 = xegpu.load_nd %arg61 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %354 = xegpu.load_nd %arg62 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %355 = xegpu.load_nd %arg63 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %356 = xegpu.update_nd_offset %arg60, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %357 = xegpu.update_nd_offset %arg61, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %358 = xegpu.update_nd_offset %arg62, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %359 = xegpu.update_nd_offset %arg63, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + %360 = xegpu.dpas %314, %352, %344 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %361 = xegpu.dpas %318, %352, %345 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %362 = xegpu.dpas %314, %353, %346 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %363 = xegpu.dpas %318, %353, %347 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %364 = xegpu.dpas %314, %354, %348 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %365 = xegpu.dpas %318, %354, %349 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %366 = xegpu.dpas %314, %355, %350 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %367 = xegpu.dpas %318, %355, %351 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %368 = xegpu.load_nd %arg64 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %369 = xegpu.load_nd %arg65 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %370 = xegpu.load_nd %arg66 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %371 = xegpu.load_nd %arg67 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.compile_hint + %372 = xegpu.update_nd_offset %arg64, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %373 = xegpu.update_nd_offset %arg65, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %374 = xegpu.update_nd_offset %arg66, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + %375 = xegpu.update_nd_offset %arg67, [%arg26, %c0] : !xegpu.tensor_desc<16x16xf16> + xegpu.compile_hint + %376 = xegpu.dpas %315, %368, %360 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %377 = xegpu.dpas %319, %368, %361 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %378 = xegpu.dpas %315, %369, %362 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %379 = xegpu.dpas %319, %369, %363 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %380 = xegpu.dpas %315, %370, %364 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %381 = xegpu.dpas %319, %370, %365 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %382 = xegpu.dpas %315, %371, %366 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %383 = xegpu.dpas %319, %371, %367 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + xegpu.nbarrier_wait %90 : !xegpu.nbarrier + scf.yield %376, %378, %380, %382, %377, %379, %381, %383, %136, %137, %138, %139, %152, %153, %154, %155, %168, %169, %170, %171, %184, %185, %186, %187, %324, %325, %326, %327, %340, %341, %342, %343, %356, %357, %358, %359, %372, %373, %374, %375, %130, %131, %212, %270, %246, %304 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32>, vector<8x1xf32> } // divide acc output by l_i - %l_i_row_0_broadcast_t1 = vector.shape_cast %result#44 : vector<8x1xf32> to vector<8xf32> - %l_i_row_0_broadcast_t2 = vector.shuffle %l_i_row_0_broadcast_t1, %l_i_row_0_broadcast_t1 - [ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, - 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, - 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4, - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5, - 6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7] : vector<8xf32>, vector<8xf32> - %l_i_row_0_broadcast = vector.shape_cast %l_i_row_0_broadcast_t2 : vector<128xf32> to vector<8x16xf32> - - %l_i_row_1_broadcast_t1 = vector.shape_cast %result#45 : vector<8x1xf32> to vector<8xf32> - %l_i_row_1_broadcast_t2 = vector.shuffle %l_i_row_1_broadcast_t1, %l_i_row_1_broadcast_t1 - [ 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, - 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, - 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4, - 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5, - 6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6, - 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7] : vector<8xf32>, vector<8xf32> - %l_i_row_1_broadcast = vector.shape_cast %l_i_row_1_broadcast_t2 : vector<128xf32> to vector<8x16xf32> - %o_val_final_0_0_t = arith.divf %result#0, %l_i_row_0_broadcast : vector<8x16xf32> - %o_val_final_0_1_t = arith.divf %result#1, %l_i_row_0_broadcast : vector<8x16xf32> - %o_val_final_0_2_t = arith.divf %result#2, %l_i_row_0_broadcast : vector<8x16xf32> - %o_val_final_0_3_t = arith.divf %result#3, %l_i_row_0_broadcast : vector<8x16xf32> - %o_val_final_1_0_t = arith.divf %result#4, %l_i_row_1_broadcast : vector<8x16xf32> - %o_val_final_1_1_t = arith.divf %result#5, %l_i_row_1_broadcast : vector<8x16xf32> - %o_val_final_1_2_t = arith.divf %result#6, %l_i_row_1_broadcast : vector<8x16xf32> - %o_val_final_1_3_t = arith.divf %result#7, %l_i_row_1_broadcast : vector<8x16xf32> - - %o_val_final_0_0 = arith.truncf %o_val_final_0_0_t : vector<8x16xf32> to vector<8x16xf16> - %o_val_final_0_1 = arith.truncf %o_val_final_0_1_t : vector<8x16xf32> to vector<8x16xf16> - %o_val_final_0_2 = arith.truncf %o_val_final_0_2_t : vector<8x16xf32> to vector<8x16xf16> - %o_val_final_0_3 = arith.truncf %o_val_final_0_3_t : vector<8x16xf32> to vector<8x16xf16> - %o_val_final_1_0 = arith.truncf %o_val_final_1_0_t : vector<8x16xf32> to vector<8x16xf16> - %o_val_final_1_1 = arith.truncf %o_val_final_1_1_t : vector<8x16xf32> to vector<8x16xf16> - %o_val_final_1_2 = arith.truncf %o_val_final_1_2_t : vector<8x16xf32> to vector<8x16xf16> - %o_val_final_1_3 = arith.truncf %o_val_final_1_3_t : vector<8x16xf32> to vector<8x16xf16> - // O tile, max size is 8x32 - %o_tile_init_0_0 = xegpu.create_nd_tdesc %Out [%sg_q_x_offset, %c0], shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<8x32xf16> - %o_tile_init_0_1 = xegpu.update_nd_offset %o_tile_init_0_0, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - %o_tile_init_1_0 = xegpu.update_nd_offset %o_tile_init_0_0, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - %o_tile_init_1_1 = xegpu.update_nd_offset %o_tile_init_1_0, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - - %o_val_8x32_0_0_t1 = vector.shuffle %o_val_final_0_0, %o_val_final_0_1 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %o_val_8x32_0_0_t2 = vector.shape_cast %o_val_8x32_0_0_t1 : vector<16x16xf16> to vector<256xf16> - %o_val_8x32_0_0_t3 = vector.shape_cast %o_val_8x32_0_0_t2 : vector<256xf16> to vector<8x32xf16> - xegpu.store_nd %o_val_8x32_0_0_t3, %o_tile_init_0_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %92 = vector.shape_cast %91#44 : vector<8x1xf32> to vector<8xf32> + %93 = vector.shuffle %92, %92 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] : vector<8xf32>, vector<8xf32> + %94 = vector.shape_cast %93 : vector<128xf32> to vector<8x16xf32> + %95 = vector.shape_cast %91#45 : vector<8x1xf32> to vector<8xf32> + %96 = vector.shuffle %95, %95 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] : vector<8xf32>, vector<8xf32> + %97 = vector.shape_cast %96 : vector<128xf32> to vector<8x16xf32> + %98 = arith.divf %91#0, %94 : vector<8x16xf32> + %99 = arith.divf %91#1, %94 : vector<8x16xf32> + %100 = arith.divf %91#2, %94 : vector<8x16xf32> + %101 = arith.divf %91#3, %94 : vector<8x16xf32> + %102 = arith.divf %91#4, %97 : vector<8x16xf32> + %103 = arith.divf %91#5, %97 : vector<8x16xf32> + %104 = arith.divf %91#6, %97 : vector<8x16xf32> + %105 = arith.divf %91#7, %97 : vector<8x16xf32> + %106 = arith.truncf %98 : vector<8x16xf32> to vector<8x16xf16> + %107 = arith.truncf %99 : vector<8x16xf32> to vector<8x16xf16> + %108 = arith.truncf %100 : vector<8x16xf32> to vector<8x16xf16> + %109 = arith.truncf %101 : vector<8x16xf32> to vector<8x16xf16> + %110 = arith.truncf %102 : vector<8x16xf32> to vector<8x16xf16> + %111 = arith.truncf %103 : vector<8x16xf32> to vector<8x16xf16> + %112 = arith.truncf %104 : vector<8x16xf32> to vector<8x16xf16> + %113 = arith.truncf %105 : vector<8x16xf32> to vector<8x16xf16> + %114 = xegpu.create_nd_tdesc %arg3[%8, %c0], shape : [%2, %arg25], strides : [%arg25, %c1] : memref -> !xegpu.tensor_desc<8x32xf16> + %115 = xegpu.update_nd_offset %114, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %116 = xegpu.update_nd_offset %114, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + %117 = xegpu.update_nd_offset %116, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %118 = vector.shuffle %106, %107 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %119 = vector.shape_cast %118 : vector<16x16xf16> to vector<256xf16> + %120 = vector.shape_cast %119 : vector<256xf16> to vector<8x32xf16> + xegpu.store_nd %120, %114 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %o_val_8x32_0_1_t1 = vector.shuffle %o_val_final_0_2, %o_val_final_0_3 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %o_val_8x32_0_1_t2 = vector.shape_cast %o_val_8x32_0_1_t1 : vector<16x16xf16> to vector<256xf16> - %o_val_8x32_0_1_t3 = vector.shape_cast %o_val_8x32_0_1_t2 : vector<256xf16> to vector<8x32xf16> - xegpu.store_nd %o_val_8x32_0_1_t3, %o_tile_init_0_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %121 = vector.shuffle %108, %109 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %122 = vector.shape_cast %121 : vector<16x16xf16> to vector<256xf16> + %123 = vector.shape_cast %122 : vector<256xf16> to vector<8x32xf16> + xegpu.store_nd %123, %115 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %o_val_8x32_1_0_t1 = vector.shuffle %o_val_final_1_0, %o_val_final_1_1 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %o_val_8x32_1_0_t2 = vector.shape_cast %o_val_8x32_1_0_t1 : vector<16x16xf16> to vector<256xf16> - %o_val_8x32_1_0_t3 = vector.shape_cast %o_val_8x32_1_0_t2 : vector<256xf16> to vector<8x32xf16> - xegpu.store_nd %o_val_8x32_1_0_t3, %o_tile_init_1_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %124 = vector.shuffle %110, %111 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %125 = vector.shape_cast %124 : vector<16x16xf16> to vector<256xf16> + %126 = vector.shape_cast %125 : vector<256xf16> to vector<8x32xf16> + xegpu.store_nd %126, %116 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %o_val_8x32_1_1_t1 = vector.shuffle %o_val_final_1_2, %o_val_final_1_3 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %o_val_8x32_1_1_t2 = vector.shape_cast %o_val_8x32_1_1_t1 : vector<16x16xf16> to vector<256xf16> - %o_val_8x32_1_1_t3 = vector.shape_cast %o_val_8x32_1_1_t2 : vector<256xf16> to vector<8x32xf16> - xegpu.store_nd %o_val_8x32_1_1_t3, %o_tile_init_1_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %127 = vector.shuffle %112, %113 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %128 = vector.shape_cast %127 : vector<16x16xf16> to vector<256xf16> + %129 = vector.shape_cast %128 : vector<256xf16> to vector<8x32xf16> + xegpu.store_nd %129, %117 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - gpu.return } } - - func.func @gpu_impl(%q : memref, %k : memref, %v : memref, - %o : memref, %Z : index, %H : index, %N_CTX : index, %D_HEAD : index, - %sm_scale : f32) -> memref { - + func.func @gpu_impl(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: f32) -> memref { + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c1_i64 = arith.constant 1 : i64 - - %Z_H_N_t0 = arith.muli %Z, %H : index - %Z_H_N = arith.muli %Z_H_N_t0, %N_CTX : index - // %Z_i64 = index.castu %Z : index to i64 // %H_i64 = index.castu %H : index to i64 // %N_CTX_i64 = index.castu %N_CTX : index to i64 // %D_HEAD_i64 = index.castu %D_HEAD : index to i64 - //strides - %stride_1 = arith.muli %N_CTX, %D_HEAD : index - %stride_2 = arith.muli %stride_1, %H : index - - %q_gpu = gpu.alloc host_shared (%Z_H_N, %D_HEAD) : memref - %k_gpu = gpu.alloc host_shared (%Z_H_N, %D_HEAD) : memref - %v_gpu = gpu.alloc host_shared (%Z_H_N, %D_HEAD) : memref - %o_gpu = gpu.alloc host_shared (%Z_H_N, %D_HEAD) : memref // %m_gpu = gpu.alloc host_shared (%Z, %H, %N_CTX) : memref - // copy from CPU to - memref.copy %q, %q_gpu : memref to memref - memref.copy %k, %k_gpu : memref to memref - memref.copy %v, %v_gpu : memref to memref - memref.copy %o, %o_gpu : memref to memref // memref.copy %m, %m_gpu : memref to memref - // GPU params - %BLOCK_M = arith.constant 128 : index - %BLOCK_N = arith.constant 64 : index - %N_CTX_i64 = index.castu %N_CTX : index to i64 - %BLOCK_M_i64 = index.castu %BLOCK_M : index to i64 // do a ceiling div to figure out blocks_x // blocks_x = (N_CTX + BLOCKS_M - 1) / BLOCKS_M - %blocks_x_t1 = arith.subi %BLOCK_M_i64, %c1_i64 : i64 - %blocks_x_t2 = arith.addi %N_CTX_i64, %blocks_x_t1 : i64 - %blocks_x_i64 = arith.divui %blocks_x_t2, %BLOCK_M_i64 : i64 - %blocks_x = index.castu %blocks_x_i64 : i64 to index - %blocks_y = arith.muli %Z, %H : index // %blocks_x = arith.constant 32 : index - // %BLOCK_M_i64 = index.castu %BLOCK_M : index to i64 // %BLOCK_N_i64 = index.castu %BLOCK_N : index to i64 - // launch GPU func - gpu.launch_func @flash_attention_fwd::@flash_attention_fwd blocks in (%blocks_x, %blocks_y, %c1) - threads in (%c8, %c1, %c1) args( - %q_gpu : memref, %k_gpu : memref, %v_gpu : memref, %o_gpu : memref, - %sm_scale : f32, - %stride_2 : index, %stride_1 : index, %D_HEAD : index, %c1 : index, - %stride_2 : index, %stride_1 : index, %D_HEAD : index, %c1 : index, - %stride_2 : index, %stride_1 : index, %D_HEAD : index, %c1 : index, - %stride_2 : index, %stride_1 : index, %D_HEAD : index, %c1 : index, - %Z : index, %H : index, %N_CTX : index, %BLOCK_M : index, %D_HEAD : index, %BLOCK_N : index - ) - // copy output to CPU - memref.copy %o_gpu, %o : memref to memref - - gpu.dealloc %q_gpu : memref - gpu.dealloc %k_gpu : memref - gpu.dealloc %v_gpu : memref - gpu.dealloc %o_gpu : memref // gpu.dealloc %m_gpu : memref - - return %o : memref + %0 = arith.muli %arg4, %arg5 : index + %1 = arith.muli %0, %arg6 : index + %2 = arith.muli %arg6, %arg7 : index + %3 = arith.muli %2, %arg5 : index + %memref = gpu.alloc (%1, %arg7) : memref + %memref_0 = gpu.alloc (%1, %arg7) : memref + %memref_1 = gpu.alloc (%1, %arg7) : memref + %memref_2 = gpu.alloc (%1, %arg7) : memref + gpu.memcpy %memref, %arg0 : memref, memref + gpu.memcpy %memref_0, %arg1 : memref, memref + gpu.memcpy %memref_1, %arg2 : memref, memref + gpu.memcpy %memref_2, %arg3 : memref, memref + %4 = index.castu %arg6 : index to i64 + %5 = index.castu %c128 : index to i64 + %6 = arith.subi %5, %c1_i64 : i64 + %7 = arith.addi %4, %6 : i64 + %8 = arith.divui %7, %5 : i64 + %9 = index.castu %8 : i64 to index + %10 = arith.muli %arg4, %arg5 : index + gpu.launch_func @flash_attention_fwd::@flash_attention_fwd blocks in (%9, %10, %c1) threads in (%c8, %c1, %c1) args(%memref : memref, %memref_0 : memref, %memref_1 : memref, %memref_2 : memref, %arg8 : f32, %3 : index, %2 : index, %arg7 : index, %c1 : index, %3 : index, %2 : index, %arg7 : index, %c1 : index, %3 : index, %2 : index, %arg7 : index, %c1 : index, %3 : index, %2 : index, %arg7 : index, %c1 : index, %arg4 : index, %arg5 : index, %arg6 : index, %c128 : index, %arg7 : index, %c64 : index) + gpu.memcpy %arg3, %memref_2 : memref, memref + gpu.dealloc %memref : memref + gpu.dealloc %memref_0 : memref + gpu.dealloc %memref_1 : memref + gpu.dealloc %memref_2 : memref + return %arg3 : memref } - - func.func @cpu_impl(%Q : memref, %K : memref, %V : memref, - %o : memref, %Z : index, %H : index, %N_CTX : index, %D_HEAD : index, - %sm_scale : f32) -> memref { + func.func @cpu_impl(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: f32) -> memref { + %cst = arith.constant 0xFF800000 : f32 + %cst_0 = arith.constant 1.44269502 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c0_f32 = arith.constant 0.0 : f32 - %Z_H = arith.muli %Z, %H : index - %BLOCK_N = arith.constant 64 : index - %log2e = arith.constant 1.442695040888963 : f32 - // buffer - %qk_buffer = memref.alloc(%N_CTX, %N_CTX) : memref - scf.for %zh = %c0 to %Z_H step %c1 { // reset memref - scf.for %i = %c0 to %N_CTX step %c1 { - scf.for %j = %c0 to %N_CTX step %c1 { - memref.store %c0_f32, %qk_buffer[%i, %j] : memref + %cst_1 = arith.constant 0.000000e+00 : f32 + %0 = arith.muli %arg4, %arg5 : index + %alloc = memref.alloc(%arg6, %arg6) : memref + scf.for %arg9 = %c0 to %0 step %c1 { + scf.for %arg10 = %c0 to %arg6 step %c1 { + scf.for %arg11 = %c0 to %arg6 step %c1 { + memref.store %cst_1, %alloc[%arg10, %arg11] : memref } } - %x_offset = arith.muli %N_CTX, %zh : index // compute p = q*k^T - scf.for %i = %c0 to %N_CTX step %c1 { - scf.for %j = %c0 to %N_CTX step %c1 { - %qk_init = arith.constant 0.0 : f32 - %result = scf.for %k = %c0 to %D_HEAD step %c1 iter_args(%qk = %qk_init) -> f32 { - %zh_i = arith.addi %i, %x_offset : index - %zh_j = arith.addi %j, %x_offset : index - %q_val = memref.load %Q [%zh_i, %k] : memref - %k_val = memref.load %K [%zh_j, %k] : memref - %q_val_f32 = arith.extf %q_val : f16 to f32 - %k_val_f32 = arith.extf %k_val : f16 to f32 - %t = arith.mulf %q_val_f32, %k_val_f32 : f32 - %t1 = arith.addf %qk, %t : f32 - scf.yield %t1 : f32 + %1 = arith.muli %arg6, %arg9 : index + scf.for %arg10 = %c0 to %arg6 step %c1 { + scf.for %arg11 = %c0 to %arg6 step %c1 { + %2 = scf.for %arg12 = %c0 to %arg7 step %c1 iter_args(%arg13 = %cst_1) -> (f32) { + %4 = arith.addi %arg10, %1 : index + %5 = arith.addi %arg11, %1 : index + %6 = memref.load %arg0[%4, %arg12] : memref + %7 = memref.load %arg1[%5, %arg12] : memref + %8 = arith.extf %6 : f16 to f32 + %9 = arith.extf %7 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %arg13, %10 : f32 + scf.yield %11 : f32 } - %scaled = arith.mulf %result, %sm_scale : f32 - memref.store %scaled, %qk_buffer [%i, %j] : memref + %3 = arith.mulf %2, %arg8 : f32 + memref.store %3, %alloc[%arg10, %arg11] : memref } } // compute the softmax - scf.for %i = %c0 to %N_CTX step %c1 { - %qk_row_max_in = arith.constant 0xFF800000 : f32 // max reduce - %qk_row_max = scf.for %j = %c0 to %N_CTX step %c1 iter_args(%curr = %qk_row_max_in) -> f32 { - %qk_val = memref.load %qk_buffer [%i, %j] : memref - %new_max = arith.maximumf %curr, %qk_val : f32 - scf.yield %new_max : f32 + scf.for %arg10 = %c0 to %arg6 step %c1 { + %2 = scf.for %arg11 = %c0 to %arg6 step %c1 iter_args(%arg12 = %cst) -> (f32) { + %4 = memref.load %alloc[%arg10, %arg11] : memref + %5 = arith.maximumf %arg12, %4 : f32 + scf.yield %5 : f32 } // center by max and exp - scf.for %j = %c0 to %N_CTX step %c1 { - %qk_val = memref.load %qk_buffer [%i, %j] : memref - %t = arith.subf %qk_val, %qk_row_max : f32 // scale by log2e to emulate exp2 - %t1 = arith.mulf %t, %log2e : f32 - %t2 = math.exp2 %t1 : f32 - memref.store %t2, %qk_buffer [%i, %j] : memref + scf.for %arg11 = %c0 to %arg6 step %c1 { + %4 = memref.load %alloc[%arg10, %arg11] : memref + %5 = arith.subf %4, %2 : f32 + %6 = arith.mulf %5, %cst_0 : f32 + %7 = math.exp2 %6 : f32 + memref.store %7, %alloc[%arg10, %arg11] : memref } // take sum of row - %qk_row_sum_in = arith.constant 0.0 : f32 - %qk_row_sum = scf.for %j = %c0 to %N_CTX step %c1 iter_args(%curr = %qk_row_sum_in) -> f32 { - %qk_val = memref.load %qk_buffer [%i, %j] : memref - %sum_new = arith.addf %curr, %qk_val : f32 - scf.yield %sum_new : f32 + %3 = scf.for %arg11 = %c0 to %arg6 step %c1 iter_args(%arg12 = %cst_1) -> (f32) { + %4 = memref.load %alloc[%arg10, %arg11] : memref + %5 = arith.addf %arg12, %4 : f32 + scf.yield %5 : f32 } // div by sum - scf.for %j = %c0 to %N_CTX step %c1 { - %qk_val = memref.load %qk_buffer [%i, %j] : memref - %t = arith.divf %qk_val, %qk_row_sum : f32 - memref.store %t, %qk_buffer [%i, %j] : memref + scf.for %arg11 = %c0 to %arg6 step %c1 { + %4 = memref.load %alloc[%arg10, %arg11] : memref + %5 = arith.divf %4, %3 : f32 + memref.store %5, %alloc[%arg10, %arg11] : memref } } // compute p*v - scf.for %i = %c0 to %N_CTX step %c1 { - scf.for %j = %c0 to %D_HEAD step %c1 { - %pv_init = arith.constant 0.0 : f32 - %result = scf.for %k = %c0 to %N_CTX step %c1 iter_args (%pv = %pv_init) -> f32 { - %qk_val = memref.load %qk_buffer [%i, %k] : memref - %qk_val_f16 = arith.truncf %qk_val : f32 to f16 - %zh_k = arith.addi %k, %x_offset : index - %v_val = memref.load %V [%zh_k, %j] : memref - %qk_val_f32 = arith.extf %qk_val_f16 : f16 to f32 - %v_val_f32 = arith.extf %v_val : f16 to f32 - %t = arith.mulf %qk_val_f32, %v_val_f32 : f32 - %t1 = arith.addf %t, %pv : f32 - scf.yield %t1 : f32 + scf.for %arg10 = %c0 to %arg6 step %c1 { + scf.for %arg11 = %c0 to %arg7 step %c1 { + %2 = scf.for %arg12 = %c0 to %arg6 step %c1 iter_args(%arg13 = %cst_1) -> (f32) { + %5 = memref.load %alloc[%arg10, %arg12] : memref + %6 = arith.truncf %5 : f32 to f16 + %7 = arith.addi %arg12, %1 : index + %8 = memref.load %arg2[%7, %arg11] : memref + %9 = arith.extf %6 : f16 to f32 + %10 = arith.extf %8 : f16 to f32 + %11 = arith.mulf %9, %10 : f32 + %12 = arith.addf %11, %arg13 : f32 + scf.yield %12 : f32 } - %zh_i = arith.addi %i, %x_offset : index - %pv_f16 = arith.truncf %result : f32 to f16 - memref.store %pv_f16, %o [%zh_i, %j] : memref + %3 = arith.addi %arg10, %1 : index + %4 = arith.truncf %2 : f32 to f16 + memref.store %4, %arg3[%3, %arg11] : memref } } - } - - memref.dealloc %qk_buffer : memref - - return %o : memref + memref.dealloc %alloc : memref + return %arg3 : memref } - - - func.func @init_2d_dynamic_memref_to_const_f16(%m : memref, - %d0 : index, %d1 : index, %value : f16) -> () { + func.func @init_2d_dynamic_memref_to_const_f16(%arg0: memref, %arg1: index, %arg2: index, %arg3: f16) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - scf.for %arg0 = %c0 to %d0 step %c1 { - scf.for %arg1 = %c0 to %d1 step %c1 { - memref.store %value, %m [%arg0, %arg1] : memref + scf.for %arg4 = %c0 to %arg1 step %c1 { + scf.for %arg5 = %c0 to %arg2 step %c1 { + memref.store %arg3, %arg0[%arg4, %arg5] : memref } } return } - - func.func @main() attributes {llvm.emit_c_interface} { + func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %magic = arith.constant 0.625 : f32 - %c0_f16 = arith.constant 0.0 : f16 - %c1_f32 = arith.constant 0.5 : f32 - %Z = arith.constant 2 : index // number of batches - %H = arith.constant 2 : index // number of heads - %N_CTX = arith.constant 4096 : index // sequence len - %D_HEAD = arith.constant 64 : index // head dim - %sm_scale = arith.constant 0.5 : f32 // softmax scale - // random number generator params - %rand_low = arith.constant -1.0 : f32 - %rand_high = arith.constant 1.0 : f32 - %gen_int = arith.constant 0 : i1 - // xegpu only supports 2d memrefs. So we collapse the first 3 dims of the inputs // Z x H x N_CTX x D_HEAD -> (Z * H * N_CTX) x D_HEAD - %Z_H_N_t0 = arith.muli %Z, %H : index - %Z_H_N = arith.muli %Z_H_N_t0, %N_CTX : index - // allocate q, k, v, o - %q = memref.alloc(%Z_H_N, %D_HEAD) : memref - %k = memref.alloc(%Z_H_N, %D_HEAD) : memref - %v = memref.alloc(%Z_H_N, %D_HEAD) : memref - %o = memref.alloc(%Z_H_N, %D_HEAD) : memref - %o_cpu = memref.alloc(%Z_H_N, %D_HEAD) : memref - %o_cpu_f32 = memref.alloc(%Z_H_N, %D_HEAD) : memref // FIXME : m is unused for now // %m = memref.alloc(%Z, %H, %N_CTX) : memref - // initialize q, k, v - %q_random = memref.cast %q : memref to memref<*xf16> - %k_random = memref.cast %k : memref to memref<*xf16> - %v_random = memref.cast %v : memref to memref<*xf16> // Option 1: fill with random numbers // call @fillResource1DRandomF16(%q_random, %rand_low, %rand_high, %gen_int) : (memref<*xf16>, f32, f32, i1) -> () // call @fillResource1DRandomF16(%k_random, %rand_low, %rand_high, %gen_int) : (memref<*xf16>, f32, f32, i1) -> () // call @fillResource1DRandomF16(%v_random, %rand_low, %rand_high, %gen_int) : (memref<*xf16>, f32, f32, i1) -> () // Option 2: fill with some magic constant for validation - call @fillResource1DF16(%q_random, %magic) : (memref<*xf16>, f32) -> () - call @fillResource1DF16(%k_random, %magic) : (memref<*xf16>, f32) -> () - call @fillResource1DF16(%v_random, %magic) : (memref<*xf16>, f32) -> () - // // initialize output to 0.0 // %o_random = memref.collapse_shape %o [[0, 1, 2, 3]] : memref into memref - call @init_2d_dynamic_memref_to_const_f16(%o, %Z_H_N, %D_HEAD, %c0_f16) - : (memref, index, index, f16) -> () - call @init_2d_dynamic_memref_to_const_f16(%o_cpu, %Z_H_N, %D_HEAD, %c0_f16) - : (memref, index, index, f16) -> () - // initialize m to 1.0 (FIXME : masking is not used) // %c1_f32 = arith.constant 1.0 : f32 // %m_random = memref.collapse_shape %m [[0, 1, 2]] : memref into memref // call @fillResource1DF32(%m_random, %c1_f32) : (memref, f32) -> () - // run fused version - %out = call @gpu_impl( %q, %k, %v, %o, %Z, %H, %N_CTX, %D_HEAD, %sm_scale) : - (memref, memref, memref, memref, - index, index, index, index, f32) -> memref - // run cpu version - %out_cpu = call @cpu_impl( %q, %k, %v, %o_cpu, %Z, %H, %N_CTX, %D_HEAD, %sm_scale) : - (memref, memref, memref, memref, - index, index, index, index, f32) -> memref - - %out_cast = memref.cast %out : memref to memref<*xf16> - %q_cast = memref.cast %q : memref to memref<*xf16> - %out_cpu_cast = memref.cast %out_cpu : memref to memref<*xf16> // call @printMemrefF16(%q_cast) : (memref<*xf16>) -> () // call @printMemrefF16(%out_cast) : (memref<*xf16>) -> () // call @printMemrefF16(%out_cpu_cast) : (memref<*xf16>) -> () // call @printMaxErrorF16(%out_cast, %out_cpu_cast) : (memref<*xf16>, memref<*xf16>) -> () // sign extend CPU output to f32 - scf.for %i = %c0 to %Z_H_N step %c1 { - scf.for %j = %c0 to %D_HEAD step %c1 { - %o_val = memref.load %o_cpu [%i, %j] : memref - %o_val_f32 = arith.extf %o_val : f16 to f32 - memref.store %o_val_f32, %o_cpu_f32 [%i, %j] : memref + %cst = arith.constant 6.250000e-01 : f32 + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 5.000000e-01 : f32 + %c2 = arith.constant 2 : index + %c4096 = arith.constant 4096 : index + %c64 = arith.constant 64 : index + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc(%c16384, %c64) : memref + %alloc_2 = memref.alloc(%c16384, %c64) : memref + %alloc_3 = memref.alloc(%c16384, %c64) : memref + %alloc_4 = memref.alloc(%c16384, %c64) : memref + %alloc_5 = memref.alloc(%c16384, %c64) : memref + %alloc_6 = memref.alloc(%c16384, %c64) : memref + %cast = memref.cast %alloc : memref to memref<*xf16> + %cast_7 = memref.cast %alloc_2 : memref to memref<*xf16> + %cast_8 = memref.cast %alloc_3 : memref to memref<*xf16> + call @fillResource1DF16(%cast, %cst) : (memref<*xf16>, f32) -> () + call @fillResource1DF16(%cast_7, %cst) : (memref<*xf16>, f32) -> () + call @fillResource1DF16(%cast_8, %cst) : (memref<*xf16>, f32) -> () + call @init_2d_dynamic_memref_to_const_f16(%alloc_4, %c16384, %c64, %cst_0) : (memref, index, index, f16) -> () + call @init_2d_dynamic_memref_to_const_f16(%alloc_5, %c16384, %c64, %cst_0) : (memref, index, index, f16) -> () + %0 = call @gpu_impl(%alloc, %alloc_2, %alloc_3, %alloc_4, %c2, %c2, %c4096, %c64, %cst_1) : (memref, memref, memref, memref, index, index, index, index, f32) -> memref + %1 = call @cpu_impl(%alloc, %alloc_2, %alloc_3, %alloc_5, %c2, %c2, %c4096, %c64, %cst_1) : (memref, memref, memref, memref, index, index, index, index, f32) -> memref + %cast_9 = memref.cast %0 : memref to memref<*xf16> + scf.for %arg0 = %c0 to %c16384 step %c1 { + scf.for %arg1 = %c0 to %c64 step %c1 { + %2 = memref.load %alloc_5[%arg0, %arg1] : memref + %3 = arith.extf %2 : f16 to f32 + memref.store %3, %alloc_6[%arg0, %arg1] : memref } } - %out_cpu_f32_cast = memref.cast %o_cpu_f32 : memref to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%out_cast, %out_cpu_f32_cast) : (memref<*xf16>, memref<*xf32>) -> () - - - memref.dealloc %q : memref - memref.dealloc %k : memref - memref.dealloc %v : memref - memref.dealloc %o : memref - memref.dealloc %o_cpu : memref - memref.dealloc %o_cpu_f32 : memref // memref.dealloc %m : memref - + %cast_10 = memref.cast %alloc_6 : memref to memref<*xf32> + call @printAllcloseF16(%cast_9, %cast_10) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref + memref.dealloc %alloc_2 : memref + memref.dealloc %alloc_3 : memref + memref.dealloc %alloc_4 : memref + memref.dealloc %alloc_5 : memref + memref.dealloc %alloc_6 : memref return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} @@ -1056,5 +766,4 @@ module @flash_attention attributes {gpu.container_module} { func.func private @fillResource1DF16(memref<*xf16>, f32) attributes {llvm.emit_c_interface} func.func private @fillResource1DF32(memref<*xf32>, f32) attributes {llvm.emit_c_interface} func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeGPU/VC/fmax_f32.vc.mlir b/test/Integration/Dialect/XeGPU/VC/fmax_f32.vc.mlir index 1bccb9b21..f012e7012 100644 --- a/test/Integration/Dialect/XeGPU/VC/fmax_f32.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/fmax_f32.vc.mlir @@ -1,110 +1,105 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<8x32xf16>, %B: memref<16x32xf16> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<8x32xf16>, %arg1: memref<16x32xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf16> - memref.copy %A, %memref : memref<8x32xf16> to memref<8x32xf16> - memref.copy %B, %memref_1 : memref<16x32xf16> to memref<16x32xf16> - %memref_2 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_1 : memref<16x32xf16>, %memref_2 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<16x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + gpu.memcpy %memref_0, %arg1 : memref<16x32xf16>, memref<16x32xf16> + %memref_1 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<16x32xf16>, %memref_1 : memref<8x16xf32>) gpu.dealloc %memref : memref<8x32xf16> - gpu.dealloc %memref_1 : memref<16x32xf16> - return %memref_2 : memref<8x16xf32> + gpu.dealloc %memref_0 : memref<16x32xf16> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_1 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_1 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<8x32xf16>, %B: memref<16x32xf16>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<8x32xf16>, %arg1: memref<16x32xf16>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index - %a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %a_tile1 = xegpu.create_nd_tdesc %A [%c0, %c16] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> // load A tiles - %val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %val1 = xegpu.load_nd %a_tile1 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %b_tile0 = xegpu.create_nd_tdesc %B [%c0, %c0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> - %b_tile1 = xegpu.create_nd_tdesc %B [%c0, %c16] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> // load B tiles - %val2 = xegpu.load_nd %b_tile0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %val3 = xegpu.load_nd %b_tile1 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> // do DPAS - %val4 = xegpu.dpas %val0, %val2 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - %val5 = xegpu.dpas %val1, %val3 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> // take fmax - %val6 = arith.maximumf %val4, %val5 fastmath : vector<8x16xf32> // store fmax - %out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %val6, %out_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg0[%c0, %c16] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %4 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %5 = xegpu.create_nd_tdesc %arg1[%c0, %c16] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %6 = xegpu.load_nd %4 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %7 = xegpu.load_nd %5 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %8 = xegpu.dpas %2, %6 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %9 = xegpu.dpas %3, %7 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %10 = arith.maximumf %8, %9 fastmath : vector<8x16xf32> + %11 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %10, %11 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { // init constants + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 5.000000e-01 : f32 + %cst_1 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - - %A = memref.alloc() : memref<8x32xf16> - %B = memref.alloc() : memref<16x32xf16> - %Out_cpu = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x32xf16> to memref<*xf16> - %B_random = memref.cast %B : memref<16x32xf16> to memref<*xf16> - - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // run GPU version - %Out_gpu = call @test(%A, %B) : (memref<8x32xf16>, memref<16x32xf16>) -> memref<8x16xf32> - %Out_gpu_cast = memref.cast %Out_gpu : memref<8x16xf32> to memref<*xf32> // run CPU version - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %v0_init = arith.constant 0.0 : f32 - %v1_init = arith.constant 0.0 : f32 - %result:2 = scf.for %k = %c0 to %c16 step %c1 iter_args(%v0 = %v0_init, %v1 = %v1_init) -> (f32, f32){ - %1 = arith.addi %k, %c16 : index - %2 = arith.addi %j, %c16 : index - %a0 = memref.load %A[%i, %k] : memref<8x32xf16> - %a1 = memref.load %A[%i, %1] : memref<8x32xf16> - %b0 = memref.load %B[%k, %j] : memref<16x32xf16> - %b1 = memref.load %B[%k, %2] : memref<16x32xf16> - %a0_f32 = arith.extf %a0 : f16 to f32 - %a1_f32 = arith.extf %a1 : f16 to f32 - %b0_f32 = arith.extf %b0 : f16 to f32 - %b1_f32 = arith.extf %b1 : f16 to f32 - %t0 = arith.mulf %a0_f32, %b0_f32 : f32 - %t1 = arith.mulf %a1_f32, %b1_f32 : f32 - %v0_new = arith.addf %v0, %t0 : f32 - %v1_new = arith.addf %v1, %t1 : f32 - scf.yield %v0_new, %v1_new : f32, f32 + %alloc = memref.alloc() : memref<8x32xf16> + %alloc_2 = memref.alloc() : memref<16x32xf16> + %alloc_3 = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x32xf16> to memref<*xf16> + %cast_4 = memref.cast %alloc_2 : memref<16x32xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_1, %cst_0, %false) : (memref<*xf16>, f32, f32, i1) -> () + call @fillResource1DRandomF16(%cast_4, %cst_1, %cst_0, %false) : (memref<*xf16>, f32, f32, i1) -> () + %0 = call @test(%alloc, %alloc_2) : (memref<8x32xf16>, memref<16x32xf16>) -> memref<8x16xf32> + %cast_5 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1:2 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %cst, %arg4 = %cst) -> (f32, f32) { + %3 = arith.addi %arg2, %c16 : index + %4 = arith.addi %arg1, %c16 : index + %5 = memref.load %alloc[%arg0, %arg2] : memref<8x32xf16> + %6 = memref.load %alloc[%arg0, %3] : memref<8x32xf16> + %7 = memref.load %alloc_2[%arg2, %arg1] : memref<16x32xf16> + %8 = memref.load %alloc_2[%arg2, %4] : memref<16x32xf16> + %9 = arith.extf %5 : f16 to f32 + %10 = arith.extf %6 : f16 to f32 + %11 = arith.extf %7 : f16 to f32 + %12 = arith.extf %8 : f16 to f32 + %13 = arith.mulf %9, %11 : f32 + %14 = arith.mulf %10, %12 : f32 + %15 = arith.addf %arg3, %13 : f32 + %16 = arith.addf %arg4, %14 : f32 + scf.yield %15, %16 : f32, f32 } - %vmax = arith.maximumf %result#0, %result#1 : f32 - memref.store %vmax, %Out_cpu[%i, %j] : memref<8x16xf32> + %2 = arith.maximumf %1#0, %1#1 : f32 + memref.store %2, %alloc_3[%arg0, %arg1] : memref<8x16xf32> } } - %Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32> // print GPU and CPU outs - call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () - call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () // dealloc - memref.dealloc %A : memref<8x32xf16> - memref.dealloc %B : memref<16x32xf16> - memref.dealloc %Out_cpu : memref<8x16xf32> // gpu dealloc - gpu.dealloc %Out_gpu : memref<8x16xf32> + %cast_6 = memref.cast %alloc_3 : memref<8x16xf32> to memref<*xf32> + call @printMemrefF32(%cast_6) : (memref<*xf32>) -> () + call @printMemrefF32(%cast_5) : (memref<*xf32>) -> () + call @printAllcloseF32(%cast_5, %cast_6) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x32xf16> + memref.dealloc %alloc_2 : memref<16x32xf16> + memref.dealloc %alloc_3 : memref<8x16xf32> + memref.dealloc %0 : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_4_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_4_f32.mlir index b45f812d9..3e5deebc6 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_4_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_4_f32.mlir @@ -1,82 +1,66 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#scatter = #xegpu.scatter_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16x4xf32>) -> memref<16x4xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - - %in = gpu.alloc host_shared () : memref<16x4xf32> - memref.copy %arg0, %in : memref<16x4xf32> to memref<16x4xf32> - - %out = gpu.alloc host_shared () : memref<16x4xf32> - - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%in : memref<16x4xf32>, %out : memref<16x4xf32>) - - gpu.dealloc %in : memref<16x4xf32> - return %out : memref<16x4xf32> + %memref = gpu.alloc () : memref<16x4xf32> + gpu.memcpy %memref, %arg0 : memref<16x4xf32>, memref<16x4xf32> + %memref_0 = gpu.alloc () : memref<16x4xf32> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x4xf32>, %memref_0 : memref<16x4xf32>) + gpu.dealloc %memref : memref<16x4xf32> + %alloc = memref.alloc() : memref<16x4xf32> + gpu.memcpy %alloc, %memref_0 : memref<16x4xf32>, memref<16x4xf32> + gpu.dealloc %memref_0 : memref<16x4xf32> + return %alloc : memref<16x4xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_copy(%a: memref<16x4xf32>, %b: memref<16x4xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xindex> - + gpu.module @test_kernel { // load from a using load_gather - %a_cast = memref.reinterpret_cast %a to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> - %a_tdesc = xegpu.create_tdesc %a_cast, %offsets : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #scatter> - xegpu.prefetch %a_tdesc : !xegpu.tensor_desc<16x4xf32, #scatter> - %data = xegpu.load %a_tdesc, %mask : !xegpu.tensor_desc<16x4xf32, #scatter>, vector<16xi1> -> vector<16x4xf32> - // store to b using store_nd, used to check the implicit order issues with load_gather and store_scatter. // %c0 = arith.constant 0 : index // %b_tdesc = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x4xf32> -> !xegpu.tensor_desc<16x4xf32> // xegpu.store_nd %data, %b_tdesc : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32> - // store to b using store_scatter - %b_cast = memref.reinterpret_cast %b to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> - %b_tdesc = xegpu.create_tdesc %b_cast, %offsets : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #scatter> - xegpu.store %data, %b_tdesc, %mask : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #scatter>, vector<16xi1> + gpu.func @test_copy(%arg0: memref<16x4xf32>, %arg1: memref<16x4xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense : vector<16xi1> + %cst_0 = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xindex> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + %0 = xegpu.create_tdesc %reinterpret_cast, %cst_0 : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + xegpu.prefetch %0 : !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + %1 = xegpu.load %0, %cst : !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x4xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + %2 = xegpu.create_tdesc %reinterpret_cast_1, %cst_0 : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + xegpu.store %1, %2, %cst : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index - %c1 = arith.constant 1: index - %c4 = arith.constant 4: index - %c16 = arith.constant 16: index - %A = memref.alloc() : memref<16x4xf32> - - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c4 step %c1 { - %mul = arith.muli %i, %c4 : index - %add = arith.addi %mul, %j : index - %i32 = index.castu %add : index to i32 - %f32 = arith.sitofp %i32 : i32 to f32 - memref.store %f32, %A[%i, %j] : memref<16x4xf32> + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %alloc = memref.alloc() : memref<16x4xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c4 step %c1 { + %1 = arith.muli %arg0, %c4 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i32 + %4 = arith.sitofp %3 : i32 to f32 + memref.store %4, %alloc[%arg0, %arg1] : memref<16x4xf32> } } - - %B = call @test(%A) : (memref<16x4xf32>) -> memref<16x4xf32> - %A_cast = memref.cast %A : memref<16x4xf32> to memref<*xf32> - %B_cast = memref.cast %B : memref<16x4xf32> to memref<*xf32> - // call @printMemrefF32(%A_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () - //CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<16x4xf32> + %0 = call @test(%alloc) : (memref<16x4xf32>) -> memref<16x4xf32> + %cast = memref.cast %alloc : memref<16x4xf32> to memref<*xf32> + %cast_0 = memref.cast %0 : memref<16x4xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_0) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<16x4xf32> return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_8_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_8_f32.mlir index da0d4a1c3..814e52f81 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_8_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_chunk_8_f32.mlir @@ -1,82 +1,66 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#scatter = #xegpu.scatter_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16x8xf32>) -> memref<16x8xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - - %in = gpu.alloc host_shared () : memref<16x8xf32> - memref.copy %arg0, %in : memref<16x8xf32> to memref<16x8xf32> - - %out = gpu.alloc host_shared () : memref<16x8xf32> - - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%in : memref<16x8xf32>, %out : memref<16x8xf32>) - - gpu.dealloc %in : memref<16x8xf32> - return %out : memref<16x8xf32> + %memref = gpu.alloc () : memref<16x8xf32> + gpu.memcpy %memref, %arg0 : memref<16x8xf32>, memref<16x8xf32> + %memref_0 = gpu.alloc () : memref<16x8xf32> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x8xf32>, %memref_0 : memref<16x8xf32>) + gpu.dealloc %memref : memref<16x8xf32> + %alloc = memref.alloc() : memref<16x8xf32> + gpu.memcpy %alloc, %memref_0 : memref<16x8xf32>, memref<16x8xf32> + gpu.dealloc %memref_0 : memref<16x8xf32> + return %alloc : memref<16x8xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_copy(%a: memref<16x8xf32>, %b: memref<16x8xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - - %mask = arith.constant dense<[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]> : vector<16xi1> - %offsets = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - + gpu.module @test_kernel { // load from a using load_gather - %a_cast = memref.reinterpret_cast %a to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> - %a_tdesc = xegpu.create_tdesc %a_cast, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #scatter> - xegpu.prefetch %a_tdesc : !xegpu.tensor_desc<16x8xf32, #scatter> - %data = xegpu.load %a_tdesc, %mask : !xegpu.tensor_desc<16x8xf32, #scatter>, vector<16xi1> -> vector<16x8xf32> - // store to b using store_nd, used to check the implicit order issues with load_gather and store_scatter. // %c0 = arith.constant 0 : index // %b_tdesc = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x8xf32> -> !xegpu.tensor_desc<16x8xf32> // xegpu.store_nd %data, %b_tdesc : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32> - // store to b using store_scatter - %b_cast = memref.reinterpret_cast %b to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> - %b_tdesc = xegpu.create_tdesc %b_cast, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #scatter> - xegpu.store %data, %b_tdesc, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #scatter>, vector<16xi1> + gpu.func @test_copy(%arg0: memref<16x8xf32>, %arg1: memref<16x8xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense : vector<16xi1> + %cst_0 = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> + %0 = xegpu.create_tdesc %reinterpret_cast, %cst_0 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + xegpu.prefetch %0 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + %1 = xegpu.load %0, %cst : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x8xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> + %2 = xegpu.create_tdesc %reinterpret_cast_1, %cst_0 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + xegpu.store %1, %2, %cst : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index - %c1 = arith.constant 1: index - %c8 = arith.constant 8: index - %c16 = arith.constant 16: index - %A = memref.alloc() : memref<16x8xf32> - - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c8 step %c1 { - %mul = arith.muli %i, %c8 : index - %add = arith.addi %mul, %j : index - %i32 = index.castu %add : index to i32 - %f32 = arith.sitofp %i32 : i32 to f32 - memref.store %f32, %A[%i, %j] : memref<16x8xf32> + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %alloc = memref.alloc() : memref<16x8xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %1 = arith.muli %arg0, %c8 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i32 + %4 = arith.sitofp %3 : i32 to f32 + memref.store %4, %alloc[%arg0, %arg1] : memref<16x8xf32> } } - - %B = call @test(%A) : (memref<16x8xf32>) -> memref<16x8xf32> - %A_cast = memref.cast %A : memref<16x8xf32> to memref<*xf32> - %B_cast = memref.cast %B : memref<16x8xf32> to memref<*xf32> - // call @printMemrefF32(%A_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () - //CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<16x8xf32> + %0 = call @test(%alloc) : (memref<16x8xf32>) -> memref<16x8xf32> + %cast = memref.cast %alloc : memref<16x8xf32> to memref<*xf32> + %cast_0 = memref.cast %0 : memref<16x8xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_0) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<16x8xf32> return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f16.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f16.mlir index bb4f20378..c64aef572 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f16.mlir @@ -1,65 +1,52 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#scatter = #xegpu.scatter_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16xf16>) -> memref<16xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - - %in = gpu.alloc host_shared () : memref<16xf16> - memref.copy %arg0, %in : memref<16xf16> to memref<16xf16> - - %out = gpu.alloc host_shared () : memref<16xf16> - - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%in : memref<16xf16>, %out : memref<16xf16>) - - gpu.dealloc %in : memref<16xf16> - return %out : memref<16xf16> + %memref = gpu.alloc () : memref<16xf16> + gpu.memcpy %memref, %arg0 : memref<16xf16>, memref<16xf16> + %memref_0 = gpu.alloc () : memref<16xf16> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16xf16>, %memref_0 : memref<16xf16>) + gpu.dealloc %memref : memref<16xf16> + %alloc = memref.alloc() : memref<16xf16> + gpu.memcpy %alloc, %memref_0 : memref<16xf16>, memref<16xf16> + gpu.dealloc %memref_0 : memref<16xf16> + return %alloc : memref<16xf16> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_copy(%a: memref<16xf16>, %b: memref<16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> - + gpu.module @test_kernel { // load from a using load_gather - %a_tdesc = xegpu.create_tdesc %a, %offsets : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #scatter> - %data = xegpu.load %a_tdesc, %mask : !xegpu.tensor_desc<16xf16, #scatter>, vector<16xi1> -> vector<16xf16> - // %v1 = vector.extract %data[4]: f16 from vector<16xf16> // gpu.printf "\ndata[4] : %f.\n" %v1: f16 - // store to b using store_scatter - %b_tdesc = xegpu.create_tdesc %b, %offsets : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #scatter> - xegpu.store %data, %b_tdesc, %mask : vector<16xf16>, !xegpu.tensor_desc<16xf16, #scatter>, vector<16xi1> + gpu.func @test_copy(%arg0: memref<16xf16>, %arg1: memref<16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense : vector<16xi1> + %cst_0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg0, %cst_0 : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.load %0, %cst : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf16> + %2 = xegpu.create_tdesc %arg1, %cst_0 : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>> + xegpu.store %1, %2, %cst : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index - %c1 = arith.constant 1: index - %c16 = arith.constant 16: index - %A = memref.alloc() : memref<16xf16> - - scf.for %i = %c0 to %c16 step %c1 { - %i32 = index.castu %i : index to i32 - %f16 = arith.sitofp %i32 : i32 to f16 - memref.store %f16, %A[%i] : memref<16xf16> + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %alloc = memref.alloc() : memref<16xf16> + scf.for %arg0 = %c0 to %c16 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = arith.sitofp %1 : i32 to f16 + memref.store %2, %alloc[%arg0] : memref<16xf16> } - - %B = call @test(%A) : (memref<16xf16>) -> memref<16xf16> - %B_cast = memref.cast %B : memref<16xf16> to memref<*xf16> //CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] - call @printMemrefF16(%B_cast) : (memref<*xf16>) -> () - memref.dealloc %A : memref<16xf16> + %0 = call @test(%alloc) : (memref<16xf16>) -> memref<16xf16> + %cast = memref.cast %0 : memref<16xf16> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () + memref.dealloc %alloc : memref<16xf16> return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f32.mlir index 2f14106b9..97867111f 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/load_global_no_chunk_f32.mlir @@ -1,70 +1,55 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#scatter = #xegpu.scatter_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16xf32>) -> memref<16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - - %in = gpu.alloc host_shared () : memref<16xf32> - memref.copy %arg0, %in : memref<16xf32> to memref<16xf32> - - %out = gpu.alloc host_shared () : memref<16xf32> - - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%in : memref<16xf32>, %out : memref<16xf32>) - - gpu.dealloc %in : memref<16xf32> - return %out : memref<16xf32> + %memref = gpu.alloc () : memref<16xf32> + gpu.memcpy %memref, %arg0 : memref<16xf32>, memref<16xf32> + %memref_0 = gpu.alloc () : memref<16xf32> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16xf32>, %memref_0 : memref<16xf32>) + gpu.dealloc %memref : memref<16xf32> + %alloc = memref.alloc() : memref<16xf32> + gpu.memcpy %alloc, %memref_0 : memref<16xf32>, memref<16xf32> + gpu.dealloc %memref_0 : memref<16xf32> + return %alloc : memref<16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_copy(%a: memref<16xf32>, %b: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> - + gpu.module @test_kernel { // load from a using load_gather - %a_tdesc = xegpu.create_tdesc %a, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #scatter> - xegpu.prefetch %a_tdesc : !xegpu.tensor_desc<16xf32, #scatter> - %data = xegpu.load %a_tdesc, %mask : !xegpu.tensor_desc<16xf32, #scatter>, vector<16xi1> -> vector<16xf32> - // store to b using store_scatter - %b_tdesc = xegpu.create_tdesc %b, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #scatter> - xegpu.store %data, %b_tdesc, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #scatter>, vector<16xi1> + gpu.func @test_copy(%arg0: memref<16xf32>, %arg1: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense : vector<16xi1> + %cst_0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg0, %cst_0 : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + xegpu.prefetch %0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.load %0, %cst : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + %2 = xegpu.create_tdesc %arg1, %cst_0 : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + xegpu.store %1, %2, %cst : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index - %c1 = arith.constant 1: index - %c16 = arith.constant 16: index - %A = memref.alloc() : memref<16xf32> - - scf.for %i = %c0 to %c16 step %c1 { - %i32 = index.castu %i : index to i32 - %f32 = arith.sitofp %i32 : i32 to f32 - memref.store %f32, %A[%i] : memref<16xf32> + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %alloc = memref.alloc() : memref<16xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = arith.sitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0] : memref<16xf32> } - - %B = call @test(%A) : (memref<16xf32>) -> memref<16xf32> - %A_cast = memref.cast %A : memref<16xf32> to memref<*xf32> - %B_cast = memref.cast %B : memref<16xf32> to memref<*xf32> - // call @printMemrefF32(%A_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () - //CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<16xf32> + %0 = call @test(%alloc) : (memref<16xf32>) -> memref<16xf32> + %cast = memref.cast %alloc : memref<16xf32> to memref<*xf32> + %cast_0 = memref.cast %0 : memref<16xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_0) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<16xf32> return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_4_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_4_f32.mlir index 52a27dc7e..36c37de8e 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_4_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_4_f32.mlir @@ -1,41 +1,34 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#scatter = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16x4xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16x4xf32> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16x4xf32>) - return %out : memref<16x4xf32> + %memref = gpu.alloc () : memref<16x4xf32> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x4xf32>) + %alloc = memref.alloc() : memref<16x4xf32> + gpu.memcpy %alloc, %memref : memref<16x4xf32>, memref<16x4xf32> + gpu.dealloc %memref : memref<16x4xf32> + return %alloc : memref<16x4xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16x4xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst = arith.constant dense<[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], + gpu.module @test_kernel { + gpu.func @test_store_scatter(%arg0: memref<16x4xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31.], [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.], [48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63.]]> : vector<4x16xf32> - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xindex> - %cast = memref.reinterpret_cast %mem to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> - %5 = xegpu.create_tdesc %cast, %offsets : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #scatter> - %trans = vector.transpose %cst, [1, 0] : vector<4x16xf32> to vector<16x4xf32> - xegpu.store %trans, %5, %mask : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #scatter>, vector<16xi1> + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xindex> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + %0 = xegpu.create_tdesc %reinterpret_cast, %cst_1 : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + %1 = vector.transpose %cst, [1, 0] : vector<4x16xf32> to vector<16x4xf32> + xegpu.store %1, %0, %cst_0 : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16x4xf32> - %cast = memref.cast %B : memref<16x4xf32> to memref<*xf32> - //CHECK: [0, 16, 32, 48], //CHECK: [1, 17, 33, 49], //CHECK: [2, 18, 34, 50], @@ -52,9 +45,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [13, 29, 45, 61], //CHECK: [14, 30, 46, 62], //CHECK: [15, 31, 47, 63] + %0 = call @test() : () -> memref<16x4xf32> + %cast = memref.cast %0 : memref<16x4xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_8_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_8_f32.mlir index 89ba7873d..e351aa72f 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_8_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_chunk_8_f32.mlir @@ -1,23 +1,20 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - - -#scatter = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16x8xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16x8xf32> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16x8xf32>) - return %out : memref<16x8xf32> + %memref = gpu.alloc () : memref<16x8xf32> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x8xf32>) + %alloc = memref.alloc() : memref<16x8xf32> + gpu.memcpy %alloc, %memref : memref<16x8xf32>, memref<16x8xf32> + gpu.dealloc %memref : memref<16x8xf32> + return %alloc : memref<16x8xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16x8xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_store_scatter(%arg0: memref<16x8xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %cst = arith.constant dense<[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], [ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31.], [ 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.], @@ -26,23 +23,16 @@ module @gemm attributes {gpu.container_module} { [ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95.], [ 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., 111.], [112., 113., 114., 115., 116., 117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127.]]> : vector<8x16xf32> - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - - %cast = memref.reinterpret_cast %mem to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> - %5 = xegpu.create_tdesc %cast, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #scatter> - %trans = vector.transpose %cst, [1, 0] : vector<8x16xf32> to vector<16x8xf32> - xegpu.store %trans, %5, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #scatter>, vector<16xi1> - + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> + %0 = xegpu.create_tdesc %reinterpret_cast, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + %1 = vector.transpose %cst, [1, 0] : vector<8x16xf32> to vector<16x8xf32> + xegpu.store %1, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16x8xf32> - %cast = memref.cast %B : memref<16x8xf32> to memref<*xf32> - //CHECK: 0, 16, 32, 48, 64, 80, 96, 112 //CHECK: 1, 17, 33, 49, 65, 81, 97, 113 //CHECK: 2, 18, 34, 50, 66, 82, 98, 114 @@ -59,9 +49,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: 13, 29, 45, 61, 77, 93, 109, 125 //CHECK: 14, 30, 46, 62, 78, 94, 110, 126 //CHECK: 15, 31, 47, 63, 79, 95, 111, 127 + %0 = call @test() : () -> memref<16x8xf32> + %cast = memref.cast %0 : memref<16x8xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f16.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f16.mlir index 72d0e9c04..ee2ee3be4 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f16.mlir @@ -1,40 +1,34 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#scatter = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16xf16> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16xf16>) - return %out : memref<16xf16> + %memref = gpu.alloc () : memref<16xf16> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16xf16>) + %alloc = memref.alloc() : memref<16xf16> + gpu.memcpy %alloc, %memref : memref<16xf16>, memref<16xf16> + gpu.dealloc %memref : memref<16xf16> + return %alloc : memref<16xf16> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_store_scatter(%arg0: memref<16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf16> - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> - %tdesc = xegpu.create_tdesc %mem, %offsets : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #scatter> - xegpu.store %cst, %tdesc, %mask : vector<16xf16>, !xegpu.tensor_desc<16xf16, #scatter>, vector<16xi1> - + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>> + xegpu.store %cst, %0, %cst_0 : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16xf16> - %cast = memref.cast %B : memref<16xf16> to memref<*xf16> //CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + %0 = call @test() : () -> memref<16xf16> + %cast = memref.cast %0 : memref<16xf16> to memref<*xf16> call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f32.mlir index a17b6e0a4..9b1a34d81 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_global/store_global_no_chunk_f32.mlir @@ -1,40 +1,34 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#scatter = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16xf32> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16xf32>) - return %out : memref<16xf32> + %memref = gpu.alloc () : memref<16xf32> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16xf32>) + %alloc = memref.alloc() : memref<16xf32> + gpu.memcpy %alloc, %memref : memref<16xf32>, memref<16xf32> + gpu.dealloc %memref : memref<16xf32> + return %alloc : memref<16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_store_scatter(%arg0: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf32> - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> - %tdesc = xegpu.create_tdesc %mem, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #scatter> - xegpu.store %cst, %tdesc, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #scatter>, vector<16xi1> - + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + xegpu.store %cst, %0, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16xf32> - %cast = memref.cast %B : memref<16xf32> to memref<*xf32> //CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + %0 = call @test() : () -> memref<16xf32> + %cast = memref.cast %0 : memref<16xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_4_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_4_f32.mlir index da97a0769..ad1c33255 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_4_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_4_f32.mlir @@ -1,51 +1,41 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#global = #xegpu.scatter_tdesc_attr -#slm = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16x4xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16x4xf32> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16x4xf32>) - return %out : memref<16x4xf32> + %memref = gpu.alloc () : memref<16x4xf32> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x4xf32>) + %alloc = memref.alloc() : memref<16x4xf32> + gpu.memcpy %alloc, %memref : memref<16x4xf32>, memref<16x4xf32> + gpu.dealloc %memref : memref<16x4xf32> + return %alloc : memref<16x4xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16x4xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + // store the cst into slm and load it back; + // load from slm + // store data to global memory + gpu.func @test_store_scatter(%arg0: memref<16x4xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %cst = arith.constant dense<[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31.], [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.], [48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63.]]> : vector<4x16xf32> - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xindex> - - // store the cst into slm and load it back; - %slm = memref.alloc() : memref<64xf32, 3> - %slm_tdesc = xegpu.create_tdesc %slm, %offsets : memref<64xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #slm> - %trans = vector.transpose %cst, [1, 0] : vector<4x16xf32> to vector<16x4xf32> - xegpu.store %trans, %slm_tdesc, %mask : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #slm>, vector<16xi1> - // load from slm - %data = xegpu.load %slm_tdesc, %mask : !xegpu.tensor_desc<16x4xf32, #slm>, vector<16xi1> -> vector<16x4xf32> - - // store data to global memory - %cast = memref.reinterpret_cast %mem to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> - %5 = xegpu.create_tdesc %cast, %offsets : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #global> - xegpu.store %data, %5, %mask : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #global>, vector<16xi1> + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xindex> + %alloc = memref.alloc() : memref<64xf32, 3> + %0 = xegpu.create_tdesc %alloc, %cst_1 : memref<64xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + %1 = vector.transpose %cst, [1, 0] : vector<4x16xf32> to vector<16x4xf32> + xegpu.store %1, %0, %cst_0 : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %2 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x4xf32> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64], strides: [1] : memref<16x4xf32> to memref<64xf32> + %3 = xegpu.create_tdesc %reinterpret_cast, %cst_1 : memref<64xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + xegpu.store %2, %3, %cst_0 : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16x4xf32> - %cast = memref.cast %B : memref<16x4xf32> to memref<*xf32> - //CHECK: [0, 16, 32, 48], //CHECK: [1, 17, 33, 49], //CHECK: [2, 18, 34, 50], @@ -62,9 +52,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [13, 29, 45, 61], //CHECK: [14, 30, 46, 62], //CHECK: [15, 31, 47, 63] + %0 = call @test() : () -> memref<16x4xf32> + %cast = memref.cast %0 : memref<16x4xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32.mlir index 9321c1f35..df8e30461 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32.mlir @@ -1,25 +1,24 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - - -#global = #xegpu.scatter_tdesc_attr -#slm = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16x8xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16x8xf32> - %slm = memref.alloc() : memref<128xf32, 3> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16x8xf32>, %slm : memref<128xf32, 3>) - return %out : memref<16x8xf32> + %memref = gpu.alloc () : memref<16x8xf32> + %alloc = memref.alloc() : memref<128xf32, 3> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x8xf32>, %alloc : memref<128xf32, 3>) + %alloc_0 = memref.alloc() : memref<16x8xf32> + gpu.memcpy %alloc_0, %memref : memref<16x8xf32>, memref<16x8xf32> + gpu.dealloc %memref : memref<16x8xf32> + return %alloc_0 : memref<16x8xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16x8xf32>, %slm: memref<128xf32, 3>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + // store the cst into slm + // load from slm + // store data to global memory + gpu.func @test_store_scatter(%arg0: memref<16x8xf32>, %arg1: memref<128xf32, 3>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %cst = arith.constant dense<[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], [ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31.], [ 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.], @@ -28,30 +27,19 @@ module @gemm attributes {gpu.container_module} { [ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95.], [ 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., 111.], [112., 113., 114., 115., 116., 117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127.]]> : vector<8x16xf32> - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - - // store the cst into slm - %slm_tdesc = xegpu.create_tdesc %slm, %offsets : memref<128xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #slm> - %trans = vector.transpose %cst, [1, 0] : vector<8x16xf32> to vector<16x8xf32> - xegpu.store %trans, %slm_tdesc, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #slm>, vector<16xi1> - - // load from slm - %data = xegpu.load %slm_tdesc, %mask : !xegpu.tensor_desc<16x8xf32, #slm>, vector<16xi1> -> vector<16x8xf32> - - // store data to global memory - %cast = memref.reinterpret_cast %mem to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> - %5 = xegpu.create_tdesc %cast, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #global> - xegpu.store %data, %5, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #global>, vector<16xi1> + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg1, %cst_1 : memref<128xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + %1 = vector.transpose %cst, [1, 0] : vector<8x16xf32> to vector<16x8xf32> + xegpu.store %1, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %2 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x8xf32> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> + %3 = xegpu.create_tdesc %reinterpret_cast, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + xegpu.store %2, %3, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16x8xf32> - %cast = memref.cast %B : memref<16x8xf32> to memref<*xf32> - //CHECK: 0, 16, 32, 48, 64, 80, 96, 112 //CHECK: 1, 17, 33, 49, 65, 81, 97, 113 //CHECK: 2, 18, 34, 50, 66, 82, 98, 114 @@ -68,9 +56,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: 13, 29, 45, 61, 77, 93, 109, 125 //CHECK: 14, 30, 46, 62, 78, 94, 110, 126 //CHECK: 15, 31, 47, 63, 79, 95, 111, 127 + %0 = call @test() : () -> memref<16x8xf32> + %cast = memref.cast %0 : memref<16x8xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32_mask.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32_mask.mlir index a9ab37bf2..28653227d 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32_mask.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_chunk_8_f32_mask.mlir @@ -1,25 +1,24 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - - -#global = #xegpu.scatter_tdesc_attr -#slm = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16x8xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16x8xf32> - %slm = memref.alloc() : memref<128xf32, 3> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16x8xf32>, %slm : memref<128xf32, 3>) - return %out : memref<16x8xf32> + %memref = gpu.alloc () : memref<16x8xf32> + %alloc = memref.alloc() : memref<128xf32, 3> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x8xf32>, %alloc : memref<128xf32, 3>) + %alloc_0 = memref.alloc() : memref<16x8xf32> + gpu.memcpy %alloc_0, %memref : memref<16x8xf32>, memref<16x8xf32> + gpu.dealloc %memref : memref<16x8xf32> + return %alloc_0 : memref<16x8xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16x8xf32>, %slm: memref<128xf32, 3>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + // store the cst into slm + // load from slm + // store data to global memory + gpu.func @test_store_scatter(%arg0: memref<16x8xf32>, %arg1: memref<128xf32, 3>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %cst = arith.constant dense<[[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], [ 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31.], [ 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47.], @@ -28,31 +27,19 @@ module @gemm attributes {gpu.container_module} { [ 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95.], [ 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., 111.], [112., 113., 114., 115., 116., 117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127.]]> : vector<8x16xf32> - - %mask = arith.constant dense<[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]> : vector<16xi1> - %offsets = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - - // store the cst into slm - %slm_tdesc = xegpu.create_tdesc %slm, %offsets : memref<128xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #slm> - %trans = vector.transpose %cst, [1, 0] : vector<8x16xf32> to vector<16x8xf32> - xegpu.store %trans, %slm_tdesc, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #slm>, vector<16xi1> - - // load from slm - %data = xegpu.load %slm_tdesc, %mask : !xegpu.tensor_desc<16x8xf32, #slm>, vector<16xi1> -> vector<16x8xf32> - - // store data to global memory - %cast = memref.reinterpret_cast %mem to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> - %5 = xegpu.create_tdesc %cast, %offsets : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #global> - xegpu.store %data, %5, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #global>, vector<16xi1> + %cst_0 = arith.constant dense<[true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false]> : vector<16xi1> + %cst_1 = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg1, %cst_1 : memref<128xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + %1 = vector.transpose %cst, [1, 0] : vector<8x16xf32> to vector<16x8xf32> + xegpu.store %1, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %2 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x8xf32> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] : memref<16x8xf32> to memref<128xf32> + %3 = xegpu.create_tdesc %reinterpret_cast, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + xegpu.store %2, %3, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16x8xf32> - %cast = memref.cast %B : memref<16x8xf32> to memref<*xf32> - - //CHECK: 0, 16, 32, 48, 64, 80, 96, 112 //CHECK: 0, 0, 0, 0, 0, 0, 0, 0 //CHECK: 2, 18, 34, 50, 66, 82, 98, 114 @@ -69,9 +56,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: 0, 0, 0, 0, 0, 0, 0, 0 //CHECK: 14, 30, 46, 62, 78, 94, 110, 126 //CHECK: 0, 0, 0, 0, 0, 0, 0, 0 + %0 = call @test() : () -> memref<16x8xf32> + %cast = memref.cast %0 : memref<16x8xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f16.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f16.mlir index d26632e71..18625ac03 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f16.mlir @@ -1,48 +1,40 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#global = #xegpu.scatter_tdesc_attr -#slm = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16xf16> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16xf16>) - return %out : memref<16xf16> + %memref = gpu.alloc () : memref<16xf16> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16xf16>) + %alloc = memref.alloc() : memref<16xf16> + gpu.memcpy %alloc, %memref : memref<16xf16>, memref<16xf16> + gpu.dealloc %memref : memref<16xf16> + return %alloc : memref<16xf16> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf16> - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> - + gpu.module @test_kernel { // store the cst into slm and load it back; - %slm = memref.alloc() : memref<16xf16, 3> - %slm_tdesc = xegpu.create_tdesc %slm, %offsets : memref<16xf16, 3>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #slm> - xegpu.store %cst, %slm_tdesc, %mask : vector<16xf16>, !xegpu.tensor_desc<16xf16, #slm>, vector<16xi1> - %data = xegpu.load %slm_tdesc, %mask : !xegpu.tensor_desc<16xf16, #slm>, vector<16xi1> -> vector<16xf16> - // store data to global memory - %tdesc = xegpu.create_tdesc %mem, %offsets : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #global> - xegpu.store %data, %tdesc, %mask : vector<16xf16>, !xegpu.tensor_desc<16xf16, #global>, vector<16xi1> - + gpu.func @test_store_scatter(%arg0: memref<16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf16> + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %alloc = memref.alloc() : memref<16xf16, 3> + %0 = xegpu.create_tdesc %alloc, %cst_1 : memref<16xf16, 3>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr> + xegpu.store %cst, %0, %cst_0 : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %1 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf16> + %2 = xegpu.create_tdesc %arg0, %cst_1 : memref<16xf16>, vector<16xindex> -> !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>> + xegpu.store %1, %2, %cst_0 : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16xf16> - %cast = memref.cast %B : memref<16xf16> to memref<*xf16> //CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + %0 = call @test() : () -> memref<16xf16> + %cast = memref.cast %0 : memref<16xf16> to memref<*xf16> call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f32.mlir index 975d854ca..bb9faa6f8 100644 --- a/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gather_scatter_slm/store_load_slm_no_chunk_f32.mlir @@ -1,48 +1,39 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#global = #xegpu.scatter_tdesc_attr -#slm = #xegpu.scatter_tdesc_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/../xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test() -> memref<16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %out = gpu.alloc host_shared () : memref<16xf32> - gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%out : memref<16xf32>) - return %out : memref<16xf32> + %memref = gpu.alloc () : memref<16xf32> + gpu.launch_func @test_kernel::@test_store_scatter blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16xf32>) + %alloc = memref.alloc() : memref<16xf32> + gpu.memcpy %alloc, %memref : memref<16xf32>, memref<16xf32> + gpu.dealloc %memref : memref<16xf32> + return %alloc : memref<16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_store_scatter(%mem: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf32> - - %mask = arith.constant dense<1> : vector<16xi1> - %offsets = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> - + gpu.module @test_kernel { // store the cst into slm and load it back; - %slm = memref.alloc() : memref<16xf32, 3> - %slm_tdesc = xegpu.create_tdesc %slm, %offsets : memref<16xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #slm> - xegpu.store %cst, %slm_tdesc, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #slm>, vector<16xi1> - %data = xegpu.load %slm_tdesc, %mask : !xegpu.tensor_desc<16xf32, #slm>, vector<16xi1> -> vector<16xf32> - - %tdesc = xegpu.create_tdesc %mem, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #global> - xegpu.store %cst, %tdesc, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #global>, vector<16xi1> - + gpu.func @test_store_scatter(%arg0: memref<16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]> : vector<16xf32> + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %alloc = memref.alloc() : memref<16xf32, 3> + %0 = xegpu.create_tdesc %alloc, %cst_1 : memref<16xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + xegpu.store %cst, %0, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %1 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %2 = xegpu.create_tdesc %arg0, %cst_1 : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + xegpu.store %cst, %2, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<16xf32> - %cast = memref.cast %B : memref<16xf32> to memref<*xf32> //CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + %0 = call @test() : () -> memref<16xf32> + %cast = memref.cast %0 : memref<16xf32> to memref<*xf32> call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1016x1016_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1016x1016_f16_f16_f32.mlir index 08b098fed..354dff635 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1016x1016_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1016x1016_f16_f16_f32.mlir @@ -1,109 +1,105 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_1024x1016xf16 : memref<1024x1016xf16> = dense<1.0> - memref.global "private" @__constant_1016x1016xf16_ : memref<1016x1016xf16> = dense<1.0> - memref.global "private" @__constant_1024x1016xf32 : memref<1024x1016xf32> = dense<0.0> + memref.global "private" @__constant_1024x1016xf16 : memref<1024x1016xf16> = dense<1.000000e+00> + memref.global "private" @__constant_1016x1016xf16_ : memref<1016x1016xf16> = dense<1.000000e+00> + memref.global "private" @__constant_1024x1016xf32 : memref<1024x1016xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<1024x1016xf16>, %arg1: memref<1016x1016xf16>) -> memref<1024x1016xf32> attributes {llvm.emit_c_interface} { %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1024x1016xf16> - memref.copy %arg0, %memref : memref<1024x1016xf16> to memref<1024x1016xf16> - %memref_0 = gpu.alloc host_shared () : memref<1016x1016xf16> - memref.copy %arg1, %memref_0 : memref<1016x1016xf16> to memref<1016x1016xf16> - %memref_1 = gpu.alloc host_shared () : memref<1024x1016xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1016xf16>, %memref_0 : memref<1016x1016xf16>, %memref_1 : memref<1024x1016xf32>) + %memref = gpu.alloc () : memref<1024x1016xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1016xf16>, memref<1024x1016xf16> + %memref_0 = gpu.alloc () : memref<1016x1016xf16> + gpu.memcpy %memref_0, %arg1 : memref<1016x1016xf16>, memref<1016x1016xf16> + %memref_1 = gpu.alloc () : memref<1024x1016xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1016xf16>, %memref_0 : memref<1016x1016xf16>, %memref_1 : memref<1024x1016xf32>) gpu.dealloc %memref : memref<1024x1016xf16> gpu.dealloc %memref_0 : memref<1016x1016xf16> - return %memref_1 : memref<1024x1016xf32> + %alloc = memref.alloc() : memref<1024x1016xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1016xf32>, memref<1024x1016xf32> + gpu.dealloc %memref_1 : memref<1024x1016xf32> + return %alloc : memref<1024x1016xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<1024x1016xf16>, %arg1: memref<1016x1016xf16>, %arg2: memref<1024x1016xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c8 = arith.constant 8 : index %c1024 = arith.constant 1024 : index %c1016 = arith.constant 1016 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1016xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block - %6 = scf.for %arg3 = %c0 to %c1016 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1016xf16> -> !xegpu.tensor_desc<8x16xf16> - %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1016x1016xf16> -> !xegpu.tensor_desc<16x16xf16> - %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %10 = xegpu.load_nd %8 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - scf.yield %11 : vector<8x16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1016xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg3 = %c0 to %c1016 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<1024x1016xf16> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<1016x1016xf16> -> !xegpu.tensor_desc<16x16xf16> + %7 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %8 = xegpu.load_nd %6 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = xegpu.dpas %7, %8, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %9 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 1.000000e+02 : f16 + %c128_i16 = arith.constant 128 : i16 + %c1016 = arith.constant 1016 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1024x1016xf16 : memref<1024x1016xf16> %1 = memref.get_global @__constant_1016x1016xf16_ : memref<1016x1016xf16> - %ref = memref.get_global @__constant_1024x1016xf32 : memref<1024x1016xf32> - %init = arith.constant 0.0 : f16 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %c1016 = arith.constant 1016 : index // fill the top-left block 128x128 // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 + %2 = memref.get_global @__constant_1024x1016xf32 : memref<1024x1016xf32> scf.for %arg0 = %c0 to %c128 step %c1 { scf.for %arg1 = %c0 to %c128 step %c1 { - %int0 = arith.index_cast %arg0 : index to i16 - %int1 = arith.index_cast %arg1 : index to i16 - %c128_i16 = arith.constant 128 : i16 - %idx0 = arith.muli %int0, %c128_i16 : i16 - %idx1 = arith.addi %int1, %idx0 : i16 - %fp = arith.uitofp %idx1 : i16 to f16 - %cst100 = arith.constant 100.0 : f16 - %val0 = arith.divf %fp, %cst100 : f16 - %cst1 = arith.constant 1.0 : f16 - %val1 = arith.addf %val0, %cst1 : f16 - memref.store %val0, %0[%arg0, %arg1] : memref<1024x1016xf16> - memref.store %val1, %1[%arg0, %arg1] : memref<1016x1016xf16> + %4 = arith.index_cast %arg0 : index to i16 + %5 = arith.index_cast %arg1 : index to i16 + %6 = arith.muli %4, %c128_i16 : i16 + %7 = arith.addi %5, %6 : i16 + %8 = arith.uitofp %7 : i16 to f16 + %9 = arith.divf %8, %cst_0 : f16 + %10 = arith.addf %9, %cst : f16 + memref.store %9, %0[%arg0, %arg1] : memref<1024x1016xf16> + memref.store %10, %1[%arg0, %arg1] : memref<1016x1016xf16> } } // caculate the result C matrix scf.for %arg0 = %c0 to %c1024 step %c1 { scf.for %arg1 = %c0 to %c1016 step %c1 { - %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1016xf32> - %res = scf.for %arg2 = %c0 to %c1016 step %c1 iter_args(%arg3 = %acc) -> f32 { - %a = memref.load %0[%arg0, %arg2] : memref<1024x1016xf16> - %b = memref.load %1[%arg2, %arg1] : memref<1016x1016xf16> - %c = arith.mulf %a, %b : f16 - %cc = arith.extf %c : f16 to f32 - %ccc = arith.addf %cc, %arg3 : f32 - scf.yield %ccc : f32 + %4 = memref.load %2[%arg0, %arg1] : memref<1024x1016xf32> + %5 = scf.for %arg2 = %c0 to %c1016 step %c1 iter_args(%arg3 = %4) -> (f32) { + %6 = memref.load %0[%arg0, %arg2] : memref<1024x1016xf16> + %7 = memref.load %1[%arg2, %arg1] : memref<1016x1016xf16> + %8 = arith.mulf %6, %7 : f16 + %9 = arith.extf %8 : f16 to f32 + %10 = arith.addf %9, %arg3 : f32 + scf.yield %10 : f32 } - memref.store %res, %ref[%arg0, %arg1] : memref<1024x1016xf32> + memref.store %5, %2[%arg0, %arg1] : memref<1024x1016xf32> } } - - %2 = call @test(%0, %1) : (memref<1024x1016xf16>, memref<1016x1016xf16>) -> memref<1024x1016xf32> - %cast = memref.cast %2 : memref<1024x1016xf32> to memref<*xf32> // call @printMemrefF32(%cast) : (memref<*xf32>) -> () - %cast_ref = memref.cast %ref : memref<1024x1016xf32> to memref<*xf32> // call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %3 = call @test(%0, %1) : (memref<1024x1016xf16>, memref<1016x1016xf16>) -> memref<1024x1016xf32> + %cast = memref.cast %3 : memref<1024x1016xf32> to memref<*xf32> + %cast_1 = memref.cast %2 : memref<1024x1016xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xbf16.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xbf16.mlir index 84ce9b0c4..2893268a1 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xbf16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xbf16.mlir @@ -1,107 +1,103 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_1024x1024xbf16 : memref<1024x1024xbf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xbf16_ : memref<1024x1024xbf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.0> + memref.global "private" @__constant_1024x1024xbf16 : memref<1024x1024xbf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xbf16_ : memref<1024x1024xbf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %arg0, %memref : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %memref_0 = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %arg1, %memref_0 : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %memref_1 = gpu.alloc host_shared () : memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xbf16>, %memref_0 : memref<1024x1024xbf16>, %memref_1 : memref<1024x1024xf32>) + %memref = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_0 = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xbf16>, %memref_0 : memref<1024x1024xbf16>, %memref_1 : memref<1024x1024xf32>) gpu.dealloc %memref : memref<1024x1024xbf16> gpu.dealloc %memref_0 : memref<1024x1024xbf16> - return %memref_1 : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c8 = arith.constant 8 : index %c1024 = arith.constant 1024 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block - %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16> - %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16> - %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> - %10 = xegpu.load_nd %8 {packed} : !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16> - %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - scf.yield %11 : vector<8x16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %6 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16> + %7 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> + %8 = xegpu.load_nd %6 <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<8x16x2xbf16> + %9 = xegpu.dpas %7, %8, %arg4 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %9 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : bf16 + %cst_0 = arith.constant 1.000000e+02 : bf16 + %c128_i16 = arith.constant 128 : i16 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1024x1024xbf16 : memref<1024x1024xbf16> %1 = memref.get_global @__constant_1024x1024xbf16_ : memref<1024x1024xbf16> - %ref = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> - %init = arith.constant 0.0 : bf16 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index // fill the top-left block 128x128 // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 + %2 = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> scf.for %arg0 = %c0 to %c128 step %c1 { scf.for %arg1 = %c0 to %c128 step %c1 { - %int0 = arith.index_cast %arg0 : index to i16 - %int1 = arith.index_cast %arg1 : index to i16 - %c128_i16 = arith.constant 128 : i16 - %idx0 = arith.muli %int0, %c128_i16 : i16 - %idx1 = arith.addi %int1, %idx0 : i16 - %fp = arith.uitofp %idx1 : i16 to bf16 - %cst100 = arith.constant 100.0 : bf16 - %val0 = arith.divf %fp, %cst100 : bf16 - %cst1 = arith.constant 1.0 : bf16 - %val1 = arith.addf %val0, %cst1 : bf16 - memref.store %val0, %0[%arg0, %arg1] : memref<1024x1024xbf16> - memref.store %val1, %1[%arg0, %arg1] : memref<1024x1024xbf16> + %4 = arith.index_cast %arg0 : index to i16 + %5 = arith.index_cast %arg1 : index to i16 + %6 = arith.muli %4, %c128_i16 : i16 + %7 = arith.addi %5, %6 : i16 + %8 = arith.uitofp %7 : i16 to bf16 + %9 = arith.divf %8, %cst_0 : bf16 + %10 = arith.addf %9, %cst : bf16 + memref.store %9, %0[%arg0, %arg1] : memref<1024x1024xbf16> + memref.store %10, %1[%arg0, %arg1] : memref<1024x1024xbf16> } } // caculate the result C matrix scf.for %arg0 = %c0 to %c1024 step %c1 { scf.for %arg1 = %c0 to %c1024 step %c1 { - %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1024xf32> - %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { - %a = memref.load %0[%arg0, %arg2] : memref<1024x1024xbf16> - %b = memref.load %1[%arg2, %arg1] : memref<1024x1024xbf16> - %c = arith.mulf %a, %b : bf16 - %cc = arith.extf %c : bf16 to f32 - %ccc = arith.addf %cc, %arg3 : f32 - scf.yield %ccc : f32 + %4 = memref.load %2[%arg0, %arg1] : memref<1024x1024xf32> + %5 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %4) -> (f32) { + %6 = memref.load %0[%arg0, %arg2] : memref<1024x1024xbf16> + %7 = memref.load %1[%arg2, %arg1] : memref<1024x1024xbf16> + %8 = arith.mulf %6, %7 : bf16 + %9 = arith.extf %8 : bf16 to f32 + %10 = arith.addf %9, %arg3 : f32 + scf.yield %10 : f32 } - memref.store %res, %ref[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %5, %2[%arg0, %arg1] : memref<1024x1024xf32> } } - - %2 = call @test(%0, %1) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>) -> memref<1024x1024xf32> - %cast = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast) : (memref<*xf32>) -> () - %cast_ref = memref.cast %ref : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %3 = call @test(%0, %1) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>) -> memref<1024x1024xf32> + %cast = memref.cast %3 : memref<1024x1024xf32> to memref<*xf32> + %cast_1 = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.mlir index f85f6b837..ba738dfa3 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.mlir @@ -1,107 +1,103 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.0> + memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg0, %memref : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_0 = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg1, %memref_0 : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_1 = gpu.alloc host_shared () : memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) gpu.dealloc %memref : memref<1024x1024xf16> gpu.dealloc %memref_0 : memref<1024x1024xf16> - return %memref_1 : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c8 = arith.constant 8 : index %c1024 = arith.constant 1024 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block - %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> - %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> - %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %10 = xegpu.load_nd %8 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - scf.yield %11 : vector<8x16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %7 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %8 = xegpu.load_nd %6 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = xegpu.dpas %7, %8, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %9 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 1.000000e+02 : f16 + %c128_i16 = arith.constant 128 : i16 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1024x1024xf16 : memref<1024x1024xf16> %1 = memref.get_global @__constant_1024x1024xf16_ : memref<1024x1024xf16> - %ref = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> - %init = arith.constant 0.0 : f16 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index // fill the top-left block 128x128 // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 + %2 = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> scf.for %arg0 = %c0 to %c128 step %c1 { scf.for %arg1 = %c0 to %c128 step %c1 { - %int0 = arith.index_cast %arg0 : index to i16 - %int1 = arith.index_cast %arg1 : index to i16 - %c128_i16 = arith.constant 128 : i16 - %idx0 = arith.muli %int0, %c128_i16 : i16 - %idx1 = arith.addi %int1, %idx0 : i16 - %fp = arith.uitofp %idx1 : i16 to f16 - %cst100 = arith.constant 100.0 : f16 - %val0 = arith.divf %fp, %cst100 : f16 - %cst1 = arith.constant 1.0 : f16 - %val1 = arith.addf %val0, %cst1 : f16 - memref.store %val0, %0[%arg0, %arg1] : memref<1024x1024xf16> - memref.store %val1, %1[%arg0, %arg1] : memref<1024x1024xf16> + %4 = arith.index_cast %arg0 : index to i16 + %5 = arith.index_cast %arg1 : index to i16 + %6 = arith.muli %4, %c128_i16 : i16 + %7 = arith.addi %5, %6 : i16 + %8 = arith.uitofp %7 : i16 to f16 + %9 = arith.divf %8, %cst_0 : f16 + %10 = arith.addf %9, %cst : f16 + memref.store %9, %0[%arg0, %arg1] : memref<1024x1024xf16> + memref.store %10, %1[%arg0, %arg1] : memref<1024x1024xf16> } } // caculate the result C matrix scf.for %arg0 = %c0 to %c1024 step %c1 { scf.for %arg1 = %c0 to %c1024 step %c1 { - %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1024xf32> - %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { - %a = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> - %b = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> - %c = arith.mulf %a, %b : f16 - %cc = arith.extf %c : f16 to f32 - %ccc = arith.addf %cc, %arg3 : f32 - scf.yield %ccc : f32 + %4 = memref.load %2[%arg0, %arg1] : memref<1024x1024xf32> + %5 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %4) -> (f32) { + %6 = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> + %7 = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> + %8 = arith.mulf %6, %7 : f16 + %9 = arith.extf %8 : f16 to f32 + %10 = arith.addf %9, %arg3 : f32 + scf.yield %10 : f32 } - memref.store %res, %ref[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %5, %2[%arg0, %arg1] : memref<1024x1024xf32> } } - - %2 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> - %cast = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast) : (memref<*xf32>) -> () - %cast_ref = memref.cast %ref : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %3 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> + %cast = memref.cast %3 : memref<1024x1024xf32> to memref<*xf32> + %cast_1 = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.using.updateoffset.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.using.updateoffset.mlir index afd170bc0..622465c09 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.using.updateoffset.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_1024x1024xf16.using.updateoffset.mlir @@ -1,109 +1,105 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.0> + memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg0, %memref : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_0 = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg1, %memref_0 : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_1 = gpu.alloc host_shared () : memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) gpu.dealloc %memref : memref<1024x1024xf16> gpu.dealloc %memref_0 : memref<1024x1024xf16> - return %memref_1 : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c8 = arith.constant 8 : index %c1024 = arith.constant 1024 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block - %7 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> - %8 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> - %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5, %subA = %7, %subB = %8) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>) { - %9 = xegpu.load_nd %subA : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %10 = xegpu.load_nd %subB {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %12 = xegpu.update_nd_offset %subA, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %13 = xegpu.update_nd_offset %subB, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - scf.yield %11, %12, %13: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %5) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>) { + %7 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %8 = xegpu.load_nd %arg6 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = xegpu.dpas %7, %8, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %10 = xegpu.update_nd_offset %arg5, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %11 = xegpu.update_nd_offset %arg6, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + scf.yield %9, %10, %11 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16> } - xegpu.store_nd %6#0, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %6#0, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 1.000000e+02 : f16 + %c128_i16 = arith.constant 128 : i16 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1024x1024xf16 : memref<1024x1024xf16> %1 = memref.get_global @__constant_1024x1024xf16_ : memref<1024x1024xf16> - %ref = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> - %init = arith.constant 0.0 : f16 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index // fill the top-left block 128x128 // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 + %2 = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> scf.for %arg0 = %c0 to %c128 step %c1 { scf.for %arg1 = %c0 to %c128 step %c1 { - %int0 = arith.index_cast %arg0 : index to i16 - %int1 = arith.index_cast %arg1 : index to i16 - %c128_i16 = arith.constant 128 : i16 - %idx0 = arith.muli %int0, %c128_i16 : i16 - %idx1 = arith.addi %int1, %idx0 : i16 - %fp = arith.uitofp %idx1 : i16 to f16 - %cst100 = arith.constant 100.0 : f16 - %val0 = arith.divf %fp, %cst100 : f16 - %cst1 = arith.constant 1.0 : f16 - %val1 = arith.addf %val0, %cst1 : f16 - memref.store %val0, %0[%arg0, %arg1] : memref<1024x1024xf16> - memref.store %val1, %1[%arg0, %arg1] : memref<1024x1024xf16> + %4 = arith.index_cast %arg0 : index to i16 + %5 = arith.index_cast %arg1 : index to i16 + %6 = arith.muli %4, %c128_i16 : i16 + %7 = arith.addi %5, %6 : i16 + %8 = arith.uitofp %7 : i16 to f16 + %9 = arith.divf %8, %cst_0 : f16 + %10 = arith.addf %9, %cst : f16 + memref.store %9, %0[%arg0, %arg1] : memref<1024x1024xf16> + memref.store %10, %1[%arg0, %arg1] : memref<1024x1024xf16> } } // caculate the result C matrix scf.for %arg0 = %c0 to %c1024 step %c1 { scf.for %arg1 = %c0 to %c1024 step %c1 { - %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1024xf32> - %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { - %a = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> - %b = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> - %c = arith.mulf %a, %b : f16 - %cc = arith.extf %c : f16 to f32 - %ccc = arith.addf %cc, %arg3 : f32 - scf.yield %ccc : f32 + %4 = memref.load %2[%arg0, %arg1] : memref<1024x1024xf32> + %5 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %4) -> (f32) { + %6 = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> + %7 = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> + %8 = arith.mulf %6, %7 : f16 + %9 = arith.extf %8 : f16 to f32 + %10 = arith.addf %9, %arg3 : f32 + scf.yield %10 : f32 } - memref.store %res, %ref[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %5, %2[%arg0, %arg1] : memref<1024x1024xf32> } } - - %2 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> - %cast = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast) : (memref<*xf32>) -> () - %cast_ref = memref.cast %ref : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %3 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> + %cast = memref.cast %3 : memref<1024x1024xf32> to memref<*xf32> + %cast_1 = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir index 5c0e2e9d8..f19e136c9 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_256x256x256_bf16_bf16_f32.mlir @@ -1,37 +1,30 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<256x256xbf16>, %B: memref<256x256xbf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<256x256xbf16>, %arg1: memref<256x256xbf16>, %arg2: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - - %A_gpu = gpu.alloc host_shared () : memref<256x256xbf16> - memref.copy %A, %A_gpu : memref<256x256xbf16> to memref<256x256xbf16> - %B_gpu = gpu.alloc host_shared () : memref<256x256xbf16> - memref.copy %B, %B_gpu : memref<256x256xbf16> to memref<256x256xbf16> - %C_gpu = gpu.alloc host_shared () : memref<256x256xf32> - memref.copy %C, %C_gpu : memref<256x256xf32> to memref<256x256xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<256x256xbf16>, %B_gpu : memref<256x256xbf16>, %C_gpu : memref<256x256xf32>) - gpu.dealloc %A_gpu : memref<256x256xbf16> - gpu.dealloc %B_gpu : memref<256x256xbf16> - return %C_gpu : memref<256x256xf32> + %memref = gpu.alloc () : memref<256x256xbf16> + gpu.memcpy %memref, %arg0 : memref<256x256xbf16>, memref<256x256xbf16> + %memref_0 = gpu.alloc () : memref<256x256xbf16> + gpu.memcpy %memref_0, %arg1 : memref<256x256xbf16>, memref<256x256xbf16> + %memref_1 = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_1, %arg2 : memref<256x256xf32>, memref<256x256xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<256x256xbf16>, %memref_0 : memref<256x256xbf16>, %memref_1 : memref<256x256xf32>) + gpu.dealloc %memref : memref<256x256xbf16> + gpu.dealloc %memref_0 : memref<256x256xbf16> + %alloc = memref.alloc() : memref<256x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_1 : memref<256x256xf32> + return %alloc : memref<256x256xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<256x256xbf16>, %B: memref<256x256xbf16>, %C: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // constants + gpu.func @test_kernel(%arg0: memref<256x256xbf16>, %arg1: memref<256x256xbf16>, %arg2: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c256 = arith.constant 256 : index %c512 = arith.constant 512 : index %c1024 = arith.constant 1024 : index @@ -48,10 +41,7 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 // get IDs - %wg_id_x = gpu.block_id x - %wg_id_y = gpu.block_id y // %sg_id = gpu.subgroup_id : index - // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout // C sg tile size is 32x64 // SG layout for one C tile update @@ -62,17 +52,8 @@ module @gemm attributes {gpu.container_module} { // --> y means cols // | // V x means rows - // get unique sg ID in global context - %global_sg_id_x = gpu.global_id x - %global_sg_id_y = gpu.global_id y - %local_sg_id_x = arith.remui %global_sg_id_x, %c8 : index - %local_sg_id_y = arith.remui %global_sg_id_y, %c4 : index - // compute SG C tile offsets in x and y dims - %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index - %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index - // each SG needs to do the follwoing compute to update its 32x64 sub tile // (32x256)x(256x64)=(32x64) // DPAS size is (8x16)x(16x16)=(8x16) @@ -80,14 +61,7 @@ module @gemm attributes {gpu.container_module} { // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) // this will require 32 DPAS ops (4x2x2) inside the K loop - // WG tiles are 256x256 so there offsets are same for A, B and C - %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index - %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index - - %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index - %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index - // prefetching A and B slice within the 256x256 WG tile // // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice @@ -99,22 +73,11 @@ module @gemm attributes {gpu.container_module} { // SG 4 -> slice 4 // .... // SG 31 -> slice 31 - %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index - %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index // create A preftech tiles and prefetch // stage 1 - %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<256x256xbf16> -> !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) - %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 3 - %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> - // stage 4 - %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - // prefetch the entire 32x256 slice of B WG tile, we still use the prefetch size 8x32. // SGs have 8x4 layout. In this case 8 subgroups must do a colloborative prefetch of 32x64 tile. // this because the B tile arrangement within the 32x256 slice is as follows @@ -130,144 +93,111 @@ module @gemm attributes {gpu.container_module} { // | 8 | 24|| 9 | 25 || 10 | 26 || 11| 27 | // | 12 | 28|| 13 | 29 || 14 | 30 || 15| 31 | // For example, SGs 0,4,8,12,16,20,24,28 share the data in left 32x64 tile of B slice. - // calculate the x offsets and y offsets within the 32x256 slice // XeTLA like co-operative prefetch for B - %B_sg_prefetch_offset_x_temp0 = arith.divui %local_sg_id_x, %c2 : index - %B_sg_prefetch_offset_x = arith.muli %B_sg_prefetch_offset_x_temp0, %c8 : index - - %B_sg_prefetch_offset_y_temp0 = arith.muli %local_sg_id_y, %c64 : index - %B_sg_prefetch_offset_y_temp1 = arith.remui %local_sg_id_x, %c2 : index - %B_sg_prefetch_offset_y_temp2 = arith.muli %B_sg_prefetch_offset_y_temp1, %c32 : index - - %B_sg_prefetch_offset_y_temp3 = arith.addi %B_sg_prefetch_offset_y_temp0, %B_sg_prefetch_offset_y_temp2 : index - %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_temp3 : index - - // create B prefetch tiles and prefetch - %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<256x256xbf16> -> !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) - %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 3 - %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 4 - %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> - - %A_sg_init_tile_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<256x256xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %A_sg_init_tile_1 = xegpu.update_nd_offset %A_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - //create B tiles - %B_sg_init_tile_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<256x256xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %B_sg_init_tile_1 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - - %B_sg_init_tile_2 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %B_sg_init_tile_3 = xegpu.update_nd_offset %B_sg_init_tile_2, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> // ************************* // - // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix - %zero_vec = arith.constant dense<0.0> : vector<128xf32> - %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - // Multi nbarrier implementation, // one set nbarrier is used to sync subgroups with same sg_id_x (local_sg_id_x) // second set nbarrier us used to sync subgroups with same sg_id_y (local_sg_id_y) // In this case wg_size = 8,4 (wg_size_x = 8; wg_size_y = 4) // So in Y-direction we need 4 nbarrier (to sync subgroups with same sg_id_y) // In X-direction we need 8 nbarrier (to sync subgroups with same sg_id_x) - %c_wg_size_x = arith.constant 8 : index - %c_wg_size_y = arith.constant 4 : index - %num_nbarrier = arith.addi %c_wg_size_y, %c_wg_size_x : index // 8+4=12 - xegpu.alloc_nbarrier 12 // = 12 - // First set of nbarriers work across coloumns, we have 4 coloums of subgroups, // Hnece 4 nbrrier // Each nbarrier has 8 producers and consumers // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) - // %nbarrier_role = arith.constant 0 : i8 - %nbarrier_threads_y = arith.constant 8 : i8 - %nbarrier_id_y = arith.index_cast %local_sg_id_y : index to i8 - %nbarrier_y = xegpu.init_nbarrier %nbarrier_id_y, %nbarrier_threads_y : i8, i8 -> !xegpu.nbarrier - // Second set of barriers work on across rows of subgroups, // we have 8 rows of subgroups. Hnece, 8 nbarrier // Each nbarrier has 4 producers and consumers // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) - // We already have 4 (=%c_wg_size_y) nbarriers with id 0-3, // Now the next set of barrier id would start from 4, hence, - %nbarrier_threads_x = arith.constant 4 : i8 - %index_nbarrier_id_x = arith.addi %c_wg_size_y, %local_sg_id_x : index - %nbarrier_id_x = arith.index_cast %index_nbarrier_id_x : index to i8 - %nbarrier_x = xegpu.init_nbarrier %nbarrier_id_x, %nbarrier_threads_x : i8, i8 -> !xegpu.nbarrier - // K loop advances in 32 steps - %k_loop_result:24 = scf.for %k = %c0 to %c256 step %c32 iter_args ( - %A_tile_0 = %A_sg_init_tile_0, - %A_tile_1 = %A_sg_init_tile_1, - - %B_tile_0 = %B_sg_init_tile_0, - %B_tile_1 = %B_sg_init_tile_1, - %B_tile_2 = %B_sg_init_tile_2, - %B_tile_3 = %B_sg_init_tile_3, - - %c_val_0_0 = %c_init_val_0_0, - %c_val_0_1 = %c_init_val_0_1, - %c_val_0_2 = %c_init_val_0_2, - %c_val_0_3 = %c_init_val_0_3, - %c_val_1_0 = %c_init_val_1_0, - %c_val_1_1 = %c_init_val_1_1, - %c_val_1_2 = %c_init_val_1_2, - %c_val_1_3 = %c_init_val_1_3, - %c_val_2_0 = %c_init_val_2_0, - %c_val_2_1 = %c_init_val_2_1, - %c_val_2_2 = %c_init_val_2_2, - %c_val_2_3 = %c_init_val_2_3, - %c_val_3_0 = %c_init_val_3_0, - %c_val_3_1 = %c_init_val_3_1, - %c_val_3_2 = %c_init_val_3_2, - %c_val_3_3 = %c_init_val_3_3, - - %A_prefetch_tile = %A_sg_prefetch_tile_iter2, - %B_prefetch_tile = %B_sg_prefetch_tile_iter2 - ) -> - (!xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> - ) - { // all SGs must arrive here first - %every_8th_iter = arith.remui %k, %c32 : index - %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 - %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 - scf.if %every_8th_iter_cond { - xegpu.nbarrier_arrive %nbarrier_y : !xegpu.nbarrier - xegpu.nbarrier_arrive %nbarrier_x : !xegpu.nbarrier + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %global_id_x = gpu.global_id x + %global_id_y = gpu.global_id y + %0 = arith.remui %global_id_x, %c8 : index + %1 = arith.remui %global_id_y, %c4 : index + %2 = arith.muli %global_id_x, %c32 : index + %3 = arith.muli %global_id_y, %c64 : index + %4 = arith.muli %block_id_x, %c256 : index + %5 = arith.muli %block_id_y, %c256 : index + %6 = arith.muli %0, %c4 : index + %7 = arith.addi %6, %1 : index + %8 = arith.muli %7, %c8 : index + %9 = arith.addi %8, %4 : index + %10 = xegpu.create_nd_tdesc %arg0[%9, %c0] : memref<256x256xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %11 = xegpu.update_nd_offset %10, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %12 = xegpu.update_nd_offset %11, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %13 = xegpu.update_nd_offset %12, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + %14 = arith.divui %0, %c2 : index + %15 = arith.muli %14, %c8 : index + %16 = arith.muli %1, %c64 : index + %17 = arith.remui %0, %c2 : index + %18 = arith.muli %17, %c32 : index + %19 = arith.addi %16, %18 : index + %20 = arith.addi %5, %19 : index + %21 = xegpu.create_nd_tdesc %arg1[%15, %20] : memref<256x256xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %22 = xegpu.update_nd_offset %21, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %23 = xegpu.update_nd_offset %22, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %24 = xegpu.update_nd_offset %23, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + %25 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<256x256xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %26 = xegpu.update_nd_offset %25, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %27 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<256x256xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %28 = xegpu.update_nd_offset %27, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %29 = xegpu.update_nd_offset %27, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %30 = xegpu.update_nd_offset %29, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %cst = arith.constant dense<0.000000e+00> : vector<128xf32> + %31 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %32 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %33 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %34 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %35 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %36 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %37 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %38 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %39 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %40 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %41 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %42 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %43 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %44 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %45 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %46 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %c8_0 = arith.constant 8 : index + %c4_1 = arith.constant 4 : index + %47 = arith.addi %c4_1, %c8_0 : index + xegpu.alloc_nbarrier 12 + %c8_i8 = arith.constant 8 : i8 + %48 = arith.index_cast %1 : index to i8 + %49 = xegpu.init_nbarrier %48, %c8_i8 : i8, i8 -> !xegpu.nbarrier + %c4_i8 = arith.constant 4 : i8 + %50 = arith.addi %c4_1, %0 : index + %51 = arith.index_cast %50 : index to i8 + %52 = xegpu.init_nbarrier %51, %c4_i8 : i8, i8 -> !xegpu.nbarrier + %53:24 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %25, %arg5 = %26, %arg6 = %27, %arg7 = %28, %arg8 = %29, %arg9 = %30, %arg10 = %31, %arg11 = %32, %arg12 = %33, %arg13 = %34, %arg14 = %35, %arg15 = %36, %arg16 = %37, %arg17 = %38, %arg18 = %39, %arg19 = %40, %arg20 = %41, %arg21 = %42, %arg22 = %43, %arg23 = %44, %arg24 = %45, %arg25 = %46, %arg26 = %12, %arg27 = %23) -> (!xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16>) { + %70 = arith.remui %arg3, %c32 : index + %71 = arith.index_cast %70 : index to i32 + %72 = arith.cmpi eq, %71, %c0_i32 : i32 + scf.if %72 { + xegpu.nbarrier_arrive %49 : !xegpu.nbarrier + xegpu.nbarrier_arrive %52 : !xegpu.nbarrier } - // Load smaller load (16 registers) with cache line size width : 64 bytes, 32 elements // Although maximum load size supported is 2KB or 32 registers, we use smaller loads, for 2 main reasons: // ** 1. Hide load latency: we do smaller load means for B, we do 4 loads, we set up the loads and dpas orderring in @@ -276,247 +206,159 @@ module @gemm attributes {gpu.container_module} { // // ** 2. Reduce the impact of L3 miss: Larger load means more cache lines to be loaded, more chance of potential L3 miss // which could increase the load time - // load B tiles - %b_val_0 = xegpu.load_nd %B_tile_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - %b_val_1 = xegpu.load_nd %B_tile_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - %b_val_2 = xegpu.load_nd %B_tile_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - %b_val_3 = xegpu.load_nd %B_tile_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - // load A tiles - %a_val_0 = xegpu.load_nd %A_tile_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> - %a_val_1 = xegpu.load_nd %A_tile_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> - - xegpu.compile_hint - // prefetch A and B tiles - xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> - - xegpu.compile_hint - // advance A and B prefetch tiles - %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> // advance A and B tiles - %next_A_tile_0 = xegpu.update_nd_offset %A_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_A_tile_1 = xegpu.update_nd_offset %A_tile_1, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - - %next_B_tile_0 = xegpu.update_nd_offset %B_tile_0, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_B_tile_1 = xegpu.update_nd_offset %B_tile_1, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_B_tile_2 = xegpu.update_nd_offset %B_tile_2, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_B_tile_3 = xegpu.update_nd_offset %B_tile_3, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - - %b_val_0_flat = vector.shape_cast %b_val_0 : vector<2x8x16x2xbf16> to vector<512xbf16> - %b_val_1_flat = vector.shape_cast %b_val_1 : vector<2x8x16x2xbf16> to vector<512xbf16> - %b_val_2_flat = vector.shape_cast %b_val_2 : vector<2x8x16x2xbf16> to vector<512xbf16> - %b_val_3_flat = vector.shape_cast %b_val_3 : vector<2x8x16x2xbf16> to vector<512xbf16> - // b[0,0], b[0,1] - %b_val_0_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_0 = vector.shape_cast %b_val_0_0_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_0_1_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_1 = vector.shape_cast %b_val_0_1_flat : vector<256xbf16> to vector<8x16x2xbf16> - // b[0,2], b[0,3] - %b_val_0_2_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_2 = vector.shape_cast %b_val_0_2_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_0_3_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_3 = vector.shape_cast %b_val_0_3_flat : vector<256xbf16> to vector<8x16x2xbf16> - // b[1,0], b[1,1] - %b_val_1_0_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_0 = vector.shape_cast %b_val_1_0_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_1_1_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_1 = vector.shape_cast %b_val_1_1_flat : vector<256xbf16> to vector<8x16x2xbf16> - // b[1,2], b[1,3] - %b_val_1_2_flat = vector.extract_strided_slice %b_val_3_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_2 = vector.shape_cast %b_val_1_2_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_1_3_flat = vector.extract_strided_slice %b_val_3_flat {offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_3 = vector.shape_cast %b_val_1_3_flat : vector<256xbf16> to vector<8x16x2xbf16> - - // xegpu.compile_hint - %a_val_0_flat = vector.shape_cast %a_val_0 : vector<2x16x16xbf16> to vector<512xbf16> - %a_val_1_flat = vector.shape_cast %a_val_1 : vector<2x16x16xbf16> to vector<512xbf16> - - %a_val_0_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_0_0 = vector.shape_cast %a_val_0_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_1_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [128], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_1_0 = vector.shape_cast %a_val_1_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_0_1_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_0_1 = vector.shape_cast %a_val_0_1_flat : vector<128xbf16> to vector<8x16xbf16> - %a_val_1_1_flat = vector.extract_strided_slice %a_val_0_flat {offsets = [384], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_1_1 = vector.shape_cast %a_val_1_1_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_2_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_2_0 = vector.shape_cast %a_val_2_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_3_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [128], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_3_0 = vector.shape_cast %a_val_3_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_2_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_2_1 = vector.shape_cast %a_val_2_1_flat : vector<128xbf16> to vector<8x16xbf16> - %a_val_3_1_flat = vector.extract_strided_slice %a_val_1_flat {offsets = [384], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_3_1 = vector.shape_cast %a_val_3_1_flat : vector<128xbf16> to vector<8x16xbf16> - - // do DPAS + // barrier wait + %73 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %74 = xegpu.load_nd %arg7 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %75 = xegpu.load_nd %arg8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %76 = xegpu.load_nd %arg9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %77 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> + %78 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> xegpu.compile_hint - - %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - + xegpu.prefetch_nd %arg26 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %arg27 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> xegpu.compile_hint - - %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - // barrier wait - scf.if %every_8th_iter_cond { - xegpu.nbarrier_wait %nbarrier_y : !xegpu.nbarrier - xegpu.nbarrier_wait %nbarrier_x : !xegpu.nbarrier + %79 = xegpu.update_nd_offset %arg26, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + %80 = xegpu.update_nd_offset %arg27, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + %81 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %82 = xegpu.update_nd_offset %arg5, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %83 = xegpu.update_nd_offset %arg6, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %84 = xegpu.update_nd_offset %arg7, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %85 = xegpu.update_nd_offset %arg8, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %86 = xegpu.update_nd_offset %arg9, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %87 = vector.shape_cast %73 : vector<2x8x16x2xbf16> to vector<512xbf16> + %88 = vector.shape_cast %74 : vector<2x8x16x2xbf16> to vector<512xbf16> + %89 = vector.shape_cast %75 : vector<2x8x16x2xbf16> to vector<512xbf16> + %90 = vector.shape_cast %76 : vector<2x8x16x2xbf16> to vector<512xbf16> + %91 = vector.extract_strided_slice %87 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %92 = vector.shape_cast %91 : vector<256xbf16> to vector<8x16x2xbf16> + %93 = vector.extract_strided_slice %87 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %94 = vector.shape_cast %93 : vector<256xbf16> to vector<8x16x2xbf16> + %95 = vector.extract_strided_slice %88 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %96 = vector.shape_cast %95 : vector<256xbf16> to vector<8x16x2xbf16> + %97 = vector.extract_strided_slice %88 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %98 = vector.shape_cast %97 : vector<256xbf16> to vector<8x16x2xbf16> + %99 = vector.extract_strided_slice %89 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %100 = vector.shape_cast %99 : vector<256xbf16> to vector<8x16x2xbf16> + %101 = vector.extract_strided_slice %89 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %102 = vector.shape_cast %101 : vector<256xbf16> to vector<8x16x2xbf16> + %103 = vector.extract_strided_slice %90 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %104 = vector.shape_cast %103 : vector<256xbf16> to vector<8x16x2xbf16> + %105 = vector.extract_strided_slice %90 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %106 = vector.shape_cast %105 : vector<256xbf16> to vector<8x16x2xbf16> + %107 = vector.shape_cast %77 : vector<2x16x16xbf16> to vector<512xbf16> + %108 = vector.shape_cast %78 : vector<2x16x16xbf16> to vector<512xbf16> + %109 = vector.extract_strided_slice %107 {offsets = [0], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %110 = vector.shape_cast %109 : vector<128xbf16> to vector<8x16xbf16> + %111 = vector.extract_strided_slice %107 {offsets = [128], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %112 = vector.shape_cast %111 : vector<128xbf16> to vector<8x16xbf16> + %113 = vector.extract_strided_slice %107 {offsets = [256], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %114 = vector.shape_cast %113 : vector<128xbf16> to vector<8x16xbf16> + %115 = vector.extract_strided_slice %107 {offsets = [384], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %116 = vector.shape_cast %115 : vector<128xbf16> to vector<8x16xbf16> + %117 = vector.extract_strided_slice %108 {offsets = [0], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %118 = vector.shape_cast %117 : vector<128xbf16> to vector<8x16xbf16> + %119 = vector.extract_strided_slice %108 {offsets = [128], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %120 = vector.shape_cast %119 : vector<128xbf16> to vector<8x16xbf16> + %121 = vector.extract_strided_slice %108 {offsets = [256], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %122 = vector.shape_cast %121 : vector<128xbf16> to vector<8x16xbf16> + %123 = vector.extract_strided_slice %108 {offsets = [384], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %124 = vector.shape_cast %123 : vector<128xbf16> to vector<8x16xbf16> + xegpu.compile_hint + %125 = xegpu.dpas %110, %92, %arg10 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %126 = xegpu.dpas %112, %92, %arg14 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %127 = xegpu.dpas %118, %92, %arg18 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %128 = xegpu.dpas %120, %92, %arg22 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %129 = xegpu.dpas %110, %94, %arg11 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %130 = xegpu.dpas %112, %94, %arg15 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %131 = xegpu.dpas %118, %94, %arg19 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %132 = xegpu.dpas %120, %94, %arg23 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %133 = xegpu.dpas %110, %96, %arg12 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %134 = xegpu.dpas %112, %96, %arg16 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %135 = xegpu.dpas %118, %96, %arg20 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %136 = xegpu.dpas %120, %96, %arg24 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %137 = xegpu.dpas %110, %98, %arg13 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %138 = xegpu.dpas %112, %98, %arg17 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %139 = xegpu.dpas %118, %98, %arg21 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %140 = xegpu.dpas %120, %98, %arg25 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %141 = xegpu.dpas %114, %100, %125 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %142 = xegpu.dpas %116, %100, %126 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %143 = xegpu.dpas %122, %100, %127 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %144 = xegpu.dpas %124, %100, %128 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %145 = xegpu.dpas %114, %102, %129 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %146 = xegpu.dpas %116, %102, %130 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %147 = xegpu.dpas %122, %102, %131 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %148 = xegpu.dpas %124, %102, %132 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %149 = xegpu.dpas %114, %104, %133 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %150 = xegpu.dpas %116, %104, %134 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %151 = xegpu.dpas %122, %104, %135 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %152 = xegpu.dpas %124, %104, %136 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %153 = xegpu.dpas %114, %106, %137 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %154 = xegpu.dpas %116, %106, %138 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %155 = xegpu.dpas %122, %106, %139 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %156 = xegpu.dpas %124, %106, %140 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.if %72 { + xegpu.nbarrier_wait %49 : !xegpu.nbarrier + xegpu.nbarrier_wait %52 : !xegpu.nbarrier } - - scf.yield %next_A_tile_0, %next_A_tile_1, %next_B_tile_0, %next_B_tile_1, %next_B_tile_2, %next_B_tile_3, - %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, - %next_A_prefetch_tile, %next_B_prefetch_tile - : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> + scf.yield %81, %82, %83, %84, %85, %86, %141, %145, %149, %153, %142, %146, %150, %154, %143, %147, %151, %155, %144, %148, %152, %156, %79, %80 : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> } - // each SG needs to store the result of K loop into a 32x64 tile in C matrix. This is organized in 8x16 DPAS tiles // in the layout of 4x4x8x16. The max store size HW supoprt in f32 is 8x16. - - %c_sg_tile_00 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#6, %c_sg_tile_00 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %54 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#6, %54 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_01 = xegpu.update_nd_offset %c_sg_tile_00, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#7, %c_sg_tile_01 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_02 = xegpu.update_nd_offset %c_sg_tile_01, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#8, %c_sg_tile_02 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_03 = xegpu.update_nd_offset %c_sg_tile_02, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#9, %c_sg_tile_03 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_10 = xegpu.update_nd_offset %c_sg_tile_00, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#10, %c_sg_tile_10 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %55 = xegpu.update_nd_offset %54, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#7, %55 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %56 = xegpu.update_nd_offset %55, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#8, %56 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %57 = xegpu.update_nd_offset %56, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#9, %57 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %58 = xegpu.update_nd_offset %54, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#10, %58 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_11 = xegpu.update_nd_offset %c_sg_tile_01, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#11, %c_sg_tile_11 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_12 = xegpu.update_nd_offset %c_sg_tile_02, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#12, %c_sg_tile_12 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_13 = xegpu.update_nd_offset %c_sg_tile_03, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#13, %c_sg_tile_13 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_20 = xegpu.update_nd_offset %c_sg_tile_10, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#14, %c_sg_tile_20 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %59 = xegpu.update_nd_offset %55, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#11, %59 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %60 = xegpu.update_nd_offset %56, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#12, %60 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %61 = xegpu.update_nd_offset %57, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#13, %61 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %62 = xegpu.update_nd_offset %58, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#14, %62 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_21 = xegpu.update_nd_offset %c_sg_tile_11, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#15, %c_sg_tile_21 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_22 = xegpu.update_nd_offset %c_sg_tile_12, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#16, %c_sg_tile_22 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_23 = xegpu.update_nd_offset %c_sg_tile_13, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#17, %c_sg_tile_23 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_30 = xegpu.update_nd_offset %c_sg_tile_20, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#18, %c_sg_tile_30 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %63 = xegpu.update_nd_offset %59, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#15, %63 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %64 = xegpu.update_nd_offset %60, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#16, %64 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %65 = xegpu.update_nd_offset %61, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#17, %65 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %66 = xegpu.update_nd_offset %62, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#18, %66 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_31 = xegpu.update_nd_offset %c_sg_tile_21, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#19, %c_sg_tile_31 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_32 = xegpu.update_nd_offset %c_sg_tile_22, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#20, %c_sg_tile_32 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_33 = xegpu.update_nd_offset %c_sg_tile_23, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#21, %c_sg_tile_33 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + %67 = xegpu.update_nd_offset %63, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#19, %67 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %68 = xegpu.update_nd_offset %64, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#20, %68 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %69 = xegpu.update_nd_offset %65, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#21, %69 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c1_f16 = arith.constant 1.0 : bf16 - %c2_f16 = arith.constant 2.0 : bf16 %c256 = arith.constant 256 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant 0.0 : f32 - %cf_upper = arith.constant 1.0 : f32 - - %A = memref.alloc() : memref<256x256xbf16> - %B = memref.alloc() : memref<256x256xbf16> - %C = memref.alloc() : memref<256x256xf32> - %C_ref = memref.alloc() : memref<256x256xf32> - // Use one of the two options to initialize the A matrix // Option 1: intialize matrix A ; A[i, j] = j // scf.for %i = %c0 to %c256 step %c1 { @@ -529,9 +371,6 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) - %A_random = memref.cast %A : memref<256x256xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - // Use one of the two options below to initialize the B matrix // Option 1: make matrix B an identity matrix // scf.for %i = %c0 to %c256 step %c1 { @@ -539,7 +378,6 @@ module @gemm attributes {gpu.container_module} { // %i_i32 = index.castu %i : index to i32 // %j_i32 = index.castu %j : index to i32 // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - // scf.if %i_j_same { // memref.store %cf_1, %B[%i, %j] : memref<256x256xbf16> // } else { @@ -548,36 +386,39 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) - %B_random = memref.cast %B : memref<256x256xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f16 = arith.constant 0.0 : bf16 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> + %false = arith.constant false + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %alloc = memref.alloc() : memref<256x256xbf16> + %alloc_1 = memref.alloc() : memref<256x256xbf16> + %alloc_2 = memref.alloc() : memref<256x256xf32> + %alloc_3 = memref.alloc() : memref<256x256xf32> + %cast = memref.cast %alloc : memref<256x256xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + %cast_4 = memref.cast %alloc_1 : memref<256x256xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast_4, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<256x256xf32> + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<256x256xf32> } } - // Run GPU - %2 = call @test(%A, %B, %C) : (memref<256x256xbf16>, memref<256x256xbf16>, memref<256x256xf32>) -> memref<256x256xf32> - %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - // Run CPU - %A_cast = memref.cast %A : memref<256x256xbf16> to memref<*xbf16> - %B_cast = memref.cast %B : memref<256x256xbf16> to memref<*xbf16> - %C_ref_cast = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> - call @gemmBF16BF16F32(%A_cast, %B_cast, %C_ref_cast) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %C_ref_cast) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<256x256xbf16> - memref.dealloc %B : memref<256x256xbf16> - memref.dealloc %C : memref<256x256xf32> - memref.dealloc %C_ref : memref<256x256xf32> - gpu.dealloc %2 : memref<256x256xf32> + %0 = call @test(%alloc, %alloc_1, %alloc_2) : (memref<256x256xbf16>, memref<256x256xbf16>, memref<256x256xf32>) -> memref<256x256xf32> + %cast_5 = memref.cast %0 : memref<256x256xf32> to memref<*xf32> + %cast_6 = memref.cast %alloc : memref<256x256xbf16> to memref<*xbf16> + %cast_7 = memref.cast %alloc_1 : memref<256x256xbf16> to memref<*xbf16> + %cast_8 = memref.cast %alloc_3 : memref<256x256xf32> to memref<*xf32> + call @gemmBF16BF16F32(%cast_6, %cast_7, %cast_8) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_5, %cast_8) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<256x256xbf16> + memref.dealloc %alloc_1 : memref<256x256xbf16> + memref.dealloc %alloc_2 : memref<256x256xf32> + memref.dealloc %alloc_3 : memref<256x256xf32> + memref.dealloc %0 : memref<256x256xf32> return } func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} @@ -586,5 +427,4 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} func.func private @gemmBF16BF16F32(memref<*xbf16>, memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir index dde050221..954ad9de2 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_bf16_bf16_f32_xetla_like_load_store_prefetch.mlir @@ -1,43 +1,33 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<4096x4096xbf16>, %arg1: memref<4096x4096xbf16>, %arg2: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - // Explicit memory copy to and from host - %A_gpu = gpu.alloc () : memref<4096x4096xbf16> - gpu.memcpy %A_gpu, %A : memref<4096x4096xbf16>, memref<4096x4096xbf16> - %B_gpu = gpu.alloc () : memref<4096x4096xbf16> - gpu.memcpy %B_gpu, %B : memref<4096x4096xbf16>, memref<4096x4096xbf16> - %C_gpu = gpu.alloc () : memref<4096x4096xf32> - gpu.memcpy %C_gpu, %C : memref<4096x4096xf32>, memref<4096x4096xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xbf16>, %B_gpu : memref<4096x4096xbf16>, %C_gpu : memref<4096x4096xf32>) - %C_host = memref.alloc() : memref<4096x4096xf32> - gpu.memcpy %C_host, %C_gpu : memref<4096x4096xf32>, memref<4096x4096xf32> - gpu.dealloc %A_gpu : memref<4096x4096xbf16> - gpu.dealloc %B_gpu : memref<4096x4096xbf16> - gpu.dealloc %C_gpu : memref<4096x4096xf32> - return %C_host : memref<4096x4096xf32> // ******************************************* - + %memref = gpu.alloc () : memref<4096x4096xbf16> + gpu.memcpy %memref, %arg0 : memref<4096x4096xbf16>, memref<4096x4096xbf16> + %memref_0 = gpu.alloc () : memref<4096x4096xbf16> + gpu.memcpy %memref_0, %arg1 : memref<4096x4096xbf16>, memref<4096x4096xbf16> + %memref_1 = gpu.alloc () : memref<4096x4096xf32> + gpu.memcpy %memref_1, %arg2 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<4096x4096xbf16>, %memref_0 : memref<4096x4096xbf16>, %memref_1 : memref<4096x4096xf32>) + %alloc = memref.alloc() : memref<4096x4096xf32> + gpu.memcpy %alloc, %memref_1 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.dealloc %memref : memref<4096x4096xbf16> + gpu.dealloc %memref_0 : memref<4096x4096xbf16> + gpu.dealloc %memref_1 : memref<4096x4096xf32> + return %alloc : memref<4096x4096xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // constants + gpu.func @test_kernel(%arg0: memref<4096x4096xbf16>, %arg1: memref<4096x4096xbf16>, %arg2: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c256 = arith.constant 256 : index %c512 = arith.constant 512 : index %c1024 = arith.constant 1024 : index @@ -55,10 +45,7 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 // get IDs - %wg_id_x = gpu.block_id x - %wg_id_y = gpu.block_id y // %sg_id = gpu.subgroup_id : index - // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout // C sg tile size is 32x64 // SG layout for one C tile update @@ -69,17 +56,8 @@ module @gemm attributes {gpu.container_module} { // --> y means cols // | // V x means rows - // get unique sg ID in global context - %global_sg_id_x = gpu.global_id x - %global_sg_id_y = gpu.global_id y - %local_sg_id_x = arith.remui %global_sg_id_x, %c8 : index - %local_sg_id_y = arith.remui %global_sg_id_y, %c4 : index - // compute SG C tile offsets in x and y dims - %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index - %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index - // each SG needs to do the follwoing compute to update its 32x64 sub tile // (32x4096)x(4096x64)=(32x64) // DPAS size is (8x16)x(16x16)=(8x16) @@ -87,14 +65,7 @@ module @gemm attributes {gpu.container_module} { // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) // this will require 32 DPAS ops (4x2x2) inside the K loop - // WG tiles are 256x256 so there offsets are same for A, B and C - %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index - %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index - - %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index - %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index - // prefetching A and B slice within the 256x256 WG tile // // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice @@ -106,22 +77,11 @@ module @gemm attributes {gpu.container_module} { // SG 4 -> slice 4 // .... // SG 31 -> slice 31 - %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index - %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index // create A preftech tiles and prefetch // stage 1 - %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) - %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 3 - %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> - // stage 4 - %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - // prefetch the entire 32x256 slice of B WG tile, we still use the prefetch size 8x32. // SGs have 8x4 layout. In this case 8 subgroups must do a colloborative prefetch of 32x64 tile. // this because the B tile arrangement within the 32x256 slice is as follows @@ -137,144 +97,111 @@ module @gemm attributes {gpu.container_module} { // | 8 | 24|| 9 | 25 || 10 | 26 || 11| 27 | // | 12 | 28|| 13 | 29 || 14 | 30 || 15| 31 | // For example, SGs 0,4,8,12,16,20,24,28 share the data in left 32x64 tile of B slice. - // calculate the x offsets and y offsets within the 32x256 slice // XeTLA like co-operative prefetch for B - %B_sg_prefetch_offset_x_temp0 = arith.divui %local_sg_id_x, %c2 : index - %B_sg_prefetch_offset_x = arith.muli %B_sg_prefetch_offset_x_temp0, %c8 : index - - %B_sg_prefetch_offset_y_temp0 = arith.muli %local_sg_id_y, %c64 : index - %B_sg_prefetch_offset_y_temp1 = arith.remui %local_sg_id_x, %c2 : index - %B_sg_prefetch_offset_y_temp2 = arith.muli %B_sg_prefetch_offset_y_temp1, %c32 : index - - %B_sg_prefetch_offset_y_temp3 = arith.addi %B_sg_prefetch_offset_y_temp0, %B_sg_prefetch_offset_y_temp2 : index - %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_temp3 : index - - // create B prefetch tiles and prefetch - %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) - %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 3 - %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> // stage 4 - %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> - - %A_sg_init_tile_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %A_sg_init_tile_1 = xegpu.update_nd_offset %A_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - //create B tiles - %B_sg_init_tile_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %B_sg_init_tile_1 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - - %B_sg_init_tile_2 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %B_sg_init_tile_3 = xegpu.update_nd_offset %B_sg_init_tile_2, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> // ************************* // - // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix - %zero_vec = arith.constant dense<0.0> : vector<128xf32> - %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - // Multi nbarrier implementation, // one set nbarrier is used to sync subgroups with same sg_id_x (local_sg_id_x) // second set nbarrier us used to sync subgroups with same sg_id_y (local_sg_id_y) // In this case wg_size = 8,4 (wg_size_x = 8; wg_size_y = 4) // So in Y-direction we need 4 nbarrier (to sync subgroups with same sg_id_y) // In X-direction we need 8 nbarrier (to sync subgroups with same sg_id_x) - %c_wg_size_x = arith.constant 8 : index - %c_wg_size_y = arith.constant 4 : index - %num_nbarrier = arith.addi %c_wg_size_y, %c_wg_size_x : index // 8+4=12 - xegpu.alloc_nbarrier 12 // = 12 - // First set of nbarriers work across coloumns, we have 4 coloums of subgroups, // Hnece 4 nbrrier // Each nbarrier has 8 producers and consumers // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) - // %nbarrier_role = arith.constant 0 : i8 - %nbarrier_threads_y = arith.constant 8 : i8 - %nbarrier_id_y = arith.index_cast %local_sg_id_y : index to i8 - %nbarrier_y = xegpu.init_nbarrier %nbarrier_id_y, %nbarrier_threads_y : i8, i8 -> !xegpu.nbarrier - // Second set of barriers work on across rows of subgroups, // we have 8 rows of subgroups. Hnece, 8 nbarrier // Each nbarrier has 4 producers and consumers // nbarrier type is Producer_Consumer (https://gfxspecs.intel.com/Predator/Home/Index/57499) - // We already have 4 (=%c_wg_size_y) nbarriers with id 0-3, // Now the next set of barrier id would start from 4, hence, - %nbarrier_threads_x = arith.constant 4 : i8 - %index_nbarrier_id_x = arith.addi %c_wg_size_y, %local_sg_id_x : index - %nbarrier_id_x = arith.index_cast %index_nbarrier_id_x : index to i8 - %nbarrier_x = xegpu.init_nbarrier %nbarrier_id_x, %nbarrier_threads_x : i8, i8 -> !xegpu.nbarrier - // K loop advances in 32 steps - %k_loop_result:24 = scf.for %k = %c0 to %c4096 step %c32 iter_args ( - %A_tile_0 = %A_sg_init_tile_0, - %A_tile_1 = %A_sg_init_tile_1, - - %B_tile_0 = %B_sg_init_tile_0, - %B_tile_1 = %B_sg_init_tile_1, - %B_tile_2 = %B_sg_init_tile_2, - %B_tile_3 = %B_sg_init_tile_3, - - %c_val_0_0 = %c_init_val_0_0, - %c_val_0_1 = %c_init_val_0_1, - %c_val_0_2 = %c_init_val_0_2, - %c_val_0_3 = %c_init_val_0_3, - %c_val_1_0 = %c_init_val_1_0, - %c_val_1_1 = %c_init_val_1_1, - %c_val_1_2 = %c_init_val_1_2, - %c_val_1_3 = %c_init_val_1_3, - %c_val_2_0 = %c_init_val_2_0, - %c_val_2_1 = %c_init_val_2_1, - %c_val_2_2 = %c_init_val_2_2, - %c_val_2_3 = %c_init_val_2_3, - %c_val_3_0 = %c_init_val_3_0, - %c_val_3_1 = %c_init_val_3_1, - %c_val_3_2 = %c_init_val_3_2, - %c_val_3_3 = %c_init_val_3_3, - - %A_prefetch_tile = %A_sg_prefetch_tile_iter2, - %B_prefetch_tile = %B_sg_prefetch_tile_iter2 - ) -> - (!xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> - ) - { // all SGs must arrive here first - %every_8th_iter = arith.remui %k, %c32 : index - %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 - %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 - scf.if %every_8th_iter_cond { - xegpu.nbarrier_arrive %nbarrier_y : !xegpu.nbarrier - xegpu.nbarrier_arrive %nbarrier_x : !xegpu.nbarrier + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %global_id_x = gpu.global_id x + %global_id_y = gpu.global_id y + %0 = arith.remui %global_id_x, %c8 : index + %1 = arith.remui %global_id_y, %c4 : index + %2 = arith.muli %global_id_x, %c32 : index + %3 = arith.muli %global_id_y, %c64 : index + %4 = arith.muli %block_id_x, %c256 : index + %5 = arith.muli %block_id_y, %c256 : index + %6 = arith.muli %0, %c4 : index + %7 = arith.addi %6, %1 : index + %8 = arith.muli %7, %c8 : index + %9 = arith.addi %8, %4 : index + %10 = xegpu.create_nd_tdesc %arg0[%9, %c0] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %11 = xegpu.update_nd_offset %10, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %12 = xegpu.update_nd_offset %11, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %13 = xegpu.update_nd_offset %12, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + %14 = arith.divui %0, %c2 : index + %15 = arith.muli %14, %c8 : index + %16 = arith.muli %1, %c64 : index + %17 = arith.remui %0, %c2 : index + %18 = arith.muli %17, %c32 : index + %19 = arith.addi %16, %18 : index + %20 = arith.addi %5, %19 : index + %21 = xegpu.create_nd_tdesc %arg1[%15, %20] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %22 = xegpu.update_nd_offset %21, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %23 = xegpu.update_nd_offset %22, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + %24 = xegpu.update_nd_offset %23, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + %25 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %26 = xegpu.update_nd_offset %25, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %27 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<4096x4096xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %28 = xegpu.update_nd_offset %27, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %29 = xegpu.update_nd_offset %27, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %30 = xegpu.update_nd_offset %29, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %cst = arith.constant dense<0.000000e+00> : vector<128xf32> + %31 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %32 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %33 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %34 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %35 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %36 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %37 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %38 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %39 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %40 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %41 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %42 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %43 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %44 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %45 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %46 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %c8_0 = arith.constant 8 : index + %c4_1 = arith.constant 4 : index + %47 = arith.addi %c4_1, %c8_0 : index + xegpu.alloc_nbarrier 12 + %c8_i8 = arith.constant 8 : i8 + %48 = arith.index_cast %1 : index to i8 + %49 = xegpu.init_nbarrier %48, %c8_i8 : i8, i8 -> !xegpu.nbarrier + %c4_i8 = arith.constant 4 : i8 + %50 = arith.addi %c4_1, %0 : index + %51 = arith.index_cast %50 : index to i8 + %52 = xegpu.init_nbarrier %51, %c4_i8 : i8, i8 -> !xegpu.nbarrier + %53:24 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %25, %arg5 = %26, %arg6 = %27, %arg7 = %28, %arg8 = %29, %arg9 = %30, %arg10 = %31, %arg11 = %32, %arg12 = %33, %arg13 = %34, %arg14 = %35, %arg15 = %36, %arg16 = %37, %arg17 = %38, %arg18 = %39, %arg19 = %40, %arg20 = %41, %arg21 = %42, %arg22 = %43, %arg23 = %44, %arg24 = %45, %arg25 = %46, %arg26 = %12, %arg27 = %23) -> (!xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16>) { + %70 = arith.remui %arg3, %c32 : index + %71 = arith.index_cast %70 : index to i32 + %72 = arith.cmpi eq, %71, %c0_i32 : i32 + scf.if %72 { + xegpu.nbarrier_arrive %49 : !xegpu.nbarrier + xegpu.nbarrier_arrive %52 : !xegpu.nbarrier } - // Load smaller load (16 registers) with cache line size width : 64 bytes, 32 elements // Although maximum load size supported is 2KB or 32 registers, we use smaller loads, for 2 main reasons: // ** 1. Hide load latency: we do smaller load means for B, we do 4 loads, we set up the loads and dpas orderring in @@ -283,247 +210,159 @@ module @gemm attributes {gpu.container_module} { // // ** 2. Reduce the impact of L3 miss: Larger load means more cache lines to be loaded, more chance of potential L3 miss // which could increase the load time - // load B tiles - %b_val_0 = xegpu.load_nd %B_tile_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - %b_val_1 = xegpu.load_nd %B_tile_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - %b_val_2 = xegpu.load_nd %B_tile_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - %b_val_3 = xegpu.load_nd %B_tile_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> - // load A tiles - %a_val_0 = xegpu.load_nd %A_tile_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> - %a_val_1 = xegpu.load_nd %A_tile_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> - - xegpu.compile_hint - // prefetch A and B tiles - xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> - xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xbf16> - - xegpu.compile_hint - // advance A and B prefetch tiles - %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> - %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> // advance A and B tiles - %next_A_tile_0 = xegpu.update_nd_offset %A_tile_0, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_A_tile_1 = xegpu.update_nd_offset %A_tile_1, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - - %next_B_tile_0 = xegpu.update_nd_offset %B_tile_0, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_B_tile_1 = xegpu.update_nd_offset %B_tile_1, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_B_tile_2 = xegpu.update_nd_offset %B_tile_2, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - %next_B_tile_3 = xegpu.update_nd_offset %B_tile_3, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> - - %b_val_0_flat = vector.shape_cast %b_val_0 : vector<2x8x16x2xbf16> to vector<512xbf16> - %b_val_1_flat = vector.shape_cast %b_val_1 : vector<2x8x16x2xbf16> to vector<512xbf16> - %b_val_2_flat = vector.shape_cast %b_val_2 : vector<2x8x16x2xbf16> to vector<512xbf16> - %b_val_3_flat = vector.shape_cast %b_val_3 : vector<2x8x16x2xbf16> to vector<512xbf16> - // b[0,0], b[0,1] - %b_val_0_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_0 = vector.shape_cast %b_val_0_0_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_0_1_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_1 = vector.shape_cast %b_val_0_1_flat : vector<256xbf16> to vector<8x16x2xbf16> - // b[0,2], b[0,3] - %b_val_0_2_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_2 = vector.shape_cast %b_val_0_2_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_0_3_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_0_3 = vector.shape_cast %b_val_0_3_flat : vector<256xbf16> to vector<8x16x2xbf16> - // b[1,0], b[1,1] - %b_val_1_0_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_0 = vector.shape_cast %b_val_1_0_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_1_1_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_1 = vector.shape_cast %b_val_1_1_flat : vector<256xbf16> to vector<8x16x2xbf16> - // b[1,2], b[1,3] - %b_val_1_2_flat = vector.extract_strided_slice %b_val_3_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_2 = vector.shape_cast %b_val_1_2_flat : vector<256xbf16> to vector<8x16x2xbf16> - %b_val_1_3_flat = vector.extract_strided_slice %b_val_3_flat {offsets = [256], sizes = [256], strides = [1]} : - vector<512xbf16> to vector<256xbf16> - %b_val_1_3 = vector.shape_cast %b_val_1_3_flat : vector<256xbf16> to vector<8x16x2xbf16> - - // xegpu.compile_hint - %a_val_0_flat = vector.shape_cast %a_val_0 : vector<2x16x16xbf16> to vector<512xbf16> - %a_val_1_flat = vector.shape_cast %a_val_1 : vector<2x16x16xbf16> to vector<512xbf16> - - %a_val_0_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_0_0 = vector.shape_cast %a_val_0_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_1_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [128], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_1_0 = vector.shape_cast %a_val_1_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_0_1_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_0_1 = vector.shape_cast %a_val_0_1_flat : vector<128xbf16> to vector<8x16xbf16> - %a_val_1_1_flat = vector.extract_strided_slice %a_val_0_flat {offsets = [384], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_1_1 = vector.shape_cast %a_val_1_1_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_2_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_2_0 = vector.shape_cast %a_val_2_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_3_0_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [128], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_3_0 = vector.shape_cast %a_val_3_0_flat : vector<128xbf16> to vector<8x16xbf16> - - %a_val_2_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_2_1 = vector.shape_cast %a_val_2_1_flat : vector<128xbf16> to vector<8x16xbf16> - %a_val_3_1_flat = vector.extract_strided_slice %a_val_1_flat {offsets = [384], sizes = [128], strides = [1]} : - vector<512xbf16> to vector<128xbf16> - %a_val_3_1 = vector.shape_cast %a_val_3_1_flat : vector<128xbf16> to vector<8x16xbf16> - - // do DPAS + // barrier wait + %73 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %74 = xegpu.load_nd %arg7 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %75 = xegpu.load_nd %arg8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %76 = xegpu.load_nd %arg9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x8x16x2xbf16> + %77 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> + %78 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xbf16> xegpu.compile_hint - - %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - + xegpu.prefetch_nd %arg26 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> + xegpu.prefetch_nd %arg27 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xbf16> xegpu.compile_hint - - %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> - - // barrier wait - scf.if %every_8th_iter_cond { - xegpu.nbarrier_wait %nbarrier_y : !xegpu.nbarrier - xegpu.nbarrier_wait %nbarrier_x : !xegpu.nbarrier + %79 = xegpu.update_nd_offset %arg26, [%c0, %c32] : !xegpu.tensor_desc<8x32xbf16> + %80 = xegpu.update_nd_offset %arg27, [%c32, %c0] : !xegpu.tensor_desc<8x32xbf16> + %81 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %82 = xegpu.update_nd_offset %arg5, [%c0, %c32] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %83 = xegpu.update_nd_offset %arg6, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %84 = xegpu.update_nd_offset %arg7, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %85 = xegpu.update_nd_offset %arg8, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %86 = xegpu.update_nd_offset %arg9, [%c32, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr> + %87 = vector.shape_cast %73 : vector<2x8x16x2xbf16> to vector<512xbf16> + %88 = vector.shape_cast %74 : vector<2x8x16x2xbf16> to vector<512xbf16> + %89 = vector.shape_cast %75 : vector<2x8x16x2xbf16> to vector<512xbf16> + %90 = vector.shape_cast %76 : vector<2x8x16x2xbf16> to vector<512xbf16> + %91 = vector.extract_strided_slice %87 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %92 = vector.shape_cast %91 : vector<256xbf16> to vector<8x16x2xbf16> + %93 = vector.extract_strided_slice %87 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %94 = vector.shape_cast %93 : vector<256xbf16> to vector<8x16x2xbf16> + %95 = vector.extract_strided_slice %88 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %96 = vector.shape_cast %95 : vector<256xbf16> to vector<8x16x2xbf16> + %97 = vector.extract_strided_slice %88 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %98 = vector.shape_cast %97 : vector<256xbf16> to vector<8x16x2xbf16> + %99 = vector.extract_strided_slice %89 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %100 = vector.shape_cast %99 : vector<256xbf16> to vector<8x16x2xbf16> + %101 = vector.extract_strided_slice %89 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %102 = vector.shape_cast %101 : vector<256xbf16> to vector<8x16x2xbf16> + %103 = vector.extract_strided_slice %90 {offsets = [0], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %104 = vector.shape_cast %103 : vector<256xbf16> to vector<8x16x2xbf16> + %105 = vector.extract_strided_slice %90 {offsets = [256], sizes = [256], strides = [1]} : vector<512xbf16> to vector<256xbf16> + %106 = vector.shape_cast %105 : vector<256xbf16> to vector<8x16x2xbf16> + %107 = vector.shape_cast %77 : vector<2x16x16xbf16> to vector<512xbf16> + %108 = vector.shape_cast %78 : vector<2x16x16xbf16> to vector<512xbf16> + %109 = vector.extract_strided_slice %107 {offsets = [0], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %110 = vector.shape_cast %109 : vector<128xbf16> to vector<8x16xbf16> + %111 = vector.extract_strided_slice %107 {offsets = [128], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %112 = vector.shape_cast %111 : vector<128xbf16> to vector<8x16xbf16> + %113 = vector.extract_strided_slice %107 {offsets = [256], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %114 = vector.shape_cast %113 : vector<128xbf16> to vector<8x16xbf16> + %115 = vector.extract_strided_slice %107 {offsets = [384], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %116 = vector.shape_cast %115 : vector<128xbf16> to vector<8x16xbf16> + %117 = vector.extract_strided_slice %108 {offsets = [0], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %118 = vector.shape_cast %117 : vector<128xbf16> to vector<8x16xbf16> + %119 = vector.extract_strided_slice %108 {offsets = [128], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %120 = vector.shape_cast %119 : vector<128xbf16> to vector<8x16xbf16> + %121 = vector.extract_strided_slice %108 {offsets = [256], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %122 = vector.shape_cast %121 : vector<128xbf16> to vector<8x16xbf16> + %123 = vector.extract_strided_slice %108 {offsets = [384], sizes = [128], strides = [1]} : vector<512xbf16> to vector<128xbf16> + %124 = vector.shape_cast %123 : vector<128xbf16> to vector<8x16xbf16> + xegpu.compile_hint + %125 = xegpu.dpas %110, %92, %arg10 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %126 = xegpu.dpas %112, %92, %arg14 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %127 = xegpu.dpas %118, %92, %arg18 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %128 = xegpu.dpas %120, %92, %arg22 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %129 = xegpu.dpas %110, %94, %arg11 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %130 = xegpu.dpas %112, %94, %arg15 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %131 = xegpu.dpas %118, %94, %arg19 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %132 = xegpu.dpas %120, %94, %arg23 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %133 = xegpu.dpas %110, %96, %arg12 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %134 = xegpu.dpas %112, %96, %arg16 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %135 = xegpu.dpas %118, %96, %arg20 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %136 = xegpu.dpas %120, %96, %arg24 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %137 = xegpu.dpas %110, %98, %arg13 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %138 = xegpu.dpas %112, %98, %arg17 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %139 = xegpu.dpas %118, %98, %arg21 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %140 = xegpu.dpas %120, %98, %arg25 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %141 = xegpu.dpas %114, %100, %125 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %142 = xegpu.dpas %116, %100, %126 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %143 = xegpu.dpas %122, %100, %127 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %144 = xegpu.dpas %124, %100, %128 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %145 = xegpu.dpas %114, %102, %129 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %146 = xegpu.dpas %116, %102, %130 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %147 = xegpu.dpas %122, %102, %131 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %148 = xegpu.dpas %124, %102, %132 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %149 = xegpu.dpas %114, %104, %133 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %150 = xegpu.dpas %116, %104, %134 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %151 = xegpu.dpas %122, %104, %135 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %152 = xegpu.dpas %124, %104, %136 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %153 = xegpu.dpas %114, %106, %137 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %154 = xegpu.dpas %116, %106, %138 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %155 = xegpu.dpas %122, %106, %139 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + %156 = xegpu.dpas %124, %106, %140 : vector<8x16xbf16>, vector<8x16x2xbf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.if %72 { + xegpu.nbarrier_wait %49 : !xegpu.nbarrier + xegpu.nbarrier_wait %52 : !xegpu.nbarrier } - - scf.yield %next_A_tile_0, %next_A_tile_1, %next_B_tile_0, %next_B_tile_1, %next_B_tile_2, %next_B_tile_3, - %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, - %next_A_prefetch_tile, %next_B_prefetch_tile - : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> + scf.yield %81, %82, %83, %84, %85, %86, %141, %145, %149, %153, %142, %146, %150, %154, %143, %147, %151, %155, %144, %148, %152, %156, %79, %80 : !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xbf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xbf16>, !xegpu.tensor_desc<8x32xbf16> } - // each SG needs to store the result of K loop into a 32x64 tile in C matrix. This is organized in 8x16 DPAS tiles // in the layout of 4x4x8x16. The max store size HW supoprt in f32 is 8x16. - - %c_sg_tile_00 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<4096x4096xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#6, %c_sg_tile_00 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %54 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<4096x4096xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#6, %54 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_01 = xegpu.update_nd_offset %c_sg_tile_00, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#7, %c_sg_tile_01 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_02 = xegpu.update_nd_offset %c_sg_tile_01, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#8, %c_sg_tile_02 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_03 = xegpu.update_nd_offset %c_sg_tile_02, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#9, %c_sg_tile_03 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_10 = xegpu.update_nd_offset %c_sg_tile_00, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#10, %c_sg_tile_10 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %55 = xegpu.update_nd_offset %54, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#7, %55 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %56 = xegpu.update_nd_offset %55, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#8, %56 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %57 = xegpu.update_nd_offset %56, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#9, %57 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %58 = xegpu.update_nd_offset %54, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#10, %58 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_11 = xegpu.update_nd_offset %c_sg_tile_01, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#11, %c_sg_tile_11 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_12 = xegpu.update_nd_offset %c_sg_tile_02, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#12, %c_sg_tile_12 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_13 = xegpu.update_nd_offset %c_sg_tile_03, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#13, %c_sg_tile_13 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_20 = xegpu.update_nd_offset %c_sg_tile_10, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#14, %c_sg_tile_20 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %59 = xegpu.update_nd_offset %55, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#11, %59 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %60 = xegpu.update_nd_offset %56, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#12, %60 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %61 = xegpu.update_nd_offset %57, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#13, %61 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %62 = xegpu.update_nd_offset %58, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#14, %62 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_21 = xegpu.update_nd_offset %c_sg_tile_11, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#15, %c_sg_tile_21 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_22 = xegpu.update_nd_offset %c_sg_tile_12, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#16, %c_sg_tile_22 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_23 = xegpu.update_nd_offset %c_sg_tile_13, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#17, %c_sg_tile_23 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_30 = xegpu.update_nd_offset %c_sg_tile_20, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#18, %c_sg_tile_30 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %63 = xegpu.update_nd_offset %59, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#15, %63 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %64 = xegpu.update_nd_offset %60, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#16, %64 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %65 = xegpu.update_nd_offset %61, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#17, %65 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %66 = xegpu.update_nd_offset %62, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#18, %66 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> xegpu.compile_hint - - %c_sg_tile_31 = xegpu.update_nd_offset %c_sg_tile_21, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#19, %c_sg_tile_31 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_32 = xegpu.update_nd_offset %c_sg_tile_22, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#20, %c_sg_tile_32 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - %c_sg_tile_33 = xegpu.update_nd_offset %c_sg_tile_23, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#21, %c_sg_tile_33 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + %67 = xegpu.update_nd_offset %63, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#19, %67 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %68 = xegpu.update_nd_offset %64, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#20, %68 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %69 = xegpu.update_nd_offset %65, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %53#21, %69 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c1_f16 = arith.constant 1.0 : bf16 - %c2_f16 = arith.constant 2.0 : bf16 %c4096 = arith.constant 4096 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant 0.0 : f32 - %cf_upper = arith.constant 1.0 : f32 - - %A = memref.alloc() : memref<4096x4096xbf16> - %B = memref.alloc() : memref<4096x4096xbf16> - %C = memref.alloc() : memref<4096x4096xf32> - %C_ref = memref.alloc() : memref<4096x4096xf32> - // Use one of the two options to initialize the A matrix // Option 1: intialize matrix A ; A[i, j] = j // scf.for %i = %c0 to %c4096 step %c1 { @@ -536,9 +375,6 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) - %A_random = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - // Use one of the two options below to initialize the B matrix // Option 1: make matrix B an identity matrix // scf.for %i = %c0 to %c4096 step %c1 { @@ -546,7 +382,6 @@ module @gemm attributes {gpu.container_module} { // %i_i32 = index.castu %i : index to i32 // %j_i32 = index.castu %j : index to i32 // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - // scf.if %i_j_same { // memref.store %cf_1, %B[%i, %j] : memref<4096x4096xbf16> // } else { @@ -555,35 +390,38 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (0.0, 1.0) - %B_random = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f16 = arith.constant 0.0 : bf16 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + %false = arith.constant false + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %alloc = memref.alloc() : memref<4096x4096xbf16> + %alloc_1 = memref.alloc() : memref<4096x4096xbf16> + %alloc_2 = memref.alloc() : memref<4096x4096xf32> + %alloc_3 = memref.alloc() : memref<4096x4096xf32> + %cast = memref.cast %alloc : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + %cast_4 = memref.cast %alloc_1 : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast_4, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<4096x4096xf32> + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<4096x4096xf32> } } - // Run GPU - %2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - // Run CPU. - %A_cast = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> - %B_cast = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> - %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - call @gemmBF16BF16F32(%A_cast, %B_cast, %C_cast) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %C_cast) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<4096x4096xbf16> - memref.dealloc %B : memref<4096x4096xbf16> - memref.dealloc %C : memref<4096x4096xf32> - memref.dealloc %C_ref : memref<4096x4096xf32> + %0 = call @test(%alloc, %alloc_1, %alloc_2) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %cast_5 = memref.cast %0 : memref<4096x4096xf32> to memref<*xf32> + %cast_6 = memref.cast %alloc : memref<4096x4096xbf16> to memref<*xbf16> + %cast_7 = memref.cast %alloc_1 : memref<4096x4096xbf16> to memref<*xbf16> + %cast_8 = memref.cast %alloc_3 : memref<4096x4096xf32> to memref<*xf32> + call @gemmBF16BF16F32(%cast_6, %cast_7, %cast_8) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_5, %cast_8) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4096x4096xbf16> + memref.dealloc %alloc_1 : memref<4096x4096xbf16> + memref.dealloc %alloc_2 : memref<4096x4096xf32> + memref.dealloc %alloc_3 : memref<4096x4096xf32> return } func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} @@ -592,5 +430,4 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomBF16(memref<*xbf16>, f32, f32, i1) attributes {llvm.emit_c_interface} func.func private @gemmBF16BF16F32(memref<*xbf16>, memref<*xbf16>, memref<*xf32>) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir index b1675e7cb..660ea58b3 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_dpas_sized_loads_f16_f16_f32.mlir @@ -1,36 +1,31 @@ -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %A, %A_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %B_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %B, %B_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32> - memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf32>) - gpu.dealloc %A_gpu : memref<4096x4096xf16> - gpu.dealloc %B_gpu : memref<4096x4096xf16> - return %C_gpu : memref<4096x4096xf32> + %memref = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref, %arg0 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_0 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_1 = gpu.alloc () : memref<4096x4096xf32> + gpu.memcpy %memref_1, %arg2 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<4096x4096xf16>, %memref_0 : memref<4096x4096xf16>, %memref_1 : memref<4096x4096xf32>) + gpu.dealloc %memref : memref<4096x4096xf16> + gpu.dealloc %memref_0 : memref<4096x4096xf16> + %alloc = memref.alloc() : memref<4096x4096xf32> + gpu.memcpy %alloc, %memref_1 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.dealloc %memref_1 : memref<4096x4096xf32> + return %alloc : memref<4096x4096xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // constants + gpu.func @test_kernel(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c256 = arith.constant 256 : index %c32 = arith.constant 32 : index %c4096 = arith.constant 4096 : index @@ -43,10 +38,7 @@ module @gemm attributes {gpu.container_module} { %c24 = arith.constant 24 : index %c0 = arith.constant 0 : index // get IDs - %wg_id_x = gpu.block_id x - %wg_id_y = gpu.block_id y // %sg_id = gpu.subgroup_id : index - // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout // C sg tile size is 32x64 // SG layout for one C tile update @@ -57,17 +49,8 @@ module @gemm attributes {gpu.container_module} { // --> y means cols // | // V x means rows - // get unique sg ID in global context - %global_sg_id_x = gpu.global_id x - %global_sg_id_y = gpu.global_id y - %local_sg_id_x = arith.remsi %global_sg_id_x, %c8 : index - %local_sg_id_y = arith.remsi %global_sg_id_y, %c4 : index - // compute SG C tile offsets in x and y dims - %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index - %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index - // each SG needs to do the follwoing compute to update its 32x64 sub tile // (32x4096)x(4096x64)=(32x64) // DPAS size is (8x16)x(16x16)=(8x16) @@ -75,14 +58,7 @@ module @gemm attributes {gpu.container_module} { // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) // this will require 32 DPAS ops (4x2x2) inside the K loop - // WG tiles are 256x256 so there offsets are same for A, B and C - %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index - %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index - - %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index - %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index - // prefetching A and B slice within the 256x256 WG tile // // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice @@ -94,21 +70,11 @@ module @gemm attributes {gpu.container_module} { // SG 4 -> slice 4 // .... // SG 31 -> slice 31 - %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index - %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index // create A preftech tiles and prefetch // stage 1 - %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) - %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch within K loop - %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - // prefetch the entire 32x256 slice of B WG tile, we still use the prefetch size 8x32. // SGs have 8x4 layout. In this case 8 subgroups must do a colloborative prefetch of 32x64 tile. // this because the B tile arrangement within the 32x256 slice is as follows @@ -124,281 +90,224 @@ module @gemm attributes {gpu.container_module} { // | 8 | 24|| 9 | 25 || 10 | 26 || 11| 27 | // | 12 | 28|| 13 | 29 || 14 | 30 || 15| 31 | // For example, SGs 0,4,8,12,16,20,24,28 share the data in left 32x64 tile of B slice. - // calculate the x offsets and y offsets within the 32x256 slice - %B_sg_prefetch_offset_x_temp0 = arith.remsi %local_sg_id_x, %c4 : index - %B_sg_prefetch_offset_x = arith.muli %B_sg_prefetch_offset_x_temp0, %c8 : index - %B_sg_prefetch_offset_y_temp0 = arith.muli %local_sg_id_y, %c64 : index - %B_sg_prefetch_offset_y_temp1 = arith.divsi %local_sg_id_x, %c4 : index - %B_sg_prefetch_offset_y_temp2 = arith.muli %B_sg_prefetch_offset_y_temp1, %c32 : index - %B_sg_prefetch_offset_y_temp3 = arith.addi %B_sg_prefetch_offset_y_temp0, %B_sg_prefetch_offset_y_temp2 : index - %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_temp3 : index - // create B prefetch tiles and prefetch - %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) - %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch inside K loop - %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - - // create A tiles - %A_sg_init_tile_0_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x16xf16> - %A_sg_init_tile_1_0 = xegpu.update_nd_offset %A_sg_init_tile_0_0, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> - %A_sg_init_tile_2_0 = xegpu.update_nd_offset %A_sg_init_tile_1_0, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> - %A_sg_init_tile_3_0 = xegpu.update_nd_offset %A_sg_init_tile_2_0, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> - %A_sg_init_tile_0_1 = xegpu.update_nd_offset %A_sg_init_tile_0_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %A_sg_init_tile_1_1 = xegpu.update_nd_offset %A_sg_init_tile_0_1, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> - %A_sg_init_tile_2_1 = xegpu.update_nd_offset %A_sg_init_tile_1_1, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> - %A_sg_init_tile_3_1 = xegpu.update_nd_offset %A_sg_init_tile_2_1, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> - //create B tiles - %B_sg_init_tile_0_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<16x16xf16> - %B_sg_init_tile_0_1 = xegpu.update_nd_offset %B_sg_init_tile_0_0, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %B_sg_init_tile_0_2 = xegpu.update_nd_offset %B_sg_init_tile_0_1, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %B_sg_init_tile_0_3 = xegpu.update_nd_offset %B_sg_init_tile_0_2, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %B_sg_init_tile_1_0 = xegpu.update_nd_offset %B_sg_init_tile_0_0, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> - %B_sg_init_tile_1_1 = xegpu.update_nd_offset %B_sg_init_tile_1_0, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %B_sg_init_tile_1_2 = xegpu.update_nd_offset %B_sg_init_tile_1_1, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - %B_sg_init_tile_1_3 = xegpu.update_nd_offset %B_sg_init_tile_1_2, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> - // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix - %zero_vec = arith.constant dense<0.0> : vector<128xf32> - %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - - + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %global_id_x = gpu.global_id x + %global_id_y = gpu.global_id y + %0 = arith.remsi %global_id_x, %c8 : index + %1 = arith.remsi %global_id_y, %c4 : index + %2 = arith.muli %global_id_x, %c32 : index + %3 = arith.muli %global_id_y, %c64 : index + %4 = arith.muli %block_id_x, %c256 : index + %5 = arith.muli %block_id_y, %c256 : index + %6 = arith.muli %0, %c4 : index + %7 = arith.addi %6, %1 : index + %8 = arith.muli %7, %c8 : index + %9 = arith.addi %8, %4 : index + %10 = xegpu.create_nd_tdesc %arg0[%9, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %11 = xegpu.update_nd_offset %10, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %12 = xegpu.update_nd_offset %11, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %13 = xegpu.update_nd_offset %12, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %14 = arith.remsi %0, %c4 : index + %15 = arith.muli %14, %c8 : index + %16 = arith.muli %1, %c64 : index + %17 = arith.divsi %0, %c4 : index + %18 = arith.muli %17, %c32 : index + %19 = arith.addi %16, %18 : index + %20 = arith.addi %5, %19 : index + %21 = xegpu.create_nd_tdesc %arg1[%15, %20] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %22 = xegpu.update_nd_offset %21, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %23 = xegpu.update_nd_offset %22, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %24 = xegpu.update_nd_offset %23, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %25 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x16xf16> + %26 = xegpu.update_nd_offset %25, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> + %27 = xegpu.update_nd_offset %26, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> + %28 = xegpu.update_nd_offset %27, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> + %29 = xegpu.update_nd_offset %25, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %30 = xegpu.update_nd_offset %29, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> + %31 = xegpu.update_nd_offset %30, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> + %32 = xegpu.update_nd_offset %31, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> + %33 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<16x16xf16> + %34 = xegpu.update_nd_offset %33, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %35 = xegpu.update_nd_offset %34, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %36 = xegpu.update_nd_offset %35, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %37 = xegpu.update_nd_offset %33, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + %38 = xegpu.update_nd_offset %37, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %39 = xegpu.update_nd_offset %38, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %40 = xegpu.update_nd_offset %39, [%c0, %c16] : !xegpu.tensor_desc<16x16xf16> + %cst = arith.constant dense<0.000000e+00> : vector<128xf32> + %41 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %42 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %43 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %44 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %45 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %46 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %47 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %48 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %49 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %50 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %51 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %52 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %53 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %54 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %55 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %56 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> xegpu.alloc_nbarrier 16 - %nbarrier_id = arith.constant 1 : i8 - %num_threads = arith.constant 8 : i8 - %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier // K loop advances in 32 steps - %k_loop_result:34 = scf.for %k = %c0 to %c4096 step %c32 iter_args ( - %A_tile_0_0 = %A_sg_init_tile_0_0, - %A_tile_1_0 = %A_sg_init_tile_1_0, - %A_tile_2_0 = %A_sg_init_tile_2_0, - %A_tile_3_0 = %A_sg_init_tile_3_0, - %A_tile_0_1 = %A_sg_init_tile_0_1, - %A_tile_1_1 = %A_sg_init_tile_1_1, - %A_tile_2_1 = %A_sg_init_tile_2_1, - %A_tile_3_1 = %A_sg_init_tile_3_1, - - %B_tile_0_0 = %B_sg_init_tile_0_0, - %B_tile_0_1 = %B_sg_init_tile_0_1, - %B_tile_0_2 = %B_sg_init_tile_0_2, - %B_tile_0_3 = %B_sg_init_tile_0_3, - %B_tile_1_0 = %B_sg_init_tile_1_0, - %B_tile_1_1 = %B_sg_init_tile_1_1, - %B_tile_1_2 = %B_sg_init_tile_1_2, - %B_tile_1_3 = %B_sg_init_tile_1_3, - - %c_val_0_0 = %c_init_val_0_0, - %c_val_0_1 = %c_init_val_0_1, - %c_val_0_2 = %c_init_val_0_2, - %c_val_0_3 = %c_init_val_0_3, - %c_val_1_0 = %c_init_val_1_0, - %c_val_1_1 = %c_init_val_1_1, - %c_val_1_2 = %c_init_val_1_2, - %c_val_1_3 = %c_init_val_1_3, - %c_val_2_0 = %c_init_val_2_0, - %c_val_2_1 = %c_init_val_2_1, - %c_val_2_2 = %c_init_val_2_2, - %c_val_2_3 = %c_init_val_2_3, - %c_val_3_0 = %c_init_val_3_0, - %c_val_3_1 = %c_init_val_3_1, - %c_val_3_2 = %c_init_val_3_2, - %c_val_3_3 = %c_init_val_3_3, - - %A_prefetch_tile = %A_sg_prefetch_tile_iter3, - %B_prefetch_tile = %B_sg_prefetch_tile_iter3 - ) -> - (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, - !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> - ) - { // all SGs must arrive here first - xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier // load A tiles - %a_val_0_0 = xegpu.load_nd %A_tile_0_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %a_val_1_0 = xegpu.load_nd %A_tile_1_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %a_val_2_0 = xegpu.load_nd %A_tile_2_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %a_val_3_0 = xegpu.load_nd %A_tile_3_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %a_val_0_1 = xegpu.load_nd %A_tile_0_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %a_val_1_1 = xegpu.load_nd %A_tile_1_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %a_val_2_1 = xegpu.load_nd %A_tile_2_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %a_val_3_1 = xegpu.load_nd %A_tile_3_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - // load B tiles - %b_val_0_0 = xegpu.load_nd %B_tile_0_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %b_val_0_1 = xegpu.load_nd %B_tile_0_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %b_val_0_2 = xegpu.load_nd %B_tile_0_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %b_val_0_3 = xegpu.load_nd %B_tile_0_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %b_val_1_0 = xegpu.load_nd %B_tile_1_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %b_val_1_1 = xegpu.load_nd %B_tile_1_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %b_val_1_2 = xegpu.load_nd %B_tile_1_2 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %b_val_1_3 = xegpu.load_nd %B_tile_1_3 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - // prefetch A and B tiles - xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> - // + %c1_i8 = arith.constant 1 : i8 + %c8_i8 = arith.constant 8 : i8 + %57 = xegpu.init_nbarrier %c1_i8, %c8_i8 : i8, i8 -> !xegpu.nbarrier + %58:34 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %25, %arg5 = %26, %arg6 = %27, %arg7 = %28, %arg8 = %29, %arg9 = %30, %arg10 = %31, %arg11 = %32, %arg12 = %33, %arg13 = %34, %arg14 = %35, %arg15 = %36, %arg16 = %37, %arg17 = %38, %arg18 = %39, %arg19 = %40, %arg20 = %41, %arg21 = %42, %arg22 = %43, %arg23 = %44, %arg24 = %45, %arg25 = %46, %arg26 = %47, %arg27 = %48, %arg28 = %49, %arg29 = %50, %arg30 = %51, %arg31 = %52, %arg32 = %53, %arg33 = %54, %arg34 = %55, %arg35 = %56, %arg36 = %13, %arg37 = %24) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16>) { + xegpu.nbarrier_arrive %57 : !xegpu.nbarrier + %75 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %76 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %77 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %78 = xegpu.load_nd %arg7 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %79 = xegpu.load_nd %arg8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %80 = xegpu.load_nd %arg9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %81 = xegpu.load_nd %arg10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %82 = xegpu.load_nd %arg11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %83 = xegpu.load_nd %arg12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %84 = xegpu.load_nd %arg13 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %85 = xegpu.load_nd %arg14 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %86 = xegpu.load_nd %arg15 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %87 = xegpu.load_nd %arg16 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %88 = xegpu.load_nd %arg17 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %89 = xegpu.load_nd %arg18 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %90 = xegpu.load_nd %arg19 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + xegpu.prefetch_nd %arg36 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %arg37 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - // advance A and B prefetch tiles - %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> // advance A and B tiles - %next_A_tile_0_0 = xegpu.update_nd_offset %A_tile_0_0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - %next_A_tile_1_0 = xegpu.update_nd_offset %A_tile_1_0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - %next_A_tile_2_0 = xegpu.update_nd_offset %A_tile_2_0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - %next_A_tile_3_0 = xegpu.update_nd_offset %A_tile_3_0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - %next_A_tile_0_1 = xegpu.update_nd_offset %A_tile_0_1, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - %next_A_tile_1_1 = xegpu.update_nd_offset %A_tile_1_1, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - %next_A_tile_2_1 = xegpu.update_nd_offset %A_tile_2_1, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - %next_A_tile_3_1 = xegpu.update_nd_offset %A_tile_3_1, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> - - %next_B_tile_0_0 = xegpu.update_nd_offset %B_tile_0_0, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - %next_B_tile_0_1 = xegpu.update_nd_offset %B_tile_0_1, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - %next_B_tile_0_2 = xegpu.update_nd_offset %B_tile_0_2, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - %next_B_tile_0_3 = xegpu.update_nd_offset %B_tile_0_3, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - %next_B_tile_1_0 = xegpu.update_nd_offset %B_tile_1_0, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - %next_B_tile_1_1 = xegpu.update_nd_offset %B_tile_1_1, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - %next_B_tile_1_2 = xegpu.update_nd_offset %B_tile_1_2, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - %next_B_tile_1_3 = xegpu.update_nd_offset %B_tile_1_3, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> - + %91 = xegpu.update_nd_offset %arg36, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %92 = xegpu.update_nd_offset %arg37, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %93 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %94 = xegpu.update_nd_offset %arg5, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %95 = xegpu.update_nd_offset %arg6, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %96 = xegpu.update_nd_offset %arg7, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %97 = xegpu.update_nd_offset %arg8, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %98 = xegpu.update_nd_offset %arg9, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %99 = xegpu.update_nd_offset %arg10, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %100 = xegpu.update_nd_offset %arg11, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16> + %101 = xegpu.update_nd_offset %arg12, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %102 = xegpu.update_nd_offset %arg13, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %103 = xegpu.update_nd_offset %arg14, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %104 = xegpu.update_nd_offset %arg15, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %105 = xegpu.update_nd_offset %arg16, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %106 = xegpu.update_nd_offset %arg17, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %107 = xegpu.update_nd_offset %arg18, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> + %108 = xegpu.update_nd_offset %arg19, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16> xegpu.compile_hint - // do DPAS - %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - + %109 = xegpu.dpas %75, %83, %arg20 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %110 = xegpu.dpas %75, %84, %arg21 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %111 = xegpu.dpas %75, %85, %arg22 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %112 = xegpu.dpas %75, %86, %arg23 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %113 = xegpu.dpas %76, %83, %arg24 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %114 = xegpu.dpas %76, %84, %arg25 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %115 = xegpu.dpas %76, %85, %arg26 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %116 = xegpu.dpas %76, %86, %arg27 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %117 = xegpu.dpas %77, %83, %arg28 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %118 = xegpu.dpas %77, %84, %arg29 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %119 = xegpu.dpas %77, %85, %arg30 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %120 = xegpu.dpas %77, %86, %arg31 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %121 = xegpu.dpas %78, %83, %arg32 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %122 = xegpu.dpas %78, %84, %arg33 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %123 = xegpu.dpas %78, %85, %arg34 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %124 = xegpu.dpas %78, %86, %arg35 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %125 = xegpu.dpas %79, %87, %109 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %126 = xegpu.dpas %79, %88, %110 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %127 = xegpu.dpas %79, %89, %111 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %128 = xegpu.dpas %79, %90, %112 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %129 = xegpu.dpas %80, %87, %113 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %130 = xegpu.dpas %80, %88, %114 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %131 = xegpu.dpas %80, %89, %115 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %132 = xegpu.dpas %80, %90, %116 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %133 = xegpu.dpas %81, %87, %117 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %134 = xegpu.dpas %81, %88, %118 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %135 = xegpu.dpas %81, %89, %119 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %136 = xegpu.dpas %81, %90, %120 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %137 = xegpu.dpas %82, %87, %121 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %138 = xegpu.dpas %82, %88, %122 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %139 = xegpu.dpas %82, %89, %123 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %140 = xegpu.dpas %82, %90, %124 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> xegpu.compile_hint // barrier wait - xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier - - scf.yield %next_A_tile_0_0, %next_A_tile_1_0, %next_A_tile_2_0, %next_A_tile_3_0, %next_A_tile_0_1, %next_A_tile_1_1, %next_A_tile_2_1, %next_A_tile_3_1, - %next_B_tile_0_0, %next_B_tile_0_1, %next_B_tile_0_2, %next_B_tile_0_3, %next_B_tile_1_0, %next_B_tile_1_1, %next_B_tile_1_2, %next_B_tile_1_3, - %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, - %next_A_prefetch_tile, %next_B_prefetch_tile - : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, - !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>,!xegpu.tensor_desc<16x16xf16>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + xegpu.nbarrier_wait %57 : !xegpu.nbarrier + scf.yield %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136, %137, %138, %139, %140, %91, %92 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> } - // each SG needs to write to 32x64 C tile. // DPAS output size is 8x16. So each SG needs to write 16 (4x4) DPAS outputs. // create 16 address descriptions to cover 8x16 tiles in 4x4 layout within the 32x64 SG C tile. // advance 8 in x dim and, advance 16 in y dim // row 1 - %c_sg_tile_0_0 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<4096x4096xf32> -> !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_0_1 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_0_2 = xegpu.update_nd_offset %c_sg_tile_0_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_0_3 = xegpu.update_nd_offset %c_sg_tile_0_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> // row 2 - %c_sg_tile_1_0 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_1_1 = xegpu.update_nd_offset %c_sg_tile_1_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_1_2 = xegpu.update_nd_offset %c_sg_tile_1_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_1_3 = xegpu.update_nd_offset %c_sg_tile_1_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> // row 3 - %c_sg_tile_2_0 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c16, %c0] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_2_1 = xegpu.update_nd_offset %c_sg_tile_2_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_2_2 = xegpu.update_nd_offset %c_sg_tile_2_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_2_3 = xegpu.update_nd_offset %c_sg_tile_2_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> // row 4 - %c_sg_tile_3_0 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c24, %c0] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_3_1 = xegpu.update_nd_offset %c_sg_tile_3_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_3_2 = xegpu.update_nd_offset %c_sg_tile_3_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> - %c_sg_tile_3_3 = xegpu.update_nd_offset %c_sg_tile_3_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> // do store_nd - xegpu.store_nd %k_loop_result#16, %c_sg_tile_0_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#17, %c_sg_tile_0_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#18, %c_sg_tile_0_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#19, %c_sg_tile_0_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#20, %c_sg_tile_1_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#21, %c_sg_tile_1_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#22, %c_sg_tile_1_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#23, %c_sg_tile_1_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#24, %c_sg_tile_2_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#25, %c_sg_tile_2_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#26, %c_sg_tile_2_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#27, %c_sg_tile_2_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#28, %c_sg_tile_3_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#29, %c_sg_tile_3_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#30, %c_sg_tile_3_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %k_loop_result#31, %c_sg_tile_3_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %59 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<4096x4096xf32> -> !xegpu.tensor_desc<8x16xf32> + %60 = xegpu.update_nd_offset %59, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %61 = xegpu.update_nd_offset %60, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %62 = xegpu.update_nd_offset %61, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %63 = xegpu.update_nd_offset %59, [%c8, %c0] : !xegpu.tensor_desc<8x16xf32> + %64 = xegpu.update_nd_offset %63, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %65 = xegpu.update_nd_offset %64, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %66 = xegpu.update_nd_offset %65, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %67 = xegpu.update_nd_offset %59, [%c16, %c0] : !xegpu.tensor_desc<8x16xf32> + %68 = xegpu.update_nd_offset %67, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %69 = xegpu.update_nd_offset %68, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %70 = xegpu.update_nd_offset %69, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %71 = xegpu.update_nd_offset %59, [%c24, %c0] : !xegpu.tensor_desc<8x16xf32> + %72 = xegpu.update_nd_offset %71, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %73 = xegpu.update_nd_offset %72, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + %74 = xegpu.update_nd_offset %73, [%c0, %c16] : !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#16, %59 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#17, %60 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#18, %61 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#19, %62 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#20, %63 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#21, %64 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#22, %65 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#23, %66 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#24, %67 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#25, %68 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#26, %69 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#27, %70 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#28, %71 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#29, %72 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#30, %73 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %58#31, %74 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 5.000000e-01 : f32 + %cst_1 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c1_f16 = arith.constant 1.0 : f16 - %c2_f16 = arith.constant 2.0 : f16 %c4096 = arith.constant 4096 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<4096x4096xf16> - %B = memref.alloc() : memref<4096x4096xf16> - %C = memref.alloc() : memref<4096x4096xf32> - %C_ref = memref.alloc() : memref<4096x4096xf32> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 // Use one of the two options to initialize the A matrix // Option 1: intialize matrix A ; A[i, j] = j // scf.for %i = %c0 to %c4096 step %c1 { @@ -411,10 +320,6 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %A_random = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - // Use one of the two options below to initialize the B matrix // Option 1: make matrix B an identity matrix // scf.for %i = %c0 to %c4096 step %c1 { @@ -422,7 +327,6 @@ module @gemm attributes {gpu.container_module} { // %i_i32 = index.castu %i : index to i32 // %j_i32 = index.castu %j : index to i32 // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - // scf.if %i_j_same { // memref.store %cf_1, %B[%i, %j] : memref<4096x4096xf16> // } else { @@ -431,35 +335,35 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %B_random = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + %alloc = memref.alloc() : memref<4096x4096xf16> + %alloc_2 = memref.alloc() : memref<4096x4096xf16> + %alloc_3 = memref.alloc() : memref<4096x4096xf32> + %alloc_4 = memref.alloc() : memref<4096x4096xf32> + %cast = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_1, %cst_0, %false) : (memref<*xf16>, f32, f32, i1) -> () + %cast_5 = memref.cast %alloc_2 : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_5, %cst_1, %cst_0, %false) : (memref<*xf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<4096x4096xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<4096x4096xf32> } } - // Run GPU. - %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - // Run CPU. - %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - %C_ref_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - call @gemmF16F16F32(%A_cast, %B_cast, %C_ref_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %C_ref_cast) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<4096x4096xf16> - memref.dealloc %B : memref<4096x4096xf16> - memref.dealloc %C : memref<4096x4096xf32> - memref.dealloc %C_ref : memref<4096x4096xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %cast_6 = memref.cast %0 : memref<4096x4096xf32> to memref<*xf32> + %cast_7 = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + %cast_8 = memref.cast %alloc_2 : memref<4096x4096xf16> to memref<*xf16> + %cast_9 = memref.cast %alloc_4 : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F32(%cast_7, %cast_8, %cast_9) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_6, %cast_9) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4096x4096xf16> + memref.dealloc %alloc_2 : memref<4096x4096xf16> + memref.dealloc %alloc_3 : memref<4096x4096xf32> + memref.dealloc %alloc_4 : memref<4096x4096xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} @@ -467,5 +371,4 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} func.func private @gemmF16F16F32(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir index 02f744a72..f7aa989aa 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16.mlir @@ -1,35 +1,31 @@ -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %A, %A_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %B_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %B, %B_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %C, %C_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf16>) - gpu.dealloc %A_gpu : memref<4096x4096xf16> - gpu.dealloc %B_gpu : memref<4096x4096xf16> - return %C_gpu : memref<4096x4096xf16> + %memref = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref, %arg0 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_0 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_1 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_1, %arg2 : memref<4096x4096xf16>, memref<4096x4096xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<4096x4096xf16>, %memref_0 : memref<4096x4096xf16>, %memref_1 : memref<4096x4096xf16>) + gpu.dealloc %memref : memref<4096x4096xf16> + gpu.dealloc %memref_0 : memref<4096x4096xf16> + %alloc = memref.alloc() : memref<4096x4096xf16> + gpu.memcpy %alloc, %memref_1 : memref<4096x4096xf16>, memref<4096x4096xf16> + gpu.dealloc %memref_1 : memref<4096x4096xf16> + return %alloc : memref<4096x4096xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // constants + gpu.func @test_kernel(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c256 = arith.constant 256 : index %c512 = arith.constant 512 : index %c128 = arith.constant 128 : index @@ -45,10 +41,7 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 // get IDs - %wg_id_x = gpu.block_id x - %wg_id_y = gpu.block_id y // %sg_id = gpu.subgroup_id : index - // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout // C sg tile size is 32x64 // SG layout for one C tile update @@ -59,17 +52,8 @@ module @gemm attributes {gpu.container_module} { // --> y means cols // | // V x means rows - // get unique sg ID in global context - %global_sg_id_x = gpu.global_id x - %global_sg_id_y = gpu.global_id y - %local_sg_id_x = arith.remui %global_sg_id_x, %c8 : index - %local_sg_id_y = arith.remui %global_sg_id_y, %c4 : index - // compute SG C tile offsets in x and y dims - %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index - %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index - // each SG needs to do the follwoing compute to update its 32x64 sub tile // (32x4096)x(4096x64)=(32x64) // DPAS size is (8x16)x(16x16)=(8x16) @@ -77,14 +61,7 @@ module @gemm attributes {gpu.container_module} { // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) // this will require 32 DPAS ops (4x2x2) inside the K loop - // WG tiles are 256x256 so there offsets are same for A, B and C - %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index - %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index - - %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index - %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index - // prefetching A and B slice within the 256x256 WG tile // // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice @@ -96,21 +73,11 @@ module @gemm attributes {gpu.container_module} { // SG 4 -> slice 4 // .... // SG 31 -> slice 31 - %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index - %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index // create A preftech tiles and prefetch // stage 1 - %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) - %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch within K loop - %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - // prefetch the entire 32x256 slice of B WG tile, we still use the prefetch size 8x32. // SGs have 8x4 layout. In this case 8 subgroups must do a colloborative prefetch of 32x64 tile. // this because the B tile arrangement within the 32x256 slice is as follows @@ -126,335 +93,265 @@ module @gemm attributes {gpu.container_module} { // | 8 | 24|| 9 | 25 || 10 | 26 || 11| 27 | // | 12 | 28|| 13 | 29 || 14 | 30 || 15| 31 | // For example, SGs 0,4,8,12,16,20,24,28 share the data in left 32x64 tile of B slice. - // calculate the x offsets and y offsets within the 32x256 slice - %B_sg_prefetch_offset_x_temp0 = arith.remui %local_sg_id_x, %c4 : index - %B_sg_prefetch_offset_x = arith.muli %B_sg_prefetch_offset_x_temp0, %c8 : index - %B_sg_prefetch_offset_y_temp0 = arith.muli %local_sg_id_y, %c64 : index - %B_sg_prefetch_offset_y_temp1 = arith.divui %local_sg_id_x, %c4 : index - %B_sg_prefetch_offset_y_temp2 = arith.muli %B_sg_prefetch_offset_y_temp1, %c32 : index - %B_sg_prefetch_offset_y_temp3 = arith.addi %B_sg_prefetch_offset_y_temp0, %B_sg_prefetch_offset_y_temp2 : index - %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_temp3 : index - // create B prefetch tiles and prefetch - %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) - %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch inside K loop - %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - - // two 32x16 A tiles from 256x32 WG slice - %A_sg_init_tile_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %A_sg_init_tile_1 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c16] : memref<4096x4096xf16> - //create B tiles - %B_sg_init_tile_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %B_sg_init_tile_1 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %B_sg_init_tile_2 = xegpu.update_nd_offset %B_sg_init_tile_1, [%c0, %c16] : !xegpu.tensor_desc<32x16xf16> // %B_sg_init_tile_3 = xegpu.update_nd_offset %B_sg_init_tile_2, [%c0, %c16] : !xegpu.tensor_desc<32x16xf16> - // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix - %zero_vec = arith.constant dense<0.0> : vector<128xf32> - %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - - + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %global_id_x = gpu.global_id x + %global_id_y = gpu.global_id y + %0 = arith.remui %global_id_x, %c8 : index + %1 = arith.remui %global_id_y, %c4 : index + %2 = arith.muli %global_id_x, %c32 : index + %3 = arith.muli %global_id_y, %c64 : index + %4 = arith.muli %block_id_x, %c256 : index + %5 = arith.muli %block_id_y, %c256 : index + %6 = arith.muli %0, %c4 : index + %7 = arith.addi %6, %1 : index + %8 = arith.muli %7, %c8 : index + %9 = arith.addi %8, %4 : index + %10 = xegpu.create_nd_tdesc %arg0[%9, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %11 = xegpu.update_nd_offset %10, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %12 = xegpu.update_nd_offset %11, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %13 = xegpu.update_nd_offset %12, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %14 = arith.remui %0, %c4 : index + %15 = arith.muli %14, %c8 : index + %16 = arith.muli %1, %c64 : index + %17 = arith.divui %0, %c4 : index + %18 = arith.muli %17, %c32 : index + %19 = arith.addi %16, %18 : index + %20 = arith.addi %5, %19 : index + %21 = xegpu.create_nd_tdesc %arg1[%15, %20] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %22 = xegpu.update_nd_offset %21, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %23 = xegpu.update_nd_offset %22, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %24 = xegpu.update_nd_offset %23, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %25 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %26 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %27 = xegpu.update_nd_offset %26, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %cst = arith.constant dense<0.000000e+00> : vector<128xf32> + %28 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %29 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %30 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %31 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %32 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %33 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %34 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %35 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %36 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %37 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %38 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %39 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %40 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %41 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %42 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %43 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> xegpu.alloc_nbarrier 16 - %nbarrier_id = arith.constant 1 : i8 - %num_threads = arith.constant 32 : i8 - %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier // K loop advances in 32 steps - %k_loop_result:21 = scf.for %k = %c0 to %c4096 step %c32 iter_args ( - %A_tile_0 = %A_sg_init_tile_0, // %A_tile_1 = %A_sg_init_tile_1, - - %B_tile_0 = %B_sg_init_tile_0, - %B_tile_1 = %B_sg_init_tile_1, // %B_tile_2 = %B_sg_init_tile_2, // %B_tile_3 = %B_sg_init_tile_3, - - %c_val_0_0 = %c_init_val_0_0, - %c_val_0_1 = %c_init_val_0_1, - %c_val_0_2 = %c_init_val_0_2, - %c_val_0_3 = %c_init_val_0_3, - %c_val_1_0 = %c_init_val_1_0, - %c_val_1_1 = %c_init_val_1_1, - %c_val_1_2 = %c_init_val_1_2, - %c_val_1_3 = %c_init_val_1_3, - %c_val_2_0 = %c_init_val_2_0, - %c_val_2_1 = %c_init_val_2_1, - %c_val_2_2 = %c_init_val_2_2, - %c_val_2_3 = %c_init_val_2_3, - %c_val_3_0 = %c_init_val_3_0, - %c_val_3_1 = %c_init_val_3_1, - %c_val_3_2 = %c_init_val_3_2, - %c_val_3_3 = %c_init_val_3_3, - - %A_prefetch_tile = %A_sg_prefetch_tile_iter3, - %B_prefetch_tile = %B_sg_prefetch_tile_iter3 - ) -> - (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> - ) - { // all SGs must arrive here first - %every_8th_iter = arith.remui %k, %c256 : index - %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 - %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 - scf.if %every_8th_iter_cond { - xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier + %c1_i8 = arith.constant 1 : i8 + %c32_i8 = arith.constant 32 : i8 + %44 = xegpu.init_nbarrier %c1_i8, %c32_i8 : i8, i8 -> !xegpu.nbarrier + %45:21 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %25, %arg5 = %26, %arg6 = %27, %arg7 = %28, %arg8 = %29, %arg9 = %30, %arg10 = %31, %arg11 = %32, %arg12 = %33, %arg13 = %34, %arg14 = %35, %arg15 = %36, %arg16 = %37, %arg17 = %38, %arg18 = %39, %arg19 = %40, %arg20 = %41, %arg21 = %42, %arg22 = %43, %arg23 = %13, %arg24 = %24) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16>) { + %78 = arith.remui %arg3, %c256 : index + %79 = arith.index_cast %78 : index to i32 + %80 = arith.cmpi eq, %79, %c0_i32 : i32 + scf.if %80 { + xegpu.nbarrier_arrive %44 : !xegpu.nbarrier } // load A tiles - %a_val = xegpu.load_nd %A_tile_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - %a_val_0 = vector.extract %a_val [0] : vector<32x16xf16> from vector<2x32x16xf16> - %a_val_1 = vector.extract %a_val [1] : vector<32x16xf16> from vector<2x32x16xf16> - // load B tiles - %b_val_arr_0 = xegpu.load_nd %B_tile_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> - %b_val_arr_1 = xegpu.load_nd %B_tile_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> - - %b_val_0 = vector.extract %b_val_arr_0 [0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_1 = vector.extract %b_val_arr_0 [1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_2 = vector.extract %b_val_arr_1 [0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_3 = vector.extract %b_val_arr_1 [1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - - xegpu.compile_hint - // prefetch A and B tiles - xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> - // - xegpu.compile_hint - // advance A and B prefetch tiles - %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> // advance A and B tiles - %next_A_tile_0 = xegpu.update_nd_offset %A_tile_0, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %next_A_tile_1 = xegpu.update_nd_offset %A_tile_1, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16> - - %next_B_tile_0 = xegpu.update_nd_offset %B_tile_0, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %next_B_tile_1 = xegpu.update_nd_offset %B_tile_1, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %next_B_tile_2 = xegpu.update_nd_offset %B_tile_2, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16> // %next_B_tile_3 = xegpu.update_nd_offset %B_tile_3, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16> - - xegpu.compile_hint - %a_val_0_flat = vector.shape_cast %a_val_0 : vector<32x16xf16> to vector<512xf16> - %a_val_1_flat = vector.shape_cast %a_val_1 : vector<32x16xf16> to vector<512xf16> - %a_val_0_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_0_0 = vector.shape_cast %a_val_0_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_1_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [128], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_1_0 = vector.shape_cast %a_val_1_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_2_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_2_0 = vector.shape_cast %a_val_2_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_3_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [384], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_3_0 = vector.shape_cast %a_val_3_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_0_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_0_1 = vector.shape_cast %a_val_0_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_1_1_flat = vector.extract_strided_slice %a_val_1_flat {offsets = [128], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_1_1 = vector.shape_cast %a_val_1_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_2_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_2_1 = vector.shape_cast %a_val_2_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_3_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [384], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_3_1 = vector.shape_cast %a_val_3_1_flat : vector<128xf16> to vector<8x16xf16> - - - %b_val_0_flat = vector.shape_cast %b_val_0 : vector<16x16x2xf16> to vector<512xf16> - %b_val_1_flat = vector.shape_cast %b_val_1 : vector<16x16x2xf16> to vector<512xf16> - %b_val_2_flat = vector.shape_cast %b_val_2 : vector<16x16x2xf16> to vector<512xf16> - %b_val_3_flat = vector.shape_cast %b_val_3 : vector<16x16x2xf16> to vector<512xf16> - %b_val_0_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_0 = vector.shape_cast %b_val_0_0_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_0 = vector.shape_cast %b_val_1_0_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_1_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_1 = vector.shape_cast %b_val_0_1_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_1_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_1 = vector.shape_cast %b_val_1_1_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_2_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_2 = vector.shape_cast %b_val_0_2_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_2_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_2 = vector.shape_cast %b_val_1_2_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_3_flat = vector.extract_strided_slice %b_val_3_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_3 = vector.shape_cast %b_val_0_3_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_3_flat = vector.extract_strided_slice %b_val_3_flat {offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_3 = vector.shape_cast %b_val_1_3_flat : vector<256xf16> to vector<8x16x2xf16> - // do DPAS + // barrier wait + %81 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %82 = vector.extract %81[0] : vector<32x16xf16> from vector<2x32x16xf16> + %83 = vector.extract %81[1] : vector<32x16xf16> from vector<2x32x16xf16> + %84 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> + %85 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> + %86 = vector.extract %84[0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %87 = vector.extract %84[1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %88 = vector.extract %85[0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %89 = vector.extract %85[1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> xegpu.compile_hint - %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.prefetch_nd %arg23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %arg24 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - + %90 = xegpu.update_nd_offset %arg23, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %91 = xegpu.update_nd_offset %arg24, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %92 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %93 = xegpu.update_nd_offset %arg5, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %94 = xegpu.update_nd_offset %arg6, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> xegpu.compile_hint - // barrier wait - scf.if %every_8th_iter_cond { - xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier + %95 = vector.shape_cast %82 : vector<32x16xf16> to vector<512xf16> + %96 = vector.shape_cast %83 : vector<32x16xf16> to vector<512xf16> + %97 = vector.extract_strided_slice %95 {offsets = [0], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %98 = vector.shape_cast %97 : vector<128xf16> to vector<8x16xf16> + %99 = vector.extract_strided_slice %95 {offsets = [128], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %100 = vector.shape_cast %99 : vector<128xf16> to vector<8x16xf16> + %101 = vector.extract_strided_slice %95 {offsets = [256], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %102 = vector.shape_cast %101 : vector<128xf16> to vector<8x16xf16> + %103 = vector.extract_strided_slice %95 {offsets = [384], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %104 = vector.shape_cast %103 : vector<128xf16> to vector<8x16xf16> + %105 = vector.extract_strided_slice %96 {offsets = [0], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %106 = vector.shape_cast %105 : vector<128xf16> to vector<8x16xf16> + %107 = vector.extract_strided_slice %96 {offsets = [128], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %108 = vector.shape_cast %107 : vector<128xf16> to vector<8x16xf16> + %109 = vector.extract_strided_slice %96 {offsets = [256], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %110 = vector.shape_cast %109 : vector<128xf16> to vector<8x16xf16> + %111 = vector.extract_strided_slice %96 {offsets = [384], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %112 = vector.shape_cast %111 : vector<128xf16> to vector<8x16xf16> + %113 = vector.shape_cast %86 : vector<16x16x2xf16> to vector<512xf16> + %114 = vector.shape_cast %87 : vector<16x16x2xf16> to vector<512xf16> + %115 = vector.shape_cast %88 : vector<16x16x2xf16> to vector<512xf16> + %116 = vector.shape_cast %89 : vector<16x16x2xf16> to vector<512xf16> + %117 = vector.extract_strided_slice %113 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %118 = vector.shape_cast %117 : vector<256xf16> to vector<8x16x2xf16> + %119 = vector.extract_strided_slice %113 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %120 = vector.shape_cast %119 : vector<256xf16> to vector<8x16x2xf16> + %121 = vector.extract_strided_slice %114 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %122 = vector.shape_cast %121 : vector<256xf16> to vector<8x16x2xf16> + %123 = vector.extract_strided_slice %114 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %124 = vector.shape_cast %123 : vector<256xf16> to vector<8x16x2xf16> + %125 = vector.extract_strided_slice %115 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %126 = vector.shape_cast %125 : vector<256xf16> to vector<8x16x2xf16> + %127 = vector.extract_strided_slice %115 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %128 = vector.shape_cast %127 : vector<256xf16> to vector<8x16x2xf16> + %129 = vector.extract_strided_slice %116 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %130 = vector.shape_cast %129 : vector<256xf16> to vector<8x16x2xf16> + %131 = vector.extract_strided_slice %116 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %132 = vector.shape_cast %131 : vector<256xf16> to vector<8x16x2xf16> + xegpu.compile_hint + %133 = xegpu.dpas %98, %118, %arg7 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %134 = xegpu.dpas %106, %120, %133 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %135 = xegpu.dpas %100, %118, %arg11 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %136 = xegpu.dpas %108, %120, %135 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %137 = xegpu.dpas %102, %118, %arg15 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %138 = xegpu.dpas %110, %120, %137 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %139 = xegpu.dpas %104, %118, %arg19 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %140 = xegpu.dpas %112, %120, %139 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %141 = xegpu.dpas %98, %122, %arg8 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %142 = xegpu.dpas %106, %124, %141 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %143 = xegpu.dpas %100, %122, %arg12 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %144 = xegpu.dpas %108, %124, %143 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %145 = xegpu.dpas %102, %122, %arg16 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %146 = xegpu.dpas %110, %124, %145 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %147 = xegpu.dpas %104, %122, %arg20 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %148 = xegpu.dpas %112, %124, %147 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %149 = xegpu.dpas %98, %126, %arg9 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %150 = xegpu.dpas %106, %128, %149 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %151 = xegpu.dpas %100, %126, %arg13 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %152 = xegpu.dpas %108, %128, %151 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %153 = xegpu.dpas %102, %126, %arg17 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %154 = xegpu.dpas %110, %128, %153 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %155 = xegpu.dpas %104, %126, %arg21 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %156 = xegpu.dpas %112, %128, %155 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %157 = xegpu.dpas %98, %130, %arg10 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %158 = xegpu.dpas %106, %132, %157 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %159 = xegpu.dpas %100, %130, %arg14 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %160 = xegpu.dpas %108, %132, %159 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %161 = xegpu.dpas %102, %130, %arg18 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %162 = xegpu.dpas %110, %132, %161 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %163 = xegpu.dpas %104, %130, %arg22 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %164 = xegpu.dpas %112, %132, %163 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + scf.if %80 { + xegpu.nbarrier_wait %44 : !xegpu.nbarrier } - - scf.yield %next_A_tile_0, %next_B_tile_0, %next_B_tile_1, - %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, - %next_A_prefetch_tile, %next_B_prefetch_tile - : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + scf.yield %92, %93, %94, %134, %142, %150, %158, %136, %144, %152, %160, %138, %146, %154, %162, %140, %148, %156, %164, %90, %91 : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> } - // trunc to f16 - %c_result_0_0_f16 = arith.truncf %k_loop_result#3 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_1_f16 = arith.truncf %k_loop_result#4 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_2_f16 = arith.truncf %k_loop_result#5 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_3_f16 = arith.truncf %k_loop_result#6 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_0_f16 = arith.truncf %k_loop_result#7 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_1_f16 = arith.truncf %k_loop_result#8 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_2_f16 = arith.truncf %k_loop_result#9 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_3_f16 = arith.truncf %k_loop_result#10 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_0_f16 = arith.truncf %k_loop_result#11 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_1_f16 = arith.truncf %k_loop_result#12 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_2_f16 = arith.truncf %k_loop_result#13 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_3_f16 = arith.truncf %k_loop_result#14 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_0_f16 = arith.truncf %k_loop_result#15 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_1_f16 = arith.truncf %k_loop_result#16 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_2_f16 = arith.truncf %k_loop_result#17 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_3_f16 = arith.truncf %k_loop_result#18 : vector<8x16xf32> to vector<8x16xf16> - // each SG needs to write to 32x64 C tile. // DPAS output size is 8x16. So each SG needs to write 16 (4x4) DPAS outputs. // create 16 address descriptions to cover 8x16 tiles in 4x4 layout within the 32x64 SG C tile. // advance 8 in x dim and, advance 16 in y dim // row 1 - %c_sg_tile_0_0 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_0_1 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_0_2 = xegpu.update_nd_offset %c_sg_tile_0_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_0_3 = xegpu.update_nd_offset %c_sg_tile_0_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> // row 2 - %c_sg_tile_1_0 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_1_1 = xegpu.update_nd_offset %c_sg_tile_1_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_1_2 = xegpu.update_nd_offset %c_sg_tile_1_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_1_3 = xegpu.update_nd_offset %c_sg_tile_1_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> // row 3 - %c_sg_tile_2_0 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c16, %c0] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_2_1 = xegpu.update_nd_offset %c_sg_tile_2_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_2_2 = xegpu.update_nd_offset %c_sg_tile_2_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_2_3 = xegpu.update_nd_offset %c_sg_tile_2_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> // row 4 - %c_sg_tile_3_0 = xegpu.update_nd_offset %c_sg_tile_0_0, [%c24, %c0] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_3_1 = xegpu.update_nd_offset %c_sg_tile_3_0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_3_2 = xegpu.update_nd_offset %c_sg_tile_3_1, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - %c_sg_tile_3_3 = xegpu.update_nd_offset %c_sg_tile_3_2, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> - - // do store_nd - xegpu.store_nd %c_result_0_0_f16, %c_sg_tile_0_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_0_1_f16, %c_sg_tile_0_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_0_2_f16, %c_sg_tile_0_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_0_3_f16, %c_sg_tile_0_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_1_0_f16, %c_sg_tile_1_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_1_1_f16, %c_sg_tile_1_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_1_2_f16, %c_sg_tile_1_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_1_3_f16, %c_sg_tile_1_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_2_0_f16, %c_sg_tile_2_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_2_1_f16, %c_sg_tile_2_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_2_2_f16, %c_sg_tile_2_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_2_3_f16, %c_sg_tile_2_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_3_0_f16, %c_sg_tile_3_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_3_1_f16, %c_sg_tile_3_1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_3_2_f16, %c_sg_tile_3_2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %c_result_3_3_f16, %c_sg_tile_3_3 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + %46 = arith.truncf %45#3 : vector<8x16xf32> to vector<8x16xf16> + %47 = arith.truncf %45#4 : vector<8x16xf32> to vector<8x16xf16> + %48 = arith.truncf %45#5 : vector<8x16xf32> to vector<8x16xf16> + %49 = arith.truncf %45#6 : vector<8x16xf32> to vector<8x16xf16> + %50 = arith.truncf %45#7 : vector<8x16xf32> to vector<8x16xf16> + %51 = arith.truncf %45#8 : vector<8x16xf32> to vector<8x16xf16> + %52 = arith.truncf %45#9 : vector<8x16xf32> to vector<8x16xf16> + %53 = arith.truncf %45#10 : vector<8x16xf32> to vector<8x16xf16> + %54 = arith.truncf %45#11 : vector<8x16xf32> to vector<8x16xf16> + %55 = arith.truncf %45#12 : vector<8x16xf32> to vector<8x16xf16> + %56 = arith.truncf %45#13 : vector<8x16xf32> to vector<8x16xf16> + %57 = arith.truncf %45#14 : vector<8x16xf32> to vector<8x16xf16> + %58 = arith.truncf %45#15 : vector<8x16xf32> to vector<8x16xf16> + %59 = arith.truncf %45#16 : vector<8x16xf32> to vector<8x16xf16> + %60 = arith.truncf %45#17 : vector<8x16xf32> to vector<8x16xf16> + %61 = arith.truncf %45#18 : vector<8x16xf32> to vector<8x16xf16> + %62 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x16xf16> + %63 = xegpu.update_nd_offset %62, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %64 = xegpu.update_nd_offset %63, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %65 = xegpu.update_nd_offset %64, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %66 = xegpu.update_nd_offset %62, [%c8, %c0] : !xegpu.tensor_desc<8x16xf16> + %67 = xegpu.update_nd_offset %66, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %68 = xegpu.update_nd_offset %67, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %69 = xegpu.update_nd_offset %68, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %70 = xegpu.update_nd_offset %62, [%c16, %c0] : !xegpu.tensor_desc<8x16xf16> + %71 = xegpu.update_nd_offset %70, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %72 = xegpu.update_nd_offset %71, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %73 = xegpu.update_nd_offset %72, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %74 = xegpu.update_nd_offset %62, [%c24, %c0] : !xegpu.tensor_desc<8x16xf16> + %75 = xegpu.update_nd_offset %74, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %76 = xegpu.update_nd_offset %75, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %77 = xegpu.update_nd_offset %76, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %46, %62 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %47, %63 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %48, %64 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %49, %65 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %50, %66 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %51, %67 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %52, %68 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %53, %69 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %54, %70 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %55, %71 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %56, %72 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %57, %73 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %58, %74 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %59, %75 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %60, %76 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %61, %77 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 5.000000e-01 : f32 + %cst_1 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c1_f16 = arith.constant 1.0 : f16 - %c2_f16 = arith.constant 2.0 : f16 %c4096 = arith.constant 4096 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<4096x4096xf16> - %B = memref.alloc() : memref<4096x4096xf16> - %C = memref.alloc() : memref<4096x4096xf16> - %C_ref = memref.alloc() : memref<4096x4096xf32> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 // Use one of the two options to initialize the A matrix // Option 1: intialize matrix A ; A[i, j] = j // scf.for %i = %c0 to %c4096 step %c1 { @@ -467,10 +364,6 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %A_random = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - // Use one of the two options below to initialize the B matrix // Option 1: make matrix B an identity matrix // scf.for %i = %c0 to %c4096 step %c1 { @@ -478,7 +371,6 @@ module @gemm attributes {gpu.container_module} { // %i_i32 = index.castu %i : index to i32 // %j_i32 = index.castu %j : index to i32 // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - // scf.if %i_j_same { // memref.store %cf_1, %B[%i, %j] : memref<4096x4096xf16> // } else { @@ -487,45 +379,38 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %B_random = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f16 = arith.constant 0.0 : f16 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_f16, %C[%i, %j] : memref<4096x4096xf16> - memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + %cst_2 = arith.constant 0.000000e+00 : f16 + %alloc = memref.alloc() : memref<4096x4096xf16> + %alloc_3 = memref.alloc() : memref<4096x4096xf16> + %alloc_4 = memref.alloc() : memref<4096x4096xf16> + %alloc_5 = memref.alloc() : memref<4096x4096xf32> + %cast = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_1, %cst_0, %false) : (memref<*xf16>, f32, f32, i1) -> () + %cast_6 = memref.cast %alloc_3 : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_6, %cst_1, %cst_0, %false) : (memref<*xf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + memref.store %cst_2, %alloc_4[%arg0, %arg1] : memref<4096x4096xf16> + memref.store %cst, %alloc_5[%arg0, %arg1] : memref<4096x4096xf32> } } - // Run GPU. - %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - // Run CPU. - %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - %C_ref_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - call @gemmF16F16F16(%A_cast, %B_cast, %C_ref_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () - - - %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> - %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () - - %C_row_0_gpu = memref.subview %2[0, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset:0>> - %C_row_0_cast_gpu = memref.cast %C_row_0_gpu : memref<1x4096xf16, strided<[4096, 1], offset: 0>> to memref<*xf16> // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %C_ref_cast) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<4096x4096xf16> - memref.dealloc %B : memref<4096x4096xf16> - memref.dealloc %C : memref<4096x4096xf16> - memref.dealloc %C_ref : memref<4096x4096xf32> + %0 = call @test(%alloc, %alloc_3, %alloc_4) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %cast_7 = memref.cast %0 : memref<4096x4096xf16> to memref<*xf16> + %cast_8 = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + %cast_9 = memref.cast %alloc_3 : memref<4096x4096xf16> to memref<*xf16> + %cast_10 = memref.cast %alloc_5 : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%cast_8, %cast_9, %cast_10) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%cast_7, %cast_10) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4096x4096xf16> + memref.dealloc %alloc_3 : memref<4096x4096xf16> + memref.dealloc %alloc_4 : memref<4096x4096xf16> + memref.dealloc %alloc_5 : memref<4096x4096xf32> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} @@ -533,5 +418,4 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir index 7ebef8539..609cd405e 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_8x32xf16_stores.mlir @@ -1,35 +1,31 @@ -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %A, %A_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %B_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %B, %B_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %C, %C_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf16>) - gpu.dealloc %A_gpu : memref<4096x4096xf16> - gpu.dealloc %B_gpu : memref<4096x4096xf16> - return %C_gpu : memref<4096x4096xf16> + %memref = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref, %arg0 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_0 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_1 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_1, %arg2 : memref<4096x4096xf16>, memref<4096x4096xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<4096x4096xf16>, %memref_0 : memref<4096x4096xf16>, %memref_1 : memref<4096x4096xf16>) + gpu.dealloc %memref : memref<4096x4096xf16> + gpu.dealloc %memref_0 : memref<4096x4096xf16> + %alloc = memref.alloc() : memref<4096x4096xf16> + gpu.memcpy %alloc, %memref_1 : memref<4096x4096xf16>, memref<4096x4096xf16> + gpu.dealloc %memref_1 : memref<4096x4096xf16> + return %alloc : memref<4096x4096xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // constants + gpu.func @test_kernel(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c256 = arith.constant 256 : index %c512 = arith.constant 512 : index %c128 = arith.constant 128 : index @@ -45,10 +41,7 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 // get IDs - %wg_id_x = gpu.block_id x - %wg_id_y = gpu.block_id y // %sg_id = gpu.subgroup_id : index - // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout // C sg tile size is 32x64 // SG layout for one C tile update @@ -59,17 +52,8 @@ module @gemm attributes {gpu.container_module} { // --> y means cols // | // V x means rows - // get unique sg ID in global context - %global_sg_id_x = gpu.global_id x - %global_sg_id_y = gpu.global_id y - %local_sg_id_x = arith.remui %global_sg_id_x, %c8 : index - %local_sg_id_y = arith.remui %global_sg_id_y, %c4 : index - // compute SG C tile offsets in x and y dims - %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index - %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index - // each SG needs to do the follwoing compute to update its 32x64 sub tile // (32x4096)x(4096x64)=(32x64) // DPAS size is (8x16)x(16x16)=(8x16) @@ -77,14 +61,7 @@ module @gemm attributes {gpu.container_module} { // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) // this will require 32 DPAS ops (4x2x2) inside the K loop - // WG tiles are 256x256 so there offsets are same for A, B and C - %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index - %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index - - %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index - %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index - // prefetching A and B slice within the 256x256 WG tile // // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice @@ -96,21 +73,11 @@ module @gemm attributes {gpu.container_module} { // SG 4 -> slice 4 // .... // SG 31 -> slice 31 - %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index - %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index // create A preftech tiles and prefetch // stage 1 - %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) - %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch within K loop - %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - // prefetch the entire 32x256 slice of B WG tile, we still use the prefetch size 8x32. // SGs have 8x4 layout. In this case 8 subgroups must do a colloborative prefetch of 32x64 tile. // this because the B tile arrangement within the 32x256 slice is as follows @@ -126,355 +93,272 @@ module @gemm attributes {gpu.container_module} { // | 8 | 24|| 9 | 25 || 10 | 26 || 11| 27 | // | 12 | 28|| 13 | 29 || 14 | 30 || 15| 31 | // For example, SGs 0,4,8,12,16,20,24,28 share the data in left 32x64 tile of B slice. - // calculate the x offsets and y offsets within the 32x256 slice - %B_sg_prefetch_offset_x_temp0 = arith.remui %local_sg_id_x, %c4 : index - %B_sg_prefetch_offset_x = arith.muli %B_sg_prefetch_offset_x_temp0, %c8 : index - %B_sg_prefetch_offset_y_temp0 = arith.muli %local_sg_id_y, %c64 : index - %B_sg_prefetch_offset_y_temp1 = arith.divui %local_sg_id_x, %c4 : index - %B_sg_prefetch_offset_y_temp2 = arith.muli %B_sg_prefetch_offset_y_temp1, %c32 : index - %B_sg_prefetch_offset_y_temp3 = arith.addi %B_sg_prefetch_offset_y_temp0, %B_sg_prefetch_offset_y_temp2 : index - %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_temp3 : index - // create B prefetch tiles and prefetch - %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) - %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch inside K loop - %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - - // two 32x16 A tiles from 256x32 WG slice - %A_sg_init_tile_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %A_sg_init_tile_1 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c16] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16> - //create B tiles - %B_sg_init_tile_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %B_sg_init_tile_1 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %B_sg_init_tile_2 = xegpu.update_nd_offset %B_sg_init_tile_1, [%c0, %c16] : !xegpu.tensor_desc<32x16xf16> // %B_sg_init_tile_3 = xegpu.update_nd_offset %B_sg_init_tile_2, [%c0, %c16] : !xegpu.tensor_desc<32x16xf16> - // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix - %zero_vec = arith.constant dense<0.0> : vector<128xf32> - %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - - + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %global_id_x = gpu.global_id x + %global_id_y = gpu.global_id y + %0 = arith.remui %global_id_x, %c8 : index + %1 = arith.remui %global_id_y, %c4 : index + %2 = arith.muli %global_id_x, %c32 : index + %3 = arith.muli %global_id_y, %c64 : index + %4 = arith.muli %block_id_x, %c256 : index + %5 = arith.muli %block_id_y, %c256 : index + %6 = arith.muli %0, %c4 : index + %7 = arith.addi %6, %1 : index + %8 = arith.muli %7, %c8 : index + %9 = arith.addi %8, %4 : index + %10 = xegpu.create_nd_tdesc %arg0[%9, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %11 = xegpu.update_nd_offset %10, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %12 = xegpu.update_nd_offset %11, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %13 = xegpu.update_nd_offset %12, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %14 = arith.remui %0, %c4 : index + %15 = arith.muli %14, %c8 : index + %16 = arith.muli %1, %c64 : index + %17 = arith.divui %0, %c4 : index + %18 = arith.muli %17, %c32 : index + %19 = arith.addi %16, %18 : index + %20 = arith.addi %5, %19 : index + %21 = xegpu.create_nd_tdesc %arg1[%15, %20] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %22 = xegpu.update_nd_offset %21, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %23 = xegpu.update_nd_offset %22, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %24 = xegpu.update_nd_offset %23, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %25 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %26 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %27 = xegpu.update_nd_offset %26, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %cst = arith.constant dense<0.000000e+00> : vector<128xf32> + %28 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %29 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %30 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %31 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %32 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %33 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %34 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %35 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %36 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %37 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %38 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %39 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %40 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %41 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %42 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %43 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> xegpu.alloc_nbarrier 16 - %nbarrier_id = arith.constant 1 : i8 - %num_threads = arith.constant 32 : i8 - %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier // K loop advances in 32 steps - %k_loop_result:21 = scf.for %k = %c0 to %c4096 step %c32 iter_args ( - %A_tile_0 = %A_sg_init_tile_0, // %A_tile_1 = %A_sg_init_tile_1, - - %B_tile_0 = %B_sg_init_tile_0, - %B_tile_1 = %B_sg_init_tile_1, // %B_tile_2 = %B_sg_init_tile_2, // %B_tile_3 = %B_sg_init_tile_3, - - %c_val_0_0 = %c_init_val_0_0, - %c_val_0_1 = %c_init_val_0_1, - %c_val_0_2 = %c_init_val_0_2, - %c_val_0_3 = %c_init_val_0_3, - %c_val_1_0 = %c_init_val_1_0, - %c_val_1_1 = %c_init_val_1_1, - %c_val_1_2 = %c_init_val_1_2, - %c_val_1_3 = %c_init_val_1_3, - %c_val_2_0 = %c_init_val_2_0, - %c_val_2_1 = %c_init_val_2_1, - %c_val_2_2 = %c_init_val_2_2, - %c_val_2_3 = %c_init_val_2_3, - %c_val_3_0 = %c_init_val_3_0, - %c_val_3_1 = %c_init_val_3_1, - %c_val_3_2 = %c_init_val_3_2, - %c_val_3_3 = %c_init_val_3_3, - - %A_prefetch_tile = %A_sg_prefetch_tile_iter3, - %B_prefetch_tile = %B_sg_prefetch_tile_iter3 - ) -> - (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> - ) - { // all SGs must arrive here first - %every_8th_iter = arith.remui %k, %c256 : index - %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 - %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 - scf.if %every_8th_iter_cond { - xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier + %c1_i8 = arith.constant 1 : i8 + %c32_i8 = arith.constant 32 : i8 + %44 = xegpu.init_nbarrier %c1_i8, %c32_i8 : i8, i8 -> !xegpu.nbarrier + %45:21 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %25, %arg5 = %26, %arg6 = %27, %arg7 = %28, %arg8 = %29, %arg9 = %30, %arg10 = %31, %arg11 = %32, %arg12 = %33, %arg13 = %34, %arg14 = %35, %arg15 = %36, %arg16 = %37, %arg17 = %38, %arg18 = %39, %arg19 = %40, %arg20 = %41, %arg21 = %42, %arg22 = %43, %arg23 = %13, %arg24 = %24) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16>) { + %94 = arith.remui %arg3, %c256 : index + %95 = arith.index_cast %94 : index to i32 + %96 = arith.cmpi eq, %95, %c0_i32 : i32 + scf.if %96 { + xegpu.nbarrier_arrive %44 : !xegpu.nbarrier } // load A tiles - %a_val = xegpu.load_nd %A_tile_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - %a_val_0 = vector.extract %a_val [0] : vector<32x16xf16> from vector<2x32x16xf16> - %a_val_1 = vector.extract %a_val [1] : vector<32x16xf16> from vector<2x32x16xf16> - // load B tiles - %b_val_arr_0 = xegpu.load_nd %B_tile_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> - %b_val_arr_1 = xegpu.load_nd %B_tile_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> - - %b_val_0 = vector.extract %b_val_arr_0 [0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_1 = vector.extract %b_val_arr_0 [1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_2 = vector.extract %b_val_arr_1 [0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_3 = vector.extract %b_val_arr_1 [1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - - xegpu.compile_hint - // prefetch A and B tiles - xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> - // - xegpu.compile_hint - // advance A and B prefetch tiles - %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> // advance A and B tiles - %next_A_tile_0 = xegpu.update_nd_offset %A_tile_0, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %next_A_tile_1 = xegpu.update_nd_offset %A_tile_1, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16> - - %next_B_tile_0 = xegpu.update_nd_offset %B_tile_0, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %next_B_tile_1 = xegpu.update_nd_offset %B_tile_1, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %next_B_tile_2 = xegpu.update_nd_offset %B_tile_2, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16> // %next_B_tile_3 = xegpu.update_nd_offset %B_tile_3, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16> - - xegpu.compile_hint - %a_val_0_flat = vector.shape_cast %a_val_0 : vector<32x16xf16> to vector<512xf16> - %a_val_1_flat = vector.shape_cast %a_val_1 : vector<32x16xf16> to vector<512xf16> - %a_val_0_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_0_0 = vector.shape_cast %a_val_0_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_1_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [128], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_1_0 = vector.shape_cast %a_val_1_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_2_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_2_0 = vector.shape_cast %a_val_2_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_3_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [384], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_3_0 = vector.shape_cast %a_val_3_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_0_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_0_1 = vector.shape_cast %a_val_0_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_1_1_flat = vector.extract_strided_slice %a_val_1_flat {offsets = [128], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_1_1 = vector.shape_cast %a_val_1_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_2_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_2_1 = vector.shape_cast %a_val_2_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_3_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [384], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_3_1 = vector.shape_cast %a_val_3_1_flat : vector<128xf16> to vector<8x16xf16> - - - %b_val_0_flat = vector.shape_cast %b_val_0 : vector<16x16x2xf16> to vector<512xf16> - %b_val_1_flat = vector.shape_cast %b_val_1 : vector<16x16x2xf16> to vector<512xf16> - %b_val_2_flat = vector.shape_cast %b_val_2 : vector<16x16x2xf16> to vector<512xf16> - %b_val_3_flat = vector.shape_cast %b_val_3 : vector<16x16x2xf16> to vector<512xf16> - %b_val_0_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_0 = vector.shape_cast %b_val_0_0_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_0 = vector.shape_cast %b_val_1_0_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_1_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_1 = vector.shape_cast %b_val_0_1_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_1_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_1 = vector.shape_cast %b_val_1_1_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_2_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_2 = vector.shape_cast %b_val_0_2_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_2_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_2 = vector.shape_cast %b_val_1_2_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_3_flat = vector.extract_strided_slice %b_val_3_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_3 = vector.shape_cast %b_val_0_3_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_3_flat = vector.extract_strided_slice %b_val_3_flat {offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_3 = vector.shape_cast %b_val_1_3_flat : vector<256xf16> to vector<8x16x2xf16> - // do DPAS + // barrier wait + %97 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %98 = vector.extract %97[0] : vector<32x16xf16> from vector<2x32x16xf16> + %99 = vector.extract %97[1] : vector<32x16xf16> from vector<2x32x16xf16> + %100 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> + %101 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> + %102 = vector.extract %100[0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %103 = vector.extract %100[1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %104 = vector.extract %101[0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %105 = vector.extract %101[1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> xegpu.compile_hint - %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.prefetch_nd %arg23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %arg24 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - + %106 = xegpu.update_nd_offset %arg23, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %107 = xegpu.update_nd_offset %arg24, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %108 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %109 = xegpu.update_nd_offset %arg5, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %110 = xegpu.update_nd_offset %arg6, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> xegpu.compile_hint - // barrier wait - scf.if %every_8th_iter_cond { - xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier + %111 = vector.shape_cast %98 : vector<32x16xf16> to vector<512xf16> + %112 = vector.shape_cast %99 : vector<32x16xf16> to vector<512xf16> + %113 = vector.extract_strided_slice %111 {offsets = [0], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %114 = vector.shape_cast %113 : vector<128xf16> to vector<8x16xf16> + %115 = vector.extract_strided_slice %111 {offsets = [128], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %116 = vector.shape_cast %115 : vector<128xf16> to vector<8x16xf16> + %117 = vector.extract_strided_slice %111 {offsets = [256], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %118 = vector.shape_cast %117 : vector<128xf16> to vector<8x16xf16> + %119 = vector.extract_strided_slice %111 {offsets = [384], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %120 = vector.shape_cast %119 : vector<128xf16> to vector<8x16xf16> + %121 = vector.extract_strided_slice %112 {offsets = [0], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %122 = vector.shape_cast %121 : vector<128xf16> to vector<8x16xf16> + %123 = vector.extract_strided_slice %112 {offsets = [128], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %124 = vector.shape_cast %123 : vector<128xf16> to vector<8x16xf16> + %125 = vector.extract_strided_slice %112 {offsets = [256], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %126 = vector.shape_cast %125 : vector<128xf16> to vector<8x16xf16> + %127 = vector.extract_strided_slice %112 {offsets = [384], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %128 = vector.shape_cast %127 : vector<128xf16> to vector<8x16xf16> + %129 = vector.shape_cast %102 : vector<16x16x2xf16> to vector<512xf16> + %130 = vector.shape_cast %103 : vector<16x16x2xf16> to vector<512xf16> + %131 = vector.shape_cast %104 : vector<16x16x2xf16> to vector<512xf16> + %132 = vector.shape_cast %105 : vector<16x16x2xf16> to vector<512xf16> + %133 = vector.extract_strided_slice %129 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %134 = vector.shape_cast %133 : vector<256xf16> to vector<8x16x2xf16> + %135 = vector.extract_strided_slice %129 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %136 = vector.shape_cast %135 : vector<256xf16> to vector<8x16x2xf16> + %137 = vector.extract_strided_slice %130 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %138 = vector.shape_cast %137 : vector<256xf16> to vector<8x16x2xf16> + %139 = vector.extract_strided_slice %130 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %140 = vector.shape_cast %139 : vector<256xf16> to vector<8x16x2xf16> + %141 = vector.extract_strided_slice %131 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %142 = vector.shape_cast %141 : vector<256xf16> to vector<8x16x2xf16> + %143 = vector.extract_strided_slice %131 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %144 = vector.shape_cast %143 : vector<256xf16> to vector<8x16x2xf16> + %145 = vector.extract_strided_slice %132 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %146 = vector.shape_cast %145 : vector<256xf16> to vector<8x16x2xf16> + %147 = vector.extract_strided_slice %132 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %148 = vector.shape_cast %147 : vector<256xf16> to vector<8x16x2xf16> + xegpu.compile_hint + %149 = xegpu.dpas %114, %134, %arg7 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %150 = xegpu.dpas %122, %136, %149 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %151 = xegpu.dpas %116, %134, %arg11 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %152 = xegpu.dpas %124, %136, %151 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %153 = xegpu.dpas %118, %134, %arg15 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %154 = xegpu.dpas %126, %136, %153 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %155 = xegpu.dpas %120, %134, %arg19 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %156 = xegpu.dpas %128, %136, %155 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %157 = xegpu.dpas %114, %138, %arg8 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %158 = xegpu.dpas %122, %140, %157 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %159 = xegpu.dpas %116, %138, %arg12 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %160 = xegpu.dpas %124, %140, %159 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %161 = xegpu.dpas %118, %138, %arg16 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %162 = xegpu.dpas %126, %140, %161 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %163 = xegpu.dpas %120, %138, %arg20 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %164 = xegpu.dpas %128, %140, %163 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %165 = xegpu.dpas %114, %142, %arg9 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %166 = xegpu.dpas %122, %144, %165 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %167 = xegpu.dpas %116, %142, %arg13 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %168 = xegpu.dpas %124, %144, %167 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %169 = xegpu.dpas %118, %142, %arg17 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %170 = xegpu.dpas %126, %144, %169 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %171 = xegpu.dpas %120, %142, %arg21 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %172 = xegpu.dpas %128, %144, %171 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %173 = xegpu.dpas %114, %146, %arg10 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %174 = xegpu.dpas %122, %148, %173 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %175 = xegpu.dpas %116, %146, %arg14 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %176 = xegpu.dpas %124, %148, %175 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %177 = xegpu.dpas %118, %146, %arg18 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %178 = xegpu.dpas %126, %148, %177 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %179 = xegpu.dpas %120, %146, %arg22 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %180 = xegpu.dpas %128, %148, %179 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + scf.if %96 { + xegpu.nbarrier_wait %44 : !xegpu.nbarrier } - - scf.yield %next_A_tile_0, %next_B_tile_0, %next_B_tile_1, - %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, - %next_A_prefetch_tile, %next_B_prefetch_tile - : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + scf.yield %108, %109, %110, %150, %158, %166, %174, %152, %160, %168, %176, %154, %162, %170, %178, %156, %164, %172, %180, %106, %107 : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> } - // trunc all DPAS output tiles to f16 - %c_result_0_0_f16 = arith.truncf %k_loop_result#3 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_1_f16 = arith.truncf %k_loop_result#4 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_2_f16 = arith.truncf %k_loop_result#5 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_3_f16 = arith.truncf %k_loop_result#6 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_0_f16 = arith.truncf %k_loop_result#7 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_1_f16 = arith.truncf %k_loop_result#8 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_2_f16 = arith.truncf %k_loop_result#9 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_3_f16 = arith.truncf %k_loop_result#10 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_0_f16 = arith.truncf %k_loop_result#11 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_1_f16 = arith.truncf %k_loop_result#12 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_2_f16 = arith.truncf %k_loop_result#13 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_3_f16 = arith.truncf %k_loop_result#14 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_0_f16 = arith.truncf %k_loop_result#15 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_1_f16 = arith.truncf %k_loop_result#16 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_2_f16 = arith.truncf %k_loop_result#17 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_3_f16 = arith.truncf %k_loop_result#18 : vector<8x16xf32> to vector<8x16xf16> - // each SG needs to store the result of K loop into a 32x64 tile in C matrix. This is organized in 8x16 DPAS tiles // in the layout of 4x4x8x16. The max store size HW supoprt in f16 is 8x32. So we combine two 8x16 DPAS tiles // horizontally using vector.shuffle to get the required store size. The store layout then will 4x2x8x32 i.e. // we have 8 stores of size 8x32 in the layout 4x2. - - %c_result_8x32_0_0_t1 = vector.shuffle %c_result_0_0_f16, %c_result_0_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_0_0_t2 = vector.shape_cast %c_result_8x32_0_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_0_0 = vector.shape_cast %c_result_8x32_0_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_00 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_0_0, %c_sg_tile_00 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %46 = arith.truncf %45#3 : vector<8x16xf32> to vector<8x16xf16> + %47 = arith.truncf %45#4 : vector<8x16xf32> to vector<8x16xf16> + %48 = arith.truncf %45#5 : vector<8x16xf32> to vector<8x16xf16> + %49 = arith.truncf %45#6 : vector<8x16xf32> to vector<8x16xf16> + %50 = arith.truncf %45#7 : vector<8x16xf32> to vector<8x16xf16> + %51 = arith.truncf %45#8 : vector<8x16xf32> to vector<8x16xf16> + %52 = arith.truncf %45#9 : vector<8x16xf32> to vector<8x16xf16> + %53 = arith.truncf %45#10 : vector<8x16xf32> to vector<8x16xf16> + %54 = arith.truncf %45#11 : vector<8x16xf32> to vector<8x16xf16> + %55 = arith.truncf %45#12 : vector<8x16xf32> to vector<8x16xf16> + %56 = arith.truncf %45#13 : vector<8x16xf32> to vector<8x16xf16> + %57 = arith.truncf %45#14 : vector<8x16xf32> to vector<8x16xf16> + %58 = arith.truncf %45#15 : vector<8x16xf32> to vector<8x16xf16> + %59 = arith.truncf %45#16 : vector<8x16xf32> to vector<8x16xf16> + %60 = arith.truncf %45#17 : vector<8x16xf32> to vector<8x16xf16> + %61 = arith.truncf %45#18 : vector<8x16xf32> to vector<8x16xf16> + %62 = vector.shuffle %46, %47 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %63 = vector.shape_cast %62 : vector<16x16xf16> to vector<256xf16> + %64 = vector.shape_cast %63 : vector<256xf16> to vector<8x32xf16> + %65 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %64, %65 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_0_1_t1 = vector.shuffle %c_result_0_2_f16, %c_result_0_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_0_1_t2 = vector.shape_cast %c_result_8x32_0_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_0_1 = vector.shape_cast %c_result_8x32_0_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_01 = xegpu.update_nd_offset %c_sg_tile_00, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_0_1, %c_sg_tile_01 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %66 = vector.shuffle %48, %49 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %67 = vector.shape_cast %66 : vector<16x16xf16> to vector<256xf16> + %68 = vector.shape_cast %67 : vector<256xf16> to vector<8x32xf16> + %69 = xegpu.update_nd_offset %65, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %68, %69 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_1_0_t1 = vector.shuffle %c_result_1_0_f16, %c_result_1_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_1_0_t2 = vector.shape_cast %c_result_8x32_1_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_1_0 = vector.shape_cast %c_result_8x32_1_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_10 = xegpu.update_nd_offset %c_sg_tile_00, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_1_0, %c_sg_tile_10 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %70 = vector.shuffle %50, %51 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %71 = vector.shape_cast %70 : vector<16x16xf16> to vector<256xf16> + %72 = vector.shape_cast %71 : vector<256xf16> to vector<8x32xf16> + %73 = xegpu.update_nd_offset %65, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %72, %73 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - - %c_result_8x32_1_1_t1 = vector.shuffle %c_result_1_2_f16, %c_result_1_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_1_1_t2 = vector.shape_cast %c_result_8x32_1_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_1_1 = vector.shape_cast %c_result_8x32_1_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_11 = xegpu.update_nd_offset %c_sg_tile_01, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_1_1, %c_sg_tile_11 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %74 = vector.shuffle %52, %53 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %75 = vector.shape_cast %74 : vector<16x16xf16> to vector<256xf16> + %76 = vector.shape_cast %75 : vector<256xf16> to vector<8x32xf16> + %77 = xegpu.update_nd_offset %69, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %76, %77 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_2_0_t1 = vector.shuffle %c_result_2_0_f16, %c_result_2_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_2_0_t2 = vector.shape_cast %c_result_8x32_2_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_2_0 = vector.shape_cast %c_result_8x32_2_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_20 = xegpu.update_nd_offset %c_sg_tile_10, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_2_0, %c_sg_tile_20 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %78 = vector.shuffle %54, %55 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %79 = vector.shape_cast %78 : vector<16x16xf16> to vector<256xf16> + %80 = vector.shape_cast %79 : vector<256xf16> to vector<8x32xf16> + %81 = xegpu.update_nd_offset %73, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %80, %81 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_2_1_t1 = vector.shuffle %c_result_2_2_f16, %c_result_2_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_2_1_t2 = vector.shape_cast %c_result_8x32_2_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_2_1 = vector.shape_cast %c_result_8x32_2_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_21 = xegpu.update_nd_offset %c_sg_tile_11, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_2_1, %c_sg_tile_21 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %82 = vector.shuffle %56, %57 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %83 = vector.shape_cast %82 : vector<16x16xf16> to vector<256xf16> + %84 = vector.shape_cast %83 : vector<256xf16> to vector<8x32xf16> + %85 = xegpu.update_nd_offset %77, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %84, %85 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_3_0_t1 = vector.shuffle %c_result_3_0_f16, %c_result_3_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_3_0_t2 = vector.shape_cast %c_result_8x32_3_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_3_0 = vector.shape_cast %c_result_8x32_3_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_30 = xegpu.update_nd_offset %c_sg_tile_20, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_3_0, %c_sg_tile_30 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %86 = vector.shuffle %58, %59 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %87 = vector.shape_cast %86 : vector<16x16xf16> to vector<256xf16> + %88 = vector.shape_cast %87 : vector<256xf16> to vector<8x32xf16> + %89 = xegpu.update_nd_offset %81, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %88, %89 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_3_1_t1 = vector.shuffle %c_result_3_2_f16, %c_result_3_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_3_1_t2 = vector.shape_cast %c_result_8x32_3_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_3_1 = vector.shape_cast %c_result_8x32_3_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_31 = xegpu.update_nd_offset %c_sg_tile_21, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_3_1, %c_sg_tile_31 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> - + %90 = vector.shuffle %60, %61 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %91 = vector.shape_cast %90 : vector<16x16xf16> to vector<256xf16> + %92 = vector.shape_cast %91 : vector<256xf16> to vector<8x32xf16> + %93 = xegpu.update_nd_offset %85, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %92, %93 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c1_f16 = arith.constant 1.0 : f16 - %c2_f16 = arith.constant 2.0 : f16 %c4096 = arith.constant 4096 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - %A = memref.alloc() : memref<4096x4096xf16> - %B = memref.alloc() : memref<4096x4096xf16> - %C = memref.alloc() : memref<4096x4096xf16> - %C_ref = memref.alloc() : memref<4096x4096xf32> - // Use one of the two options to initialize the A matrix // Option 1: intialize matrix A ; A[i, j] = j // scf.for %i = %c0 to %c4096 step %c1 { @@ -487,9 +371,6 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %A_random = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // Use one of the two options below to initialize the B matrix // Option 1: make matrix B an identity matrix // scf.for %i = %c0 to %c4096 step %c1 { @@ -497,7 +378,6 @@ module @gemm attributes {gpu.container_module} { // %i_i32 = index.castu %i : index to i32 // %j_i32 = index.castu %j : index to i32 // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - // scf.if %i_j_same { // memref.store %cf_1, %B[%i, %j] : memref<4096x4096xf16> // } else { @@ -506,42 +386,41 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %B_random = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f16 = arith.constant 0.0 : f16 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_f16, %C[%i, %j] : memref<4096x4096xf16> - memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + %cst_0 = arith.constant 0.000000e+00 : f16 + %false = arith.constant false + %cst_1 = arith.constant -5.000000e-01 : f32 + %cst_2 = arith.constant 5.000000e-01 : f32 + %alloc = memref.alloc() : memref<4096x4096xf16> + %alloc_3 = memref.alloc() : memref<4096x4096xf16> + %alloc_4 = memref.alloc() : memref<4096x4096xf16> + %alloc_5 = memref.alloc() : memref<4096x4096xf32> + %cast = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_1, %cst_2, %false) : (memref<*xf16>, f32, f32, i1) -> () + %cast_6 = memref.cast %alloc_3 : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_6, %cst_1, %cst_2, %false) : (memref<*xf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + memref.store %cst_0, %alloc_4[%arg0, %arg1] : memref<4096x4096xf16> + memref.store %cst, %alloc_5[%arg0, %arg1] : memref<4096x4096xf32> } } - // Run GPU. - %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> // Run CPU. - %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () - - %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> - %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () - - %C_row_0_gpu = memref.subview %2[0, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset:0>> - %C_row_0_cast_gpu = memref.cast %C_row_0_gpu : memref<1x4096xf16, strided<[4096, 1], offset: 0>> to memref<*xf16> // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %C_cast) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<4096x4096xf16> - memref.dealloc %B : memref<4096x4096xf16> - memref.dealloc %C : memref<4096x4096xf16> - memref.dealloc %C_ref : memref<4096x4096xf32> + %0 = call @test(%alloc, %alloc_3, %alloc_4) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %cast_7 = memref.cast %0 : memref<4096x4096xf16> to memref<*xf16> + %cast_8 = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + %cast_9 = memref.cast %alloc_3 : memref<4096x4096xf16> to memref<*xf16> + %cast_10 = memref.cast %alloc_5 : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%cast_8, %cast_9, %cast_10) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%cast_7, %cast_10) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4096x4096xf16> + memref.dealloc %alloc_3 : memref<4096x4096xf16> + memref.dealloc %alloc_4 : memref<4096x4096xf16> + memref.dealloc %alloc_5 : memref<4096x4096xf32> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} @@ -549,5 +428,4 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir index e0de737c8..215d12703 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_4kx4kx4k_f16_f16_f16_w_simple_B_prefetch.mlir @@ -1,35 +1,31 @@ -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) -> memref<4096x4096xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %A, %A_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %B_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %B, %B_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %C, %C_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf16>) - gpu.dealloc %A_gpu : memref<4096x4096xf16> - gpu.dealloc %B_gpu : memref<4096x4096xf16> - return %C_gpu : memref<4096x4096xf16> + %memref = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref, %arg0 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_0 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_1 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_1, %arg2 : memref<4096x4096xf16>, memref<4096x4096xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<4096x4096xf16>, %memref_0 : memref<4096x4096xf16>, %memref_1 : memref<4096x4096xf16>) + gpu.dealloc %memref : memref<4096x4096xf16> + gpu.dealloc %memref_0 : memref<4096x4096xf16> + %alloc = memref.alloc() : memref<4096x4096xf16> + gpu.memcpy %alloc, %memref_1 : memref<4096x4096xf16>, memref<4096x4096xf16> + gpu.dealloc %memref_1 : memref<4096x4096xf16> + return %alloc : memref<4096x4096xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // constants + gpu.func @test_kernel(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c256 = arith.constant 256 : index %c512 = arith.constant 512 : index %c128 = arith.constant 128 : index @@ -45,10 +41,7 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 // get IDs - %wg_id_x = gpu.block_id x - %wg_id_y = gpu.block_id y // %sg_id = gpu.subgroup_id : index - // each C wg tile is 256x256 and 32 SGs update it in 8x4 layout // C sg tile size is 32x64 // SG layout for one C tile update @@ -59,17 +52,8 @@ module @gemm attributes {gpu.container_module} { // --> y means cols // | // V x means rows - // get unique sg ID in global context - %global_sg_id_x = gpu.global_id x - %global_sg_id_y = gpu.global_id y - %local_sg_id_x = arith.remui %global_sg_id_x, %c8 : index - %local_sg_id_y = arith.remui %global_sg_id_y, %c4 : index - // compute SG C tile offsets in x and y dims - %C_sg_tile_offset_x = arith.muli %global_sg_id_x, %c32 : index - %C_sg_tile_offset_y = arith.muli %global_sg_id_y, %c64 : index - // each SG needs to do the follwoing compute to update its 32x64 sub tile // (32x4096)x(4096x64)=(32x64) // DPAS size is (8x16)x(16x16)=(8x16) @@ -77,14 +61,7 @@ module @gemm attributes {gpu.container_module} { // So we need to (4x2) A tiles of size (8x16) and (2x4) B tiles of size (16x16) // tiled compute for a SG is (4x2x8x16)x(2x4x16x16)=(4x4x8x16) // this will require 32 DPAS ops (4x2x2) inside the K loop - // WG tiles are 256x256 so there offsets are same for A, B and C - %wg_tile_offset_x = arith.muli %wg_id_x, %c256 : index - %wg_tile_offset_y = arith.muli %wg_id_y, %c256 : index - - %local_sg_id_temp = arith.muli %local_sg_id_x, %c4 : index - %local_sg_id = arith.addi %local_sg_id_temp, %local_sg_id_y : index - // prefetching A and B slice within the 256x256 WG tile // // prefetch the entire 256x32 slice of A WG tile, this means each subgroups needs to prefetch 8x32 slice @@ -96,21 +73,11 @@ module @gemm attributes {gpu.container_module} { // SG 4 -> slice 4 // .... // SG 31 -> slice 31 - %A_sg_prefetch_offset_x_temp = arith.muli %local_sg_id, %c8 : index - %A_sg_prefetch_offset_x = arith.addi %A_sg_prefetch_offset_x_temp, %wg_tile_offset_x : index // create A preftech tiles and prefetch // stage 1 - %A_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %A[%A_sg_prefetch_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the y direction and prefetch next 8x32 tile) - %A_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter0, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %A_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter1, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %A_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch within K loop - %A_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %A_sg_prefetch_tile_iter2, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - // ---- Simpler prefetch scheme for B prefetch ---- // Original SG layout is 8x4. And we need to prefetch 32x256 slice of B. Best prefetch size for the data type is // is 8x32. This makes the prefetch layout 4x8. To avoid complex prefetching address calculation, we convert the @@ -120,357 +87,274 @@ module @gemm attributes {gpu.container_module} { // | 8 | 9 | 10| 11| 12| 13| 14| 15| // | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | // | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | - // calculate the linear index of the SG - %linear_sg_id_t0 = arith.muli %local_sg_id_x, %c4 : index - %linear_sg_id = arith.muli %linear_sg_id_t0, %local_sg_id_y : index // convert layout to 4x8 from 8x4 - %sg_id_4x8_x = arith.divui %linear_sg_id, %c8 : index - %sg_id_4x8_y = arith.remui %linear_sg_id, %c8 : index // compute address for 8x32 slice - %B_sg_prefetch_offset_x = arith.muli %sg_id_4x8_x, %c8 : index - %B_sg_prefetch_offset_y_t0 = arith.muli %sg_id_4x8_y, %c32 : index - %B_sg_prefetch_offset_y = arith.addi %wg_tile_offset_y, %B_sg_prefetch_offset_y_t0 : index - // create B prefetch tiles and prefetch - %B_sg_prefetch_tile_iter0 = xegpu.create_nd_tdesc %B[%B_sg_prefetch_offset_x, %B_sg_prefetch_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 2 (move 32 elements in the x direction and prefetch next 8x32 tile) - %B_sg_prefetch_tile_iter1 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter0, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // stage 3 - %B_sg_prefetch_tile_iter2 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter1, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.prefetch_nd %B_sg_prefetch_tile_iter2 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // compute the next tile to prefetch inside K loop - %B_sg_prefetch_tile_iter3 = xegpu.update_nd_offset %B_sg_prefetch_tile_iter2, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> - - // two 32x16 A tiles from 256x32 WG slice - %A_sg_init_tile_0 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %A_sg_init_tile_1 = xegpu.create_nd_tdesc %A[%C_sg_tile_offset_x, %c16] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16> - //create B tiles - %B_sg_init_tile_0 = xegpu.create_nd_tdesc %B[%c0, %C_sg_tile_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %B_sg_init_tile_1 = xegpu.update_nd_offset %B_sg_init_tile_0, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr< array_length = 2>> // %B_sg_init_tile_2 = xegpu.update_nd_offset %B_sg_init_tile_1, [%c0, %c16] : !xegpu.tensor_desc<32x16xf16> // %B_sg_init_tile_3 = xegpu.update_nd_offset %B_sg_init_tile_2, [%c0, %c16] : !xegpu.tensor_desc<32x16xf16> - // init 16 C tiles of size 8x16 each is initialized to 0.0 assuming a zero C matrix - %zero_vec = arith.constant dense<0.0> : vector<128xf32> - %c_init_val_0_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_0_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_1_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_2_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_0 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_1 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_2 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - %c_init_val_3_3 = vector.shape_cast %zero_vec : vector<128xf32> to vector<8x16xf32> - - + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %global_id_x = gpu.global_id x + %global_id_y = gpu.global_id y + %0 = arith.remui %global_id_x, %c8 : index + %1 = arith.remui %global_id_y, %c4 : index + %2 = arith.muli %global_id_x, %c32 : index + %3 = arith.muli %global_id_y, %c64 : index + %4 = arith.muli %block_id_x, %c256 : index + %5 = arith.muli %block_id_y, %c256 : index + %6 = arith.muli %0, %c4 : index + %7 = arith.addi %6, %1 : index + %8 = arith.muli %7, %c8 : index + %9 = arith.addi %8, %4 : index + %10 = xegpu.create_nd_tdesc %arg0[%9, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %11 = xegpu.update_nd_offset %10, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %12 = xegpu.update_nd_offset %11, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %13 = xegpu.update_nd_offset %12, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %14 = arith.muli %0, %c4 : index + %15 = arith.muli %14, %1 : index + %16 = arith.divui %15, %c8 : index + %17 = arith.remui %15, %c8 : index + %18 = arith.muli %16, %c8 : index + %19 = arith.muli %17, %c32 : index + %20 = arith.addi %5, %19 : index + %21 = xegpu.create_nd_tdesc %arg1[%18, %20] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %22 = xegpu.update_nd_offset %21, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %23 = xegpu.update_nd_offset %22, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.prefetch_nd %23 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x32xf16> + %24 = xegpu.update_nd_offset %23, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %25 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %26 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %27 = xegpu.update_nd_offset %26, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %cst = arith.constant dense<0.000000e+00> : vector<128xf32> + %28 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %29 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %30 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %31 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %32 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %33 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %34 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %35 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %36 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %37 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %38 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %39 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %40 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %41 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %42 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %43 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> xegpu.alloc_nbarrier 16 - %nbarrier_id = arith.constant 1 : i8 - %num_threads = arith.constant 32 : i8 - %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier // K loop advances in 32 steps - %k_loop_result:21 = scf.for %k = %c0 to %c4096 step %c32 iter_args ( - %A_tile_0 = %A_sg_init_tile_0, // %A_tile_1 = %A_sg_init_tile_1, - - %B_tile_0 = %B_sg_init_tile_0, - %B_tile_1 = %B_sg_init_tile_1, // %B_tile_2 = %B_sg_init_tile_2, // %B_tile_3 = %B_sg_init_tile_3, - - %c_val_0_0 = %c_init_val_0_0, - %c_val_0_1 = %c_init_val_0_1, - %c_val_0_2 = %c_init_val_0_2, - %c_val_0_3 = %c_init_val_0_3, - %c_val_1_0 = %c_init_val_1_0, - %c_val_1_1 = %c_init_val_1_1, - %c_val_1_2 = %c_init_val_1_2, - %c_val_1_3 = %c_init_val_1_3, - %c_val_2_0 = %c_init_val_2_0, - %c_val_2_1 = %c_init_val_2_1, - %c_val_2_2 = %c_init_val_2_2, - %c_val_2_3 = %c_init_val_2_3, - %c_val_3_0 = %c_init_val_3_0, - %c_val_3_1 = %c_init_val_3_1, - %c_val_3_2 = %c_init_val_3_2, - %c_val_3_3 = %c_init_val_3_3, - - %A_prefetch_tile = %A_sg_prefetch_tile_iter3, - %B_prefetch_tile = %B_sg_prefetch_tile_iter3 - ) -> - (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> - ) - { // all SGs must arrive here first - %every_8th_iter = arith.remui %k, %c256 : index - %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 - %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 - scf.if %every_8th_iter_cond { - xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier + %c1_i8 = arith.constant 1 : i8 + %c32_i8 = arith.constant 32 : i8 + %44 = xegpu.init_nbarrier %c1_i8, %c32_i8 : i8, i8 -> !xegpu.nbarrier + %45:21 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %25, %arg5 = %26, %arg6 = %27, %arg7 = %28, %arg8 = %29, %arg9 = %30, %arg10 = %31, %arg11 = %32, %arg12 = %33, %arg13 = %34, %arg14 = %35, %arg15 = %36, %arg16 = %37, %arg17 = %38, %arg18 = %39, %arg19 = %40, %arg20 = %41, %arg21 = %42, %arg22 = %43, %arg23 = %13, %arg24 = %24) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16>) { + %94 = arith.remui %arg3, %c256 : index + %95 = arith.index_cast %94 : index to i32 + %96 = arith.cmpi eq, %95, %c0_i32 : i32 + scf.if %96 { + xegpu.nbarrier_arrive %44 : !xegpu.nbarrier } // load A tiles - %a_val = xegpu.load_nd %A_tile_0 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - %a_val_0 = vector.extract %a_val [0] : vector<32x16xf16> from vector<2x32x16xf16> - %a_val_1 = vector.extract %a_val [1] : vector<32x16xf16> from vector<2x32x16xf16> - // load B tiles - %b_val_arr_0 = xegpu.load_nd %B_tile_0 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> - %b_val_arr_1 = xegpu.load_nd %B_tile_1 {packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> - - %b_val_0 = vector.extract %b_val_arr_0 [0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_1 = vector.extract %b_val_arr_0 [1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_2 = vector.extract %b_val_arr_1 [0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - %b_val_3 = vector.extract %b_val_arr_1 [1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> - - xegpu.compile_hint - // prefetch A and B tiles // xegpu.prefetch_nd %A_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> // xegpu.prefetch_nd %B_prefetch_tile {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x32xf16> - // - xegpu.compile_hint - // advance A and B prefetch tiles - %next_A_prefetch_tile = xegpu.update_nd_offset %A_prefetch_tile, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - %next_B_prefetch_tile = xegpu.update_nd_offset %B_prefetch_tile, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> // advance A and B tiles - %next_A_tile_0 = xegpu.update_nd_offset %A_tile_0, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %next_A_tile_1 = xegpu.update_nd_offset %A_tile_1, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16> - - %next_B_tile_0 = xegpu.update_nd_offset %B_tile_0, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %next_B_tile_1 = xegpu.update_nd_offset %B_tile_1, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> // %next_B_tile_2 = xegpu.update_nd_offset %B_tile_2, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16> // %next_B_tile_3 = xegpu.update_nd_offset %B_tile_3, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16> - - xegpu.compile_hint - %a_val_0_flat = vector.shape_cast %a_val_0 : vector<32x16xf16> to vector<512xf16> - %a_val_1_flat = vector.shape_cast %a_val_1 : vector<32x16xf16> to vector<512xf16> - %a_val_0_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_0_0 = vector.shape_cast %a_val_0_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_1_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [128], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_1_0 = vector.shape_cast %a_val_1_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_2_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_2_0 = vector.shape_cast %a_val_2_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_3_0_flat = vector.extract_strided_slice %a_val_0_flat { offsets = [384], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_3_0 = vector.shape_cast %a_val_3_0_flat : vector<128xf16> to vector<8x16xf16> - %a_val_0_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [0], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_0_1 = vector.shape_cast %a_val_0_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_1_1_flat = vector.extract_strided_slice %a_val_1_flat {offsets = [128], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_1_1 = vector.shape_cast %a_val_1_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_2_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [256], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_2_1 = vector.shape_cast %a_val_2_1_flat : vector<128xf16> to vector<8x16xf16> - %a_val_3_1_flat = vector.extract_strided_slice %a_val_1_flat { offsets = [384], sizes = [128], strides = [1]} : - vector<512xf16> to vector<128xf16> - %a_val_3_1 = vector.shape_cast %a_val_3_1_flat : vector<128xf16> to vector<8x16xf16> - - - %b_val_0_flat = vector.shape_cast %b_val_0 : vector<16x16x2xf16> to vector<512xf16> - %b_val_1_flat = vector.shape_cast %b_val_1 : vector<16x16x2xf16> to vector<512xf16> - %b_val_2_flat = vector.shape_cast %b_val_2 : vector<16x16x2xf16> to vector<512xf16> - %b_val_3_flat = vector.shape_cast %b_val_3 : vector<16x16x2xf16> to vector<512xf16> - %b_val_0_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_0 = vector.shape_cast %b_val_0_0_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_0_flat = vector.extract_strided_slice %b_val_0_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_0 = vector.shape_cast %b_val_1_0_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_1_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_1 = vector.shape_cast %b_val_0_1_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_1_flat = vector.extract_strided_slice %b_val_1_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_1 = vector.shape_cast %b_val_1_1_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_2_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_2 = vector.shape_cast %b_val_0_2_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_2_flat = vector.extract_strided_slice %b_val_2_flat { offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_2 = vector.shape_cast %b_val_1_2_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_0_3_flat = vector.extract_strided_slice %b_val_3_flat { offsets = [0], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_0_3 = vector.shape_cast %b_val_0_3_flat : vector<256xf16> to vector<8x16x2xf16> - %b_val_1_3_flat = vector.extract_strided_slice %b_val_3_flat {offsets = [256], sizes = [256], strides = [1]} : - vector<512xf16> to vector<256xf16> - %b_val_1_3 = vector.shape_cast %b_val_1_3_flat : vector<256xf16> to vector<8x16x2xf16> - // do DPAS + // barrier wait + %97 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %98 = vector.extract %97[0] : vector<32x16xf16> from vector<2x32x16xf16> + %99 = vector.extract %97[1] : vector<32x16xf16> from vector<2x32x16xf16> + %100 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> + %101 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16x2xf16> + %102 = vector.extract %100[0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %103 = vector.extract %100[1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %104 = vector.extract %101[0] : vector<16x16x2xf16> from vector<2x16x16x2xf16> + %105 = vector.extract %101[1] : vector<16x16x2xf16> from vector<2x16x16x2xf16> xegpu.compile_hint - %new_c_val_0_0_temp = xegpu.dpas %a_val_0_0, %b_val_0_0, %c_val_0_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_0 = xegpu.dpas %a_val_0_1, %b_val_1_0, %new_c_val_0_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0_temp = xegpu.dpas %a_val_1_0, %b_val_0_0, %c_val_1_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_0 = xegpu.dpas %a_val_1_1, %b_val_1_0, %new_c_val_1_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0_temp = xegpu.dpas %a_val_2_0, %b_val_0_0, %c_val_2_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_0 = xegpu.dpas %a_val_2_1, %b_val_1_0, %new_c_val_2_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0_temp = xegpu.dpas %a_val_3_0, %b_val_0_0, %c_val_3_0 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_0 = xegpu.dpas %a_val_3_1, %b_val_1_0, %new_c_val_3_0_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_1_temp = xegpu.dpas %a_val_0_0, %b_val_0_1, %c_val_0_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_1 = xegpu.dpas %a_val_0_1, %b_val_1_1, %new_c_val_0_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1_temp = xegpu.dpas %a_val_1_0, %b_val_0_1, %c_val_1_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_1 = xegpu.dpas %a_val_1_1, %b_val_1_1, %new_c_val_1_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1_temp = xegpu.dpas %a_val_2_0, %b_val_0_1, %c_val_2_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_1 = xegpu.dpas %a_val_2_1, %b_val_1_1, %new_c_val_2_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1_temp = xegpu.dpas %a_val_3_0, %b_val_0_1, %c_val_3_1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_1 = xegpu.dpas %a_val_3_1, %b_val_1_1, %new_c_val_3_1_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_2_temp = xegpu.dpas %a_val_0_0, %b_val_0_2, %c_val_0_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_2 = xegpu.dpas %a_val_0_1, %b_val_1_2, %new_c_val_0_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2_temp = xegpu.dpas %a_val_1_0, %b_val_0_2, %c_val_1_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_2 = xegpu.dpas %a_val_1_1, %b_val_1_2, %new_c_val_1_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2_temp = xegpu.dpas %a_val_2_0, %b_val_0_2, %c_val_2_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_2 = xegpu.dpas %a_val_2_1, %b_val_1_2, %new_c_val_2_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2_temp = xegpu.dpas %a_val_3_0, %b_val_0_2, %c_val_3_2 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_2 = xegpu.dpas %a_val_3_1, %b_val_1_2, %new_c_val_3_2_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - %new_c_val_0_3_temp = xegpu.dpas %a_val_0_0, %b_val_0_3, %c_val_0_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_0_3 = xegpu.dpas %a_val_0_1, %b_val_1_3, %new_c_val_0_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3_temp = xegpu.dpas %a_val_1_0, %b_val_0_3, %c_val_1_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_1_3 = xegpu.dpas %a_val_1_1, %b_val_1_3, %new_c_val_1_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3_temp = xegpu.dpas %a_val_2_0, %b_val_0_3, %c_val_2_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_2_3 = xegpu.dpas %a_val_2_1, %b_val_1_3, %new_c_val_2_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - %new_c_val_3_3_temp = xegpu.dpas %a_val_3_0, %b_val_0_3, %c_val_3_3 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> xegpu.compile_hint - %new_c_val_3_3 = xegpu.dpas %a_val_3_1, %b_val_1_3, %new_c_val_3_3_temp : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - + %106 = xegpu.update_nd_offset %arg23, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + %107 = xegpu.update_nd_offset %arg24, [%c32, %c0] : !xegpu.tensor_desc<8x32xf16> + %108 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %109 = xegpu.update_nd_offset %arg5, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %110 = xegpu.update_nd_offset %arg6, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> xegpu.compile_hint - // barrier wait - scf.if %every_8th_iter_cond { - xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier + %111 = vector.shape_cast %98 : vector<32x16xf16> to vector<512xf16> + %112 = vector.shape_cast %99 : vector<32x16xf16> to vector<512xf16> + %113 = vector.extract_strided_slice %111 {offsets = [0], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %114 = vector.shape_cast %113 : vector<128xf16> to vector<8x16xf16> + %115 = vector.extract_strided_slice %111 {offsets = [128], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %116 = vector.shape_cast %115 : vector<128xf16> to vector<8x16xf16> + %117 = vector.extract_strided_slice %111 {offsets = [256], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %118 = vector.shape_cast %117 : vector<128xf16> to vector<8x16xf16> + %119 = vector.extract_strided_slice %111 {offsets = [384], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %120 = vector.shape_cast %119 : vector<128xf16> to vector<8x16xf16> + %121 = vector.extract_strided_slice %112 {offsets = [0], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %122 = vector.shape_cast %121 : vector<128xf16> to vector<8x16xf16> + %123 = vector.extract_strided_slice %112 {offsets = [128], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %124 = vector.shape_cast %123 : vector<128xf16> to vector<8x16xf16> + %125 = vector.extract_strided_slice %112 {offsets = [256], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %126 = vector.shape_cast %125 : vector<128xf16> to vector<8x16xf16> + %127 = vector.extract_strided_slice %112 {offsets = [384], sizes = [128], strides = [1]} : vector<512xf16> to vector<128xf16> + %128 = vector.shape_cast %127 : vector<128xf16> to vector<8x16xf16> + %129 = vector.shape_cast %102 : vector<16x16x2xf16> to vector<512xf16> + %130 = vector.shape_cast %103 : vector<16x16x2xf16> to vector<512xf16> + %131 = vector.shape_cast %104 : vector<16x16x2xf16> to vector<512xf16> + %132 = vector.shape_cast %105 : vector<16x16x2xf16> to vector<512xf16> + %133 = vector.extract_strided_slice %129 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %134 = vector.shape_cast %133 : vector<256xf16> to vector<8x16x2xf16> + %135 = vector.extract_strided_slice %129 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %136 = vector.shape_cast %135 : vector<256xf16> to vector<8x16x2xf16> + %137 = vector.extract_strided_slice %130 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %138 = vector.shape_cast %137 : vector<256xf16> to vector<8x16x2xf16> + %139 = vector.extract_strided_slice %130 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %140 = vector.shape_cast %139 : vector<256xf16> to vector<8x16x2xf16> + %141 = vector.extract_strided_slice %131 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %142 = vector.shape_cast %141 : vector<256xf16> to vector<8x16x2xf16> + %143 = vector.extract_strided_slice %131 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %144 = vector.shape_cast %143 : vector<256xf16> to vector<8x16x2xf16> + %145 = vector.extract_strided_slice %132 {offsets = [0], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %146 = vector.shape_cast %145 : vector<256xf16> to vector<8x16x2xf16> + %147 = vector.extract_strided_slice %132 {offsets = [256], sizes = [256], strides = [1]} : vector<512xf16> to vector<256xf16> + %148 = vector.shape_cast %147 : vector<256xf16> to vector<8x16x2xf16> + xegpu.compile_hint + %149 = xegpu.dpas %114, %134, %arg7 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %150 = xegpu.dpas %122, %136, %149 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %151 = xegpu.dpas %116, %134, %arg11 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %152 = xegpu.dpas %124, %136, %151 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %153 = xegpu.dpas %118, %134, %arg15 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %154 = xegpu.dpas %126, %136, %153 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %155 = xegpu.dpas %120, %134, %arg19 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %156 = xegpu.dpas %128, %136, %155 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %157 = xegpu.dpas %114, %138, %arg8 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %158 = xegpu.dpas %122, %140, %157 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %159 = xegpu.dpas %116, %138, %arg12 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %160 = xegpu.dpas %124, %140, %159 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %161 = xegpu.dpas %118, %138, %arg16 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %162 = xegpu.dpas %126, %140, %161 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %163 = xegpu.dpas %120, %138, %arg20 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %164 = xegpu.dpas %128, %140, %163 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %165 = xegpu.dpas %114, %142, %arg9 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %166 = xegpu.dpas %122, %144, %165 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %167 = xegpu.dpas %116, %142, %arg13 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %168 = xegpu.dpas %124, %144, %167 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %169 = xegpu.dpas %118, %142, %arg17 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %170 = xegpu.dpas %126, %144, %169 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %171 = xegpu.dpas %120, %142, %arg21 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %172 = xegpu.dpas %128, %144, %171 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %173 = xegpu.dpas %114, %146, %arg10 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %174 = xegpu.dpas %122, %148, %173 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %175 = xegpu.dpas %116, %146, %arg14 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %176 = xegpu.dpas %124, %148, %175 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %177 = xegpu.dpas %118, %146, %arg18 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %178 = xegpu.dpas %126, %148, %177 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %179 = xegpu.dpas %120, %146, %arg22 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + %180 = xegpu.dpas %128, %148, %179 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + xegpu.compile_hint + scf.if %96 { + xegpu.nbarrier_wait %44 : !xegpu.nbarrier } - - scf.yield %next_A_tile_0, %next_B_tile_0, %next_B_tile_1, - %new_c_val_0_0, %new_c_val_0_1, %new_c_val_0_2, %new_c_val_0_3, %new_c_val_1_0, %new_c_val_1_1, %new_c_val_1_2, %new_c_val_1_3, %new_c_val_2_0, %new_c_val_2_1, %new_c_val_2_2, %new_c_val_2_3, %new_c_val_3_0, %new_c_val_3_1, %new_c_val_3_2, %new_c_val_3_3, - %next_A_prefetch_tile, %next_B_prefetch_tile - : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, - vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>,vector<8x16xf32>, - !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + scf.yield %108, %109, %110, %150, %158, %166, %174, %152, %160, %168, %176, %154, %162, %170, %178, %156, %164, %172, %180, %106, %107 : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xegpu.tensor_desc<8x32xf16>, !xegpu.tensor_desc<8x32xf16> } - // trunc all DPAS output tiles to f16 - %c_result_0_0_f16 = arith.truncf %k_loop_result#3 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_1_f16 = arith.truncf %k_loop_result#4 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_2_f16 = arith.truncf %k_loop_result#5 : vector<8x16xf32> to vector<8x16xf16> - %c_result_0_3_f16 = arith.truncf %k_loop_result#6 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_0_f16 = arith.truncf %k_loop_result#7 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_1_f16 = arith.truncf %k_loop_result#8 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_2_f16 = arith.truncf %k_loop_result#9 : vector<8x16xf32> to vector<8x16xf16> - %c_result_1_3_f16 = arith.truncf %k_loop_result#10 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_0_f16 = arith.truncf %k_loop_result#11 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_1_f16 = arith.truncf %k_loop_result#12 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_2_f16 = arith.truncf %k_loop_result#13 : vector<8x16xf32> to vector<8x16xf16> - %c_result_2_3_f16 = arith.truncf %k_loop_result#14 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_0_f16 = arith.truncf %k_loop_result#15 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_1_f16 = arith.truncf %k_loop_result#16 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_2_f16 = arith.truncf %k_loop_result#17 : vector<8x16xf32> to vector<8x16xf16> - %c_result_3_3_f16 = arith.truncf %k_loop_result#18 : vector<8x16xf32> to vector<8x16xf16> - // each SG needs to store the result of K loop into a 32x64 tile in C matrix. This is organized in 8x16 DPAS tiles // in the layout of 4x4x8x16. The max store size HW supoprt in f16 is 8x32. So we combine two 8x16 DPAS tiles // horizontally using vector.shuffle to get the required store size. The store layout then will 4x2x8x32 i.e. // we have 8 stores of size 8x32 in the layout 4x2. - - %c_result_8x32_0_0_t1 = vector.shuffle %c_result_0_0_f16, %c_result_0_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_0_0_t2 = vector.shape_cast %c_result_8x32_0_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_0_0 = vector.shape_cast %c_result_8x32_0_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_00 = xegpu.create_nd_tdesc %C[%C_sg_tile_offset_x, %C_sg_tile_offset_y] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_0_0, %c_sg_tile_00 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %46 = arith.truncf %45#3 : vector<8x16xf32> to vector<8x16xf16> + %47 = arith.truncf %45#4 : vector<8x16xf32> to vector<8x16xf16> + %48 = arith.truncf %45#5 : vector<8x16xf32> to vector<8x16xf16> + %49 = arith.truncf %45#6 : vector<8x16xf32> to vector<8x16xf16> + %50 = arith.truncf %45#7 : vector<8x16xf32> to vector<8x16xf16> + %51 = arith.truncf %45#8 : vector<8x16xf32> to vector<8x16xf16> + %52 = arith.truncf %45#9 : vector<8x16xf32> to vector<8x16xf16> + %53 = arith.truncf %45#10 : vector<8x16xf32> to vector<8x16xf16> + %54 = arith.truncf %45#11 : vector<8x16xf32> to vector<8x16xf16> + %55 = arith.truncf %45#12 : vector<8x16xf32> to vector<8x16xf16> + %56 = arith.truncf %45#13 : vector<8x16xf32> to vector<8x16xf16> + %57 = arith.truncf %45#14 : vector<8x16xf32> to vector<8x16xf16> + %58 = arith.truncf %45#15 : vector<8x16xf32> to vector<8x16xf16> + %59 = arith.truncf %45#16 : vector<8x16xf32> to vector<8x16xf16> + %60 = arith.truncf %45#17 : vector<8x16xf32> to vector<8x16xf16> + %61 = arith.truncf %45#18 : vector<8x16xf32> to vector<8x16xf16> + %62 = vector.shuffle %46, %47 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %63 = vector.shape_cast %62 : vector<16x16xf16> to vector<256xf16> + %64 = vector.shape_cast %63 : vector<256xf16> to vector<8x32xf16> + %65 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<4096x4096xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %64, %65 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_0_1_t1 = vector.shuffle %c_result_0_2_f16, %c_result_0_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_0_1_t2 = vector.shape_cast %c_result_8x32_0_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_0_1 = vector.shape_cast %c_result_8x32_0_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_01 = xegpu.update_nd_offset %c_sg_tile_00, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_0_1, %c_sg_tile_01 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %66 = vector.shuffle %48, %49 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %67 = vector.shape_cast %66 : vector<16x16xf16> to vector<256xf16> + %68 = vector.shape_cast %67 : vector<256xf16> to vector<8x32xf16> + %69 = xegpu.update_nd_offset %65, [%c0, %c32] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %68, %69 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_1_0_t1 = vector.shuffle %c_result_1_0_f16, %c_result_1_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_1_0_t2 = vector.shape_cast %c_result_8x32_1_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_1_0 = vector.shape_cast %c_result_8x32_1_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_10 = xegpu.update_nd_offset %c_sg_tile_00, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_1_0, %c_sg_tile_10 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %70 = vector.shuffle %50, %51 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %71 = vector.shape_cast %70 : vector<16x16xf16> to vector<256xf16> + %72 = vector.shape_cast %71 : vector<256xf16> to vector<8x32xf16> + %73 = xegpu.update_nd_offset %65, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %72, %73 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - - %c_result_8x32_1_1_t1 = vector.shuffle %c_result_1_2_f16, %c_result_1_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_1_1_t2 = vector.shape_cast %c_result_8x32_1_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_1_1 = vector.shape_cast %c_result_8x32_1_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_11 = xegpu.update_nd_offset %c_sg_tile_01, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_1_1, %c_sg_tile_11 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %74 = vector.shuffle %52, %53 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %75 = vector.shape_cast %74 : vector<16x16xf16> to vector<256xf16> + %76 = vector.shape_cast %75 : vector<256xf16> to vector<8x32xf16> + %77 = xegpu.update_nd_offset %69, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %76, %77 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_2_0_t1 = vector.shuffle %c_result_2_0_f16, %c_result_2_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_2_0_t2 = vector.shape_cast %c_result_8x32_2_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_2_0 = vector.shape_cast %c_result_8x32_2_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_20 = xegpu.update_nd_offset %c_sg_tile_10, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_2_0, %c_sg_tile_20 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %78 = vector.shuffle %54, %55 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %79 = vector.shape_cast %78 : vector<16x16xf16> to vector<256xf16> + %80 = vector.shape_cast %79 : vector<256xf16> to vector<8x32xf16> + %81 = xegpu.update_nd_offset %73, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %80, %81 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_2_1_t1 = vector.shuffle %c_result_2_2_f16, %c_result_2_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_2_1_t2 = vector.shape_cast %c_result_8x32_2_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_2_1 = vector.shape_cast %c_result_8x32_2_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_21 = xegpu.update_nd_offset %c_sg_tile_11, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_2_1, %c_sg_tile_21 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %82 = vector.shuffle %56, %57 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %83 = vector.shape_cast %82 : vector<16x16xf16> to vector<256xf16> + %84 = vector.shape_cast %83 : vector<256xf16> to vector<8x32xf16> + %85 = xegpu.update_nd_offset %77, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %84, %85 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_3_0_t1 = vector.shuffle %c_result_3_0_f16, %c_result_3_1_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_3_0_t2 = vector.shape_cast %c_result_8x32_3_0_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_3_0 = vector.shape_cast %c_result_8x32_3_0_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_30 = xegpu.update_nd_offset %c_sg_tile_20, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_3_0, %c_sg_tile_30 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %86 = vector.shuffle %58, %59 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %87 = vector.shape_cast %86 : vector<16x16xf16> to vector<256xf16> + %88 = vector.shape_cast %87 : vector<256xf16> to vector<8x32xf16> + %89 = xegpu.update_nd_offset %81, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %88, %89 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> xegpu.compile_hint - - %c_result_8x32_3_1_t1 = vector.shuffle %c_result_3_2_f16, %c_result_3_3_f16 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> - %c_result_8x32_3_1_t2 = vector.shape_cast %c_result_8x32_3_1_t1 : vector<16x16xf16> to vector<256xf16> - %c_result_8x32_3_1 = vector.shape_cast %c_result_8x32_3_1_t2 : vector<256xf16> to vector<8x32xf16> - %c_sg_tile_31 = xegpu.update_nd_offset %c_sg_tile_21, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %c_result_8x32_3_1, %c_sg_tile_31 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> - + %90 = vector.shuffle %60, %61 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %91 = vector.shape_cast %90 : vector<16x16xf16> to vector<256xf16> + %92 = vector.shape_cast %91 : vector<256xf16> to vector<8x32xf16> + %93 = xegpu.update_nd_offset %85, [%c8, %c0] : !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %92, %93 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c1_f16 = arith.constant 1.0 : f16 - %c2_f16 = arith.constant 2.0 : f16 %c4096 = arith.constant 4096 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %c_gen_int = arith.constant 1 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - %A = memref.alloc() : memref<4096x4096xf16> - %B = memref.alloc() : memref<4096x4096xf16> - %C = memref.alloc() : memref<4096x4096xf16> - %C_ref = memref.alloc() : memref<4096x4096xf32> - // Use one of the two options to initialize the A matrix // Option 1: intialize matrix A ; A[i, j] = j // scf.for %i = %c0 to %c4096 step %c1 { @@ -483,9 +367,6 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %A_random = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // Use one of the two options below to initialize the B matrix // Option 1: make matrix B an identity matrix // scf.for %i = %c0 to %c4096 step %c1 { @@ -493,7 +374,6 @@ module @gemm attributes {gpu.container_module} { // %i_i32 = index.castu %i : index to i32 // %j_i32 = index.castu %j : index to i32 // %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - // scf.if %i_j_same { // memref.store %cf_1, %B[%i, %j] : memref<4096x4096xf16> // } else { @@ -502,44 +382,41 @@ module @gemm attributes {gpu.container_module} { // } // } // Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5) - %B_random = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f16 = arith.constant 0.0 : f16 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_f16, %C[%i, %j] : memref<4096x4096xf16> - memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + %cst_0 = arith.constant 0.000000e+00 : f16 + %true = arith.constant true + %cst_1 = arith.constant -5.000000e-01 : f32 + %cst_2 = arith.constant 5.000000e-01 : f32 + %alloc = memref.alloc() : memref<4096x4096xf16> + %alloc_3 = memref.alloc() : memref<4096x4096xf16> + %alloc_4 = memref.alloc() : memref<4096x4096xf16> + %alloc_5 = memref.alloc() : memref<4096x4096xf32> + %cast = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_1, %cst_2, %true) : (memref<*xf16>, f32, f32, i1) -> () + %cast_6 = memref.cast %alloc_3 : memref<4096x4096xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_6, %cst_1, %cst_2, %true) : (memref<*xf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + memref.store %cst_0, %alloc_4[%arg0, %arg1] : memref<4096x4096xf16> + memref.store %cst, %alloc_5[%arg0, %arg1] : memref<4096x4096xf32> } } - // Run GPU. - %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> - %cast_C = memref.cast %2 : memref<4096x4096xf16> to memref<*xf16> - // Run CPU. - %A_cast = memref.cast %A : memref<4096x4096xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<4096x4096xf16> to memref<*xf16> - %C_cast = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - call @gemmF16F16F16(%A_cast, %B_cast, %C_cast) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () - - - %C_row_0 = memref.subview %C_ref[0, 0][1, 4096][1, 1] : memref<4096x4096xf32> to memref<1x4096xf32, strided<[4096, 1], offset:0>> - %C_row_0_cast = memref.cast %C_row_0 : memref<1x4096xf32, strided<[4096, 1], offset: 0>> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () - - %C_row_0_gpu = memref.subview %2[0, 0][1, 4096][1, 1] : memref<4096x4096xf16> to memref<1x4096xf16, strided<[4096, 1], offset:0>> - %C_row_0_cast_gpu = memref.cast %C_row_0_gpu : memref<1x4096xf16, strided<[4096, 1], offset: 0>> to memref<*xf16> // call @printMemrefF16(%C_row_0_cast_gpu) : (memref<*xf16>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_C, %C_cast) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<4096x4096xf16> - memref.dealloc %B : memref<4096x4096xf16> - memref.dealloc %C : memref<4096x4096xf16> - memref.dealloc %C_ref : memref<4096x4096xf32> + %0 = call @test(%alloc, %alloc_3, %alloc_4) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf16>) -> memref<4096x4096xf16> + %cast_7 = memref.cast %0 : memref<4096x4096xf16> to memref<*xf16> + %cast_8 = memref.cast %alloc : memref<4096x4096xf16> to memref<*xf16> + %cast_9 = memref.cast %alloc_3 : memref<4096x4096xf16> to memref<*xf16> + %cast_10 = memref.cast %alloc_5 : memref<4096x4096xf32> to memref<*xf32> + call @gemmF16F16F16(%cast_8, %cast_9, %cast_10) : (memref<*xf16>, memref<*xf16>, memref<*xf32>) -> () + call @printAllcloseF16(%cast_7, %cast_10) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4096x4096xf16> + memref.dealloc %alloc_3 : memref<4096x4096xf16> + memref.dealloc %alloc_4 : memref<4096x4096xf16> + memref.dealloc %alloc_5 : memref<4096x4096xf32> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} @@ -547,5 +424,4 @@ module @gemm attributes {gpu.container_module} { func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} func.func private @gemmF16F16F16(memref<*xf16>, memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_with_extract_e2e.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_with_extract_e2e.mlir index 001478827..429a60288 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_with_extract_e2e.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_with_extract_e2e.mlir @@ -1,191 +1,181 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.0> + memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %memref = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg0, %memref : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_0 = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg1, %memref_0 : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_1 = gpu.alloc host_shared () : memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_gemm blocks in (%c32, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_gemm blocks in (%c32, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) gpu.dealloc %memref : memref<1024x1024xf16> gpu.dealloc %memref_0 : memref<1024x1024xf16> - return %memref_1 : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - -gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_gemm(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %0 = arith.muli %block_id_x, %c32 : index - %1 = arith.muli %block_id_y, %c32 : index - %2 = arith.addi %0, %c0 : index - %3 = arith.addi %1, %c0 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %c16 = arith.constant 16 : index - %5 = arith.addi %1, %c16 : index - %6 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %c8 = arith.constant 8 : index - %7 = arith.addi %0, %c8 : index - %8 = xegpu.create_nd_tdesc %arg2[%7, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %9 = xegpu.create_nd_tdesc %arg2[%7, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %10 = arith.addi %0, %c16 : index - %11 = xegpu.create_nd_tdesc %arg2[%10, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %12 = xegpu.create_nd_tdesc %arg2[%10, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %c24 = arith.constant 24 : index - %13 = arith.addi %0, %c24 : index - %14 = xegpu.create_nd_tdesc %arg2[%13, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %15 = xegpu.create_nd_tdesc %arg2[%13, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %16 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> - %17 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> - %18 = xegpu.load_nd %16 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> -> vector<32x16xf32> - %19 = xegpu.load_nd %17 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf32, #xegpu.block_tdesc_attr> -> vector<32x16xf32> - %20 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %21 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %22:4 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %20, %arg5 = %21, %arg6 = %18, %arg7 = %19) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<32x16xf32>, vector<32x16xf32>) { - %31 = vector.extract_strided_slice %arg6 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %32 = vector.extract_strided_slice %arg6 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %33 = vector.extract_strided_slice %arg6 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %34 = vector.extract_strided_slice %arg6 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %35 = vector.extract_strided_slice %arg7 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %36 = vector.extract_strided_slice %arg7 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %37 = vector.extract_strided_slice %arg7 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %38 = vector.extract_strided_slice %arg7 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %39 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - %40 = vector.extract %39[0] : vector<32x16xf16> from vector<2x32x16xf16> - %41 = vector.extract %39[1] : vector<32x16xf16> from vector<2x32x16xf16> - %42 = vector.extract_strided_slice %40 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %43 = vector.extract_strided_slice %40 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %44 = vector.extract_strided_slice %40 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %45 = vector.extract_strided_slice %40 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %46 = vector.extract_strided_slice %41 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %47 = vector.extract_strided_slice %41 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %48 = vector.extract_strided_slice %41 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - %49 = vector.extract_strided_slice %41 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> - - %50 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - - %51 = vector.extract %50[0] : vector<32x16xf16> from vector<2x32x16xf16> - %52 = vector.extract %50[1] : vector<32x16xf16> from vector<2x32x16xf16> - - %53 = vector.extract_strided_slice %51 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - %54 = vector.extract_strided_slice %51 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - %55 = vector.extract_strided_slice %52 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - %56 = vector.extract_strided_slice %52 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> - - %61 = xegpu.dpas %42, %53, %31 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %62 = xegpu.dpas %46, %54, %61 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %63 = xegpu.dpas %42, %55, %35 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %64 = xegpu.dpas %46, %56, %63 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %65 = xegpu.dpas %43, %53, %32 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %66 = xegpu.dpas %47, %54, %65 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %67 = xegpu.dpas %43, %55, %36 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %68 = xegpu.dpas %47, %56, %67 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %69 = xegpu.dpas %44, %53, %33 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %70 = xegpu.dpas %48, %54, %69 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %71 = xegpu.dpas %44, %55, %37 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %72 = xegpu.dpas %48, %56, %71 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %73 = xegpu.dpas %45, %53, %34 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %74 = xegpu.dpas %49, %54, %73 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %75 = xegpu.dpas %45, %55, %38 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %76 = xegpu.dpas %49, %56, %75 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %77 = vector.shuffle %62, %66 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> - %78 = vector.shuffle %70, %74 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> - %79 = vector.shuffle %77, %78 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x16xf32>, vector<16x16xf32> - %80 = vector.shuffle %64, %68 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> - %81 = vector.shuffle %72, %76 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> - %82 = vector.shuffle %80, %81 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x16xf32>, vector<16x16xf32> - %83 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %84 = xegpu.update_nd_offset %arg5, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - scf.yield %83, %84, %79, %82 : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<32x16xf32>, vector<32x16xf32> + gpu.module @test_kernel { + gpu.func @test_gemm(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c32 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = arith.addi %0, %c0 : index + %3 = arith.addi %1, %c0 : index + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %c16 = arith.constant 16 : index + %5 = arith.addi %1, %c16 : index + %6 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %c8 = arith.constant 8 : index + %7 = arith.addi %0, %c8 : index + %8 = xegpu.create_nd_tdesc %arg2[%7, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %9 = xegpu.create_nd_tdesc %arg2[%7, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %10 = arith.addi %0, %c16 : index + %11 = xegpu.create_nd_tdesc %arg2[%10, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %12 = xegpu.create_nd_tdesc %arg2[%10, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %c24 = arith.constant 24 : index + %13 = arith.addi %0, %c24 : index + %14 = xegpu.create_nd_tdesc %arg2[%13, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %15 = xegpu.create_nd_tdesc %arg2[%13, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %16 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32> + %17 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<1024x1024xf32> -> !xegpu.tensor_desc<32x16xf32> + %18 = xegpu.load_nd %16 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf32> -> vector<32x16xf32> + %19 = xegpu.load_nd %17 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf32> -> vector<32x16xf32> + %20 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %21 = xegpu.create_nd_tdesc %arg1[%c0, %3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %22:4 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %20, %arg5 = %21, %arg6 = %18, %arg7 = %19) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<32x16xf32>, vector<32x16xf32>) { + %31 = vector.extract_strided_slice %arg6 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %32 = vector.extract_strided_slice %arg6 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %33 = vector.extract_strided_slice %arg6 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %34 = vector.extract_strided_slice %arg6 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %35 = vector.extract_strided_slice %arg7 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %36 = vector.extract_strided_slice %arg7 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %37 = vector.extract_strided_slice %arg7 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %38 = vector.extract_strided_slice %arg7 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %39 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %40 = vector.extract %39[0] : vector<32x16xf16> from vector<2x32x16xf16> + %41 = vector.extract %39[1] : vector<32x16xf16> from vector<2x32x16xf16> + %42 = vector.extract_strided_slice %40 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %43 = vector.extract_strided_slice %40 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %44 = vector.extract_strided_slice %40 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %45 = vector.extract_strided_slice %40 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %46 = vector.extract_strided_slice %41 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %47 = vector.extract_strided_slice %41 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %48 = vector.extract_strided_slice %41 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %49 = vector.extract_strided_slice %41 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %50 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %51 = vector.extract %50[0] : vector<32x16xf16> from vector<2x32x16xf16> + %52 = vector.extract %50[1] : vector<32x16xf16> from vector<2x32x16xf16> + %53 = vector.extract_strided_slice %51 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> + %54 = vector.extract_strided_slice %51 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> + %55 = vector.extract_strided_slice %52 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> + %56 = vector.extract_strided_slice %52 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> + %57 = xegpu.dpas %42, %53, %31 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %58 = xegpu.dpas %46, %54, %57 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %59 = xegpu.dpas %42, %55, %35 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %60 = xegpu.dpas %46, %56, %59 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %61 = xegpu.dpas %43, %53, %32 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %62 = xegpu.dpas %47, %54, %61 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %63 = xegpu.dpas %43, %55, %36 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %64 = xegpu.dpas %47, %56, %63 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %65 = xegpu.dpas %44, %53, %33 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %66 = xegpu.dpas %48, %54, %65 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %67 = xegpu.dpas %44, %55, %37 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %68 = xegpu.dpas %48, %56, %67 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %69 = xegpu.dpas %45, %53, %34 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %70 = xegpu.dpas %49, %54, %69 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %71 = xegpu.dpas %45, %55, %38 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %72 = xegpu.dpas %49, %56, %71 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %73 = vector.shuffle %58, %62 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> + %74 = vector.shuffle %66, %70 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> + %75 = vector.shuffle %73, %74 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x16xf32>, vector<16x16xf32> + %76 = vector.shuffle %60, %64 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> + %77 = vector.shuffle %68, %72 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> + %78 = vector.shuffle %76, %77 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16x16xf32>, vector<16x16xf32> + %79 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %80 = xegpu.update_nd_offset %arg5, [%c32, %c0] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + scf.yield %79, %80, %75, %78 : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<32x16xf32>, vector<32x16xf32> + } + %23 = vector.extract_strided_slice %22#2 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %24 = vector.extract_strided_slice %22#2 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %25 = vector.extract_strided_slice %22#2 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %26 = vector.extract_strided_slice %22#2 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %27 = vector.extract_strided_slice %22#3 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %28 = vector.extract_strided_slice %22#3 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %29 = vector.extract_strided_slice %22#3 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + %30 = vector.extract_strided_slice %22#3 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> + xegpu.store_nd %23, %4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %27, %6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %24, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %28, %9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %25, %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %29, %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %26, %14 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %30, %15 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return } - %23 = vector.extract_strided_slice %22#2 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %24 = vector.extract_strided_slice %22#2 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %25 = vector.extract_strided_slice %22#2 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %26 = vector.extract_strided_slice %22#2 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %27 = vector.extract_strided_slice %22#3 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %28 = vector.extract_strided_slice %22#3 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %29 = vector.extract_strided_slice %22#3 {offsets = [16, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - %30 = vector.extract_strided_slice %22#3 {offsets = [24, 0], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf32> to vector<8x16xf32> - xegpu.store_nd %23, %4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %27, %6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %24, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %28, %9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %25, %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %29, %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %26, %14 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %30, %15 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - gpu.return } -} - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 1.000000e+02 : f16 + %c128_i16 = arith.constant 128 : i16 + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1024x1024xf16 : memref<1024x1024xf16> %1 = memref.get_global @__constant_1024x1024xf16_ : memref<1024x1024xf16> - %ref = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> - %init = arith.constant 0.0 : f16 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index // fill the top-left block 128x128 // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 + %2 = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> scf.for %arg0 = %c0 to %c128 step %c1 { scf.for %arg1 = %c0 to %c128 step %c1 { - %int0 = arith.index_cast %arg0 : index to i16 - %int1 = arith.index_cast %arg1 : index to i16 - %c128_i16 = arith.constant 128 : i16 - %idx0 = arith.muli %int0, %c128_i16 : i16 - %idx1 = arith.addi %int1, %idx0 : i16 - %fp = arith.uitofp %idx1 : i16 to f16 - %cst100 = arith.constant 100.0 : f16 - %val0 = arith.divf %fp, %cst100 : f16 - %cst1 = arith.constant 1.0 : f16 - %val1 = arith.addf %val0, %cst1 : f16 - memref.store %val0, %0[%arg0, %arg1] : memref<1024x1024xf16> - memref.store %val1, %1[%arg0, %arg1] : memref<1024x1024xf16> + %4 = arith.index_cast %arg0 : index to i16 + %5 = arith.index_cast %arg1 : index to i16 + %6 = arith.muli %4, %c128_i16 : i16 + %7 = arith.addi %5, %6 : i16 + %8 = arith.uitofp %7 : i16 to f16 + %9 = arith.divf %8, %cst_0 : f16 + %10 = arith.addf %9, %cst : f16 + memref.store %9, %0[%arg0, %arg1] : memref<1024x1024xf16> + memref.store %10, %1[%arg0, %arg1] : memref<1024x1024xf16> } } // caculate the result C matrix scf.for %arg0 = %c0 to %c1024 step %c1 { scf.for %arg1 = %c0 to %c1024 step %c1 { - %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1024xf32> - %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { - %a = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> - %b = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> - %c = arith.mulf %a, %b : f16 - %cc = arith.extf %c : f16 to f32 - %ccc = arith.addf %cc, %arg3 : f32 - scf.yield %ccc : f32 + %4 = memref.load %2[%arg0, %arg1] : memref<1024x1024xf32> + %5 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %4) -> (f32) { + %6 = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> + %7 = memref.load %1[%arg2, %arg1] : memref<1024x1024xf16> + %8 = arith.mulf %6, %7 : f16 + %9 = arith.extf %8 : f16 to f32 + %10 = arith.addf %9, %arg3 : f32 + scf.yield %10 : f32 } - memref.store %res, %ref[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %5, %2[%arg0, %arg1] : memref<1024x1024xf32> } } - - %2 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> - %cast = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast) : (memref<*xf32>) -> () - %cast_ref = memref.cast %ref : memref<1024x1024xf32> to memref<*xf32> //call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %3 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> + %cast = memref.cast %3 : memref<1024x1024xf32> to memref<*xf32> + %cast_1 = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeGPU/VC/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir index f04cadff9..619a98624 100644 --- a/test/Integration/Dialect/XeGPU/VC/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir @@ -1,109 +1,104 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.0> - memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.0> + memref.global "private" @__constant_1024x1024xf16 : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf16_ : memref<1024x1024xf16> = dense<0.000000e+00> + memref.global "private" @__constant_1024x1024xf32 : memref<1024x1024xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg0, %memref : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_0 = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %arg1, %memref_0 : memref<1024x1024xf16> to memref<1024x1024xf16> - %memref_1 = gpu.alloc host_shared () : memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c128, %c64, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) gpu.dealloc %memref : memref<1024x1024xf16> gpu.dealloc %memref_0 : memref<1024x1024xf16> - return %memref_1 : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c8 = arith.constant 8 : index %c1024 = arith.constant 1024 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // each work-group has 1 subgroup. the subgroup caculates a [8x16 = 8x1024 * 1024x16] block - %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> - %8 = xegpu.create_nd_tdesc %arg1[%3, %arg3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> - %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %10 = xegpu.load_nd %8 {transpose_bit_width = 32:i32, transpose = array} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - scf.yield %11 : vector<8x16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.create_nd_tdesc %arg1[%1, %arg3] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %7 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %8 = xegpu.load_nd %6 <{transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = xegpu.dpas %7, %8, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %9 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 1.000000e+02 : f16 + %c128_i16 = arith.constant 128 : i16 + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1024x1024xf16 : memref<1024x1024xf16> %1 = memref.get_global @__constant_1024x1024xf16_ : memref<1024x1024xf16> - %ref = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> - %init = arith.constant 0.0 : f16 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index // fill the top-left block 128x128 // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 + %2 = memref.get_global @__constant_1024x1024xf32 : memref<1024x1024xf32> scf.for %arg0 = %c0 to %c1024 step %c1 { scf.for %arg1 = %c0 to %c1024 step %c1 { - %int0 = arith.index_cast %arg0 : index to i16 - %int1 = arith.index_cast %arg1 : index to i16 - %c128_i16 = arith.constant 128 : i16 - %idx0 = arith.muli %int0, %c128_i16 : i16 - %idx1 = arith.addi %int1, %idx0 : i16 - %fp = arith.uitofp %idx1 : i16 to f16 - %cst100 = arith.constant 100.0 : f16 - %val0 = arith.divf %fp, %cst100 : f16 - %cst1 = arith.constant 1.0 : f16 - %val1 = arith.addf %val0, %cst1 : f16 - memref.store %val0, %0[%arg0, %arg1] : memref<1024x1024xf16> - memref.store %val1, %1[%arg0, %arg1] : memref<1024x1024xf16> + %4 = arith.index_cast %arg0 : index to i16 + %5 = arith.index_cast %arg1 : index to i16 + %6 = arith.muli %4, %c128_i16 : i16 + %7 = arith.addi %5, %6 : i16 + %8 = arith.uitofp %7 : i16 to f16 + %9 = arith.divf %8, %cst_0 : f16 + %10 = arith.addf %9, %cst : f16 + memref.store %9, %0[%arg0, %arg1] : memref<1024x1024xf16> + memref.store %10, %1[%arg0, %arg1] : memref<1024x1024xf16> } } // caculate the result C matrix scf.for %arg0 = %c0 to %c1024 step %c1 { scf.for %arg1 = %c0 to %c1024 step %c1 { - %acc = memref.load %ref[%arg0, %arg1] : memref<1024x1024xf32> - %res = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %acc) -> f32 { - %a = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> - %b = memref.load %1[%arg1, %arg2] : memref<1024x1024xf16> - %t1 = arith.extf %a : f16 to f32 - %t2 = arith.extf %b : f16 to f32 - %c = arith.mulf %t1, %t2 : f32 // %cc = arith.extf %c : f16 to f32 - %ccc = arith.addf %c, %arg3 : f32 - scf.yield %ccc : f32 + %4 = memref.load %2[%arg0, %arg1] : memref<1024x1024xf32> + %5 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %4) -> (f32) { + %6 = memref.load %0[%arg0, %arg2] : memref<1024x1024xf16> + %7 = memref.load %1[%arg1, %arg2] : memref<1024x1024xf16> + %8 = arith.extf %6 : f16 to f32 + %9 = arith.extf %7 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %10, %arg3 : f32 + scf.yield %11 : f32 } - memref.store %res, %ref[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %5, %2[%arg0, %arg1] : memref<1024x1024xf32> } } - - %2 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> - %cast = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast) : (memref<*xf32>) -> () - %cast_ref = memref.cast %ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %3 = call @test(%0, %1) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> memref<1024x1024xf32> + %cast = memref.cast %3 : memref<1024x1024xf32> to memref<*xf32> + %cast_1 = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir b/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir index b99b04142..76633cded 100644 --- a/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_1d_vector_shuffle.vc.mlir @@ -1,84 +1,65 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<8x32xf16>) -> memref<8x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x32xf16> - memref.copy %arg0, %memref : memref<8x32xf16> to memref<8x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<8x32xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_1 : memref<8x32xf16>) + %memref = gpu.alloc () : memref<8x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<8x32xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<8x32xf16>) gpu.dealloc %memref : memref<8x32xf16> - return %memref_1 : memref<8x32xf16> + %alloc = memref.alloc() : memref<8x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<8x32xf16>, memref<8x32xf16> + gpu.dealloc %memref_0 : memref<8x32xf16> + return %alloc : memref<8x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref<8x32xf16>, %arg1: memref<8x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<8x32xf16>, %arg1: memref<8x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> %1 = xegpu.create_nd_tdesc %arg0[0, 16] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %2 = xegpu.load_nd %0 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} - : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %3 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} - : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %2 = xegpu.load_nd %0 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> %4 = vector.shape_cast %2 : vector<8x16xf16> to vector<128xf16> %5 = vector.shape_cast %3 : vector<8x16xf16> to vector<128xf16> - %8 = vector.shuffle %4, %5 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, - 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, - 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, - 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, - 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, - 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, - 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, - 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, - 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, - 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] - : vector<128xf16>, vector<128xf16> - %11 = vector.shape_cast %8 : vector<256xf16> to vector<8x32xf16> - %6 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %11, %6 : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %6 = vector.shuffle %4, %5 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] : vector<128xf16>, vector<128xf16> + %7 = vector.shape_cast %6 : vector<256xf16> to vector<8x32xf16> + %8 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %7, %8 : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<8x32xf16> - %A_zeros = memref.cast %A : memref<8x32xf16> to memref<*xf16> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - call @fillResource1DRandomF16(%A_zeros, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - %B = call @test(%A) : (memref<8x32xf16>) -> memref<8x32xf16> - %A_cast = memref.cast %A : memref<8x32xf16> to memref<*xf16> // call @printMemrefF16(%A_cast): (memref<*xf16>) -> () // call @printMemrefF16(%B_cast): (memref<*xf16>) -> () - %B_copy = memref.alloc() : memref<8x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index %c32 = arith.constant 32 : index - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %v = memref.load %B[%i, %j] : memref<8x32xf16> - %v_f32 = arith.extf %v : f16 to f32 - memref.store %v_f32, %B_copy[%i, %j] : memref<8x32xf32> + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x32xf16> + %cast = memref.cast %alloc : memref<8x32xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %false) : (memref<*xf16>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<8x32xf16>) -> memref<8x32xf16> + %cast_1 = memref.cast %alloc : memref<8x32xf16> to memref<*xf16> + %alloc_2 = memref.alloc() : memref<8x32xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = memref.load %0[%arg0, %arg1] : memref<8x32xf16> + %2 = arith.extf %1 : f16 to f32 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<8x32xf32> } } - %B_cast = memref.cast %B_copy : memref<8x32xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> () - - memref.dealloc %A : memref<8x32xf16> - memref.dealloc %B_copy : memref<8x32xf32> - gpu.dealloc %B : memref<8x32xf16> + %cast_3 = memref.cast %alloc_2 : memref<8x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast_1, %cast_3) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x32xf16> + memref.dealloc %alloc_2 : memref<8x32xf32> + memref.dealloc %0 : memref<8x32xf16> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir b/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir index 834a41e7a..09a974de6 100644 --- a/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_2d_vector_shuffle.vc.mlir @@ -1,68 +1,64 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<8x32xf16>) -> memref<8x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x32xf16> - memref.copy %arg0, %memref : memref<8x32xf16> to memref<8x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<8x32xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_1 : memref<8x32xf16>) + %memref = gpu.alloc () : memref<8x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<8x32xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<8x32xf16>) gpu.dealloc %memref : memref<8x32xf16> - return %memref_1 : memref<8x32xf16> + %alloc = memref.alloc() : memref<8x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<8x32xf16>, memref<8x32xf16> + gpu.dealloc %memref_0 : memref<8x32xf16> + return %alloc : memref<8x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref<8x32xf16>, %arg1: memref<8x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<8x32xf16>, %arg1: memref<8x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> %1 = xegpu.create_nd_tdesc %arg0[0, 16] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %2 = xegpu.load_nd %0 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} - : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %3 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} - : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %8 = vector.shuffle %2, %3 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] - : vector<8x16xf16>, vector<8x16xf16> - %9 = vector.shape_cast %8 : vector<16x16xf16> to vector<256xf16> - %11 = vector.shape_cast %9 : vector<256xf16> to vector<8x32xf16> - %6 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %11, %6 : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + %2 = xegpu.load_nd %0 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %4 = vector.shuffle %2, %3 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x16xf16>, vector<8x16xf16> + %5 = vector.shape_cast %4 : vector<16x16xf16> to vector<256xf16> + %6 = vector.shape_cast %5 : vector<256xf16> to vector<8x32xf16> + %7 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %6, %7 : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<8x32xf16> - %A_zeros = memref.cast %A : memref<8x32xf16> to memref<*xf16> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - call @fillResource1DRandomF16(%A_zeros, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - %B = call @test(%A) : (memref<8x32xf16>) -> memref<8x32xf16> - %A_cast = memref.cast %A : memref<8x32xf16> to memref<*xf16> // call @printMemrefF16(%A_cast): (memref<*xf16>) -> () // call @printMemrefF16(%B_cast): (memref<*xf16>) -> () - %B_copy = memref.alloc() : memref<8x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index %c32 = arith.constant 32 : index - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %v = memref.load %B[%i, %j] : memref<8x32xf16> - %v_f32 = arith.extf %v : f16 to f32 - memref.store %v_f32, %B_copy[%i, %j] : memref<8x32xf32> + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x32xf16> + %cast = memref.cast %alloc : memref<8x32xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %false) : (memref<*xf16>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<8x32xf16>) -> memref<8x32xf16> + %cast_1 = memref.cast %alloc : memref<8x32xf16> to memref<*xf16> + %alloc_2 = memref.alloc() : memref<8x32xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = memref.load %0[%arg0, %arg1] : memref<8x32xf16> + %2 = arith.extf %1 : f16 to f32 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<8x32xf32> } } - %B_cast = memref.cast %B_copy : memref<8x32xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> () - - memref.dealloc %A : memref<8x32xf16> - memref.dealloc %B_copy : memref<8x32xf32> - gpu.dealloc %B : memref<8x32xf16> + %cast_3 = memref.cast %alloc_2 : memref<8x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast_1, %cast_3) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x32xf16> + memref.dealloc %alloc_2 : memref<8x32xf32> + memref.dealloc %0 : memref<8x32xf16> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir b/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir index 796c65da8..5c9fc54d1 100644 --- a/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/large_stores_8x32xf16_w_constant_vector_shuffle.vc.mlir @@ -1,51 +1,47 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test() -> memref<8x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref_1 = gpu.alloc host_shared () : memref<8x32xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_1 : memref<8x32xf16>) - return %memref_1 : memref<8x32xf16> + %memref = gpu.alloc () : memref<8x32xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>) + %alloc = memref.alloc() : memref<8x32xf16> + gpu.memcpy %alloc, %memref : memref<8x32xf16>, memref<8x32xf16> + gpu.dealloc %memref : memref<8x32xf16> + return %alloc : memref<8x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg1: memref<8x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %t0 = arith.constant dense<1.0> : vector<192xf16> - %t1 = arith.constant dense<2.0> : vector<64xf16> - %t3 = vector.shape_cast %t0 : vector<192xf16> to vector<12x16xf16> - %t4 = vector.shape_cast %t1 : vector<64xf16> to vector<4x16xf16> - + gpu.module @test_kernel { // do a vector shuffle with two constant vectors that have different number of elements. - %8 = vector.shuffle %t3, %t4 [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 9, 13, 10, 14, 11, 15] - : vector<12x16xf16>, vector<4x16xf16> - %9 = vector.shape_cast %8: vector<16x16xf16> to vector<256xf16> - %11 = vector.shape_cast %9 : vector<256xf16> to vector<8x32xf16> - %6 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16> - xegpu.store_nd %11, %6 : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> + gpu.func @test_kernel(%arg0: memref<8x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<1.000000e+00> : vector<192xf16> + %cst_0 = arith.constant dense<2.000000e+00> : vector<64xf16> + %0 = vector.shape_cast %cst : vector<192xf16> to vector<12x16xf16> + %1 = vector.shape_cast %cst_0 : vector<64xf16> to vector<4x16xf16> + %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 9, 13, 10, 14, 11, 15] : vector<12x16xf16>, vector<4x16xf16> + %3 = vector.shape_cast %2 : vector<16x16xf16> to vector<256xf16> + %4 = vector.shape_cast %3 : vector<256xf16> to vector<8x32xf16> + %5 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16> + xegpu.store_nd %4, %5 : vector<8x32xf16>, !xegpu.tensor_desc<8x32xf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %B = call @test() : () -> memref<8x32xf16> - %B_copy = memref.alloc() : memref<8x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index %c32 = arith.constant 32 : index - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %v = memref.load %B[%i, %j] : memref<8x32xf16> - %v_f32 = arith.extf %v : f16 to f32 - memref.store %v_f32, %B_copy[%i, %j] : memref<8x32xf32> + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = call @test() : () -> memref<8x32xf16> + %alloc = memref.alloc() : memref<8x32xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = memref.load %0[%arg0, %arg1] : memref<8x32xf16> + %2 = arith.extf %1 : f16 to f32 + memref.store %2, %alloc[%arg0, %arg1] : memref<8x32xf32> } } - %B_cast = memref.cast %B_copy : memref<8x32xf32> to memref<*xf32> - call @printMemrefF32(%B_cast): (memref<*xf32>) -> () // CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], // CHECK-NEXT: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], // CHECK-NEXT: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], @@ -54,9 +50,10 @@ module @gemm attributes {gpu.container_module} { // CHECK-NEXT: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], // CHECK-NEXT: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], // CHECK-NEXT: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] - - memref.dealloc %B_copy : memref<8x32xf32> - gpu.dealloc %B : memref<8x32xf16> + %cast = memref.cast %alloc : memref<8x32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x32xf32> + gpu.dealloc %0 : memref<8x32xf16> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/load-transpose-f16.mlir b/test/Integration/Dialect/XeGPU/VC/load-transpose-f16.mlir index 5c564e28a..4335e1d26 100644 --- a/test/Integration/Dialect/XeGPU/VC/load-transpose-f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load-transpose-f16.mlir @@ -1,28 +1,23 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#slm = #xegpu.scatter_tdesc_attr -#blk_slm = #xegpu.block_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16x32xf16>) -> memref<8x64xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %memref = gpu.alloc host_shared () : memref<16x32xf16> - memref.copy %arg0, %memref : memref<16x32xf16> to memref<16x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<8x64xf16> - gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c4, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_1 : memref<8x64xf16>) - + %memref = gpu.alloc () : memref<16x32xf16> + gpu.memcpy %memref, %arg0 : memref<16x32xf16>, memref<16x32xf16> + %memref_0 = gpu.alloc () : memref<8x64xf16> + gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c4, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_0 : memref<8x64xf16>) gpu.dealloc %memref : memref<16x32xf16> - return %memref_1 : memref<8x64xf16> + %alloc = memref.alloc() : memref<8x64xf16> + gpu.memcpy %alloc, %memref_0 : memref<8x64xf16>, memref<8x64xf16> + gpu.dealloc %memref_0 : memref<8x64xf16> + return %alloc : memref<8x64xf16> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { // this example is to illustrate an example of using slm to do the transpose. // the high level logic is equivalent to the following code: // %data = xegpu.load_nd %in : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16> @@ -32,16 +27,14 @@ module @gemm attributes {gpu.container_module} { gpu.func @test_transpose(%arg0: memref<16x32xf16>, %arg1: memref<8x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %id = gpu.subgroup_id : index - %y = arith.muli %id, %c8 : index - %in = xegpu.create_nd_tdesc %arg0[0, %y] : memref<16x32xf16> -> !xegpu.tensor_desc<16x8xf16> - %data = xegpu.load_nd %in : !xegpu.tensor_desc<16x8xf16> -> vector<16x8xf16> - %transposed = vector.transpose %data, [1, 0] : vector<16x8xf16> to vector<8x16xf16> - - %y2 = arith.muli %id, %c16 : index - %out = xegpu.create_nd_tdesc %arg1[0, %y2]: memref<8x64xf16> -> !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %transposed, %out : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> - + %0 = gpu.subgroup_id : index + %1 = arith.muli %0, %c8 : index + %2 = xegpu.create_nd_tdesc %arg0[0, %1] : memref<16x32xf16> -> !xegpu.tensor_desc<16x8xf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16x8xf16> -> vector<16x8xf16> + %4 = vector.transpose %3, [1, 0] : vector<16x8xf16> to vector<8x16xf16> + %5 = arith.muli %0, %c16 : index + %6 = xegpu.create_nd_tdesc %arg1[0, %5] : memref<8x64xf16> -> !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %4, %6 : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> gpu.return } } @@ -50,22 +43,16 @@ module @gemm attributes {gpu.container_module} { %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %0 = memref.alloc() : memref<16x32xf16> - - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %mul = arith.muli %i, %c32 : index - %add = arith.addi %mul, %j : index - %int = arith.index_cast %add : index to i16 - %fp = arith.uitofp %int : i16 to f16 - memref.store %fp, %0[%i, %j] : memref<16x32xf16> + %alloc = memref.alloc() : memref<16x32xf16> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = arith.index_cast %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + memref.store %4, %alloc[%arg0, %arg1] : memref<16x32xf16> } } - - %2 = call @test(%0) : (memref<16x32xf16>) -> memref<8x64xf16> - %cast = memref.cast %2: memref<8x64xf16> to memref<*xf16> - //CHECK: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480, 8, 40, 72, 104, 136, 168, 200, 232, 264, 296, 328, 360, 392, 424, 456, 488, 16, 48, 80, 112, 144, 176, 208, 240, 272, 304, 336, 368, 400, 432, 464, 496, 24, 56, 88, 120, 152, 184, 216, 248, 280, 312, 344, 376, 408, 440, 472, 504] //CHECK: [1, 33, 65, 97, 129, 161, 193, 225, 257, 289, 321, 353, 385, 417, 449, 481, 9, 41, 73, 105, 137, 169, 201, 233, 265, 297, 329, 361, 393, 425, 457, 489, 17, 49, 81, 113, 145, 177, 209, 241, 273, 305, 337, 369, 401, 433, 465, 497, 25, 57, 89, 121, 153, 185, 217, 249, 281, 313, 345, 377, 409, 441, 473, 505] //CHECK: [2, 34, 66, 98, 130, 162, 194, 226, 258, 290, 322, 354, 386, 418, 450, 482, 10, 42, 74, 106, 138, 170, 202, 234, 266, 298, 330, 362, 394, 426, 458, 490, 18, 50, 82, 114, 146, 178, 210, 242, 274, 306, 338, 370, 402, 434, 466, 498, 26, 58, 90, 122, 154, 186, 218, 250, 282, 314, 346, 378, 410, 442, 474, 506] @@ -74,10 +61,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [5, 37, 69, 101, 133, 165, 197, 229, 261, 293, 325, 357, 389, 421, 453, 485, 13, 45, 77, 109, 141, 173, 205, 237, 269, 301, 333, 365, 397, 429, 461, 493, 21, 53, 85, 117, 149, 181, 213, 245, 277, 309, 341, 373, 405, 437, 469, 501, 29, 61, 93, 125, 157, 189, 221, 253, 285, 317, 349, 381, 413, 445, 477, 509] //CHECK: [6, 38, 70, 102, 134, 166, 198, 230, 262, 294, 326, 358, 390, 422, 454, 486, 14, 46, 78, 110, 142, 174, 206, 238, 270, 302, 334, 366, 398, 430, 462, 494, 22, 54, 86, 118, 150, 182, 214, 246, 278, 310, 342, 374, 406, 438, 470, 502, 30, 62, 94, 126, 158, 190, 222, 254, 286, 318, 350, 382, 414, 446, 478, 510] //CHECK: [7, 39, 71, 103, 135, 167, 199, 231, 263, 295, 327, 359, 391, 423, 455, 487, 15, 47, 79, 111, 143, 175, 207, 239, 271, 303, 335, 367, 399, 431, 463, 495, 23, 55, 87, 119, 151, 183, 215, 247, 279, 311, 343, 375, 407, 439, 471, 503, 31, 63, 95, 127, 159, 191, 223, 255, 287, 319, 351, 383, 415, 447, 479, 511] - call @printMemrefF16(%cast): (memref<*xf16>) -> () - + %0 = call @test(%alloc) : (memref<16x32xf16>) -> memref<8x64xf16> + %cast = memref.cast %0 : memref<8x64xf16> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load1d-f32.mlir b/test/Integration/Dialect/XeGPU/VC/load1d-f32.mlir index 1bc339f3a..ef0cc3808 100644 --- a/test/Integration/Dialect/XeGPU/VC/load1d-f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load1d-f32.mlir @@ -1,45 +1,39 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0]> +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01, 1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01, 1.600000e+01, 1.700000e+01, 1.800000e+01, 1.900000e+01, 2.000000e+01, 2.100000e+01, 2.200000e+01, 2.300000e+01, 2.400000e+01, 2.500000e+01, 2.600000e+01, 2.700000e+01, 2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]> func.func @test(%arg0: memref<32xf32>) -> memref<32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32xf32> - memref.copy %arg0, %memref : memref<32xf32> to memref<32xf32> - %memref_1 = gpu.alloc host_shared () : memref<32xf32> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32xf32>, %memref_1 : memref<32xf32>) + %memref = gpu.alloc () : memref<32xf32> + gpu.memcpy %memref, %arg0 : memref<32xf32>, memref<32xf32> + %memref_0 = gpu.alloc () : memref<32xf32> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32xf32>, %memref_0 : memref<32xf32>) gpu.dealloc %memref : memref<32xf32> - return %memref_1 : memref<32xf32> + %alloc = memref.alloc() : memref<32xf32> + gpu.memcpy %alloc, %memref_0 : memref<32xf32>, memref<32xf32> + gpu.dealloc %memref_0 : memref<32xf32> + return %alloc : memref<32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<32xf32>, %arg1: memref<32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c4 = arith.constant 4: index + %c4 = arith.constant 4 : index %0 = xegpu.create_nd_tdesc %arg0[%c4] : memref<32xf32> -> !xegpu.tensor_desc<16xf32> - %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> - + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> %2 = xegpu.create_nd_tdesc %arg1[%c4] : memref<32xf32> -> !xegpu.tensor_desc<16xf32> - xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32> + xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_32xf32 : memref<32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %2 = call @test(%0) : (memref<32xf32>) -> memref<32xf32> - %cast = memref.cast %2: memref<32xf32> to memref<*xf32> //CHECK: [0, 0, 0, 0, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefF32(%cast): (memref<*xf32>) -> () + %1 = call @test(%0) : (memref<32xf32>) -> memref<32xf32> + %cast = memref.cast %1 : memref<32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load1d-slm-f32.mlir b/test/Integration/Dialect/XeGPU/VC/load1d-slm-f32.mlir index 6fedf9a5c..a091d048a 100644 --- a/test/Integration/Dialect/XeGPU/VC/load1d-slm-f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load1d-slm-f32.mlir @@ -1,56 +1,43 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#slm = #xegpu.block_tdesc_attr module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_8x16xf32 : memref<32xf32> = dense<[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0]> - + memref.global "private" constant @__constant_8x16xf32 : memref<32xf32> = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01, 1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01, 1.600000e+01, 1.700000e+01, 1.800000e+01, 1.900000e+01, 2.000000e+01, 2.100000e+01, 2.200000e+01, 2.300000e+01, 2.400000e+01, 2.500000e+01, 2.600000e+01, 2.700000e+01, 2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]> func.func @test(%arg0: memref<32xf32>) -> memref<32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32xf32> - memref.copy %arg0, %memref : memref<32xf32> to memref<32xf32> - %memref_1 = gpu.alloc host_shared () : memref<32xf32> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32xf32>, %memref_1 : memref<32xf32>) - + %memref = gpu.alloc () : memref<32xf32> + gpu.memcpy %memref, %arg0 : memref<32xf32>, memref<32xf32> + %memref_0 = gpu.alloc () : memref<32xf32> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32xf32>, %memref_0 : memref<32xf32>) gpu.dealloc %memref : memref<32xf32> - return %memref_1 : memref<32xf32> + %alloc = memref.alloc() : memref<32xf32> + gpu.memcpy %alloc, %memref_0 : memref<32xf32>, memref<32xf32> + gpu.dealloc %memref_0 : memref<32xf32> + return %alloc : memref<32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<32xf32>, %arg1: memref<32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c4 = arith.constant 4: index + %c4 = arith.constant 4 : index %0 = xegpu.create_nd_tdesc %arg0[%c4] : memref<32xf32> -> !xegpu.tensor_desc<16xf32> - %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> - - %slm = memref.alloc() : memref<16xf32, 3> - %2 = xegpu.create_nd_tdesc %slm[0] : memref<16xf32, 3> -> !xegpu.tensor_desc<16xf32, #slm> - xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #slm> - - %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16xf32, #slm> -> vector<16xf32> - + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> + %alloc = memref.alloc() : memref<16xf32, 3> + %2 = xegpu.create_nd_tdesc %alloc[0] : memref<16xf32, 3> -> !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr> + xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16xf32, #xegpu.block_tdesc_attr> -> vector<16xf32> %4 = xegpu.create_nd_tdesc %arg1[3] : memref<32xf32> -> !xegpu.tensor_desc<16xf32> - xegpu.store_nd %3, %4 : vector<16xf32>, !xegpu.tensor_desc<16xf32> + xegpu.store_nd %3, %4 : vector<16xf32>, !xegpu.tensor_desc<16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_8x16xf32 : memref<32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %2 = call @test(%0) : (memref<32xf32>) -> memref<32xf32> - - %cast = memref.cast %2: memref<32xf32> to memref<*xf32> - //CHECK: [0, 0, 0, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefF32(%cast): (memref<*xf32>) -> () + %1 = call @test(%0) : (memref<32xf32>) -> memref<32xf32> + %cast = memref.cast %1 : memref<32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-padding-f32.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-padding-f32.mlir index 87c39cb0f..c04e5da2f 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-padding-f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-padding-f32.mlir @@ -1,49 +1,46 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.0> - func.func @test(%arg0: memref<8x16xf32>,%arg1:index)attributes {llvm.emit_c_interface} { + memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.000000e+00> + func.func @test(%arg0: memref<8x16xf32>, %arg1: index) attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref_0 = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %arg0, %memref_0 : memref<8x16xf32> to memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @test_kernel::@test_padding_f32 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<8x16xf32>, %memref_1 : memref<8x16xf32>, %arg1:index) - %cast1 = memref.cast %memref_1 : memref<8x16xf32> to memref<*xf32> - call @printMemrefF32(%cast1) : (memref<*xf32>) -> () + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_padding_f32 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>, %arg1 : index) + %result = memref.alloc() : memref<8x16xf32> + gpu.memcpy %result, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + gpu.dealloc %memref : memref<8x16xf32> + // Print the result + %cast = memref.cast %result : memref<8x16xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () + memref.dealloc %result : memref<8x16xf32> return } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - - gpu.func @test_padding_f32(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg3:index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %0 = xegpu.create_nd_tdesc %arg0[%arg3, %arg3] - : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %1 = xegpu.create_nd_tdesc %arg1[0, 0] - : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - xegpu.store_nd %3,%1 : vector<8x16xf32>,!xegpu.tensor_desc<8x16xf32> + gpu.module @test_kernel { + gpu.func @test_padding_f32(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg2: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[%arg2, %arg2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + xegpu.store_nd %2, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } - } func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.get_global @__constant_8x16xf32 : memref<8x16xf32> - %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - call @test(%0, %c1) : (memref<8x16xf32>, index)-> () - call @test(%0, %c2) : (memref<8x16xf32>, index)-> () + %c1 = arith.constant 1 : index + %0 = memref.get_global @__constant_8x16xf32 : memref<8x16xf32> + call @test(%0, %c1) : (memref<8x16xf32>, index) -> () + call @test(%0, %c2) : (memref<8x16xf32>, index) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } - // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} // CHECK-SAME: rank = 2 offset = 0 sizes = [8, 16] strides = [16, 1] data = // CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-padding.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-padding.mlir index 54ee099f0..0e1f58b22 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-padding.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-padding.mlir @@ -1,52 +1,46 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { // memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.0> - memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.0> - - func.func @test(%arg0: memref<8x16xf32>,%arg3:index) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<1.000000e+00> + func.func @test(%arg0: memref<8x16xf32>, %arg1: index) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %arg0, %memref : memref<8x16xf32> to memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @test_kernel::@test_padding blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_1 : memref<8x16xf32>, %arg3:index) - + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_padding blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>, %arg1 : index) gpu.dealloc %memref : memref<8x16xf32> - return %memref_1 : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_padding(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>,%arg3:index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %0 = xegpu.create_nd_tdesc %arg0[%arg3, %arg3] - : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.create_nd_tdesc %arg1[0, 0] - : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - xegpu.store_nd %3,%2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.module @test_kernel { + gpu.func @test_padding(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg2: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[%arg2, %arg2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + xegpu.store_nd %2, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.get_global @__constant_8x16xf32 : memref<8x16xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %2 = call @test(%0, %c1) : (memref<8x16xf32>, index) -> memref<8x16xf32> - %3 = call @test(%0, %c2) : (memref<8x16xf32>, index) -> memref<8x16xf32> - %c7 = arith.constant 7 : index - %vector_0 = vector.load %2[%c7,%c0] :memref<8x16xf32>, vector<16xf32> + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = memref.get_global @__constant_8x16xf32 : memref<8x16xf32> // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) - vector.print %vector_0 : vector<16xf32> - - %vector_1 = vector.load %3[%c0,%c0] :memref<8x16xf32>, vector<16xf32> // CHECK: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0 ) - vector.print %vector_1 : vector<16xf32> + %1 = call @test(%0, %c1) : (memref<8x16xf32>, index) -> memref<8x16xf32> + %2 = call @test(%0, %c2) : (memref<8x16xf32>, index) -> memref<8x16xf32> + %3 = vector.load %1[%c7, %c0] : memref<8x16xf32>, vector<16xf32> + vector.print %3 : vector<16xf32> + %4 = vector.load %2[%c0, %c0] : memref<8x16xf32>, vector<16xf32> + vector.print %4 : vector<16xf32> return } } diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-bf16.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-bf16.mlir index 0b1ba7ff7..d87dcbad4 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-bf16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-bf16.mlir @@ -1,49 +1,41 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { // memref.global "private" constant @__constant_8x16xbf16 : memref<16x32xbf16> = dense<1.0> - memref.global "private" constant @__constant_8x16xbf16 : memref<16x32xbf16> = dense<1.0> - + memref.global "private" constant @__constant_8x16xbf16 : memref<16x32xbf16> = dense<1.000000e+00> func.func @test(%arg0: memref<16x32xbf16>) -> memref<16x32xbf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<16x32xbf16> - memref.copy %arg0, %memref : memref<16x32xbf16> to memref<16x32xbf16> - %memref_1 = gpu.alloc host_shared () : memref<16x32xbf16> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xbf16>, %memref_1 : memref<16x32xbf16>) - + %memref = gpu.alloc () : memref<16x32xbf16> + gpu.memcpy %memref, %arg0 : memref<16x32xbf16>, memref<16x32xbf16> + %memref_0 = gpu.alloc () : memref<16x32xbf16> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xbf16>, %memref_0 : memref<16x32xbf16>) gpu.dealloc %memref : memref<16x32xbf16> - return %memref_1 : memref<16x32xbf16> + %alloc = memref.alloc() : memref<16x32xbf16> + gpu.memcpy %alloc, %memref_0 : memref<16x32xbf16>, memref<16x32xbf16> + gpu.dealloc %memref_0 : memref<16x32xbf16> + return %alloc : memref<16x32xbf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> - %2 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> - %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> - xegpu.store_nd %3,%2 : vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16> + %1 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> + xegpu.store_nd %2, %1 : vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_8x16xbf16 : memref<16x32xbf16> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %2 = call @test(%0) : (memref<16x32xbf16>) -> memref<16x32xbf16> - - %cast = memref.cast %2: memref<16x32xbf16> to memref<*xbf16> - //CHECK-COUNT-2: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK-COUNT-8: [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK-COUNT-6: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefBF16(%cast): (memref<*xbf16>) -> () + %1 = call @test(%0) : (memref<16x32xbf16>) -> memref<16x32xbf16> + %cast = memref.cast %1 : memref<16x32xbf16> to memref<*xbf16> + call @printMemrefBF16(%cast) : (memref<*xbf16>) -> () return } - func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-transpose.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-transpose.mlir index c26cb3d46..94b898cdb 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-transpose.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-transpose.mlir @@ -1,54 +1,46 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<8x32xf16>) -> memref<16x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x32xf16> - memref.copy %arg0, %memref : memref<8x32xf16> to memref<8x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf16> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_1 : memref<16x32xf16>) - + %memref = gpu.alloc () : memref<8x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<16x32xf16> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<16x32xf16>) gpu.dealloc %memref : memref<8x32xf16> - return %memref_1 : memref<16x32xf16> + %alloc = memref.alloc() : memref<16x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<16x32xf16>, memref<16x32xf16> + gpu.dealloc %memref_0 : memref<16x32xf16> + return %alloc : memref<16x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<8x32xf16>, %arg1: memref<16x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %2 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %3 = xegpu.load_nd %0 {transpose = array, transpose_bit_width = 32:i32} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> - %4 = vector.shape_cast %3 : vector<8x8x2xf16> to vector<8x16xf16> - xegpu.store_nd %4, %2 : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %2 = xegpu.load_nd %0 <{transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %3 = vector.shape_cast %2 : vector<8x8x2xf16> to vector<8x16xf16> + xegpu.store_nd %3, %1 : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.alloc() : memref<8x32xf16> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %m = arith.muli %i, %c16 : index - %a = arith.addi %m, %j : index - %t = index.castu %a : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %0[%i, %j] : memref<8x32xf16> + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<8x32xf16> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = arith.muli %arg0, %c16 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + memref.store %4, %alloc[%arg0, %arg1] : memref<8x32xf16> } } - - - %2 = call @test(%0) : (memref<8x32xf16>) -> memref<16x32xf16> - - %cast = memref.cast %2: memref<16x32xf16> to memref<*xf16> - //CHECK-COUNT-2: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK: [0, 0, 0, 1, 16, 17, 32, 33, 48, 49, 64, 65, 80, 81, 96, 97, 112, 113, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], //CHECK: [0, 0, 2, 3, 18, 19, 34, 35, 50, 51, 66, 67, 82, 83, 98, 99, 114, 115, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -59,9 +51,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [0, 0, 12, 13, 28, 29, 44, 45, 60, 61, 76, 77, 92, 93, 108, 109, 124, 125, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], //CHECK: [0, 0, 14, 15, 30, 31, 46, 47, 62, 63, 78, 79, 94, 95, 110, 111, 126, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], //CHECK-COUNT-6: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefF16(%cast): (memref<*xf16>) -> () + %0 = call @test(%alloc) : (memref<8x32xf16>) -> memref<16x32xf16> + %cast = memref.cast %0 : memref<16x32xf16> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-vnni.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-vnni.mlir index 6bafcb353..6165bfaca 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-vnni.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16-vnni.mlir @@ -1,62 +1,55 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<8x32xf16>) -> memref<16x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x32xf16> - memref.copy %arg0, %memref : memref<8x32xf16> to memref<8x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf16> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_1 : memref<16x32xf16>) - + %memref = gpu.alloc () : memref<8x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<16x32xf16> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<16x32xf16>) gpu.dealloc %memref : memref<8x32xf16> - return %memref_1 : memref<16x32xf16> + %alloc = memref.alloc() : memref<16x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<16x32xf16>, memref<16x32xf16> + gpu.dealloc %memref_0 : memref<16x32xf16> + return %alloc : memref<16x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<8x32xf16>, %arg1: memref<16x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.load_nd %0 {packed} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> + %1 = xegpu.load_nd %0 <{packed}> : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> %2 = vector.shape_cast %1 : vector<4x16x2xf16> to vector<4x32xf16> %3 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x32xf16> -> !xegpu.tensor_desc<4x32xf16> - xegpu.store_nd %2, %3 : vector<4x32xf16>, !xegpu.tensor_desc<4x32xf16> + xegpu.store_nd %2, %3 : vector<4x32xf16>, !xegpu.tensor_desc<4x32xf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.alloc() : memref<8x32xf16> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %m = arith.muli %i, %c16 : index - %a = arith.addi %m, %j : index - %t = index.castu %a : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %0[%i, %j] : memref<8x32xf16> + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<8x32xf16> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = arith.muli %arg0, %c16 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + memref.store %4, %alloc[%arg0, %arg1] : memref<8x32xf16> } } - - - %2 = call @test(%0) : (memref<8x32xf16>) -> memref<16x32xf16> - - %cast = memref.cast %2: memref<16x32xf16> to memref<*xf16> - //CHECK: [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31] //CHECK: [32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63] //CHECK: [64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95] //CHECK: [96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127] //CHECK-COUNT-12: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefF16(%cast): (memref<*xf16>) -> () + %0 = call @test(%alloc) : (memref<8x32xf16>) -> memref<16x32xf16> + %cast = memref.cast %0 : memref<16x32xf16> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16.mlir index cceb0f0fe..ae2796b67 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f16.mlir @@ -1,49 +1,41 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { // memref.global "private" constant @__constant_8x16xf16 : memref<16x32xf16> = dense<1.0> - memref.global "private" constant @__constant_8x16xf16 : memref<16x32xf16> = dense<1.0> - + memref.global "private" constant @__constant_8x16xf16 : memref<16x32xf16> = dense<1.000000e+00> func.func @test(%arg0: memref<16x32xf16>) -> memref<16x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<16x32xf16> - memref.copy %arg0, %memref : memref<16x32xf16> to memref<16x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf16> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_1 : memref<16x32xf16>) - + %memref = gpu.alloc () : memref<16x32xf16> + gpu.memcpy %memref, %arg0 : memref<16x32xf16>, memref<16x32xf16> + %memref_0 = gpu.alloc () : memref<16x32xf16> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_0 : memref<16x32xf16>) gpu.dealloc %memref : memref<16x32xf16> - return %memref_1 : memref<16x32xf16> + %alloc = memref.alloc() : memref<16x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<16x32xf16>, memref<16x32xf16> + gpu.dealloc %memref_0 : memref<16x32xf16> + return %alloc : memref<16x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<16x32xf16>, %arg1: memref<16x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %2 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - xegpu.store_nd %3,%2 : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + xegpu.store_nd %2, %1 : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_8x16xf16 : memref<16x32xf16> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %2 = call @test(%0) : (memref<16x32xf16>) -> memref<16x32xf16> - - %cast = memref.cast %2: memref<16x32xf16> to memref<*xf16> - //CHECK-COUNT-2: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK-COUNT-8: [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK-COUNT-6: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefF16(%cast): (memref<*xf16>) -> () + %1 = call @test(%0) : (memref<16x32xf16>) -> memref<16x32xf16> + %cast = memref.cast %1 : memref<16x32xf16> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32-transpose.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32-transpose.mlir index bd3a77bb5..94134c5e1 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32-transpose.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32-transpose.mlir @@ -1,55 +1,46 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16x16xf32>) -> memref<16x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<16x16xf32> - memref.copy %arg0, %memref : memref<16x16xf32> to memref<16x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf32> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x16xf32>, %memref_1 : memref<16x32xf32>) - + %memref = gpu.alloc () : memref<16x16xf32> + gpu.memcpy %memref, %arg0 : memref<16x16xf32>, memref<16x16xf32> + %memref_0 = gpu.alloc () : memref<16x32xf32> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x16xf32>, %memref_0 : memref<16x32xf32>) gpu.dealloc %memref : memref<16x16xf32> - return %memref_1 : memref<16x32xf32> + %alloc = memref.alloc() : memref<16x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<16x32xf32>, memref<16x32xf32> + gpu.dealloc %memref_0 : memref<16x32xf32> + return %alloc : memref<16x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<16x16xf32>, %arg1: memref<16x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x16xf32> -> !xegpu.tensor_desc<16x8xf32> - %1 = xegpu.load_nd %0 {transpose = array} : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> + %1 = xegpu.load_nd %0 <{transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> %2 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %1, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + xegpu.store_nd %1, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.alloc() : memref<16x16xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index // input matrix is [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], ...] - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c8 step %c1 { - %m = arith.muli %i, %c8 : index - %a = arith.addi %m, %j : index - %t = index.castu %a : index to i32 - %val = arith.uitofp %t : i32 to f32 - memref.store %val, %0[%i, %j] : memref<16x16xf32> + %alloc = memref.alloc() : memref<16x16xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + %1 = arith.muli %arg0, %c8 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i32 + %4 = arith.uitofp %3 : i32 to f32 + memref.store %4, %alloc[%arg0, %arg1] : memref<16x16xf32> } } - - - %2 = call @test(%0) : (memref<16x16xf32>) -> memref<16x32xf32> - - %cast = memref.cast %2: memref<16x32xf32> to memref<*xf32> - //CHECK-COUNT-2: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK: [0, 0, 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], //CHECK: [0, 0, 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 105, 113, 121, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -60,9 +51,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [0, 0, 6, 14, 22, 30, 38, 46, 54, 62, 70, 78, 86, 94, 102, 110, 118, 126, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], //CHECK: [0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], //CHECK-COUNT-6: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefF32(%cast): (memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<16x16xf32>) -> memref<16x32xf32> + %cast = memref.cast %0 : memref<16x32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32.mlir b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32.mlir index bcd5ee6be..077614637 100644 --- a/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load2d-ugm-f32.mlir @@ -1,49 +1,41 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { // memref.global "private" constant @__constant_8x16xf32 : memref<16x32xf32> = dense<1.0> - memref.global "private" constant @__constant_8x16xf32 : memref<16x32xf32> = dense<1.0> - + memref.global "private" constant @__constant_8x16xf32 : memref<16x32xf32> = dense<1.000000e+00> func.func @test(%arg0: memref<16x32xf32>) -> memref<16x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<16x32xf32> - memref.copy %arg0, %memref : memref<16x32xf32> to memref<16x32xf32> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf32> - gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xf32>, %memref_1 : memref<16x32xf32>) - + %memref = gpu.alloc () : memref<16x32xf32> + gpu.memcpy %memref, %arg0 : memref<16x32xf32>, memref<16x32xf32> + %memref_0 = gpu.alloc () : memref<16x32xf32> + gpu.launch_func @test_kernel::@test_copy blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xf32>, %memref_0 : memref<16x32xf32>) gpu.dealloc %memref : memref<16x32xf32> - return %memref_1 : memref<16x32xf32> + %alloc = memref.alloc() : memref<16x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<16x32xf32>, memref<16x32xf32> + gpu.dealloc %memref_0 : memref<16x32xf32> + return %alloc : memref<16x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_copy(%arg0: memref<16x32xf32>, %arg1: memref<16x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %3 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - xegpu.store_nd %3,%2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.create_nd_tdesc %arg1[2, 2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + xegpu.store_nd %2, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_8x16xf32 : memref<16x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %2 = call @test(%0) : (memref<16x32xf32>) -> memref<16x32xf32> - - %cast = memref.cast %2: memref<16x32xf32> to memref<*xf32> - //CHECK-COUNT-2: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK-COUNT-8: [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] //CHECK-COUNT-6: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - call @printMemrefF32(%cast): (memref<*xf32>) -> () + %1 = call @test(%0) : (memref<16x32xf32>) -> memref<16x32xf32> + %cast = memref.cast %1 : memref<16x32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/load_store_1x16xf16.mlir b/test/Integration/Dialect/XeGPU/VC/load_store_1x16xf16.mlir index e6bbcdc40..213cbc57a 100644 --- a/test/Integration/Dialect/XeGPU/VC/load_store_1x16xf16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load_store_1x16xf16.mlir @@ -1,57 +1,50 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<1x32xf16>) -> memref<1x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1x32xf16> - memref.copy %arg0, %memref : memref<1x32xf16> to memref<1x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<1x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1x32xf16>, %memref_1 : memref<1x32xf32>) + %memref = gpu.alloc () : memref<1x32xf16> + gpu.memcpy %memref, %arg0 : memref<1x32xf16>, memref<1x32xf16> + %memref_0 = gpu.alloc () : memref<1x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1x32xf16>, %memref_0 : memref<1x32xf32>) gpu.dealloc %memref : memref<1x32xf16> - return %memref_1 : memref<1x32xf32> + %alloc = memref.alloc() : memref<1x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<1x32xf32>, memref<1x32xf32> + gpu.dealloc %memref_0 : memref<1x32xf32> + return %alloc : memref<1x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<1x32xf16>, %arg1: memref<1x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %src_tdesc_0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x16xf16> - %src_tdesc_1 = xegpu.create_nd_tdesc %arg0[0, 16] : memref<1x32xf16> -> !xegpu.tensor_desc<1x16xf16> - - %src_loaded_0 = xegpu.load_nd %src_tdesc_0 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16> - %src_loaded_1 = xegpu.load_nd %src_tdesc_1 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16> - - %src_loaded_0_f32 = arith.extf %src_loaded_0: vector<1x16xf16> to vector<1x16xf32> - %src_loaded_1_f32 = arith.extf %src_loaded_1: vector<1x16xf16> to vector<1x16xf32> - - %dest_tdesc_0 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<1x32xf32> -> !xegpu.tensor_desc<1x16xf32> - %dest_tdesc_1 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<1x32xf32> -> !xegpu.tensor_desc<1x16xf32> - - xegpu.store_nd %src_loaded_0_f32, %dest_tdesc_0 : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32> - xegpu.store_nd %src_loaded_1_f32, %dest_tdesc_1 : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32> - + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x16xf16> + %1 = xegpu.create_nd_tdesc %arg0[0, 16] : memref<1x32xf16> -> !xegpu.tensor_desc<1x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16> + %4 = arith.extf %2 : vector<1x16xf16> to vector<1x16xf32> + %5 = arith.extf %3 : vector<1x16xf16> to vector<1x16xf32> + %6 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<1x32xf32> -> !xegpu.tensor_desc<1x16xf32> + %7 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<1x32xf32> -> !xegpu.tensor_desc<1x16xf32> + xegpu.store_nd %4, %6 : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32> + xegpu.store_nd %5, %7 : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<1x32xf16> // 1x32 to ensure surface pitch >= 64 - %A_random = memref.cast %A : memref<1x32xf16> to memref<*xf16> - %c_gen_int = arith.constant 1 : i1 - %cf_lower = arith.constant -2.0 : f32 - %cf_upper = arith.constant 2.0 : f32 - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - %B = call @test(%A) : (memref<1x32xf16>) -> memref<1x32xf32> - %A_cast = memref.cast %A : memref<1x32xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<1x32xf32> to memref<*xf32> // call @printMemrefF16(%A_cast) : (memref<*xf16>) -> () // call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> () + %cst = arith.constant 2.000000e+00 : f32 + %cst_0 = arith.constant -2.000000e+00 : f32 + %true = arith.constant true + %alloc = memref.alloc() : memref<1x32xf16> + %cast = memref.cast %alloc : memref<1x32xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %true) : (memref<*xf16>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<1x32xf16>) -> memref<1x32xf32> + %cast_1 = memref.cast %alloc : memref<1x32xf16> to memref<*xf16> + %cast_2 = memref.cast %0 : memref<1x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast_1, %cast_2) : (memref<*xf16>, memref<*xf32>) -> () return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/load_store_non_pow2.mlir b/test/Integration/Dialect/XeGPU/VC/load_store_non_pow2.mlir index aef0152f3..64d0fc4b9 100644 --- a/test/Integration/Dialect/XeGPU/VC/load_store_non_pow2.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load_store_non_pow2.mlir @@ -1,64 +1,54 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @loadstore attributes {gpu.container_module} { - func.func @test(%A: memref<8x16xf32>, %B: memref<8x16xf32> ) -> (memref<8x16xf32>) attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %A, %memref : memref<8x16xf32> to memref<8x16xf32> - memref.copy %B, %memref_1 : memref<8x16xf32> to memref<8x16xf32> - gpu.launch_func @module::@test blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_1 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + gpu.memcpy %memref_0, %arg1 : memref<8x16xf32>, memref<8x16xf32> + gpu.launch_func @module::@test blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>) gpu.dealloc %memref : memref<8x16xf32> - return %memref_1 : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - - gpu.module @module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test(%A: memref<8x16xf32>, %B: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module { + gpu.func @test(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index // load A tile - %a_tile = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %val_a = xegpu.load_nd %a_tile : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // store to B tile - %b_tile = xegpu.create_nd_tdesc %B [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %val_a, %b_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %1, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_0_f32 = arith.constant 0.0 : f32 - %cf_2_f32 = arith.constant 2.0 : f32 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<8x16xf32> - %B = memref.alloc() : memref<8x16xf32> // TRY 8x15. While it can encode vector type to 120f32 for intrinsics, the result is wrong. - // fill A with 2, B with 0 - %A_nonzero = memref.cast %A : memref<8x16xf32> to memref<*xf32> - %B_zeros = memref.cast %B : memref<8x16xf32> to memref<*xf32> - call @fillResource1DF32(%A_nonzero, %cf_2_f32) : (memref<*xf32>, f32) -> () - call @fillResource1DF32(%B_zeros, %cf_0_f32) : (memref<*xf32>, f32) -> () // Load from A, store to B - %2 = call @test(%A, %B) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> - - %B_filled = memref.cast %2 : memref<8x16xf32> to memref<*xf32> // call @printMemrefF32(%A_nonzero) : (memref<*xf32>) -> () // call @printMemrefF32(%B_filled) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_nonzero, %B_filled) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<8x16xf32> - memref.dealloc %B : memref<8x16xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %alloc = memref.alloc() : memref<8x16xf32> + %alloc_1 = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<8x16xf32> to memref<*xf32> + call @fillResource1DF32(%cast, %cst_0) : (memref<*xf32>, f32) -> () + call @fillResource1DF32(%cast_2, %cst) : (memref<*xf32>, f32) -> () + %0 = call @test(%alloc, %alloc_1) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> + %cast_3 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x16xf32> + memref.dealloc %alloc_1 : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_bf16_tile.mlir b/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_bf16_tile.mlir index bc0a88512..18d13fc4c 100644 --- a/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_bf16_tile.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_bf16_tile.mlir @@ -1,32 +1,31 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_8x32xbf16 : memref<8x32xbf16> = dense<0.0> + memref.global "private" constant @__constant_8x32xbf16 : memref<8x32xbf16> = dense<0.000000e+00> func.func @test(%arg0: memref<8x32xbf16>, %arg1: memref<8x32xbf16>) -> memref<8x32xbf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index - - %memref = gpu.alloc host_shared () : memref<8x32xbf16> - memref.copy %arg0, %memref : memref<8x32xbf16> to memref<8x32xbf16> - %memref_1 = gpu.alloc host_shared () : memref<8x32xbf16> - memref.copy %arg1, %memref_1 : memref<8x32xbf16> to memref<8x32xbf16> - %memref_2 = gpu.alloc host_shared () : memref<8x32xbf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x32xbf16>, %memref_1 : memref<8x32xbf16>, %memref_2 : memref<8x32xbf16>) + %memref = gpu.alloc () : memref<8x32xbf16> + gpu.memcpy %memref, %arg0 : memref<8x32xbf16>, memref<8x32xbf16> + %memref_0 = gpu.alloc () : memref<8x32xbf16> + gpu.memcpy %memref_0, %arg1 : memref<8x32xbf16>, memref<8x32xbf16> + %memref_1 = gpu.alloc () : memref<8x32xbf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x32xbf16>, %memref_0 : memref<8x32xbf16>, %memref_1 : memref<8x32xbf16>) gpu.dealloc %memref : memref<8x32xbf16> + gpu.dealloc %memref_0 : memref<8x32xbf16> + %alloc = memref.alloc() : memref<8x32xbf16> + gpu.memcpy %alloc, %memref_1 : memref<8x32xbf16>, memref<8x32xbf16> gpu.dealloc %memref_1 : memref<8x32xbf16> - return %memref_2 : memref<8x32xbf16> + return %alloc : memref<8x32xbf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<8x32xbf16>, %arg1: memref<8x32xbf16>, %arg2: memref<8x32xbf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %thread_id_x = gpu.thread_id x + %thread_id_x = gpu.thread_id x cf.br ^bb1 - ^bb1: + ^bb1: // pred: ^bb0 %0 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32xbf16> -> vector<32xbf16> %2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x32xbf16> -> !xegpu.tensor_desc<32xbf16> @@ -38,45 +37,40 @@ module @gemm attributes {gpu.container_module} { } } func.func @main() attributes {llvm.emit_c_interface} { - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - %A = memref.alloc() : memref<8x32xbf16> - %A_random = memref.cast %A : memref<8x32xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - - %B = memref.alloc() : memref<8x32xbf16> - %B_random = memref.cast %B : memref<8x32xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - // calculate the result C matrix - %c32 = arith.constant 32 : index - %c8 = arith.constant 8 : index - %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %ref = memref.alloc() : memref<8x32xf32> - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %a = memref.load %A[%i, %j] : memref<8x32xbf16> - %b = memref.load %B[%i, %j] : memref<8x32xbf16> - %a_ext = arith.extf %a : bf16 to f32 - %b_ext = arith.extf %b : bf16 to f32 - %c = arith.addf %a_ext, %b_ext : f32 - %c_trunc = arith.truncf %c : f32 to bf16 - %c_ext = arith.extf %c_trunc : bf16 to f32 - memref.store %c_ext, %ref[%i, %j] : memref<8x32xf32> + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %false = arith.constant false + %cst = arith.constant -5.000000e-01 : f32 + %cst_0 = arith.constant 5.000000e-01 : f32 + %alloc = memref.alloc() : memref<8x32xbf16> + %cast = memref.cast %alloc : memref<8x32xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + %alloc_1 = memref.alloc() : memref<8x32xbf16> + %cast_2 = memref.cast %alloc_1 : memref<8x32xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast_2, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + %alloc_3 = memref.alloc() : memref<8x32xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<8x32xbf16> + %2 = memref.load %alloc_1[%arg0, %arg1] : memref<8x32xbf16> + %3 = arith.extf %1 : bf16 to f32 + %4 = arith.extf %2 : bf16 to f32 + %5 = arith.addf %3, %4 : f32 + %6 = arith.truncf %5 : f32 to bf16 + %7 = arith.extf %6 : bf16 to f32 + memref.store %7, %alloc_3[%arg0, %arg1] : memref<8x32xf32> } } - - %C = call @test(%A, %B) : (memref<8x32xbf16>, memref<8x32xbf16>) -> memref<8x32xbf16> - - %C_cast = memref.cast %C : memref<8x32xbf16> to memref<*xbf16> - %ref_cast = memref.cast %ref : memref<8x32xf32> to memref<*xf32> //call @printMemrefBF16(%C_cast) : (memref<*xbf16>) -> () //call @printMemrefF32(%ref_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseBF16(%C_cast, %ref_cast) : (memref<*xbf16>, memref<*xf32>) -> () + %0 = call @test(%alloc, %alloc_1) : (memref<8x32xbf16>, memref<8x32xbf16>) -> memref<8x32xbf16> + %cast_4 = memref.cast %0 : memref<8x32xbf16> to memref<*xbf16> + %cast_5 = memref.cast %alloc_3 : memref<8x32xf32> to memref<*xf32> + call @printAllcloseBF16(%cast_4, %cast_5) : (memref<*xbf16>, memref<*xf32>) -> () return } func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_tile.mlir b/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_tile.mlir index f92d7dc20..f27196979 100644 --- a/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_tile.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load_store_with_1d_tile.mlir @@ -1,32 +1,31 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<0.0> + memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<0.000000e+00> func.func @test(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index - - %memref = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %arg0, %memref : memref<8x16xf32> to memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %arg1, %memref_1 : memref<8x16xf32> to memref<8x16xf32> - %memref_2 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_1 : memref<8x16xf32>, %memref_2 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref_0, %arg1 : memref<8x16xf32>, memref<8x16xf32> + %memref_1 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>, %memref_1 : memref<8x16xf32>) gpu.dealloc %memref : memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_1 : memref<8x16xf32>, memref<8x16xf32> gpu.dealloc %memref_1 : memref<8x16xf32> - return %memref_2 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %thread_id_x = gpu.thread_id x + %thread_id_x = gpu.thread_id x cf.br ^bb1 - ^bb1: + ^bb1: // pred: ^bb0 %0 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<16xf32> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> %2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<16xf32> @@ -38,40 +37,35 @@ module @gemm attributes {gpu.container_module} { } } func.func @main() attributes {llvm.emit_c_interface} { - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - %A = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - - %B = memref.alloc() : memref<8x16xf32> - %B_random = memref.cast %B : memref<8x16xf32> to memref<*xf32> - call @fillResource1DRandomF32(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - // calculate the result C matrix - %c16 = arith.constant 16 : index - %c8 = arith.constant 8 : index - %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index - %ref = memref.alloc() : memref<8x16xf32> - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %a = memref.load %A[%i, %j] : memref<8x16xf32> - %b = memref.load %B[%i, %j] : memref<8x16xf32> - %c = arith.addf %a, %b : f32 - memref.store %c, %ref[%i, %j] : memref<8x16xf32> + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %false = arith.constant false + %cst = arith.constant -5.000000e-01 : f32 + %cst_0 = arith.constant 5.000000e-01 : f32 + %alloc = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %alloc_1 = memref.alloc() : memref<8x16xf32> + %cast_2 = memref.cast %alloc_1 : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast_2, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %alloc_3 = memref.alloc() : memref<8x16xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<8x16xf32> + %2 = memref.load %alloc_1[%arg0, %arg1] : memref<8x16xf32> + %3 = arith.addf %1, %2 : f32 + memref.store %3, %alloc_3[%arg0, %arg1] : memref<8x16xf32> } } - - %C = call @test(%A, %B) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> - - %C_cast = memref.cast %C : memref<8x16xf32> to memref<*xf32> - %ref_cast = memref.cast %ref : memref<8x16xf32> to memref<*xf32> // call @printMemrefF32(%C_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%ref_cast, %C_cast) : (memref<*xf32>, memref<*xf32>) -> () + %0 = call @test(%alloc, %alloc_1) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> + %cast_4 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_3 : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast_5, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/load_with_block_array_16_16_2.vc.mlir b/test/Integration/Dialect/XeGPU/VC/load_with_block_array_16_16_2.vc.mlir index 7e6c8a85e..c99eea25f 100644 --- a/test/Integration/Dialect/XeGPU/VC/load_with_block_array_16_16_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load_with_block_array_16_16_2.vc.mlir @@ -1,68 +1,62 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_16x32xf16 : memref<16x32xf16> = dense<5.000000e-01> func.func @test(%arg0: memref<16x32xf16>) -> memref<16x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<16x32xf16> - memref.copy %arg0, %memref : memref<16x32xf16> to memref<16x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_1 : memref<16x32xf32>) + %memref = gpu.alloc () : memref<16x32xf16> + gpu.memcpy %memref, %arg0 : memref<16x32xf16>, memref<16x32xf16> + %memref_0 = gpu.alloc () : memref<16x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_0 : memref<16x32xf32>) gpu.dealloc %memref : memref<16x32xf16> - return %memref_1 : memref<16x32xf32> + %alloc = memref.alloc() : memref<16x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<16x32xf32>, memref<16x32xf32> + gpu.dealloc %memref_0 : memref<16x32xf32> + return %alloc : memref<16x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref<16x32xf16>, %arg1: memref<16x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> - %1 = xegpu.load_nd %0 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> - %3 = arith.extf %1: vector<2x16x16xf16> to vector<2x16x16xf32> - %4 = vector.extract %3[0]: vector<16x16xf32> from vector<2x16x16xf32> - %5 = vector.extract %3[1]: vector<16x16xf32> from vector<2x16x16xf32> - %6 = vector.shape_cast %4: vector<16x16xf32> to vector<2x8x16xf32> - %7 = vector.shape_cast %5: vector<16x16xf32> to vector<2x8x16xf32> - - %8 = vector.extract %6[0]: vector<8x16xf32> from vector<2x8x16xf32> - %9 = vector.extract %6[1]: vector<8x16xf32> from vector<2x8x16xf32> - - %10 = vector.extract %7[0]: vector<8x16xf32> from vector<2x8x16xf32> - %11 = vector.extract %7[1]: vector<8x16xf32> from vector<2x8x16xf32> - - %12 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %13 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %14 = xegpu.create_nd_tdesc %arg1[8, 0] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %15 = xegpu.create_nd_tdesc %arg1[8, 16] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - xegpu.store_nd %8, %12 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %10, %13 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %9, %14 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %11, %15 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + gpu.module @test_kernel { // %16 = vector.extract %4[0, 0]: f32 from vector<16x16xf32> // %17 = vector.extract %5[0, 0]: f32 from vector<16x16xf32> // gpu.printf "\narray 0: %f, array 1: %f.\n" %16, %17: f32, f32 + gpu.func @test_kernel(%arg0: memref<16x32xf16>, %arg1: memref<16x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %1 = xegpu.load_nd %0 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> + %2 = arith.extf %1 : vector<2x16x16xf16> to vector<2x16x16xf32> + %3 = vector.extract %2[0] : vector<16x16xf32> from vector<2x16x16xf32> + %4 = vector.extract %2[1] : vector<16x16xf32> from vector<2x16x16xf32> + %5 = vector.shape_cast %3 : vector<16x16xf32> to vector<2x8x16xf32> + %6 = vector.shape_cast %4 : vector<16x16xf32> to vector<2x8x16xf32> + %7 = vector.extract %5[0] : vector<8x16xf32> from vector<2x8x16xf32> + %8 = vector.extract %5[1] : vector<8x16xf32> from vector<2x8x16xf32> + %9 = vector.extract %6[0] : vector<8x16xf32> from vector<2x8x16xf32> + %10 = vector.extract %6[1] : vector<8x16xf32> from vector<2x8x16xf32> + %11 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %12 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %13 = xegpu.create_nd_tdesc %arg1[8, 0] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %14 = xegpu.create_nd_tdesc %arg1[8, 16] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %7, %11 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %9, %12 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %8, %13 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %10, %14 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<16x32xf16> - %A_random = memref.cast %A : memref<16x32xf16> to memref<*xf16> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - %B = call @test(%A) : (memref<16x32xf16>) -> memref<16x32xf32> - %A_cast = memref.cast %A : memref<16x32xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<16x32xf32> to memref<*xf32> //call @printMemrefF32(%cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> () + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<16x32xf16> + %cast = memref.cast %alloc : memref<16x32xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %false) : (memref<*xf16>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<16x32xf16>) -> memref<16x32xf32> + %cast_1 = memref.cast %alloc : memref<16x32xf16> to memref<*xf16> + %cast_2 = memref.cast %0 : memref<16x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast_1, %cast_2) : (memref<*xf16>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/load_with_block_array_32_16_2.vc.mlir b/test/Integration/Dialect/XeGPU/VC/load_with_block_array_32_16_2.vc.mlir index 0ec7c69f5..d78cc8551 100644 --- a/test/Integration/Dialect/XeGPU/VC/load_with_block_array_32_16_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load_with_block_array_32_16_2.vc.mlir @@ -1,88 +1,73 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %arg0, %memref : memref<32x32xf16> to memref<32x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<32x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_1 : memref<32x32xf32>) + %memref = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf32>) gpu.dealloc %memref : memref<32x32xf16> - return %memref_1 : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref_0 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %1 = xegpu.load_nd %0 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} - : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> - %3 = arith.extf %1: vector<2x32x16xf16> to vector<2x32x16xf32> - %4 = vector.extract %3[0]: vector<32x16xf32> from vector<2x32x16xf32> - %5 = vector.extract %3[1]: vector<32x16xf32> from vector<2x32x16xf32> - %6 = vector.shape_cast %4: vector<32x16xf32> to vector<4x8x16xf32> - %7 = vector.shape_cast %5: vector<32x16xf32> to vector<4x8x16xf32> - - %10 = vector.extract %6[0]: vector<8x16xf32> from vector<4x8x16xf32> - %11 = vector.extract %6[1]: vector<8x16xf32> from vector<4x8x16xf32> - %12 = vector.extract %6[2]: vector<8x16xf32> from vector<4x8x16xf32> - %13 = vector.extract %6[3]: vector<8x16xf32> from vector<4x8x16xf32> - - %14 = vector.extract %7[0]: vector<8x16xf32> from vector<4x8x16xf32> - %15 = vector.extract %7[1]: vector<8x16xf32> from vector<4x8x16xf32> - %16 = vector.extract %7[2]: vector<8x16xf32> from vector<4x8x16xf32> - %17 = vector.extract %7[3]: vector<8x16xf32> from vector<4x8x16xf32> - - %20 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %21 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - %22 = xegpu.create_nd_tdesc %arg1[8, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %23 = xegpu.create_nd_tdesc %arg1[8, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - %24 = xegpu.create_nd_tdesc %arg1[16, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %25 = xegpu.create_nd_tdesc %arg1[16, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - %26 = xegpu.create_nd_tdesc %arg1[24, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %27 = xegpu.create_nd_tdesc %arg1[24, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - - xegpu.store_nd %10, %20 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %14, %21 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - xegpu.store_nd %11, %22 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %15, %23 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - xegpu.store_nd %12, %24 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %16, %25 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - - xegpu.store_nd %13, %26 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %17, %27 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + gpu.module @test_kernel { // %30 = vector.extract %4[0, 0]: f32 from vector<32x16xf32> // %31 = vector.extract %5[0, 0]: f32 from vector<32x16xf32> // gpu.printf "\narray 0: %f, array 1: %f.\n" %30, %31: f32, f32 - + gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %1 = xegpu.load_nd %0 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %2 = arith.extf %1 : vector<2x32x16xf16> to vector<2x32x16xf32> + %3 = vector.extract %2[0] : vector<32x16xf32> from vector<2x32x16xf32> + %4 = vector.extract %2[1] : vector<32x16xf32> from vector<2x32x16xf32> + %5 = vector.shape_cast %3 : vector<32x16xf32> to vector<4x8x16xf32> + %6 = vector.shape_cast %4 : vector<32x16xf32> to vector<4x8x16xf32> + %7 = vector.extract %5[0] : vector<8x16xf32> from vector<4x8x16xf32> + %8 = vector.extract %5[1] : vector<8x16xf32> from vector<4x8x16xf32> + %9 = vector.extract %5[2] : vector<8x16xf32> from vector<4x8x16xf32> + %10 = vector.extract %5[3] : vector<8x16xf32> from vector<4x8x16xf32> + %11 = vector.extract %6[0] : vector<8x16xf32> from vector<4x8x16xf32> + %12 = vector.extract %6[1] : vector<8x16xf32> from vector<4x8x16xf32> + %13 = vector.extract %6[2] : vector<8x16xf32> from vector<4x8x16xf32> + %14 = vector.extract %6[3] : vector<8x16xf32> from vector<4x8x16xf32> + %15 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %16 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %17 = xegpu.create_nd_tdesc %arg1[8, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %18 = xegpu.create_nd_tdesc %arg1[8, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %19 = xegpu.create_nd_tdesc %arg1[16, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %20 = xegpu.create_nd_tdesc %arg1[16, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %21 = xegpu.create_nd_tdesc %arg1[24, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %22 = xegpu.create_nd_tdesc %arg1[24, 16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %7, %15 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %11, %16 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %8, %17 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %12, %18 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %9, %19 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %13, %20 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %10, %21 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %14, %22 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<32x32xf16> - %A_random = memref.cast %A : memref<32x32xf16> to memref<*xf16> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - %B = call @test(%A) : (memref<32x32xf16>) -> memref<32x32xf32> - %A_cast = memref.cast %A : memref<32x32xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<32x32xf32> to memref<*xf32> // call @printMemrefF32(%B_cast): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> () + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<32x32xf16> + %cast = memref.cast %alloc : memref<32x32xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %false) : (memref<*xf16>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<32x32xf16>) -> memref<32x32xf32> + %cast_1 = memref.cast %alloc : memref<32x32xf16> to memref<*xf16> + %cast_2 = memref.cast %0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast_1, %cast_2) : (memref<*xf16>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/load_with_block_array_8_16_2.vc.mlir b/test/Integration/Dialect/XeGPU/VC/load_with_block_array_8_16_2.vc.mlir index cf3898668..e9d6fd6ff 100644 --- a/test/Integration/Dialect/XeGPU/VC/load_with_block_array_8_16_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/load_with_block_array_8_16_2.vc.mlir @@ -1,51 +1,50 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { memref.global "private" constant @__constant_8x32xf16 : memref<8x32xf16> = dense<5.000000e-01> func.func @test(%arg0: memref<8x32xf16>) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x32xf16> - memref.copy %arg0, %memref : memref<8x32xf16> to memref<8x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<8x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_1 : memref<8x32xf32>) + %memref = gpu.alloc () : memref<8x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<8x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<8x32xf32>) gpu.dealloc %memref : memref<8x32xf16> - return %memref_1 : memref<8x32xf32> + %alloc = memref.alloc() : memref<8x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x32xf32>, memref<8x32xf32> + gpu.dealloc %memref_0 : memref<8x32xf32> + return %alloc : memref<8x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<8x32xf16>, %arg1: memref<8x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> %2 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %3 = xegpu.load_nd %0 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> -> vector<2x8x16xf16> - %4 = vector.extract %3[0]: vector<8x16xf16> from vector<2x8x16xf16> - %5 = vector.extract %3[1]: vector<8x16xf16> from vector<2x8x16xf16> - %8 = arith.extf %4: vector<8x16xf16> to vector<8x16xf32> - %9 = arith.extf %5: vector<8x16xf16> to vector<8x16xf32> - xegpu.store_nd %8, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %9, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %0 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> -> vector<2x8x16xf16> + %4 = vector.extract %3[0] : vector<8x16xf16> from vector<2x8x16xf16> + %5 = vector.extract %3[1] : vector<8x16xf16> from vector<2x8x16xf16> + %6 = arith.extf %4 : vector<8x16xf16> to vector<8x16xf32> + %7 = arith.extf %5 : vector<8x16xf16> to vector<8x16xf32> + xegpu.store_nd %6, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %7, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<8x32xf16> - %A_random = memref.cast %A : memref<8x32xf16> to memref<*xf16> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - - %B = call @test(%A) : (memref<8x32xf16>) -> memref<8x32xf32> - %A_cast = memref.cast %A : memref<8x32xf16> to memref<*xf16> - %B_cast = memref.cast %B : memref<8x32xf32> to memref<*xf32> // call @printMemrefF32(%cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> () + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x32xf16> + %cast = memref.cast %alloc : memref<8x32xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %false) : (memref<*xf16>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<8x32xf16>) -> memref<8x32xf32> + %cast_1 = memref.cast %alloc : memref<8x32xf16> to memref<*xf16> + %cast_2 = memref.cast %0 : memref<8x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast_1, %cast_2) : (memref<*xf16>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/loadgather2d_masked_f32.mlir b/test/Integration/Dialect/XeGPU/VC/loadgather2d_masked_f32.mlir index ab5431bcc..0ce69d631 100644 --- a/test/Integration/Dialect/XeGPU/VC/loadgather2d_masked_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/loadgather2d_masked_f32.mlir @@ -1,75 +1,62 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_1_3x16xf32 : memref<3x16xf32> = dense<1.0> - memref.global "private" constant @__constant_3_3x16xf32 : memref<3x16xf32> = dense<3.0> +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_1_3x16xf32 : memref<3x16xf32> = dense<1.000000e+00> + memref.global "private" constant @__constant_3_3x16xf32 : memref<3x16xf32> = dense<3.000000e+00> func.func @test(%arg0: memref<3x16xf32>, %arg1: memref<3x16xf32>) -> memref<3x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<3x16xf32> - memref.copy %arg0, %memref : memref<3x16xf32> to memref<3x16xf32> - %memref1 = gpu.alloc host_shared () : memref<3x16xf32> - memref.copy %arg1, %memref1 : memref<3x16xf32> to memref<3x16xf32> - // Spirv has no lowering for memref.subview - %in = memref.reinterpret_cast %memref to offset: [0], sizes: [48], strides: [1] : memref<3x16xf32> to memref<48xf32> - %out = memref.reinterpret_cast %memref1 to offset: [0], sizes: [48], strides: [1] : memref<3x16xf32> to memref<48xf32> - - %memref_dyn = memref.cast %in : memref<48xf32> to memref - %memref1_dyn = memref.cast %out : memref<48xf32> to memref - - gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_dyn : memref, %memref1_dyn : memref) + %memref = gpu.alloc () : memref<3x16xf32> + gpu.memcpy %memref, %arg0 : memref<3x16xf32>, memref<3x16xf32> + %memref_0 = gpu.alloc () : memref<3x16xf32> + gpu.memcpy %memref_0, %arg1 : memref<3x16xf32>, memref<3x16xf32> + %reinterpret_cast = memref.reinterpret_cast %memref to offset: [0], sizes: [48], strides: [1] : memref<3x16xf32> to memref<48xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %memref_0 to offset: [0], sizes: [48], strides: [1] : memref<3x16xf32> to memref<48xf32> + %cast = memref.cast %reinterpret_cast : memref<48xf32> to memref + %cast_2 = memref.cast %reinterpret_cast_1 : memref<48xf32> to memref + gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%cast : memref, %cast_2 : memref) gpu.dealloc %memref : memref<3x16xf32> - return %memref1 : memref<3x16xf32> + %alloc = memref.alloc() : memref<3x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<3x16xf32>, memref<3x16xf32> + gpu.dealloc %memref_0 : memref<3x16xf32> + return %alloc : memref<3x16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_scattered(%arg0: memref, %arg1: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { // This test emulates 2D load with user defined padding // We load rows with %row_mask that has 0's to not cross the boundary. // We pad the values that were not loaded (as per %row_mask) with %user_val. // We store full (padded) rows with %store_mask. %c0 = arith.constant 0 : index - %store_mask = arith.constant dense<[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]> : vector<16xi1> - %user_val = arith.constant dense<22.33> : vector<16xf32> - %row_mask = arith.constant dense<[1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0]> : vector<16xi1> - %offset_step = arith.constant dense<16>: vector<16xindex> - // Spirv has no lowering for memref.reinterpret_cast with different sizes (doesn't work: memref<3x16xf32> to memref<16xf32>) // Each row has a tdesc with offsets that determine linearized memref's values to be loaded - %offsets_row1 = arith.constant dense<[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]> : vector<16xindex> - %row_1_in_td = xegpu.create_tdesc %arg0, %offsets_row1 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - %row_1_out_td = xegpu.create_tdesc %arg1, %offsets_row1 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - %row_1_loaded = xegpu.load %row_1_in_td, %row_mask : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - %row_1_store = arith.select %row_mask, %row_1_loaded, %user_val : vector<16xi1>, vector<16xf32> - xegpu.store %row_1_store, %row_1_out_td, %store_mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> - - %row_2_in_td = xegpu.update_offset %row_1_in_td, %offset_step : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> - %row_2_out_td = xegpu.update_offset %row_1_out_td, %offset_step : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> - %row_2_loaded = xegpu.load %row_2_in_td, %row_mask : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - %row_2_store = arith.select %row_mask, %row_2_loaded, %user_val : vector<16xi1>, vector<16xf32> - xegpu.store %row_2_store, %row_2_out_td, %store_mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> - // The entire row is out of bounds - %row_3_out_td = xegpu.update_offset %row_2_out_td, %offset_step : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> - xegpu.store %user_val, %row_3_out_td, %store_mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + %cst = arith.constant dense : vector<16xi1> + %cst_0 = arith.constant dense<2.233000e+01> : vector<16xf32> + %cst_1 = arith.constant dense<[true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, false]> : vector<16xi1> + %cst_2 = arith.constant dense<16> : vector<16xindex> + %cst_3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg0, %cst_3 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.create_tdesc %arg1, %cst_3 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %2 = xegpu.load %0, %cst_1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + %3 = arith.select %cst_1, %2, %cst_0 : vector<16xi1>, vector<16xf32> + xegpu.store %3, %1, %cst : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + %4 = xegpu.update_offset %0, %cst_2 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> + %5 = xegpu.update_offset %1, %cst_2 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> + %6 = xegpu.load %4, %cst_1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + %7 = arith.select %cst_1, %6, %cst_0 : vector<16xi1>, vector<16xf32> + xegpu.store %7, %5, %cst : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + %8 = xegpu.update_offset %5, %cst_2 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex> + xegpu.store %cst_0, %8, %cst : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_1_3x16xf32 : memref<3x16xf32> %1 = memref.get_global @__constant_3_3x16xf32 : memref<3x16xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %2 = call @test(%0, %1) : (memref<3x16xf32>, memref<3x16xf32>) -> memref<3x16xf32> %cast = memref.cast %2 : memref<3x16xf32> to memref<*xf32> // CHECK: Unranked Memref base@ = 0x{{.*}} rank = 2 offset = 0 sizes = [3, 16] strides = [16, 1] data = diff --git a/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_f32.mlir b/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_f32.mlir index d1d7a4ba5..1df42c327 100644 --- a/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_f32.mlir @@ -1,55 +1,44 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_1_4x16xf32 : memref<4x16xf32> = dense<1.1> - memref.global "private" constant @__constant_3_4x16xf32 : memref<4x16xf32> = dense<3.0> +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_1_4x16xf32 : memref<4x16xf32> = dense<1.100000e+00> + memref.global "private" constant @__constant_3_4x16xf32 : memref<4x16xf32> = dense<3.000000e+00> func.func @test(%arg0: memref<4x16xf32>, %arg1: memref<4x16xf32>) -> memref<4x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<4x16xf32> - memref.copy %arg0, %memref : memref<4x16xf32> to memref<4x16xf32> - %memref1 = gpu.alloc host_shared () : memref<4x16xf32> - memref.copy %arg1, %memref1 : memref<4x16xf32> to memref<4x16xf32> - - %in = memref.reinterpret_cast %memref to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> - %out = memref.reinterpret_cast %memref1 to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> - - %memref_dyn = memref.cast %in : memref<64xf32> to memref - %memref1_dyn = memref.cast %out : memref<64xf32> to memref - - gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_dyn : memref, %memref1_dyn : memref) + %memref = gpu.alloc () : memref<4x16xf32> + gpu.memcpy %memref, %arg0 : memref<4x16xf32>, memref<4x16xf32> + %memref_0 = gpu.alloc () : memref<4x16xf32> + gpu.memcpy %memref_0, %arg1 : memref<4x16xf32>, memref<4x16xf32> + %reinterpret_cast = memref.reinterpret_cast %memref to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %memref_0 to offset: [0], sizes: [64], strides: [1] : memref<4x16xf32> to memref<64xf32> + %cast = memref.cast %reinterpret_cast : memref<64xf32> to memref + %cast_2 = memref.cast %reinterpret_cast_1 : memref<64xf32> to memref + gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%cast : memref, %cast_2 : memref) gpu.dealloc %memref : memref<4x16xf32> - return %memref1 : memref<4x16xf32> + %alloc = memref.alloc() : memref<4x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<4x16xf32>, memref<4x16xf32> + gpu.dealloc %memref_0 : memref<4x16xf32> + return %alloc : memref<4x16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_scattered(%in: memref, %out: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // We have 16 work items, each accesses 2 elements: {chunk_size = 2}, hence 16x2 tensor. // Valid offsets (%offsets for which %mask is 1) should not exceed 16*2=32. - %offsets = arith.constant dense<[0,4,8,12,16,20,24,28,32,34,38,42,46,50,54,58]> : vector<16xindex> - %mask = arith.constant dense<[1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0]> : vector<16xi1> - %tdesc_in = xegpu.create_tdesc %in, %offsets : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> - %tdesc_out = xegpu.create_tdesc %out, %offsets : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> - %loaded = xegpu.load %tdesc_in, %mask : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> - xegpu.store %loaded, %tdesc_out, %mask : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + gpu.func @test_scattered(%arg0: memref, %arg1: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 34, 38, 42, 46, 50, 54, 58]> : vector<16xindex> + %cst_0 = arith.constant dense<[true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false]> : vector<16xi1> + %0 = xegpu.create_tdesc %arg0, %cst : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + %1 = xegpu.create_tdesc %arg1, %cst : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr> + %2 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xf32> + xegpu.store %2, %1, %cst_0 : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_1_4x16xf32 : memref<4x16xf32> %1 = memref.get_global @__constant_3_4x16xf32 : memref<4x16xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %2 = call @test(%0, %1) : (memref<4x16xf32>, memref<4x16xf32>) -> memref<4x16xf32> %cast = memref.cast %2 : memref<4x16xf32> to memref<*xf32> // CHECK: Unranked Memref base@ = 0x{{.*}} rank = 2 offset = 0 sizes = [4, 16] strides = [16, 1] data = diff --git a/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_i32.mlir b/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_i32.mlir index 2ebed221a..3e71721d5 100644 --- a/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_i32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/loadgather_chunk_size_i32.mlir @@ -1,55 +1,44 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_1_4x16xi32 : memref<4x16xi32> = dense<[[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1], [2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2], [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3], [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4]]> + memref.global "private" constant @__constant_1_4x16xi32 : memref<4x16xi32> = dense<[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]> memref.global "private" constant @__constant_3_4x16xi32 : memref<4x16xi32> = dense<8> - func.func @test(%arg0: memref<4x16xi32>, %arg1: memref<4x16xi32>) -> memref<4x16xi32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<4x16xi32> - memref.copy %arg0, %memref : memref<4x16xi32> to memref<4x16xi32> - %memref1 = gpu.alloc host_shared () : memref<4x16xi32> - memref.copy %arg1, %memref1 : memref<4x16xi32> to memref<4x16xi32> - - %in = memref.reinterpret_cast %memref to offset: [0], sizes: [64], strides: [1] : memref<4x16xi32> to memref<64xi32> - %out = memref.reinterpret_cast %memref1 to offset: [0], sizes: [64], strides: [1] : memref<4x16xi32> to memref<64xi32> - - %memref_dyn = memref.cast %in : memref<64xi32> to memref - %memref1_dyn = memref.cast %out : memref<64xi32> to memref - - gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_dyn : memref, %memref1_dyn : memref) + %memref = gpu.alloc () : memref<4x16xi32> + gpu.memcpy %memref, %arg0 : memref<4x16xi32>, memref<4x16xi32> + %memref_0 = gpu.alloc () : memref<4x16xi32> + gpu.memcpy %memref_0, %arg1 : memref<4x16xi32>, memref<4x16xi32> + %reinterpret_cast = memref.reinterpret_cast %memref to offset: [0], sizes: [64], strides: [1] : memref<4x16xi32> to memref<64xi32> + %reinterpret_cast_1 = memref.reinterpret_cast %memref_0 to offset: [0], sizes: [64], strides: [1] : memref<4x16xi32> to memref<64xi32> + %cast = memref.cast %reinterpret_cast : memref<64xi32> to memref + %cast_2 = memref.cast %reinterpret_cast_1 : memref<64xi32> to memref + gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%cast : memref, %cast_2 : memref) gpu.dealloc %memref : memref<4x16xi32> - return %memref1 : memref<4x16xi32> + %alloc = memref.alloc() : memref<4x16xi32> + gpu.memcpy %alloc, %memref_0 : memref<4x16xi32>, memref<4x16xi32> + gpu.dealloc %memref_0 : memref<4x16xi32> + return %alloc : memref<4x16xi32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_scattered(%in: memref, %out: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { // We have 16 work items, each accesses 2 elements: {chunk_size = 2}, hence 16x2 tensor. // Valid offsets (%offsets for which %mask is 1) should not exceed 16*2=32. - %offsets = arith.constant dense<[0,4,8,12,16,20,24,28,32,34,38,42,46,50,54,58]> : vector<16xindex> - %mask = arith.constant dense<[1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0]> : vector<16xi1> - %tdesc_in = xegpu.create_tdesc %in, %offsets : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr> - %tdesc_out = xegpu.create_tdesc %out, %offsets : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr> - %loaded = xegpu.load %tdesc_in, %mask : !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xi32> - xegpu.store %loaded, %tdesc_out, %mask : vector<16x2xi32>, !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + gpu.func @test_scattered(%arg0: memref, %arg1: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 34, 38, 42, 46, 50, 54, 58]> : vector<16xindex> + %cst_0 = arith.constant dense<[true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false]> : vector<16xi1> + %0 = xegpu.create_tdesc %arg0, %cst : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr> + %1 = xegpu.create_tdesc %arg1, %cst : memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr> + %2 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x2xi32> + xegpu.store %2, %1, %cst_0 : vector<16x2xi32>, !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %0 = memref.get_global @__constant_1_4x16xi32 : memref<4x16xi32> %1 = memref.get_global @__constant_3_4x16xi32 : memref<4x16xi32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %2 = call @test(%0, %1) : (memref<4x16xi32>, memref<4x16xi32>) -> memref<4x16xi32> %cast = memref.cast %2 : memref<4x16xi32> to memref<*xi32> // CHECK: Unranked Memref base@ = 0x{{.*}} rank = 2 offset = 0 sizes = [4, 16] strides = [16, 1] data = diff --git a/test/Integration/Dialect/XeGPU/VC/loadgather_f32.mlir b/test/Integration/Dialect/XeGPU/VC/loadgather_f32.mlir index 49243399b..0c8138ad0 100644 --- a/test/Integration/Dialect/XeGPU/VC/loadgather_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/loadgather_f32.mlir @@ -1,49 +1,46 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_1x16xf32 : memref<1x16xf32> = dense<1.1> - memref.global "private" constant @__constant_3x16xf32 : memref<1x16xf32> = dense<3.0> +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_1x16xf32 : memref<1x16xf32> = dense<1.100000e+00> + memref.global "private" constant @__constant_3x16xf32 : memref<1x16xf32> = dense<3.000000e+00> func.func @test(%arg0: memref<1x16xf32>, %arg1: memref<1x16xf32>) -> memref<1x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1x16xf32> - memref.copy %arg0, %memref : memref<1x16xf32> to memref<1x16xf32> - %memref1 = gpu.alloc host_shared () : memref<1x16xf32> - memref.copy %arg1, %memref1 : memref<1x16xf32> to memref<1x16xf32> - gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1x16xf32>, %memref1 : memref<1x16xf32>) + %memref = gpu.alloc () : memref<1x16xf32> + gpu.memcpy %memref, %arg0 : memref<1x16xf32>, memref<1x16xf32> + %memref_0 = gpu.alloc () : memref<1x16xf32> + gpu.memcpy %memref_0, %arg1 : memref<1x16xf32>, memref<1x16xf32> + gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1x16xf32>, %memref_0 : memref<1x16xf32>) gpu.dealloc %memref : memref<1x16xf32> - return %memref1 : memref<1x16xf32> + %alloc = memref.alloc() : memref<1x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<1x16xf32>, memref<1x16xf32> + gpu.dealloc %memref_0 : memref<1x16xf32> + return %alloc : memref<1x16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_scattered(%arg0: memref<1x16xf32>, %arg1: memref<1x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index - %offsets = arith.constant dense<[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]> : vector<16xindex> - %mask = arith.constant dense<[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]> : vector<16xi1> - %1 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> - %2 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> - %tdesc1 = xegpu.create_tdesc %1, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - %tdesc2 = xegpu.create_tdesc %2, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - %loaded = xegpu.load %tdesc1, %mask : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - xegpu.store %loaded, %tdesc2, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %cst_0 = arith.constant dense : vector<16xi1> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> + %0 = xegpu.create_tdesc %reinterpret_cast, %cst : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.create_tdesc %reinterpret_cast_1, %cst : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %2 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + xegpu.store %2, %1, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1x16xf32 : memref<1x16xf32> %1 = memref.get_global @__constant_3x16xf32 : memref<1x16xf32> - %c0 = arith.constant 0 : index %2 = call @test(%0, %1) : (memref<1x16xf32>, memref<1x16xf32>) -> memref<1x16xf32> - %vector_0 = vector.load %2[%c0,%c0] :memref<1x16xf32>, vector<16xf32> // CHECK: ( 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1 ) - vector.print %vector_0 : vector<16xf32> + %3 = vector.load %2[%c0, %c0] : memref<1x16xf32>, vector<16xf32> + vector.print %3 : vector<16xf32> return } } diff --git a/test/Integration/Dialect/XeGPU/VC/loadgather_masked_f32.mlir b/test/Integration/Dialect/XeGPU/VC/loadgather_masked_f32.mlir index 0db4bdd5f..d01614352 100644 --- a/test/Integration/Dialect/XeGPU/VC/loadgather_masked_f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/loadgather_masked_f32.mlir @@ -1,49 +1,46 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module} { - memref.global "private" constant @__constant_1x16xf32 : memref<1x16xf32> = dense<1.0> - memref.global "private" constant @__constant_3x16xf32 : memref<1x16xf32> = dense<3.3> +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +module @gemm attributes {gpu.container_module} { + memref.global "private" constant @__constant_1x16xf32 : memref<1x16xf32> = dense<1.000000e+00> + memref.global "private" constant @__constant_3x16xf32 : memref<1x16xf32> = dense<3.300000e+00> func.func @test(%arg0: memref<1x16xf32>, %arg1: memref<1x16xf32>) -> memref<1x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<1x16xf32> - memref.copy %arg0, %memref : memref<1x16xf32> to memref<1x16xf32> - %memref1 = gpu.alloc host_shared () : memref<1x16xf32> - memref.copy %arg1, %memref1 : memref<1x16xf32> to memref<1x16xf32> - gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1x16xf32>, %memref1 : memref<1x16xf32>) + %memref = gpu.alloc () : memref<1x16xf32> + gpu.memcpy %memref, %arg0 : memref<1x16xf32>, memref<1x16xf32> + %memref_0 = gpu.alloc () : memref<1x16xf32> + gpu.memcpy %memref_0, %arg1 : memref<1x16xf32>, memref<1x16xf32> + gpu.launch_func @test_kernel::@test_scattered blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1x16xf32>, %memref_0 : memref<1x16xf32>) gpu.dealloc %memref : memref<1x16xf32> - return %memref1 : memref<1x16xf32> + %alloc = memref.alloc() : memref<1x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<1x16xf32>, memref<1x16xf32> + gpu.dealloc %memref_0 : memref<1x16xf32> + return %alloc : memref<1x16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_scattered(%arg0: memref<1x16xf32>, %arg1: memref<1x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index - %offsets = arith.constant dense<[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]> : vector<16xindex> - %mask = arith.constant dense<[1,1,1,0,1,1,1,1,0,1,1,1,1,0,1,1]> : vector<16xi1> - %1 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> - %2 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> - %tdesc1 = xegpu.create_tdesc %1, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - %tdesc2 = xegpu.create_tdesc %2, %offsets : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> - %loaded = xegpu.load %tdesc1, %mask : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> - xegpu.store %loaded, %tdesc2, %mask : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + %cst_0 = arith.constant dense<[true, true, true, false, true, true, true, true, false, true, true, true, true, false, true, true]> : vector<16xi1> + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [16], strides: [1] : memref<1x16xf32> to memref<16xf32> + %0 = xegpu.create_tdesc %reinterpret_cast, %cst : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.create_tdesc %reinterpret_cast_1, %cst : memref<16xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %2 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + xegpu.store %2, %1, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index %0 = memref.get_global @__constant_1x16xf32 : memref<1x16xf32> %1 = memref.get_global @__constant_3x16xf32 : memref<1x16xf32> - %c0 = arith.constant 0 : index %2 = call @test(%0, %1) : (memref<1x16xf32>, memref<1x16xf32>) -> memref<1x16xf32> - %vector_0 = vector.load %2[%c0,%c0] :memref<1x16xf32>, vector<16xf32> // CHECK: ( 1, 1, 1, 3.3, 1, 1, 1, 1, 3.3, 1, 1, 1, 1, 3.3, 1, 1 ) - vector.print %vector_0 : vector<16xf32> + %3 = vector.load %2[%c0, %c0] : memref<1x16xf32>, vector<16xf32> + vector.print %3 : vector<16xf32> return } } diff --git a/test/Integration/Dialect/XeGPU/VC/optimize_transpose.mlir b/test/Integration/Dialect/XeGPU/VC/optimize_transpose.mlir index bf273cb5f..9d64e35c3 100644 --- a/test/Integration/Dialect/XeGPU/VC/optimize_transpose.mlir +++ b/test/Integration/Dialect/XeGPU/VC/optimize_transpose.mlir @@ -1,28 +1,28 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %memref = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %arg0, %memref : memref<256x256xf16> to memref<256x256xf16> - %memref_0 = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %arg1, %memref_0 : memref<256x256xf16> to memref<256x256xf16> - %memref_1 = gpu.alloc host_shared () : memref<256x256xf32> - memref.copy %arg2, %memref_1 : memref<256x256xf32> to memref<256x256xf32> + %memref = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref, %arg0 : memref<256x256xf16>, memref<256x256xf16> + %memref_0 = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_0, %arg1 : memref<256x256xf16>, memref<256x256xf16> + %memref_1 = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_1, %arg2 : memref<256x256xf32>, memref<256x256xf32> gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c8, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<256x256xf16>, %memref_0 : memref<256x256xf16>, %memref_1 : memref<256x256xf32>) gpu.dealloc %memref : memref<256x256xf16> gpu.dealloc %memref_0 : memref<256x256xf16> - return %memref_1 : memref<256x256xf32> + %alloc = memref.alloc() : memref<256x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_1 : memref<256x256xf32> + return %alloc : memref<256x256xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index @@ -34,34 +34,34 @@ module @gemm attributes {gpu.container_module} { %1 = arith.muli %block_id_y, %c32 : index %2 = arith.addi %0, %c0 : index %3 = arith.addi %1, %c0 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> %5 = arith.addi %1, %c16 : index - %6 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %6 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> %c8 = arith.constant 8 : index %7 = arith.addi %0, %c8 : index - %8 = xegpu.create_nd_tdesc %arg2[%7, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %9 = xegpu.create_nd_tdesc %arg2[%7, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %10 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> - %11 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> - %12 = xegpu.load_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> -> vector<16x16xf32> - %13 = xegpu.load_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> -> vector<16x16xf32> - %14 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> - %15 = xegpu.create_nd_tdesc %arg1[%3, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %16 = xegpu.create_nd_tdesc %arg1[%3, %c16] : memref<256x256xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %17:5 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %14, %arg5 = %15, %arg6 = %16, %arg7 = %12, %arg8 = %13) -> (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<16x16xf32>, vector<16x16xf32>) { + %8 = xegpu.create_nd_tdesc %arg2[%7, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + %9 = xegpu.create_nd_tdesc %arg2[%7, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + %10 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32> + %11 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32> + %12 = xegpu.load_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> + %13 = xegpu.load_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> + %14 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %15 = xegpu.create_nd_tdesc %arg1[%3, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<32x16xf16> + %16 = xegpu.create_nd_tdesc %arg1[%3, %c16] : memref<256x256xf16> -> !xegpu.tensor_desc<32x16xf16> + %17:5 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %14, %arg5 = %15, %arg6 = %16, %arg7 = %12, %arg8 = %13) -> (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16>, !xegpu.tensor_desc<32x16xf16>, vector<16x16xf32>, vector<16x16xf32>) { %22 = vector.extract_strided_slice %arg7 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %23 = vector.extract_strided_slice %arg7 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %24 = vector.extract_strided_slice %arg8 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %25 = vector.extract_strided_slice %arg8 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> - %26 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> + %26 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> %27 = vector.extract %26[0] : vector<16x16xf16> from vector<2x16x16xf16> %28 = vector.extract %26[1] : vector<16x16xf16> from vector<2x16x16xf16> %29 = vector.extract_strided_slice %27 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %30 = vector.extract_strided_slice %27 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %31 = vector.extract_strided_slice %28 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %32 = vector.extract_strided_slice %28 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> - %33 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<32x16xf16> - %34 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<32x16xf16> + %33 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16> -> vector<32x16xf16> + %34 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16> -> vector<32x16xf16> %35 = vector.transpose %33, [1, 0] : vector<32x16xf16> to vector<16x32xf16> %36 = vector.shape_cast %35 {packed} : vector<16x32xf16> to vector<512xf16> %37 = vector.shuffle %36, %36 [0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47, 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63, 64, 96, 65, 97, 66, 98, 67, 99, 68, 100, 69, 101, 70, 102, 71, 103, 72, 104, 73, 105, 74, 106, 75, 107, 76, 108, 77, 109, 78, 110, 79, 111, 80, 112, 81, 113, 82, 114, 83, 115, 84, 116, 85, 117, 86, 118, 87, 119, 88, 120, 89, 121, 90, 122, 91, 123, 92, 124, 93, 125, 94, 126, 95, 127, 128, 160, 129, 161, 130, 162, 131, 163, 132, 164, 133, 165, 134, 166, 135, 167, 136, 168, 137, 169, 138, 170, 139, 171, 140, 172, 141, 173, 142, 174, 143, 175, 144, 176, 145, 177, 146, 178, 147, 179, 148, 180, 149, 181, 150, 182, 151, 183, 152, 184, 153, 185, 154, 186, 155, 187, 156, 188, 157, 189, 158, 190, 159, 191, 192, 224, 193, 225, 194, 226, 195, 227, 196, 228, 197, 229, 198, 230, 199, 231, 200, 232, 201, 233, 202, 234, 203, 235, 204, 236, 205, 237, 206, 238, 207, 239, 208, 240, 209, 241, 210, 242, 211, 243, 212, 244, 213, 245, 214, 246, 215, 247, 216, 248, 217, 249, 218, 250, 219, 251, 220, 252, 221, 253, 222, 254, 223, 255, 256, 288, 257, 289, 258, 290, 259, 291, 260, 292, 261, 293, 262, 294, 263, 295, 264, 296, 265, 297, 266, 298, 267, 299, 268, 300, 269, 301, 270, 302, 271, 303, 272, 304, 273, 305, 274, 306, 275, 307, 276, 308, 277, 309, 278, 310, 279, 311, 280, 312, 281, 313, 282, 314, 283, 315, 284, 316, 285, 317, 286, 318, 287, 319, 320, 352, 321, 353, 322, 354, 323, 355, 324, 356, 325, 357, 326, 358, 327, 359, 328, 360, 329, 361, 330, 362, 331, 363, 332, 364, 333, 365, 334, 366, 335, 367, 336, 368, 337, 369, 338, 370, 339, 371, 340, 372, 341, 373, 342, 374, 343, 375, 344, 376, 345, 377, 346, 378, 347, 379, 348, 380, 349, 381, 350, 382, 351, 383, 384, 416, 385, 417, 386, 418, 387, 419, 388, 420, 389, 421, 390, 422, 391, 423, 392, 424, 393, 425, 394, 426, 395, 427, 396, 428, 397, 429, 398, 430, 399, 431, 400, 432, 401, 433, 402, 434, 403, 435, 404, 436, 405, 437, 406, 438, 407, 439, 408, 440, 409, 441, 410, 442, 411, 443, 412, 444, 413, 445, 414, 446, 415, 447, 448, 480, 449, 481, 450, 482, 451, 483, 452, 484, 453, 485, 454, 486, 455, 487, 456, 488, 457, 489, 458, 490, 459, 491, 460, 492, 461, 493, 462, 494, 463, 495, 464, 496, 465, 497, 466, 498, 467, 499, 468, 500, 469, 501, 470, 502, 471, 503, 472, 504, 473, 505, 474, 506, 475, 507, 476, 508, 477, 509, 478, 510, 479, 511] {packed} : vector<512xf16>, vector<512xf16> @@ -84,37 +84,38 @@ module @gemm attributes {gpu.container_module} { %54 = xegpu.dpas %32, %46, %53 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> %55 = vector.shuffle %48, %52 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> %56 = vector.shuffle %50, %54 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8x16xf32>, vector<8x16xf32> - %57 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> - %58 = xegpu.update_nd_offset %arg5, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %59 = xegpu.update_nd_offset %arg6, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - scf.yield %57, %58, %59, %55, %56 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<16x16xf32>, vector<16x16xf32> + %57 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %58 = xegpu.update_nd_offset %arg5, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16> + %59 = xegpu.update_nd_offset %arg6, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16> + scf.yield %57, %58, %59, %55, %56 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16>, !xegpu.tensor_desc<32x16xf16>, vector<16x16xf32>, vector<16x16xf32> } %18 = vector.extract_strided_slice %17#3 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %19 = vector.extract_strided_slice %17#3 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %20 = vector.extract_strided_slice %17#4 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %21 = vector.extract_strided_slice %17#4 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> - xegpu.store_nd %18, %4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %20, %6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %19, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %21, %9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + xegpu.store_nd %18, %4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %20, %6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %19, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %21, %9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f16 - %cst_0 = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 %alloc = memref.alloc() : memref<256x256xf16> - %alloc_1 = memref.alloc() : memref<256x256xf16> - %alloc_2 = memref.alloc() : memref<256x256xf32> + %alloc_2 = memref.alloc() : memref<256x256xf16> %alloc_3 = memref.alloc() : memref<256x256xf32> + %alloc_4 = memref.alloc() : memref<256x256xf32> scf.for %arg0 = %c0 to %c256 step %c1 { scf.for %arg1 = %c0 to %c256 step %c1 { %1 = index.castu %arg1 : index to i16 %2 = arith.uitofp %1 : i16 to f16 - memref.store %2, %alloc_1[%arg0, %arg1] : memref<256x256xf16> + memref.store %2, %alloc_2[%arg0, %arg1] : memref<256x256xf16> } } scf.for %arg0 = %c0 to %c256 step %c1 { @@ -123,43 +124,42 @@ module @gemm attributes {gpu.container_module} { %2 = index.castu %arg1 : index to i32 %3 = arith.cmpi eq, %1, %2 : i32 scf.if %3 { - memref.store %cst_0, %alloc[%arg0, %arg1] : memref<256x256xf16> + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<256x256xf16> } else { - memref.store %cst, %alloc[%arg0, %arg1] : memref<256x256xf16> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<256x256xf16> } } } - %cst_4 = arith.constant 0.000000e+00 : f32 scf.for %arg0 = %c0 to %c256 step %c1 { scf.for %arg1 = %c0 to %c256 step %c1 { - memref.store %cst_4, %alloc_2[%arg0, %arg1] : memref<256x256xf32> - memref.store %cst_4, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } scf.for %arg0 = %c0 to %c256 step %c1 { scf.for %arg1 = %c0 to %c256 step %c1 { - %1 = memref.load %alloc_3[%arg0, %arg1] : memref<256x256xf32> + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<256x256xf32> %2 = scf.for %arg2 = %c0 to %c256 step %c1 iter_args(%arg3 = %1) -> (f32) { %3 = memref.load %alloc[%arg0, %arg2] : memref<256x256xf16> - %4 = memref.load %alloc_1[%arg1, %arg2] : memref<256x256xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<256x256xf16> %5 = arith.mulf %3, %4 : f16 %6 = arith.extf %5 : f16 to f32 %7 = arith.addf %6, %arg3 : f32 scf.yield %7 : f32 } - memref.store %2, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } - %0 = call @test(%alloc, %alloc_1, %alloc_2) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> %cast = memref.cast %0 : memref<256x256xf32> to memref<*xf32> - %cast_5 = memref.cast %alloc_3 : memref<256x256xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] + %cast_5 = memref.cast %alloc_4 : memref<256x256xf32> to memref<*xf32> call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () // call @printMemrefF32(%cast) : (memref<*xf32>) -> () memref.dealloc %alloc : memref<256x256xf16> - memref.dealloc %alloc_1 : memref<256x256xf16> - memref.dealloc %alloc_2 : memref<256x256xf32> + memref.dealloc %alloc_2 : memref<256x256xf16> memref.dealloc %alloc_3 : memref<256x256xf32> + memref.dealloc %alloc_4 : memref<256x256xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/optimize_transpose_array_length.mlir b/test/Integration/Dialect/XeGPU/VC/optimize_transpose_array_length.mlir index c44f3c7b8..7485a1a56 100644 --- a/test/Integration/Dialect/XeGPU/VC/optimize_transpose_array_length.mlir +++ b/test/Integration/Dialect/XeGPU/VC/optimize_transpose_array_length.mlir @@ -1,28 +1,28 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %memref = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %arg0, %memref : memref<256x256xf16> to memref<256x256xf16> - %memref_0 = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %arg1, %memref_0 : memref<256x256xf16> to memref<256x256xf16> - %memref_1 = gpu.alloc host_shared () : memref<256x256xf32> - memref.copy %arg2, %memref_1 : memref<256x256xf32> to memref<256x256xf32> + %memref = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref, %arg0 : memref<256x256xf16>, memref<256x256xf16> + %memref_0 = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_0, %arg1 : memref<256x256xf16>, memref<256x256xf16> + %memref_1 = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_1, %arg2 : memref<256x256xf32>, memref<256x256xf32> gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<256x256xf16>, %memref_0 : memref<256x256xf16>, %memref_1 : memref<256x256xf32>) gpu.dealloc %memref : memref<256x256xf16> gpu.dealloc %memref_0 : memref<256x256xf16> - return %memref_1 : memref<256x256xf32> + %alloc = memref.alloc() : memref<256x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_1 : memref<256x256xf32> + return %alloc : memref<256x256xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index @@ -34,32 +34,32 @@ module @gemm attributes {gpu.container_module} { %1 = arith.muli %block_id_y, %c32 : index %2 = arith.addi %0, %c0 : index %3 = arith.addi %1, %c0 : index - %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> %5 = arith.addi %1, %c16 : index - %6 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + %6 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> %c8 = arith.constant 8 : index %7 = arith.addi %0, %c8 : index - %8 = xegpu.create_nd_tdesc %arg2[%7, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %9 = xegpu.create_nd_tdesc %arg2[%7, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - %10 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> - %11 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> - %12 = xegpu.load_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> -> vector<16x16xf32> - %13 = xegpu.load_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32, #xegpu.block_tdesc_attr> -> vector<16x16xf32> + %8 = xegpu.create_nd_tdesc %arg2[%7, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + %9 = xegpu.create_nd_tdesc %arg2[%7, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + %10 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32> + %11 = xegpu.create_nd_tdesc %arg2[%2, %5] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32> + %12 = xegpu.load_nd %10 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> + %13 = xegpu.load_nd %11 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf32> -> vector<16x16xf32> %14 = vector.extract_strided_slice %12 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %15 = vector.extract_strided_slice %12 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %16 = vector.extract_strided_slice %13 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> %17 = vector.extract_strided_slice %13 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf32> to vector<8x16xf32> - %18 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> - %19 = xegpu.create_nd_tdesc %arg1[%3, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %20:6 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %18, %arg5 = %19, %arg6 = %14, %arg7 = %16, %arg8 = %15, %arg9 = %17) -> (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { - %21 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> + %18 = xegpu.create_nd_tdesc %arg0[%2, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %19 = xegpu.create_nd_tdesc %arg1[%3, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %20:6 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %18, %arg5 = %19, %arg6 = %14, %arg7 = %16, %arg8 = %15, %arg9 = %17) -> (!xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { + %21 = xegpu.load_nd %arg4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> %22 = vector.extract %21[0] : vector<16x16xf16> from vector<2x16x16xf16> %23 = vector.extract %21[1] : vector<16x16xf16> from vector<2x16x16xf16> %24 = vector.extract_strided_slice %22 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %25 = vector.extract_strided_slice %22 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %26 = vector.extract_strided_slice %23 {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> %27 = vector.extract_strided_slice %23 {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> - %28 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %28 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> %29 = vector.extract %28[0] : vector<32x16xf16> from vector<2x32x16xf16> %30 = vector.extract %28[1] : vector<32x16xf16> from vector<2x32x16xf16> %31 = vector.extract_strided_slice %29 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> @@ -78,32 +78,33 @@ module @gemm attributes {gpu.container_module} { %44 = xegpu.dpas %27, %37, %43 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> %45 = xegpu.dpas %25, %36, %arg9 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> %46 = xegpu.dpas %27, %38, %45 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> - %47 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> - %48 = xegpu.update_nd_offset %arg5, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - scf.yield %47, %48, %40, %42, %44, %46 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + %47 = xegpu.update_nd_offset %arg4, [%c0, %c32] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %48 = xegpu.update_nd_offset %arg5, [%c0, %c32] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + scf.yield %47, %48, %40, %42, %44, %46 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> } - xegpu.store_nd %20#2, %4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %20#3, %6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %20#4, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - xegpu.store_nd %20#5, %9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + xegpu.store_nd %20#2, %4 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %20#3, %6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %20#4, %8 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %20#5, %9 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c256 = arith.constant 256 : index - %cst = arith.constant 0.000000e+00 : f16 - %cst_0 = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 %alloc = memref.alloc() : memref<256x256xf16> - %alloc_1 = memref.alloc() : memref<256x256xf16> - %alloc_2 = memref.alloc() : memref<256x256xf32> + %alloc_2 = memref.alloc() : memref<256x256xf16> %alloc_3 = memref.alloc() : memref<256x256xf32> + %alloc_4 = memref.alloc() : memref<256x256xf32> scf.for %arg0 = %c0 to %c256 step %c1 { scf.for %arg1 = %c0 to %c256 step %c1 { %1 = index.castu %arg1 : index to i16 %2 = arith.uitofp %1 : i16 to f16 - memref.store %2, %alloc_1[%arg0, %arg1] : memref<256x256xf16> + memref.store %2, %alloc_2[%arg0, %arg1] : memref<256x256xf16> } } scf.for %arg0 = %c0 to %c256 step %c1 { @@ -112,42 +113,41 @@ module @gemm attributes {gpu.container_module} { %2 = index.castu %arg1 : index to i32 %3 = arith.cmpi eq, %1, %2 : i32 scf.if %3 { - memref.store %cst_0, %alloc[%arg0, %arg1] : memref<256x256xf16> + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<256x256xf16> } else { - memref.store %cst, %alloc[%arg0, %arg1] : memref<256x256xf16> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<256x256xf16> } } } - %cst_4 = arith.constant 0.000000e+00 : f32 scf.for %arg0 = %c0 to %c256 step %c1 { scf.for %arg1 = %c0 to %c256 step %c1 { - memref.store %cst_4, %alloc_2[%arg0, %arg1] : memref<256x256xf32> - memref.store %cst_4, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } scf.for %arg0 = %c0 to %c256 step %c1 { scf.for %arg1 = %c0 to %c256 step %c1 { - %1 = memref.load %alloc_3[%arg0, %arg1] : memref<256x256xf32> + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<256x256xf32> %2 = scf.for %arg2 = %c0 to %c256 step %c1 iter_args(%arg3 = %1) -> (f32) { %3 = memref.load %alloc[%arg0, %arg2] : memref<256x256xf16> - %4 = memref.load %alloc_1[%arg1, %arg2] : memref<256x256xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<256x256xf16> %5 = arith.mulf %3, %4 : f16 %6 = arith.extf %5 : f16 to f32 %7 = arith.addf %6, %arg3 : f32 scf.yield %7 : f32 } - memref.store %2, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } - %0 = call @test(%alloc, %alloc_1, %alloc_2) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> %cast = memref.cast %0 : memref<256x256xf32> to memref<*xf32> - %cast_5 = memref.cast %alloc_3 : memref<256x256xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] + %cast_5 = memref.cast %alloc_4 : memref<256x256xf32> to memref<*xf32> call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () memref.dealloc %alloc : memref<256x256xf16> - memref.dealloc %alloc_1 : memref<256x256xf16> - memref.dealloc %alloc_2 : memref<256x256xf32> + memref.dealloc %alloc_2 : memref<256x256xf16> memref.dealloc %alloc_3 : memref<256x256xf32> + memref.dealloc %alloc_4 : memref<256x256xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/preop_dpas.mlir b/test/Integration/Dialect/XeGPU/VC/preop_dpas.mlir index 3b3f126ec..76d63ca9c 100644 --- a/test/Integration/Dialect/XeGPU/VC/preop_dpas.mlir +++ b/test/Integration/Dialect/XeGPU/VC/preop_dpas.mlir @@ -1,103 +1,88 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_32x32xf16 : memref<32x32xf16> = dense<1.0> - memref.global "private" @__Bconstant_32x32xf16 : memref<32x32xf16> = dense<2.0> + memref.global "private" @__constant_32x32xf16 : memref<32x32xf16> = dense<1.000000e+00> + memref.global "private" @__Bconstant_32x32xf16 : memref<32x32xf16> = dense<2.000000e+00> func.func @test(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { - %c64 = arith.constant 64 : index + %c16 = arith.constant 16 : index %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %arg0, %memref : memref<32x32xf16> to memref<32x32xf16> - %memref_0 = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %arg1, %memref_0 : memref<32x32xf16> to memref<32x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<32x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>, %memref_1 : memref<32x32xf32>) + %memref = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref_0, %arg1 : memref<32x32xf16>, memref<32x32xf16> + %memref_1 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>, %memref_1 : memref<32x32xf32>) gpu.dealloc %memref : memref<32x32xf16> gpu.dealloc %memref_0 : memref<32x32xf16> - return %memref_1 : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_1 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref_1 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } - -gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>, %C: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - - + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c8 = arith.constant 8 : index - %cst = arith.constant dense<1.0> : vector<8x16xf16> - %0 = gpu.block_id x - %1 = gpu.block_id y - - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - - %4 = xegpu.create_nd_tdesc %C[%2, %3] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - - %6 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %A0 = xegpu.create_nd_tdesc %A[%2, %arg3] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %A0_val = xegpu.load_nd %A0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - - %B0 = xegpu.create_nd_tdesc %B[%arg3, %3] : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> - %B0_val = xegpu.load_nd %B0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - - %A0_preop = arith.addf %A0_val, %cst : vector<8x16xf16> - - %dpas0 = xegpu.dpas %A0_preop, %B0_val , %arg4: vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - scf.yield %dpas0 : vector<8x16xf32> + %cst = arith.constant dense<1.000000e+00> : vector<8x16xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %7 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %8 = xegpu.load_nd %7 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %9 = arith.addf %6, %cst : vector<8x16xf16> + %10 = xegpu.dpas %9, %8, %arg4 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %10 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - - %A = memref.get_global @__constant_32x32xf16 : memref<32x32xf16> - %B = memref.get_global @__Bconstant_32x32xf16 : memref<32x32xf16> - %C_ref = memref.alloc() : memref<32x32xf32> - // caculate the result C matrix - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %acc = arith.constant 0.0 : f32 - %res = scf.for %k = %c0 to %c32 step %c1 iter_args(%acc1 = %acc) -> f32 { - %a = memref.load %A[%i, %k] : memref<32x32xf16> - %b = memref.load %B[%k, %j] : memref<32x32xf16> // adjust for preop in GPU kernel, where we add 1 between load and dpas - %cst1 = arith.constant 1.0 : f16 - %a_adj = arith.addf %a, %cst1 : f16 - %c = arith.mulf %a_adj, %b : f16 - %cc = arith.extf %c : f16 to f32 - %ccc = arith.addf %cc, %acc1 : f32 - scf.yield %ccc : f32 + %0 = memref.get_global @__constant_32x32xf16 : memref<32x32xf16> + %1 = memref.get_global @__Bconstant_32x32xf16 : memref<32x32xf16> + %alloc = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %3 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst_0) -> (f32) { + %4 = memref.load %0[%arg0, %arg2] : memref<32x32xf16> + %5 = memref.load %1[%arg2, %arg1] : memref<32x32xf16> + %6 = arith.addf %4, %cst : f16 + %7 = arith.mulf %6, %5 : f16 + %8 = arith.extf %7 : f16 to f32 + %9 = arith.addf %8, %arg3 : f32 + scf.yield %9 : f32 } - memref.store %res, %C_ref[%i, %j] : memref<32x32xf32> + memref.store %3, %alloc[%arg0, %arg1] : memref<32x32xf32> } } - - %2 = call @test(%A, %B) : (memref<32x32xf16>, memref<32x32xf16>) -> memref<32x32xf32> + %2 = call @test(%0, %1) : (memref<32x32xf16>, memref<32x32xf16>) -> memref<32x32xf32> %cast = memref.cast %2 : memref<32x32xf32> to memref<*xf32> // call @printMemrefF32(%cast) : (memref<*xf32>) -> () - %cast_ref = memref.cast %C_ref : memref<32x32xf32> to memref<*xf32> // call @printMaxErrorF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () // call @printMemrefF32(%cast_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %cast_1 = memref.cast %alloc : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/ranked_dynamic_memref.vc.mlir b/test/Integration/Dialect/XeGPU/VC/ranked_dynamic_memref.vc.mlir index 8e6d08dbf..98066c495 100644 --- a/test/Integration/Dialect/XeGPU/VC/ranked_dynamic_memref.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/ranked_dynamic_memref.vc.mlir @@ -1,55 +1,49 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index %c1 = arith.constant 1 : index - %memref_0 = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %A, %memref_0 : memref<8x16xf32> to memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - %memref_0_cast = memref.cast %memref_0 : memref<8x16xf32> to memref - %memref_1_cast = memref.cast %memref_1 : memref<8x16xf32> to memref - %dim0 = arith.constant 8 : index - %dim1 = arith.constant 16 : index - %stride0 = arith.constant 16 : index - %stride1 = arith.constant 1 : index - %x = arith.constant 0 : index - %y = arith.constant 0 : index - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0_cast : memref, %memref_1_cast : memref, %dim0 : index, %dim1 : index, %stride0 : index, %stride1 : index, %x : index, %y : index) - gpu.dealloc %memref_0 : memref<8x16xf32> - return %memref_1 : memref<8x16xf32> + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + %cast = memref.cast %memref : memref<8x16xf32> to memref + %cast_1 = memref.cast %memref_0 : memref<8x16xf32> to memref + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%cast : memref, %cast_1 : memref, %c8 : index, %c16 : index, %c16 : index, %c1 : index, %c0 : index, %c0 : index) + gpu.dealloc %memref : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0 : memref, %arg1: memref, %dim0: index, %dim1: index, %stride0: index, %stride1: index, %x: index, %y: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %1 = xegpu.create_nd_tdesc %arg0[%x, %y], shape: [%dim0, %dim1], strides: [%stride0, %stride1] : memref -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - %6 = xegpu.create_nd_tdesc %arg1[%x, %y], shape: [%dim0, %dim1], strides: [%stride0, %stride1] : memref -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %2, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref, %arg1: memref, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = xegpu.create_nd_tdesc %arg0[%arg6, %arg7], shape : [%arg2, %arg3], strides : [%arg4, %arg5] : memref -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.load_nd %0 <{l1_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg1[%arg6, %arg7], shape : [%arg2, %arg3], strides : [%arg4, %arg5] : memref -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %1, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - - %B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32> - %B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32> - %A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32> // call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<8x16xf32> + %cst = arith.constant 5.000000e-01 : f32 + %cst_0 = arith.constant -5.000000e-01 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst_0, %cst, %false) : (memref<*xf32>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<8x16xf32>) -> memref<8x16xf32> + %cast_1 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast_2, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/sg_coop_transpose.mlir b/test/Integration/Dialect/XeGPU/VC/sg_coop_transpose.mlir index 962a75e29..9916a34c2 100644 --- a/test/Integration/Dialect/XeGPU/VC/sg_coop_transpose.mlir +++ b/test/Integration/Dialect/XeGPU/VC/sg_coop_transpose.mlir @@ -1,25 +1,24 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { func.func @test(%A: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index - %A_gpu = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %A, %A_gpu : memref<32x32xf16> to memref<32x32xf16> - %B_gpu = gpu.alloc host_shared () : memref<32x32xf16> + %A_gpu = gpu.alloc (): memref<32x32xf16> + gpu.memcpy %A_gpu, %A : memref<32x32xf16>, memref<32x32xf16> + %B_gpu = gpu.alloc () : memref<32x32xf16> gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c2, %c1) args(%A_gpu : memref<32x32xf16>, %B_gpu : memref<32x32xf16>) + %host_result = memref.alloc () : memref<32x32xf16> + gpu.memcpy %host_result, %B_gpu : memref<32x32xf16>, memref<32x32xf16> gpu.dealloc %A_gpu : memref<32x32xf16> - return %B_gpu : memref<32x32xf16> + gpu.dealloc %B_gpu : memref<32x32xf16> + return %host_result : memref<32x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -79,6 +78,7 @@ module @gemm attributes {gpu.container_module} { // CHECK: [ALLCLOSE: TRUE] call @printAllcloseF16(%cast, %cast_ref) : (memref<*xf16>, memref<*xf32>) -> () memref.dealloc %A : memref<32x32xf16> + memref.dealloc %B : memref<32x32xf16> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/strided_memref_1d.mlir b/test/Integration/Dialect/XeGPU/VC/strided_memref_1d.mlir index fdd4f516f..213899b1d 100644 --- a/test/Integration/Dialect/XeGPU/VC/strided_memref_1d.mlir +++ b/test/Integration/Dialect/XeGPU/VC/strided_memref_1d.mlir @@ -1,65 +1,59 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__Aconstant_8x32xf32 : memref<8x32xf32> = dense<1.0> - memref.global "private" @__Bconstant_8x32xf32 : memref<8x32xf32> = dense<2.0> + memref.global "private" @__Aconstant_8x32xf32 : memref<8x32xf32> = dense<1.000000e+00> + memref.global "private" @__Bconstant_8x32xf32 : memref<8x32xf32> = dense<2.000000e+00> func.func @test(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf32>) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index - %c0_f32 = arith.constant 0.0 : f32 - - %A = gpu.alloc host_shared () : memref<8x32xf32> - memref.copy %arg0, %A : memref<8x32xf32> to memref<8x32xf32> - %B = gpu.alloc host_shared () : memref<8x32xf32> - memref.copy %arg1, %B : memref<8x32xf32> to memref<8x32xf32> - - %C = gpu.alloc host_shared () : memref<8x32xf32> - %C_unranked = memref.cast %C : memref<8x32xf32> to memref<*xf32> - call @fillResource1DF32(%C_unranked, %c0_f32) : (memref<*xf32>, f32) -> () - // Create the strided memrefs from A, B, C : first 16 elements of each row - %A_strided = memref.subview %A[0, 0][8, 16][1, 1] : memref<8x32xf32> to memref<8x16xf32, strided<[32,1], offset: 0>> - %B_strided = memref.subview %B[0, 0][8, 16][1, 1] : memref<8x32xf32> to memref<8x16xf32, strided<[32,1], offset: 0>> - %C_strided = memref.subview %C[0, 0][8, 16][1, 1] : memref<8x32xf32> to memref<8x16xf32, strided<[32,1], offset: 0>> - - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%A_strided : memref<8x16xf32, strided<[32,1], offset: 0>>, %B_strided : memref<8x16xf32, strided<[32,1], offset: 0>>, %C_strided : memref<8x16xf32, strided<[32,1], offset: 0>>) - gpu.dealloc %A : memref<8x32xf32> - gpu.dealloc %B : memref<8x32xf32> - return %C : memref<8x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + %memref = gpu.alloc () : memref<8x32xf32> + gpu.memcpy %memref, %arg0 : memref<8x32xf32>, memref<8x32xf32> + %memref_0 = gpu.alloc () : memref<8x32xf32> + gpu.memcpy %memref_0, %arg1 : memref<8x32xf32>, memref<8x32xf32> + %memref_host = memref.alloc() : memref<8x32xf32> + %cast_host = memref.cast %memref_host : memref<8x32xf32> to memref<*xf32> + call @fillResource1DF32(%cast_host, %cst) : (memref<*xf32>, f32) -> () + %memref_1 = gpu.alloc () : memref<8x32xf32> + gpu.memcpy %memref_1, %memref_host : memref<8x32xf32>, memref<8x32xf32> + %subview = memref.subview %memref[0, 0] [8, 16] [1, 1] : memref<8x32xf32> to memref<8x16xf32, strided<[32, 1]>> + %subview_2 = memref.subview %memref_0[0, 0] [8, 16] [1, 1] : memref<8x32xf32> to memref<8x16xf32, strided<[32, 1]>> + %subview_3 = memref.subview %memref_1[0, 0] [8, 16] [1, 1] : memref<8x32xf32> to memref<8x16xf32, strided<[32, 1]>> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%subview : memref<8x16xf32, strided<[32, 1]>>, %subview_2 : memref<8x16xf32, strided<[32, 1]>>, %subview_3 : memref<8x16xf32, strided<[32, 1]>>) + memref.dealloc %memref_host : memref<8x32xf32> + gpu.dealloc %memref : memref<8x32xf32> + gpu.dealloc %memref_0 : memref<8x32xf32> + %alloc = memref.alloc() : memref<8x32xf32> + gpu.memcpy %alloc, %memref_1 : memref<8x32xf32>, memref<8x32xf32> + gpu.dealloc %memref_1 : memref<8x32xf32> + return %alloc : memref<8x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0: memref<8x16xf32, strided<[32,1], offset: 0>>, %arg1: memref<8x16xf32, strided<[32,1], offset: 0>>, %arg2: memref<8x16xf32, strided<[32,1], offset: 0>>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %thread_id_x = gpu.thread_id x - - %0 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x16xf32, strided<[32,1], offset: 0>> -> !xegpu.tensor_desc<16xf32> + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<8x16xf32, strided<[32, 1]>>, %arg1: memref<8x16xf32, strided<[32, 1]>>, %arg2: memref<8x16xf32, strided<[32, 1]>>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %thread_id_x = gpu.thread_id x + %0 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x16xf32, strided<[32, 1]>> -> !xegpu.tensor_desc<16xf32> %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> - %2 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x16xf32, strided<[32,1], offset: 0>> -> !xegpu.tensor_desc<16xf32> + %2 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x16xf32, strided<[32, 1]>> -> !xegpu.tensor_desc<16xf32> %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16xf32> -> vector<16xf32> %4 = arith.addf %3, %1 : vector<16xf32> - %5 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0] : memref<8x16xf32, strided<[32,1], offset: 0>> -> !xegpu.tensor_desc<16xf32> + %5 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0] : memref<8x16xf32, strided<[32, 1]>> -> !xegpu.tensor_desc<16xf32> xegpu.store_nd %4, %5 : vector<16xf32>, !xegpu.tensor_desc<16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - // Allocate/get regular row major memrefs - %A = memref.get_global @__Aconstant_8x32xf32 : memref<8x32xf32> - %B = memref.get_global @__Bconstant_8x32xf32 : memref<8x32xf32> - - %result = call @test(%A, %B) : (memref<8x32xf32>, memref<8x32xf32>) -> memref<8x32xf32> - - %result_cast = memref.cast %result : memref<8x32xf32> to memref<*xf32> - call @printMemrefF32(%result_cast) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} // CHECK-NEXT:[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - + %0 = memref.get_global @__Aconstant_8x32xf32 : memref<8x32xf32> + %1 = memref.get_global @__Bconstant_8x32xf32 : memref<8x32xf32> + %2 = call @test(%0, %1) : (memref<8x32xf32>, memref<8x32xf32>) -> memref<8x32xf32> + %cast = memref.cast %2 : memref<8x32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } func.func private @fillResource1DF32(memref<*xf32>, f32) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/strided_memref_2d.mlir b/test/Integration/Dialect/XeGPU/VC/strided_memref_2d.mlir index 791315f22..91dcf4bc0 100644 --- a/test/Integration/Dialect/XeGPU/VC/strided_memref_2d.mlir +++ b/test/Integration/Dialect/XeGPU/VC/strided_memref_2d.mlir @@ -1,85 +1,73 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__Aconstant_32x64xf16 : memref<32x64xf16> = dense<1.0> - memref.global "private" @__Bconstant_32x64xf16 : memref<32x64xf16> = dense<2.0> + memref.global "private" @__Aconstant_32x64xf16 : memref<32x64xf16> = dense<1.000000e+00> + memref.global "private" @__Bconstant_32x64xf16 : memref<32x64xf16> = dense<2.000000e+00> func.func @test(%arg0: memref<32x64xf16>, %arg1: memref<32x64xf16>) -> memref<32x64xf32> attributes {llvm.emit_c_interface} { - %c64 = arith.constant 64 : index %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index - %c0_f32 = arith.constant 0.0 : f32 - + %cst = arith.constant 0.000000e+00 : f32 %c1 = arith.constant 1 : index - %A = gpu.alloc host_shared () : memref<32x64xf16> - memref.copy %arg0, %A : memref<32x64xf16> to memref<32x64xf16> - %B = gpu.alloc host_shared () : memref<32x64xf16> - memref.copy %arg1, %B : memref<32x64xf16> to memref<32x64xf16> - - %C = gpu.alloc host_shared () : memref<32x64xf32> - %C_unranked = memref.cast %C : memref<32x64xf32> to memref<*xf32> - call @fillResource1DF32(%C_unranked, %c0_f32) : (memref<*xf32>, f32) -> () - // Create the strided memrefs from A, B, C : first 32 elements of each row - %A_strided = memref.subview %A[0, 0][32, 32][1, 1] : memref<32x64xf16> to memref<32x32xf16, strided<[64,1], offset: 0>> - %B_strided = memref.subview %B[0, 0][32, 32][1, 1] : memref<32x64xf16> to memref<32x32xf16, strided<[64,1], offset: 0>> - %C_strided = memref.subview %C[0, 0][32, 32][1, 1] : memref<32x64xf32> to memref<32x32xf32, strided<[64,1], offset: 0>> - - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%A_strided : memref<32x32xf16, strided<[64,1], offset: 0>>, %B_strided : memref<32x32xf16, strided<[64,1], offset: 0>>, %C_strided : memref<32x32xf32, strided<[64,1], offset: 0>>) - gpu.dealloc %A : memref<32x64xf16> - gpu.dealloc %B : memref<32x64xf16> - return %C : memref<32x64xf32> + %memref = gpu.alloc () : memref<32x64xf16> + gpu.memcpy %memref, %arg0 : memref<32x64xf16>, memref<32x64xf16> + %memref_0 = gpu.alloc () : memref<32x64xf16> + gpu.memcpy %memref_0, %arg1 : memref<32x64xf16>, memref<32x64xf16> + %memref_host = memref.alloc() : memref<32x64xf32> + %cast_host = memref.cast %memref_host : memref<32x64xf32> to memref<*xf32> + call @fillResource1DF32(%cast_host, %cst) : (memref<*xf32>, f32) -> () + %memref_1 = gpu.alloc () : memref<32x64xf32> + gpu.memcpy %memref_1, %memref_host : memref<32x64xf32>, memref<32x64xf32> + %subview = memref.subview %memref[0, 0] [32, 32] [1, 1] : memref<32x64xf16> to memref<32x32xf16, strided<[64, 1]>> + %subview_2 = memref.subview %memref_0[0, 0] [32, 32] [1, 1] : memref<32x64xf16> to memref<32x32xf16, strided<[64, 1]>> + %subview_3 = memref.subview %memref_1[0, 0] [32, 32] [1, 1] : memref<32x64xf32> to memref<32x32xf32, strided<[64, 1]>> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%subview : memref<32x32xf16, strided<[64, 1]>>, %subview_2 : memref<32x32xf16, strided<[64, 1]>>, %subview_3 : memref<32x32xf32, strided<[64, 1]>>) + memref.dealloc %memref_host : memref<32x64xf32> + gpu.dealloc %memref : memref<32x64xf16> + gpu.dealloc %memref_0 : memref<32x64xf16> + %alloc = memref.alloc() : memref<32x64xf32> + gpu.memcpy %alloc, %memref_1 : memref<32x64xf32>, memref<32x64xf32> + gpu.dealloc %memref_1 : memref<32x64xf32> + return %alloc : memref<32x64xf32> } - -gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<32x32xf16, strided<[64,1], offset: 0>>, %B: memref<32x32xf16, strided<[64,1], offset: 0>>, %C: memref<32x32xf32, strided<[64,1], offset: 0>>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<32x32xf16, strided<[64, 1]>>, %arg1: memref<32x32xf16, strided<[64, 1]>>, %arg2: memref<32x32xf32, strided<[64, 1]>>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c8 = arith.constant 8 : index - %cst = arith.constant dense<1.0> : vector<8x16xf16> - %0 = gpu.block_id x - %1 = gpu.block_id y - - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - - %4 = xegpu.create_nd_tdesc %C[%2, %3] : memref<32x32xf32, strided<[64,1], offset: 0>> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - - %6 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - %A0 = xegpu.create_nd_tdesc %A[%2, %arg3] : memref<32x32xf16, strided<[64,1], offset: 0>> -> !xegpu.tensor_desc<8x16xf16> - %A0_val = xegpu.load_nd %A0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - - %B0 = xegpu.create_nd_tdesc %B[%arg3, %3] : memref<32x32xf16, strided<[64,1], offset: 0>> -> !xegpu.tensor_desc<16x16xf16> - %B0_val = xegpu.load_nd %B0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - - %A0_preop = arith.addf %A0_val, %cst : vector<8x16xf16> - - %dpas0 = xegpu.dpas %A0_preop, %B0_val , %arg4: vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - scf.yield %dpas0 : vector<8x16xf32> + %cst = arith.constant dense<1.000000e+00> : vector<8x16xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<32x32xf32, strided<[64, 1]>> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<32x32xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %7 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<32x32xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16> + %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = arith.addf %6, %cst : vector<8x16xf16> + %10 = xegpu.dpas %9, %8, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %10 : vector<8x16xf32> } - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { // Allocate/get regular row major memrefs - %A = memref.get_global @__Aconstant_32x64xf16 : memref<32x64xf16> - %B = memref.get_global @__Bconstant_32x64xf16 : memref<32x64xf16> - - %result = call @test(%A, %B) : (memref<32x64xf16>, memref<32x64xf16>) -> memref<32x64xf32> - %result_cast = memref.cast %result : memref<32x64xf32> to memref<*xf32> - call @printMemrefF32(%result_cast) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} // CHECK-NEXT:[128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + %0 = memref.get_global @__Aconstant_32x64xf16 : memref<32x64xf16> + %1 = memref.get_global @__Bconstant_32x64xf16 : memref<32x64xf16> + %2 = call @test(%0, %1) : (memref<32x64xf16>, memref<32x64xf16>) -> memref<32x64xf32> + %cast = memref.cast %2 : memref<32x64xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } func.func private @fillResource1DF32(memref<*xf32>, f32) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16-simplified.mlir b/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16-simplified.mlir index 38d7ec78e..d280178d6 100644 --- a/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16-simplified.mlir +++ b/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16-simplified.mlir @@ -1,28 +1,23 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#slm = #xegpu.scatter_tdesc_attr -#blk_slm = #xegpu.block_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16x32xf16>) -> memref<8x64xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %memref = gpu.alloc host_shared () : memref<16x32xf16> - memref.copy %arg0, %memref : memref<16x32xf16> to memref<16x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<8x64xf16> - gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c4, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_1 : memref<8x64xf16>) - + %memref = gpu.alloc () : memref<16x32xf16> + gpu.memcpy %memref, %arg0 : memref<16x32xf16>, memref<16x32xf16> + %memref_0 = gpu.alloc () : memref<8x64xf16> + gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c4, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_0 : memref<8x64xf16>) gpu.dealloc %memref : memref<16x32xf16> - return %memref_1 : memref<8x64xf16> + %alloc = memref.alloc() : memref<8x64xf16> + gpu.memcpy %alloc, %memref_0 : memref<8x64xf16>, memref<8x64xf16> + gpu.dealloc %memref_0 : memref<8x64xf16> + return %alloc : memref<8x64xf16> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { // this example is to illustrate an example of using slm to do the transpose. // the high level logic is equivalent to the following code: // %data = xegpu.load_nd %in : !xegpu.tensor_desc<16x32xf16> -> vector<16x8xf16> @@ -33,45 +28,36 @@ module @gemm attributes {gpu.container_module} { %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c64 = arith.constant 64 : index - - %id = gpu.subgroup_id : index - %y_in = arith.muli %id, %c8 : index - - %in = xegpu.create_nd_tdesc %arg0[0, %y_in] : memref<16x32xf16> -> !xegpu.tensor_desc<16x8xf16> // original load is %data = xegpu.load_nd %in : !xegpu.tensor_desc<16x8xf16> -> vector<16x8xf16> // now it is transformed to use packed attribute - %data = xegpu.load_nd %in {packed} : !xegpu.tensor_desc<16x8xf16> -> vector<8x8x2xf16> - - %shapecast = vector.shape_cast %data : vector<8x8x2xf16> to vector<128xf16> - %data32b = vector.bitcast %shapecast : vector<128xf16> to vector<64xf32> - %cast = vector.shape_cast %data32b : vector<64xf32> to vector<8x8xf32> - // the following code uses slm to do the transpose. It contains 3 steps: // step1: store the data into slm using store scatter - %slm = memref.alloc() : memref<256xf32, 3> - - %base = arith.muli %id, %c64 : index - %baseVec = vector.broadcast %base : index to vector<8xindex> - %staticOff = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56]> : vector<8xindex> - %offsets = arith.addi %baseVec, %staticOff : vector<8xindex> - %slm_desc = xegpu.create_tdesc %slm, %offsets : memref<256xf32, 3>, vector<8xindex> -> !xegpu.tensor_desc<8x8xf32, #slm> - - %mask = arith.constant dense<[1, 1, 1, 1, 1, 1, 1, 1]> : vector<8xi1> - %trans = vector.transpose %cast, [1, 0] : vector<8x8xf32> to vector<8x8xf32> - xegpu.store %trans, %slm_desc, %mask : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32, #slm>, vector<8xi1> - // step2: load from slm using 1d block load - %off = arith.muli %id, %c64 : index - %slm_1d_desc = xegpu.create_nd_tdesc %slm[%off] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #blk_slm> - %data_1d = xegpu.load_nd %slm_1d_desc : !xegpu.tensor_desc<64xf32, #blk_slm> -> vector<64xf32> - // step3: simply do the shape cast to get the final result - %bitcast = vector.bitcast %data_1d : vector<64xf32> to vector<128xf16> - %transposed = vector.shape_cast %bitcast : vector<128xf16> to vector<8x16xf16> - - %out_y = arith.muli %id, %c16 : index - %out = xegpu.create_nd_tdesc %arg1[0, %out_y]: memref<8x64xf16> -> !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %transposed, %out : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + %0 = gpu.subgroup_id : index + %1 = arith.muli %0, %c8 : index + %2 = xegpu.create_nd_tdesc %arg0[0, %1] : memref<16x32xf16> -> !xegpu.tensor_desc<16x8xf16> + %3 = xegpu.load_nd %2 <{packed}> : !xegpu.tensor_desc<16x8xf16> -> vector<8x8x2xf16> + %4 = vector.shape_cast %3 : vector<8x8x2xf16> to vector<128xf16> + %5 = vector.bitcast %4 : vector<128xf16> to vector<64xf32> + %6 = vector.shape_cast %5 : vector<64xf32> to vector<8x8xf32> + %alloc = memref.alloc() : memref<256xf32, 3> + %7 = arith.muli %0, %c64 : index + %8 = vector.broadcast %7 : index to vector<8xindex> + %cst = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56]> : vector<8xindex> + %9 = arith.addi %8, %cst : vector<8xindex> + %10 = xegpu.create_tdesc %alloc, %9 : memref<256xf32, 3>, vector<8xindex> -> !xegpu.tensor_desc<8x8xf32, #xegpu.scatter_tdesc_attr> + %cst_0 = arith.constant dense : vector<8xi1> + %11 = vector.transpose %6, [1, 0] : vector<8x8xf32> to vector<8x8xf32> + xegpu.store %11, %10, %cst_0 : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32, #xegpu.scatter_tdesc_attr>, vector<8xi1> + %12 = arith.muli %0, %c64 : index + %13 = xegpu.create_nd_tdesc %alloc[%12] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> + %14 = xegpu.load_nd %13 : !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> -> vector<64xf32> + %15 = vector.bitcast %14 : vector<64xf32> to vector<128xf16> + %16 = vector.shape_cast %15 : vector<128xf16> to vector<8x16xf16> + %17 = arith.muli %0, %c16 : index + %18 = xegpu.create_nd_tdesc %arg1[0, %17] : memref<8x64xf16> -> !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %16, %18 : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> gpu.return } } @@ -80,23 +66,16 @@ module @gemm attributes {gpu.container_module} { %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %0 = memref.alloc() : memref<16x32xf16> - - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %mul = arith.muli %i, %c32 : index - %add = arith.addi %mul, %j : index - %int = arith.index_cast %add : index to i16 - %fp = arith.uitofp %int : i16 to f16 - memref.store %fp, %0[%i, %j] : memref<16x32xf16> + %alloc = memref.alloc() : memref<16x32xf16> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = arith.index_cast %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + memref.store %4, %alloc[%arg0, %arg1] : memref<16x32xf16> } } - - %2 = call @test(%0) : (memref<16x32xf16>) -> memref<8x64xf16> - %cast = memref.cast %2: memref<8x64xf16> to memref<*xf16> - - //CHECK: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480, 8, 40, 72, 104, 136, 168, 200, 232, 264, 296, 328, 360, 392, 424, 456, 488, 16, 48, 80, 112, 144, 176, 208, 240, 272, 304, 336, 368, 400, 432, 464, 496, 24, 56, 88, 120, 152, 184, 216, 248, 280, 312, 344, 376, 408, 440, 472, 504] //CHECK: [1, 33, 65, 97, 129, 161, 193, 225, 257, 289, 321, 353, 385, 417, 449, 481, 9, 41, 73, 105, 137, 169, 201, 233, 265, 297, 329, 361, 393, 425, 457, 489, 17, 49, 81, 113, 145, 177, 209, 241, 273, 305, 337, 369, 401, 433, 465, 497, 25, 57, 89, 121, 153, 185, 217, 249, 281, 313, 345, 377, 409, 441, 473, 505] //CHECK: [2, 34, 66, 98, 130, 162, 194, 226, 258, 290, 322, 354, 386, 418, 450, 482, 10, 42, 74, 106, 138, 170, 202, 234, 266, 298, 330, 362, 394, 426, 458, 490, 18, 50, 82, 114, 146, 178, 210, 242, 274, 306, 338, 370, 402, 434, 466, 498, 26, 58, 90, 122, 154, 186, 218, 250, 282, 314, 346, 378, 410, 442, 474, 506] @@ -105,9 +84,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [5, 37, 69, 101, 133, 165, 197, 229, 261, 293, 325, 357, 389, 421, 453, 485, 13, 45, 77, 109, 141, 173, 205, 237, 269, 301, 333, 365, 397, 429, 461, 493, 21, 53, 85, 117, 149, 181, 213, 245, 277, 309, 341, 373, 405, 437, 469, 501, 29, 61, 93, 125, 157, 189, 221, 253, 285, 317, 349, 381, 413, 445, 477, 509] //CHECK: [6, 38, 70, 102, 134, 166, 198, 230, 262, 294, 326, 358, 390, 422, 454, 486, 14, 46, 78, 110, 142, 174, 206, 238, 270, 302, 334, 366, 398, 430, 462, 494, 22, 54, 86, 118, 150, 182, 214, 246, 278, 310, 342, 374, 406, 438, 470, 502, 30, 62, 94, 126, 158, 190, 222, 254, 286, 318, 350, 382, 414, 446, 478, 510] //CHECK: [7, 39, 71, 103, 135, 167, 199, 231, 263, 295, 327, 359, 391, 423, 455, 487, 15, 47, 79, 111, 143, 175, 207, 239, 271, 303, 335, 367, 399, 431, 463, 495, 23, 55, 87, 119, 151, 183, 215, 247, 279, 311, 343, 375, 407, 439, 471, 503, 31, 63, 95, 127, 159, 191, 223, 255, 287, 319, 351, 383, 415, 447, 479, 511] - call @printMemrefF16(%cast): (memref<*xf16>) -> () + %0 = call @test(%alloc) : (memref<16x32xf16>) -> memref<8x64xf16> + %cast = memref.cast %0 : memref<8x64xf16> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16.mlir b/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16.mlir index 3b56128a7..289e34bb0 100644 --- a/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f16.mlir @@ -1,28 +1,23 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#slm = #xegpu.scatter_tdesc_attr -#blk_slm = #xegpu.block_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<16x32xf16>) -> memref<8x64xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %memref = gpu.alloc host_shared () : memref<16x32xf16> - memref.copy %arg0, %memref : memref<16x32xf16> to memref<16x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<8x64xf16> - gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c4, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_1 : memref<8x64xf16>) - + %memref = gpu.alloc () : memref<16x32xf16> + gpu.memcpy %memref, %arg0 : memref<16x32xf16>, memref<16x32xf16> + %memref_0 = gpu.alloc () : memref<8x64xf16> + gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c4, %c1, %c1) args(%memref : memref<16x32xf16>, %memref_0 : memref<8x64xf16>) gpu.dealloc %memref : memref<16x32xf16> - return %memref_1 : memref<8x64xf16> + %alloc = memref.alloc() : memref<8x64xf16> + gpu.memcpy %alloc, %memref_0 : memref<8x64xf16>, memref<8x64xf16> + gpu.dealloc %memref_0 : memref<8x64xf16> + return %alloc : memref<8x64xf16> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { // this example is to illustrate an example of using slm to do the transpose. // the high level logic is equivalent to the following code: // %data = xegpu.load_nd %in : !xegpu.tensor_desc<16x32xf16> -> vector<16x8xf16> @@ -33,47 +28,39 @@ module @gemm attributes {gpu.container_module} { %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c64 = arith.constant 64 : index - - %id = gpu.subgroup_id : index - %y_in = arith.muli %id, %c8 : index - - %in = xegpu.create_nd_tdesc %arg0[0, %y_in] : memref<16x32xf16> -> !xegpu.tensor_desc<16x8xf16> // original load is %data = xegpu.load_nd %in : !xegpu.tensor_desc<16x8xf16> -> vector<16x8xf16> // now it is transformed to use packed attribute - %data = xegpu.load_nd %in {packed} : !xegpu.tensor_desc<16x8xf16> -> vector<8x8x2xf16> - %shapecast = vector.shape_cast %data : vector<8x8x2xf16> to vector<128xf16> - %data32b = vector.bitcast %shapecast : vector<128xf16> to vector<64xf32> - %tmp = vector.shape_cast %data32b : vector<64xf32> to vector<8x8xf32> - %pad = arith.constant dense<0.0> : vector<8x8xf32> - %comb = vector.shuffle %tmp, %pad[0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> - %cast = vector.shape_cast %comb : vector<16x8xf32> to vector<8x16xf32> - // the following code uses slm to do the transpose. It contains 3 steps: // step1: store the data into slm using store scatter - %slm = memref.alloc() : memref<256xf32, 3> - %mask = arith.constant dense<[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]> : vector<16xi1> - - %base = arith.muli %id, %c64 : index - %baseVec = vector.broadcast %base : index to vector<16xindex> - %staticOff = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - %offsets = arith.addi %baseVec, %staticOff : vector<16xindex> - - %slm_desc = xegpu.create_tdesc %slm, %offsets : memref<256xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #slm> - %trans = vector.transpose %cast, [1, 0] : vector<8x16xf32> to vector<16x8xf32> - xegpu.store %trans, %slm_desc, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #slm>, vector<16xi1> - // step2: load from slm using 1d block load - %off = arith.muli %id, %c64 : index - %slm_1d_desc = xegpu.create_nd_tdesc %slm[%off] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #blk_slm> - %data_1d = xegpu.load_nd %slm_1d_desc : !xegpu.tensor_desc<64xf32, #blk_slm> -> vector<64xf32> - // step3: simply do the shape cast to get the final result - %bitcast = vector.bitcast %data_1d : vector<64xf32> to vector<128xf16> - %transposed = vector.shape_cast %bitcast : vector<128xf16> to vector<8x16xf16> - - %out_y = arith.muli %id, %c16 : index - %out = xegpu.create_nd_tdesc %arg1[0, %out_y]: memref<8x64xf16> -> !xegpu.tensor_desc<8x16xf16> - xegpu.store_nd %transposed, %out : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + %0 = gpu.subgroup_id : index + %1 = arith.muli %0, %c8 : index + %2 = xegpu.create_nd_tdesc %arg0[0, %1] : memref<16x32xf16> -> !xegpu.tensor_desc<16x8xf16> + %3 = xegpu.load_nd %2 <{packed}> : !xegpu.tensor_desc<16x8xf16> -> vector<8x8x2xf16> + %4 = vector.shape_cast %3 : vector<8x8x2xf16> to vector<128xf16> + %5 = vector.bitcast %4 : vector<128xf16> to vector<64xf32> + %6 = vector.shape_cast %5 : vector<64xf32> to vector<8x8xf32> + %cst = arith.constant dense<0.000000e+00> : vector<8x8xf32> + %7 = vector.shuffle %6, %cst [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> + %8 = vector.shape_cast %7 : vector<16x8xf32> to vector<8x16xf32> + %alloc = memref.alloc() : memref<256xf32, 3> + %cst_0 = arith.constant dense<[true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false]> : vector<16xi1> + %9 = arith.muli %0, %c64 : index + %10 = vector.broadcast %9 : index to vector<16xindex> + %cst_1 = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + %11 = arith.addi %10, %cst_1 : vector<16xindex> + %12 = xegpu.create_tdesc %alloc, %11 : memref<256xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + %13 = vector.transpose %8, [1, 0] : vector<8x16xf32> to vector<16x8xf32> + xegpu.store %13, %12, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %14 = arith.muli %0, %c64 : index + %15 = xegpu.create_nd_tdesc %alloc[%14] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> + %16 = xegpu.load_nd %15 : !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> -> vector<64xf32> + %17 = vector.bitcast %16 : vector<64xf32> to vector<128xf16> + %18 = vector.shape_cast %17 : vector<128xf16> to vector<8x16xf16> + %19 = arith.muli %0, %c16 : index + %20 = xegpu.create_nd_tdesc %arg1[0, %19] : memref<8x64xf16> -> !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %18, %20 : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> gpu.return } } @@ -82,21 +69,16 @@ module @gemm attributes {gpu.container_module} { %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %0 = memref.alloc() : memref<16x32xf16> - - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %mul = arith.muli %i, %c32 : index - %add = arith.addi %mul, %j : index - %int = arith.index_cast %add : index to i16 - %fp = arith.uitofp %int : i16 to f16 - memref.store %fp, %0[%i, %j] : memref<16x32xf16> + %alloc = memref.alloc() : memref<16x32xf16> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = arith.index_cast %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + memref.store %4, %alloc[%arg0, %arg1] : memref<16x32xf16> } } - %2 = call @test(%0) : (memref<16x32xf16>) -> memref<8x64xf16> - %cast = memref.cast %2: memref<8x64xf16> to memref<*xf16> - //CHECK: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480, 8, 40, 72, 104, 136, 168, 200, 232, 264, 296, 328, 360, 392, 424, 456, 488, 16, 48, 80, 112, 144, 176, 208, 240, 272, 304, 336, 368, 400, 432, 464, 496, 24, 56, 88, 120, 152, 184, 216, 248, 280, 312, 344, 376, 408, 440, 472, 504] //CHECK: [1, 33, 65, 97, 129, 161, 193, 225, 257, 289, 321, 353, 385, 417, 449, 481, 9, 41, 73, 105, 137, 169, 201, 233, 265, 297, 329, 361, 393, 425, 457, 489, 17, 49, 81, 113, 145, 177, 209, 241, 273, 305, 337, 369, 401, 433, 465, 497, 25, 57, 89, 121, 153, 185, 217, 249, 281, 313, 345, 377, 409, 441, 473, 505] //CHECK: [2, 34, 66, 98, 130, 162, 194, 226, 258, 290, 322, 354, 386, 418, 450, 482, 10, 42, 74, 106, 138, 170, 202, 234, 266, 298, 330, 362, 394, 426, 458, 490, 18, 50, 82, 114, 146, 178, 210, 242, 274, 306, 338, 370, 402, 434, 466, 498, 26, 58, 90, 122, 154, 186, 218, 250, 282, 314, 346, 378, 410, 442, 474, 506] @@ -105,9 +87,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [5, 37, 69, 101, 133, 165, 197, 229, 261, 293, 325, 357, 389, 421, 453, 485, 13, 45, 77, 109, 141, 173, 205, 237, 269, 301, 333, 365, 397, 429, 461, 493, 21, 53, 85, 117, 149, 181, 213, 245, 277, 309, 341, 373, 405, 437, 469, 501, 29, 61, 93, 125, 157, 189, 221, 253, 285, 317, 349, 381, 413, 445, 477, 509] //CHECK: [6, 38, 70, 102, 134, 166, 198, 230, 262, 294, 326, 358, 390, 422, 454, 486, 14, 46, 78, 110, 142, 174, 206, 238, 270, 302, 334, 366, 398, 430, 462, 494, 22, 54, 86, 118, 150, 182, 214, 246, 278, 310, 342, 374, 406, 438, 470, 502, 30, 62, 94, 126, 158, 190, 222, 254, 286, 318, 350, 382, 414, 446, 478, 510] //CHECK: [7, 39, 71, 103, 135, 167, 199, 231, 263, 295, 327, 359, 391, 423, 455, 487, 15, 47, 79, 111, 143, 175, 207, 239, 271, 303, 335, 367, 399, 431, 463, 495, 23, 55, 87, 119, 151, 183, 215, 247, 279, 311, 343, 375, 407, 439, 471, 503, 31, 63, 95, 127, 159, 191, 223, 255, 287, 319, 351, 383, 415, 447, 479, 511] - call @printMemrefF16(%cast): (memref<*xf16>) -> () + %0 = call @test(%alloc) : (memref<16x32xf16>) -> memref<8x64xf16> + %cast = memref.cast %0 : memref<8x64xf16> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f32.mlir b/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f32.mlir index e9c0a45ff..535529b99 100644 --- a/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/transpose-via-slm-f32.mlir @@ -1,28 +1,23 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck -#slm = #xegpu.scatter_tdesc_attr -#blk_slm = #xegpu.block_tdesc_attr module @gemm attributes {gpu.container_module} { func.func @test(%arg0: memref<8x32xf32>) -> memref<16x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %memref = gpu.alloc host_shared () : memref<8x32xf32> - memref.copy %arg0, %memref : memref<8x32xf32> to memref<8x32xf32> - %memref_1 = gpu.alloc host_shared () : memref<16x16xf32> - gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c2, %c1, %c1) args(%memref : memref<8x32xf32>, %memref_1 : memref<16x16xf32>) - + %memref = gpu.alloc () : memref<8x32xf32> + gpu.memcpy %memref, %arg0 : memref<8x32xf32>, memref<8x32xf32> + %memref_0 = gpu.alloc () : memref<16x16xf32> + gpu.launch_func @test_kernel::@test_transpose blocks in (%c1, %c1, %c1) threads in (%c2, %c1, %c1) args(%memref : memref<8x32xf32>, %memref_0 : memref<16x16xf32>) gpu.dealloc %memref : memref<8x32xf32> - return %memref_1 : memref<16x16xf32> + %alloc = memref.alloc() : memref<16x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<16x16xf32>, memref<16x16xf32> + gpu.dealloc %memref_0 : memref<16x16xf32> + return %alloc : memref<16x16xf32> } - - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { // this example is to illustrate an example of using slm to do the transpose. // the high level logic is equivalent to the following code: // %data = xegpu.load_nd %in : !xegpu.tensor_desc<8x32xf32> -> vector<8x32xf32> @@ -35,44 +30,36 @@ module @gemm attributes {gpu.container_module} { %c16 = arith.constant 16 : index %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index - - %id = gpu.subgroup_id : index - %y = arith.muli %id, %c16 : index - %in = xegpu.create_nd_tdesc %arg0[0, %y] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %data = xegpu.load_nd %in : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - // the following code uses slm to do the transpose. It contains 3 steps: // step1: store the data into slm using store scatter - %slm = memref.alloc() : memref<256xf32, 3> - %mask = arith.constant dense<1> : vector<16xi1> - - %base = arith.muli %id, %c128 : index - %baseVec = vector.broadcast %base : index to vector<16xindex> - %staticOff = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> - %offsets = arith.addi %baseVec, %staticOff : vector<16xindex> - - %slm_desc = xegpu.create_tdesc %slm, %offsets : memref<256xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #slm> - %trans = vector.transpose %data, [1, 0] : vector<8x16xf32> to vector<16x8xf32> - xegpu.store %trans, %slm_desc, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #slm>, vector<16xi1> - // step2: load from slm using 1d block load - %base1 = arith.addi %base, %c0 : index - %base2 = arith.addi %base, %c64 : index - %slm_1d_desc_0 = xegpu.create_nd_tdesc %slm[%base1] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #blk_slm> - %slm_1d_desc_1 = xegpu.create_nd_tdesc %slm[%base2] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #blk_slm> - %data_1d_0 = xegpu.load_nd %slm_1d_desc_0 : !xegpu.tensor_desc<64xf32, #blk_slm> -> vector<64xf32> - %data_1d_1 = xegpu.load_nd %slm_1d_desc_1 : !xegpu.tensor_desc<64xf32, #blk_slm> -> vector<64xf32> - // step3: simply do the shape cast to get the final result - %transposed_0 = vector.shape_cast %data_1d_0 : vector<64xf32> to vector<8x8xf32> - %transposed_1 = vector.shape_cast %data_1d_1 : vector<64xf32> to vector<8x8xf32> - - %y2 = arith.muli %id, %c8 : index - %out_0 = xegpu.create_nd_tdesc %arg1[0, %y2]: memref<16x16xf32> -> !xegpu.tensor_desc<8x8xf32> - %out_1 = xegpu.create_nd_tdesc %arg1[8, %y2]: memref<16x16xf32> -> !xegpu.tensor_desc<8x8xf32> - xegpu.store_nd %transposed_0, %out_0 : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32> - xegpu.store_nd %transposed_1, %out_1 : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32> - + %0 = gpu.subgroup_id : index + %1 = arith.muli %0, %c16 : index + %2 = xegpu.create_nd_tdesc %arg0[0, %1] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %alloc = memref.alloc() : memref<256xf32, 3> + %cst = arith.constant dense : vector<16xi1> + %4 = arith.muli %0, %c128 : index + %5 = vector.broadcast %4 : index to vector<16xindex> + %cst_0 = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex> + %6 = arith.addi %5, %cst_0 : vector<16xindex> + %7 = xegpu.create_tdesc %alloc, %6 : memref<256xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + %8 = vector.transpose %3, [1, 0] : vector<8x16xf32> to vector<16x8xf32> + xegpu.store %8, %7, %cst : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %9 = arith.addi %4, %c0 : index + %10 = arith.addi %4, %c64 : index + %11 = xegpu.create_nd_tdesc %alloc[%9] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> + %12 = xegpu.create_nd_tdesc %alloc[%10] : memref<256xf32, 3> -> !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> + %13 = xegpu.load_nd %11 : !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> -> vector<64xf32> + %14 = xegpu.load_nd %12 : !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> -> vector<64xf32> + %15 = vector.shape_cast %13 : vector<64xf32> to vector<8x8xf32> + %16 = vector.shape_cast %14 : vector<64xf32> to vector<8x8xf32> + %17 = arith.muli %0, %c8 : index + %18 = xegpu.create_nd_tdesc %arg1[0, %17] : memref<16x16xf32> -> !xegpu.tensor_desc<8x8xf32> + %19 = xegpu.create_nd_tdesc %arg1[8, %17] : memref<16x16xf32> -> !xegpu.tensor_desc<8x8xf32> + xegpu.store_nd %15, %18 : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32> + xegpu.store_nd %16, %19 : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32> gpu.return } } @@ -81,21 +68,16 @@ module @gemm attributes {gpu.container_module} { %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c32 = arith.constant 32 : index - - %0 = memref.alloc() : memref<8x32xf32> - - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %mul = arith.muli %i, %c32 : index - %add = arith.addi %mul, %j : index - %int = arith.index_cast %add : index to i32 - %fp = arith.uitofp %int : i32 to f32 - memref.store %fp, %0[%i, %j] : memref<8x32xf32> + %alloc = memref.alloc() : memref<8x32xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = arith.index_cast %2 : index to i32 + %4 = arith.uitofp %3 : i32 to f32 + memref.store %4, %alloc[%arg0, %arg1] : memref<8x32xf32> } } - - %2 = call @test(%0) : (memref<8x32xf32>) -> memref<16x16xf32> - //CHECK: [0, 32, 64, 96, 128, 160, 192, 224, 16, 48, 80, 112, 144, 176, 208, 240] //CHECK: [1, 33, 65, 97, 129, 161, 193, 225, 17, 49, 81, 113, 145, 177, 209, 241] //CHECK: [2, 34, 66, 98, 130, 162, 194, 226, 18, 50, 82, 114, 146, 178, 210, 242] @@ -112,11 +94,10 @@ module @gemm attributes {gpu.container_module} { //CHECK: [13, 45, 77, 109, 141, 173, 205, 237, 29, 61, 93, 125, 157, 189, 221, 253] //CHECK: [14, 46, 78, 110, 142, 174, 206, 238, 30, 62, 94, 126, 158, 190, 222, 254] //CHECK: [15, 47, 79, 111, 143, 175, 207, 239, 31, 63, 95, 127, 159, 191, 223, 255] - %cast = memref.cast %2: memref<16x16xf32> to memref<*xf32> - call @printMemrefF32(%cast): (memref<*xf32>) -> () - + %0 = call @test(%alloc) : (memref<8x32xf32>) -> memref<16x16xf32> + %cast = memref.cast %0 : memref<16x16xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf16.mlir b/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf16.mlir index 8ce9efd45..7985086d0 100644 --- a/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf16.mlir @@ -1,66 +1,62 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %arg0, %memref : memref<32x32xf16> to memref<32x32xf16> - %B = gpu.alloc host_shared () : memref<32x32xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %B : memref<32x32xf16>) + %memref = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>) gpu.dealloc %memref : memref<32x32xf16> - return %B : memref<32x32xf16> + %alloc = memref.alloc() : memref<32x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf16>, memref<32x32xf16> + gpu.dealloc %memref_0 : memref<32x32xf16> + return %alloc : memref<32x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg0[%2, %3] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %6 = xegpu.create_nd_tdesc %arg1[%3, %2] : memref<32x32xf16> -> !xegpu.tensor_desc<16x8xf16> - %7 = vector.transpose %5, [1, 0]: vector<8x16xf16> to vector<16x8xf16> - xegpu.store_nd %7, %6 : vector<16x8xf16>, !xegpu.tensor_desc<16x8xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg0[%0, %1] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %4 = xegpu.create_nd_tdesc %arg1[%1, %0] : memref<32x32xf16> -> !xegpu.tensor_desc<16x8xf16> + %5 = vector.transpose %3, [1, 0] : vector<8x16xf16> to vector<16x8xf16> + xegpu.store_nd %5, %4 : vector<16x8xf16>, !xegpu.tensor_desc<16x8xf16> gpu.return } } - - func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.alloc() : memref<32x32xf16> - %ref = memref.alloc() : memref<32x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %int = arith.index_cast %j : index to i32 - %fp = arith.uitofp %int : i32 to f16 - memref.store %fp, %0[%i, %j] : memref<32x32xf16> - %fp_32 = arith.extf %fp : f16 to f32 - memref.store %fp_32, %ref[%j, %i] : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf16> + %alloc_0 = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.index_cast %arg1 : index to i32 + %2 = arith.uitofp %1 : i32 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<32x32xf16> + %3 = arith.extf %2 : f16 to f32 + memref.store %3, %alloc_0[%arg1, %arg0] : memref<32x32xf32> } } - - %2 = call @test(%0) : (memref<32x32xf16>) -> memref<32x32xf16> - %res = memref.cast %2 : memref<32x32xf16> to memref<*xf16> - %cast_ref = memref.cast %ref : memref<32x32xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%res, %cast_ref) : (memref<*xf16>, memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<32x32xf16>) -> memref<32x32xf16> + %cast = memref.cast %0 : memref<32x32xf16> to memref<*xf16> + %cast_1 = memref.cast %alloc_0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_1) : (memref<*xf16>, memref<*xf32>) -> () return } func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf32.mlir b/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf32.mlir index 84800dd8a..8eaf19fd3 100644 --- a/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/transpose_8x16xf32.mlir @@ -1,65 +1,61 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32x32xf32> - memref.copy %arg0, %memref : memref<32x32xf32> to memref<32x32xf32> - %B = gpu.alloc host_shared () : memref<32x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf32>, %B : memref<32x32xf32>) + %memref = gpu.alloc () : memref<32x32xf32> + gpu.memcpy %memref, %arg0 : memref<32x32xf32>, memref<32x32xf32> + %memref_0 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf32>, %memref_0 : memref<32x32xf32>) gpu.dealloc %memref : memref<32x32xf32> - return %B : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref_0 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %4 = xegpu.create_nd_tdesc %arg0[%2, %3] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - %6 = xegpu.create_nd_tdesc %arg1[%3, %2] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32> - %7 = vector.transpose %5, [1, 0]: vector<8x16xf32> to vector<16x8xf32> - xegpu.store_nd %7, %6 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg0[%0, %1] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = xegpu.create_nd_tdesc %arg1[%1, %0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32> + %5 = vector.transpose %3, [1, 0] : vector<8x16xf32> to vector<16x8xf32> + xegpu.store_nd %5, %4 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32> gpu.return } } - - func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.alloc() : memref<32x32xf32> - %ref = memref.alloc() : memref<32x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %int = arith.index_cast %j : index to i32 - %fp = arith.uitofp %int : i32 to f32 - memref.store %fp, %0[%i, %j] : memref<32x32xf32> - memref.store %fp, %ref[%j, %i] : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf32> + %alloc_0 = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.index_cast %arg1 : index to i32 + %2 = arith.uitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0, %arg1] : memref<32x32xf32> + memref.store %2, %alloc_0[%arg1, %arg0] : memref<32x32xf32> } } - - %2 = call @test(%0) : (memref<32x32xf32>) -> memref<32x32xf32> - %res = memref.cast %2 : memref<32x32xf32> to memref<*xf32> - %cast_ref = memref.cast %ref : memref<32x32xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%res, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<32x32xf32>) -> memref<32x32xf32> + %cast = memref.cast %0 : memref<32x32xf32> to memref<*xf32> + %cast_1 = memref.cast %alloc_0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } // func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf16.mlir b/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf16.mlir index 953492409..de50c8034 100644 --- a/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf16.mlir +++ b/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf16.mlir @@ -1,65 +1,61 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %arg0, %memref : memref<32x32xf16> to memref<32x32xf16> - %B = gpu.alloc host_shared () : memref<32x32xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c4, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %B : memref<32x32xf16>) + %memref = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c4, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>) gpu.dealloc %memref : memref<32x32xf16> - return %B : memref<32x32xf16> + %alloc = memref.alloc() : memref<32x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf16>, memref<32x32xf16> + gpu.dealloc %memref_0 : memref<32x32xf16> + return %alloc : memref<32x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c8 = arith.constant 8 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c8 : index - %4 = xegpu.create_nd_tdesc %arg0[%2, %3] : memref<32x32xf16> -> !xegpu.tensor_desc<8x8xf16> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x8xf16> -> vector<8x8xf16> - %6 = xegpu.create_nd_tdesc %arg1[%3, %2] : memref<32x32xf16> -> !xegpu.tensor_desc<8x8xf16> - %7 = vector.transpose %5, [1, 0]: vector<8x8xf16> to vector<8x8xf16> - xegpu.store_nd %7, %6 : vector<8x8xf16>, !xegpu.tensor_desc<8x8xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c8 : index + %2 = xegpu.create_nd_tdesc %arg0[%0, %1] : memref<32x32xf16> -> !xegpu.tensor_desc<8x8xf16> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x8xf16> -> vector<8x8xf16> + %4 = xegpu.create_nd_tdesc %arg1[%1, %0] : memref<32x32xf16> -> !xegpu.tensor_desc<8x8xf16> + %5 = vector.transpose %3, [1, 0] : vector<8x8xf16> to vector<8x8xf16> + xegpu.store_nd %5, %4 : vector<8x8xf16>, !xegpu.tensor_desc<8x8xf16> gpu.return } } - - func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.alloc() : memref<32x32xf16> - %ref = memref.alloc() : memref<32x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %int = arith.index_cast %j : index to i32 - %fp = arith.uitofp %int : i32 to f16 - memref.store %fp, %0[%i, %j] : memref<32x32xf16> - %fp_32 = arith.extf %fp : f16 to f32 - memref.store %fp_32, %ref[%j, %i] : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf16> + %alloc_0 = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.index_cast %arg1 : index to i32 + %2 = arith.uitofp %1 : i32 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<32x32xf16> + %3 = arith.extf %2 : f16 to f32 + memref.store %3, %alloc_0[%arg1, %arg0] : memref<32x32xf32> } } - - %2 = call @test(%0) : (memref<32x32xf16>) -> memref<32x32xf16> - %res = memref.cast %2 : memref<32x32xf16> to memref<*xf16> - %cast_ref = memref.cast %ref : memref<32x32xf32> to memref<*xf32> // call @printMemreff16(%cast_ref) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%res, %cast_ref) : (memref<*xf16>, memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<32x32xf16>) -> memref<32x32xf16> + %cast = memref.cast %0 : memref<32x32xf16> to memref<*xf16> + %cast_1 = memref.cast %alloc_0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_1) : (memref<*xf16>, memref<*xf32>) -> () return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf32.mlir b/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf32.mlir index 543aae0a4..dc1f3120f 100644 --- a/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf32.mlir +++ b/test/Integration/Dialect/XeGPU/VC/transpose_8x8xf32.mlir @@ -1,63 +1,59 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @transpose attributes {gpu.container_module} { func.func @test(%arg0: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c4 = arith.constant 4 : index %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32x32xf32> - memref.copy %arg0, %memref : memref<32x32xf32> to memref<32x32xf32> - %B = gpu.alloc host_shared () : memref<32x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c4, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf32>, %B : memref<32x32xf32>) + %memref = gpu.alloc () : memref<32x32xf32> + gpu.memcpy %memref, %arg0 : memref<32x32xf32>, memref<32x32xf32> + %memref_0 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c4, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf32>, %memref_0 : memref<32x32xf32>) gpu.dealloc %memref : memref<32x32xf32> - return %B : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref_0 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_kernel { gpu.func @test_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c8 = arith.constant 8 : index - %0 = gpu.block_id x - %1 = gpu.block_id y - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c8 : index - %4 = xegpu.create_nd_tdesc %arg0[%2, %3] : memref<32x32xf32> -> !xegpu.tensor_desc<8x8xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x8xf32> -> vector<8x8xf32> - %6 = xegpu.create_nd_tdesc %arg1[%3, %2] : memref<32x32xf32> -> !xegpu.tensor_desc<8x8xf32> - %7 = vector.transpose %5, [1, 0]: vector<8x8xf32> to vector<8x8xf32> - xegpu.store_nd %7, %6 : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c8 : index + %2 = xegpu.create_nd_tdesc %arg0[%0, %1] : memref<32x32xf32> -> !xegpu.tensor_desc<8x8xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x8xf32> -> vector<8x8xf32> + %4 = xegpu.create_nd_tdesc %arg1[%1, %0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x8xf32> + %5 = vector.transpose %3, [1, 0] : vector<8x8xf32> to vector<8x8xf32> + xegpu.store_nd %5, %4 : vector<8x8xf32>, !xegpu.tensor_desc<8x8xf32> gpu.return } } - - func.func @main() attributes {llvm.emit_c_interface} { - %0 = memref.alloc() : memref<32x32xf32> - %ref = memref.alloc() : memref<32x32xf32> - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index // A matrix: row-major, start from 0.0, increase 0.01 per element // B matrix: A matrix + 1.0 - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %int = arith.index_cast %j : index to i32 - %fp = arith.uitofp %int : i32 to f32 - memref.store %fp, %0[%i, %j] : memref<32x32xf32> - memref.store %fp, %ref[%j, %i] : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf32> + %alloc_0 = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.index_cast %arg1 : index to i32 + %2 = arith.uitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0, %arg1] : memref<32x32xf32> + memref.store %2, %alloc_0[%arg1, %arg0] : memref<32x32xf32> } } - - %2 = call @test(%0) : (memref<32x32xf32>) -> memref<32x32xf32> - %res = memref.cast %2 : memref<32x32xf32> to memref<*xf32> - %cast_ref = memref.cast %ref : memref<32x32xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%res, %cast_ref) : (memref<*xf32>, memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<32x32xf32>) -> memref<32x32xf32> + %cast = memref.cast %0 : memref<32x32xf32> to memref<*xf32> + %cast_1 = memref.cast %alloc_0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/unranked_memref.vc.mlir b/test/Integration/Dialect/XeGPU/VC/unranked_memref.vc.mlir deleted file mode 100644 index 30007d902..000000000 --- a/test/Integration/Dialect/XeGPU/VC/unranked_memref.vc.mlir +++ /dev/null @@ -1,60 +0,0 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module} { - func.func @test(%A : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { - %c1 = arith.constant 1 : index - %memref_0 = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %A, %memref_0 : memref<8x16xf32> to memref<8x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - %memref_0_cast = memref.cast %memref_0 : memref<8x16xf32> to memref<*xf32> - %memref_1_cast = memref.cast %memref_1 : memref<8x16xf32> to memref<*xf32> - %dim0 = arith.constant 8 : index - %dim1 = arith.constant 16 : index - %stride0 = arith.constant 16 : index - %stride1 = arith.constant 1 : index - %x = arith.constant 0 : index - %y = arith.constant 0 : index - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0_cast : memref<*xf32>, %memref_1_cast : memref<*xf32>, %dim0 : index, %dim1 : index, %stride0 : index, %stride1 : index, %x : index, %y : index) - gpu.dealloc %memref_0 : memref<8x16xf32> - return %memref_1 : memref<8x16xf32> - } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%arg0 : memref<*xf32>, %arg1: memref<*xf32>, %dim0: index, %dim1: index, %stride0: index, %stride1: index, %x: index, %y: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %ranked0 = memref.cast %arg0 : memref<*xf32> to memref - %ranked1 = memref.cast %arg1 : memref<*xf32> to memref - %1 = xegpu.create_nd_tdesc %ranked0[%x, %y], [%dim0, %dim1], [%stride0, %stride1] : memref -> !xegpu.tensor_desc<8x16xf32> - %2 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - %6 = xegpu.create_nd_tdesc %ranked1[%x, %y], [%dim0, %dim1], [%stride0, %stride1] : memref -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %2, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - gpu.return - } - } - func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant -0.5 : f32 - %cf_upper = arith.constant 0.5 : f32 - - call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> () - - %B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32> - %B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32> - %A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32> - // call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<8x16xf32> - return - } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} - func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface} - func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} -} diff --git a/test/Integration/Dialect/XeGPU/VC/vector_broadcast_1.mlir b/test/Integration/Dialect/XeGPU/VC/vector_broadcast_1.mlir index 9a6972e21..1d3d1300e 100644 --- a/test/Integration/Dialect/XeGPU/VC/vector_broadcast_1.mlir +++ b/test/Integration/Dialect/XeGPU/VC/vector_broadcast_1.mlir @@ -1,133 +1,107 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_32x32xf16 : memref<32x32xf16> = dense<1.0> - memref.global "private" @__constant_B32x32xf16 : memref<32x32xf16> = dense<2.0> - memref.global "private" @__constant_1x32xf16 : memref<1x32xf16> = dense<10.0> - func.func @test(%A: memref<32x32xf16>, %B: memref<32x32xf16>, %bcast : memref<1x32xf16> ) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { + memref.global "private" @__constant_32x32xf16 : memref<32x32xf16> = dense<1.000000e+00> + memref.global "private" @__constant_B32x32xf16 : memref<32x32xf16> = dense<2.000000e+00> + memref.global "private" @__constant_1x32xf16 : memref<1x32xf16> = dense<1.000000e+01> + func.func @test(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<1x32xf16>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index + %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index - %memref = gpu.alloc host_shared () : memref<32x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<32x32xf16> - %memref_2 = gpu.alloc host_shared () : memref<1x32xf16> - memref.copy %A, %memref : memref<32x32xf16> to memref<32x32xf16> - memref.copy %B, %memref_1 : memref<32x32xf16> to memref<32x32xf16> - memref.copy %bcast, %memref_2 : memref<1x32xf16> to memref<1x32xf16> - %memref_3 = gpu.alloc host_shared () : memref<32x32xf32> - gpu.launch_func @module0::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_1 : memref<32x32xf16>, %memref_3 : memref<32x32xf32>, %memref_2 : memref<1x32xf16>) + %memref = gpu.alloc () : memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + %memref_1 = gpu.alloc () : memref<1x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + gpu.memcpy %memref_0, %arg1 : memref<32x32xf16>, memref<32x32xf16> + gpu.memcpy %memref_1, %arg2 : memref<1x32xf16>, memref<1x32xf16> + %memref_2 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @module0::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>, %memref_2 : memref<32x32xf32>, %memref_1 : memref<1x32xf16>) gpu.dealloc %memref : memref<32x32xf16> - gpu.dealloc %memref_1 : memref<32x32xf16> - gpu.dealloc %memref_2 : memref<1x32xf16> - return %memref_3 : memref<32x32xf32> + gpu.dealloc %memref_0 : memref<32x32xf16> + gpu.dealloc %memref_1 : memref<1x32xf16> + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_2 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref_2 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } - - gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>, %Out: memref<32x32xf32>, %bcast : memref<1x32xf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module0 { + gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>, %arg3: memref<1x32xf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c8 = arith.constant 8 : index - - %0 = gpu.block_id x - %1 = gpu.block_id y - - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - - %4 = xegpu.create_nd_tdesc %Out[%2, %3] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - - %6 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - // load A tile - %a_tile0 = xegpu.create_nd_tdesc %A [%2, %arg3] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %A0_val = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - // load B tile - %b_tile0 = xegpu.create_nd_tdesc %B [%arg3, %3] : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> - %B0_val = xegpu.load_nd %b_tile0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - // load B cast - %bcast_tile = xegpu.create_nd_tdesc %bcast [%c0, %c0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x32xf16> - %val3 = xegpu.load_nd %bcast_tile : !xegpu.tensor_desc<1x32xf16> -> vector<1x32xf16> - // extract first 16 elems - %val5 = vector.extract_strided_slice %val3 {offsets = [0, 0], strides = [1, 1], sizes = [1, 16]} - : vector<1x32xf16> to vector<1x16xf16> // broadcast over row dim - %val6 = vector.broadcast %val5 : vector<1x16xf16> to vector<8x16xf16> // add to A - %A0_val8 = arith.addf %A0_val, %val6 : vector<8x16xf16> - // do DPAS - %dpas = xegpu.dpas %A0_val8, %B0_val, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - scf.yield %dpas : vector<8x16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg4 = %c0 to %c32 step %c16 iter_args(%arg5 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg4] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %7 = xegpu.create_nd_tdesc %arg1[%arg4, %1] : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x32xf16> + %10 = xegpu.load_nd %9 : !xegpu.tensor_desc<1x32xf16> -> vector<1x32xf16> + %11 = vector.extract_strided_slice %10 {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf16> to vector<1x16xf16> + %12 = vector.broadcast %11 : vector<1x16xf16> to vector<8x16xf16> + %13 = arith.addf %6, %12 : vector<8x16xf16> + %14 = xegpu.dpas %13, %8, %arg5 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %14 : vector<8x16xf32> } // store - - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { // init constants + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %c1_f32 = arith.constant 1.0 : f32 // random init - %lower = arith.constant -1.0 : f32 - %upper = arith.constant 1.0 : f32 - %false = arith.constant 0 : i1 - %A = memref.get_global @__constant_32x32xf16 : memref<32x32xf16> - %B = memref.get_global @__constant_B32x32xf16 : memref<32x32xf16> - %bcast = memref.get_global @__constant_1x32xf16 : memref<1x32xf16> - - %Out_cpu = memref.alloc() : memref<32x32xf32> - - %A_random = memref.cast %A : memref<32x32xf16> to memref<*xf16> - %B_random = memref.cast %B : memref<32x32xf16> to memref<*xf16> - %bcast_random = memref.cast %bcast : memref<1x32xf16> to memref<*xf16> - // run GPU version - %Out_gpu = call @test(%A, %B, %bcast) : (memref<32x32xf16>, memref<32x32xf16>, memref<1x32xf16>) -> memref<32x32xf32> - %Out_gpu_cast = memref.cast %Out_gpu : memref<32x32xf32> to memref<*xf32> // run CPU version - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %v0_init = arith.constant 0.0 : f32 - %result:1 = scf.for %k = %c0 to %c32 step %c1 iter_args(%v0 = %v0_init) -> f32 { - %a0 = memref.load %A[%i, %k] : memref<32x32xf16> - %b0 = memref.load %B[%k, %j] : memref<32x32xf16> - %bcast_val = memref.load %bcast[%c0, %k] : memref<1x32xf16> - %t1 = arith.addf %a0, %bcast_val : f16 - %a0_f32 = arith.extf %t1 : f16 to f32 - %b0_f32 = arith.extf %b0 : f16 to f32 - %t0 = arith.mulf %a0_f32, %b0_f32 : f32 - %v0_new = arith.addf %v0, %t0 : f32 - scf.yield %v0_new : f32 + %0 = memref.get_global @__constant_32x32xf16 : memref<32x32xf16> + %1 = memref.get_global @__constant_B32x32xf16 : memref<32x32xf16> + %2 = memref.get_global @__constant_1x32xf16 : memref<1x32xf16> + %alloc = memref.alloc() : memref<32x32xf32> + %3 = call @test(%0, %1, %2) : (memref<32x32xf16>, memref<32x32xf16>, memref<1x32xf16>) -> memref<32x32xf32> + %cast = memref.cast %3 : memref<32x32xf32> to memref<*xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %4 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst) -> (f32) { + %5 = memref.load %0[%arg0, %arg2] : memref<32x32xf16> + %6 = memref.load %1[%arg2, %arg1] : memref<32x32xf16> + %7 = memref.load %2[%c0, %arg2] : memref<1x32xf16> + %8 = arith.addf %5, %7 : f16 + %9 = arith.extf %8 : f16 to f32 + %10 = arith.extf %6 : f16 to f32 + %11 = arith.mulf %9, %10 : f32 + %12 = arith.addf %arg3, %11 : f32 + scf.yield %12 : f32 } // only update the first 8x8 of the result, next 8x8 is value 1 - memref.store %result#0, %Out_cpu[%i, %j] : memref<32x32xf32> + memref.store %4, %alloc[%arg0, %arg1] : memref<32x32xf32> } } - %Out_cpu_cast = memref.cast %Out_cpu : memref<32x32xf32> to memref<*xf32> // print GPU and CPU outs // call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () + %cast_0 = memref.cast %alloc : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_0) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/vector_broadcast_2.mlir b/test/Integration/Dialect/XeGPU/VC/vector_broadcast_2.mlir index 576a6b7c1..1ff791e59 100644 --- a/test/Integration/Dialect/XeGPU/VC/vector_broadcast_2.mlir +++ b/test/Integration/Dialect/XeGPU/VC/vector_broadcast_2.mlir @@ -1,135 +1,109 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_32x32xf16 : memref<32x32xf16> = dense<1.0> - memref.global "private" @__constant_B32x32xf16 : memref<32x32xf16> = dense<2.0> - memref.global "private" @__constant_1x32xf16 : memref<1x32xf16> = dense<10.0> - func.func @test(%A: memref<32x32xf16>, %B: memref<32x32xf16>, %bcast : memref<1x32xf16> ) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { + memref.global "private" @__constant_32x32xf16 : memref<32x32xf16> = dense<1.000000e+00> + memref.global "private" @__constant_B32x32xf16 : memref<32x32xf16> = dense<2.000000e+00> + memref.global "private" @__constant_1x32xf16 : memref<1x32xf16> = dense<1.000000e+01> + func.func @test(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<1x32xf16>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index + %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index - %memref = gpu.alloc host_shared () : memref<32x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<32x32xf16> - %memref_2 = gpu.alloc host_shared () : memref<1x32xf16> - memref.copy %A, %memref : memref<32x32xf16> to memref<32x32xf16> - memref.copy %B, %memref_1 : memref<32x32xf16> to memref<32x32xf16> - memref.copy %bcast, %memref_2 : memref<1x32xf16> to memref<1x32xf16> - %memref_3 = gpu.alloc host_shared () : memref<32x32xf32> - gpu.launch_func @module0::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_1 : memref<32x32xf16>, %memref_3 : memref<32x32xf32>, %memref_2 : memref<1x32xf16>) + %memref = gpu.alloc () : memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + %memref_1 = gpu.alloc () : memref<1x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + gpu.memcpy %memref_0, %arg1 : memref<32x32xf16>, memref<32x32xf16> + gpu.memcpy %memref_1, %arg2 : memref<1x32xf16>, memref<1x32xf16> + %memref_2 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @module0::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>, %memref_2 : memref<32x32xf32>, %memref_1 : memref<1x32xf16>) gpu.dealloc %memref : memref<32x32xf16> - gpu.dealloc %memref_1 : memref<32x32xf16> - gpu.dealloc %memref_2 : memref<1x32xf16> - return %memref_3 : memref<32x32xf32> + gpu.dealloc %memref_0 : memref<32x32xf16> + gpu.dealloc %memref_1 : memref<1x32xf16> + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_2 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref_2 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } - - gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>, %Out: memref<32x32xf32>, %bcast : memref<1x32xf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module0 { + gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>, %arg3: memref<1x32xf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c8 = arith.constant 8 : index - - %0 = gpu.block_id x - %1 = gpu.block_id y - - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - - %4 = xegpu.create_nd_tdesc %Out[%2, %3] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - - %6 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) { - // load A tile - %a_tile0 = xegpu.create_nd_tdesc %A [%2, %arg3] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %A0_val = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - // load B tile - %b_tile0 = xegpu.create_nd_tdesc %B [%arg3, %3] : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> - %B0_val = xegpu.load_nd %b_tile0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - // load B cast - %bcast_tile = xegpu.create_nd_tdesc %bcast [%c0, %c0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x32xf16> - %val3 = xegpu.load_nd %bcast_tile : !xegpu.tensor_desc<1x32xf16> -> vector<1x32xf16> - // extract first 8 elems - %val5 = vector.extract_strided_slice %val3 {offsets = [0, 0], strides = [1, 1], sizes = [1, 8]} - : vector<1x32xf16> to vector<1x8xf16> // reshape and broadcast over col dim - %val6 = vector.shape_cast %val5 : vector<1x8xf16> to vector<8xf16> - %t = vector.shape_cast %val6 : vector<8xf16> to vector<8x1xf16> - %val7 = vector.broadcast %t : vector<8x1xf16> to vector<8x16xf16> // add to A - %A0_val8 = arith.addf %A0_val, %val7 : vector<8x16xf16> - // do DPAS - %dpas = xegpu.dpas %A0_val8, %B0_val, %arg4 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> - - scf.yield %dpas : vector<8x16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %4 = scf.for %arg4 = %c0 to %c32 step %c16 iter_args(%arg5 = %3) -> (vector<8x16xf32>) { + %5 = xegpu.create_nd_tdesc %arg0[%0, %arg4] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %7 = xegpu.create_nd_tdesc %arg1[%arg4, %1] : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %9 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x32xf16> + %10 = xegpu.load_nd %9 : !xegpu.tensor_desc<1x32xf16> -> vector<1x32xf16> + %11 = vector.extract_strided_slice %10 {offsets = [0, 0], sizes = [1, 8], strides = [1, 1]} : vector<1x32xf16> to vector<1x8xf16> + %12 = vector.shape_cast %11 : vector<1x8xf16> to vector<8xf16> + %13 = vector.shape_cast %12 : vector<8xf16> to vector<8x1xf16> + %14 = vector.broadcast %13 : vector<8x1xf16> to vector<8x16xf16> + %15 = arith.addf %6, %14 : vector<8x16xf16> + %16 = xegpu.dpas %15, %8, %arg5 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %16 : vector<8x16xf32> } // store - - xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { // init constants + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %c1_f32 = arith.constant 1.0 : f32 // random init - %lower = arith.constant -1.0 : f32 - %upper = arith.constant 1.0 : f32 - %false = arith.constant 0 : i1 - %A = memref.get_global @__constant_32x32xf16 : memref<32x32xf16> - %B = memref.get_global @__constant_B32x32xf16 : memref<32x32xf16> - %bcast = memref.get_global @__constant_1x32xf16 : memref<1x32xf16> - - %Out_cpu = memref.alloc() : memref<32x32xf32> - - %A_random = memref.cast %A : memref<32x32xf16> to memref<*xf16> - %B_random = memref.cast %B : memref<32x32xf16> to memref<*xf16> - %bcast_random = memref.cast %bcast : memref<1x32xf16> to memref<*xf16> - // run GPU version - %Out_gpu = call @test(%A, %B, %bcast) : (memref<32x32xf16>, memref<32x32xf16>, memref<1x32xf16>) -> memref<32x32xf32> - %Out_gpu_cast = memref.cast %Out_gpu : memref<32x32xf32> to memref<*xf32> // run CPU version - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %v0_init = arith.constant 0.0 : f32 - %result:1 = scf.for %k = %c0 to %c32 step %c1 iter_args(%v0 = %v0_init) -> f32 { - %a0 = memref.load %A[%i, %k] : memref<32x32xf16> - %b0 = memref.load %B[%k, %j] : memref<32x32xf16> - %bcast_val = memref.load %bcast[%c0, %i] : memref<1x32xf16> - %t1 = arith.addf %a0, %bcast_val : f16 - %a0_f32 = arith.extf %t1 : f16 to f32 - %b0_f32 = arith.extf %b0 : f16 to f32 - %t0 = arith.mulf %a0_f32, %b0_f32 : f32 - %v0_new = arith.addf %v0, %t0 : f32 - scf.yield %v0_new : f32 + %0 = memref.get_global @__constant_32x32xf16 : memref<32x32xf16> + %1 = memref.get_global @__constant_B32x32xf16 : memref<32x32xf16> + %2 = memref.get_global @__constant_1x32xf16 : memref<1x32xf16> + %alloc = memref.alloc() : memref<32x32xf32> + %3 = call @test(%0, %1, %2) : (memref<32x32xf16>, memref<32x32xf16>, memref<1x32xf16>) -> memref<32x32xf32> + %cast = memref.cast %3 : memref<32x32xf32> to memref<*xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %4 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %cst) -> (f32) { + %5 = memref.load %0[%arg0, %arg2] : memref<32x32xf16> + %6 = memref.load %1[%arg2, %arg1] : memref<32x32xf16> + %7 = memref.load %2[%c0, %arg0] : memref<1x32xf16> + %8 = arith.addf %5, %7 : f16 + %9 = arith.extf %8 : f16 to f32 + %10 = arith.extf %6 : f16 to f32 + %11 = arith.mulf %9, %10 : f32 + %12 = arith.addf %arg3, %11 : f32 + scf.yield %12 : f32 } // only update the first 8x8 of the result, next 8x8 is value 1 - memref.store %result#0, %Out_cpu[%i, %j] : memref<32x32xf32> + memref.store %4, %alloc[%arg0, %arg1] : memref<32x32xf32> } } - %Out_cpu_cast = memref.cast %Out_cpu : memref<32x32xf32> to memref<*xf32> // print GPU and CPU outs // call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () + %cast_0 = memref.cast %alloc : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_0) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_1.vc.mlir b/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_1.vc.mlir index a494cca84..de696855d 100644 --- a/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_1.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_1.vc.mlir @@ -1,140 +1,120 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module} { - memref.global "private" @__constant_8x32xf16 : memref<8x32xf16> = dense<1.0> - memref.global "private" @__constant_16x32xf16 : memref<16x32xf16> = dense<2.0> +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck - func.func @test(%A: memref<8x32xf16>, %B: memref<16x32xf16> ) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { +module @gemm attributes {gpu.container_module} { + memref.global "private" @__constant_8x32xf16 : memref<8x32xf16> = dense<1.000000e+00> + memref.global "private" @__constant_16x32xf16 : memref<16x32xf16> = dense<2.000000e+00> + func.func @test(%arg0: memref<8x32xf16>, %arg1: memref<16x32xf16>) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<16x32xf16> - memref.copy %A, %memref : memref<8x32xf16> to memref<8x32xf16> - memref.copy %B, %memref_1 : memref<16x32xf16> to memref<16x32xf16> - %memref_2 = gpu.alloc host_shared () : memref<8x32xf32> - gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_1 : memref<16x32xf16>, %memref_2 : memref<8x32xf32>) + %memref = gpu.alloc () : memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<16x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + gpu.memcpy %memref_0, %arg1 : memref<16x32xf16>, memref<16x32xf16> + %memref_1 = gpu.alloc () : memref<8x32xf32> + gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<16x32xf16>, %memref_1 : memref<8x32xf32>) gpu.dealloc %memref : memref<8x32xf16> - gpu.dealloc %memref_1 : memref<16x32xf16> - return %memref_2 : memref<8x32xf32> + gpu.dealloc %memref_0 : memref<16x32xf16> + %alloc = memref.alloc() : memref<8x32xf32> + gpu.memcpy %alloc, %memref_1 : memref<8x32xf32>, memref<8x32xf32> + gpu.dealloc %memref_1 : memref<8x32xf32> + return %alloc : memref<8x32xf32> } - - gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<8x32xf16>, %B: memref<16x32xf16>, %C: memref<8x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module0 { + gpu.func @test_kernel(%arg0: memref<8x32xf16>, %arg1: memref<16x32xf16>, %arg2: memref<8x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index // load A tile - %A0 = xegpu.create_nd_tdesc %A[%c0, %c0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %A1 = xegpu.create_nd_tdesc %A[%c0, %c16] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %A0_val = xegpu.load_nd %A0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - %A1_val = xegpu.load_nd %A1 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - // load B tile - %B0 = xegpu.create_nd_tdesc %B[%c0, %c0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> - %B1 = xegpu.create_nd_tdesc %B[%c0, %c16] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> - %B0_val = xegpu.load_nd %B0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - %B1_val = xegpu.load_nd %B1 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - // do DPAS - %dpas0 = xegpu.dpas %A0_val, %B0_val : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - %dpas1 = xegpu.dpas %A1_val, %B1_val : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - // extract second 8x8 - %val5_0 = vector.extract_strided_slice %dpas0 {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} : vector<8x16xf32> to vector<8x8xf32> - %val5_1 = vector.extract_strided_slice %dpas1 {sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} : vector<8x16xf32> to vector<8x8xf32> - - %cst_8x8_flat = arith.constant dense<1.0> : vector<64xf32> - %cst_8x8 = vector.shape_cast %cst_8x8_flat : vector<64xf32> to vector<8x8xf32> // shift the first half to left and use %cst_8x8 as the second half - - %val6_0 = vector.shuffle %val5_0, %cst_8x8 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> - %val6_1 = vector.shuffle %val5_1, %cst_8x8 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> - - %val7_0 = vector.shape_cast %val6_0 : vector<16x8xf32> to vector<8x16xf32> - %val7_1 = vector.shape_cast %val6_1 : vector<16x8xf32> to vector<8x16xf32> - // store - %out_tile_0 = xegpu.create_nd_tdesc %C [%c0, %c0] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> - %out_tile_1 = xegpu.create_nd_tdesc %C [%c0, %c16] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - xegpu.store_nd %val7_0, %out_tile_0 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %val7_1, %out_tile_1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> - + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg0[%c0, %c16] : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %4 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %5 = xegpu.create_nd_tdesc %arg1[%c0, %c16] : memref<16x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %6 = xegpu.load_nd %4 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %7 = xegpu.load_nd %5 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %8 = xegpu.dpas %2, %6 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %9 = xegpu.dpas %3, %7 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %10 = vector.extract_strided_slice %8 {offsets = [0, 8], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %11 = vector.extract_strided_slice %9 {offsets = [0, 8], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %cst = arith.constant dense<1.000000e+00> : vector<64xf32> + %12 = vector.shape_cast %cst : vector<64xf32> to vector<8x8xf32> + %13 = vector.shuffle %10, %12 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> + %14 = vector.shuffle %11, %12 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> + %15 = vector.shape_cast %13 : vector<16x8xf32> to vector<8x16xf32> + %16 = vector.shape_cast %14 : vector<16x8xf32> to vector<8x16xf32> + %17 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> + %18 = xegpu.create_nd_tdesc %arg2[%c0, %c16] : memref<8x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %15, %17 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %16, %18 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { // init constants + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c1_f32 = arith.constant 1.0 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 %c24 = arith.constant 24 : index - %c32 = arith.constant 32 : index - // random init - %lower = arith.constant -1.0 : f32 - %upper = arith.constant 1.0 : f32 - %false = arith.constant 0 : i1 - %A = memref.get_global @__constant_8x32xf16 : memref<8x32xf16> - %B =memref.get_global @__constant_16x32xf16 : memref<16x32xf16> - %Out_cpu = memref.alloc() : memref<8x32xf32> // run GPU version - %Out_gpu = call @test(%A, %B) : (memref<8x32xf16>, memref<16x32xf16>) -> memref<8x32xf32> - %Out_gpu_cast = memref.cast %Out_gpu : memref<8x32xf32> to memref<*xf32> // run CPU version - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c8 to %c16 step %c1 { - %v0_init = arith.constant 0.0 : f32 - %result:1 = scf.for %k = %c0 to %c16 step %c1 iter_args(%v0 = %v0_init) -> f32 { - %a0 = memref.load %A[%i, %k] : memref<8x32xf16> - %b0 = memref.load %B[%k, %j] : memref<16x32xf16> - %a0_f32 = arith.extf %a0 : f16 to f32 - %b0_f32 = arith.extf %b0 : f16 to f32 - %t0 = arith.mulf %a0_f32, %b0_f32 : f32 - %v0_new = arith.addf %v0, %t0 : f32 - scf.yield %v0_new : f32 + %0 = memref.get_global @__constant_8x32xf16 : memref<8x32xf16> + %1 = memref.get_global @__constant_16x32xf16 : memref<16x32xf16> + %alloc = memref.alloc() : memref<8x32xf32> + %2 = call @test(%0, %1) : (memref<8x32xf16>, memref<16x32xf16>) -> memref<8x32xf32> + %cast = memref.cast %2 : memref<8x32xf32> to memref<*xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c8 to %c16 step %c1 { + %3 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %cst) -> (f32) { + %5 = memref.load %0[%arg0, %arg2] : memref<8x32xf16> + %6 = memref.load %1[%arg2, %arg1] : memref<16x32xf16> + %7 = arith.extf %5 : f16 to f32 + %8 = arith.extf %6 : f16 to f32 + %9 = arith.mulf %7, %8 : f32 + %10 = arith.addf %arg3, %9 : f32 + scf.yield %10 : f32 } // only update the 8x8 of first half of 8x32 of the result, next 8x8 is value 1 - %shifted_j = arith.subi %j, %c8 : index - memref.store %result#0, %Out_cpu[%i, %shifted_j] : memref<8x32xf32> - memref.store %c1_f32, %Out_cpu[%i, %j] : memref<8x32xf32> + %4 = arith.subi %arg1, %c8 : index + memref.store %3, %alloc[%arg0, %4] : memref<8x32xf32> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<8x32xf32> } } - // run CPU version - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c24 to %c32 step %c1 { - %v0_init = arith.constant 0.0 : f32 - %result:1 = scf.for %k = %c0 to %c16 step %c1 iter_args(%v0 = %v0_init) -> f32 { - %a0 = memref.load %A[%i, %k] : memref<8x32xf16> - %b0 = memref.load %B[%k, %j] : memref<16x32xf16> - %a0_f32 = arith.extf %a0 : f16 to f32 - %b0_f32 = arith.extf %b0 : f16 to f32 - %t0 = arith.mulf %a0_f32, %b0_f32 : f32 - %v0_new = arith.addf %v0, %t0 : f32 - scf.yield %v0_new : f32 + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c24 to %c32 step %c1 { + %3 = scf.for %arg2 = %c0 to %c16 step %c1 iter_args(%arg3 = %cst) -> (f32) { + %5 = memref.load %0[%arg0, %arg2] : memref<8x32xf16> + %6 = memref.load %1[%arg2, %arg1] : memref<16x32xf16> + %7 = arith.extf %5 : f16 to f32 + %8 = arith.extf %6 : f16 to f32 + %9 = arith.mulf %7, %8 : f32 + %10 = arith.addf %arg3, %9 : f32 + scf.yield %10 : f32 } // only update the 8x8 of second half of 8x32 of the result, next 8x8 is value 1 - %shifted_j = arith.subi %j, %c8 : index - memref.store %result#0, %Out_cpu[%i, %shifted_j] : memref<8x32xf32> - memref.store %c1_f32, %Out_cpu[%i, %j] : memref<8x32xf32> + %4 = arith.subi %arg1, %c8 : index + memref.store %3, %alloc[%arg0, %4] : memref<8x32xf32> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<8x32xf32> } } - %Out_cpu_cast = memref.cast %Out_cpu : memref<8x32xf32> to memref<*xf32> - // print GPU and CPU outs // call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () - + %cast_1 = memref.cast %alloc : memref<8x32xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_2.vc.mlir b/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_2.vc.mlir index d46c9f130..9ea32a0e5 100644 --- a/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_2.vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/vector_extract_strided_slice_2.vc.mlir @@ -1,42 +1,40 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<32x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<32x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<32x16xf32> - memref.copy %A, %memref : memref<32x16xf32> to memref<32x16xf32> - %memref_1 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x16xf32>, %memref_1 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<32x16xf32> + gpu.memcpy %memref, %arg0 : memref<32x16xf32>, memref<32x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x16xf32>, %memref_0 : memref<8x16xf32>) gpu.dealloc %memref : memref<32x16xf32> - return %memref_1 : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - - gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<32x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module0 { + gpu.func @test_kernel(%arg0: memref<32x16xf32>, %arg1: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index // load tile - %tile = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<32x16xf32> -> !xegpu.tensor_desc<32x8xf32, #xegpu.block_tdesc_attr> - %value = xegpu.load_nd %tile : !xegpu.tensor_desc<32x8xf32, #xegpu.block_tdesc_attr> -> vector<2x32x8xf32> // extract the bottom 8x8 part of first 32x8 block - %sub_tile0 = vector.extract_strided_slice %value { offsets = [0, 24], strides = [1, 1], sizes = [1, 8] } : vector<2x32x8xf32> to vector<1x8x8xf32> // extract the bottom 8x8 part of second 32x8 block - %sub_tile1 = vector.extract_strided_slice %value { offsets = [1, 24], strides = [1, 1], sizes = [1, 8] } : vector<2x32x8xf32> to vector<1x8x8xf32> // combine these two 8x8 tiles into a single 8x16 tile - %t1 = vector.shape_cast %sub_tile0 : vector<1x8x8xf32> to vector<8x8xf32> - %t2 = vector.shape_cast %sub_tile1 : vector<1x8x8xf32> to vector<8x8xf32> - %t3 = vector.shuffle %t1, %t2 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> - %t4 = vector.shape_cast %t3 : vector<16x8xf32> to vector<8x16xf32> - // store the result - %out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %t4, %out_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<32x16xf32> -> !xegpu.tensor_desc<32x8xf32, #xegpu.block_tdesc_attr> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x8xf32, #xegpu.block_tdesc_attr> -> vector<2x32x8xf32> + %2 = vector.extract_strided_slice %1 {offsets = [0, 24], sizes = [1, 8], strides = [1, 1]} : vector<2x32x8xf32> to vector<1x8x8xf32> + %3 = vector.extract_strided_slice %1 {offsets = [1, 24], sizes = [1, 8], strides = [1, 1]} : vector<2x32x8xf32> to vector<1x8x8xf32> + %4 = vector.shape_cast %2 : vector<1x8x8xf32> to vector<8x8xf32> + %5 = vector.shape_cast %3 : vector<1x8x8xf32> to vector<8x8xf32> + %6 = vector.shuffle %4, %5 [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8x8xf32>, vector<8x8xf32> + %7 = vector.shape_cast %6 : vector<16x8xf32> to vector<8x16xf32> + %8 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %7, %8 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } @@ -47,39 +45,37 @@ module @gemm attributes {gpu.container_module} { %c32 = arith.constant 32 : index %c16 = arith.constant 16 : index %c24 = arith.constant 24 : index - %c1_f32 = arith.constant 1.0 : f32 - %A = memref.alloc() : memref<32x16xf32> - %Out_cpu = memref.alloc() : memref<8x16xf32> // fill A with values form 0, 1, ...., 511 - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %t1 = arith.muli %i, %c16 : index - %val = arith.addi %t1, %j : index - %val_i32 = arith.index_cast %val : index to i32 - %val_f32 = arith.sitofp %val_i32 : i32 to f32 - %cond = arith.cmpi "sge", %i, %c24 : index // only store the bottom 8x16 into Out_cpu - scf.if %cond { - %i_cpu = arith.subi %i, %c24 : index - memref.store %val_f32, %Out_cpu[%i_cpu, %j] : memref<8x16xf32> + %alloc = memref.alloc() : memref<32x16xf32> + %alloc_0 = memref.alloc() : memref<8x16xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = arith.muli %arg0, %c16 : index + %2 = arith.addi %1, %arg1 : index + %3 = arith.index_cast %2 : index to i32 + %4 = arith.sitofp %3 : i32 to f32 + %5 = arith.cmpi sge, %arg0, %c24 : index + scf.if %5 { + %6 = arith.subi %arg0, %c24 : index + memref.store %4, %alloc_0[%6, %arg1] : memref<8x16xf32> } - memref.store %val_f32, %A[%i, %j] : memref<32x16xf32> + memref.store %4, %alloc[%arg0, %arg1] : memref<32x16xf32> } } // run GPU version - %Out_gpu = call @test(%A) : (memref<32x16xf32>) -> memref<8x16xf32> - %Out_gpu_cast = memref.cast %Out_gpu : memref<8x16xf32> to memref<*xf32> - %A_cast = memref.cast %A : memref<32x16xf32> to memref<*xf32> - %Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32> // print GPU and CPU outs // call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () // dealloc - memref.dealloc %A : memref<32x16xf32> // gpu dealloc - gpu.dealloc %Out_gpu : memref<8x16xf32> + %0 = call @test(%alloc) : (memref<32x16xf32>) -> memref<8x16xf32> + %cast = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + %cast_1 = memref.cast %alloc_0 : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_1) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<32x16xf32> + memref.dealloc %0 : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/vector_insert_1.mlir b/test/Integration/Dialect/XeGPU/VC/vector_insert_1.mlir index d4539ceb0..775fe9333 100644 --- a/test/Integration/Dialect/XeGPU/VC/vector_insert_1.mlir +++ b/test/Integration/Dialect/XeGPU/VC/vector_insert_1.mlir @@ -1,36 +1,35 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<8x16xf32> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %A, %memref : memref<8x16xf32> to memref<8x16xf32> - %memref_2 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_2 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>) gpu.dealloc %memref : memref<8x16xf32> - return %memref_2 : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - - gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<8x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module0 { + gpu.func @test_kernel(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index // load tile - %a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // extract row at pos 2 - %a_row = vector.extract %val0 [2] : vector<16xf32> from vector<8x16xf32> // insert row at pos 7 - %val3 = vector.insert %a_row, %val0 [7] : vector<16xf32> into vector<8x16xf32> // store - %out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %val3, %out_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %2 = vector.extract %1[2] : vector<16xf32> from vector<8x16xf32> + %3 = vector.insert %2, %1 [7] : vector<16xf32> into vector<8x16xf32> + %4 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %3, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } @@ -42,44 +41,39 @@ module @gemm attributes {gpu.container_module} { %c7 = arith.constant 7 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c1_f32 = arith.constant 1.0 : f32 - %c2_f32 = arith.constant 2.0 : f32 - %cst = arith.constant 2.0 : f32 // random init - %lower = arith.constant -3.0 : f32 - %upper = arith.constant 3.0 : f32 - %false = arith.constant 0 : i1 - %A = memref.alloc() : memref<8x16xf32> - %Out_cpu = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - call @fillResource1DRandomF32(%A_random, %lower, %upper, %false) : (memref<*xf32>, f32, f32, i1) -> () // run GPU version - %Out_gpu = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32> - %Out_gpu_cast = memref.cast %Out_gpu : memref<8x16xf32> to memref<*xf32> - // run CPU version - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %v = memref.load %A[%i, %j] : memref<8x16xf32> - memref.store %v, %Out_cpu[%i, %j] : memref<8x16xf32> + %cst = arith.constant -3.000000e+00 : f32 + %cst_0 = arith.constant 3.000000e+00 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x16xf32> + %alloc_1 = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst, %cst_0, %false) : (memref<*xf32>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<8x16xf32>) -> memref<8x16xf32> + %cast_2 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<8x16xf32> + memref.store %1, %alloc_1[%arg0, %arg1] : memref<8x16xf32> } } - scf.for %i = %c0 to %c16 step %c1 { - %v = memref.load %A[%c2, %i] : memref<8x16xf32> - memref.store %v, %Out_cpu[%c7, %i] : memref<8x16xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + %1 = memref.load %alloc[%c2, %arg0] : memref<8x16xf32> + memref.store %1, %alloc_1[%c7, %arg0] : memref<8x16xf32> } - - %Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32> // print GPU and CPU outs // call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () // dealloc - memref.dealloc %A : memref<8x16xf32> - memref.dealloc %Out_cpu : memref<8x16xf32> // gpu dealloc - gpu.dealloc %Out_gpu : memref<8x16xf32> + %cast_3 = memref.cast %alloc_1 : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast_2, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x16xf32> + memref.dealloc %alloc_1 : memref<8x16xf32> + memref.dealloc %0 : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/vector_insert_2.mlir b/test/Integration/Dialect/XeGPU/VC/vector_insert_2.mlir index 600db74dd..d95342cd9 100644 --- a/test/Integration/Dialect/XeGPU/VC/vector_insert_2.mlir +++ b/test/Integration/Dialect/XeGPU/VC/vector_insert_2.mlir @@ -1,36 +1,35 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<8x16xf32> ) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %memref = gpu.alloc host_shared () : memref<8x16xf32> - memref.copy %A, %memref : memref<8x16xf32> to memref<8x16xf32> - %memref_2 = gpu.alloc host_shared () : memref<8x16xf32> - gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_2 : memref<8x16xf32>) + %memref = gpu.alloc () : memref<8x16xf32> + gpu.memcpy %memref, %arg0 : memref<8x16xf32>, memref<8x16xf32> + %memref_0 = gpu.alloc () : memref<8x16xf32> + gpu.launch_func @module0::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_0 : memref<8x16xf32>) gpu.dealloc %memref : memref<8x16xf32> - return %memref_2 : memref<8x16xf32> + %alloc = memref.alloc() : memref<8x16xf32> + gpu.memcpy %alloc, %memref_0 : memref<8x16xf32>, memref<8x16xf32> + gpu.dealloc %memref_0 : memref<8x16xf32> + return %alloc : memref<8x16xf32> } - - gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<8x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @module0 { + gpu.func @test_kernel(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index // load tile - %a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // define const vector - %cst = arith.constant dense<1.23> : vector<16xf32> // insert row at pos 7 - %val3 = vector.insert %cst, %val0 [7] : vector<16xf32> into vector<8x16xf32> // store - %out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %val3, %out_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %cst = arith.constant dense<1.230000e+00> : vector<16xf32> + %2 = vector.insert %cst, %1 [7] : vector<16xf32> into vector<8x16xf32> + %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %2, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } @@ -38,47 +37,42 @@ module @gemm attributes {gpu.container_module} { // init constants %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index %c7 = arith.constant 7 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c1_f32 = arith.constant 1.0 : f32 - %c2_f32 = arith.constant 2.0 : f32 - %cst = arith.constant 1.23 : f32 // random init - %lower = arith.constant -3.0 : f32 - %upper = arith.constant 3.0 : f32 - %false = arith.constant 0 : i1 - %A = memref.alloc() : memref<8x16xf32> - %Out_cpu = memref.alloc() : memref<8x16xf32> - %A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32> - call @fillResource1DRandomF32(%A_random, %lower, %upper, %false) : (memref<*xf32>, f32, f32, i1) -> () // run GPU version - %Out_gpu = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32> - %Out_gpu_cast = memref.cast %Out_gpu : memref<8x16xf32> to memref<*xf32> - // run CPU version - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c16 step %c1 { - %v = memref.load %A[%i, %j] : memref<8x16xf32> - memref.store %v, %Out_cpu[%i, %j] : memref<8x16xf32> + %cst = arith.constant 1.230000e+00 : f32 + %cst_0 = arith.constant -3.000000e+00 : f32 + %cst_1 = arith.constant 3.000000e+00 : f32 + %false = arith.constant false + %alloc = memref.alloc() : memref<8x16xf32> + %alloc_2 = memref.alloc() : memref<8x16xf32> + %cast = memref.cast %alloc : memref<8x16xf32> to memref<*xf32> + call @fillResource1DRandomF32(%cast, %cst_0, %cst_1, %false) : (memref<*xf32>, f32, f32, i1) -> () + %0 = call @test(%alloc) : (memref<8x16xf32>) -> memref<8x16xf32> + %cast_3 = memref.cast %0 : memref<8x16xf32> to memref<*xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<8x16xf32> + memref.store %1, %alloc_2[%arg0, %arg1] : memref<8x16xf32> } } - scf.for %i = %c0 to %c16 step %c1 { - memref.store %cst, %Out_cpu[%c7, %i] : memref<8x16xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + memref.store %cst, %alloc_2[%c7, %arg0] : memref<8x16xf32> } - - %Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32> // print GPU and CPU outs // call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> () // call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> () // dealloc - memref.dealloc %A : memref<8x16xf32> - memref.dealloc %Out_cpu : memref<8x16xf32> // gpu dealloc - gpu.dealloc %Out_gpu : memref<8x16xf32> + %cast_4 = memref.cast %alloc_2 : memref<8x16xf32> to memref<*xf32> + call @printAllcloseF32(%cast_3, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x16xf32> + memref.dealloc %alloc_2 : memref<8x16xf32> + memref.dealloc %0 : memref<8x16xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp b/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp index 8ba93d194..981740106 100644 --- a/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp +++ b/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp @@ -1,3 +1,8 @@ +// gpu dialect with intel intrinsic functions (func dialect) to +// llvm dialect (for host code) and +// spirv dialect (for device code) lowering pipeline. +// Ready for imex runner starting from GPU dialect. + builtin.module( cse gpu.module( @@ -16,23 +21,21 @@ canonicalize cse reconcile-unrealized-casts - bf16-to-gpu - imex-convert-gpu-to-spirv - spirv.module(spirv-lower-abi-attrs - spirv-update-vce) + gpu.module(math-extend-to-supported-types{target-type=f32}) + gpu.module(arith-emulate-unsupported-floats{source-types=bf16 target-type=f32}) + spirv-attach-target{ver=v1.0 caps=Addresses,BFloat16TypeKHR,Float16Buffer,Int64,Int16,Int8,Kernel,Linkage,Vector16,GenericPointer,Groups,Float16,Float64,AtomicFloat32AddEXT,ExpectAssumeKHR,SubgroupDispatch,VectorComputeINTEL,VectorAnyINTEL,Bfloat16ConversionINTEL exts=SPV_EXT_shader_atomic_float_add,SPV_KHR_bfloat16,SPV_KHR_expect_assume,SPV_INTEL_vector_compute,SPV_INTEL_bfloat16_conversion} + imex-convert-to-spirv{use-64bit-index=true} + gpu.module(spirv.module(spirv-lower-abi-attrs, spirv-update-vce)) func.func(llvm-request-c-wrappers) - serialize-spirv convert-vector-to-scf - convert-gpu-to-gpux convert-scf-to-cf + func.func(gpu-async-region) expand-strided-metadata finalize-memref-to-llvm - convert-cf-to-llvm - convert-vector-to-llvm - convert-index-to-llvm - convert-arith-to-llvm - convert-func-to-llvm - convert-math-to-llvm - convert-gpux-to-llvm + gpu-to-llvm{use-bare-pointers-for-kernels=true} + convert-to-llvm lower-affine - reconcile-unrealized-casts) + reconcile-unrealized-casts + gpu-module-to-binary) + +// End diff --git a/test/Integration/Dialect/XeGPU/VC/xegpu-to-vc.mlir b/test/Integration/Dialect/XeGPU/VC/xegpu-to-vc.mlir index f4eced016..b44712f49 100644 --- a/test/Integration/Dialect/XeGPU/VC/xegpu-to-vc.mlir +++ b/test/Integration/Dialect/XeGPU/VC/xegpu-to-vc.mlir @@ -1,76 +1,61 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck -module @gemm attributes {gpu.container_module, -spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + +module @gemm attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { memref.global "private" constant @__constant_32x32xf16 : memref<32x32xf16> = dense<5.000000e-01> memref.global "private" constant @__Bconstant_32x32xf16 : memref<32x32xf16> = dense<1.099610e+00> func.func @test(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>) -> memref<32x32xf32> { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c2 = arith.constant 2 : index - %memref_0 = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %arg0, %memref_0 : memref<32x32xf16> to memref<32x32xf16> - %memref_1 = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %arg1, %memref_1 : memref<32x32xf16> to memref<32x32xf16> - %memref_c = gpu.alloc host_shared () : memref<32x32xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<32x32xf16>, %memref_1 : memref<32x32xf16>, %memref_c : memref<32x32xf32>) - %result = memref.alloc() : memref<32x32xf32> - memref.copy %memref_c, %result: memref<32x32xf32> to memref<32x32xf32> + %memref = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref_0, %arg1 : memref<32x32xf16>, memref<32x32xf16> + %memref_1 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c4, %c2, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>, %memref_1 : memref<32x32xf32>) + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_1 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref : memref<32x32xf16> gpu.dealloc %memref_0 : memref<32x32xf16> - gpu.dealloc %memref_1 : memref<32x32xf16> - gpu.dealloc %memref_c :memref<32x32xf32> - - return %result : memref<32x32xf32> + gpu.dealloc %memref_1 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } gpu.module @test_kernel { - gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>}{ + gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c128 = arith.constant 128 : index %c8 = arith.constant 8 : index - - %0 = gpu.block_id x - %1 = gpu.block_id y - - %2 = arith.muli %0, %c8 : index - %3 = arith.muli %1, %c16 : index - %128 = arith.muli %c8, %c16 : index - %256 = arith.muli %128, %c2 : index - %x = arith.muli %256, %0 : index - %y = arith.muli %128, %1 : index - - %c_index = arith.addi %x, %y : index - %arg02 = memref.reinterpret_cast %arg2 to offset: [0], sizes: [1024], strides: [1] : memref<32x32xf32> to memref<1024xf32> - %C0 = xegpu.create_nd_tdesc %arg02[%c_index] : memref<1024xf32> -> !xegpu.tensor_desc<128xf32> - %5 = xegpu.load_nd %C0 : !xegpu.tensor_desc<128xf32> -> vector<128xf32> - - %arg00 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1024], strides: [1] : memref<32x32xf16> to memref<1024xf16> - - %6 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %5) -> (vector<128xf32>) { - %a_index = arith.addi %x, %arg3 : index - %A0 = xegpu.create_nd_tdesc %arg00[%a_index]: memref<1024xf16> -> !xegpu.tensor_desc<128xf16> - %A0_val = xegpu.load_nd %A0 : !xegpu.tensor_desc<128xf16> -> vector<128xf16> - - %B0 = xegpu.create_nd_tdesc %arg1[%arg3, %3] {boundary_check = true} : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> - %B0_val = xegpu.load_nd %B0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> - - %A0_cast = vector.shape_cast %A0_val : vector<128xf16> to vector<8x16xf16> - - %dpas0 = xegpu.dpas %A0_cast, %B0_val : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> - %dpas0_cast = vector.shape_cast %dpas0: vector<8x16xf32> to vector<128xf32> - - scf.yield %dpas0_cast : vector<128xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c8 : index + %1 = arith.muli %block_id_y, %c16 : index + %2 = arith.muli %c8, %c16 : index + %3 = arith.muli %2, %c2 : index + %4 = arith.muli %3, %block_id_x : index + %5 = arith.muli %2, %block_id_y : index + %6 = arith.addi %4, %5 : index + %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [0], sizes: [1024], strides: [1] : memref<32x32xf32> to memref<1024xf32> + %7 = xegpu.create_nd_tdesc %reinterpret_cast[%6] : memref<1024xf32> -> !xegpu.tensor_desc<128xf32> + %8 = xegpu.load_nd %7 : !xegpu.tensor_desc<128xf32> -> vector<128xf32> + %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1024], strides: [1] : memref<32x32xf16> to memref<1024xf16> + %9 = scf.for %arg3 = %c0 to %c32 step %c16 iter_args(%arg4 = %8) -> (vector<128xf32>) { + %10 = arith.addi %4, %arg3 : index + %11 = xegpu.create_nd_tdesc %reinterpret_cast_0[%10] : memref<1024xf16> -> !xegpu.tensor_desc<128xf16> + %12 = xegpu.load_nd %11 : !xegpu.tensor_desc<128xf16> -> vector<128xf16> + %13 = xegpu.create_nd_tdesc %arg1[%arg3, %1] {boundary_check = true} : memref<32x32xf16> -> !xegpu.tensor_desc<16x16xf16> + %14 = xegpu.load_nd %13 <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %15 = vector.shape_cast %12 : vector<128xf16> to vector<8x16xf16> + %16 = xegpu.dpas %15, %14 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %17 = vector.shape_cast %16 : vector<8x16xf32> to vector<128xf32> + scf.yield %17 : vector<128xf32> } - xegpu.store_nd %6, %C0 : vector<128xf32>, !xegpu.tensor_desc<128xf32> - + xegpu.store_nd %9, %7 : vector<128xf32>, !xegpu.tensor_desc<128xf32> gpu.return } } diff --git a/test/Integration/Dialect/XeTile/batch_gemm.mlir b/test/Integration/Dialect/XeTile/batch_gemm.mlir index 3d2a31e01..b207827e4 100644 --- a/test/Integration/Dialect/XeTile/batch_gemm.mlir +++ b/test/Integration/Dialect/XeTile/batch_gemm.mlir @@ -1,34 +1,31 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +#map = affine_map<() -> (0)> +#map1 = affine_map<() -> (96)> +#map2 = affine_map<() -> (3)> +#map3 = affine_map<() -> (2)> module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<2x3x128x96xf16>, %B: memref<2x3x256x96xf16>) -> memref<2x3x128x256xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<2x3x128x96xf16>, %arg1: memref<2x3x256x96xf16>) -> memref<2x3x128x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %A_gpu = gpu.alloc host_shared () : memref<2x3x128x96xf16> - memref.copy %A, %A_gpu : memref<2x3x128x96xf16> to memref<2x3x128x96xf16> - %B_gpu = gpu.alloc host_shared () : memref<2x3x256x96xf16> - memref.copy %B, %B_gpu : memref<2x3x256x96xf16> to memref<2x3x256x96xf16> - %C_gpu = gpu.alloc host_shared () : memref<2x3x128x256xf32> - gpu.launch_func @b2x3_m128_n256_k96::@b2x3_m128_n256_k96 - blocks in (%c1, %c1, %c1) - threads in (%c4, %c8, %c1) - args(%A_gpu : memref<2x3x128x96xf16>, - %B_gpu : memref<2x3x256x96xf16>, - %C_gpu : memref<2x3x128x256xf32>) - gpu.dealloc %A_gpu : memref<2x3x128x96xf16> - gpu.dealloc %B_gpu : memref<2x3x256x96xf16> - return %C_gpu : memref<2x3x128x256xf32> + %memref = gpu.alloc () : memref<2x3x128x96xf16> + gpu.memcpy %memref, %arg0 : memref<2x3x128x96xf16>, memref<2x3x128x96xf16> + %memref_0 = gpu.alloc () : memref<2x3x256x96xf16> + gpu.memcpy %memref_0, %arg1 : memref<2x3x256x96xf16>, memref<2x3x256x96xf16> + %memref_1 = gpu.alloc () : memref<2x3x128x256xf32> + gpu.launch_func @b2x3_m128_n256_k96::@b2x3_m128_n256_k96 blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%memref : memref<2x3x128x96xf16>, %memref_0 : memref<2x3x256x96xf16>, %memref_1 : memref<2x3x128x256xf32>) + gpu.dealloc %memref : memref<2x3x128x96xf16> + gpu.dealloc %memref_0 : memref<2x3x256x96xf16> + %alloc = memref.alloc() : memref<2x3x128x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<2x3x128x256xf32>, memref<2x3x128x256xf32> + gpu.dealloc %memref_1 : memref<2x3x128x256xf32> + return %alloc : memref<2x3x128x256xf32> } - - gpu.module @b2x3_m128_n256_k96 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @b2x3_m128_n256_k96 { gpu.func @b2x3_m128_n256_k96(%arg0: memref<2x3x128x96xf16>, %arg1: memref<2x3x256x96xf16>, %arg2: memref<2x3x128x256xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %block_id_x = gpu.block_id x %block_id_y = gpu.block_id y @@ -85,27 +82,27 @@ module @gemm attributes {gpu.container_module} { %21 = xetile.init_tile %arg1[%arg3, %arg4, %11, %c0] : memref<2x3x256x96xf16> -> !xetile.tile<32x32xf16> %22 = xetile.init_tile %arg0[%arg3, %arg4, %16, %c0] : memref<2x3x128x96xf16> -> !xetile.tile<4x32xf16> xetile.prefetch_tile %22 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> - %23 = xetile.update_tile_offset %22, [%c0, %c32] : !xetile.tile<4x32xf16> + %23 = xetile.update_tile_offset %22, [%c0, %c32] : !xetile.tile<4x32xf16> xetile.prefetch_tile %23 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> - %24 = xetile.update_tile_offset %23, [%c0, %c32] : !xetile.tile<4x32xf16> + %24 = xetile.update_tile_offset %23, [%c0, %c32] : !xetile.tile<4x32xf16> %25 = xetile.init_tile %arg1[%arg3, %arg4, %17, %c0] : memref<2x3x256x96xf16> -> !xetile.tile<8x32xf16> xetile.prefetch_tile %25 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<8x32xf16> - %26 = xetile.update_tile_offset %25, [%c0, %c32] : !xetile.tile<8x32xf16> + %26 = xetile.update_tile_offset %25, [%c0, %c32] : !xetile.tile<8x32xf16> xetile.prefetch_tile %26 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<8x32xf16> - %27 = xetile.update_tile_offset %26, [%c0, %c32] : !xetile.tile<8x32xf16> + %27 = xetile.update_tile_offset %26, [%c0, %c32] : !xetile.tile<8x32xf16> %28 = xetile.init_tile %arg0[%arg3, %arg4, %18, %c0] : memref<2x3x128x96xf16> -> !xetile.tile<4x32xf16> xetile.prefetch_tile %28 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> - %29 = xetile.update_tile_offset %28, [%c0, %c32] : !xetile.tile<4x32xf16> + %29 = xetile.update_tile_offset %28, [%c0, %c32] : !xetile.tile<4x32xf16> xetile.prefetch_tile %29 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> - %30 = xetile.update_tile_offset %29, [%c0, %c32] : !xetile.tile<4x32xf16> + %30 = xetile.update_tile_offset %29, [%c0, %c32] : !xetile.tile<4x32xf16> %31 = xetile.init_tile %arg1[%arg3, %arg4, %19, %c0] : memref<2x3x256x96xf16> -> !xetile.tile<8x32xf16> xetile.prefetch_tile %31 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<8x32xf16> - %32 = xetile.update_tile_offset %31, [%c0, %c32] : !xetile.tile<8x32xf16> + %32 = xetile.update_tile_offset %31, [%c0, %c32] : !xetile.tile<8x32xf16> xetile.prefetch_tile %32 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<8x32xf16> - %33 = xetile.update_tile_offset %32, [%c0, %c32] : !xetile.tile<8x32xf16> + %33 = xetile.update_tile_offset %32, [%c0, %c32] : !xetile.tile<8x32xf16> %34:8 = scf.for %arg5 = %c0 to %c96 step %c32 iter_args(%arg6 = %cst, %arg7 = %20, %arg8 = %21, %arg9 = %24, %arg10 = %27, %arg11 = %30, %arg12 = %33, %arg13 = %c0) -> (vector<32x32xf32>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<4x32xf16>, !xetile.tile<8x32xf16>, !xetile.tile<4x32xf16>, !xetile.tile<8x32xf16>, index) { - %36 = xetile.load_tile %arg7 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %37 = xetile.load_tile %arg8 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %36 = xetile.load_tile %arg7 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %37 = xetile.load_tile %arg8 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> %38 = arith.cmpi eq, %arg13, %c10 : index %39 = arith.select %38, %c0, %arg13 : index scf.if %38 { @@ -118,27 +115,26 @@ module @gemm attributes {gpu.container_module} { xetile.prefetch_tile %arg11 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> xetile.prefetch_tile %arg12 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<8x32xf16> xegpu.compile_hint - %41 = xetile.update_tile_offset %arg9, [%c0, %c32] : !xetile.tile<4x32xf16> - %42 = xetile.update_tile_offset %arg10, [%c0, %c32] : !xetile.tile<8x32xf16> - %43 = xetile.update_tile_offset %arg11, [%c0, %c32] : !xetile.tile<4x32xf16> - %44 = xetile.update_tile_offset %arg12, [%c0, %c32] : !xetile.tile<8x32xf16> + %41 = xetile.update_tile_offset %arg9, [%c0, %c32] : !xetile.tile<4x32xf16> + %42 = xetile.update_tile_offset %arg10, [%c0, %c32] : !xetile.tile<8x32xf16> + %43 = xetile.update_tile_offset %arg11, [%c0, %c32] : !xetile.tile<4x32xf16> + %44 = xetile.update_tile_offset %arg12, [%c0, %c32] : !xetile.tile<8x32xf16> %45 = vector.transpose %37, [1, 0] : vector<32x32xf16> to vector<32x32xf16> xegpu.compile_hint - %46 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16> - %47 = xetile.update_tile_offset %arg8, [%c0, %c32] : !xetile.tile<32x32xf16> + %46 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16> + %47 = xetile.update_tile_offset %arg8, [%c0, %c32] : !xetile.tile<32x32xf16> xegpu.compile_hint %48 = xetile.tile_mma %36, %45, %arg6 : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32> xegpu.compile_hint scf.yield %48, %46, %47, %41, %42, %43, %44, %40 : vector<32x32xf32>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<4x32xf16>, !xetile.tile<8x32xf16>, !xetile.tile<4x32xf16>, !xetile.tile<8x32xf16>, index - } {lowerBoundMap = affine_map<() -> (0)>, operandSegmentSizes = array, step = 32 : index, upperBoundMap = affine_map<() -> (96)>} + } {lowerBoundMap = #map, operandSegmentSizes = array, step = 32 : index, upperBoundMap = #map1} %35 = xetile.init_tile %arg2[%arg3, %arg4, %9, %11] : memref<2x3x128x256xf32> -> !xetile.tile<32x32xf32> xetile.store_tile %34#0, %35 : vector<32x32xf32>, !xetile.tile<32x32xf32> - } {lowerBoundMap = affine_map<() -> (0)>, operandSegmentSizes = array, step = 1 : index, syn.parall_level = 2 : i64, upperBoundMap = affine_map<() -> (3)>} - } {lowerBoundMap = affine_map<() -> (0)>, operandSegmentSizes = array, step = 1 : index, syn.parall_level = 2 : i64, upperBoundMap = affine_map<() -> (2)>} + } {lowerBoundMap = #map, operandSegmentSizes = array, step = 1 : index, syn.parall_level = 2 : i64, upperBoundMap = #map2} + } {lowerBoundMap = #map, operandSegmentSizes = array, step = 1 : index, syn.parall_level = 2 : i64, upperBoundMap = #map3} gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -147,59 +143,53 @@ module @gemm attributes {gpu.container_module} { %c96 = arith.constant 96 : index %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index - %cf_1 = arith.constant 1.0 : f16 - %cf_3 = arith.constant 3.0 : f16 - %cf_96 = arith.constant 96.0 : f32 - %A = memref.alloc() : memref<2x3x128x96xf16> - %B = memref.alloc() : memref<2x3x256x96xf16> - %C_ref = memref.alloc() : memref<2x3x128x256xf32> - // The batch contains 6 gemms, fill the first A/B matrices with ones, // the second A/B matrices with twos, and so. The output should be: // first matrix filled with 1*1*96, the second one with 2*2*96, and so on. - scf.for %i0 = %c0 to %c2 step %c1 { - scf.for %i1 = %c0 to %c3 step %c1 { - %i0_i16 = index.castu %i0 : index to i16 - %i0_f16 = arith.uitofp %i0_i16 : i16 to f16 - %i1_i16 = index.castu %i1 : index to i16 - %i1_f16 = arith.uitofp %i1_i16 : i16 to f16 - %v0 = arith.mulf %i0_f16, %cf_3 : f16 - %v1 = arith.addf %v0, %i1_f16 : f16 - %v = arith.addf %v1, %cf_1 : f16 - - scf.for %i2 = %c0 to %c128 step %c1 { - scf.for %i3 = %c0 to %c96 step %c1 { - memref.store %v, %A[%i0, %i1, %i2, %i3] : memref<2x3x128x96xf16> + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 3.000000e+00 : f16 + %cst_1 = arith.constant 9.600000e+01 : f32 + %alloc = memref.alloc() : memref<2x3x128x96xf16> + %alloc_2 = memref.alloc() : memref<2x3x256x96xf16> + %alloc_3 = memref.alloc() : memref<2x3x128x256xf32> + scf.for %arg0 = %c0 to %c2 step %c1 { + scf.for %arg1 = %c0 to %c3 step %c1 { + %1 = index.castu %arg0 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + %3 = index.castu %arg1 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + %5 = arith.mulf %2, %cst_0 : f16 + %6 = arith.addf %5, %4 : f16 + %7 = arith.addf %6, %cst : f16 + scf.for %arg2 = %c0 to %c128 step %c1 { + scf.for %arg3 = %c0 to %c96 step %c1 { + memref.store %7, %alloc[%arg0, %arg1, %arg2, %arg3] : memref<2x3x128x96xf16> } } - - scf.for %i2 = %c0 to %c256 step %c1 { - scf.for %i3 = %c0 to %c96 step %c1 { - memref.store %v, %B[%i0, %i1, %i2, %i3] : memref<2x3x256x96xf16> + scf.for %arg2 = %c0 to %c256 step %c1 { + scf.for %arg3 = %c0 to %c96 step %c1 { + memref.store %7, %alloc_2[%arg0, %arg1, %arg2, %arg3] : memref<2x3x256x96xf16> } } - - %r0 = arith.extf %v : f16 to f32 - %r1 = arith.mulf %r0, %r0 : f32 - %r = arith.mulf %r1, %cf_96 : f32 - - scf.for %i2 = %c0 to %c128 step %c1 { - scf.for %i3 = %c0 to %c256 step %c1 { - memref.store %r, %C_ref[%i0, %i1, %i2, %i3] : memref<2x3x128x256xf32> + %8 = arith.extf %7 : f16 to f32 + %9 = arith.mulf %8, %8 : f32 + %10 = arith.mulf %9, %cst_1 : f32 + scf.for %arg2 = %c0 to %c128 step %c1 { + scf.for %arg3 = %c0 to %c256 step %c1 { + memref.store %10, %alloc_3[%arg0, %arg1, %arg2, %arg3] : memref<2x3x128x256xf32> } } } } - - %C = call @test(%A, %B) : (memref<2x3x128x96xf16>, memref<2x3x256x96xf16>) -> memref<2x3x128x256xf32> - %cast_C = memref.cast %C : memref<2x3x128x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<2x3x128x256xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<2x3x128x96xf16> - memref.dealloc %B : memref<2x3x256x96xf16> - memref.dealloc %C_ref : memref<2x3x128x256xf32> + %0 = call @test(%alloc, %alloc_2) : (memref<2x3x128x96xf16>, memref<2x3x256x96xf16>) -> memref<2x3x128x256xf32> + %cast = memref.cast %0 : memref<2x3x128x256xf32> to memref<*xf32> + %cast_4 = memref.cast %alloc_3 : memref<2x3x128x256xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<2x3x128x96xf16> + memref.dealloc %alloc_2 : memref<2x3x256x96xf16> + memref.dealloc %alloc_3 : memref<2x3x128x256xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir b/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir index 69e924851..bd4de1b1c 100644 --- a/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_broadcast_dim_0_fp32.mlir @@ -1,66 +1,57 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @softmax attributes {gpu.container_module} { func.func @broadcast_test() -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - - %b_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - gpu.launch_func @kernel::@softmax_dim_0 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%b_gpu : memref<1024x1024xf32>) - return %b_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @kernel::@softmax_dim_0 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf32>) + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block broadcast. each thread is assigned with a 16x32 block, and broadcast value from vector<1x32xf32> to vector<16x32xf32> along dim-0 independently. - gpu.func @softmax_dim_0(%b: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @kernel { + gpu.func @softmax_dim_0(%arg0: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - - %1 = arith.constant dense<3.0>: vector<1x32xf32> - %2 = xetile.broadcast %1 [0]: vector<1x32xf32> -> vector<16x32xf32> - - %3 = xetile.init_tile %b[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - xetile.store_tile %2, %3: vector<16x32xf32>, !xetile.tile<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %cst = arith.constant dense<3.000000e+00> : vector<1x32xf32> + %2 = xetile.broadcast %cst [0] : vector<1x32xf32> -> vector<16x32xf32> + %3 = xetile.init_tile %arg0[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + xetile.store_tile %2, %3 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 3.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %b_ref = memref.alloc() : memref<1024x1024xf32> - // compute b for reference // step 1: exp - %val = arith.constant 3.0: f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %val, %b_ref[%i, %j] : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc[%arg0, %arg1] : memref<1024x1024xf32> } } - - %b = call @broadcast_test() : () -> memref<1024x1024xf32> - %cast_b = memref.cast %b : memref<1024x1024xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_b) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_b_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %b_ref : memref<1024x1024xf32> + %0 = call @broadcast_test() : () -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_0 = memref.cast %alloc : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_0) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir b/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir index fae86e136..bc4a98e15 100644 --- a/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_broadcast_dim_1_fp32.mlir @@ -1,66 +1,57 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @softmax attributes {gpu.container_module} { func.func @broadcast_test() -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - - %b_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - gpu.launch_func @kernel::@softmax_dim_0 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%b_gpu : memref<1024x1024xf32>) - return %b_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @kernel::@softmax_dim_0 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf32>) + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block broadcast. each thread is assigned with a 16x32 block, and broadcast value from vector<16x1xf32> to vector<16x32xf32> along dim-1 independently. - gpu.func @softmax_dim_0(%b: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @kernel { + gpu.func @softmax_dim_0(%arg0: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - - %1 = arith.constant dense<3.0>: vector<16x1xf32> - %2 = xetile.broadcast %1 [1]: vector<16x1xf32> -> vector<16x32xf32> - - %3 = xetile.init_tile %b[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - xetile.store_tile %2, %3: vector<16x32xf32>, !xetile.tile<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %cst = arith.constant dense<3.000000e+00> : vector<16x1xf32> + %2 = xetile.broadcast %cst [1] : vector<16x1xf32> -> vector<16x32xf32> + %3 = xetile.init_tile %arg0[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + xetile.store_tile %2, %3 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 3.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %b_ref = memref.alloc() : memref<1024x1024xf32> - // compute b for reference // step 1: exp - %val = arith.constant 3.0: f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %val, %b_ref[%i, %j] : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc[%arg0, %arg1] : memref<1024x1024xf32> } } - - %b = call @broadcast_test() : () -> memref<1024x1024xf32> - %cast_b = memref.cast %b : memref<1024x1024xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_b) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_b_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %b_ref : memref<1024x1024xf32> + %0 = call @broadcast_test() : () -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_0 = memref.cast %alloc : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_0) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir b/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir index 3e5835fa3..dc0871072 100644 --- a/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_reduce_dim_0_fp32.mlir @@ -1,88 +1,75 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @softmax attributes {gpu.container_module} { - func.func @reduce_test(%a: memref<16x1024xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { + func.func @reduce_test(%arg0: memref<16x1024xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - - %a_gpu = gpu.alloc host_shared () : memref<16x1024xf32> - memref.copy %a, %a_gpu : memref<16x1024xf32> to memref<16x1024xf32> - %b_gpu = gpu.alloc host_shared () : memref<1x1024xf32> - - gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c1, %c32, %c1) threads in (%c1, %c1, %c1) args(%a_gpu : memref<16x1024xf32>, %b_gpu : memref<1x1024xf32>) - - gpu.dealloc %a_gpu : memref<16x1024xf32> - return %b_gpu : memref<1x1024xf32> + %memref = gpu.alloc () : memref<16x1024xf32> + gpu.memcpy %memref, %arg0 : memref<16x1024xf32>, memref<16x1024xf32> + %memref_0 = gpu.alloc () : memref<1x1024xf32> + gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c1, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<16x1024xf32>, %memref_0 : memref<1x1024xf32>) + gpu.dealloc %memref : memref<16x1024xf32> + %alloc = memref.alloc() : memref<1x1024xf32> + gpu.memcpy %alloc, %memref_0 : memref<1x1024xf32>, memref<1x1024xf32> + gpu.dealloc %memref_0 : memref<1x1024xf32> + return %alloc : memref<1x1024xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block reduction. each thread is assigned with a 16x32 block, and do reduction along dim-0 independently. - gpu.func @reduce_dim_1(%a: memref<16x1024xf32>, %b: memref<1x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @kernel { + gpu.func @reduce_dim_1(%arg0: memref<16x1024xf32>, %arg1: memref<1x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - %1 = xetile.init_tile %a[%m, %n] : memref<16x1024xf32> -> !xetile.tile<16x32xf32> - %2 = xetile.load_tile %1: !xetile.tile<16x32xf32> -> vector<16x32xf32> - %4 = xetile.reduction , %2 [0]: vector<16x32xf32> -> vector<1x32xf32> - %5 = xetile.init_tile %b[0, %n] : memref<1x1024xf32> -> !xetile.tile<1x32xf32> - xetile.store_tile %4, %5: vector<1x32xf32>, !xetile.tile<1x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<16x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.reduction , %3 [0] : vector<16x32xf32> -> vector<1x32xf32> + %5 = xetile.init_tile %arg1[0, %1] : memref<1x1024xf32> -> !xetile.tile<1x32xf32> + xetile.store_tile %4, %5 : vector<1x32xf32>, !xetile.tile<1x32xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %c0_f32 = arith.constant 0.0 : f32 - %c32_f32 = arith.constant 32.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<16x1024xf32> - %b_ref = memref.alloc() : memref<1024xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c16 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %u = arith.uitofp %t : i16 to f32 - %v = arith.divf %u, %c100_f32 : f32 - memref.store %v, %a[%i, %j] : memref<16x1024xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+02 : f32 + %alloc = memref.alloc() : memref<16x1024xf32> + %alloc_1 = memref.alloc() : memref<1024xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f32 + %3 = arith.divf %2, %cst_0 : f32 + memref.store %3, %alloc[%arg0, %arg1] : memref<16x1024xf32> } } - - scf.for %j = %c0 to %c1024 step %c1 { - %sum = scf.for %i = %c0 to %c16 step %c1 iter_args(%arg = %c0_f32) -> (f32) { - %val = memref.load %a[%i, %j] : memref<16x1024xf32> - %2 = arith.addf %arg, %val : f32 - scf.yield %2 : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = scf.for %arg1 = %c0 to %c16 step %c1 iter_args(%arg2 = %cst) -> (f32) { + %2 = memref.load %alloc[%arg1, %arg0] : memref<16x1024xf32> + %3 = arith.addf %arg2, %2 : f32 + scf.yield %3 : f32 } - memref.store %sum, %b_ref[%j] : memref<1024xf32> + memref.store %1, %alloc_1[%arg0] : memref<1024xf32> } - - %b = call @reduce_test(%a) : (memref<16x1024xf32>) -> memref<1x1024xf32> - %cast_b = memref.cast %b : memref<1x1024xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_b): (memref<*xf32>) -> () // call @printMemrefF32(%cast_b_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<16x1024xf32> - memref.dealloc %b_ref : memref<1024xf32> + %0 = call @reduce_test(%alloc) : (memref<16x1024xf32>) -> memref<1x1024xf32> + %cast = memref.cast %0 : memref<1x1024xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_2) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<16x1024xf32> + memref.dealloc %alloc_1 : memref<1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir b/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir index 48263b323..5a70d59ad 100644 --- a/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_reduce_dim_1_fp32.mlir @@ -1,90 +1,75 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @softmax attributes {gpu.container_module} { - func.func @reduce_test(%a: memref<1024x32xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { + func.func @reduce_test(%arg0: memref<1024x32xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - - %a_gpu = gpu.alloc host_shared () : memref<1024x32xf32> - memref.copy %a, %a_gpu : memref<1024x32xf32> to memref<1024x32xf32> - %b_gpu = gpu.alloc host_shared () : memref<1x1024xf32> - - gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c64, %c1, %c1) threads in (%c1, %c1, %c1) args(%a_gpu : memref<1024x32xf32>, %b_gpu : memref<1x1024xf32>) - - gpu.dealloc %a_gpu : memref<1024x32xf32> - return %b_gpu : memref<1x1024xf32> + %memref = gpu.alloc () : memref<1024x32xf32> + gpu.memcpy %memref, %arg0 : memref<1024x32xf32>, memref<1024x32xf32> + %memref_0 = gpu.alloc () : memref<1x1024xf32> + gpu.launch_func @kernel::@reduce_dim_1 blocks in (%c64, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x32xf32>, %memref_0 : memref<1x1024xf32>) + gpu.dealloc %memref : memref<1024x32xf32> + %alloc = memref.alloc() : memref<1x1024xf32> + gpu.memcpy %alloc, %memref_0 : memref<1x1024xf32>, memref<1x1024xf32> + gpu.dealloc %memref_0 : memref<1x1024xf32> + return %alloc : memref<1x1024xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block reduction. each thread is assigned with a 16x32 block, and do reduction along dim-1 independently. - gpu.func @reduce_dim_1(%a: memref<1024x32xf32>, %b: memref<1x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c = arith.constant dense<3.2>: vector<16xf32> + gpu.module @kernel { + gpu.func @reduce_dim_1(%arg0: memref<1024x32xf32>, %arg1: memref<1x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense<3.200000e+00> : vector<16xf32> %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - - %1 = xetile.init_tile %a[%m, %n] : memref<1024x32xf32> -> !xetile.tile<16x32xf32> - %2 = xetile.load_tile %1: !xetile.tile<16x32xf32> -> vector<16x32xf32> - - %4 = xetile.reduction , %2 [1]: vector<16x32xf32> -> vector<16x1xf32> - %5 = xetile.init_tile %b[0, %m] : memref<1x1024xf32> -> !xetile.tile<1x16xf32> - %cast = vector.shape_cast %4: vector<16x1xf32> to vector<1x16xf32> - xetile.store_tile %cast, %5: vector<1x16xf32>, !xetile.tile<1x16xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<1024x32xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.reduction , %3 [1] : vector<16x32xf32> -> vector<16x1xf32> + %5 = xetile.init_tile %arg1[0, %0] : memref<1x1024xf32> -> !xetile.tile<1x16xf32> + %6 = vector.shape_cast %4 : vector<16x1xf32> to vector<1x16xf32> + xetile.store_tile %6, %5 : vector<1x16xf32>, !xetile.tile<1x16xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %c0_f32 = arith.constant 0.0 : f32 - %c32_f32 = arith.constant 32.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<1024x32xf32> - %b_ref = memref.alloc() : memref<1024xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %t = index.castu %j : index to i16 - %u = arith.uitofp %t : i16 to f32 - %v = arith.divf %u, %c100_f32 : f32 - memref.store %v, %a[%i, %j] : memref<1024x32xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+02 : f32 + %alloc = memref.alloc() : memref<1024x32xf32> + %alloc_1 = memref.alloc() : memref<1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f32 + %3 = arith.divf %2, %cst_0 : f32 + memref.store %3, %alloc[%arg0, %arg1] : memref<1024x32xf32> } } - - scf.for %i = %c0 to %c1024 step %c1 { - %sum = scf.for %j = %c0 to %c32 step %c1 iter_args(%arg = %c0_f32) -> (f32) { - %val = memref.load %a[%i, %j] : memref<1024x32xf32> - %2 = arith.addf %arg, %val : f32 - scf.yield %2 : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = scf.for %arg1 = %c0 to %c32 step %c1 iter_args(%arg2 = %cst) -> (f32) { + %2 = memref.load %alloc[%arg0, %arg1] : memref<1024x32xf32> + %3 = arith.addf %arg2, %2 : f32 + scf.yield %3 : f32 } - memref.store %sum, %b_ref[%i] : memref<1024xf32> + memref.store %1, %alloc_1[%arg0] : memref<1024xf32> } - - %b = call @reduce_test(%a) : (memref<1024x32xf32>) -> memref<1x1024xf32> - %cast_b = memref.cast %b : memref<1x1024xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<1024xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<1024x32xf32> - memref.dealloc %b_ref : memref<1024xf32> + %0 = call @reduce_test(%alloc) : (memref<1024x32xf32>) -> memref<1x1024xf32> + %cast = memref.cast %0 : memref<1x1024xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_2) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x32xf32> + memref.dealloc %alloc_1 : memref<1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir b/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir index d537f7e6e..99a6e85e0 100644 --- a/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_softmax_dim_0_fp32.mlir @@ -1,109 +1,95 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @softmax attributes {gpu.container_module} { - func.func @block_softmax_test(%a: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @block_softmax_test(%arg0: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - - %a_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %a, %a_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %b_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - - gpu.launch_func @kernel::@block_softmax_dim_0 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%a_gpu : memref<1024x1024xf32>, %b_gpu : memref<1024x1024xf32>) - - gpu.dealloc %a_gpu : memref<1024x1024xf32> - return %b_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_0 = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @kernel::@block_softmax_dim_0 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf32>, %memref_0 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_0 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_0 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block softmax. each thread is assigned with a 16x32 block, and do softmax along dim-0 independently. - gpu.func @block_softmax_dim_0(%a: memref<1024x1024xf32>, %b: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @kernel { + gpu.func @block_softmax_dim_0(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - - %1 = xetile.init_tile %a[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %2 = xetile.load_tile %1: !xetile.tile<16x32xf32> -> vector<16x32xf32> - %3 = math.exp %2: vector<16x32xf32> - %4 = xetile.reduction , %3 [0]: vector<16x32xf32> -> vector<1x32xf32> - %5 = xetile.broadcast %4 [0]: vector<1x32xf32> -> vector<16x32xf32> - %6 = arith.divf %3, %5: vector<16x32xf32> - %7 = xetile.init_tile %b[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - xetile.store_tile %6, %7: vector<16x32xf32>, !xetile.tile<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = math.exp %3 : vector<16x32xf32> + %5 = xetile.reduction , %4 [0] : vector<16x32xf32> -> vector<1x32xf32> + %6 = xetile.broadcast %5 [0] : vector<1x32xf32> -> vector<16x32xf32> + %7 = arith.divf %4, %6 : vector<16x32xf32> + %8 = xetile.init_tile %arg1[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + xetile.store_tile %7, %8 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c16 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %c0_f32 = arith.constant 0.0 : f32 - %c64_f32 = arith.constant 64.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<1024x1024xf32> - %b_ref = memref.alloc() : memref<1024x1024xf32> - %s = memref.alloc() : memref<1024xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %u = arith.uitofp %t : i16 to f32 - %v = arith.divf %u, %c100_f32 : f32 - memref.store %v, %a[%i, %j] : memref<1024x1024xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 6.400000e+01 : f32 + %cst_1 = arith.constant 1.000000e+02 : f32 + %alloc = memref.alloc() : memref<1024x1024xf32> + %alloc_2 = memref.alloc() : memref<1024x1024xf32> + %alloc_3 = memref.alloc() : memref<1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f32 + %3 = arith.divf %2, %cst_1 : f32 + memref.store %3, %alloc[%arg0, %arg1] : memref<1024x1024xf32> } } - // compute b for reference // step 1: exp - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %val = memref.load %a[%i, %j] : memref<1024x1024xf32> - %exp = math.exp %val : f32 - memref.store %exp, %b_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<1024x1024xf32> + %2 = math.exp %1 : f32 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<1024x1024xf32> } } - // step 2: sum and div along dim-0 - scf.for %j = %c0 to %c1024 step %c1 { - %sum = scf.for %i = %c0 to %c1024 step %c1 iter_args(%arg = %c0_f32) -> (f32) { - %val = memref.load %b_ref[%i, %j] : memref<1024x1024xf32> - %2 = arith.addf %arg, %val : f32 - scf.yield %2 : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = scf.for %arg1 = %c0 to %c1024 step %c1 iter_args(%arg2 = %cst) -> (f32) { + %3 = memref.load %alloc_2[%arg1, %arg0] : memref<1024x1024xf32> + %4 = arith.addf %arg2, %3 : f32 + scf.yield %4 : f32 } - %avg = arith.divf %sum, %c64_f32: f32 - memref.store %avg, %s[%j] : memref<1024xf32> - - scf.for %i = %c0 to %c1024 step %c1 { - %val = memref.load %b_ref[%i, %j] : memref<1024x1024xf32> - %div = arith.divf %val, %avg: f32 - memref.store %div, %b_ref[%i, %j] : memref<1024x1024xf32> + %2 = arith.divf %1, %cst_0 : f32 + memref.store %2, %alloc_3[%arg0] : memref<1024xf32> + scf.for %arg1 = %c0 to %c1024 step %c1 { + %3 = memref.load %alloc_2[%arg1, %arg0] : memref<1024x1024xf32> + %4 = arith.divf %3, %2 : f32 + memref.store %4, %alloc_2[%arg1, %arg0] : memref<1024x1024xf32> } } - - %b = call @block_softmax_test(%a) : (memref<1024x1024xf32>) -> memref<1024x1024xf32> - %cast_b = memref.cast %b : memref<1024x1024xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_b) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_b_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<1024x1024xf32> - memref.dealloc %b_ref : memref<1024x1024xf32> + %0 = call @block_softmax_test(%alloc) : (memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_4 = memref.cast %alloc_2 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf32> + memref.dealloc %alloc_2 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir b/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir index 5bf08f1d4..1ca5e425d 100644 --- a/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir +++ b/test/Integration/Dialect/XeTile/block_softmax_dim_1_fp32.mlir @@ -1,110 +1,95 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @block_softmax attributes {gpu.container_module} { - func.func @block_softmax_test(%a: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @block_softmax_test(%arg0: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - - %a_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %a, %a_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %b_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - - gpu.launch_func @kernel::@block_softmax_dim_1 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%a_gpu : memref<1024x1024xf32>, %b_gpu : memref<1024x1024xf32>) - - gpu.dealloc %a_gpu : memref<1024x1024xf32> - return %b_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_0 = gpu.alloc () : memref<1024x1024xf32> + gpu.launch_func @kernel::@block_softmax_dim_1 blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf32>, %memref_0 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_0 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_0 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { // the kernel is a 16x32 block softmax. each thread is assigned with a 16x32 block, and do softmax along dim-1 independently. - gpu.func @block_softmax_dim_1(%a: memref<1024x1024xf32>, %b: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @kernel { + gpu.func @block_softmax_dim_1(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - - %1 = xetile.init_tile %a[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %2 = xetile.load_tile %1: !xetile.tile<16x32xf32> -> vector<16x32xf32> - %3 = math.exp %2: vector<16x32xf32> - %4 = xetile.reduction , %3 [1]: vector<16x32xf32> -> vector<16x1xf32> - %5 = xetile.broadcast %4 [1]: vector<16x1xf32> -> vector<16x32xf32> - %6 = arith.divf %3, %5: vector<16x32xf32> - %7 = xetile.init_tile %b[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - xetile.store_tile %6, %7: vector<16x32xf32>, !xetile.tile<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = math.exp %3 : vector<16x32xf32> + %5 = xetile.reduction , %4 [1] : vector<16x32xf32> -> vector<16x1xf32> + %6 = xetile.broadcast %5 [1] : vector<16x1xf32> -> vector<16x32xf32> + %7 = arith.divf %4, %6 : vector<16x32xf32> + %8 = xetile.init_tile %arg1[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + xetile.store_tile %7, %8 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c16 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %c0_f32 = arith.constant 0.0 : f32 - %c32_f32 = arith.constant 32.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<1024x1024xf32> - %b_ref = memref.alloc() : memref<1024x1024xf32> - %s = memref.alloc() : memref<1024xf32> - // intialize matrix A ; A[i, j] = i - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %i : index to i16 - %u = arith.uitofp %t : i16 to f32 - %v = arith.divf %u, %c100_f32 : f32 - memref.store %v, %a[%i, %j] : memref<1024x1024xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 3.200000e+01 : f32 + %cst_1 = arith.constant 1.000000e+02 : f32 + %alloc = memref.alloc() : memref<1024x1024xf32> + %alloc_2 = memref.alloc() : memref<1024x1024xf32> + %alloc_3 = memref.alloc() : memref<1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i16 + %2 = arith.uitofp %1 : i16 to f32 + %3 = arith.divf %2, %cst_1 : f32 + memref.store %3, %alloc[%arg0, %arg1] : memref<1024x1024xf32> } } - // compute b for reference // step 1: exp - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %val = memref.load %a[%i, %j] : memref<1024x1024xf32> - %exp = math.exp %val : f32 - memref.store %exp, %b_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc[%arg0, %arg1] : memref<1024x1024xf32> + %2 = math.exp %1 : f32 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<1024x1024xf32> } } - // step 2: sum and div along dim-1 - scf.for %i = %c0 to %c1024 step %c1 { - %sum = scf.for %j = %c0 to %c1024 step %c1 iter_args(%arg = %c0_f32) -> (f32) { - %val = memref.load %b_ref[%i, %j] : memref<1024x1024xf32> - %2 = arith.addf %arg, %val : f32 - scf.yield %2 : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = scf.for %arg1 = %c0 to %c1024 step %c1 iter_args(%arg2 = %cst) -> (f32) { + %3 = memref.load %alloc_2[%arg0, %arg1] : memref<1024x1024xf32> + %4 = arith.addf %arg2, %3 : f32 + scf.yield %4 : f32 } - %avg = arith.divf %sum, %c32_f32: f32 - memref.store %avg, %s[%i] : memref<1024xf32> - - scf.for %j = %c0 to %c1024 step %c1 { - %val = memref.load %b_ref[%i, %j] : memref<1024x1024xf32> - %div = arith.divf %val, %avg: f32 - memref.store %div, %b_ref[%i, %j] : memref<1024x1024xf32> + %2 = arith.divf %1, %cst_0 : f32 + memref.store %2, %alloc_3[%arg0] : memref<1024xf32> + scf.for %arg1 = %c0 to %c1024 step %c1 { + %3 = memref.load %alloc_2[%arg0, %arg1] : memref<1024x1024xf32> + %4 = arith.divf %3, %2 : f32 + memref.store %4, %alloc_2[%arg0, %arg1] : memref<1024x1024xf32> } } - - - %b = call @block_softmax_test(%a) : (memref<1024x1024xf32>) -> memref<1024x1024xf32> - %cast_b = memref.cast %b : memref<1024x1024xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_b) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_b_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<1024x1024xf32> - memref.dealloc %b_ref : memref<1024x1024xf32> + %0 = call @block_softmax_test(%alloc) : (memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_4 = memref.cast %alloc_2 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf32> + memref.dealloc %alloc_2 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/convert_layout.mlir b/test/Integration/Dialect/XeTile/convert_layout.mlir index 532d46bd7..f3a1876af 100644 --- a/test/Integration/Dialect/XeTile/convert_layout.mlir +++ b/test/Integration/Dialect/XeTile/convert_layout.mlir @@ -1,81 +1,72 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @conv_layout attributes {gpu.container_module} { - func.func @convert_layout(%a: memref<64x64xf32>, %b: memref<64x64xf32>) -> memref<64x64xf32> attributes {llvm.emit_c_interface} { + func.func @convert_layout(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>) -> memref<64x64xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - - %a_gpu = gpu.alloc host_shared () : memref<64x64xf32> - memref.copy %a, %a_gpu : memref<64x64xf32> to memref<64x64xf32> - %b_gpu = gpu.alloc host_shared () : memref<64x64xf32> - memref.copy %b, %b_gpu : memref<64x64xf32> to memref<64x64xf32> - %c_gpu = gpu.alloc host_shared () : memref<64x64xf32> - - gpu.launch_func @kernel::@test_convert_layout blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%a_gpu : memref<64x64xf32>, %b_gpu : memref<64x64xf32>, %c_gpu : memref<64x64xf32>) - - gpu.dealloc %a_gpu : memref<64x64xf32> - gpu.dealloc %b_gpu : memref<64x64xf32> - return %c_gpu : memref<64x64xf32> + %memref = gpu.alloc () : memref<64x64xf32> + gpu.memcpy %memref, %arg0 : memref<64x64xf32>, memref<64x64xf32> + %memref_0 = gpu.alloc () : memref<64x64xf32> + gpu.memcpy %memref_0, %arg1 : memref<64x64xf32>, memref<64x64xf32> + %memref_1 = gpu.alloc () : memref<64x64xf32> + gpu.launch_func @kernel::@test_convert_layout blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<64x64xf32>, %memref_0 : memref<64x64xf32>, %memref_1 : memref<64x64xf32>) + gpu.dealloc %memref : memref<64x64xf32> + gpu.dealloc %memref_0 : memref<64x64xf32> + %alloc = memref.alloc() : memref<64x64xf32> + gpu.memcpy %alloc, %memref_1 : memref<64x64xf32>, memref<64x64xf32> + gpu.dealloc %memref_1 : memref<64x64xf32> + return %alloc : memref<64x64xf32> } - -gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_convert_layout(%arg0 : memref<64x64xf32>, %arg1 : memref<64x64xf32>, %arg2 : memref<64x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c1 = arith.constant 1 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c1 : index - %n = arith.muli %block_id_y, %c1 : index - %init_tile_1 = xetile.init_tile %arg0[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> - %load_tile_1 = xetile.load_tile %init_tile_1: !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> - %init_tile_2 = xetile.init_tile %arg1[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> - %load_tile_2 = xetile.load_tile %init_tile_2: !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> - %convert_layout = xetile.convert_layout %load_tile_1 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<64x64xf32> - %add = arith.addf %load_tile_2, %convert_layout {map = #xetile.wg_map} : vector<64x64xf32> - %init_store_tile = xetile.init_tile %arg2[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> - xetile.store_tile %add, %init_store_tile : vector<64x64xf32>, !xetile.tile<64x64xf32, #xetile.tile_attr>> - gpu.return + gpu.module @kernel { + gpu.func @test_convert_layout(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c1 = arith.constant 1 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c1 : index + %1 = arith.muli %block_id_y, %c1 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + %3 = xetile.load_tile %2 : !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> + %4 = xetile.init_tile %arg1[%0, %1] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + %5 = xetile.load_tile %4 : !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> + %6 = xetile.convert_layout %3 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<64x64xf32> + %7 = arith.addf %5, %6 {map = #xetile.wg_map} : vector<64x64xf32> + %8 = xetile.init_tile %arg2[%0, %1] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + xetile.store_tile %7, %8 : vector<64x64xf32>, !xetile.tile<64x64xf32, #xetile.tile_attr>> + gpu.return + } } -} - -func.func @main() attributes {llvm.emit_c_interface} { + func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index - %c1_f32 = arith.constant 1.0 : f32 - %c2_f32 = arith.constant 2.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<64x64xf32> - %b = memref.alloc() : memref<64x64xf32> - %c_ref = memref.alloc() : memref<64x64xf32> - - // intialize matrix A, B ; A[i, j] = 1 - scf.for %i = %c0 to %c64 step %c1 { - scf.for %j = %c0 to %c64 step %c1 { - memref.store %c1_f32, %a[%i, %j] : memref<64x64xf32> - memref.store %c1_f32, %b[%i, %j] : memref<64x64xf32> - memref.store %c2_f32, %c_ref[%i, %j] : memref<64x64xf32> + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %alloc = memref.alloc() : memref<64x64xf32> + %alloc_1 = memref.alloc() : memref<64x64xf32> + %alloc_2 = memref.alloc() : memref<64x64xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + scf.for %arg1 = %c0 to %c64 step %c1 { + memref.store %cst, %alloc[%arg0, %arg1] : memref<64x64xf32> + memref.store %cst, %alloc_1[%arg0, %arg1] : memref<64x64xf32> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<64x64xf32> } } - - %c = call @convert_layout(%a, %b) : (memref<64x64xf32>, memref<64x64xf32>) -> memref<64x64xf32> - %cast_c = memref.cast %c : memref<64x64xf32> to memref<*xf32> - %cast_c_ref = memref.cast %c_ref :memref<64x64xf32> to memref<*xf32> //call @printMemrefF32(%cast_c): (memref<*xf32>) -> () //call @printMemrefF32(%cast_c_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_c, %cast_c_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<64x64xf32> - memref.dealloc %b : memref<64x64xf32> + %0 = call @convert_layout(%alloc, %alloc_1) : (memref<64x64xf32>, memref<64x64xf32>) -> memref<64x64xf32> + %cast = memref.cast %0 : memref<64x64xf32> to memref<*xf32> + %cast_3 = memref.cast %alloc_2 : memref<64x64xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<64x64xf32> + memref.dealloc %alloc_1 : memref<64x64xf32> + memref.dealloc %alloc_2 : memref<64x64xf32> + memref.dealloc %0 : memref<64x64xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/convert_layout_gemm_fp16.mlir b/test/Integration/Dialect/XeTile/convert_layout_gemm_fp16.mlir index 461811fcb..7fcd24bf2 100644 --- a/test/Integration/Dialect/XeTile/convert_layout_gemm_fp16.mlir +++ b/test/Integration/Dialect/XeTile/convert_layout_gemm_fp16.mlir @@ -1,131 +1,104 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#wg_map_a_coop = #xetile.wg_map -#wg_map_a = #xetile.wg_map -#wg_map_b = #xetile.wg_map -#wg_map_c = #xetile.wg_map +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @conv_layout attributes {gpu.container_module} { - func.func @test_convert_layout_gemm(%a: memref<8x32xf16>, %b: memref<32x32xf16>) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { + func.func @test_convert_layout_gemm(%arg0: memref<8x32xf16>, %arg1: memref<32x32xf16>) -> memref<8x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - - %a_gpu = gpu.alloc host_shared () : memref<8x32xf16> - memref.copy %a, %a_gpu : memref<8x32xf16> to memref<8x32xf16> - %b_gpu = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %b, %b_gpu : memref<32x32xf16> to memref<32x32xf16> - %c_gpu = gpu.alloc host_shared () : memref<8x32xf32> - - gpu.launch_func @kernel::@test_convert_layout_gemm blocks in (%c1, %c1, %c1) threads in (%c2, %c1, %c1) args(%a_gpu : memref<8x32xf16>, %b_gpu : memref<32x32xf16>, %c_gpu : memref<8x32xf32>) - - gpu.dealloc %a_gpu : memref<8x32xf16> - gpu.dealloc %b_gpu : memref<32x32xf16> - return %c_gpu : memref<8x32xf32> + %memref = gpu.alloc () : memref<8x32xf16> + gpu.memcpy %memref, %arg0 : memref<8x32xf16>, memref<8x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref_0, %arg1 : memref<32x32xf16>, memref<32x32xf16> + %memref_1 = gpu.alloc () : memref<8x32xf32> + gpu.launch_func @kernel::@test_convert_layout_gemm blocks in (%c1, %c1, %c1) threads in (%c2, %c1, %c1) args(%memref : memref<8x32xf16>, %memref_0 : memref<32x32xf16>, %memref_1 : memref<8x32xf32>) + gpu.dealloc %memref : memref<8x32xf16> + gpu.dealloc %memref_0 : memref<32x32xf16> + %alloc = memref.alloc() : memref<8x32xf32> + gpu.memcpy %alloc, %memref_1 : memref<8x32xf32>, memref<8x32xf32> + gpu.dealloc %memref_1 : memref<8x32xf32> + return %alloc : memref<8x32xf32> } - - gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @kernel { // this test performs a simple matrix multiplication on 8x32xf16 and 32x32xf16 with a workgroup of 2 threads, which resulting a 8x32xf32 matrix. // Each thread will compute 8x16xf32 matrix, which 8x32xf16 * 32x16xf16. a is shared, each thread will load 8x16xf16 from memory, and using convert // layout to share the data. - gpu.func @test_convert_layout_gemm(%arg0 : memref<8x32xf16>, %arg1 : memref<32x32xf16>, %arg2 : memref<8x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.func @test_convert_layout_gemm(%arg0: memref<8x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<8x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c1 = arith.constant 1 : index - %id_x = gpu.block_id x - %id_y = gpu.block_id y - %m = arith.muli %id_x, %c1 : index - %n = arith.muli %id_y, %c1 : index - - %a_tile = xetile.init_tile %arg0[%m, %n] : memref<8x32xf16> -> !xetile.tile<8x32xf16, #xetile.tile_attr> - %a_coop = xetile.load_tile %a_tile: !xetile.tile<8x32xf16, #xetile.tile_attr> -> vector<8x32xf16> - %a = xetile.convert_layout %a_coop {wg_map_result = #wg_map_a, wg_map_source = #wg_map_a_coop} : vector<8x32xf16> - - %b_tile = xetile.init_tile %arg1[%m, %n] : memref<32x32xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr> - %b = xetile.load_tile %b_tile: !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<32x32xf16> - - %c = xetile.tile_mma %a, %b {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c} : vector<8x32xf16>, vector<32x32xf16> -> vector<8x32xf32> - - %init_store_tile = xetile.init_tile %arg2[%m, %n] : memref<8x32xf32> -> !xetile.tile<8x32xf32, #xetile.tile_attr> - xetile.store_tile %c, %init_store_tile : vector<8x32xf32>, !xetile.tile<8x32xf32, #xetile.tile_attr> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c1 : index + %1 = arith.muli %block_id_y, %c1 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<8x32xf16> -> !xetile.tile<8x32xf16, #xetile.tile_attr>> + %3 = xetile.load_tile %2 : !xetile.tile<8x32xf16, #xetile.tile_attr>> -> vector<8x32xf16> + %4 = xetile.convert_layout %3 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<8x32xf16> + %5 = xetile.init_tile %arg1[%0, %1] : memref<32x32xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr>> + %6 = xetile.load_tile %5 : !xetile.tile<32x32xf16, #xetile.tile_attr>> -> vector<32x32xf16> + %7 = xetile.tile_mma %4, %6 {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<8x32xf16>, vector<32x32xf16> -> vector<8x32xf32> + %8 = xetile.init_tile %arg2[%0, %1] : memref<8x32xf32> -> !xetile.tile<8x32xf32, #xetile.tile_attr>> + xetile.store_tile %7, %8 : vector<8x32xf32>, !xetile.tile<8x32xf32, #xetile.tile_attr>> gpu.return } - } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %c1_f16 = arith.constant 1.0 : f16 - %c2_f32 = arith.constant 2.0 : f32 - %c0_f32 = arith.constant 0.0 : f32 - %c100_f16 = arith.constant 100.0 : f16 - %a = memref.alloc() : memref<8x32xf16> - %b = memref.alloc() : memref<32x32xf16> - %c_ref = memref.alloc() : memref<8x32xf32> - - // intialize matrix A; - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %m = arith.muli %i, %c32 : index - %add = arith.addi %m, %j : index - %t = index.castu %add : index to i16 - %v = arith.uitofp %t : i16 to f16 - %d = arith.divf %v, %c100_f16 : f16 - memref.store %d, %a[%i, %j] : memref<8x32xf16> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+02 : f16 + %alloc = memref.alloc() : memref<8x32xf16> + %alloc_1 = memref.alloc() : memref<32x32xf16> + %alloc_2 = memref.alloc() : memref<8x32xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + %5 = arith.divf %4, %cst_0 : f16 + memref.store %5, %alloc[%arg0, %arg1] : memref<8x32xf16> } } - // intialize matrix B; - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %m = arith.muli %i, %c32 : index - %add = arith.addi %m, %j : index - %t = index.castu %add : index to i16 - %v = arith.uitofp %t : i16 to f16 - %d = arith.divf %v, %c100_f16 : f16 - memref.store %d, %b[%i, %j] : memref<32x32xf16> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + %5 = arith.divf %4, %cst_0 : f16 + memref.store %5, %alloc_1[%arg0, %arg1] : memref<32x32xf16> } } - // intialize matrix c_ref; - scf.for %i = %c0 to %c8 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - memref.store %c0_f32, %c_ref[%i, %j] : memref<8x32xf32> - scf.for %k = %c0 to %c32 step %c1 { - %cv = memref.load %c_ref[%i, %j] : memref<8x32xf32> - %av = memref.load %a[%i, %k] : memref<8x32xf16> - %bv = memref.load %b[%k, %j] : memref<32x32xf16> - - %a_f32 = arith.extf %av : f16 to f32 - %b_f32 = arith.extf %bv : f16 to f32 - %m = arith.mulf %a_f32, %b_f32 : f32 - - %acc = arith.addf %cv, %m : f32 - memref.store %acc, %c_ref[%i, %j] : memref<8x32xf32> + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<8x32xf32> + scf.for %arg2 = %c0 to %c32 step %c1 { + %1 = memref.load %alloc_2[%arg0, %arg1] : memref<8x32xf32> + %2 = memref.load %alloc[%arg0, %arg2] : memref<8x32xf16> + %3 = memref.load %alloc_1[%arg2, %arg1] : memref<32x32xf16> + %4 = arith.extf %2 : f16 to f32 + %5 = arith.extf %3 : f16 to f32 + %6 = arith.mulf %4, %5 : f32 + %7 = arith.addf %1, %6 : f32 + memref.store %7, %alloc_2[%arg0, %arg1] : memref<8x32xf32> } } } - - %c = call @test_convert_layout_gemm(%a, %b) : (memref<8x32xf16>, memref<32x32xf16>) -> memref<8x32xf32> - %cast_c = memref.cast %c : memref<8x32xf32> to memref<*xf32> - %cast_c_ref = memref.cast %c_ref :memref<8x32xf32> to memref<*xf32> - call @printMemrefF32(%cast_c): (memref<*xf32>) -> () - call @printMemrefF32(%cast_c_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_c, %cast_c_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<8x32xf16> - memref.dealloc %b : memref<32x32xf16> - memref.dealloc %c_ref : memref<8x32xf32> + %0 = call @test_convert_layout_gemm(%alloc, %alloc_1) : (memref<8x32xf16>, memref<32x32xf16>) -> memref<8x32xf32> + %cast = memref.cast %0 : memref<8x32xf32> to memref<*xf32> + %cast_3 = memref.cast %alloc_2 : memref<8x32xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () + call @printMemrefF32(%cast_3) : (memref<*xf32>) -> () + call @printAllcloseF32(%cast, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<8x32xf16> + memref.dealloc %alloc_1 : memref<32x32xf16> + memref.dealloc %alloc_2 : memref<8x32xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/convert_layout_optimal_f16.mlir b/test/Integration/Dialect/XeTile/convert_layout_optimal_f16.mlir index 217a547b2..52d0475d7 100644 --- a/test/Integration/Dialect/XeTile/convert_layout_optimal_f16.mlir +++ b/test/Integration/Dialect/XeTile/convert_layout_optimal_f16.mlir @@ -1,81 +1,68 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @conv_layout attributes {gpu.container_module} { - func.func @convert_layout(%a: memref<64x64xf16>, %b: memref<64x64xf16>) -> memref<64x64xf16> attributes {llvm.emit_c_interface} { + func.func @convert_layout(%arg0: memref<64x64xf16>, %arg1: memref<64x64xf16>) -> memref<64x64xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - - %a_gpu = gpu.alloc host_shared () : memref<64x64xf16> - memref.copy %a, %a_gpu : memref<64x64xf16> to memref<64x64xf16> - %b_gpu = gpu.alloc host_shared () : memref<64x64xf16> - memref.copy %b, %b_gpu : memref<64x64xf16> to memref<64x64xf16> - %c_gpu = gpu.alloc host_shared () : memref<64x64xf16> - - gpu.launch_func @kernel::@test_convert_layout blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%a_gpu : memref<64x64xf16>, %b_gpu : memref<64x64xf16>, %c_gpu : memref<64x64xf16>) - - gpu.dealloc %a_gpu : memref<64x64xf16> - gpu.dealloc %b_gpu : memref<64x64xf16> - return %c_gpu : memref<64x64xf16> + %memref = gpu.alloc () : memref<64x64xf16> + gpu.memcpy %memref, %arg0 : memref<64x64xf16>, memref<64x64xf16> + %memref_0 = gpu.alloc () : memref<64x64xf16> + gpu.memcpy %memref_0, %arg1 : memref<64x64xf16>, memref<64x64xf16> + %memref_1 = gpu.alloc () : memref<64x64xf16> + gpu.launch_func @kernel::@test_convert_layout blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<64x64xf16>, %memref_0 : memref<64x64xf16>, %memref_1 : memref<64x64xf16>) + gpu.dealloc %memref : memref<64x64xf16> + gpu.dealloc %memref_0 : memref<64x64xf16> + %alloc = memref.alloc() : memref<64x64xf16> + gpu.memcpy %alloc, %memref_1 : memref<64x64xf16>, memref<64x64xf16> + gpu.dealloc %memref_1 : memref<64x64xf16> + return %alloc : memref<64x64xf16> } - -gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_convert_layout(%arg0 : memref<64x64xf16>, %arg1 : memref<64x64xf16>, %arg2 : memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c1 = arith.constant 1 : index - %m = gpu.block_id x - %n = gpu.block_id y - %init_tile_1 = xetile.init_tile %arg0[%m, %n] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> - %load_tile_1 = xetile.load_tile %init_tile_1: !xetile.tile<64x64xf16, #xetile.tile_attr>> -> vector<64x64xf16> - - %convert = xetile.convert_layout %load_tile_1 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<64x64xf16> - - %init_tile_2 = xetile.init_tile %arg1[%m, %n] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> - %load_tile_2 = xetile.load_tile %init_tile_2: !xetile.tile<64x64xf16, #xetile.tile_attr>> -> vector<64x64xf16> - - %add = arith.addf %load_tile_2, %convert {map = #xetile.wg_map} : vector<64x64xf16> - %init_store_tile = xetile.init_tile %arg2[%m, %n] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> - xetile.store_tile %add, %init_store_tile : vector<64x64xf16>, !xetile.tile<64x64xf16, #xetile.tile_attr>> - gpu.return + gpu.module @kernel { + gpu.func @test_convert_layout(%arg0: memref<64x64xf16>, %arg1: memref<64x64xf16>, %arg2: memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c1 = arith.constant 1 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = xetile.init_tile %arg0[%block_id_x, %block_id_y] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> + %1 = xetile.load_tile %0 : !xetile.tile<64x64xf16, #xetile.tile_attr>> -> vector<64x64xf16> + %2 = xetile.convert_layout %1 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<64x64xf16> + %3 = xetile.init_tile %arg1[%block_id_x, %block_id_y] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> + %4 = xetile.load_tile %3 : !xetile.tile<64x64xf16, #xetile.tile_attr>> -> vector<64x64xf16> + %5 = arith.addf %4, %2 {map = #xetile.wg_map} : vector<64x64xf16> + %6 = xetile.init_tile %arg2[%block_id_x, %block_id_y] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> + xetile.store_tile %5, %6 : vector<64x64xf16>, !xetile.tile<64x64xf16, #xetile.tile_attr>> + gpu.return + } } -} - -func.func @main() attributes {llvm.emit_c_interface} { + func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index - %c1_f16 = arith.constant 1.0 : f16 - %c2_f32 = arith.constant 2.0 : f32 - %a = memref.alloc() : memref<64x64xf16> - %b = memref.alloc() : memref<64x64xf16> - %c_ref = memref.alloc() : memref<64x64xf32> - - // intialize matrix A, B ; A[i, j] = 1 - scf.for %i = %c0 to %c64 step %c1 { - scf.for %j = %c0 to %c64 step %c1 { - memref.store %c1_f16, %a[%i, %j] : memref<64x64xf16> - memref.store %c1_f16, %b[%i, %j] : memref<64x64xf16> - memref.store %c2_f32, %c_ref[%i, %j] : memref<64x64xf32> + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant 2.000000e+00 : f32 + %alloc = memref.alloc() : memref<64x64xf16> + %alloc_1 = memref.alloc() : memref<64x64xf16> + %alloc_2 = memref.alloc() : memref<64x64xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + scf.for %arg1 = %c0 to %c64 step %c1 { + memref.store %cst, %alloc[%arg0, %arg1] : memref<64x64xf16> + memref.store %cst, %alloc_1[%arg0, %arg1] : memref<64x64xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<64x64xf32> } } - - %c = call @convert_layout(%a, %b) : (memref<64x64xf16>, memref<64x64xf16>) -> memref<64x64xf16> - %cast_c = memref.cast %c : memref<64x64xf16> to memref<*xf16> - %cast_c_ref = memref.cast %c_ref :memref<64x64xf32> to memref<*xf32> // call @printMemrefF32(%cast_c): (memref<*xf32>) -> () // call @printMemrefF32(%cast_c_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_c, %cast_c_ref) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %a : memref<64x64xf16> - memref.dealloc %b : memref<64x64xf16> + %0 = call @convert_layout(%alloc, %alloc_1) : (memref<64x64xf16>, memref<64x64xf16>) -> memref<64x64xf16> + %cast = memref.cast %0 : memref<64x64xf16> to memref<*xf16> + %cast_3 = memref.cast %alloc_2 : memref<64x64xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_3) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<64x64xf16> + memref.dealloc %alloc_1 : memref<64x64xf16> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/convert_layout_optimal_f32.mlir b/test/Integration/Dialect/XeTile/convert_layout_optimal_f32.mlir index 67d88e90d..769483fe4 100644 --- a/test/Integration/Dialect/XeTile/convert_layout_optimal_f32.mlir +++ b/test/Integration/Dialect/XeTile/convert_layout_optimal_f32.mlir @@ -1,82 +1,68 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @conv_layout attributes {gpu.container_module} { - func.func @convert_layout(%a: memref<64x64xf32>, %b: memref<64x64xf32>) -> memref<64x64xf32> attributes {llvm.emit_c_interface} { + func.func @convert_layout(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>) -> memref<64x64xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - - %a_gpu = gpu.alloc host_shared () : memref<64x64xf32> - memref.copy %a, %a_gpu : memref<64x64xf32> to memref<64x64xf32> - %b_gpu = gpu.alloc host_shared () : memref<64x64xf32> - memref.copy %b, %b_gpu : memref<64x64xf32> to memref<64x64xf32> - %c_gpu = gpu.alloc host_shared () : memref<64x64xf32> - - gpu.launch_func @kernel::@test_convert_layout blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%a_gpu : memref<64x64xf32>, %b_gpu : memref<64x64xf32>, %c_gpu : memref<64x64xf32>) - - gpu.dealloc %a_gpu : memref<64x64xf32> - gpu.dealloc %b_gpu : memref<64x64xf32> - return %c_gpu : memref<64x64xf32> + %memref = gpu.alloc () : memref<64x64xf32> + gpu.memcpy %memref, %arg0 : memref<64x64xf32>, memref<64x64xf32> + %memref_0 = gpu.alloc () : memref<64x64xf32> + gpu.memcpy %memref_0, %arg1 : memref<64x64xf32>, memref<64x64xf32> + %memref_1 = gpu.alloc () : memref<64x64xf32> + gpu.launch_func @kernel::@test_convert_layout blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<64x64xf32>, %memref_0 : memref<64x64xf32>, %memref_1 : memref<64x64xf32>) + gpu.dealloc %memref : memref<64x64xf32> + gpu.dealloc %memref_0 : memref<64x64xf32> + %alloc = memref.alloc() : memref<64x64xf32> + gpu.memcpy %alloc, %memref_1 : memref<64x64xf32>, memref<64x64xf32> + gpu.dealloc %memref_1 : memref<64x64xf32> + return %alloc : memref<64x64xf32> } - -gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_convert_layout(%arg0 : memref<64x64xf32>, %arg1 : memref<64x64xf32>, %arg2 : memref<64x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c1 = arith.constant 1 : index - %m = gpu.block_id x - %n = gpu.block_id y - %init_tile_1 = xetile.init_tile %arg0[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> - %load_tile_1 = xetile.load_tile %init_tile_1: !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> - - %convert = xetile.convert_layout %load_tile_1 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<64x64xf32> - - %init_tile_2 = xetile.init_tile %arg1[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> - %load_tile_2 = xetile.load_tile %init_tile_2: !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> - - %add = arith.addf %load_tile_2, %convert {map = #xetile.wg_map} : vector<64x64xf32> - %init_store_tile = xetile.init_tile %arg2[%m, %n] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> - xetile.store_tile %add, %init_store_tile : vector<64x64xf32>, !xetile.tile<64x64xf32, #xetile.tile_attr>> - gpu.return + gpu.module @kernel { + gpu.func @test_convert_layout(%arg0: memref<64x64xf32>, %arg1: memref<64x64xf32>, %arg2: memref<64x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c1 = arith.constant 1 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = xetile.init_tile %arg0[%block_id_x, %block_id_y] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + %1 = xetile.load_tile %0 : !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> + %2 = xetile.convert_layout %1 {wg_map_result = #xetile.wg_map, wg_map_source = #xetile.wg_map} : vector<64x64xf32> + %3 = xetile.init_tile %arg1[%block_id_x, %block_id_y] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + %4 = xetile.load_tile %3 : !xetile.tile<64x64xf32, #xetile.tile_attr>> -> vector<64x64xf32> + %5 = arith.addf %4, %2 {map = #xetile.wg_map} : vector<64x64xf32> + %6 = xetile.init_tile %arg2[%block_id_x, %block_id_y] : memref<64x64xf32> -> !xetile.tile<64x64xf32, #xetile.tile_attr>> + xetile.store_tile %5, %6 : vector<64x64xf32>, !xetile.tile<64x64xf32, #xetile.tile_attr>> + gpu.return + } } -} - -func.func @main() attributes {llvm.emit_c_interface} { + func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index - %c1_f32 = arith.constant 1.0 : f32 - %c2_f32 = arith.constant 2.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<64x64xf32> - %b = memref.alloc() : memref<64x64xf32> - %c_ref = memref.alloc() : memref<64x64xf32> - - // intialize matrix A, B ; A[i, j] = 1 - scf.for %i = %c0 to %c64 step %c1 { - scf.for %j = %c0 to %c64 step %c1 { - memref.store %c1_f32, %a[%i, %j] : memref<64x64xf32> - memref.store %c1_f32, %b[%i, %j] : memref<64x64xf32> - memref.store %c2_f32, %c_ref[%i, %j] : memref<64x64xf32> + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %alloc = memref.alloc() : memref<64x64xf32> + %alloc_1 = memref.alloc() : memref<64x64xf32> + %alloc_2 = memref.alloc() : memref<64x64xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + scf.for %arg1 = %c0 to %c64 step %c1 { + memref.store %cst, %alloc[%arg0, %arg1] : memref<64x64xf32> + memref.store %cst, %alloc_1[%arg0, %arg1] : memref<64x64xf32> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<64x64xf32> } } - - %c = call @convert_layout(%a, %b) : (memref<64x64xf32>, memref<64x64xf32>) -> memref<64x64xf32> - %cast_c = memref.cast %c : memref<64x64xf32> to memref<*xf32> - %cast_c_ref = memref.cast %c_ref :memref<64x64xf32> to memref<*xf32> // call @printMemrefF32(%cast_c): (memref<*xf32>) -> () // call @printMemrefF32(%cast_c_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_c, %cast_c_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<64x64xf32> - memref.dealloc %b : memref<64x64xf32> + %0 = call @convert_layout(%alloc, %alloc_1) : (memref<64x64xf32>, memref<64x64xf32>) -> memref<64x64xf32> + %cast = memref.cast %0 : memref<64x64xf32> to memref<*xf32> + %cast_3 = memref.cast %alloc_2 : memref<64x64xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<64x64xf32> + memref.dealloc %alloc_1 : memref<64x64xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/eltwise_int_ops.mlir b/test/Integration/Dialect/XeTile/eltwise_int_ops.mlir index e213a0c51..80b7e949a 100644 --- a/test/Integration/Dialect/XeTile/eltwise_int_ops.mlir +++ b/test/Integration/Dialect/XeTile/eltwise_int_ops.mlir @@ -1,77 +1,61 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + module @eltwise_int attributes {gpu.container_module} { memref.global "private" constant @__constant_5_1024x1024xi32 : memref<1024x1024xi32> = dense<5> memref.global "private" constant @__constant_2_1024x1024xi32 : memref<1024x1024xi32> = dense<2> - func.func @eltwise_int_test(%arg0: memref<1024x1024xi32>, %arg1: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - - %arg0_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> - memref.copy %arg0, %arg0_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> - - %arg1_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> - memref.copy %arg1, %arg1_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> - - %result = gpu.alloc host_shared () : memref<1024x1024xi32> - - gpu.launch_func @eltwise_int::@eltwise_int blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%arg0_gpu : memref<1024x1024xi32>, %arg1_gpu : memref<1024x1024xi32>, %result : memref<1024x1024xi32>) - - gpu.dealloc %arg0_gpu : memref<1024x1024xi32> - gpu.dealloc %arg1_gpu : memref<1024x1024xi32> - return %result : memref<1024x1024xi32> - + %memref = gpu.alloc () : memref<1024x1024xi32> + gpu.memcpy %memref, %arg0 : memref<1024x1024xi32>, memref<1024x1024xi32> + %memref_0 = gpu.alloc () : memref<1024x1024xi32> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xi32>, memref<1024x1024xi32> + %memref_1 = gpu.alloc () : memref<1024x1024xi32> + gpu.launch_func @eltwise_int::@eltwise_int blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xi32>, %memref_0 : memref<1024x1024xi32>, %memref_1 : memref<1024x1024xi32>) + gpu.dealloc %memref : memref<1024x1024xi32> + gpu.dealloc %memref_0 : memref<1024x1024xi32> + %alloc = memref.alloc() : memref<1024x1024xi32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xi32>, memref<1024x1024xi32> + gpu.dealloc %memref_1 : memref<1024x1024xi32> + return %alloc : memref<1024x1024xi32> } - - gpu.module @eltwise_int attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @eltwise_int { gpu.func @eltwise_int(%arg0: memref<1024x1024xi32>, %arg1: memref<1024x1024xi32>, %arg2: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - - %1 = xetile.init_tile %arg0[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> - %2 = xetile.load_tile %1: !xetile.tile<16x32xi32> -> vector<16x32xi32> - %3 = xetile.init_tile %arg1[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> - %4 = xetile.load_tile %3: !xetile.tile<16x32xi32> -> vector<16x32xi32> - %result_add = arith.addi %2, %4: vector<16x32xi32> //=7 - %result_sub = arith.subi %2, %4: vector<16x32xi32> //=3 - %result_mul = arith.muli %result_add, %result_sub: vector<16x32xi32> //=21 - %result_sdiv = arith.divsi %result_mul, %result_add: vector<16x32xi32> //=3 - %result_udiv = arith.divui %result_mul, %result_add: vector<16x32xi32> //=3 - %result_srem = arith.remsi %result_sdiv, %result_mul: vector<16x32xi32> //=3 - %result_urem = arith.remui %result_udiv, %result_srem: vector<16x32xi32> //=0 - %result = arith.addi %result_srem, %result_urem: vector<16x32xi32> //=3 - %store_tile = xetile.init_tile %arg2[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> - xetile.store_tile %result, %store_tile: vector<16x32xi32>, !xetile.tile<16x32xi32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xi32> -> vector<16x32xi32> + %4 = xetile.init_tile %arg1[%0, %1] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + %5 = xetile.load_tile %4 : !xetile.tile<16x32xi32> -> vector<16x32xi32> + %6 = arith.addi %3, %5 : vector<16x32xi32> + %7 = arith.subi %3, %5 : vector<16x32xi32> + %8 = arith.muli %6, %7 : vector<16x32xi32> + %9 = arith.divsi %8, %6 : vector<16x32xi32> + %10 = arith.divui %8, %6 : vector<16x32xi32> + %11 = arith.remsi %9, %8 : vector<16x32xi32> + %12 = arith.remui %10, %11 : vector<16x32xi32> + %13 = arith.addi %11, %12 : vector<16x32xi32> + %14 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + xetile.store_tile %13, %14 : vector<16x32xi32>, !xetile.tile<16x32xi32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { - %A = memref.get_global @__constant_5_1024x1024xi32 : memref<1024x1024xi32> - %B = memref.get_global @__constant_2_1024x1024xi32 : memref<1024x1024xi32> - - %c0_i32 = arith.constant 0 : i32 - - %result = call @eltwise_int_test(%A, %B) : (memref<1024x1024xi32>, memref<1024x1024xi32>) -> memref<1024x1024xi32> - %result_cast = memref.cast %result : memref<1024x1024xi32> to memref<*xi32> // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}} // CHECK-COUNT-1048576: 3 - call @printMemrefI32(%result_cast) : (memref<*xi32>) -> () - + %0 = memref.get_global @__constant_5_1024x1024xi32 : memref<1024x1024xi32> + %1 = memref.get_global @__constant_2_1024x1024xi32 : memref<1024x1024xi32> + %2 = call @eltwise_int_test(%0, %1) : (memref<1024x1024xi32>, memref<1024x1024xi32>) -> memref<1024x1024xi32> + %cast = memref.cast %2 : memref<1024x1024xi32> to memref<*xi32> + call @printMemrefI32(%cast) : (memref<*xi32>) -> () return } func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir b/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir index 8b6fd3997..e653609f1 100644 --- a/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir +++ b/test/Integration/Dialect/XeTile/fallback/narrow_tile_one_elem_wide.mlir @@ -1,40 +1,34 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @narrow_tile attributes {gpu.container_module} { - func.func @test(%A: memref<64x1xf32>) -> memref<64x1xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<64x1xf32>) -> memref<64x1xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %A_gpu = gpu.alloc host_shared() : memref<64x1xf32> - memref.copy %A, %A_gpu : memref<64x1xf32> to memref<64x1xf32> - %B_gpu = gpu.alloc host_shared() : memref<64x1xf32> - gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<64x1xf32>, %B_gpu : memref<64x1xf32>) - %B = memref.alloc() : memref<64x1xf32> - memref.copy %B_gpu, %B : memref<64x1xf32> to memref<64x1xf32> - gpu.dealloc %A_gpu : memref<64x1xf32> - gpu.dealloc %B_gpu : memref<64x1xf32> - return %B : memref<64x1xf32> + %memref = gpu.alloc () : memref<64x1xf32> + gpu.memcpy %memref, %arg0 : memref<64x1xf32>, memref<64x1xf32> + %memref_0 = gpu.alloc () : memref<64x1xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<64x1xf32>, %memref_0 : memref<64x1xf32>) + %alloc = memref.alloc() : memref<64x1xf32> + gpu.memcpy %alloc, %memref_0 : memref<64x1xf32>, memref<64x1xf32> + gpu.dealloc %memref : memref<64x1xf32> + gpu.dealloc %memref_0 : memref<64x1xf32> + return %alloc : memref<64x1xf32> } - gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_module { gpu.func @test_scf_for(%arg0: memref<64x1xf32>, %arg1: memref<64x1xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst0 = arith.constant 0 : index - %cst16 = arith.constant 16 : index - %cst64 = arith.constant 64 : index - %0 = xetile.init_tile %arg0 [0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr> - %1 = xetile.init_tile %arg1 [0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr> - %out:2 = scf.for %k = %cst0 to %cst64 step %cst16 - iter_args(%a_tile = %0, %b_tile = %1) - -> (!xetile.tile<16x1xf32, #xetile.tile_attr>, !xetile.tile<16x1xf32, #xetile.tile_attr>) { - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x1xf32, #xetile.tile_attr> -> vector<16x1xf32> - xetile.store_tile %a_value, %b_tile : vector<16x1xf32>, !xetile.tile<16x1xf32, #xetile.tile_attr> - %a_next_tile = xetile.update_tile_offset %a_tile, [%cst16, %cst0] : !xetile.tile<16x1xf32, #xetile.tile_attr> - %b_next_tile = xetile.update_tile_offset %b_tile, [%cst16, %cst0] : !xetile.tile<16x1xf32, #xetile.tile_attr> - scf.yield %a_next_tile, %b_next_tile : !xetile.tile<16x1xf32, #xetile.tile_attr>, !xetile.tile<16x1xf32, #xetile.tile_attr> + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %0 = xetile.init_tile %arg0[0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr<>> + %1 = xetile.init_tile %arg1[0, 0] : memref<64x1xf32> -> !xetile.tile<16x1xf32, #xetile.tile_attr<>> + %2:2 = scf.for %arg2 = %c0 to %c64 step %c16 iter_args(%arg3 = %0, %arg4 = %1) -> (!xetile.tile<16x1xf32, #xetile.tile_attr<>>, !xetile.tile<16x1xf32, #xetile.tile_attr<>>) { + %3 = xetile.load_tile %arg3 : !xetile.tile<16x1xf32, #xetile.tile_attr<>> -> vector<16x1xf32> + xetile.store_tile %3, %arg4 : vector<16x1xf32>, !xetile.tile<16x1xf32, #xetile.tile_attr<>> + %4 = xetile.update_tile_offset %arg3, [%c16, %c0] : !xetile.tile<16x1xf32, #xetile.tile_attr<>> + %5 = xetile.update_tile_offset %arg4, [%c16, %c0] : !xetile.tile<16x1xf32, #xetile.tile_attr<>> + scf.yield %4, %5 : !xetile.tile<16x1xf32, #xetile.tile_attr<>>, !xetile.tile<16x1xf32, #xetile.tile_attr<>> } gpu.return } @@ -43,20 +37,19 @@ module @narrow_tile attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index - - %A = memref.alloc() : memref<64x1xf32> + %alloc = memref.alloc() : memref<64x1xf32> scf.for %arg0 = %c0 to %c64 step %c1 { - %0 = index.castu %arg0 : index to i32 - %val = arith.uitofp %0 : i32 to f32 - memref.store %val, %A[%arg0, %c0] : memref<64x1xf32> + %1 = index.castu %arg0 : index to i32 + %2 = arith.uitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0, %c0] : memref<64x1xf32> } - %C = call @test(%A) : (memref<64x1xf32>) -> memref<64x1xf32> - %cast_A = memref.cast %A : memref<64x1xf32> to memref<*xf32> - %cast_C = memref.cast %C : memref<64x1xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<64x1xf32>) -> memref<64x1xf32> + %cast = memref.cast %alloc : memref<64x1xf32> to memref<*xf32> + %cast_0 = memref.cast %0 : memref<64x1xf32> to memref<*xf32> + call @printAllcloseF32(%cast_0, %cast) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir b/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir index 8f45c295c..d68afae98 100644 --- a/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir +++ b/test/Integration/Dialect/XeTile/fallback/narrow_tile_two_elem_wide.mlir @@ -1,40 +1,34 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @narrow_tile attributes {gpu.container_module} { - func.func @test(%A: memref<64x2xf32>) -> memref<64x2xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<64x2xf32>) -> memref<64x2xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %A_gpu = gpu.alloc host_shared() : memref<64x2xf32> - memref.copy %A, %A_gpu : memref<64x2xf32> to memref<64x2xf32> - %B_gpu = gpu.alloc host_shared() : memref<64x2xf32> - gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<64x2xf32>, %B_gpu : memref<64x2xf32>) - %B = memref.alloc() : memref<64x2xf32> - memref.copy %B_gpu, %B : memref<64x2xf32> to memref<64x2xf32> - gpu.dealloc %A_gpu : memref<64x2xf32> - gpu.dealloc %B_gpu : memref<64x2xf32> - return %B : memref<64x2xf32> + %memref = gpu.alloc () : memref<64x2xf32> + gpu.memcpy %memref, %arg0 : memref<64x2xf32>, memref<64x2xf32> + %memref_0 = gpu.alloc () : memref<64x2xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<64x2xf32>, %memref_0 : memref<64x2xf32>) + %alloc = memref.alloc() : memref<64x2xf32> + gpu.memcpy %alloc, %memref_0 : memref<64x2xf32>, memref<64x2xf32> + gpu.dealloc %memref : memref<64x2xf32> + gpu.dealloc %memref_0 : memref<64x2xf32> + return %alloc : memref<64x2xf32> } - gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_module { gpu.func @test_scf_for(%arg0: memref<64x2xf32>, %arg1: memref<64x2xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst0 = arith.constant 0 : index - %cst16 = arith.constant 16 : index - %cst64 = arith.constant 64 : index - %0 = xetile.init_tile %arg0 [0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr> - %1 = xetile.init_tile %arg1 [0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr> - %out:2 = scf.for %k = %cst0 to %cst64 step %cst16 - iter_args(%a_tile = %0, %b_tile = %1) - -> (!xetile.tile<16x2xf32, #xetile.tile_attr>, !xetile.tile<16x2xf32, #xetile.tile_attr>) { - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x2xf32, #xetile.tile_attr> -> vector<16x2xf32> - xetile.store_tile %a_value, %b_tile : vector<16x2xf32>, !xetile.tile<16x2xf32, #xetile.tile_attr> - %a_next_tile = xetile.update_tile_offset %a_tile, [%cst16, %cst0] : !xetile.tile<16x2xf32, #xetile.tile_attr> - %b_next_tile = xetile.update_tile_offset %b_tile, [%cst16, %cst0] : !xetile.tile<16x2xf32, #xetile.tile_attr> - scf.yield %a_next_tile, %b_next_tile : !xetile.tile<16x2xf32, #xetile.tile_attr>, !xetile.tile<16x2xf32, #xetile.tile_attr> + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %0 = xetile.init_tile %arg0[0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr<>> + %1 = xetile.init_tile %arg1[0, 0] : memref<64x2xf32> -> !xetile.tile<16x2xf32, #xetile.tile_attr<>> + %2:2 = scf.for %arg2 = %c0 to %c64 step %c16 iter_args(%arg3 = %0, %arg4 = %1) -> (!xetile.tile<16x2xf32, #xetile.tile_attr<>>, !xetile.tile<16x2xf32, #xetile.tile_attr<>>) { + %3 = xetile.load_tile %arg3 : !xetile.tile<16x2xf32, #xetile.tile_attr<>> -> vector<16x2xf32> + xetile.store_tile %3, %arg4 : vector<16x2xf32>, !xetile.tile<16x2xf32, #xetile.tile_attr<>> + %4 = xetile.update_tile_offset %arg3, [%c16, %c0] : !xetile.tile<16x2xf32, #xetile.tile_attr<>> + %5 = xetile.update_tile_offset %arg4, [%c16, %c0] : !xetile.tile<16x2xf32, #xetile.tile_attr<>> + scf.yield %4, %5 : !xetile.tile<16x2xf32, #xetile.tile_attr<>>, !xetile.tile<16x2xf32, #xetile.tile_attr<>> } gpu.return } @@ -43,21 +37,20 @@ module @narrow_tile attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index - - %A = memref.alloc() : memref<64x2xf32> + %alloc = memref.alloc() : memref<64x2xf32> scf.for %arg0 = %c0 to %c64 step %c1 { - %0 = index.castu %arg0 : index to i32 - %val = arith.uitofp %0 : i32 to f32 - memref.store %val, %A[%arg0, %c0] : memref<64x2xf32> - memref.store %val, %A[%arg0, %c1] : memref<64x2xf32> + %1 = index.castu %arg0 : index to i32 + %2 = arith.uitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0, %c0] : memref<64x2xf32> + memref.store %2, %alloc[%arg0, %c1] : memref<64x2xf32> } - %C = call @test(%A) : (memref<64x2xf32>) -> memref<64x2xf32> - %cast_A = memref.cast %A : memref<64x2xf32> to memref<*xf32> - %cast_C = memref.cast %C : memref<64x2xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<64x2xf32>) -> memref<64x2xf32> + %cast = memref.cast %alloc : memref<64x2xf32> to memref<*xf32> + %cast_0 = memref.cast %0 : memref<64x2xf32> to memref<*xf32> + call @printAllcloseF32(%cast_0, %cast) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/fallback/slm.mlir b/test/Integration/Dialect/XeTile/fallback/slm.mlir index e3c93d78c..1178ccd74 100644 --- a/test/Integration/Dialect/XeTile/fallback/slm.mlir +++ b/test/Integration/Dialect/XeTile/fallback/slm.mlir @@ -1,53 +1,45 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-fallback-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @narrow_tile attributes {gpu.container_module} { - func.func @test(%A: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<32x32xf32>) -> memref<32x32xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %A_gpu = gpu.alloc host_shared() : memref<32x32xf32> - memref.copy %A, %A_gpu : memref<32x32xf32> to memref<32x32xf32> - %B_gpu = gpu.alloc host_shared() : memref<32x32xf32> - gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<32x32xf32>, %B_gpu : memref<32x32xf32>) - %B = memref.alloc() : memref<32x32xf32> - memref.copy %B_gpu, %B : memref<32x32xf32> to memref<32x32xf32> - gpu.dealloc %A_gpu : memref<32x32xf32> - gpu.dealloc %B_gpu : memref<32x32xf32> - return %B : memref<32x32xf32> + %memref = gpu.alloc () : memref<32x32xf32> + gpu.memcpy %memref, %arg0 : memref<32x32xf32>, memref<32x32xf32> + %memref_0 = gpu.alloc () : memref<32x32xf32> + gpu.launch_func @test_module::@test_scf_for blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x32xf32>, %memref_0 : memref<32x32xf32>) + %alloc = memref.alloc() : memref<32x32xf32> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf32>, memref<32x32xf32> + gpu.dealloc %memref : memref<32x32xf32> + gpu.dealloc %memref_0 : memref<32x32xf32> + return %alloc : memref<32x32xf32> } - gpu.module @test_module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @test_module { gpu.func @test_scf_for(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst0 = arith.constant 0 : index - %cst8 = arith.constant 8 : index - %cst16 = arith.constant 16 : index - %cst32 = arith.constant 32 : index - %0 = xetile.init_tile %arg0 [0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr> - %1 = xetile.init_tile %arg1 [0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr> - %slm = memref.alloc() : memref<512xi8, 3> - %view = memref.view %slm[%cst0][] : memref<512xi8, 3> to memref<8x16xf32, 3> - %slm_tile = xetile.init_tile %view[0, 0] : memref<8x16xf32, 3> -> !xetile.tile<8x16xf32, #xetile.tile_attr> - %out:2 = scf.for %j = %cst0 to %cst32 step %cst8 - iter_args(%a_tile = %0, %b_tile = %1) - -> (!xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr>) { - %out:2 = scf.for %k = %cst0 to %cst32 step %cst16 - iter_args(%c_tile = %a_tile, %d_tile = %b_tile) - -> (!xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr>) { - %c_value = xetile.load_tile %c_tile : !xetile.tile<8x16xf32, #xetile.tile_attr> -> vector<8x16xf32> - xetile.store_tile %c_value, %slm_tile : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr> - %d_value = xetile.load_tile %slm_tile : !xetile.tile<8x16xf32, #xetile.tile_attr> -> vector<8x16xf32> - xetile.store_tile %d_value, %d_tile : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr> - %c_next_tile = xetile.update_tile_offset %c_tile, [%cst0, %cst16] : !xetile.tile<8x16xf32, #xetile.tile_attr> - %d_next_tile = xetile.update_tile_offset %d_tile, [%cst0, %cst16] : !xetile.tile<8x16xf32, #xetile.tile_attr> - scf.yield %c_next_tile, %d_next_tile : !xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr> + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %0 = xetile.init_tile %arg0[0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr<>> + %1 = xetile.init_tile %arg1[0, 0] : memref<32x32xf32> -> !xetile.tile<8x16xf32, #xetile.tile_attr<>> + %alloc = memref.alloc() : memref<512xi8, 3> + %view = memref.view %alloc[%c0][] : memref<512xi8, 3> to memref<8x16xf32, 3> + %2 = xetile.init_tile %view[0, 0] : memref<8x16xf32, 3> -> !xetile.tile<8x16xf32, #xetile.tile_attr> + %3:2 = scf.for %arg2 = %c0 to %c32 step %c8 iter_args(%arg3 = %0, %arg4 = %1) -> (!xetile.tile<8x16xf32, #xetile.tile_attr<>>, !xetile.tile<8x16xf32, #xetile.tile_attr<>>) { + %4:2 = scf.for %arg5 = %c0 to %c32 step %c16 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (!xetile.tile<8x16xf32, #xetile.tile_attr<>>, !xetile.tile<8x16xf32, #xetile.tile_attr<>>) { + %7 = xetile.load_tile %arg6 : !xetile.tile<8x16xf32, #xetile.tile_attr<>> -> vector<8x16xf32> + xetile.store_tile %7, %2 : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr> + %8 = xetile.load_tile %2 : !xetile.tile<8x16xf32, #xetile.tile_attr> -> vector<8x16xf32> + xetile.store_tile %8, %arg7 : vector<8x16xf32>, !xetile.tile<8x16xf32, #xetile.tile_attr<>> + %9 = xetile.update_tile_offset %arg6, [%c0, %c16] : !xetile.tile<8x16xf32, #xetile.tile_attr<>> + %10 = xetile.update_tile_offset %arg7, [%c0, %c16] : !xetile.tile<8x16xf32, #xetile.tile_attr<>> + scf.yield %9, %10 : !xetile.tile<8x16xf32, #xetile.tile_attr<>>, !xetile.tile<8x16xf32, #xetile.tile_attr<>> } - %a_next_tile = xetile.update_tile_offset %a_tile, [%cst8, %cst0] : !xetile.tile<8x16xf32, #xetile.tile_attr> - %b_next_tile = xetile.update_tile_offset %b_tile, [%cst8, %cst0] : !xetile.tile<8x16xf32, #xetile.tile_attr> - scf.yield %a_next_tile, %b_next_tile : !xetile.tile<8x16xf32, #xetile.tile_attr>, !xetile.tile<8x16xf32, #xetile.tile_attr> + %5 = xetile.update_tile_offset %arg3, [%c8, %c0] : !xetile.tile<8x16xf32, #xetile.tile_attr<>> + %6 = xetile.update_tile_offset %arg4, [%c8, %c0] : !xetile.tile<8x16xf32, #xetile.tile_attr<>> + scf.yield %5, %6 : !xetile.tile<8x16xf32, #xetile.tile_attr<>>, !xetile.tile<8x16xf32, #xetile.tile_attr<>> } gpu.return } @@ -56,24 +48,23 @@ module @narrow_tile attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - - %A = memref.alloc() : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf32> scf.for %arg0 = %c0 to %c32 step %c1 { scf.for %arg1 = %c0 to %c32 step %c1 { - %0 = index.castu %arg0 : index to i32 - %1 = index.castu %arg1 : index to i32 - %2 = arith.addi %0, %1 : i32 - %val = arith.uitofp %2 : i32 to f32 - memref.store %val, %A[%arg0, %arg1] : memref<32x32xf32> + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.addi %1, %2 : i32 + %4 = arith.uitofp %3 : i32 to f32 + memref.store %4, %alloc[%arg0, %arg1] : memref<32x32xf32> } } - %C = call @test(%A) : (memref<32x32xf32>) -> memref<32x32xf32> - %cast_A = memref.cast %A : memref<32x32xf32> to memref<*xf32> - %cast_C = memref.cast %C : memref<32x32xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_A) : (memref<*xf32>, memref<*xf32>) -> () //call @printMemrefF32(%cast_A) : (memref<*xf32>) -> () //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + %0 = call @test(%alloc) : (memref<32x32xf32>) -> memref<32x32xf32> + %cast = memref.cast %alloc : memref<32x32xf32> to memref<*xf32> + %cast_0 = memref.cast %0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF32(%cast_0, %cast) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp b/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp index 844e31c98..2a8cd5c76 100644 --- a/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp +++ b/test/Integration/Dialect/XeTile/fallback/xetile-fallback-to-func-vc.pp @@ -1,40 +1,42 @@ +// gpu dialect with subgroup level XeTile dialect to +// llvm dialect (for host code) and +// spirv dialect (for device code) lowering pipeline. +// Ready for imex runner starting from GPU dialect. + builtin.module( cse - gpu.module(xetile-init-duplicate - xetile-canonicalization - xetile-blockop-fallback - xetile-blocking - cse - convert-xetile-to-xegpu - cse - imex-xegpu-hoist-transpose - imex-xegpu-apply-vnni-transformation - imex-xegpu-optimize-transpose - cse + gpu.module(xetile-init-duplicate, + xetile-canonicalization, + xetile-blockop-fallback, + xetile-blocking, + cse, + convert-xetile-to-xegpu, + cse, + imex-xegpu-hoist-transpose, + imex-xegpu-apply-vnni-transformation, + imex-xegpu-optimize-transpose) + cse + gpu.module(convert-math-to-vc{enable-high-precision-interim-calculation=true}, convert-xegpu-to-vc) cse xegpu-vector-linearize canonicalize cse reconcile-unrealized-casts - bf16-to-gpu - cse - imex-convert-gpu-to-spirv - spirv.module(spirv-lower-abi-attrs - spirv-update-vce) + gpu.module(math-extend-to-supported-types{target-type=f32}) + gpu.module(arith-emulate-unsupported-floats{source-types=bf16 target-type=f32}) + spirv-attach-target{ver=v1.0 caps=Addresses,BFloat16TypeKHR,Float16Buffer,Int64,Int16,Int8,Kernel,Linkage,Vector16,GenericPointer,Groups,Float16,Float64,AtomicFloat32AddEXT,ExpectAssumeKHR,SubgroupDispatch,VectorComputeINTEL,VectorAnyINTEL,Bfloat16ConversionINTEL exts=SPV_EXT_shader_atomic_float_add,SPV_KHR_bfloat16,SPV_KHR_expect_assume,SPV_INTEL_vector_compute,SPV_INTEL_bfloat16_conversion} + imex-convert-to-spirv{use-64bit-index=true} + gpu.module(spirv.module(spirv-lower-abi-attrs, spirv-update-vce)) func.func(llvm-request-c-wrappers) - serialize-spirv convert-vector-to-scf - convert-gpu-to-gpux convert-scf-to-cf + func.func(gpu-async-region) expand-strided-metadata + gpu-to-llvm{use-bare-pointers-for-kernels=true} finalize-memref-to-llvm - convert-cf-to-llvm - convert-vector-to-llvm - convert-index-to-llvm - convert-arith-to-llvm - convert-func-to-llvm - convert-math-to-llvm - convert-gpux-to-llvm + convert-to-llvm + gpu-module-to-binary lower-affine reconcile-unrealized-casts) +// End diff --git a/test/Integration/Dialect/XeTile/gemm_k_oob.mlir b/test/Integration/Dialect/XeTile/gemm_k_oob.mlir index 0de2e9b93..920423431 100644 --- a/test/Integration/Dialect/XeTile/gemm_k_oob.mlir +++ b/test/Integration/Dialect/XeTile/gemm_k_oob.mlir @@ -1,35 +1,30 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck #map = affine_map<() -> (0)> #map1 = affine_map<() -> (100)> module @gemm attributes {gpu.container_module} { - func.func @gemm_k_oob_entry( - %A: memref<128x100xf16>, %B: memref<256x100xf16>) -> - memref<128x256xf32> attributes {llvm.emit_c_interface} { + func.func @gemm_k_oob_entry(%arg0: memref<128x100xf16>, %arg1: memref<256x100xf16>) -> memref<128x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %A_gpu = gpu.alloc host_shared () : memref<128x100xf16> - memref.copy %A, %A_gpu : memref<128x100xf16> to memref<128x100xf16> - %B_gpu = gpu.alloc host_shared () : memref<256x100xf16> - memref.copy %B, %B_gpu : memref<256x100xf16> to memref<256x100xf16> - %GEMM_gpu = gpu.alloc host_shared () : memref<128x256xf32> - gpu.launch_func @gemm_k_oob::@gemm_k_oob blocks in (%c2, %c1, %c1) threads in (%c4, %c8, %c1) - args(%A_gpu : memref<128x100xf16>, %B_gpu : memref<256x100xf16>, %GEMM_gpu : memref<128x256xf32>) - gpu.dealloc %A_gpu : memref<128x100xf16> - gpu.dealloc %B_gpu : memref<256x100xf16> - return %GEMM_gpu : memref<128x256xf32> + %memref = gpu.alloc () : memref<128x100xf16> + gpu.memcpy %memref, %arg0 : memref<128x100xf16>, memref<128x100xf16> + %memref_0 = gpu.alloc () : memref<256x100xf16> + gpu.memcpy %memref_0, %arg1 : memref<256x100xf16>, memref<256x100xf16> + %memref_1 = gpu.alloc () : memref<128x256xf32> + gpu.launch_func @gemm_k_oob::@gemm_k_oob blocks in (%c2, %c1, %c1) threads in (%c4, %c8, %c1) args(%memref : memref<128x100xf16>, %memref_0 : memref<256x100xf16>, %memref_1 : memref<128x256xf32>) + gpu.dealloc %memref : memref<128x100xf16> + gpu.dealloc %memref_0 : memref<256x100xf16> + %alloc = memref.alloc() : memref<128x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<128x256xf32>, memref<128x256xf32> + gpu.dealloc %memref_1 : memref<128x256xf32> + return %alloc : memref<128x256xf32> } - - gpu.module @gemm_k_oob attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @gemm_k_oob { gpu.func @gemm_k_oob(%arg0: memref<128x100xf16>, %arg1: memref<256x100xf16>, %arg2: memref<128x256xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { cf.br ^bb1 ^bb1: // pred: ^bb0 @@ -55,10 +50,10 @@ module @gemm attributes {gpu.container_module} { %8 = xetile.init_tile %arg0[%5, %c0] : memref<128x100xf16> -> !xetile.tile<32x32xf16> %9 = xetile.init_tile %arg1[%7, %c0] : memref<256x100xf16> -> !xetile.tile<32x32xf16> %10:3 = scf.for %arg3 = %c0 to %c100 step %c32 iter_args(%arg4 = %cst_0, %arg5 = %8, %arg6 = %9) -> (vector<32x32xf32>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>) { - %12 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xf16> - %13 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16> - %14 = xetile.load_tile %arg5 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %15 = xetile.load_tile %arg6 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %12 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xf16> + %13 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16> + %14 = xetile.load_tile %arg5 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %15 = xetile.load_tile %arg6 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> %16 = vector.transpose %15, [1, 0] : vector<32x32xf16> to vector<32x32xf16> xegpu.compile_hint %17 = math.exp %14 : vector<32x32xf16> @@ -76,49 +71,46 @@ module @gemm attributes {gpu.container_module} { gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 738.90564 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c100 = arith.constant 100 : index %c256 = arith.constant 256 : index %c128 = arith.constant 128 : index - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<128x100xf16> - %B = memref.alloc() : memref<256x100xf16> - %GEMM_ref = memref.alloc() : memref<128x256xf32> // intialize matrix A with ones - scf.for %i = %c0 to %c128 step %c1 { - scf.for %j = %c0 to %c100 step %c1 { - memref.store %cf_1, %A[%i, %j] : memref<128x100xf16> + %cst_0 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<128x100xf16> + %alloc_1 = memref.alloc() : memref<256x100xf16> + %alloc_2 = memref.alloc() : memref<128x256xf32> + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c100 step %c1 { + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<128x100xf16> } } // intialize matrix B with ones - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c100 step %c1 { - memref.store %cf_1, %B[%i, %j] : memref<256x100xf16> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c100 step %c1 { + memref.store %cst_0, %alloc_1[%arg0, %arg1] : memref<256x100xf16> } } // intialize matrix GEMM_ref - %cf_result = arith.constant 738.90560989 : f32 - scf.for %i = %c0 to %c128 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - memref.store %cf_result, %GEMM_ref[%i, %j] : memref<128x256xf32> + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<128x256xf32> } } - - %GEMM = call @gemm_k_oob_entry(%A, %B) : (memref<128x100xf16>, memref<256x100xf16>) -> memref<128x256xf32> - %cast_GEMM = memref.cast %GEMM : memref<128x256xf32> to memref<*xf32> - %cast_GEMM_ref = memref.cast %GEMM_ref : memref<128x256xf32> to memref<*xf32> // CHECK: Max absolute error 0. // CHECK: Max relative error 0.00 - call @printMaxErrorF32(%cast_GEMM, %cast_GEMM_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<128x100xf16> - memref.dealloc %B : memref<256x100xf16> - memref.dealloc %GEMM_ref : memref<128x256xf32> + %0 = call @gemm_k_oob_entry(%alloc, %alloc_1) : (memref<128x100xf16>, memref<256x100xf16>) -> memref<128x256xf32> + %cast = memref.cast %0 : memref<128x256xf32> to memref<*xf32> + %cast_3 = memref.cast %alloc_2 : memref<128x256xf32> to memref<*xf32> + call @printMaxErrorF32(%cast, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<128x100xf16> + memref.dealloc %alloc_1 : memref<256x100xf16> + memref.dealloc %alloc_2 : memref<128x256xf32> return } - func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/gemm_output_f16.mlir b/test/Integration/Dialect/XeTile/gemm_output_f16.mlir index 454688d53..96f229b3d 100644 --- a/test/Integration/Dialect/XeTile/gemm_output_f16.mlir +++ b/test/Integration/Dialect/XeTile/gemm_output_f16.mlir @@ -1,40 +1,32 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck #map = affine_map<() -> (0)> #map1 = affine_map<() -> (96)> module @gemm_output_f16 attributes {gpu.container_module} { - func.func @gemm_output_f16_entry( - %A: memref<128x96xf16>, %B: memref<256x96xf16>, %POSTOP: memref<128x256xf16>) -> - memref<128x256xf16> attributes {llvm.emit_c_interface} { + func.func @gemm_output_f16_entry(%arg0: memref<128x96xf16>, %arg1: memref<256x96xf16>, %arg2: memref<128x256xf16>) -> memref<128x256xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %A_gpu = gpu.alloc host_shared () : memref<128x96xf16> - memref.copy %A, %A_gpu : memref<128x96xf16> to memref<128x96xf16> - %B_gpu = gpu.alloc host_shared () : memref<256x96xf16> - memref.copy %B, %B_gpu : memref<256x96xf16> to memref<256x96xf16> - %POSTOP_gpu = gpu.alloc host_shared () : memref<128x256xf16> - memref.copy %POSTOP, %POSTOP_gpu : memref<128x256xf16> to memref<128x256xf16> - %OUTPUT_gpu = gpu.alloc host_shared () : memref<128x256xf16> - gpu.launch_func @gemm_output_f16::@gemm_output_f16 blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) - args(%A_gpu : memref<128x96xf16>, - %B_gpu : memref<256x96xf16>, - %POSTOP_gpu : memref<128x256xf16>, - %OUTPUT_gpu : memref<128x256xf16>) - gpu.dealloc %A_gpu : memref<128x96xf16> - gpu.dealloc %B_gpu : memref<256x96xf16> - gpu.dealloc %POSTOP_gpu : memref<128x256xf16> - return %OUTPUT_gpu : memref<128x256xf16> + %memref = gpu.alloc () : memref<128x96xf16> + gpu.memcpy %memref, %arg0 : memref<128x96xf16>, memref<128x96xf16> + %memref_0 = gpu.alloc () : memref<256x96xf16> + gpu.memcpy %memref_0, %arg1 : memref<256x96xf16>, memref<256x96xf16> + %memref_1 = gpu.alloc () : memref<128x256xf16> + gpu.memcpy %memref_1, %arg2 : memref<128x256xf16>, memref<128x256xf16> + %memref_2 = gpu.alloc () : memref<128x256xf16> + gpu.launch_func @gemm_output_f16::@gemm_output_f16 blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%memref : memref<128x96xf16>, %memref_0 : memref<256x96xf16>, %memref_1 : memref<128x256xf16>, %memref_2 : memref<128x256xf16>) + gpu.dealloc %memref : memref<128x96xf16> + gpu.dealloc %memref_0 : memref<256x96xf16> + gpu.dealloc %memref_1 : memref<128x256xf16> + %alloc = memref.alloc() : memref<128x256xf16> + gpu.memcpy %alloc, %memref_2 : memref<128x256xf16>, memref<128x256xf16> + gpu.dealloc %memref_2 : memref<128x256xf16> + return %alloc : memref<128x256xf16> } - - gpu.module @gemm_output_f16 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @gemm_output_f16 { gpu.func @gemm_output_f16(%arg0: memref<128x96xf16>, %arg1: memref<256x96xf16>, %arg2: memref<128x256xf16>, %arg3: memref<128x256xf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { cf.br ^bb1 ^bb1: // pred: ^bb0 @@ -59,10 +51,10 @@ module @gemm_output_f16 attributes {gpu.container_module} { %8 = xetile.init_tile %arg0[%5, %c0] : memref<128x96xf16> -> !xetile.tile<32x32xf16> %9 = xetile.init_tile %arg1[%7, %c0] : memref<256x96xf16> -> !xetile.tile<32x32xf16> %10:3 = scf.for %arg4 = %c0 to %c96 step %c32 iter_args(%arg5 = %cst, %arg6 = %8, %arg7 = %9) -> (vector<32x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>) { - %15 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16> - %16 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xf16> - %17 = xetile.load_tile %arg6 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %18 = xetile.load_tile %arg7 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %15 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16> + %16 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xf16> + %17 = xetile.load_tile %arg6 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %18 = xetile.load_tile %arg7 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> %19 = vector.transpose %18, [1, 0] : vector<32x32xf16> to vector<32x32xf16> xegpu.compile_hint xegpu.compile_hint @@ -71,64 +63,61 @@ module @gemm_output_f16 attributes {gpu.container_module} { scf.yield %20, %16, %15 : vector<32x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16> } {lowerBoundMap = #map, operandSegmentSizes = array, step = 32 : index, upperBoundMap = #map1} %11 = xetile.init_tile %arg2[%5, %7] : memref<128x256xf16> -> !xetile.tile<32x32xf16> - %12 = xetile.load_tile %11 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %12 = xetile.load_tile %11 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> %13 = arith.addf %10#0, %12 : vector<32x32xf16> %14 = xetile.init_tile %arg3[%5, %7] : memref<128x256xf16> -> !xetile.tile<32x32xf16> xetile.store_tile %13, %14 : vector<32x32xf16>, !xetile.tile<32x32xf16> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+02 : f16 + %cst_0 = arith.constant 4.000000e+00 : f16 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c96 = arith.constant 96 : index %c256 = arith.constant 256 : index %c128 = arith.constant 128 : index - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<128x96xf16> - %B = memref.alloc() : memref<256x96xf16> - %POSTOP = memref.alloc() : memref<128x256xf16> - %OUTPUT_ref = memref.alloc() : memref<128x256xf16> // intialize matrix A with ones - scf.for %i = %c0 to %c128 step %c1 { - scf.for %j = %c0 to %c96 step %c1 { - memref.store %cf_1, %A[%i, %j] : memref<128x96xf16> + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<128x96xf16> + %alloc_2 = memref.alloc() : memref<256x96xf16> + %alloc_3 = memref.alloc() : memref<128x256xf16> + %alloc_4 = memref.alloc() : memref<128x256xf16> + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c96 step %c1 { + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<128x96xf16> } } // intialize matrix B with ones - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c96 step %c1 { - memref.store %cf_1, %B[%i, %j] : memref<256x96xf16> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c96 step %c1 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<256x96xf16> } } // intialize matrix POSTOP (second operand of the postop) and OUTPUT_ref. - %cf_4 = arith.constant 4.0 : f16 - %cf_result = arith.constant 100.0 : f16 - scf.for %i = %c0 to %c128 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - memref.store %cf_4, %POSTOP[%i, %j] : memref<128x256xf16> - memref.store %cf_result, %OUTPUT_ref[%i, %j] : memref<128x256xf16> + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + memref.store %cst_0, %alloc_3[%arg0, %arg1] : memref<128x256xf16> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<128x256xf16> } } - - %OUTPUT = call @gemm_output_f16_entry(%A, %B, %POSTOP) : (memref<128x96xf16>, memref<256x96xf16>, memref<128x256xf16>) -> memref<128x256xf16> - %cast_OUTPUT = memref.cast %OUTPUT : memref<128x256xf16> to memref<*xf16> - %cast_OUTPUT_ref = memref.cast %OUTPUT_ref : memref<128x256xf16> to memref<*xf16> // TODO: investigate why printAllcloseF16 was returning false even when the // tensors are identical. It looks like an issue when comparing f16 values. // For now using printMaxErrorF16. // call @printAllcloseF16(%cast_OUTPUT, %cast_OUTPUT_ref) : (memref<*xf16>, memref<*xf16>) -> () // CHECK: Max absolute error 0 // CHECK: Max relative error 0 - call @printMaxErrorF16(%cast_OUTPUT, %cast_OUTPUT_ref) : (memref<*xf16>, memref<*xf16>) -> () - memref.dealloc %A : memref<128x96xf16> - memref.dealloc %B : memref<256x96xf16> - memref.dealloc %POSTOP : memref<128x256xf16> - memref.dealloc %OUTPUT_ref : memref<128x256xf16> + %0 = call @gemm_output_f16_entry(%alloc, %alloc_2, %alloc_3) : (memref<128x96xf16>, memref<256x96xf16>, memref<128x256xf16>) -> memref<128x256xf16> + %cast = memref.cast %0 : memref<128x256xf16> to memref<*xf16> + %cast_5 = memref.cast %alloc_4 : memref<128x256xf16> to memref<*xf16> + call @printMaxErrorF16(%cast, %cast_5) : (memref<*xf16>, memref<*xf16>) -> () + memref.dealloc %alloc : memref<128x96xf16> + memref.dealloc %alloc_2 : memref<256x96xf16> + memref.dealloc %alloc_3 : memref<128x256xf16> + memref.dealloc %alloc_4 : memref<128x256xf16> return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @printMaxErrorF16(memref<*xf16>, memref<*xf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/load_store_non_pow2.mlir b/test/Integration/Dialect/XeTile/load_store_non_pow2.mlir index efbc8b75b..f55694c24 100644 --- a/test/Integration/Dialect/XeTile/load_store_non_pow2.mlir +++ b/test/Integration/Dialect/XeTile/load_store_non_pow2.mlir @@ -1,77 +1,60 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<384x64xf32>, %B: memref<384x64xf32>) -> memref<384x64xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<384x64xf32>, %arg1: memref<384x64xf32>) -> memref<384x64xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %A_gpu = gpu.alloc host_shared () : memref<384x64xf32> - memref.copy %A, %A_gpu : memref<384x64xf32> to memref<384x64xf32> - %B_gpu = gpu.alloc host_shared () : memref<384x64xf32> - memref.copy %B, %B_gpu : memref<384x64xf32> to memref<384x64xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<384x64xf32>, %B_gpu : memref<384x64xf32>) - gpu.dealloc %A_gpu : memref<384x64xf32> - return %B_gpu : memref<384x64xf32> + %memref = gpu.alloc () : memref<384x64xf32> + gpu.memcpy %memref, %arg0 : memref<384x64xf32>, memref<384x64xf32> + %memref_0 = gpu.alloc () : memref<384x64xf32> + gpu.memcpy %memref_0, %arg1 : memref<384x64xf32>, memref<384x64xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<384x64xf32>, %memref_0 : memref<384x64xf32>) + gpu.dealloc %memref : memref<384x64xf32> + %alloc = memref.alloc() : memref<384x64xf32> + gpu.memcpy %alloc, %memref_0 : memref<384x64xf32>, memref<384x64xf32> + gpu.dealloc %memref_0 : memref<384x64xf32> + return %alloc : memref<384x64xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<384x64xf32>, %B: memref<384x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { /// canonicalize + gpu.func @test_kernel(%arg0: memref<384x64xf32>, %arg1: memref<384x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - - %a_tile = xetile.init_tile %A[%c0, %c0] : memref<384x64xf32> -> !xetile.tile<384x64xf32> - %b_tile = xetile.init_tile %B[%c0, %c0] : memref<384x64xf32> -> !xetile.tile<384x64xf32> - - %a_value = xetile.load_tile %a_tile : !xetile.tile<384x64xf32> -> vector<384x64xf32> - xetile.store_tile %a_value, %b_tile : vector<384x64xf32>, !xetile.tile<384x64xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = xetile.init_tile %arg0[%c0, %c0] : memref<384x64xf32> -> !xetile.tile<384x64xf32> + %1 = xetile.init_tile %arg1[%c0, %c0] : memref<384x64xf32> -> !xetile.tile<384x64xf32> + %2 = xetile.load_tile %0 : !xetile.tile<384x64xf32> -> vector<384x64xf32> + xetile.store_tile %2, %1 : vector<384x64xf32>, !xetile.tile<384x64xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_0_f32 = arith.constant 0.0 : f32 - %cf_2_f32 = arith.constant 2.0 : f32 - %cf_1 = arith.constant 1.0 : f16 // TRY 385x64 - %A = memref.alloc() : memref<384x64xf32> - %B = memref.alloc() : memref<384x64xf32> - // fill A with 2, B with 0 - %A_nonzero = memref.cast %A : memref<384x64xf32> to memref<*xf32> - %B_zeros = memref.cast %B : memref<384x64xf32> to memref<*xf32> - call @fillResource1DF32(%A_nonzero, %cf_2_f32) : (memref<*xf32>, f32) -> () - call @fillResource1DF32(%B_zeros, %cf_0_f32) : (memref<*xf32>, f32) -> () // Load from A, store to B - %2 = call @test(%A, %B) : (memref<384x64xf32>, memref<384x64xf32>) -> memref<384x64xf32> - - %B_filled = memref.cast %2 : memref<384x64xf32> to memref<*xf32> // call @printMemrefF32(%A_nonzero) : (memref<*xf32>) -> () // call @printMemrefF32(%B_filled) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%A_nonzero, %B_filled) : (memref<*xf32>, memref<*xf32>) -> () - - memref.dealloc %A : memref<384x64xf32> - memref.dealloc %B : memref<384x64xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 2.000000e+00 : f32 + %alloc = memref.alloc() : memref<384x64xf32> + %alloc_1 = memref.alloc() : memref<384x64xf32> + %cast = memref.cast %alloc : memref<384x64xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<384x64xf32> to memref<*xf32> + call @fillResource1DF32(%cast, %cst_0) : (memref<*xf32>, f32) -> () + call @fillResource1DF32(%cast_2, %cst) : (memref<*xf32>, f32) -> () + %0 = call @test(%alloc, %alloc_1) : (memref<384x64xf32>, memref<384x64xf32>) -> memref<384x64xf32> + %cast_3 = memref.cast %0 : memref<384x64xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<384x64xf32> + memref.dealloc %alloc_1 : memref<384x64xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir b/test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir index 9040b1d81..0d715f743 100644 --- a/test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir +++ b/test/Integration/Dialect/XeTile/sg_add_scattered_ops.mlir @@ -1,62 +1,51 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024xf32>, %B: memref<1024xf32>) -> memref<1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) -> memref<1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %A_gpu = gpu.alloc host_shared () : memref<1024xf32> - memref.copy %A, %A_gpu : memref<1024xf32> to memref<1024xf32> - %B_gpu = gpu.alloc host_shared () : memref<1024xf32> - memref.copy %B, %B_gpu : memref<1024xf32> to memref<1024xf32> - %C_gpu = gpu.alloc host_shared () : memref<1024xf32> - gpu.launch_func @test_kernel::@add_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024xf32>, %B_gpu : memref<1024xf32>, %C_gpu : memref<1024xf32>) - gpu.dealloc %A_gpu : memref<1024xf32> - gpu.dealloc %B_gpu : memref<1024xf32> - return %C_gpu : memref<1024xf32> + %memref = gpu.alloc () : memref<1024xf32> + gpu.memcpy %memref, %arg0 : memref<1024xf32>, memref<1024xf32> + %memref_0 = gpu.alloc () : memref<1024xf32> + gpu.memcpy %memref_0, %arg1 : memref<1024xf32>, memref<1024xf32> + %memref_1 = gpu.alloc () : memref<1024xf32> + gpu.launch_func @test_kernel::@add_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024xf32>, %memref_0 : memref<1024xf32>, %memref_1 : memref<1024xf32>) + gpu.dealloc %memref : memref<1024xf32> + gpu.dealloc %memref_0 : memref<1024xf32> + %alloc = memref.alloc() : memref<1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024xf32>, memref<1024xf32> + gpu.dealloc %memref_1 : memref<1024xf32> + return %alloc : memref<1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @add_kernel(%A: memref<1024xf32>, %B: memref<1024xf32>, %C: memref<1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @add_kernel(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %indices = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]>: vector<1x32xindex> - %offsets = arith.constant dense<32>: vector<1x32xindex> - %mask = arith.constant dense: vector<1x32xi1> - - %a_init_tile = xetile.init_tile %A, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> - %b_init_tile = xetile.init_tile %B, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> - %c_init_tile = xetile.init_tile %C, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> // %c_init_tile = xetile.init_tile %C[0, 0] : memref<1024xf32> -> !xetile.tile<1x32xf32> - - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_tile = %c_init_tile) - -> (!xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>) { - // load A and B tiles - %a_value = xetile.load %a_tile, %mask : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> - %b_value = xetile.load %b_tile, %mask : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> - %c_value = arith.addf %a_value, %b_value : vector<1x32xf32> - // xetile.store_tile %c_value, %c_tile : vector<1x32xf32>, !xetile.tile<1x32xf32> - xetile.store %c_value, %c_tile, %mask : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> - - %a_next_tile = xetile.update_tile_offset %a_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> - %b_next_tile = xetile.update_tile_offset %b_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> - %c_next_tile = xetile.update_tile_offset %c_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> // %c_next_tile = xetile.update_tile_offset %c_tile, [%c0, %c32] : !xetile.tile<1x32xf32> - - scf.yield %a_next_tile, %b_next_tile, %c_next_tile - : !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr> + %cst = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : vector<1x32xindex> + %cst_0 = arith.constant dense<32> : vector<1x32xindex> + %cst_1 = arith.constant dense : vector<1x32xi1> + %0 = xetile.init_tile %arg0, %cst : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1, %cst : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %2 = xetile.init_tile %arg2, %cst : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %3:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %2) -> (!xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>) { + %4 = xetile.load %arg4, %cst_1 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> + %5 = xetile.load %arg5, %cst_1 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> + %6 = arith.addf %4, %5 : vector<1x32xf32> + xetile.store %6, %arg6, %cst_1 : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> + %7 = xetile.update_tile_offset %arg4, %cst_0 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + %8 = xetile.update_tile_offset %arg5, %cst_0 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + %9 = xetile.update_tile_offset %arg6, %cst_0 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + scf.yield %7, %8, %9 : !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr> } gpu.return } @@ -65,35 +54,32 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %A = memref.alloc() : memref<1024xf32> - %B = memref.alloc() : memref<1024xf32> - %C_ref = memref.alloc() : memref<1024xf32> // intialize matrix A ; - scf.for %i = %c0 to %c1024 step %c1 { - %t = index.castu %i : index to i32 - %val = arith.uitofp %t : i32 to f32 - memref.store %val, %A[%i] : memref<1024xf32> - memref.store %val, %B[%i] : memref<1024xf32> + %alloc = memref.alloc() : memref<1024xf32> + %alloc_0 = memref.alloc() : memref<1024xf32> + %alloc_1 = memref.alloc() : memref<1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = arith.uitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0] : memref<1024xf32> + memref.store %2, %alloc_0[%arg0] : memref<1024xf32> } - // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - %a_val = memref.load %A[%i] : memref<1024xf32> - %b_val = memref.load %B[%i] : memref<1024xf32> - %c_val = arith.addf %a_val, %b_val : f32 - memref.store %c_val, %C_ref[%i] : memref<1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc[%arg0] : memref<1024xf32> + %2 = memref.load %alloc_0[%arg0] : memref<1024xf32> + %3 = arith.addf %1, %2 : f32 + memref.store %3, %alloc_1[%arg0] : memref<1024xf32> } - %2 = call @test(%A, %B) : (memref<1024xf32>, memref<1024xf32>) -> memref<1024xf32> - %cast_C = memref.cast %2 : memref<1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024xf32> - memref.dealloc %B : memref<1024xf32> - memref.dealloc %C_ref : memref<1024xf32> + %0 = call @test(%alloc, %alloc_0) : (memref<1024xf32>, memref<1024xf32>) -> memref<1024xf32> + %cast = memref.cast %0 : memref<1024xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_2) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024xf32> + memref.dealloc %alloc_0 : memref<1024xf32> + memref.dealloc %alloc_1 : memref<1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_atomic_rmw.mlir b/test/Integration/Dialect/XeTile/sg_atomic_rmw.mlir index ae71ada63..c10d42271 100644 --- a/test/Integration/Dialect/XeTile/sg_atomic_rmw.mlir +++ b/test/Integration/Dialect/XeTile/sg_atomic_rmw.mlir @@ -1,57 +1,46 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024xf32>, %B: memref<1024xf32>) -> memref<1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) -> memref<1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %A_gpu = gpu.alloc host_shared () : memref<1024xf32> - memref.copy %A, %A_gpu : memref<1024xf32> to memref<1024xf32> - %B_gpu = gpu.alloc host_shared () : memref<1024xf32> - memref.copy %B, %B_gpu : memref<1024xf32> to memref<1024xf32> - %C_gpu = gpu.alloc host_shared () : memref<1024xf32> - gpu.launch_func @test_kernel::@add_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024xf32>, %B_gpu : memref<1024xf32>, %C_gpu : memref<1024xf32>) - gpu.dealloc %B_gpu : memref<1024xf32> - return %A_gpu : memref<1024xf32> + %memref = gpu.alloc () : memref<1024xf32> + gpu.memcpy %memref, %arg0 : memref<1024xf32>, memref<1024xf32> + %memref_0 = gpu.alloc () : memref<1024xf32> + gpu.memcpy %memref_0, %arg1 : memref<1024xf32>, memref<1024xf32> + %memref_1 = gpu.alloc () : memref<1024xf32> + gpu.launch_func @test_kernel::@add_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024xf32>, %memref_0 : memref<1024xf32>, %memref_1 : memref<1024xf32>) + gpu.dealloc %memref_0 : memref<1024xf32> + %alloc = memref.alloc() : memref<1024xf32> + gpu.memcpy %alloc, %memref : memref<1024xf32>, memref<1024xf32> + gpu.dealloc %memref : memref<1024xf32> + return %alloc : memref<1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @add_kernel(%A: memref<1024xf32>, %B: memref<1024xf32>, %C: memref<1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @add_kernel(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %indices = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]>: vector<1x32xindex> - %offsets = arith.constant dense<32>: vector<1x32xindex> - %mask = arith.constant dense: vector<1x32xi1> - - %a_init_tile = xetile.init_tile %A, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> - %b_init_tile = xetile.init_tile %B, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> - %c_init_tile = xetile.init_tile %C, %indices : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> - - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_tile = %c_init_tile) - -> (!xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>) { - // load A and B tiles - %b_value = xetile.load %b_tile, %mask : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> - %c_value = xetile.atomic_rmw addf %b_value, %a_tile : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr> -> vector<1x32xf32> - - xetile.store %c_value, %c_tile, %mask : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> - - %a_next_tile = xetile.update_tile_offset %a_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> - %b_next_tile = xetile.update_tile_offset %b_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> - %c_next_tile = xetile.update_tile_offset %c_tile, %offsets : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> - - scf.yield %a_next_tile, %b_next_tile, %c_next_tile - : !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr> + %cst = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : vector<1x32xindex> + %cst_0 = arith.constant dense<32> : vector<1x32xindex> + %cst_1 = arith.constant dense : vector<1x32xi1> + %0 = xetile.init_tile %arg0, %cst : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %1 = xetile.init_tile %arg1, %cst : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %2 = xetile.init_tile %arg2, %cst : memref<1024xf32>, vector<1x32xindex> -> !xetile.tile<1x32xf32, #xetile.tile_attr> + %3:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %2) -> (!xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>) { + %4 = xetile.load %arg5, %cst_1 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> -> vector<1x32xf32> + %5 = xetile.atomic_rmw addf %4, %arg4 : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr> -> vector<1x32xf32> + xetile.store %5, %arg6, %cst_1 : vector<1x32xf32>, !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xi1> + %6 = xetile.update_tile_offset %arg4, %cst_0 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + %7 = xetile.update_tile_offset %arg5, %cst_0 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + %8 = xetile.update_tile_offset %arg6, %cst_0 : !xetile.tile<1x32xf32, #xetile.tile_attr>, vector<1x32xindex> + scf.yield %6, %7, %8 : !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr>, !xetile.tile<1x32xf32, #xetile.tile_attr> } gpu.return } @@ -60,36 +49,33 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %A = memref.alloc() : memref<1024xf32> - %B = memref.alloc() : memref<1024xf32> - %C_ref = memref.alloc() : memref<1024xf32> // intialize matrix A ; - scf.for %i = %c0 to %c1024 step %c1 { - %t = index.castu %i : index to i32 - %val = arith.uitofp %t : i32 to f32 - memref.store %val, %A[%i] : memref<1024xf32> - memref.store %val, %B[%i] : memref<1024xf32> + %alloc = memref.alloc() : memref<1024xf32> + %alloc_0 = memref.alloc() : memref<1024xf32> + %alloc_1 = memref.alloc() : memref<1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = arith.uitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0] : memref<1024xf32> + memref.store %2, %alloc_0[%arg0] : memref<1024xf32> } - // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - %a_val = memref.load %A[%i] : memref<1024xf32> - %b_val = memref.load %B[%i] : memref<1024xf32> - %c_val = arith.addf %a_val, %b_val : f32 - memref.store %c_val, %C_ref[%i] : memref<1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc[%arg0] : memref<1024xf32> + %2 = memref.load %alloc_0[%arg0] : memref<1024xf32> + %3 = arith.addf %1, %2 : f32 + memref.store %3, %alloc_1[%arg0] : memref<1024xf32> } - %2 = call @test(%A, %B) : (memref<1024xf32>, memref<1024xf32>) -> memref<1024xf32> - %cast_C = memref.cast %2 : memref<1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024xf32> to memref<*xf32> //call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () //call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024xf32> - memref.dealloc %B : memref<1024xf32> - memref.dealloc %C_ref : memref<1024xf32> + %0 = call @test(%alloc, %alloc_0) : (memref<1024xf32>, memref<1024xf32>) -> memref<1024xf32> + %cast = memref.cast %0 : memref<1024xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_2) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024xf32> + memref.dealloc %alloc_0 : memref<1024xf32> + memref.dealloc %alloc_1 : memref<1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_coop_transpose.mlir b/test/Integration/Dialect/XeTile/sg_coop_transpose.mlir index b46367dfb..b33f1238b 100644 --- a/test/Integration/Dialect/XeTile/sg_coop_transpose.mlir +++ b/test/Integration/Dialect/XeTile/sg_coop_transpose.mlir @@ -1,52 +1,45 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index - %A_gpu = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %A, %A_gpu : memref<32x32xf16> to memref<32x32xf16> - %B_gpu = gpu.alloc host_shared () : memref<32x32xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c2, %c1) args(%A_gpu : memref<32x32xf16>, %B_gpu : memref<32x32xf16>) - gpu.dealloc %A_gpu : memref<32x32xf16> - return %B_gpu : memref<32x32xf16> + %memref = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c2, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>) + gpu.dealloc %memref : memref<32x32xf16> + %alloc = memref.alloc() : memref<32x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf16>, memref<32x32xf16> + gpu.dealloc %memref_0 : memref<32x32xf16> + return %alloc : memref<32x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - - %tid_x = gpu.thread_id x - %tid_y = gpu.thread_id y - %m = arith.muli %tid_x, %c8 : index - %n = arith.muli %tid_y, %c16 : index - - %a_tile = xetile.init_tile %A[%m, %n] : memref<32x32xf16> -> !xetile.tile<8x16xf16> - %a = xetile.load_tile %a_tile : !xetile.tile<8x16xf16> -> vector<8x16xf16> - - %slm = memref.alloc() : memref<2048xi8, 3> - %view = memref.view %slm[%c0][] : memref<2048xi8, 3> to memref<32x32xf16, 3> - - %trans_slm = memref.transpose %view (i, j) -> (j, i) : memref<32x32xf16, 3> to memref<32x32xf16, strided<[1, 32], offset:0>, 3> - %st_tile = xetile.init_tile %trans_slm[%m, %n] : memref<32x32xf16, strided<[1, 32], offset:0>, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> - xetile.store_tile %a, %st_tile : vector<8x16xf16>, !xetile.tile<8x16xf16, #xetile.tile_attr> + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %0 = arith.muli %thread_id_x, %c8 : index + %1 = arith.muli %thread_id_y, %c16 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<32x32xf16> -> !xetile.tile<8x16xf16> + %3 = xetile.load_tile %2 : !xetile.tile<8x16xf16> -> vector<8x16xf16> + %alloc = memref.alloc() : memref<2048xi8, 3> + %view = memref.view %alloc[%c0][] : memref<2048xi8, 3> to memref<32x32xf16, 3> + %transpose = memref.transpose %view (d0, d1) -> (d1, d0) : memref<32x32xf16, 3> to memref<32x32xf16, strided<[1, 32]>, 3> + %4 = xetile.init_tile %transpose[%0, %1] : memref<32x32xf16, strided<[1, 32]>, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> + xetile.store_tile %3, %4 : vector<8x16xf16>, !xetile.tile<8x16xf16, #xetile.tile_attr> gpu.barrier - - %ld_tile = xetile.init_tile %view[%m, %n] : memref<32x32xf16, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> - %d = xetile.load_tile %ld_tile : !xetile.tile<8x16xf16, #xetile.tile_attr> -> vector<8x16xf16> - - %b_tile = xetile.init_tile %B[%m, %n] : memref<32x32xf16> -> !xetile.tile<8x16xf16> - xetile.store_tile %d, %b_tile: vector<8x16xf16>, !xetile.tile<8x16xf16> + %5 = xetile.init_tile %view[%0, %1] : memref<32x32xf16, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> + %6 = xetile.load_tile %5 : !xetile.tile<8x16xf16, #xetile.tile_attr> -> vector<8x16xf16> + %7 = xetile.init_tile %arg1[%0, %1] : memref<32x32xf16> -> !xetile.tile<8x16xf16> + xetile.store_tile %6, %7 : vector<8x16xf16>, !xetile.tile<8x16xf16> gpu.return } } @@ -54,32 +47,28 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<32x32xf16> - %Ref = memref.alloc() : memref<32x32xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %mul = arith.muli %i, %c32 : index - %add = arith.addi %mul, %j : index - %t = index.castu %add : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<32x32xf16> - %t32 = index.castu %add : index to i32 - %val32 = arith.uitofp %t32 : i32 to f32 - memref.store %val32, %Ref[%j, %i] : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf16> + %alloc_0 = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + memref.store %4, %alloc[%arg0, %arg1] : memref<32x32xf16> + %5 = index.castu %2 : index to i32 + %6 = arith.uitofp %5 : i32 to f32 + memref.store %6, %alloc_0[%arg1, %arg0] : memref<32x32xf32> } } - - %B = call @test(%A) : (memref<32x32xf16>) -> memref<32x32xf16> - %cast = memref.cast %B : memref<32x32xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_ref = memref.cast %Ref : memref<32x32xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast, %cast_ref) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<32x32xf16> + %0 = call @test(%alloc) : (memref<32x32xf16>) -> memref<32x32xf16> + %cast = memref.cast %0 : memref<32x32xf16> to memref<*xf16> + %cast_1 = memref.cast %alloc_0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_1) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<32x32xf16> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_copy_via_slm.mlir b/test/Integration/Dialect/XeTile/sg_copy_via_slm.mlir index 8dc487e05..e18c24b24 100644 --- a/test/Integration/Dialect/XeTile/sg_copy_via_slm.mlir +++ b/test/Integration/Dialect/XeTile/sg_copy_via_slm.mlir @@ -1,51 +1,45 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : This test simply load a tile from A and store it to SLM, and load it back from SLM // and store it to B, to verify the correctness of SLM support. module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<64x64xf16>) -> memref<64x64xf16> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<64x64xf16>) -> memref<64x64xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %A_gpu = gpu.alloc host_shared () : memref<64x64xf16> - memref.copy %A, %A_gpu : memref<64x64xf16> to memref<64x64xf16> - %B_gpu = gpu.alloc host_shared () : memref<64x64xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<64x64xf16>, %B_gpu : memref<64x64xf16>) - gpu.dealloc %A_gpu : memref<64x64xf16> - return %B_gpu : memref<64x64xf16> + %memref = gpu.alloc () : memref<64x64xf16> + gpu.memcpy %memref, %arg0 : memref<64x64xf16>, memref<64x64xf16> + %memref_0 = gpu.alloc () : memref<64x64xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<64x64xf16>, %memref_0 : memref<64x64xf16>) + gpu.dealloc %memref : memref<64x64xf16> + %alloc = memref.alloc() : memref<64x64xf16> + gpu.memcpy %alloc, %memref_0 : memref<64x64xf16>, memref<64x64xf16> + gpu.dealloc %memref_0 : memref<64x64xf16> + return %alloc : memref<64x64xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<64x64xf16>, %B: memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<64x64xf16>, %arg1: memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - - %tid_x = gpu.thread_id x - %tid_y = gpu.thread_id y - %m = arith.muli %tid_x, %c8 : index - %n = arith.muli %tid_y, %c16 : index - - %a_tile = xetile.init_tile %A[%m, %n] : memref<64x64xf16> -> !xetile.tile<8x16xf16> - %a = xetile.load_tile %a_tile : !xetile.tile<8x16xf16> -> vector<8x16xf16> - - %slm = memref.alloc() : memref<8192xi8, 3> - %view = memref.view %slm[%c0][] : memref<8192xi8, 3> to memref<64x64xf16, 3> - %st_tile = xetile.init_tile %view[%m, %n] : memref<64x64xf16, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> - xetile.store_tile %a, %st_tile : vector<8x16xf16>, !xetile.tile<8x16xf16, #xetile.tile_attr> - - %ld_tile = xetile.init_tile %view[%m, %n] : memref<64x64xf16, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> - %d = xetile.load_tile %ld_tile : !xetile.tile<8x16xf16, #xetile.tile_attr> -> vector<8x16xf16> - - %b_tile = xetile.init_tile %B[%m, %n] : memref<64x64xf16> -> !xetile.tile<8x16xf16> - xetile.store_tile %d, %b_tile: vector<8x16xf16>, !xetile.tile<8x16xf16> + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %0 = arith.muli %thread_id_x, %c8 : index + %1 = arith.muli %thread_id_y, %c16 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<64x64xf16> -> !xetile.tile<8x16xf16> + %3 = xetile.load_tile %2 : !xetile.tile<8x16xf16> -> vector<8x16xf16> + %alloc = memref.alloc() : memref<8192xi8, 3> + %view = memref.view %alloc[%c0][] : memref<8192xi8, 3> to memref<64x64xf16, 3> + %4 = xetile.init_tile %view[%0, %1] : memref<64x64xf16, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> + xetile.store_tile %3, %4 : vector<8x16xf16>, !xetile.tile<8x16xf16, #xetile.tile_attr> + %5 = xetile.init_tile %view[%0, %1] : memref<64x64xf16, 3> -> !xetile.tile<8x16xf16, #xetile.tile_attr> + %6 = xetile.load_tile %5 : !xetile.tile<8x16xf16, #xetile.tile_attr> -> vector<8x16xf16> + %7 = xetile.init_tile %arg1[%0, %1] : memref<64x64xf16> -> !xetile.tile<8x16xf16> + xetile.store_tile %6, %7 : vector<8x16xf16>, !xetile.tile<8x16xf16> gpu.return } } @@ -53,29 +47,25 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<64x64xf16> - %Ref = memref.alloc() : memref<64x64xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c64 step %c1 { - scf.for %j = %c0 to %c64 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<64x64xf16> - %val_f32 = arith.uitofp %t : i16 to f32 - memref.store %val_f32, %Ref[%i, %j] : memref<64x64xf32> + %alloc = memref.alloc() : memref<64x64xf16> + %alloc_0 = memref.alloc() : memref<64x64xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + scf.for %arg1 = %c0 to %c64 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<64x64xf16> + %3 = arith.uitofp %1 : i16 to f32 + memref.store %3, %alloc_0[%arg0, %arg1] : memref<64x64xf32> } } - - %B = call @test(%A) : (memref<64x64xf16>) -> memref<64x64xf16> - %cast = memref.cast %B : memref<64x64xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_ref = memref.cast %Ref : memref<64x64xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast, %cast_ref) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<64x64xf16> + %0 = call @test(%alloc) : (memref<64x64xf16>) -> memref<64x64xf16> + %cast = memref.cast %0 : memref<64x64xf16> to memref<*xf16> + %cast_1 = memref.cast %alloc_0 : memref<64x64xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_1) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<64x64xf16> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gather_scatter_slm.mlir b/test/Integration/Dialect/XeTile/sg_gather_scatter_slm.mlir index 9efef9fe2..5b013faca 100644 --- a/test/Integration/Dialect/XeTile/sg_gather_scatter_slm.mlir +++ b/test/Integration/Dialect/XeTile/sg_gather_scatter_slm.mlir @@ -1,34 +1,32 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { // a test case case return the transpose of A, which is viewed as memref<32x32xf16>. // it uses one workgroup containing 32 subgroups, organized as (8x4), so each subgroup // works on a 4x8 tile of A. It used SLM to do the transpose, to evaluate the functionality // of the SLM operations. - func.func @test(%A: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %A_gpu = gpu.alloc host_shared () : memref<32x32xf16> - memref.copy %A, %A_gpu : memref<32x32xf16> to memref<32x32xf16> - %B_gpu = gpu.alloc host_shared () : memref<32x32xf16> - gpu.launch_func @test_kernel::@trans_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%A_gpu : memref<32x32xf16>, %B_gpu : memref<32x32xf16>) - gpu.dealloc %A_gpu : memref<32x32xf16> - return %B_gpu : memref<32x32xf16> + %memref = gpu.alloc () : memref<32x32xf16> + gpu.memcpy %memref, %arg0 : memref<32x32xf16>, memref<32x32xf16> + %memref_0 = gpu.alloc () : memref<32x32xf16> + gpu.launch_func @test_kernel::@trans_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%memref : memref<32x32xf16>, %memref_0 : memref<32x32xf16>) + gpu.dealloc %memref : memref<32x32xf16> + %alloc = memref.alloc() : memref<32x32xf16> + gpu.memcpy %alloc, %memref_0 : memref<32x32xf16>, memref<32x32xf16> + gpu.dealloc %memref_0 : memref<32x32xf16> + return %alloc : memref<32x32xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @trans_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @trans_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index @@ -36,57 +34,40 @@ module @gemm attributes {gpu.container_module} { %c8 = arith.constant 8 : index %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index - - %sgid = gpu.subgroup_id : index // %tid_y = arith.divui %sgid, %c4 : index // %tid_x = arith.remui %sgid, %c4 : index - %tid_y = arith.shrui %sgid, %c2 : index - %tid_x = arith.andi %sgid, %c3 : index - - %off_y = arith.muli %tid_y, %c4 : index - %off_x = arith.muli %tid_x, %c8 : index - // load data from global memory using block load - %a_tile = xetile.init_tile %A[%off_y, %off_x] : memref<32x32xf16> -> !xetile.tile<4x8xf16> - %data = xetile.load_tile %a_tile : !xetile.tile<4x8xf16> -> vector<4x8xf16> - // %slm = memref.alloc() : memref<32x32xf16, 3> // %cast = memref.reinterpret_cast %slm to offset: [0], sizes: [1024], strides: [1] : memref<32x32xf16, 3> to memref<1024xf16, 3> - - %slm = memref.alloc() : memref<1024xf16, 3> - - %mask = arith.constant dense: vector<4x8xi1> - // store data to slm using original layout - %base_indices = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7], - [32, 33, 34, 35, 36, 37, 38, 39], - [64, 65, 66, 67, 68, 69, 70, 71], - [96, 97, 98, 99, 100, 101, 102, 103]]>: vector<4x8xindex> - %off_y2 = arith.muli %tid_y, %c128 : index - %offset = arith.addi %off_y2, %off_x : index - %offsets = vector.broadcast %offset: index to vector<4x8xindex> - %indices = arith.addi %base_indices, %offsets : vector<4x8xindex> - %st_tile = xetile.init_tile %slm, %indices : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr> - xetile.store %data, %st_tile, %mask : vector<4x8xf16>, !xetile.tile<4x8xf16, #xetile.tile_attr>, vector<4x8xi1> - + %0 = gpu.subgroup_id : index + %1 = arith.shrui %0, %c2 : index + %2 = arith.andi %0, %c3 : index + %3 = arith.muli %1, %c4 : index + %4 = arith.muli %2, %c8 : index + %5 = xetile.init_tile %arg0[%3, %4] : memref<32x32xf16> -> !xetile.tile<4x8xf16> + %6 = xetile.load_tile %5 : !xetile.tile<4x8xf16> -> vector<4x8xf16> + %alloc = memref.alloc() : memref<1024xf16, 3> + %cst = arith.constant dense : vector<4x8xi1> + %cst_0 = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7], [32, 33, 34, 35, 36, 37, 38, 39], [64, 65, 66, 67, 68, 69, 70, 71], [96, 97, 98, 99, 100, 101, 102, 103]]> : vector<4x8xindex> + %7 = arith.muli %1, %c128 : index + %8 = arith.addi %7, %4 : index + %9 = vector.broadcast %8: index to vector<4x8xindex> + %10 = arith.addi %cst_0, %9 : vector<4x8xindex> + %11 = xetile.init_tile %alloc, %10 : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr> + xetile.store %6, %11, %cst : vector<4x8xf16>, !xetile.tile<4x8xf16, #xetile.tile_attr>, vector<4x8xi1> gpu.barrier - // load data from slm using indices with transpose effects - %trans_base_indices = arith.constant dense<[[0, 32, 64, 96, 128, 160, 192, 224], - [1, 33, 65, 97, 129, 161, 193, 225], - [2, 34, 66, 98, 130, 162, 194, 226], - [3, 35, 67, 99, 131, 163, 195, 227]]>: vector<4x8xindex> - - %trans_off_x = arith.muli %tid_x, %c256 : index - %trans_off_y = arith.muli %tid_y, %c4 : index - %trans_off = arith.addi %trans_off_x, %trans_off_y : index - %trans_offsets = vector.broadcast %trans_off: index to vector<4x8xindex> - %trans_indices = arith.addi %trans_base_indices, %trans_offsets : vector<4x8xindex> - %ld_tile = xetile.init_tile %slm, %trans_indices : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr> - %d = xetile.load %ld_tile, %mask : !xetile.tile<4x8xf16, #xetile.tile_attr>, vector<4x8xi1> -> vector<4x8xf16> - - %b_tile = xetile.init_tile %B[%off_y, %off_x] : memref<32x32xf16> -> !xetile.tile<4x8xf16> - xetile.store_tile %d, %b_tile: vector<4x8xf16>, !xetile.tile<4x8xf16> + %cst_1 = arith.constant dense<[[0, 32, 64, 96, 128, 160, 192, 224], [1, 33, 65, 97, 129, 161, 193, 225], [2, 34, 66, 98, 130, 162, 194, 226], [3, 35, 67, 99, 131, 163, 195, 227]]> : vector<4x8xindex> + %12 = arith.muli %2, %c256 : index + %13 = arith.muli %1, %c4 : index + %14 = arith.addi %12, %13 : index + %15 = vector.broadcast %14: index to vector<4x8xindex> + %16 = arith.addi %cst_1, %15 : vector<4x8xindex> + %17 = xetile.init_tile %alloc, %16 : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr> + %18 = xetile.load %17, %cst : !xetile.tile<4x8xf16, #xetile.tile_attr>, vector<4x8xi1> -> vector<4x8xf16> + %19 = xetile.init_tile %arg1[%3, %4] : memref<32x32xf16> -> !xetile.tile<4x8xf16> + xetile.store_tile %18, %19 : vector<4x8xf16>, !xetile.tile<4x8xf16> gpu.return } } @@ -94,30 +75,28 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %A = memref.alloc() : memref<32x32xf16> - %Ref = memref.alloc() : memref<32x32xf32> // intialize matrix A ; - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c32 step %c1 { - %m = arith.muli %i, %c32 : index - %a = arith.addi %m, %j : index - %v = index.castu %a : index to i16 - %val = arith.uitofp %v : i16 to f16 - memref.store %val, %A[%i, %j] : memref<32x32xf16> - %v32 = index.castu %a : index to i32 - %val32 = arith.uitofp %v32 : i32 to f32 - memref.store %val32, %Ref[%j, %i] : memref<32x32xf32> + %alloc = memref.alloc() : memref<32x32xf16> + %alloc_0 = memref.alloc() : memref<32x32xf32> + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c32 step %c1 { + %1 = arith.muli %arg0, %c32 : index + %2 = arith.addi %1, %arg1 : index + %3 = index.castu %2 : index to i16 + %4 = arith.uitofp %3 : i16 to f16 + memref.store %4, %alloc[%arg0, %arg1] : memref<32x32xf16> + %5 = index.castu %2 : index to i32 + %6 = arith.uitofp %5 : i32 to f32 + memref.store %6, %alloc_0[%arg1, %arg0] : memref<32x32xf32> } } - %B = call @test(%A) : (memref<32x32xf16>) -> memref<32x32xf16> - %cast = memref.cast %B : memref<32x32xf16> to memref<*xf16> - %Ref_cast = memref.cast %Ref : memref<32x32xf32> to memref<*xf32> //CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast, %Ref_cast) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<32x32xf16> - memref.dealloc %Ref : memref<32x32xf32> + %0 = call @test(%alloc) : (memref<32x32xf16>) -> memref<32x32xf16> + %cast = memref.cast %0 : memref<32x32xf16> to memref<*xf16> + %cast_1 = memref.cast %alloc_0 : memref<32x32xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_1) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<32x32xf16> + memref.dealloc %alloc_0 : memref<32x32xf32> return } func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir index 2fffb4263..72136e8da 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_f32.mlir @@ -1,149 +1,136 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %A, %A_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %B, %B_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xbf16>, %B_gpu : memref<1024x1024xbf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xbf16> - gpu.dealloc %B_gpu : memref<1024x1024xbf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_0 = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xbf16>, %memref_0 : memref<1024x1024xbf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xbf16> + gpu.dealloc %memref_0 : memref<1024x1024xbf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xbf16> -> !xetile.tile<16x32xbf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xbf16> -> !xetile.tile<32x32xbf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xbf16> -> vector<16x32xbf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<16x32xbf16>, vector<32x32xbf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<16x32xbf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] : !xetile.tile<32x32xbf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xbf16> -> !xetile.tile<16x32xbf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xbf16> -> !xetile.tile<32x32xbf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xbf16> -> vector<16x32xbf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + %9 = xetile.tile_mma %7, %8, %arg6 : vector<16x32xbf16>, vector<32x32xbf16>, vector<16x32xf32> -> vector<16x32xf32> + %10 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xbf16> + %11 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xbf16> + scf.yield %10, %11, %9 : !xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %A = memref.alloc() : memref<1024x1024xbf16> - %B = memref.alloc() : memref<1024x1024xbf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to bf16 - memref.store %val, %A[%i, %j] : memref<1024x1024xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %cst_1 = arith.constant 1.000000e+00 : bf16 + %alloc = memref.alloc() : memref<1024x1024xbf16> + %alloc_2 = memref.alloc() : memref<1024x1024xbf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to bf16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xbf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xbf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xbf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xbf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xbf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xbf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xbf16> - %t = arith.mulf %a_val, %b_val : bf16 - %t_cast = arith.extf %t : bf16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xbf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xbf16> + %5 = arith.mulf %3, %4 : bf16 + %6 = arith.extf %5 : bf16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xbf16> to memref<*xbf16> // call @printMemrefbf16(%cast) : (memref<*xbf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xbf16> - memref.dealloc %B : memref<1024x1024xbf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xbf16> + memref.dealloc %alloc_2 : memref<1024x1024xbf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir index 543c1526a..24aee0ae4 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f16_blk_16x32x32.mlir @@ -1,142 +1,129 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>) -> (memref<1024x1024xf32>, memref<1024x1024xf16>) attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) -> (memref<1024x1024xf32>, memref<1024x1024xf16>) attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - %D_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>, %D_gpu : memref<1024x1024xf16>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu, %D_gpu : memref<1024x1024xf32>, memref<1024x1024xf16> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + %memref_2 = gpu.alloc () : memref<1024x1024xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>, %memref_2 : memref<1024x1024xf16>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + %alloc_3 = memref.alloc() : memref<1024x1024xf16> + gpu.memcpy %alloc_3, %memref_2 : memref<1024x1024xf16>, memref<1024x1024xf16> + gpu.dealloc %memref_2 : memref<1024x1024xf16> + return %alloc, %alloc_3 : memref<1024x1024xf32>, memref<1024x1024xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %D: memref<1024x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index - %c_init_value = arith.constant dense<0.0> : vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %cst = arith.constant dense<0.000000e+00> : vector<16x32xf32> + %2 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %3 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %4:3 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %2, %arg6 = %3, %arg7 = %cst) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %8 = xetile.load_tile %arg5 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %9 = xetile.load_tile %arg6 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %10 = xetile.tile_mma %8, %9, %arg7 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %11 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<16x32xf16> + %12 = xetile.update_tile_offset %arg6, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %11, %12, %10 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - %c_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - xetile.store_tile %out#2, %c_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> - %d_tile = xetile.init_tile %D[%m, %n] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %value = arith.truncf %out#2 : vector<16x32xf32> to vector<16x32xf16> - xetile.store_tile %value, %d_tile: vector<16x32xf16>, !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + xetile.store_tile %4#2, %5 : vector<16x32xf32>, !xetile.tile<16x32xf32> + %6 = xetile.init_tile %arg3[%0, %1] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %7 = arith.truncf %4#2 : vector<16x32xf32> to vector<16x32xf16> + xetile.store_tile %7, %6 : vector<16x32xf16>, !xetile.tile<16x32xf16> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } - // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = arith.constant 0.0 : f32 - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %cst) -> (f32) { + %2 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %3 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %4 = arith.mulf %2, %3 : f16 + %5 = arith.extf %4 : f16 to f32 + %6 = arith.addf %5, %arg3 : f32 + scf.yield %6 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %1, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> } } - %2:2 = call @test(%A, %B) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> (memref<1024x1024xf32>, memref<1024x1024xf16>) - %cast_C = memref.cast %2#0 : memref<1024x1024xf32> to memref<*xf32> - %cast_D = memref.cast %2#1 : memref<1024x1024xf16> to memref<*xf16> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF16(%cast_D) : (memref<*xf16>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - call @printAllcloseF16(%cast_D, %cast_C_ref) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0:2 = call @test(%alloc, %alloc_2) : (memref<1024x1024xf16>, memref<1024x1024xf16>) -> (memref<1024x1024xf32>, memref<1024x1024xf16>) + %cast = memref.cast %0#0 : memref<1024x1024xf32> to memref<*xf32> + %cast_4 = memref.cast %0#1 : memref<1024x1024xf16> to memref<*xf16> + %cast_5 = memref.cast %alloc_3 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + call @printAllcloseF16(%cast_4, %cast_5) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir index ea9e2e1b3..643b15321 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir @@ -1,151 +1,136 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = xetile.tile_mma %7, %8, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %10 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %11 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %10, %11, %9 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_dynamic_memref.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_dynamic_memref.mlir index 8fcf9e691..a037c2cde 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_dynamic_memref.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_dynamic_memref.mlir @@ -1,150 +1,133 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> // Make the memrefs dynamic - %A_gpu_cast = memref.cast %A_gpu : memref<1024x1024xf16> to memref - %B_gpu_cast = memref.cast %B_gpu : memref<1024x1024xf16> to memref - %C_gpu_cast = memref.cast %C_gpu : memref<1024x1024xf32> to memref - - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu_cast : memref, %B_gpu_cast : memref, %C_gpu_cast : memref) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + %cast = memref.cast %memref : memref<1024x1024xf16> to memref + %cast_2 = memref.cast %memref_0 : memref<1024x1024xf16> to memref + %cast_3 = memref.cast %memref_1 : memref<1024x1024xf32> to memref + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%cast : memref, %cast_2 : memref, %cast_3 : memref) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref, %B: memref, %C: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref, %arg1: memref, %arg2: memref) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = xetile.tile_mma %7, %8, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %10 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %11 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %10, %11, %9 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir index 2b76a847a..dd27ae6ad 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_a.mlir @@ -1,138 +1,128 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c32, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c32, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c32 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<32x32xf32> -> vector<32x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%c0, %m] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %a_trans = vector.transpose %a_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_trans, %b_value, %c_value - : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c32 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<32x32xf32> -> vector<32x32xf32> + %4 = xetile.init_tile %arg0[%c0, %0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = vector.transpose %7, [1, 0] : vector<32x32xf16> to vector<32x32xf16> + %10 = xetile.tile_mma %9, %8, %arg6 : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %11 = xetile.update_tile_offset %arg4, [%c32, %c0] : !xetile.tile<32x32xf16> + %12 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %11, %12, %10 : !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<32x32xf32>, !xetile.tile<32x32xf32> + xetile.store_tile %6#2, %2 : vector<32x32xf32>, !xetile.tile<32x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%k, %i] : memref<1024x1024xf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg2, %arg0] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir index b1ba03feb..2697a24f6 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_b.mlir @@ -1,152 +1,137 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%n, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_trans = vector.transpose %b_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_trans, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c32] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%1, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = vector.transpose %8, [1, 0] : vector<32x32xf16> to vector<32x32xf16> + %10 = xetile.tile_mma %7, %9, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %11 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %12 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16> + scf.yield %11, %12, %10 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix B ; B[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %B[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix A an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %A[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %A[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%j, %k] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<1024x1024xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_col_major_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_col_major_b.mlir index 7894d3891..b30d1a372 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_col_major_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_col_major_b.mlir @@ -1,151 +1,144 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16, strided<[1, 1024]>>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16, strided<[1, 1024]>>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - %B_gpu_cm = memref.reinterpret_cast %B_gpu to offset : [0], sizes : [1024,1024], strides : [1, 1024] : memref<1024x1024xf16> to memref<1024x1024xf16, strided<[1, 1024]>> - memref.copy %B, %B_gpu_cm : memref<1024x1024xf16, strided<[1, 1024]>> to memref<1024x1024xf16, strided<[1, 1024]>> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu_cm : memref<1024x1024xf16, strided<[1, 1024]>>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + + // gpu.memcpy only works on memref with identity layout + // We have to: + // - reinterpret cast the coloumn-major memref to row-major (identity layout) one + // - do gpu.memcpy on identiy memref + // - reinterpret cast back the gpu buffer to coloumn-major + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + %b_host_reinterpret_cast = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1024, 1024], strides: [1024, 1] : memref<1024x1024xf16, strided<[1, 1024]>> to memref<1024x1024xf16> + gpu.memcpy %memref_0, %b_host_reinterpret_cast : memref<1024x1024xf16>, memref<1024x1024xf16> + %reinterpret_cast = memref.reinterpret_cast %memref_0 to offset: [0], sizes: [1024, 1024], strides: [1, 1024] : memref<1024x1024xf16> to memref<1024x1024xf16, strided<[1, 1024]>> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %reinterpret_cast : memref<1024x1024xf16, strided<[1, 1024]>>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B_cm: memref<1024x1024xf16, strided<[1, 1024]>>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16, strided<[1, 1024]>>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B_cm[%n, %c0] : memref<1024x1024xf16, strided<[1, 1024]>> -> !xetile.tile<32x32xf16, #xetile.tile_attr> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<32x32xf16> - %b_trans = vector.transpose %b_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_trans, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%1, %c0] : memref<1024x1024xf16, strided<[1, 1024]>> -> !xetile.tile<32x32xf16, #xetile.tile_attr> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<32x32xf16> + %9 = vector.transpose %8, [1, 0] : vector<32x32xf16> to vector<32x32xf16> + %10 = xetile.tile_mma %7, %9, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %11 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %12 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr> + scf.yield %11, %12, %10 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> // convert B to col-major - %B_cm = memref.reinterpret_cast %B to offset : [0], sizes : [1024,1024], strides : [1, 1024] : memref<1024x1024xf16> to memref<1024x1024xf16, strided<[1, 1024], offset:0>> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix B ; B[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %B_cm[%i, %j] : memref<1024x1024xf16, strided<[1, 1024]>> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %reinterpret_cast = memref.reinterpret_cast %alloc_2 to offset: [0], sizes: [1024, 1024], strides: [1, 1024] : memref<1024x1024xf16> to memref<1024x1024xf16, strided<[1, 1024]>> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %reinterpret_cast[%arg0, %arg1] : memref<1024x1024xf16, strided<[1, 1024]>> } } // make matrix A an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %A[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %A[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B_cm[%j, %k] : memref<1024x1024xf16, strided<[1, 1024]>> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %reinterpret_cast[%arg1, %arg2] : memref<1024x1024xf16, strided<[1, 1024]>> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B_cm, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16, strided<[1, 1024]>>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () // memref.dealloc %A : memref<1024x1024xf16> // memref.dealloc %B : memref<1024x1024xf16> // memref.dealloc %C : memref<1024x1024xf32> // memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %reinterpret_cast, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16, strided<[1, 1024]>>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_preop_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_preop_b.mlir index 3b9585d15..a44ee7c9d 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_preop_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32_transpose_preop_b.mlir @@ -1,154 +1,139 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%n, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_trans = vector.transpose %b_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> - %preop = arith.addf %b_trans, %b_trans : vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %preop, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c32] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%1, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = vector.transpose %8, [1, 0] : vector<32x32xf16> to vector<32x32xf16> + %10 = arith.addf %9, %9 : vector<32x32xf16> + %11 = xetile.tile_mma %7, %10, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %12 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %13 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16> + scf.yield %12, %13, %11 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix B ; B[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %B[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix A an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %A[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %A[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%j, %k] : memref<1024x1024xf16> - %preop = arith.addf %b_val, %b_val : f16 - %t = arith.mulf %a_val, %preop : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<1024x1024xf16> + %5 = arith.addf %4, %4 : f16 + %6 = arith.mulf %3, %5 : f16 + %7 = arith.extf %6 : f16 to f32 + %8 = arith.addf %7, %arg3 : f32 + scf.yield %8 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir index 37a2ef989..df78c049d 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f32_f32_f32_with_truncf_a_b.mlir @@ -1,77 +1,65 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %A, %A_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %B, %B_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf32>, %B_gpu : memref<1024x1024xf32>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf32> - gpu.dealloc %B_gpu : memref<1024x1024xf32> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_0 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf32>, %memref_0 : memref<1024x1024xf32>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf32> + gpu.dealloc %memref_0 : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf32>, !xetile.tile<32x32xf32>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf32> -> vector<32x32xf32> - %a_trunc = arith.truncf %a_value: vector<16x32xf32> to vector<16x32xf16> - %b_trunc = arith.truncf %b_value: vector<32x32xf32> to vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_trunc, %b_trunc, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<16x32xf32> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] : !xetile.tile<32x32xf32> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf32>, !xetile.tile<32x32xf32>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf32>, !xetile.tile<32x32xf32>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf32> -> vector<32x32xf32> + %9 = arith.truncf %7 : vector<16x32xf32> to vector<16x32xf16> + %10 = arith.truncf %8 : vector<32x32xf32> to vector<32x32xf16> + %11 = xetile.tile_mma %9, %10, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %12 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf32> + %13 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf32> + scf.yield %12, %13, %11 : !xetile.tile<16x32xf32>, !xetile.tile<32x32xf32>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } @@ -79,72 +67,70 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f32 - %cf_1 = arith.constant 1.0 : f32 - %A = memref.alloc() : memref<1024x1024xf32> - %B = memref.alloc() : memref<1024x1024xf32> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i32 - %val = arith.uitofp %t : i32 to f32 - memref.store %val, %A[%i, %j] : memref<1024x1024xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %alloc = memref.alloc() : memref<1024x1024xf32> + %alloc_1 = memref.alloc() : memref<1024x1024xf32> + %alloc_2 = memref.alloc() : memref<1024x1024xf32> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i32 + %2 = arith.uitofp %1 : i32 to f32 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf32> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_0, %alloc_1[%arg0, %arg1] : memref<1024x1024xf32> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf32> + memref.store %cst, %alloc_1[%arg0, %arg1] : memref<1024x1024xf32> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf32> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xf32> - %t = arith.mulf %a_val, %b_val : f32 - %c_sum = arith.addf %t, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf32> + %4 = memref.load %alloc_1[%arg2, %arg1] : memref<1024x1024xf32> + %5 = arith.mulf %3, %4 : f32 + %6 = arith.addf %5, %arg3 : f32 + scf.yield %6 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf32>, memref<1024x1024xf32>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf32> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf32> - memref.dealloc %B : memref<1024x1024xf32> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_1, %alloc_2) : (memref<1024x1024xf32>, memref<1024x1024xf32>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_4 = memref.cast %alloc_3 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_4) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf32> + memref.dealloc %alloc_1 : memref<1024x1024xf32> + memref.dealloc %alloc_2 : memref<1024x1024xf32> + memref.dealloc %alloc_3 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir index ad8e962b0..7d90341c6 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir @@ -1,138 +1,130 @@ -// TODO: Add imex-runner commands -// RUN: +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +// TODO: Add imex-runner commands // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xi8>, %arg1: memref<1024x1024xi8>, %arg2: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> - memref.copy %A, %A_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> - memref.copy %B, %B_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> - memref.copy %C, %C_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xi8>, %B_gpu : memref<1024x1024xi8>, %C_gpu : memref<1024x1024xi32>) - gpu.dealloc %A_gpu : memref<1024x1024xi8> - gpu.dealloc %B_gpu : memref<1024x1024xi8> - return %C_gpu : memref<1024x1024xi32> + %memref = gpu.alloc () : memref<1024x1024xi8> + gpu.memcpy %memref, %arg0 : memref<1024x1024xi8>, memref<1024x1024xi8> + %memref_0 = gpu.alloc () : memref<1024x1024xi8> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xi8>, memref<1024x1024xi8> + %memref_1 = gpu.alloc () : memref<1024x1024xi32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xi32>, memref<1024x1024xi32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xi8>, %memref_0 : memref<1024x1024xi8>, %memref_1 : memref<1024x1024xi32>) + gpu.dealloc %memref : memref<1024x1024xi8> + gpu.dealloc %memref_0 : memref<1024x1024xi8> + %alloc = memref.alloc() : memref<1024x1024xi32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xi32>, memref<1024x1024xi32> + gpu.dealloc %memref_1 : memref<1024x1024xi32> + return %alloc : memref<1024x1024xi32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16: index - %n = arith.muli %block_id_y, %c32: index + gpu.module @test_kernel { // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xi32> -> vector<16x32xi32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xi8> -> !xetile.tile<16x64xi8> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xi8> -> !xetile.tile<64x32xi8> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c64 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x64xi8>, !xetile.tile<64x32xi8>, vector<16x32xi32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x64xi8> -> vector<16x64xi8> - %b_value = xetile.load_tile %b_tile : !xetile.tile<64x32xi8> -> vector<64x32xi8> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<16x64xi8>, vector<64x32xi8>, vector<16x32xi32> -> vector<16x32xi32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c64] : !xetile.tile<16x64xi8> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c64, %c0] : !xetile.tile<64x32xi8> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x64xi8>, !xetile.tile<64x32xi8>, vector<16x32xi32> - } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile : vector<16x32xi32>, !xetile.tile<16x32xi32> - gpu.return + gpu.func @test_kernel(%arg0: memref<1024x1024xi8>, %arg1: memref<1024x1024xi8>, %arg2: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xi32> -> vector<16x32xi32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xi8> -> !xetile.tile<16x64xi8> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xi8> -> !xetile.tile<64x32xi8> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c64 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x64xi8>, !xetile.tile<64x32xi8>, vector<16x32xi32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x64xi8> -> vector<16x64xi8> + %8 = xetile.load_tile %arg5 : !xetile.tile<64x32xi8> -> vector<64x32xi8> + %9 = xetile.tile_mma %7, %8, %arg6 : vector<16x64xi8>, vector<64x32xi8>, vector<16x32xi32> -> vector<16x32xi32> + %10 = xetile.update_tile_offset %arg4, [%c0, %c64] : !xetile.tile<16x64xi8> + %11 = xetile.update_tile_offset %arg5, [%c64, %c0] : !xetile.tile<64x32xi8> + scf.yield %10, %11, %9 : !xetile.tile<16x64xi8>, !xetile.tile<64x32xi8>, vector<16x32xi32> + } + xetile.store_tile %6#2, %2 : vector<16x32xi32>, !xetile.tile<16x32xi32> + gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %c0_i32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %ci_0 = arith.constant 0 : i8 - %ci_1 = arith.constant 1 : i8 - %A = memref.alloc() : memref<1024x1024xi8> - %B = memref.alloc() : memref<1024x1024xi8> - %C = memref.alloc() : memref<1024x1024xi32> - %C_ref = memref.alloc() : memref<1024x1024xi32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %val = index.castu %j : index to i8 - memref.store %val, %A[%i, %j] : memref<1024x1024xi8> + %c0_i8 = arith.constant 0 : i8 + %c1_i8 = arith.constant 1 : i8 + %alloc = memref.alloc() : memref<1024x1024xi8> + %alloc_0 = memref.alloc() : memref<1024x1024xi8> + %alloc_1 = memref.alloc() : memref<1024x1024xi32> + %alloc_2 = memref.alloc() : memref<1024x1024xi32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i8 + memref.store %1, %alloc[%arg0, %arg1] : memref<1024x1024xi8> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %ci_1, %B[%i, %j] : memref<1024x1024xi8> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %c1_i8, %alloc_0[%arg0, %arg1] : memref<1024x1024xi8> } else { - memref.store %ci_0, %B[%i, %j] : memref<1024x1024xi8> + memref.store %c0_i8, %alloc_0[%arg0, %arg1] : memref<1024x1024xi8> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_i32 = arith.constant 0: i32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_i32, %C[%i, %j] : memref<1024x1024xi32> - memref.store %c0_i32, %C_ref[%i, %j] : memref<1024x1024xi32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %c0_i32, %alloc_1[%arg0, %arg1] : memref<1024x1024xi32> + memref.store %c0_i32, %alloc_2[%arg0, %arg1] : memref<1024x1024xi32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xi32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> i32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xi8> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xi8> - %a_val_i32 = arith.extui %a_val : i8 to i32 - %b_val_i32 = arith.extui %b_val : i8 to i32 - %t = arith.muli %a_val_i32, %b_val_i32 : i32 - %c_sum = arith.addi %t, %c_partial : i32 - scf.yield %c_sum : i32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_2[%arg0, %arg1] : memref<1024x1024xi32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (i32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xi8> + %4 = memref.load %alloc_0[%arg2, %arg1] : memref<1024x1024xi8> + %5 = arith.extui %3 : i8 to i32 + %6 = arith.extui %4 : i8 to i32 + %7 = arith.muli %5, %6 : i32 + %8 = arith.addi %7, %arg3 : i32 + scf.yield %8 : i32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xi32> + memref.store %2, %alloc_2[%arg0, %arg1] : memref<1024x1024xi32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xi8>, memref<1024x1024xi8>, memref<1024x1024xi32>) -> memref<1024x1024xi32> - %cast_C = memref.cast %2 : memref<1024x1024xi32> to memref<*xi32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xi32> to memref<*xi32> - - call @printAllcloseI32(%cast_C, %cast_C_ref) : (memref<*xi32>, memref<*xi32>) -> () - memref.dealloc %A : memref<1024x1024xi8> - memref.dealloc %B : memref<1024x1024xi8> - memref.dealloc %C : memref<1024x1024xi32> - memref.dealloc %C_ref : memref<1024x1024xi32> + %0 = call @test(%alloc, %alloc_0, %alloc_1) : (memref<1024x1024xi8>, memref<1024x1024xi8>, memref<1024x1024xi32>) -> memref<1024x1024xi32> + %cast = memref.cast %0 : memref<1024x1024xi32> to memref<*xi32> + %cast_3 = memref.cast %alloc_2 : memref<1024x1024xi32> to memref<*xi32> + call @printAllcloseI32(%cast, %cast_3) : (memref<*xi32>, memref<*xi32>) -> () + memref.dealloc %alloc : memref<1024x1024xi8> + memref.dealloc %alloc_0 : memref<1024x1024xi8> + memref.dealloc %alloc_1 : memref<1024x1024xi32> + memref.dealloc %alloc_2 : memref<1024x1024xi32> return } func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir index bef7b27dd..0c6d04072 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_preop_postop_bf16_bf16_f32.mlir @@ -1,173 +1,157 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>, %Bias: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %A, %A_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %B, %B_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %Bias_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %Bias, %Bias_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xbf16>, %B_gpu : memref<1024x1024xbf16>, %C_gpu : memref<1024x1024xf32>, %Bias_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xbf16> - gpu.dealloc %B_gpu : memref<1024x1024xbf16> - gpu.dealloc %Bias_gpu : memref<1024x1024xf32> - memref.copy %C_gpu, %C : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.dealloc %C_gpu : memref<1024x1024xf32> - return %C : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_0 = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_2 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_2, %arg3 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xbf16>, %memref_0 : memref<1024x1024xbf16>, %memref_1 : memref<1024x1024xf32>, %memref_2 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xbf16> + gpu.dealloc %memref_0 : memref<1024x1024xbf16> + gpu.dealloc %memref_2 : memref<1024x1024xf32> + gpu.memcpy %arg2, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %arg2 : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>, %Bias: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // intialize Bias tile and load it - %bias_init_tile = xetile.init_tile %Bias[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %bias_init_value = xetile.load_tile %bias_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xbf16> -> !xetile.tile<16x32xbf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xbf16> -> !xetile.tile<32x32xbf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xbf16> -> vector<16x32xbf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> - %a_value_preop = arith.addf %a_value, %a_value : vector<16x32xbf16> - %b_value_preop = arith.addf %b_value, %b_value : vector<32x32xbf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value_preop, %b_value_preop, %c_value - : vector<16x32xbf16>, vector<32x32xbf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<16x32xbf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] : !xetile.tile<32x32xbf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg3[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %5 = xetile.load_tile %4 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %6 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xbf16> -> !xetile.tile<16x32xbf16> + %7 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xbf16> -> !xetile.tile<32x32xbf16> + %8:3 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %6, %arg6 = %7, %arg7 = %3) -> (!xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32>) { + %10 = xetile.load_tile %arg5 : !xetile.tile<16x32xbf16> -> vector<16x32xbf16> + %11 = xetile.load_tile %arg6 : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + %12 = arith.addf %10, %10 : vector<16x32xbf16> + %13 = arith.addf %11, %11 : vector<32x32xbf16> + %14 = xetile.tile_mma %12, %13, %arg7 : vector<16x32xbf16>, vector<32x32xbf16>, vector<16x32xf32> -> vector<16x32xf32> + %15 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<16x32xbf16> + %16 = xetile.update_tile_offset %arg6, [%c32, %c0] : !xetile.tile<32x32xbf16> + scf.yield %15, %16, %14 : !xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32> } // add bias to the final C tile result - %c_bias = arith.addf %out#2, %bias_init_value : vector<16x32xf32> // store the final accumulated C tile result back to memory - xetile.store_tile %c_bias, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + %9 = arith.addf %8#2, %5 : vector<16x32xf32> + xetile.store_tile %9, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %A = memref.alloc() : memref<1024x1024xbf16> - %B = memref.alloc() : memref<1024x1024xbf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> - %Bias = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to bf16 - memref.store %val, %A[%i, %j] : memref<1024x1024xbf16> + %cst_1 = arith.constant 0.000000e+00 : bf16 + %cst_2 = arith.constant 1.000000e+00 : bf16 + %alloc = memref.alloc() : memref<1024x1024xbf16> + %alloc_3 = memref.alloc() : memref<1024x1024xbf16> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + %alloc_5 = memref.alloc() : memref<1024x1024xf32> + %alloc_6 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to bf16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xbf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xbf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_2, %alloc_3[%arg0, %arg1] : memref<1024x1024xbf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xbf16> + memref.store %cst_1, %alloc_3[%arg0, %arg1] : memref<1024x1024xbf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst_0, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst_0, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } // intialize matrix Bias ; Bias[i, j] = 1 - %c1_f32 = arith.constant 1.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c1_f32, %Bias[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_6[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xbf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xbf16> - %a_val_preop = arith.addf %a_val, %a_val : bf16 - %b_val_preop = arith.addf %b_val, %b_val : bf16 - %t = arith.mulf %a_val_preop, %b_val_preop : bf16 - %t_cast = arith.extf %t : bf16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %5 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xbf16> + %6 = memref.load %alloc_3[%arg2, %arg1] : memref<1024x1024xbf16> + %7 = arith.addf %5, %5 : bf16 + %8 = arith.addf %6, %6 : bf16 + %9 = arith.mulf %7, %8 : bf16 + %10 = arith.extf %9 : bf16 to f32 + %11 = arith.addf %10, %arg3 : f32 + scf.yield %11 : f32 } - %bias_val = memref.load %Bias[%i, %j] : memref<1024x1024xf32> - %c_val_bias = arith.addf %c_val, %bias_val : f32 - memref.store %c_val_bias, %C_ref[%i, %j] : memref<1024x1024xf32> + %3 = memref.load %alloc_6[%arg0, %arg1] : memref<1024x1024xf32> + %4 = arith.addf %2, %3 : f32 + memref.store %4, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C, %Bias) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>, memref<1024x1024xf32>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xbf16> to memref<*xbf16> // call @printMemrefbf16(%cast) : (memref<*xbf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xbf16> - memref.dealloc %B : memref<1024x1024xbf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_3, %alloc_4, %alloc_6) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>, memref<1024x1024xf32>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_7 = memref.cast %alloc_5 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_7) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xbf16> + memref.dealloc %alloc_3 : memref<1024x1024xbf16> + memref.dealloc %alloc_4 : memref<1024x1024xf32> + memref.dealloc %alloc_5 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir index 177020419..92249d543 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_transposed_b_preop_postop_bf16_bf16_f32.mlir @@ -1,174 +1,158 @@ -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 IMEX_ENABLE_LARGE_REG_FILE=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>, %Bias: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %A, %A_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> - memref.copy %B, %B_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %Bias_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %Bias, %Bias_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xbf16>, %B_gpu : memref<1024x1024xbf16>, %C_gpu : memref<1024x1024xf32>, %Bias_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xbf16> - gpu.dealloc %B_gpu : memref<1024x1024xbf16> - gpu.dealloc %Bias_gpu : memref<1024x1024xf32> - memref.copy %C_gpu, %C : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.dealloc %C_gpu : memref<1024x1024xf32> - return %C : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_0 = gpu.alloc () : memref<1024x1024xbf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xbf16>, memref<1024x1024xbf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_2 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_2, %arg3 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xbf16>, %memref_0 : memref<1024x1024xbf16>, %memref_1 : memref<1024x1024xf32>, %memref_2 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xbf16> + gpu.dealloc %memref_0 : memref<1024x1024xbf16> + gpu.dealloc %memref_2 : memref<1024x1024xf32> + gpu.memcpy %arg2, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %arg2 : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>, %Bias: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // intialize Bias tile and load it - %bias_init_tile = xetile.init_tile %Bias[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %bias_init_value = xetile.load_tile %bias_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xbf16> -> !xetile.tile<16x32xbf16> - %b_init_tile = xetile.init_tile %B[%n, %c0] : memref<1024x1024xbf16> -> !xetile.tile<32x32xbf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xbf16> -> vector<16x32xbf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> - %b_value_trans = vector.transpose %b_value, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> - %a_value_preop = arith.addf %a_value, %a_value : vector<16x32xbf16> - %b_value_preop = arith.addf %b_value_trans, %b_value_trans : vector<32x32xbf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value_preop, %b_value_preop, %c_value - : vector<16x32xbf16>, vector<32x32xbf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<16x32xbf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c32] : !xetile.tile<32x32xbf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg3[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %5 = xetile.load_tile %4 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %6 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xbf16> -> !xetile.tile<16x32xbf16> + %7 = xetile.init_tile %arg1[%1, %c0] : memref<1024x1024xbf16> -> !xetile.tile<32x32xbf16> + %8:3 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %6, %arg6 = %7, %arg7 = %3) -> (!xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32>) { + %10 = xetile.load_tile %arg5 : !xetile.tile<16x32xbf16> -> vector<16x32xbf16> + %11 = xetile.load_tile %arg6 : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + %12 = vector.transpose %11, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> + %13 = arith.addf %10, %10 : vector<16x32xbf16> + %14 = arith.addf %12, %12 : vector<32x32xbf16> + %15 = xetile.tile_mma %13, %14, %arg7 : vector<16x32xbf16>, vector<32x32xbf16>, vector<16x32xf32> -> vector<16x32xf32> + %16 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<16x32xbf16> + %17 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xbf16> + scf.yield %16, %17, %15 : !xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32> } // add bias to the final C tile result - %c_bias = arith.addf %out#2, %bias_init_value : vector<16x32xf32> // store the final accumulated C tile result back to memory - xetile.store_tile %c_bias, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + %9 = arith.addf %8#2, %5 : vector<16x32xf32> + xetile.store_tile %9, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %A = memref.alloc() : memref<1024x1024xbf16> - %B = memref.alloc() : memref<1024x1024xbf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> - %Bias = memref.alloc() : memref<1024x1024xf32> // intialize matrix B ; B[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to bf16 - memref.store %val, %B[%i, %j] : memref<1024x1024xbf16> + %cst_1 = arith.constant 0.000000e+00 : bf16 + %cst_2 = arith.constant 1.000000e+00 : bf16 + %alloc = memref.alloc() : memref<1024x1024xbf16> + %alloc_3 = memref.alloc() : memref<1024x1024xbf16> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + %alloc_5 = memref.alloc() : memref<1024x1024xf32> + %alloc_6 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to bf16 + memref.store %2, %alloc_3[%arg0, %arg1] : memref<1024x1024xbf16> } } // make matrix A an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %A[%i, %j] : memref<1024x1024xbf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_2, %alloc[%arg0, %arg1] : memref<1024x1024xbf16> } else { - memref.store %cf_0, %A[%i, %j] : memref<1024x1024xbf16> + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<1024x1024xbf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst_0, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst_0, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } // intialize matrix Bias ; Bias[i, j] = 1 - %c1_f32 = arith.constant 1.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c1_f32, %Bias[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_6[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xbf16> - %b_val = memref.load %B[%j, %k] : memref<1024x1024xbf16> - %a_val_preop = arith.addf %a_val, %a_val : bf16 - %b_val_preop = arith.addf %b_val, %b_val : bf16 - %t = arith.mulf %a_val_preop, %b_val_preop : bf16 - %t_cast = arith.extf %t : bf16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %5 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xbf16> + %6 = memref.load %alloc_3[%arg1, %arg2] : memref<1024x1024xbf16> + %7 = arith.addf %5, %5 : bf16 + %8 = arith.addf %6, %6 : bf16 + %9 = arith.mulf %7, %8 : bf16 + %10 = arith.extf %9 : bf16 to f32 + %11 = arith.addf %10, %arg3 : f32 + scf.yield %11 : f32 } - %bias_val = memref.load %Bias[%i, %j] : memref<1024x1024xf32> - %c_val_bias = arith.addf %c_val, %bias_val : f32 - memref.store %c_val_bias, %C_ref[%i, %j] : memref<1024x1024xf32> + %3 = memref.load %alloc_6[%arg0, %arg1] : memref<1024x1024xf32> + %4 = arith.addf %2, %3 : f32 + memref.store %4, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C, %Bias) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>, memref<1024x1024xf32>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xbf16> to memref<*xbf16> // call @printMemrefbf16(%cast) : (memref<*xbf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xbf16> - memref.dealloc %B : memref<1024x1024xbf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_3, %alloc_4, %alloc_6) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>, memref<1024x1024xf32>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_7 = memref.cast %alloc_5 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_7) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xbf16> + memref.dealloc %alloc_3 : memref<1024x1024xbf16> + memref.dealloc %alloc_4 : memref<1024x1024xf32> + memref.dealloc %alloc_5 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir index 013d6f73a..7f5539561 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir @@ -1,6 +1,9 @@ -// TODO: Add run commands -// RUN: +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck +// TODO: Add run commands // NOTES: // This example assumes 2x2 subgroups per one workgroup and the kernel specifies the computation // done by a single subgroup. This shows the result of lowering wg_gemm_1kx1kx1k_f16_f16_f32 example @@ -14,53 +17,30 @@ // // #wg_map_c = #xetile.wg_map // #xe_map_c = #xetile.xe_map - - - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c2, %c2, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c2, %c2, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index + gpu.module @test_kernel { // %c8 = arith.constant 8 : index // %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c128 : index - %n = arith.muli %block_id_y, %c128 : index - // get linear sub group id - %sg_id = gpu.subgroup_id : index // get the x, y cordinate of this linear id assuming [2, 2] coord system - %c2 = arith.constant 2 : index - %sg_coord_x = index.floordivs %sg_id, %c2 - %sg_coord_y = index.and %sg_id, %c1 - // each subgroup in the [2, 2] subgroups needs to update four 32x32 C sub-tiles // that are arranged in round robin fashin according to SG coords // | (0,0) | (0,1) | (0,0) | (0,1) | @@ -68,204 +48,161 @@ module @gemm attributes {gpu.container_module} { // | (0,0) | (0,1) | (0,0) | (0,1) | // | (1,0) | (1,1) | (1,0) | (1,1) | // first calculate the offset into the first SG sub-tile - %C_sg_tile_offset_x = index.mul %c32, %sg_coord_x - %C_sg_tile_offset_y = index.mul %c32, %sg_coord_y - // C sub tiles // global offset for sub tile 1 for this SG - %global_offset_slice0_x = index.add %m, %C_sg_tile_offset_x - %global_offset_slice0_y = index.add %n, %C_sg_tile_offset_y // global offset for sub tile 2 for this SG (shift 64 in x) - %global_offset_slice1_x = index.add %global_offset_slice0_x, %c64 - %global_offset_slice1_y = index.add %global_offset_slice0_y, %c0 // global offset for sub tile 3 for this SG (shift 64 in y) - %global_offset_slice2_x = index.add %global_offset_slice0_x, %c0 - %global_offset_slice2_y = index.add %global_offset_slice0_y, %c64 // global offset for sub tile 4 for this SG (shift 64 in x and y) - %global_offset_slice3_x = index.add %global_offset_slice0_x, %c64 - %global_offset_slice3_y = index.add %global_offset_slice0_y, %c64 - // intialize C sub tiles and load them - %c_init_subtile0 = xetile.init_tile %C[%global_offset_slice0_x, %global_offset_slice0_y] : memref<1024x1024xf32> - -> !xetile.tile<32x32xf32> - %c_init_value0 = xetile.load_tile %c_init_subtile0 : !xetile.tile<32x32xf32> - -> vector<32x32xf32> - %c_init_subtile1 = xetile.init_tile %C[%global_offset_slice1_x, %global_offset_slice1_y] : memref<1024x1024xf32> - -> !xetile.tile<32x32xf32> - %c_init_value1 = xetile.load_tile %c_init_subtile1 : !xetile.tile<32x32xf32> - -> vector<32x32xf32> - %c_init_subtile2 = xetile.init_tile %C[%global_offset_slice2_x, %global_offset_slice2_y] : memref<1024x1024xf32> - -> !xetile.tile<32x32xf32> - %c_init_value2 = xetile.load_tile %c_init_subtile2 : !xetile.tile<32x32xf32> - -> vector<32x32xf32> - %c_init_subtile3 = xetile.init_tile %C[%global_offset_slice3_x, %global_offset_slice3_y] : memref<1024x1024xf32> - -> !xetile.tile<32x32xf32> - %c_init_value3 = xetile.load_tile %c_init_subtile2 : !xetile.tile<32x32xf32> - -> vector<32x32xf32> - // for A, each subgroup need to load two 32x128 subtiles. The access arrangement is as follows // | (0,0), (0,1)| // | (1,0), (1,1)| // | (0,0), (0,1)| // | (1,0), (1,1)| - // calculate the initial offset in x dim for this sg - %a_init_offset = index.mul %sg_coord_x, %c32 - // x offsets for A subtiles - %a_subtile0_x = index.add %m, %a_init_offset - %a_subtile1_x = index.add %a_subtile0_x, %c64 - // init A subtiles - %a_init_subtile0 = xetile.init_tile %A[%a_subtile0_x, %c0] : memref<1024x1024xf16> - -> !xetile.tile<32x128xf16> - %a_init_subtile1 = xetile.init_tile %A[%a_subtile1_x, %c0] : memref<1024x1024xf16> - -> !xetile.tile<32x128xf16> - // for B, each subgroup need to load two 128x32 subtiles. The access arrangement is as follows // | (0,0) | (0,1) | (0,0) | (0, 1) | // | (1,0) | (1,1) | (1,0) | (1, 1) | - // calculate the initial offset along y dim for this sg - %b_init_offset = index.mul %sg_coord_y, %c32 - // y offsets for B subtiles - %b_subtile0_y = index.add %n, %b_init_offset - %b_subtile1_y = index.add %b_subtile0_y, %c64 - // init B subtiles - %b_init_subtile0 = xetile.init_tile %B[%c0, %b_subtile0_y] : memref<1024x1024xf16> - -> !xetile.tile<128x32xf16> - %b_init_subtile1 = xetile.init_tile %B[%c0, %b_subtile1_y] : memref<1024x1024xf16> - -> !xetile.tile<128x32xf16> - // compute the value of C subtiles by iterating over subtiles in k-dimension and doing dpas - %out:8 = scf.for %k = %c0 to %c1024 step %c128 - iter_args(%a_subtile0 = %a_init_subtile0, %a_subtile1 = %a_init_subtile1, - %b_subtile0 = %b_init_subtile0, %b_subtile1 = %b_init_subtile1, - %c_value0 = %c_init_value0, %c_value1 = %c_init_value2, - %c_value2 = %c_init_value2, %c_value3 = %c_init_value3) - -> (!xetile.tile<32x128xf16>, - !xetile.tile<32x128xf16>, - !xetile.tile<128x32xf16>, - !xetile.tile<128x32xf16>, - vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>) { - // load A subtiles - %a_value0 = xetile.load_tile %a_subtile0 : !xetile.tile<32x128xf16> - -> vector<32x128xf16> - %a_value1 = xetile.load_tile %a_subtile1 : !xetile.tile<32x128xf16> - -> vector<32x128xf16> - // load B subtiles - %b_value0 = xetile.load_tile %b_subtile0 : !xetile.tile<128x32xf16> - -> vector<128x32xf16> - %b_value1 = xetile.load_tile %b_subtile1 : !xetile.tile<128x32xf16> - -> vector<128x32xf16> - // perform 4 dpas ops and update the C subtiles - %c_new_value0 = xetile.tile_mma %a_value0, %b_value0, %c_value0 - : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> - %c_new_value1 = xetile.tile_mma %a_value0, %b_value1, %c_value1 - : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> - %c_new_value2 = xetile.tile_mma %a_value1, %b_value0, %c_value2 - : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> - %c_new_value3 = xetile.tile_mma %a_value1, %b_value1, %c_value3 - : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> - // update offsets for A subtiles - %a_next_subtile0 = xetile.update_tile_offset %a_subtile0, [%c0, %c128] : !xetile.tile<32x128xf16> - %a_next_subtile1 = xetile.update_tile_offset %a_subtile1, [%c0, %c128] : !xetile.tile<32x128xf16> // update offsets for B subtiles - %b_next_subtile0 = xetile.update_tile_offset %b_subtile0, [%c128, %c0] : !xetile.tile<128x32xf16> - %b_next_subtile1 = xetile.update_tile_offset %b_subtile1, [%c128, %c0] : !xetile.tile<128x32xf16> - // yield subtiles and partial C results - scf.yield %a_next_subtile0, %a_next_subtile1, %b_next_subtile0, %b_next_subtile1, - %c_new_value0, %c_new_value1, %c_new_value2, %c_new_value2 - : !xetile.tile<32x128xf16>, - !xetile.tile<32x128xf16>, - !xetile.tile<128x32xf16>, - !xetile.tile<128x32xf16>, - vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32> - } // store the C final subtiles into memory - xetile.store_tile %out#4, %c_init_subtile0 : vector<32x32xf32>, - !xetile.tile<32x32xf32> - xetile.store_tile %out#5, %c_init_subtile1 : vector<32x32xf32>, - !xetile.tile<32x32xf32> - xetile.store_tile %out#6, %c_init_subtile2 : vector<32x32xf32>, - !xetile.tile<32x32xf32> - xetile.store_tile %out#7, %c_init_subtile3 : vector<32x32xf32>, - !xetile.tile<32x32xf32> - - gpu.return + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c128 : index + %1 = arith.muli %block_id_y, %c128 : index + %2 = gpu.subgroup_id : index + %c2 = arith.constant 2 : index + %3 = index.floordivs %2, %c2 + %4 = index.and %2, %c1 + %5 = index.mul %c32, %3 + %6 = index.mul %c32, %4 + %7 = index.add %0, %5 + %8 = index.add %1, %6 + %9 = index.add %7, %c64 + %10 = index.add %8, %c0 + %11 = index.add %7, %c0 + %12 = index.add %8, %c64 + %13 = index.add %7, %c64 + %14 = index.add %8, %c64 + %15 = xetile.init_tile %arg2[%7, %8] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + %16 = xetile.load_tile %15 : !xetile.tile<32x32xf32> -> vector<32x32xf32> + %17 = xetile.init_tile %arg2[%9, %10] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + %18 = xetile.load_tile %17 : !xetile.tile<32x32xf32> -> vector<32x32xf32> + %19 = xetile.init_tile %arg2[%11, %12] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + %20 = xetile.load_tile %19 : !xetile.tile<32x32xf32> -> vector<32x32xf32> + %21 = xetile.init_tile %arg2[%13, %14] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + %22 = xetile.load_tile %19 : !xetile.tile<32x32xf32> -> vector<32x32xf32> + %23 = index.mul %3, %c32 + %24 = index.add %0, %23 + %25 = index.add %24, %c64 + %26 = xetile.init_tile %arg0[%24, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x128xf16> + %27 = xetile.init_tile %arg0[%25, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x128xf16> + %28 = index.mul %4, %c32 + %29 = index.add %1, %28 + %30 = index.add %29, %c64 + %31 = xetile.init_tile %arg1[%c0, %29] : memref<1024x1024xf16> -> !xetile.tile<128x32xf16> + %32 = xetile.init_tile %arg1[%c0, %30] : memref<1024x1024xf16> -> !xetile.tile<128x32xf16> + %33:8 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %26, %arg5 = %27, %arg6 = %31, %arg7 = %32, %arg8 = %16, %arg9 = %20, %arg10 = %20, %arg11 = %22) -> (!xetile.tile<32x128xf16>, !xetile.tile<32x128xf16>, !xetile.tile<128x32xf16>, !xetile.tile<128x32xf16>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>) { + %34 = xetile.load_tile %arg4 : !xetile.tile<32x128xf16> -> vector<32x128xf16> + %35 = xetile.load_tile %arg5 : !xetile.tile<32x128xf16> -> vector<32x128xf16> + %36 = xetile.load_tile %arg6 : !xetile.tile<128x32xf16> -> vector<128x32xf16> + %37 = xetile.load_tile %arg7 : !xetile.tile<128x32xf16> -> vector<128x32xf16> + %38 = xetile.tile_mma %34, %36, %arg8 : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %39 = xetile.tile_mma %34, %37, %arg9 : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %40 = xetile.tile_mma %35, %36, %arg10 : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %41 = xetile.tile_mma %35, %37, %arg11 : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %42 = xetile.update_tile_offset %arg4, [%c0, %c128] : !xetile.tile<32x128xf16> + %43 = xetile.update_tile_offset %arg5, [%c0, %c128] : !xetile.tile<32x128xf16> + %44 = xetile.update_tile_offset %arg6, [%c128, %c0] : !xetile.tile<128x32xf16> + %45 = xetile.update_tile_offset %arg7, [%c128, %c0] : !xetile.tile<128x32xf16> + scf.yield %42, %43, %44, %45, %38, %39, %40, %40 : !xetile.tile<32x128xf16>, !xetile.tile<32x128xf16>, !xetile.tile<128x32xf16>, !xetile.tile<128x32xf16>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32> + } + xetile.store_tile %33#4, %15 : vector<32x32xf32>, !xetile.tile<32x32xf32> + xetile.store_tile %33#5, %17 : vector<32x32xf32>, !xetile.tile<32x32xf32> + xetile.store_tile %33#6, %19 : vector<32x32xf32>, !xetile.tile<32x32xf32> + xetile.store_tile %33#7, %21 : vector<32x32xf32>, !xetile.tile<32x32xf32> + gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> - - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_broadcast_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_broadcast_b.mlir index aa59e048c..f2087b619 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_broadcast_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_broadcast_b.mlir @@ -1,30 +1,29 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck + #map = affine_map<() -> (0)> #map1 = affine_map<() -> (384)> module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<128x384xf16>, %B: memref<1x384xf16>) -> memref<128x256xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<128x384xf16>, %arg1: memref<1x384xf16>) -> memref<128x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %A_gpu = gpu.alloc host_shared () : memref<128x384xf16> - memref.copy %A, %A_gpu : memref<128x384xf16> to memref<128x384xf16> - %B_gpu = gpu.alloc host_shared () : memref<1x384xf16> - memref.copy %B, %B_gpu : memref<1x384xf16> to memref<1x384xf16> - %D_gpu = gpu.alloc host_shared () : memref<128x256xf32> - gpu.launch_func @m128_n256_k384::@m128_n256_k384 blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%A_gpu : memref<128x384xf16>, %B_gpu : memref<1x384xf16>, %D_gpu : memref<128x256xf32>) - gpu.dealloc %A_gpu : memref<128x384xf16> - gpu.dealloc %B_gpu : memref<1x384xf16> - return %D_gpu : memref<128x256xf32> + %memref = gpu.alloc () : memref<128x384xf16> + gpu.memcpy %memref, %arg0 : memref<128x384xf16>, memref<128x384xf16> + %memref_0 = gpu.alloc () : memref<1x384xf16> + gpu.memcpy %memref_0, %arg1 : memref<1x384xf16>, memref<1x384xf16> + %memref_1 = gpu.alloc () : memref<128x256xf32> + gpu.launch_func @m128_n256_k384::@m128_n256_k384 blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%memref : memref<128x384xf16>, %memref_0 : memref<1x384xf16>, %memref_1 : memref<128x256xf32>) + gpu.dealloc %memref : memref<128x384xf16> + gpu.dealloc %memref_0 : memref<1x384xf16> + %alloc = memref.alloc() : memref<128x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<128x256xf32>, memref<128x256xf32> + gpu.dealloc %memref_1 : memref<128x256xf32> + return %alloc : memref<128x256xf32> } - - gpu.module @m128_n256_k384 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.module @m128_n256_k384 { gpu.func @m128_n256_k384(%arg0: memref<128x384xf16>, %arg1: memref<1x384xf16>, %arg2: memref<128x256xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array, gpu.known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { cf.br ^bb1 ^bb1: // pred: ^bb0 @@ -59,13 +58,13 @@ module @gemm attributes {gpu.container_module} { %15 = arith.addi %12, %2 : index %16 = xetile.init_tile %arg0[%15, %c0] : memref<128x384xf16> -> !xetile.tile<4x32xf16> xetile.prefetch_tile %16 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> - %17 = xetile.update_tile_offset %16, [%c0, %c32] : !xetile.tile<4x32xf16> + %17 = xetile.update_tile_offset %16, [%c0, %c32] : !xetile.tile<4x32xf16> %18 = xetile.init_tile %arg1[%1, %c0] : memref<1x384xf16> -> !xetile.tile<1x32xf16> xetile.prefetch_tile %18 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<1x32xf16> - %19 = xetile.update_tile_offset %18, [%c0, %c32] : !xetile.tile<1x32xf16> + %19 = xetile.update_tile_offset %18, [%c0, %c32] : !xetile.tile<1x32xf16> %20:6 = scf.for %arg3 = %c0 to %c384 step %c32 iter_args(%arg4 = %cst, %arg5 = %13, %arg6 = %14, %arg7 = %17, %arg8 = %19, %arg9 = %c0) -> (vector<32x32xf32>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, !xetile.tile<4x32xf16>, !xetile.tile<1x32xf16>, index) { - %22 = xetile.load_tile %arg5 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %23 = xetile.load_tile %arg6 {padding = 0.000000e+00 : f32} : !xetile.tile<1x32xf16> -> vector<1x32xf16> + %22 = xetile.load_tile %arg5 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %23 = xetile.load_tile %arg6 {padding = 0.000000e+00 : f32} : !xetile.tile<1x32xf16> -> vector<1x32xf16> %24 = arith.cmpi eq, %arg9, %c51 : index %25 = arith.select %24, %c0, %arg9 : index scf.if %24 { @@ -76,13 +75,13 @@ module @gemm attributes {gpu.container_module} { xetile.prefetch_tile %arg7 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> xetile.prefetch_tile %arg8 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<1x32xf16> xegpu.compile_hint - %27 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<4x32xf16> - %28 = xetile.update_tile_offset %arg8, [%c0, %c32] : !xetile.tile<1x32xf16> + %27 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<4x32xf16> + %28 = xetile.update_tile_offset %arg8, [%c0, %c32] : !xetile.tile<1x32xf16> %29 = vector.transpose %23, [1, 0] : vector<1x32xf16> to vector<32x1xf16> xegpu.compile_hint %30 = xetile.broadcast %29 [1] : vector<32x1xf16> -> vector<32x32xf16> - %31 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16> - %32 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<1x32xf16> + %31 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16> + %32 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<1x32xf16> xegpu.compile_hint // Use broadcast result directly %33 = xetile.tile_mma %22, %30, %arg4 : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32> @@ -94,43 +93,42 @@ module @gemm attributes {gpu.container_module} { gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 3.840000e+02 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index %c384 = arith.constant 384 : index - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<128x384xf16> - %B = memref.alloc() : memref<1x384xf16> - %C_ref = memref.alloc() : memref<128x256xf32> // Make matrix A an identity matrix - scf.for %i = %c0 to %c128 step %c1 { - scf.for %j = %c0 to %c384 step %c1 { - memref.store %cf_1, %A[%i, %j] : memref<128x384xf16> + %cst_0 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<128x384xf16> + %alloc_1 = memref.alloc() : memref<1x384xf16> + %alloc_2 = memref.alloc() : memref<128x256xf32> + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c384 step %c1 { + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<128x384xf16> } } // Make matrix B an identity matrix - scf.for %j = %c0 to %c384 step %c1 { - memref.store %cf_1, %B[%c0, %j] : memref<1x384xf16> + scf.for %arg0 = %c0 to %c384 step %c1 { + memref.store %cst_0, %alloc_1[%c0, %arg0] : memref<1x384xf16> } // intialize matrix C_ref - %cf_384 = arith.constant 384.0 : f32 - scf.for %i = %c0 to %c128 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - memref.store %cf_384, %C_ref[%i, %j] : memref<128x256xf32> + scf.for %arg0 = %c0 to %c128 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<128x256xf32> } } - %C = call @test(%A, %B) : (memref<128x384xf16>, memref<1x384xf16>) -> memref<128x256xf32> - %cast_C = memref.cast %C : memref<128x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<128x256xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<128x384xf16> - memref.dealloc %B : memref<1x384xf16> - memref.dealloc %C_ref : memref<128x256xf32> + %0 = call @test(%alloc, %alloc_1) : (memref<128x384xf16>, memref<1x384xf16>) -> memref<128x256xf32> + %cast = memref.cast %0 : memref<128x256xf32> to memref<*xf32> + %cast_3 = memref.cast %alloc_2 : memref<128x256xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_3) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<128x384xf16> + memref.dealloc %alloc_1 : memref<1x384xf16> + memref.dealloc %alloc_2 : memref<128x256xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir b/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir index 6caa056a0..d4221d46a 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_postop.mlir @@ -1,139 +1,122 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %A, %A_gpu : memref<256x256xf16> to memref<256x256xf16> - %B_gpu = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %B, %B_gpu : memref<256x256xf16> to memref<256x256xf16> - %C_gpu = gpu.alloc host_shared () : memref<256x256xf32> - memref.copy %C, %C_gpu : memref<256x256xf32> to memref<256x256xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<256x256xf16>, %B_gpu : memref<256x256xf16>, %C_gpu : memref<256x256xf32>) - gpu.dealloc %A_gpu : memref<256x256xf16> - gpu.dealloc %B_gpu : memref<256x256xf16> - return %C_gpu : memref<256x256xf32> + %memref = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref, %arg0 : memref<256x256xf16>, memref<256x256xf16> + %memref_0 = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_0, %arg1 : memref<256x256xf16>, memref<256x256xf16> + %memref_1 = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_1, %arg2 : memref<256x256xf32>, memref<256x256xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<256x256xf16>, %memref_0 : memref<256x256xf16>, %memref_1 : memref<256x256xf32>) + gpu.dealloc %memref : memref<256x256xf16> + gpu.dealloc %memref_0 : memref<256x256xf16> + %alloc = memref.alloc() : memref<256x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_1 : memref<256x256xf32> + return %alloc : memref<256x256xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c256 = arith.constant 256 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<256x256xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<256x256xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<256x256xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c256 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a, %b, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<256x256xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<256x256xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<256x256xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %8 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %9 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %10 = xetile.tile_mma %8, %9, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %11 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %12 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %11, %12, %10 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // post op - %exp = math.exp %out#2 : vector<16x32xf32> // store the final accumulated C tile result back to memory - xetile.store_tile %exp, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + %7 = math.exp %6#2 : vector<16x32xf32> + xetile.store_tile %7, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c256 = arith.constant 256 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<256x256xf16> - %B = memref.alloc() : memref<256x256xf16> - %C = memref.alloc() : memref<256x256xf32> - %C_ref = memref.alloc() : memref<256x256xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<256x256xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<256x256xf16> + %alloc_2 = memref.alloc() : memref<256x256xf16> + %alloc_3 = memref.alloc() : memref<256x256xf32> + %alloc_4 = memref.alloc() : memref<256x256xf32> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<256x256xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<256x256xf16> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<256x256xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<256x256xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<256x256xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } // compute C for reference - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<256x256xf32> - %c_val = scf.for %k = %c0 to %c256 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a = memref.load %A[%i, %k] : memref<256x256xf16> - %b = memref.load %B[%k, %j] : memref<256x256xf16> - %t = arith.mulf %a, %b : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<256x256xf32> + %2 = scf.for %arg2 = %c0 to %c256 step %c1 iter_args(%arg3 = %1) -> (f32) { + %4 = memref.load %alloc[%arg0, %arg2] : memref<256x256xf16> + %5 = memref.load %alloc_2[%arg2, %arg1] : memref<256x256xf16> + %6 = arith.mulf %4, %5 : f16 + %7 = arith.extf %6 : f16 to f32 + %8 = arith.addf %7, %arg3 : f32 + scf.yield %8 : f32 } - %exp = math.exp %c_val : f32 - memref.store %exp , %C_ref[%i, %j] : memref<256x256xf32> + %3 = math.exp %2 : f32 + memref.store %3, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } - %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> // %cast = memref.cast %B : memref<256x256xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> // Debugging prints (Do not remove) // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () @@ -143,13 +126,15 @@ module @gemm attributes {gpu.container_module} { // %C_ref_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32> // %C_ref_row_0_cast = memref.cast %C_ref_row_0 : memref<1x256xf32> to memref<*xf32> // call @printMemrefF32(%C_ref_row_0_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<256x256xf16> - memref.dealloc %B : memref<256x256xf16> - memref.dealloc %C : memref<256x256xf32> - memref.dealloc %C_ref : memref<256x256xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %cast = memref.cast %0 : memref<256x256xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<256x256xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<256x256xf16> + memref.dealloc %alloc_2 : memref<256x256xf16> + memref.dealloc %alloc_3 : memref<256x256xf32> + memref.dealloc %alloc_4 : memref<256x256xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir index 49c9b91fb..c9e22e951 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_a.mlir @@ -1,139 +1,117 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %bcast: memref<1x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %bcast_gpu = gpu.alloc host_shared () : memref<1x1024xf16> - memref.copy %bcast, %bcast_gpu : memref<1x1024xf16> to memref<1x1024xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>, %bcast_gpu : memref<1x1024xf16>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_2 = gpu.alloc () : memref<1x1024xf16> + gpu.memcpy %memref_2, %arg3 : memref<1x1024xf16>, memref<1x1024xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>, %memref_2 : memref<1x1024xf16>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %bcast: memref<1x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> - %bcast_init_tile = xetile.init_tile %bcast[%c0, %c0] : memref<1x1024xf16> -> !xetile.tile<1x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:4 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %bcast_tile = %bcast_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %bcast_val = xetile.load_tile %bcast_tile : !xetile.tile<1x32xf16> -> vector<1x32xf16> // broadcast and add to a - %bcast_val_ = xetile.broadcast %bcast_val [0] : vector<1x32xf16> -> vector<16x32xf16> - %a_value = arith.addf %a, %bcast_val_ : vector<16x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A, B and bcast - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> - %bcast_next_tile = xetile.update_tile_offset %bcast_tile, [%c0, %c32] - : !xetile.tile<1x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %bcast_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6 = xetile.init_tile %arg3[%c0, %c0] : memref<1x1024xf16> -> !xetile.tile<1x32xf16> + %7:4 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %4, %arg6 = %5, %arg7 = %6, %arg8 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32>) { + %8 = xetile.load_tile %arg5 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %9 = xetile.load_tile %arg6 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %10 = xetile.load_tile %arg7 : !xetile.tile<1x32xf16> -> vector<1x32xf16> + %11 = xetile.broadcast %10 [0] : vector<1x32xf16> -> vector<16x32xf16> + %12 = arith.addf %8, %11 : vector<16x32xf16> + %13 = xetile.tile_mma %12, %9, %arg8 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %14 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<16x32xf16> + %15 = xetile.update_tile_offset %arg6, [%c32, %c0] : !xetile.tile<32x32xf16> + %16 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<1x32xf16> + scf.yield %14, %15, %16, %13 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#3, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %7#3, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %true = arith.constant true + %cst_0 = arith.constant 3.000000e+00 : f32 + %cst_1 = arith.constant -3.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %cf_11 = arith.constant 1.1 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %bcast = memref.alloc() : memref<1x1024xf16> - %C_ref = memref.alloc() : memref<1024x1024xf32> - // random init - %cf_lower = arith.constant -3.0 : f32 - %cf_upper = arith.constant 3.0 : f32 - %c_gen_int = arith.constant 1 : i1 - - %A_random = memref.cast %A : memref<1024x1024xf16> to memref<*xf16> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - %B_random = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - %bcast_random = memref.cast %bcast : memref<1x1024xf16> to memref<*xf16> - call @fillResource1DRandomF16(%bcast_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1x1024xf16> + %alloc_5 = memref.alloc() : memref<1024x1024xf32> + %cast = memref.cast %alloc : memref<1024x1024xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_1, %cst_0, %true) : (memref<*xf16>, f32, f32, i1) -> () + %cast_6 = memref.cast %alloc_2 : memref<1024x1024xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_6, %cst_1, %cst_0, %true) : (memref<*xf16>, f32, f32, i1) -> () + %cast_7 = memref.cast %alloc_4 : memref<1x1024xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_7, %cst_1, %cst_0, %true) : (memref<*xf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b = memref.load %B[%k, %j] : memref<1024x1024xf16> - %bcast_val = memref.load %bcast[%c0, %k] : memref<1x1024xf16> - %a_val = arith.addf %a, %bcast_val : f16 - %t = arith.mulf %a_val, %b : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = memref.load %alloc_4[%c0, %arg2] : memref<1x1024xf16> + %6 = arith.addf %3, %5 : f16 + %7 = arith.mulf %6, %4 : f16 + %8 = arith.extf %7 : f16 to f32 + %9 = arith.addf %8, %arg3 : f32 + scf.yield %9 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C, %bcast) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>, memref<1x1024xf16>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // Debugging prints (Do not remove) // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () @@ -143,13 +121,15 @@ module @gemm attributes {gpu.container_module} { // %C_ref_row_0 = memref.subview %C_ref[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_ref_row_0_cast = memref.cast %C_ref_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_ref_row_0_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3, %alloc_4) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>, memref<1x1024xf16>) -> memref<1024x1024xf32> + %cast_8 = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_9 = memref.cast %alloc_5 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast_8, %cast_9) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_5 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir index 54da4f5ec..e6fc5541f 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_pre_broadcast_b.mlir @@ -1,142 +1,119 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %bcast: memref<1x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %bcast_gpu = gpu.alloc host_shared () : memref<1x1024xf16> - memref.copy %bcast, %bcast_gpu : memref<1x1024xf16> to memref<1x1024xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>, %bcast_gpu : memref<1x1024xf16>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_2 = gpu.alloc () : memref<1x1024xf16> + gpu.memcpy %memref_2, %arg3 : memref<1x1024xf16>, memref<1x1024xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>, %memref_2 : memref<1x1024xf16>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %bcast: memref<1x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> - %bcast_init_tile = xetile.init_tile %bcast[%c0, %c0] : memref<1x1024xf16> -> !xetile.tile<1x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:4 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %bcast_tile = %bcast_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %bcast_val = xetile.load_tile %bcast_tile : !xetile.tile<1x32xf16> -> vector<1x32xf16> - %t1 = vector.shape_cast %bcast_val : vector<1x32xf16> to vector<32xf16> - %t2 = vector.shape_cast %t1 : vector<32xf16> to vector<32x1xf16> // broadcast and add to a - %bcast_val_ = xetile.broadcast %t2 [1] : vector<32x1xf16> -> vector<32x32xf16> - %b_value = arith.addf %b, %bcast_val_ : vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a, %b_value, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A, B and bcast - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> - %bcast_next_tile = xetile.update_tile_offset %bcast_tile, [%c0, %c32] - : !xetile.tile<1x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %bcast_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6 = xetile.init_tile %arg3[%c0, %c0] : memref<1x1024xf16> -> !xetile.tile<1x32xf16> + %7:4 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %4, %arg6 = %5, %arg7 = %6, %arg8 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32>) { + %8 = xetile.load_tile %arg5 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %9 = xetile.load_tile %arg6 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %10 = xetile.load_tile %arg7 : !xetile.tile<1x32xf16> -> vector<1x32xf16> + %11 = vector.shape_cast %10 : vector<1x32xf16> to vector<32xf16> + %12 = vector.shape_cast %11 : vector<32xf16> to vector<32x1xf16> + %13 = xetile.broadcast %12 [1] : vector<32x1xf16> -> vector<32x32xf16> + %14 = arith.addf %9, %13 : vector<32x32xf16> + %15 = xetile.tile_mma %8, %14, %arg8 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %16 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<16x32xf16> + %17 = xetile.update_tile_offset %arg6, [%c32, %c0] : !xetile.tile<32x32xf16> + %18 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<1x32xf16> + scf.yield %16, %17, %18, %15 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<1x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#3, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %7#3, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %true = arith.constant true + %cst_0 = arith.constant 3.000000e+00 : f32 + %cst_1 = arith.constant -3.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %cf_11 = arith.constant 1.1 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %bcast = memref.alloc() : memref<1x1024xf16> - %C_ref = memref.alloc() : memref<1024x1024xf32> - // random init - %cf_lower = arith.constant -3.0 : f32 - %cf_upper = arith.constant 3.0 : f32 - %c_gen_int = arith.constant 1 : i1 - - %A_random = memref.cast %A : memref<1024x1024xf16> to memref<*xf16> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - %B_random = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - %bcast_random = memref.cast %bcast : memref<1x1024xf16> to memref<*xf16> - call @fillResource1DRandomF16(%bcast_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1x1024xf16> + %alloc_5 = memref.alloc() : memref<1024x1024xf32> + %cast = memref.cast %alloc : memref<1024x1024xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast, %cst_1, %cst_0, %true) : (memref<*xf16>, f32, f32, i1) -> () + %cast_6 = memref.cast %alloc_2 : memref<1024x1024xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_6, %cst_1, %cst_0, %true) : (memref<*xf16>, f32, f32, i1) -> () + %cast_7 = memref.cast %alloc_4 : memref<1x1024xf16> to memref<*xf16> + call @fillResource1DRandomF16(%cast_7, %cst_1, %cst_0, %true) : (memref<*xf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b = memref.load %B[%k, %j] : memref<1024x1024xf16> - %bcast_val = memref.load %bcast[%c0, %k] : memref<1x1024xf16> - %b_val = arith.addf %b, %bcast_val : f16 - %t = arith.mulf %a, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = memref.load %alloc_4[%c0, %arg2] : memref<1x1024xf16> + %6 = arith.addf %4, %5 : f16 + %7 = arith.mulf %3, %6 : f16 + %8 = arith.extf %7 : f16 to f32 + %9 = arith.addf %8, %arg3 : f32 + scf.yield %9 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C, %bcast) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>, memref<1x1024xf16>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // Debugging prints (Do not remove) // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () @@ -146,18 +123,19 @@ module @gemm attributes {gpu.container_module} { // %C_ref_row_0 = memref.subview %C_ref[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_ref_row_0_cast = memref.cast %C_ref_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_ref_row_0_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3, %alloc_4) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>, memref<1x1024xf16>) -> memref<1024x1024xf32> + %cast_8 = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_9 = memref.cast %alloc_5 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast_8, %cast_9) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_5 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} - } diff --git a/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir b/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir index a0a80eeab..0b59b5071 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_preop_a.mlir @@ -1,138 +1,121 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %a_value = arith.addf %a, %a: vector<16x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = arith.addf %7, %7 : vector<16x32xf16> + %10 = xetile.tile_mma %9, %8, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %11 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %12 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %11, %12, %10 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b = memref.load %B[%k, %j] : memref<1024x1024xf16> - %a_val = arith.addf %a, %a: f16 - %t = arith.mulf %a_val, %b : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.addf %3, %3 : f16 + %6 = arith.mulf %5, %4 : f16 + %7 = arith.extf %6 : f16 to f32 + %8 = arith.addf %7, %arg3 : f32 + scf.yield %8 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // Debugging prints (Do not remove) // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () @@ -142,13 +125,15 @@ module @gemm attributes {gpu.container_module} { // %C_ref_row_0 = memref.subview %C_ref[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_ref_row_0_cast = memref.cast %C_ref_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_ref_row_0_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir index 16f49cff5..a78e52631 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_preop_a_b.mlir @@ -1,140 +1,123 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %A, %A_gpu : memref<256x256xf16> to memref<256x256xf16> - %B_gpu = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %B, %B_gpu : memref<256x256xf16> to memref<256x256xf16> - %C_gpu = gpu.alloc host_shared () : memref<256x256xf32> - memref.copy %C, %C_gpu : memref<256x256xf32> to memref<256x256xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<256x256xf16>, %B_gpu : memref<256x256xf16>, %C_gpu : memref<256x256xf32>) - gpu.dealloc %A_gpu : memref<256x256xf16> - gpu.dealloc %B_gpu : memref<256x256xf16> - return %C_gpu : memref<256x256xf32> + %memref = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref, %arg0 : memref<256x256xf16>, memref<256x256xf16> + %memref_0 = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_0, %arg1 : memref<256x256xf16>, memref<256x256xf16> + %memref_1 = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_1, %arg2 : memref<256x256xf32>, memref<256x256xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<256x256xf16>, %memref_0 : memref<256x256xf16>, %memref_1 : memref<256x256xf32>) + gpu.dealloc %memref : memref<256x256xf16> + gpu.dealloc %memref_0 : memref<256x256xf16> + %alloc = memref.alloc() : memref<256x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_1 : memref<256x256xf32> + return %alloc : memref<256x256xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c256 = arith.constant 256 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<256x256xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<256x256xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<256x256xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c256 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %a_value = arith.addf %a, %a: vector<16x32xf16> - %b_value = arith.addf %b, %b: vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<256x256xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<256x256xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<256x256xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = arith.addf %7, %7 : vector<16x32xf16> + %10 = arith.addf %8, %8 : vector<32x32xf16> + %11 = xetile.tile_mma %9, %10, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %12 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %13 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %12, %13, %11 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c256 = arith.constant 256 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<256x256xf16> - %B = memref.alloc() : memref<256x256xf16> - %C = memref.alloc() : memref<256x256xf32> - %C_ref = memref.alloc() : memref<256x256xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<256x256xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<256x256xf16> + %alloc_2 = memref.alloc() : memref<256x256xf16> + %alloc_3 = memref.alloc() : memref<256x256xf32> + %alloc_4 = memref.alloc() : memref<256x256xf32> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<256x256xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<256x256xf16> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<256x256xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<256x256xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<256x256xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<256x256xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } // compute C for reference - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<256x256xf32> - %c_val = scf.for %k = %c0 to %c256 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a = memref.load %A[%i, %k] : memref<256x256xf16> - %b = memref.load %B[%k, %j] : memref<256x256xf16> - %a_val = arith.addf %a, %a: f16 - %b_val = arith.addf %b, %b: f16 - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<256x256xf32> + %2 = scf.for %arg2 = %c0 to %c256 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<256x256xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<256x256xf16> + %5 = arith.addf %3, %3 : f16 + %6 = arith.addf %4, %4 : f16 + %7 = arith.mulf %5, %6 : f16 + %8 = arith.extf %7 : f16 to f32 + %9 = arith.addf %8, %arg3 : f32 + scf.yield %9 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<256x256xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } - %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> // %cast = memref.cast %B : memref<256x256xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> // Debugging prints (Do not remove) // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () @@ -144,13 +127,15 @@ module @gemm attributes {gpu.container_module} { // %C_ref_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32> // %C_ref_row_0_cast = memref.cast %C_ref_row_0 : memref<1x256xf32> to memref<*xf32> // call @printMemrefF32(%C_ref_row_0_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<256x256xf16> - memref.dealloc %B : memref<256x256xf16> - memref.dealloc %C : memref<256x256xf32> - memref.dealloc %C_ref : memref<256x256xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %cast = memref.cast %0 : memref<256x256xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<256x256xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<256x256xf16> + memref.dealloc %alloc_2 : memref<256x256xf16> + memref.dealloc %alloc_3 : memref<256x256xf32> + memref.dealloc %alloc_4 : memref<256x256xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir b/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir index 77c24182b..d54fce92a 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_preop_b.mlir @@ -1,138 +1,121 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_value = arith.addf %b, %b: vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a, %b_value, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = arith.addf %8, %8 : vector<32x32xf16> + %10 = xetile.tile_mma %7, %9, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %11 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %12 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16> + scf.yield %11, %12, %10 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b = memref.load %B[%k, %j] : memref<1024x1024xf16> - %b_val = arith.addf %b, %b: f16 - %t = arith.mulf %a, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.addf %4, %4 : f16 + %6 = arith.mulf %3, %5 : f16 + %7 = arith.extf %6 : f16 to f32 + %8 = arith.addf %7, %arg3 : f32 + scf.yield %8 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // Debugging prints (Do not remove) // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () @@ -142,13 +125,15 @@ module @gemm attributes {gpu.container_module} { // %C_ref_row_0 = memref.subview %C_ref[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_ref_row_0_cast = memref.cast %C_ref_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_ref_row_0_cast) : (memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_transpose_b_order.mlir b/test/Integration/Dialect/XeTile/sg_gemm_transpose_b_order.mlir index 64966bcd9..c7df612ad 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_transpose_b_order.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_transpose_b_order.mlir @@ -1,136 +1,118 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %A, %A_gpu : memref<256x256xf16> to memref<256x256xf16> - %B_gpu = gpu.alloc host_shared () : memref<256x256xf16> - memref.copy %B, %B_gpu : memref<256x256xf16> to memref<256x256xf16> - %C_gpu = gpu.alloc host_shared () : memref<256x256xf32> - memref.copy %C, %C_gpu : memref<256x256xf32> to memref<256x256xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<256x256xf16>, %B_gpu : memref<256x256xf16>, %C_gpu : memref<256x256xf32>) - gpu.dealloc %A_gpu : memref<256x256xf16> - gpu.dealloc %B_gpu : memref<256x256xf16> - return %C_gpu : memref<256x256xf32> + %memref = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref, %arg0 : memref<256x256xf16>, memref<256x256xf16> + %memref_0 = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_0, %arg1 : memref<256x256xf16>, memref<256x256xf16> + %memref_1 = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_1, %arg2 : memref<256x256xf32>, memref<256x256xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<256x256xf16>, %memref_0 : memref<256x256xf16>, %memref_1 : memref<256x256xf32>) + gpu.dealloc %memref : memref<256x256xf16> + gpu.dealloc %memref_0 : memref<256x256xf16> + %alloc = memref.alloc() : memref<256x256xf32> + gpu.memcpy %alloc, %memref_1 : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_1 : memref<256x256xf32> + return %alloc : memref<256x256xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c256 = arith.constant 256 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<256x256xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<256x256xf16> -> !xetile.tile<16x32xf16> - %B_cast = memref.reinterpret_cast %B to offset : [0], sizes : [256,256], strides : [1, 256] : memref<256x256xf16> to memref<256x256xf16, strided<[1, 256], offset:0>> - %b_init_tile = xetile.init_tile %B_cast[%c0, %n] : memref<256x256xf16, strided<[1, 256], offset:0>> -> !xetile.tile<32x32xf16, #xetile.tile_attr> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c256 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<32x32xf16> // %b_trans = vector.transpose %b_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] : !xetile.tile<32x32xf16, #xetile.tile_attr> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<256x256xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<256x256xf16> -> !xetile.tile<16x32xf16> + %reinterpret_cast = memref.reinterpret_cast %arg1 to offset: [0], sizes: [256, 256], strides: [1, 256] : memref<256x256xf16> to memref<256x256xf16, strided<[1, 256]>> + %5 = xetile.init_tile %reinterpret_cast[%c0, %1] : memref<256x256xf16, strided<[1, 256]>> -> !xetile.tile<32x32xf16, #xetile.tile_attr> + %6:3 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16, #xetile.tile_attr> -> vector<32x32xf16> + %9 = xetile.tile_mma %7, %8, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %10 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %11 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x32xf16, #xetile.tile_attr> + scf.yield %10, %11, %9 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16, #xetile.tile_attr>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %true = arith.constant true + %cst = arith.constant 3.000000e+00 : f32 + %cst_0 = arith.constant -3.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c256 = arith.constant 256 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_0_f32 = arith.constant 0.0 : f32 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<256x256xf16> - %B = memref.alloc() : memref<256x256xf16> - %C = memref.alloc() : memref<256x256xf32> - %C_ref = memref.alloc() : memref<256x256xf32> - // fill A, B with random values - %cf_lower = arith.constant -3.0 : f32 - %cf_upper = arith.constant 3.0 : f32 - %c_gen_int = arith.constant 1 : i1 - %A_random = memref.cast %A : memref<256x256xf16> to memref<*xf16> - %B_random = memref.cast %B : memref<256x256xf16> to memref<*xf16> - %C_zeros = memref.cast %C : memref<256x256xf32> to memref<*xf32> - %C_ref_zeros = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // fill C, C_ref with zeros - call @fillResource1DF32(%C_zeros, %cf_0_f32) : (memref<*xf32>, f32) -> () - call @fillResource1DF32(%C_ref_zeros, %cf_0_f32) : (memref<*xf32>, f32) -> () - // compute C for reference - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c256 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<256x256xf32> - %c_val = scf.for %k = %c0 to %c256 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<256x256xf16> - %b_val = memref.load %B[%j, %k] : memref<256x256xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + %cst_1 = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() : memref<256x256xf16> + %alloc_2 = memref.alloc() : memref<256x256xf16> + %alloc_3 = memref.alloc() : memref<256x256xf32> + %alloc_4 = memref.alloc() : memref<256x256xf32> + %cast = memref.cast %alloc : memref<256x256xf16> to memref<*xf16> + %cast_5 = memref.cast %alloc_2 : memref<256x256xf16> to memref<*xf16> + %cast_6 = memref.cast %alloc_3 : memref<256x256xf32> to memref<*xf32> + %cast_7 = memref.cast %alloc_4 : memref<256x256xf32> to memref<*xf32> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %true) : (memref<*xf16>, f32, f32, i1) -> () + call @fillResource1DRandomF16(%cast_5, %cst_0, %cst, %true) : (memref<*xf16>, f32, f32, i1) -> () + call @fillResource1DF32(%cast_6, %cst_1) : (memref<*xf32>, f32) -> () + call @fillResource1DF32(%cast_7, %cst_1) : (memref<*xf32>, f32) -> () + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c256 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<256x256xf32> + %2 = scf.for %arg2 = %c0 to %c256 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<256x256xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<256x256xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<256x256xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<256x256xf32> } } - - %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> // %cast = memref.cast %B : memref<256x256xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[1, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset: 256>> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x256xf32, strided<[256, 1], offset: 256>>to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<256x256xf16> - memref.dealloc %B : memref<256x256xf16> - memref.dealloc %C : memref<256x256xf32> - memref.dealloc %C_ref : memref<256x256xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %cast_8 = memref.cast %0 : memref<256x256xf32> to memref<*xf32> + %cast_9 = memref.cast %alloc_4 : memref<256x256xf32> to memref<*xf32> + call @printAllcloseF32(%cast_8, %cast_9) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<256x256xf16> + memref.dealloc %alloc_2 : memref<256x256xf16> + memref.dealloc %alloc_3 : memref<256x256xf32> + memref.dealloc %alloc_4 : memref<256x256xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_1.mlir b/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_1.mlir index 8b15b26bc..44e1d4ebd 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_1.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_1.mlir @@ -1,167 +1,150 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %B1: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>, %arg3: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B1_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B1, %B1_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %B1_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - gpu.dealloc %B1_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_2 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_2, %arg3 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf16>, %memref_2 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + gpu.dealloc %memref_1 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_2 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %B1: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>, %arg3: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%n, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> - %b1_init_tile = xetile.init_tile %B1[%n, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:4 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %b1_tile = %b1_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_trans = vector.transpose %b_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> - %b1_value = xetile.load_tile %b1_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b1_trans = vector.transpose %b1_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> - %b_final = arith.addf %b_trans, %b1_trans : vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_final, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c32] - : !xetile.tile<32x32xf16> - %b1_next_tile = xetile.update_tile_offset %b1_tile, [%c0, %c32] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %b1_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg3[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%1, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6 = xetile.init_tile %arg2[%1, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %7:4 = scf.for %arg4 = %c0 to %c1024 step %c32 iter_args(%arg5 = %4, %arg6 = %5, %arg7 = %6, %arg8 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %8 = xetile.load_tile %arg5 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %9 = xetile.load_tile %arg6 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %10 = vector.transpose %9, [1, 0] : vector<32x32xf16> to vector<32x32xf16> + %11 = xetile.load_tile %arg7 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %12 = vector.transpose %11, [1, 0] : vector<32x32xf16> to vector<32x32xf16> + %13 = arith.addf %10, %12 : vector<32x32xf16> + %14 = xetile.tile_mma %8, %13, %arg8 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %15 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<16x32xf16> + %16 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<32x32xf16> + %17 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16> + scf.yield %15, %16, %17, %14 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#3, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %7#3, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %B1 = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix B ; B[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %B[%i, %j] : memref<1024x1024xf16> - memref.store %val, %B1[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf16> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + %alloc_5 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> + memref.store %2, %alloc_3[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix A an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %A[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %A[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%j, %k] : memref<1024x1024xf16> - %b1_val = memref.load %B1[%j, %k] : memref<1024x1024xf16> - %b = arith.addf %b_val, %b1_val : f16 - %t = arith.mulf %a_val, %b : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<1024x1024xf16> + %5 = memref.load %alloc_3[%arg1, %arg2] : memref<1024x1024xf16> + %6 = arith.addf %4, %5 : f16 + %7 = arith.mulf %3, %6 : f16 + %8 = arith.extf %7 : f16 to f32 + %9 = arith.addf %8, %arg3 : f32 + scf.yield %9 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_5[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %B1, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) - -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %B1 : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3, %alloc_4) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_6 = memref.cast %alloc_5 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_6) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf16> + memref.dealloc %alloc_4 : memref<1024x1024xf32> + memref.dealloc %alloc_5 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_2.mlir b/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_2.mlir index 2aec11bdc..6bf1503ef 100644 --- a/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_2.mlir +++ b/test/Integration/Dialect/XeTile/sg_gemm_transpose_binary_preop_b_2.mlir @@ -1,155 +1,140 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %cst_b1 = arith.constant dense<1.0> : vector<32x32xf16> - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %b_init_tile = xetile.init_tile %B[%n, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_trans = vector.transpose %b_value, [1, 0] : vector<32x32xf16> to vector<32x32xf16> - %b_final = arith.addf %b_trans, %cst_b1: vector<32x32xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_final, %c_value - : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<16x32xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c32] - : !xetile.tile<32x32xf16> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + %cst = arith.constant dense<1.000000e+00> : vector<32x32xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf32> -> vector<16x32xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %5 = xetile.init_tile %arg1[%1, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %9 = vector.transpose %8, [1, 0] : vector<32x32xf16> to vector<32x32xf16> + %10 = arith.addf %9, %cst : vector<32x32xf16> + %11 = xetile.tile_mma %7, %10, %arg6 : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + %12 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<16x32xf16> + %13 = xetile.update_tile_offset %arg5, [%c0, %c32] : !xetile.tile<32x32xf16> + scf.yield %12, %13, %11 : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + xetile.store_tile %6#2, %2 : vector<16x32xf32>, !xetile.tile<16x32xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix B ; B[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %B[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix A an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %A[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %A[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%j, %k] : memref<1024x1024xf16> - %b_final = arith.addf %b_val, %cf_1 : f16 - %t = arith.mulf %a_val, %b_final : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<1024x1024xf16> + %5 = arith.addf %4, %cst_1 : f16 + %6 = arith.mulf %3, %5 : f16 + %7 = arith.extf %6 : f16 to f32 + %8 = arith.addf %7, %arg3 : f32 + scf.yield %8 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/simple_order.mlir b/test/Integration/Dialect/XeTile/simple_order.mlir index d37a800f5..746cc7c26 100644 --- a/test/Integration/Dialect/XeTile/simple_order.mlir +++ b/test/Integration/Dialect/XeTile/simple_order.mlir @@ -1,123 +1,108 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<32x64xf16>, %B: memref<64x64xf16>, %C: memref<32x64xf32>) -> memref<32x64xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<32x64xf16>, %arg1: memref<64x64xf16>, %arg2: memref<32x64xf32>) -> memref<32x64xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %A_gpu = gpu.alloc host_shared () : memref<32x64xf16> - memref.copy %A, %A_gpu : memref<32x64xf16> to memref<32x64xf16> - %B_gpu = gpu.alloc host_shared () : memref<64x64xf16> - memref.copy %B, %B_gpu : memref<64x64xf16> to memref<64x64xf16> - %C_gpu = gpu.alloc host_shared () : memref<32x64xf32> - memref.copy %C, %C_gpu : memref<32x64xf32> to memref<32x64xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<32x64xf16>, %B_gpu : memref<64x64xf16>, %C_gpu : memref<32x64xf32>) - gpu.dealloc %A_gpu : memref<32x64xf16> - gpu.dealloc %B_gpu : memref<64x64xf16> - return %C_gpu : memref<32x64xf32> + %memref = gpu.alloc () : memref<32x64xf16> + gpu.memcpy %memref, %arg0 : memref<32x64xf16>, memref<32x64xf16> + %memref_0 = gpu.alloc () : memref<64x64xf16> + gpu.memcpy %memref_0, %arg1 : memref<64x64xf16>, memref<64x64xf16> + %memref_1 = gpu.alloc () : memref<32x64xf32> + gpu.memcpy %memref_1, %arg2 : memref<32x64xf32>, memref<32x64xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<32x64xf16>, %memref_0 : memref<64x64xf16>, %memref_1 : memref<32x64xf32>) + gpu.dealloc %memref : memref<32x64xf16> + gpu.dealloc %memref_0 : memref<64x64xf16> + %alloc = memref.alloc() : memref<32x64xf32> + gpu.memcpy %alloc, %memref_1 : memref<32x64xf32>, memref<32x64xf32> + gpu.dealloc %memref_1 : memref<32x64xf32> + return %alloc : memref<32x64xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<32x64xf16>, %B: memref<64x64xf16>, %C: memref<32x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { /// canonicalize + gpu.func @test_kernel(%arg0: memref<32x64xf16>, %arg1: memref<64x64xf16>, %arg2: memref<32x64xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y // intialize C tile and load it - %c_tile = xetile.init_tile %C[%c0, %c0] : memref<32x64xf32> -> !xetile.tile<32x64xf32> - %c_value = xetile.load_tile %c_tile : !xetile.tile<32x64xf32> -> vector<32x64xf32> - %B_cast = memref.reinterpret_cast %B to offset : [0], sizes : [64, 64], strides : [1, 64] : memref<64x64xf16> to memref<64x64xf16, strided<[1, 64], offset:0>> // k iter 0 : do a partial C tile 32x32x64 - %a_tile = xetile.init_tile %A[%c0, %c0] : memref<32x64xf16> -> !xetile.tile<32x32xf16> - %b_tile = xetile.init_tile %B_cast[%c0, %c0] : memref<64x64xf16, strided<[1, 64], offset:0>> -> !xetile.tile<32x64xf16, #xetile.tile_attr> - %a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x64xf16, #xetile.tile_attr> -> vector<32x64xf16> - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<32x32xf16>, vector<32x64xf16>, vector<32x64xf32> -> vector<32x64xf32> // k iter 1 : update offsets and do a partial C tile 32x32x64 - %a_tile_1 = xetile.update_tile_offset %a_tile, [%c0, %c32] : !xetile.tile<32x32xf16> - %b_tile_1 = xetile.update_tile_offset %b_tile, [%c32, %c0] : !xetile.tile<32x64xf16, #xetile.tile_attr> - %a_value_1 = xetile.load_tile %a_tile_1 : !xetile.tile<32x32xf16> -> vector<32x32xf16> - %b_value_1 = xetile.load_tile %b_tile_1 : !xetile.tile<32x64xf16, #xetile.tile_attr> -> vector<32x64xf16> - %c_new_value_1 = xetile.tile_mma %a_value_1, %b_value_1, %c_new_value - : vector<32x32xf16>, vector<32x64xf16>, vector<32x64xf32> -> vector<32x64xf32> // store the C tile result back to memory - xetile.store_tile %c_new_value_1, %c_tile: vector<32x64xf32>, !xetile.tile<32x64xf32> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = xetile.init_tile %arg2[%c0, %c0] : memref<32x64xf32> -> !xetile.tile<32x64xf32> + %1 = xetile.load_tile %0 : !xetile.tile<32x64xf32> -> vector<32x64xf32> + %reinterpret_cast = memref.reinterpret_cast %arg1 to offset: [0], sizes: [64, 64], strides: [1, 64] : memref<64x64xf16> to memref<64x64xf16, strided<[1, 64]>> + %2 = xetile.init_tile %arg0[%c0, %c0] : memref<32x64xf16> -> !xetile.tile<32x32xf16> + %3 = xetile.init_tile %reinterpret_cast[%c0, %c0] : memref<64x64xf16, strided<[1, 64]>> -> !xetile.tile<32x64xf16, #xetile.tile_attr> + %4 = xetile.load_tile %2 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %5 = xetile.load_tile %3 : !xetile.tile<32x64xf16, #xetile.tile_attr> -> vector<32x64xf16> + %6 = xetile.tile_mma %4, %5, %1 : vector<32x32xf16>, vector<32x64xf16>, vector<32x64xf32> -> vector<32x64xf32> + %7 = xetile.update_tile_offset %2, [%c0, %c32] : !xetile.tile<32x32xf16> + %8 = xetile.update_tile_offset %3, [%c32, %c0] : !xetile.tile<32x64xf16, #xetile.tile_attr> + %9 = xetile.load_tile %7 : !xetile.tile<32x32xf16> -> vector<32x32xf16> + %10 = xetile.load_tile %8 : !xetile.tile<32x64xf16, #xetile.tile_attr> -> vector<32x64xf16> + %11 = xetile.tile_mma %9, %10, %6 : vector<32x32xf16>, vector<32x64xf16>, vector<32x64xf32> -> vector<32x64xf32> + xetile.store_tile %11, %0 : vector<32x64xf32>, !xetile.tile<32x64xf32> gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %true = arith.constant true + %cst = arith.constant 3.000000e+00 : f32 + %cst_0 = arith.constant -3.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_0_f32 = arith.constant 0.0 : f32 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<32x64xf16> - %B = memref.alloc() : memref<64x64xf16> - %C = memref.alloc() : memref<32x64xf32> - %C_ref = memref.alloc() : memref<32x64xf32> - // fill A, B with random values - %cf_lower = arith.constant -3.0 : f32 - %cf_upper = arith.constant 3.0 : f32 - %c_gen_int = arith.constant 1 : i1 - %A_random = memref.cast %A : memref<32x64xf16> to memref<*xf16> - %B_random = memref.cast %B : memref<64x64xf16> to memref<*xf16> - %C_zeros = memref.cast %C : memref<32x64xf32> to memref<*xf32> - %C_ref_zeros = memref.cast %C_ref : memref<32x64xf32> to memref<*xf32> - call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () - // fill C, C_ref with zeros - call @fillResource1DF32(%C_zeros, %cf_0_f32) : (memref<*xf32>, f32) -> () - call @fillResource1DF32(%C_ref_zeros, %cf_0_f32) : (memref<*xf32>, f32) -> () - // compute C for reference - scf.for %i = %c0 to %c32 step %c1 { - scf.for %j = %c0 to %c64 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<32x64xf32> - %c_val = scf.for %k = %c0 to %c64 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<32x64xf16> - %b_val = memref.load %B[%j, %k] : memref<64x64xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + %cst_1 = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() : memref<32x64xf16> + %alloc_2 = memref.alloc() : memref<64x64xf16> + %alloc_3 = memref.alloc() : memref<32x64xf32> + %alloc_4 = memref.alloc() : memref<32x64xf32> + %cast = memref.cast %alloc : memref<32x64xf16> to memref<*xf16> + %cast_5 = memref.cast %alloc_2 : memref<64x64xf16> to memref<*xf16> + %cast_6 = memref.cast %alloc_3 : memref<32x64xf32> to memref<*xf32> + %cast_7 = memref.cast %alloc_4 : memref<32x64xf32> to memref<*xf32> + call @fillResource1DRandomF16(%cast, %cst_0, %cst, %true) : (memref<*xf16>, f32, f32, i1) -> () + call @fillResource1DRandomF16(%cast_5, %cst_0, %cst, %true) : (memref<*xf16>, f32, f32, i1) -> () + call @fillResource1DF32(%cast_6, %cst_1) : (memref<*xf32>, f32) -> () + call @fillResource1DF32(%cast_7, %cst_1) : (memref<*xf32>, f32) -> () + scf.for %arg0 = %c0 to %c32 step %c1 { + scf.for %arg1 = %c0 to %c64 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<32x64xf32> + %2 = scf.for %arg2 = %c0 to %c64 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<32x64xf16> + %4 = memref.load %alloc_2[%arg1, %arg2] : memref<64x64xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<32x64xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<32x64xf32> } } - - - %2 = call @test(%A, %B, %C) : (memref<32x64xf16>, memref<64x64xf16>, memref<32x64xf32>) -> memref<32x64xf32> // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () - %cast_C = memref.cast %2 : memref<32x64xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<32x64xf32> to memref<*xf32> // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<32x64xf16> - memref.dealloc %B : memref<64x64xf16> - memref.dealloc %C : memref<32x64xf32> - memref.dealloc %C_ref : memref<32x64xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<32x64xf16>, memref<64x64xf16>, memref<32x64xf32>) -> memref<32x64xf32> + %cast_8 = memref.cast %0 : memref<32x64xf32> to memref<*xf32> + %cast_9 = memref.cast %alloc_4 : memref<32x64xf32> to memref<*xf32> + call @printAllcloseF32(%cast_8, %cast_9) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<32x64xf16> + memref.dealloc %alloc_2 : memref<64x64xf16> + memref.dealloc %alloc_3 : memref<32x64xf32> + memref.dealloc %alloc_4 : memref<32x64xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir b/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir index b0d7b30c5..8d30f52aa 100644 --- a/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir +++ b/test/Integration/Dialect/XeTile/transpose_1kx1kx1k_f16_f16_f16.mlir @@ -1,52 +1,44 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : // This example assumes one subgroup per one workgroup and the kernel specifies the computation // done by a single subgroup. - module @transpose attributes {gpu.container_module} { - func.func @transpose_test(%A: memref<1024x1024xf16>) -> memref<1024x1024xf16> attributes {llvm.emit_c_interface} { + func.func @transpose_test(%arg0: memref<1024x1024xf16>) -> memref<1024x1024xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - gpu.launch_func @transpose_kernel::@transpose_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - return %B_gpu : memref<1024x1024xf16> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.launch_func @transpose_kernel::@transpose_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>) + gpu.dealloc %memref : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf16> + gpu.memcpy %alloc, %memref_0 : memref<1024x1024xf16>, memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + return %alloc : memref<1024x1024xf16> } - gpu.module @transpose_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @transpose_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @transpose_kernel { + gpu.func @transpose_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c16 : index - %n = arith.muli %block_id_y, %c32 : index // initalize A and B tiles - %a_tile = xetile.init_tile %A[%m, %n] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> - %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> - %b_value = vector.transpose %a_value, [1, 0] : vector<16x32xf16> to vector<32x16xf16> - - %b_tile = xetile.init_tile %B[%n, %m] : memref<1024x1024xf16> -> !xetile.tile<32x16xf16> - xetile.store_tile %b_value, %b_tile: vector<32x16xf16>, !xetile.tile<32x16xf16> + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c16 : index + %1 = arith.muli %block_id_y, %c32 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %3 = xetile.load_tile %2 : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %4 = vector.transpose %3, [1, 0] : vector<16x32xf16> to vector<32x16xf16> + %5 = xetile.init_tile %arg1[%1, %0] : memref<1024x1024xf16> -> !xetile.tile<32x16xf16> + xetile.store_tile %4, %5 : vector<32x16xf16>, !xetile.tile<32x16xf16> gpu.return } } @@ -54,28 +46,25 @@ module @transpose attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A; A[i, j] = j; B[i, j] = i - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> - %val_f32 = arith.extf %val : f16 to f32 - memref.store %val_f32, %B_ref[%j, %i] : memref<1024x1024xf32> + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_0 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> + %3 = arith.extf %2 : f16 to f32 + memref.store %3, %alloc_0[%arg1, %arg0] : memref<1024x1024xf32> } } - - %2 = call @transpose_test(%A) : (memref<1024x1024xf16>) -> memref<1024x1024xf16> - %cast_B = memref.cast %2 : memref<1024x1024xf16> to memref<*xf16> - %cast_B_ref = memref.cast %B_ref : memref<1024x1024xf32> to memref<*xf32> // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF16(%cast_B, %cast_B_ref) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B_ref : memref<1024x1024xf32> + %0 = call @transpose_test(%alloc) : (memref<1024x1024xf16>) -> memref<1024x1024xf16> + %cast = memref.cast %0 : memref<1024x1024xf16> to memref<*xf16> + %cast_1 = memref.cast %alloc_0 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_1) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_0 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/wg_coop_transpose.mlir b/test/Integration/Dialect/XeTile/wg_coop_transpose.mlir index 29dd5edca..8ad87a70d 100644 --- a/test/Integration/Dialect/XeTile/wg_coop_transpose.mlir +++ b/test/Integration/Dialect/XeTile/wg_coop_transpose.mlir @@ -1,11 +1,7 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck // NOTES : This test load a tile from A, and then do a transpose on it, // and store it back to B, using 16 threads in a workgroup. Each thread @@ -13,27 +9,28 @@ // with other threads via convert_layout. Finally each thread will store // a 8x32 block to B. module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<64x64xf16>) -> memref<64x64xf16> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<64x64xf16>) -> memref<64x64xf16> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c16 = arith.constant 16 : index - %A_gpu = gpu.alloc host_shared () : memref<64x64xf16> - memref.copy %A, %A_gpu : memref<64x64xf16> to memref<64x64xf16> - %B_gpu = gpu.alloc host_shared () : memref<64x64xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%A_gpu : memref<64x64xf16>, %B_gpu : memref<64x64xf16>) - gpu.dealloc %A_gpu : memref<64x64xf16> - return %B_gpu : memref<64x64xf16> + %memref = gpu.alloc () : memref<64x64xf16> + gpu.memcpy %memref, %arg0 : memref<64x64xf16>, memref<64x64xf16> + %memref_0 = gpu.alloc () : memref<64x64xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%memref : memref<64x64xf16>, %memref_0 : memref<64x64xf16>) + gpu.dealloc %memref : memref<64x64xf16> + %alloc = memref.alloc() : memref<64x64xf16> + gpu.memcpy %alloc, %memref_0 : memref<64x64xf16>, memref<64x64xf16> + gpu.dealloc %memref_0 : memref<64x64xf16> + return %alloc : memref<64x64xf16> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<64x64xf16>, %B: memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + gpu.module @test_kernel { + gpu.func @test_kernel(%arg0: memref<64x64xf16>, %arg1: memref<64x64xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %c0 = arith.constant 0 : index - %a_tile = xetile.init_tile %A[%c0, %c0] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> - %data = xetile.load_tile %a_tile : !xetile.tile<64x64xf16, #xetile.tile_attr>> -> vector<64x64xf16> - - %trans = xetile.transpose %data, [1, 0] {map = #xetile.wg_map} : vector<64x64xf16> -> vector<64x64xf16> - %cvt = xetile.convert_layout %trans {wg_map_result = #xetile.wg_map} : vector<64x64xf16> - - %b_tile = xetile.init_tile %B[%c0, %c0] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> - xetile.store_tile %cvt, %b_tile: vector<64x64xf16>, !xetile.tile<64x64xf16, #xetile.tile_attr>> + %0 = xetile.init_tile %arg0[%c0, %c0] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> + %1 = xetile.load_tile %0 : !xetile.tile<64x64xf16, #xetile.tile_attr>> -> vector<64x64xf16> + %2 = xetile.transpose %1, [1, 0] {map = #xetile.wg_map} : vector<64x64xf16> -> vector<64x64xf16> + %3 = xetile.convert_layout %2 {wg_map_result = #xetile.wg_map} : vector<64x64xf16> + %4 = xetile.init_tile %arg1[%c0, %c0] : memref<64x64xf16> -> !xetile.tile<64x64xf16, #xetile.tile_attr>> + xetile.store_tile %3, %4 : vector<64x64xf16>, !xetile.tile<64x64xf16, #xetile.tile_attr>> gpu.return } } @@ -41,32 +38,29 @@ module @gemm attributes {gpu.container_module} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<64x64xf16> - %Ref = memref.alloc() : memref<64x64xf32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c64 step %c1 { - scf.for %j = %c0 to %c64 step %c1 { // %mul = arith.muli %i, %c64 : index // %add = arith.addi %mul, %j : index // %t = index.castu %add : index to i16 - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<64x64xf16> - %val32 = arith.extf %val : f16 to f32 - memref.store %val32, %Ref[%j, %i] : memref<64x64xf32> + %alloc = memref.alloc() : memref<64x64xf16> + %alloc_0 = memref.alloc() : memref<64x64xf32> + scf.for %arg0 = %c0 to %c64 step %c1 { + scf.for %arg1 = %c0 to %c64 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<64x64xf16> + %3 = arith.extf %2 : f16 to f32 + memref.store %3, %alloc_0[%arg1, %arg0] : memref<64x64xf32> } } - %B = call @test(%A) : (memref<64x64xf16>) -> memref<64x64xf16> - %cast = memref.cast %B : memref<64x64xf16> to memref<*xf16> // call @printMemrefF16(%cast) : (memref<*xf16>) -> () // CHECK: [ALLCLOSE: TRUE] - %cast_ref = memref.cast %Ref : memref<64x64xf32> to memref<*xf32> - call @printAllcloseF16(%cast, %cast_ref) : (memref<*xf16>, memref<*xf32>) -> () - memref.dealloc %A : memref<64x64xf16> - memref.dealloc %Ref : memref<64x64xf32> + %0 = call @test(%alloc) : (memref<64x64xf16>) -> memref<64x64xf16> + %cast = memref.cast %0 : memref<64x64xf16> to memref<*xf16> + %cast_1 = memref.cast %alloc_0 : memref<64x64xf32> to memref<*xf32> + call @printAllcloseF16(%cast, %cast_1) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<64x64xf16> + memref.dealloc %alloc_0 : memref<64x64xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_1k_btranspose.mlir b/test/Integration/Dialect/XeTile/wg_gemm_1k_btranspose.mlir index 9dbf297ac..d187f7380 100644 --- a/test/Integration/Dialect/XeTile/wg_gemm_1k_btranspose.mlir +++ b/test/Integration/Dialect/XeTile/wg_gemm_1k_btranspose.mlir @@ -1,196 +1,139 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#wg_map_a = #xetile.wg_map -#tile_attr_a = #xetile.tile_attr - -#wg_map_b = #xetile.wg_map -#tile_attr_b = #xetile.tile_attr - -#wg_map_c = #xetile.wg_map -#tile_attr_c = #xetile.tile_attr - -#wg_map_d = #xetile.wg_map -#tile_attr_d = #xetile.tile_attr +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %D: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf16>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - %D_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %D, %D_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c4, %c4, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>, %D_gpu : memref<1024x1024xf16>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + %memref_2 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_2, %arg3 : memref<1024x1024xf16>, memref<1024x1024xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c4, %c4, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>, %memref_2 : memref<1024x1024xf16>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>, %D: memref<1024x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c128 : index - %n = arith.muli %block_id_y, %c128 : index - + gpu.module @test_kernel { // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> - -> !xetile.tile<128x128xf32, #tile_attr_c> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<128x128xf32, #tile_attr_c> - -> vector<128x128xf32> - - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> - -> !xetile.tile<128x128xf16, #tile_attr_a> - - %b_init_tile = xetile.init_tile %B[%n, %c0] : memref<1024x1024xf16> - -> !xetile.tile<128x128xf16, #tile_attr_b> - - %d_init_tile = xetile.init_tile %D[%c0, %n] : memref<1024x1024xf16> - -> !xetile.tile<128x128xf16, #tile_attr_d> - // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:4 = scf.for %k = %c0 to %c1024 step %c128 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %d_tile = %d_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<128x128xf16, #tile_attr_a>, - !xetile.tile<128x128xf16, #tile_attr_b>, - !xetile.tile<128x128xf16, #tile_attr_d>, - vector<128x128xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<128x128xf16, #tile_attr_a> - -> vector<128x128xf16> - - %b_value = xetile.load_tile %b_tile : !xetile.tile<128x128xf16, #tile_attr_b> - -> vector<128x128xf16> - - %d_value = xetile.load_tile %d_tile : !xetile.tile<128x128xf16, #tile_attr_d> - -> vector<128x128xf16> - - %b_transpose = vector.transpose %b_value, [1, 0] {map = #xetile.wg_map} : vector<128x128xf16> to vector<128x128xf16> - - %pre_op = arith.addf %b_transpose, %d_value {map = #xetile.wg_map} : vector<128x128xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %pre_op, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_d, wg_map_c = #wg_map_c} - : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> - // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c128] : !xetile.tile<128x128xf16, #tile_attr_a> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c0, %c128] : !xetile.tile<128x128xf16, #tile_attr_b> - %d_next_tile = xetile.update_tile_offset %d_tile, [%c128, %c0] : !xetile.tile<128x128xf16, #tile_attr_d> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %d_next_tile, %c_new_value - : !xetile.tile<128x128xf16, #tile_attr_a>, - !xetile.tile<128x128xf16, #tile_attr_b>, - !xetile.tile<128x128xf16, #tile_attr_d>, - vector<128x128xf32> - } // store the final accumulated C tile result back to memory - xetile.store_tile %out#3, %c_init_tile : vector<128x128xf32>, - !xetile.tile<128x128xf32, #tile_attr_c> - gpu.return - } + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c128 : index + %1 = arith.muli %block_id_y, %c128 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<128x128xf32, #xetile.tile_attr>> + %3 = xetile.load_tile %2 : !xetile.tile<128x128xf32, #xetile.tile_attr>> -> vector<128x128xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #xetile.tile_attr>> + %5 = xetile.init_tile %arg1[%1, %c0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #xetile.tile_attr>> + %6 = xetile.init_tile %arg3[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #xetile.tile_attr>> + %7:4 = scf.for %arg4 = %c0 to %c1024 step %c128 iter_args(%arg5 = %4, %arg6 = %5, %arg7 = %6, %arg8 = %3) -> (!xetile.tile<128x128xf16, #xetile.tile_attr>>, !xetile.tile<128x128xf16, #xetile.tile_attr>>, !xetile.tile<128x128xf16, #xetile.tile_attr>>, vector<128x128xf32>) { + %8 = xetile.load_tile %arg5 : !xetile.tile<128x128xf16, #xetile.tile_attr>> -> vector<128x128xf16> + %9 = xetile.load_tile %arg6 : !xetile.tile<128x128xf16, #xetile.tile_attr>> -> vector<128x128xf16> + %10 = xetile.load_tile %arg7 : !xetile.tile<128x128xf16, #xetile.tile_attr>> -> vector<128x128xf16> + %11 = vector.transpose %9, [1, 0] {map = #xetile.wg_map} : vector<128x128xf16> to vector<128x128xf16> + %12 = arith.addf %11, %10 {map = #xetile.wg_map} : vector<128x128xf16> + %13 = xetile.tile_mma %8, %12, %arg8 {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + %14 = xetile.update_tile_offset %arg5, [%c0, %c128] : !xetile.tile<128x128xf16, #xetile.tile_attr>> + %15 = xetile.update_tile_offset %arg6, [%c0, %c128] : !xetile.tile<128x128xf16, #xetile.tile_attr>> + %16 = xetile.update_tile_offset %arg7, [%c128, %c0] : !xetile.tile<128x128xf16, #xetile.tile_attr>> + scf.yield %14, %15, %16, %13 : !xetile.tile<128x128xf16, #xetile.tile_attr>>, !xetile.tile<128x128xf16, #xetile.tile_attr>>, !xetile.tile<128x128xf16, #xetile.tile_attr>>, vector<128x128xf32> + } + xetile.store_tile %7#3, %2 : vector<128x128xf32>, !xetile.tile<128x128xf32, #xetile.tile_attr>> + gpu.return } + } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index %c1_i32 = arith.constant 1 : i32 - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %cf_5 = arith.constant 5.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %D = memref.alloc() : memref<1024x1024xf16> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 5.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_1 = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf32> + %alloc_3 = memref.alloc() : memref<1024x1024xf16> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // Initialize matrix B with values such that B is not symmetric - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - // Compute a value that ensures B[i,j] != B[j,i] when i != j - %diff = arith.subi %i_i32, %j_i32 : i32 - %value_i32 = arith.addi %diff, %c1_i32 : i32 - %value_f16 = arith.sitofp %value_i32 : i32 to f16 - // Store the value in B[i,j] - memref.store %value_f16, %B[%i, %j] : memref<1024x1024xf16> - } - } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.subi %1, %2 : i32 + %4 = arith.addi %3, %c1_i32 : i32 + %5 = arith.sitofp %4 : i32 to f16 + memref.store %5, %alloc_1[%arg0, %arg1] : memref<1024x1024xf16> } } - // Pre-op: Compute D = B + 5 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %b_val = memref.load %B[%i, %j] : memref<1024x1024xf16> - %d_val = arith.addf %b_val, %cf_5 : f16 - memref.store %d_val, %D[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %D[%k, %j] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 - } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_1[%arg0, %arg1] : memref<1024x1024xf16> + %2 = arith.addf %1, %cst_0 : f16 + memref.store %2, %alloc_3[%arg0, %arg1] : memref<1024x1024xf16> } } - %2 = call @test(%A, %B, %C, %D) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>, memref<1024x1024xf16>) -> memref<1024x1024xf32> - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_3[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 + } + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + } + } + %0 = call @test(%alloc, %alloc_1, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>, memref<1024x1024xf16>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_1 : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir index 87ed4d505..b6e878c34 100644 --- a/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir @@ -1,164 +1,123 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#wg_map_a = #xetile.wg_map -#tile_attr_a = #xetile.tile_attr - -#wg_map_b = #xetile.wg_map -#tile_attr_b = #xetile.tile_attr - -#wg_map_c = #xetile.wg_map -#tile_attr_c = #xetile.tile_attr +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> - memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> - memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c4, %c4, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) - gpu.dealloc %A_gpu : memref<1024x1024xf16> - gpu.dealloc %B_gpu : memref<1024x1024xf16> - return %C_gpu : memref<1024x1024xf32> + %memref = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref, %arg0 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_0 = gpu.alloc () : memref<1024x1024xf16> + gpu.memcpy %memref_0, %arg1 : memref<1024x1024xf16>, memref<1024x1024xf16> + %memref_1 = gpu.alloc () : memref<1024x1024xf32> + gpu.memcpy %memref_1, %arg2 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c4, %c4, %c1) args(%memref : memref<1024x1024xf16>, %memref_0 : memref<1024x1024xf16>, %memref_1 : memref<1024x1024xf32>) + gpu.dealloc %memref : memref<1024x1024xf16> + gpu.dealloc %memref_0 : memref<1024x1024xf16> + %alloc = memref.alloc() : memref<1024x1024xf32> + gpu.memcpy %alloc, %memref_1 : memref<1024x1024xf32>, memref<1024x1024xf32> + gpu.dealloc %memref_1 : memref<1024x1024xf32> + return %alloc : memref<1024x1024xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c128 : index - %n = arith.muli %block_id_y, %c128 : index - + gpu.module @test_kernel { // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> - -> !xetile.tile<128x128xf32, #tile_attr_c> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<128x128xf32, #tile_attr_c> - -> vector<128x128xf32> - - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> - -> !xetile.tile<128x128xf16, #tile_attr_a> - - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> - -> !xetile.tile<128x128xf16, #tile_attr_b> - // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c128 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<128x128xf16, #tile_attr_a>, - !xetile.tile<128x128xf16, #tile_attr_b>, - vector<128x128xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<128x128xf16, #tile_attr_a> - -> vector<128x128xf16> - - %b_value = xetile.load_tile %b_tile : !xetile.tile<128x128xf16, #tile_attr_b> - -> vector<128x128xf16> - // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c} - : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> - // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c128] : !xetile.tile<128x128xf16, #tile_attr_a> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c128, %c0] : !xetile.tile<128x128xf16, #tile_attr_b> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<128x128xf16, #tile_attr_a>, - !xetile.tile<128x128xf16, #tile_attr_b>, vector<128x128xf32> - } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile : vector<128x128xf32>, - !xetile.tile<128x128xf32, #tile_attr_c> - gpu.return - } + gpu.func @test_kernel(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c128 : index + %1 = arith.muli %block_id_y, %c128 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<1024x1024xf32> -> !xetile.tile<128x128xf32, #xetile.tile_attr>> + %3 = xetile.load_tile %2 : !xetile.tile<128x128xf32, #xetile.tile_attr>> -> vector<128x128xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #xetile.tile_attr>> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #xetile.tile_attr>> + %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<128x128xf16, #xetile.tile_attr>>, !xetile.tile<128x128xf16, #xetile.tile_attr>>, vector<128x128xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<128x128xf16, #xetile.tile_attr>> -> vector<128x128xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<128x128xf16, #xetile.tile_attr>> -> vector<128x128xf16> + %9 = xetile.tile_mma %7, %8, %arg6 {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + %10 = xetile.update_tile_offset %arg4, [%c0, %c128] : !xetile.tile<128x128xf16, #xetile.tile_attr>> + %11 = xetile.update_tile_offset %arg5, [%c128, %c0] : !xetile.tile<128x128xf16, #xetile.tile_attr>> + scf.yield %10, %11, %9 : !xetile.tile<128x128xf16, #xetile.tile_attr>>, !xetile.tile<128x128xf16, #xetile.tile_attr>>, vector<128x128xf32> + } + xetile.store_tile %6#2, %2 : vector<128x128xf32>, !xetile.tile<128x128xf32, #xetile.tile_attr>> + gpu.return } + } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c1024 = arith.constant 1024 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<1024x1024xf16> - %B = memref.alloc() : memref<1024x1024xf16> - %C = memref.alloc() : memref<1024x1024xf32> - %C_ref = memref.alloc() : memref<1024x1024xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<1024x1024xf16> + %alloc_2 = memref.alloc() : memref<1024x1024xf16> + %alloc_3 = memref.alloc() : memref<1024x1024xf32> + %alloc_4 = memref.alloc() : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<1024x1024xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<1024x1024xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<1024x1024xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> + %2 = scf.for %arg2 = %c0 to %c1024 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<1024x1024xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<1024x1024xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<1024x1024xf32> } } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> - %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<1024x1024xf16> - memref.dealloc %B : memref<1024x1024xf16> - memref.dealloc %C : memref<1024x1024xf32> - memref.dealloc %C_ref : memref<1024x1024xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast = memref.cast %0 : memref<1024x1024xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<1024x1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<1024x1024xf16> + memref.dealloc %alloc_2 : memref<1024x1024xf16> + memref.dealloc %alloc_3 : memref<1024x1024xf32> + memref.dealloc %alloc_4 : memref<1024x1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir deleted file mode 100644 index 6493cf304..000000000 --- a/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir +++ /dev/null @@ -1,168 +0,0 @@ -// TODO: Add run commands -// RUN: - -// *** Experimental *** -// This example works at the work grpup level. This demonstrates how the user can specify the -// mapping for both subgroup within workgroup and work items within a single subgroup. The mapping -// of subgroups to subtiles are specified using `wg_map` and, work items to data elements mapping is -// specified using `sg_map`. Through this way, user has full control of how each work items works on -// exactly which data elements. XeTile fully honor the mapping provided by users. -// -// Note that lowering of this code to XeGPU is not supported yet because XeTile-XeGPU lowering assumes -// subgroup level programming at XeTile. - -#sg_map_a = #xetile.sg_map -#wg_map_a = #xetile.wg_map -#xe_map_a = #xetile.xe_map - -#sg_map_b = #xetile.sg_map -#wg_map_b = #xetile.wg_map -#xe_map_b = #xetile.xe_map - -#sg_map_c = #xetile.sg_map -#wg_map_c = #xetile.wg_map -#xe_map_c = #xetile.xe_map - -module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> - memref.copy %A, %A_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> - %B_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> - memref.copy %B, %B_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> - %C_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> - memref.copy %C, %C_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c2, %c2, %c1) args(%A_gpu : memref<1024x1024xi8>, %B_gpu : memref<1024x1024xi8>, %C_gpu : memref<1024x1024xi32>) - gpu.dealloc %A_gpu : memref<1024x1024xi8> - gpu.dealloc %B_gpu : memref<1024x1024xi8> - return %C_gpu : memref<1024x1024xi32> - } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - // %c8 = arith.constant 8 : index - // %c16 = arith.constant 16 : index - %c128 = arith.constant 128 : index - %c1024 = arith.constant 1024 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c128 : index - %n = arith.muli %block_id_y, %c128 : index - // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xi32> - -> !xetile.tile<128x128xi32, #xe_map_c> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<128x128xi32, #xe_map_c> - -> vector<128x128xi32> - // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xi8> - -> !xetile.tile<128x128xi8, #xe_map_a> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xi8> - -> !xetile.tile<128x128xi8, #xe_map_b> - // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c128 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<128x128xi8, #xe_map_a>, - !xetile.tile<128x128xi8, #xe_map_b>, - vector<128x128xi32>) { - - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<128x128xi8, #xe_map_a> - -> vector<128x128xi8> - %b_value = xetile.load_tile %b_tile : !xetile.tile<128x128xi8, #xe_map_b> - -> vector<128x128xi8> - // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<128x128xi8>, vector<128x128xi8>, vector<128x128xi32> -> vector<128x128xi32> - // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c128] : !xetile.tile<128x128xi8, #xe_map_a> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c128, %c0] : !xetile.tile<128x128xi8, #xe_map_b> - // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<128x128xi8, #xe_map_a>, - !xetile.tile<128x128xi8, #xe_map_b>, vector<128x128xi32> - } - // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile : vector<128x128xi32>, - !xetile.tile<128x128xi32, #xe_map_c> - gpu.return - } - } - func.func @main() attributes {llvm.emit_c_interface} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c1024 = arith.constant 1024 : index - %ci_0 = arith.constant 0 : i8 - %ci_1 = arith.constant 1 : i8 - %A = memref.alloc() : memref<1024x1024xi8> - %B = memref.alloc() : memref<1024x1024xi8> - %C = memref.alloc() : memref<1024x1024xi32> - %C_ref = memref.alloc() : memref<1024x1024xi32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %val = index.castu %j : index to i8 - memref.store %val, %A[%i, %j] : memref<1024x1024xi8> - } - } - // make matrix B an identity matrix - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %ci_1, %B[%i, %j] : memref<1024x1024xi8> - } else { - memref.store %ci_0, %B[%i, %j] : memref<1024x1024xi8> - } - } - } - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_i32 = arith.constant 0: i32 - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c0_i32, %C[%i, %j] : memref<1024x1024xi32> - memref.store %c0_i32, %C_ref[%i, %j] : memref<1024x1024xi32> - } - } - // compute C for reference - scf.for %i = %c0 to %c1024 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xi32> - %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> i32 { - %a_val = memref.load %A[%i, %k] : memref<1024x1024xi8> - %b_val = memref.load %B[%k, %j] : memref<1024x1024xi8> - %a_val_i32 = arith.extui %a_val : i8 to i32 - %b_val_i32 = arith.extui %b_val : i8 to i32 - %t = arith.muli %a_val_i32, %b_val_i32 : i32 - %c_sum = arith.addi %t, %c_partial : i32 - scf.yield %c_sum : i32 - } - memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xi32> - } - } - %2 = call @test(%A, %B, %C) : (memref<1024x1024xi8>, memref<1024x1024xi8>, memref<1024x1024xi32>) -> memref<1024x1024xi32> - %cast_C = memref.cast %2 : memref<1024x1024xi32> to memref<*xi32> - %cast_C_ref = memref.cast %C_ref : memref<1024x1024xi32> to memref<*xi32> - - call @printAllcloseI32(%cast_C, %cast_C_ref) : (memref<*xi32>, memref<*xi32>) -> () - memref.dealloc %A : memref<1024x1024xi8> - memref.dealloc %B : memref<1024x1024xi8> - memref.dealloc %C : memref<1024x1024xi32> - memref.dealloc %C_ref : memref<1024x1024xi32> - return - } - func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} - func.func private @printMemrefI8(memref<*xi8>) attributes {llvm.emit_c_interface} - func.func private @printAllcloseI32(memref<*xi32>, memref<*xi32>) attributes {llvm.emit_c_interface} -} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir b/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir index d9abea128..5d0f5625a 100644 --- a/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir +++ b/test/Integration/Dialect/XeTile/wg_gemm_4k_8x4_sg_layout.mlir @@ -1,222 +1,145 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck - -#wg_map_a = #xetile.wg_map -#tile_attr_a = #xetile.tile_attr - -#wg_map_b = #xetile.wg_map -#tile_attr_b = #xetile.tile_attr - -#wg_map_c = #xetile.wg_map -#tile_attr_c = #xetile.tile_attr +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<4096x4096xbf16>, %arg1: memref<4096x4096xbf16>, %arg2: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> - memref.copy %A, %A_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> - %B_gpu = gpu.alloc host_shared () : memref<4096x4096xbf16> - memref.copy %B, %B_gpu : memref<4096x4096xbf16> to memref<4096x4096xbf16> - %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32> - memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%A_gpu : memref<4096x4096xbf16>, %B_gpu : memref<4096x4096xbf16>, %C_gpu : memref<4096x4096xf32>) - gpu.dealloc %A_gpu : memref<4096x4096xbf16> - gpu.dealloc %B_gpu : memref<4096x4096xbf16> - return %C_gpu : memref<4096x4096xf32> + %memref = gpu.alloc () : memref<4096x4096xbf16> + gpu.memcpy %memref, %arg0 : memref<4096x4096xbf16>, memref<4096x4096xbf16> + %memref_0 = gpu.alloc () : memref<4096x4096xbf16> + gpu.memcpy %memref_0, %arg1 : memref<4096x4096xbf16>, memref<4096x4096xbf16> + %memref_1 = gpu.alloc () : memref<4096x4096xf32> + gpu.memcpy %memref_1, %arg2 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<4096x4096xbf16>, %memref_0 : memref<4096x4096xbf16>, %memref_1 : memref<4096x4096xf32>) + gpu.dealloc %memref : memref<4096x4096xbf16> + gpu.dealloc %memref_0 : memref<4096x4096xbf16> + %alloc = memref.alloc() : memref<4096x4096xf32> + gpu.memcpy %alloc, %memref_1 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.dealloc %memref_1 : memref<4096x4096xf32> + return %alloc : memref<4096x4096xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c256 = arith.constant 256 : index - %c4096 = arith.constant 4096 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c256 : index - %n = arith.muli %block_id_y, %c256 : index + gpu.module @test_kernel { // intialize C tile and load it // %prefetch_c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> // -> !xetile.tile<256x256xf32, #tile_attr_c> - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> - -> !xetile.tile<256x256xf32, #tile_attr_c> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xf32, #tile_attr_c> - -> vector<256x256xf32> - // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> - -> !xetile.tile<256x32xbf16, #tile_attr_a> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> - -> !xetile.tile<32x256xbf16, #tile_attr_b> - // prefetch first 32 slice - %prefetch_a_init_tile_1 = xetile.init_tile %A[%m, %c0] : memref<4096x4096xbf16> - -> !xetile.tile<256x32xbf16, #tile_attr_a> - %prefetch_b_init_tile_1 = xetile.init_tile %B[%c0, %n] : memref<4096x4096xbf16> - -> !xetile.tile<32x256xbf16, #tile_attr_b> - xetile.prefetch_tile %prefetch_a_init_tile_1 : !xetile.tile<256x32xbf16, #tile_attr_a> - xetile.prefetch_tile %prefetch_b_init_tile_1 : !xetile.tile<32x256xbf16, #tile_attr_b> - // prefetch second 32 slice - %prefetch_a_init_tile_2 = xetile.init_tile %A[%m, %c32] : memref<4096x4096xbf16> - -> !xetile.tile<256x32xbf16, #tile_attr_a> - %prefetch_b_init_tile_2 = xetile.init_tile %B[%c32, %n] : memref<4096x4096xbf16> - -> !xetile.tile<32x256xbf16, #tile_attr_b> - xetile.prefetch_tile %prefetch_a_init_tile_2 : !xetile.tile<256x32xbf16, #tile_attr_a> - xetile.prefetch_tile %prefetch_b_init_tile_2 : !xetile.tile<32x256xbf16, #tile_attr_b> - - // prefetch third 32 slice - %prefetch_a_init_tile_3 = xetile.init_tile %A[%m, %c64] : memref<4096x4096xbf16> - -> !xetile.tile<256x32xbf16, #tile_attr_a> - %prefetch_b_init_tile_3 = xetile.init_tile %B[%c64, %n] : memref<4096x4096xbf16> - -> !xetile.tile<32x256xbf16, #tile_attr_b> - - xegpu.alloc_nbarrier 1 - %nbarrier_id = arith.constant 0 : i8 - %num_threads = arith.constant 32 : i8 - %nbarrier = xegpu.init_nbarrier %nbarrier_id, %num_threads : i8, i8 -> !xegpu.nbarrier - %c0_i32 = arith.constant 0 : i32 - // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:5 = scf.for %k = %c0 to %c4096 step %c32 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value, - %prefetch_a_tile = %prefetch_a_init_tile_3, - %prefetch_b_tile = %prefetch_b_init_tile_3 - ) - -> (!xetile.tile<256x32xbf16, #tile_attr_a>, - !xetile.tile<32x256xbf16, #tile_attr_b>, - vector<256x256xf32>, - !xetile.tile<256x32xbf16, #tile_attr_a>, - !xetile.tile<32x256xbf16, #tile_attr_b> - ) { - // all SGs must arrive here first // %every_8th_iter = arith.remui %k, %c256 : index // %every_8th_iter_i32 = arith.index_cast %every_8th_iter : index to i32 // %every_8th_iter_cond = arith.cmpi eq, %every_8th_iter_i32, %c0_i32 : i32 // scf.if %every_8th_iter_cond { - xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier // } - - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> - -> vector<256x32xbf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> - -> vector<32x256xbf16> - - xegpu.compile_hint - // prefetch next A and B tiles - xetile.prefetch_tile %prefetch_a_tile : !xetile.tile<256x32xbf16, #tile_attr_a> - xetile.prefetch_tile %prefetch_b_tile : !xetile.tile<32x256xbf16, #tile_attr_b> - - xegpu.compile_hint - // update prefetch tile offsets - %15 = xetile.update_tile_offset %prefetch_a_tile, [%c0, %c32] : !xetile.tile<256x32xbf16, #tile_attr_a> - %16 = xetile.update_tile_offset %prefetch_b_tile, [%c32, %c0] : !xetile.tile<32x256xbf16, #tile_attr_b> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] - : !xetile.tile<256x32xbf16, #tile_attr_a> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] - : !xetile.tile<32x256xbf16, #tile_attr_b> - - xegpu.compile_hint - // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c} - : vector<256x32xbf16>, vector<32x256xbf16>, vector<256x256xf32> -> vector<256x256xf32> - - xegpu.compile_hint // barrier wait // scf.if %every_8th_iter_cond { - xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier // } // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value, %15, %16 - : !xetile.tile<256x32xbf16, #tile_attr_a>, - !xetile.tile<32x256xbf16, #tile_attr_b>, vector<256x256xf32>, - !xetile.tile<256x32xbf16, #tile_attr_a>, - !xetile.tile<32x256xbf16, #tile_attr_b> - } // store the final accumulated C tile result back to memory - %c_init_tile_1 = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> - -> !xetile.tile<256x256xf32, #tile_attr_c> - xetile.store_tile %out#2, %c_init_tile_1 : vector<256x256xf32>, - !xetile.tile<256x256xf32, #tile_attr_c> + gpu.func @test_kernel(%arg0: memref<4096x4096xbf16>, %arg1: memref<4096x4096xbf16>, %arg2: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c256 : index + %1 = arith.muli %block_id_y, %c256 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<4096x4096xf32> -> !xetile.tile<256x256xf32, #xetile.tile_attr>> + %3 = xetile.load_tile %2 : !xetile.tile<256x256xf32, #xetile.tile_attr>> -> vector<256x256xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<4096x4096xbf16> -> !xetile.tile<256x32xbf16, #xetile.tile_attr>> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<4096x4096xbf16> -> !xetile.tile<32x256xbf16, #xetile.tile_attr>> + %6 = xetile.init_tile %arg0[%0, %c0] : memref<4096x4096xbf16> -> !xetile.tile<256x32xbf16, #xetile.tile_attr>> + %7 = xetile.init_tile %arg1[%c0, %1] : memref<4096x4096xbf16> -> !xetile.tile<32x256xbf16, #xetile.tile_attr>> + xetile.prefetch_tile %6 : !xetile.tile<256x32xbf16, #xetile.tile_attr>> + xetile.prefetch_tile %7 : !xetile.tile<32x256xbf16, #xetile.tile_attr>> + %8 = xetile.init_tile %arg0[%0, %c32] : memref<4096x4096xbf16> -> !xetile.tile<256x32xbf16, #xetile.tile_attr>> + %9 = xetile.init_tile %arg1[%c32, %1] : memref<4096x4096xbf16> -> !xetile.tile<32x256xbf16, #xetile.tile_attr>> + xetile.prefetch_tile %8 : !xetile.tile<256x32xbf16, #xetile.tile_attr>> + xetile.prefetch_tile %9 : !xetile.tile<32x256xbf16, #xetile.tile_attr>> + %10 = xetile.init_tile %arg0[%0, %c64] : memref<4096x4096xbf16> -> !xetile.tile<256x32xbf16, #xetile.tile_attr>> + %11 = xetile.init_tile %arg1[%c64, %1] : memref<4096x4096xbf16> -> !xetile.tile<32x256xbf16, #xetile.tile_attr>> + xegpu.alloc_nbarrier 1 + %c0_i8 = arith.constant 0 : i8 + %c32_i8 = arith.constant 32 : i8 + %12 = xegpu.init_nbarrier %c0_i8, %c32_i8 : i8, i8 -> !xegpu.nbarrier + %c0_i32 = arith.constant 0 : i32 + %13:5 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3, %arg7 = %10, %arg8 = %11) -> (!xetile.tile<256x32xbf16, #xetile.tile_attr>>, !xetile.tile<32x256xbf16, #xetile.tile_attr>>, vector<256x256xf32>, !xetile.tile<256x32xbf16, #xetile.tile_attr>>, !xetile.tile<32x256xbf16, #xetile.tile_attr>>) { + xegpu.nbarrier_arrive %12 : !xegpu.nbarrier + %15 = xetile.load_tile %arg4 : !xetile.tile<256x32xbf16, #xetile.tile_attr>> -> vector<256x32xbf16> + %16 = xetile.load_tile %arg5 : !xetile.tile<32x256xbf16, #xetile.tile_attr>> -> vector<32x256xbf16> xegpu.compile_hint - gpu.return + xetile.prefetch_tile %arg7 : !xetile.tile<256x32xbf16, #xetile.tile_attr>> + xetile.prefetch_tile %arg8 : !xetile.tile<32x256xbf16, #xetile.tile_attr>> + xegpu.compile_hint + %17 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<256x32xbf16, #xetile.tile_attr>> + %18 = xetile.update_tile_offset %arg8, [%c32, %c0] : !xetile.tile<32x256xbf16, #xetile.tile_attr>> + %19 = xetile.update_tile_offset %arg4, [%c0, %c32] : !xetile.tile<256x32xbf16, #xetile.tile_attr>> + %20 = xetile.update_tile_offset %arg5, [%c32, %c0] : !xetile.tile<32x256xbf16, #xetile.tile_attr>> + xegpu.compile_hint + %21 = xetile.tile_mma %15, %16, %arg6 {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<256x32xbf16>, vector<32x256xbf16>, vector<256x256xf32> -> vector<256x256xf32> + xegpu.compile_hint + xegpu.nbarrier_wait %12 : !xegpu.nbarrier + scf.yield %19, %20, %21, %17, %18 : !xetile.tile<256x32xbf16, #xetile.tile_attr>>, !xetile.tile<32x256xbf16, #xetile.tile_attr>>, vector<256x256xf32>, !xetile.tile<256x32xbf16, #xetile.tile_attr>>, !xetile.tile<32x256xbf16, #xetile.tile_attr>> + } + %14 = xetile.init_tile %arg2[%0, %1] : memref<4096x4096xf32> -> !xetile.tile<256x256xf32, #xetile.tile_attr>> + xetile.store_tile %13#2, %14 : vector<256x256xf32>, !xetile.tile<256x256xf32, #xetile.tile_attr>> + xegpu.compile_hint + gpu.return } } - func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c1_f16 = arith.constant 1.0 : bf16 - %c2_f16 = arith.constant 2.0 : bf16 %c4096 = arith.constant 4096 : index - %cf_0 = arith.constant 0.0 : bf16 - %cf_1 = arith.constant 1.0 : bf16 - %c_gen_int = arith.constant 0 : i1 - %cf_lower = arith.constant 0.0 : f32 - %cf_upper = arith.constant 1.0 : f32 - - %A = memref.alloc() : memref<4096x4096xbf16> - %B = memref.alloc() : memref<4096x4096xbf16> - %C = memref.alloc() : memref<4096x4096xf32> - %C_ref = memref.alloc() : memref<4096x4096xf32> - // convert the memref to 1D and fill with random values in (0.0, 1.0) - %A_random = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - // convert the memref to 1D and fill with random values in (0.0, 1.0) - %B_random = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> - call @fillResource1DRandomBF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xbf16>, f32, f32, i1) -> () - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f16 = arith.constant 0.0 : bf16 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + %false = arith.constant false + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %alloc = memref.alloc() : memref<4096x4096xbf16> + %alloc_1 = memref.alloc() : memref<4096x4096xbf16> + %alloc_2 = memref.alloc() : memref<4096x4096xf32> + %alloc_3 = memref.alloc() : memref<4096x4096xf32> + %cast = memref.cast %alloc : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + %cast_4 = memref.cast %alloc_1 : memref<4096x4096xbf16> to memref<*xbf16> + call @fillResource1DRandomBF16(%cast_4, %cst, %cst_0, %false) : (memref<*xbf16>, f32, f32, i1) -> () + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + memref.store %cst, %alloc_2[%arg0, %arg1] : memref<4096x4096xf32> + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<4096x4096xf32> } } - // Run GPU. - %2 = call @test(%A, %B, %C) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - // Run CPU - %A_cast = memref.cast %A : memref<4096x4096xbf16> to memref<*xbf16> - %B_cast = memref.cast %B : memref<4096x4096xbf16> to memref<*xbf16> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - call @gemmBF16BF16F32(%A_cast, %B_cast, %cast_C_ref) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () - // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<4096x4096xbf16> - memref.dealloc %B : memref<4096x4096xbf16> - memref.dealloc %C : memref<4096x4096xf32> - memref.dealloc %C_ref : memref<4096x4096xf32> + %0 = call @test(%alloc, %alloc_1, %alloc_2) : (memref<4096x4096xbf16>, memref<4096x4096xbf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %cast_5 = memref.cast %0 : memref<4096x4096xf32> to memref<*xf32> + %cast_6 = memref.cast %alloc : memref<4096x4096xbf16> to memref<*xbf16> + %cast_7 = memref.cast %alloc_1 : memref<4096x4096xbf16> to memref<*xbf16> + %cast_8 = memref.cast %alloc_3 : memref<4096x4096xf32> to memref<*xf32> + call @gemmBF16BF16F32(%cast_6, %cast_7, %cast_8) : (memref<*xbf16>, memref<*xbf16>, memref<*xf32>) -> () + call @printAllcloseF32(%cast_5, %cast_8) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4096x4096xbf16> + memref.dealloc %alloc_1 : memref<4096x4096xbf16> + memref.dealloc %alloc_2 : memref<4096x4096xf32> + memref.dealloc %alloc_3 : memref<4096x4096xf32> return } func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir index 75b84d455..afd5d2ad8 100644 --- a/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir +++ b/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir @@ -1,147 +1,123 @@ -#wg_map_a = #xetile.wg_map -#tile_attr_a = #xetile.tile_attr - -#wg_map_b = #xetile.wg_map -#tile_attr_b = #xetile.tile_attr - -#wg_map_c = #xetile.wg_map -#tile_attr_c = #xetile.tile_attr +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { + func.func @test(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %A, %A_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %B_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> - memref.copy %B, %B_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> - %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32> - memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c4, %c4, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf32>) - gpu.dealloc %A_gpu : memref<4096x4096xf16> - gpu.dealloc %B_gpu : memref<4096x4096xf16> - return %C_gpu : memref<4096x4096xf32> + %memref = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref, %arg0 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_0 = gpu.alloc () : memref<4096x4096xf16> + gpu.memcpy %memref_0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16> + %memref_1 = gpu.alloc () : memref<4096x4096xf32> + gpu.memcpy %memref_1, %arg2 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c4, %c4, %c1) args(%memref : memref<4096x4096xf16>, %memref_0 : memref<4096x4096xf16>, %memref_1 : memref<4096x4096xf32>) + gpu.dealloc %memref : memref<4096x4096xf16> + gpu.dealloc %memref_0 : memref<4096x4096xf16> + %alloc = memref.alloc() : memref<4096x4096xf32> + gpu.memcpy %alloc, %memref_1 : memref<4096x4096xf32>, memref<4096x4096xf32> + gpu.dealloc %memref_1 : memref<4096x4096xf32> + return %alloc : memref<4096x4096xf32> } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c256 = arith.constant 256 : index - %c4096 = arith.constant 4096 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c256 : index - %n = arith.muli %block_id_y, %c256 : index + gpu.module @test_kernel { // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> - -> !xetile.tile<256x256xf32, #tile_attr_c> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xf32, #tile_attr_c> - -> vector<256x256xf32> // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xf16> - -> !xetile.tile<256x256xf16, #tile_attr_a> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xf16> - -> !xetile.tile<256x256xf16, #tile_attr_b> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c4096 step %c256 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<256x256xf16, #tile_attr_a>, - !xetile.tile<256x256xf16, #tile_attr_b>, - vector<256x256xf32>) { - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<256x256xf16, #tile_attr_a> - -> vector<256x256xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<256x256xf16, #tile_attr_b> - -> vector<256x256xf16> // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value {wg_map_a = #wg_map_a, wg_map_b = #wg_map_b, wg_map_c = #wg_map_c} - : vector<256x256xf16>, vector<256x256xf16>, vector<256x256xf32> -> vector<256x256xf32> // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c256] : !xetile.tile<256x256xf16, #tile_attr_a> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c256, %c0] : !xetile.tile<256x256xf16, #tile_attr_b> // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<256x256xf16, #tile_attr_a>, - !xetile.tile<256x256xf16, #tile_attr_b>, vector<256x256xf32> - } // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile : vector<256x256xf32>, - !xetile.tile<256x256xf32, #tile_attr_c> - gpu.return + gpu.func @test_kernel(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c256 : index + %1 = arith.muli %block_id_y, %c256 : index + %2 = xetile.init_tile %arg2[%0, %1] : memref<4096x4096xf32> -> !xetile.tile<256x256xf32, #xetile.tile_attr>> + %3 = xetile.load_tile %2 : !xetile.tile<256x256xf32, #xetile.tile_attr>> -> vector<256x256xf32> + %4 = xetile.init_tile %arg0[%0, %c0] : memref<4096x4096xf16> -> !xetile.tile<256x256xf16, #xetile.tile_attr>> + %5 = xetile.init_tile %arg1[%c0, %1] : memref<4096x4096xf16> -> !xetile.tile<256x256xf16, #xetile.tile_attr>> + %6:3 = scf.for %arg3 = %c0 to %c4096 step %c256 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xetile.tile<256x256xf16, #xetile.tile_attr>>, !xetile.tile<256x256xf16, #xetile.tile_attr>>, vector<256x256xf32>) { + %7 = xetile.load_tile %arg4 : !xetile.tile<256x256xf16, #xetile.tile_attr>> -> vector<256x256xf16> + %8 = xetile.load_tile %arg5 : !xetile.tile<256x256xf16, #xetile.tile_attr>> -> vector<256x256xf16> + %9 = xetile.tile_mma %7, %8, %arg6 {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<256x256xf16>, vector<256x256xf16>, vector<256x256xf32> -> vector<256x256xf32> + %10 = xetile.update_tile_offset %arg4, [%c0, %c256] : !xetile.tile<256x256xf16, #xetile.tile_attr>> + %11 = xetile.update_tile_offset %arg5, [%c256, %c0] : !xetile.tile<256x256xf16, #xetile.tile_attr>> + scf.yield %10, %11, %9 : !xetile.tile<256x256xf16, #xetile.tile_attr>>, !xetile.tile<256x256xf16, #xetile.tile_attr>>, vector<256x256xf32> + } + xetile.store_tile %6#2, %2 : vector<256x256xf32>, !xetile.tile<256x256xf32, #xetile.tile_attr>> + gpu.return } } func.func @main() attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4096 = arith.constant 4096 : index - %cf_0 = arith.constant 0.0 : f16 - %cf_1 = arith.constant 1.0 : f16 - %A = memref.alloc() : memref<4096x4096xf16> - %B = memref.alloc() : memref<4096x4096xf16> - %C = memref.alloc() : memref<4096x4096xf32> - %C_ref = memref.alloc() : memref<4096x4096xf32> // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %t = index.castu %j : index to i16 - %val = arith.uitofp %t : i16 to f16 - memref.store %val, %A[%i, %j] : memref<4096x4096xf16> + %cst_0 = arith.constant 0.000000e+00 : f16 + %cst_1 = arith.constant 1.000000e+00 : f16 + %alloc = memref.alloc() : memref<4096x4096xf16> + %alloc_2 = memref.alloc() : memref<4096x4096xf16> + %alloc_3 = memref.alloc() : memref<4096x4096xf32> + %alloc_4 = memref.alloc() : memref<4096x4096xf32> + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + %1 = index.castu %arg1 : index to i16 + %2 = arith.uitofp %1 : i16 to f16 + memref.store %2, %alloc[%arg0, %arg1] : memref<4096x4096xf16> } } // make matrix B an identity matrix - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %cf_1, %B[%i, %j] : memref<4096x4096xf16> + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + %1 = index.castu %arg0 : index to i32 + %2 = index.castu %arg1 : index to i32 + %3 = arith.cmpi eq, %1, %2 : i32 + scf.if %3 { + memref.store %cst_1, %alloc_2[%arg0, %arg1] : memref<4096x4096xf16> } else { - memref.store %cf_0, %B[%i, %j] : memref<4096x4096xf16> + memref.store %cst_0, %alloc_2[%arg0, %arg1] : memref<4096x4096xf16> } } } // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_f32 = arith.constant 0.0 : f32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> - memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + memref.store %cst, %alloc_3[%arg0, %arg1] : memref<4096x4096xf32> + memref.store %cst, %alloc_4[%arg0, %arg1] : memref<4096x4096xf32> } } // compute C for reference - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<4096x4096xf32> - %c_val = scf.for %k = %c0 to %c4096 step %c1 iter_args(%c_partial = %c_curr) -> f32 { - %a_val = memref.load %A[%i, %k] : memref<4096x4096xf16> - %b_val = memref.load %B[%k, %j] : memref<4096x4096xf16> - %t = arith.mulf %a_val, %b_val : f16 - %t_cast = arith.extf %t : f16 to f32 - %c_sum = arith.addf %t_cast, %c_partial : f32 - scf.yield %c_sum : f32 + scf.for %arg0 = %c0 to %c4096 step %c1 { + scf.for %arg1 = %c0 to %c4096 step %c1 { + %1 = memref.load %alloc_4[%arg0, %arg1] : memref<4096x4096xf32> + %2 = scf.for %arg2 = %c0 to %c4096 step %c1 iter_args(%arg3 = %1) -> (f32) { + %3 = memref.load %alloc[%arg0, %arg2] : memref<4096x4096xf16> + %4 = memref.load %alloc_2[%arg2, %arg1] : memref<4096x4096xf16> + %5 = arith.mulf %3, %4 : f16 + %6 = arith.extf %5 : f16 to f32 + %7 = arith.addf %6, %arg3 : f32 + scf.yield %7 : f32 } - memref.store %c_val , %C_ref[%i, %j] : memref<4096x4096xf32> + memref.store %2, %alloc_4[%arg0, %arg1] : memref<4096x4096xf32> } } - %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> - %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> - - call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %A : memref<4096x4096xf16> - memref.dealloc %B : memref<4096x4096xf16> - memref.dealloc %C : memref<4096x4096xf32> - memref.dealloc %C_ref : memref<4096x4096xf32> + %0 = call @test(%alloc, %alloc_2, %alloc_3) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %cast = memref.cast %0 : memref<4096x4096xf32> to memref<*xf32> + %cast_5 = memref.cast %alloc_4 : memref<4096x4096xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_5) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<4096x4096xf16> + memref.dealloc %alloc_2 : memref<4096x4096xf16> + memref.dealloc %alloc_3 : memref<4096x4096xf32> + memref.dealloc %alloc_4 : memref<4096x4096xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir deleted file mode 100644 index d8f58682b..000000000 --- a/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir +++ /dev/null @@ -1,166 +0,0 @@ -// TODO: Add run commands -// RUN: - -// *** Experimental *** -// This example works at the work grpup level. This demonstrates how the user can specify the -// mapping for both subgroup within workgroup and work items within a single subgroup. The mapping -// of subgroups to subtiles are specified using `wg_map` and, work items to data elements mapping is -// specified using `sg_map`. Through this way, user has full control of how each work items works on -// exactly which data elements. XeTile fully honor the mapping provided by users. -// -// Note that lowering of this code to XeGPU is not supported yet because XeTile-XeGPU lowering assumes -// subgroup level programming at XeTile. - -#sg_map_a = #xetile.sg_map -#wg_map_a = #xetile.wg_map -#xe_map_a = #xetile.xe_map - -#sg_map_b = #xetile.sg_map -#wg_map_b = #xetile.wg_map -#xe_map_b = #xetile.xe_map - -#sg_map_c = #xetile.sg_map -#wg_map_c = #xetile.wg_map -#xe_map_c = #xetile.xe_map - -module @gemm attributes {gpu.container_module} { - func.func @test(%A: memref<4096x4096xi8>, %B: memref<4096x4096xi8>, %C: memref<4096x4096xi32>) -> memref<4096x4096xi32> attributes {llvm.emit_c_interface} { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c128 = arith.constant 128 : index - %c512 = arith.constant 512 : index - %A_gpu = gpu.alloc host_shared () : memref<4096x4096xi8> - memref.copy %A, %A_gpu : memref<4096x4096xi8> to memref<4096x4096xi8> - %B_gpu = gpu.alloc host_shared () : memref<4096x4096xi8> - memref.copy %B, %B_gpu : memref<4096x4096xi8> to memref<4096x4096xi8> - %C_gpu = gpu.alloc host_shared () : memref<4096x4096xi32> - memref.copy %C, %C_gpu : memref<4096x4096xi32> to memref<4096x4096xi32> - gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c4, %c4, %c1) args(%A_gpu : memref<4096x4096xi8>, %B_gpu : memref<4096x4096xi8>, %C_gpu : memref<4096x4096xi32>) - gpu.dealloc %A_gpu : memref<4096x4096xi8> - gpu.dealloc %B_gpu : memref<4096x4096xi8> - return %C_gpu : memref<4096x4096xi32> - } - gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_kernel(%A: memref<4096x4096xi8>, %B: memref<4096x4096xi8>, %C: memref<4096x4096xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c256 = arith.constant 256 : index - %c4096 = arith.constant 4096 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c256 : index - %n = arith.muli %block_id_y, %c256 : index - // intialize C tile and load it - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xi32> - -> !xetile.tile<256x256xi32, #xe_map_c> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xi32, #xe_map_c> - -> vector<256x256xi32> - // initalize A and B tiles - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xi8> - -> !xetile.tile<256x256xi8, #xe_map_a> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xi8> - -> !xetile.tile<256x256xi8, #xe_map_b> - // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c4096 step %c256 - iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<256x256xi8, #xe_map_a>, - !xetile.tile<256x256xi8, #xe_map_b>, - vector<256x256xi32>) { - - // load A and B tiles - %a_value = xetile.load_tile %a_tile : !xetile.tile<256x256xi8, #xe_map_a> - -> vector<256x256xi8> - %b_value = xetile.load_tile %b_tile : !xetile.tile<256x256xi8, #xe_map_b> - -> vector<256x256xi8> - // perform dpas and accumulate - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : vector<256x256xi8>, vector<256x256xi8>, vector<256x256xi32> -> vector<256x256xi32> - // update the offsets for A and B tiles - %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c256] : !xetile.tile<256x256xi8, #xe_map_a> - %b_next_tile = xetile.update_tile_offset %b_tile, [%c256, %c0] : !xetile.tile<256x256xi8, #xe_map_b> - // partial C tile result - scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<256x256xi8, #xe_map_a>, - !xetile.tile<256x256xi8, #xe_map_b>, vector<256x256xi32> - } - // store the final accumulated C tile result back to memory - xetile.store_tile %out#2, %c_init_tile : vector<256x256xi32>, - !xetile.tile<256x256xi32, #xe_map_c> - gpu.return - } - } - func.func @main() attributes {llvm.emit_c_interface} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4096 = arith.constant 4096 : index - %ci_0 = arith.constant 0 : i8 - %ci_1 = arith.constant 1 : i8 - %A = memref.alloc() : memref<4096x4096xi8> - %B = memref.alloc() : memref<4096x4096xi8> - %C = memref.alloc() : memref<4096x4096xi32> - %C_ref = memref.alloc() : memref<4096x4096xi32> - // intialize matrix A ; A[i, j] = j - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %val = index.castu %j : index to i8 - memref.store %val, %A[%i, %j] : memref<4096x4096xi8> - } - } - // make matrix B an identity matrix - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %i_i32 = index.castu %i : index to i32 - %j_i32 = index.castu %j : index to i32 - %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 - - scf.if %i_j_same { - memref.store %ci_1, %B[%i, %j] : memref<4096x4096xi8> - } else { - memref.store %ci_0, %B[%i, %j] : memref<4096x4096xi8> - } - } - } - // intialize matrix C and C_ref ; C[i, j] = 0 - %c0_i32 = arith.constant 0: i32 - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - memref.store %c0_i32, %C[%i, %j] : memref<4096x4096xi32> - memref.store %c0_i32, %C_ref[%i, %j] : memref<4096x4096xi32> - } - } - // compute C for reference - scf.for %i = %c0 to %c4096 step %c1 { - scf.for %j = %c0 to %c4096 step %c1 { - %c_curr = memref.load %C_ref[%i, %j] : memref<4096x4096xi32> - %c_val = scf.for %k = %c0 to %c4096 step %c1 iter_args(%c_partial = %c_curr) -> i32 { - %a_val = memref.load %A[%i, %k] : memref<4096x4096xi8> - %b_val = memref.load %B[%k, %j] : memref<4096x4096xi8> - %a_val_i32 = arith.extui %a_val : i8 to i32 - %b_val_i32 = arith.extui %b_val : i8 to i32 - %t = arith.muli %a_val_i32, %b_val_i32 : i32 - %c_sum = arith.addi %t, %c_partial : i32 - scf.yield %c_sum : i32 - } - memref.store %c_val , %C_ref[%i, %j] : memref<4096x4096xi32> - } - } - %2 = call @test(%A, %B, %C) : (memref<4096x4096xi8>, memref<4096x4096xi8>, memref<4096x4096xi32>) -> memref<4096x4096xi32> - %cast_C = memref.cast %2 : memref<4096x4096xi32> to memref<*xi32> - %cast_C_ref = memref.cast %C_ref : memref<4096x4096xi32> to memref<*xi32> - - call @printAllcloseI32(%cast_C, %cast_C_ref) : (memref<*xi32>, memref<*xi32>) -> () - memref.dealloc %A : memref<4096x4096xi8> - memref.dealloc %B : memref<4096x4096xi8> - memref.dealloc %C : memref<4096x4096xi32> - memref.dealloc %C_ref : memref<4096x4096xi32> - return - } - func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} - func.func private @printMemrefI8(memref<*xi8>) attributes {llvm.emit_c_interface} - func.func private @printAllcloseI32(memref<*xi32>, memref<*xi32>) attributes {llvm.emit_c_interface} -} diff --git a/test/Integration/Dialect/XeTile/wg_reduction.mlir b/test/Integration/Dialect/XeTile/wg_reduction.mlir index f6f6255fe..b403eff9e 100644 --- a/test/Integration/Dialect/XeTile/wg_reduction.mlir +++ b/test/Integration/Dialect/XeTile/wg_reduction.mlir @@ -1,92 +1,78 @@ -// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ +// RUN: %python_executable %imex_runner --requires=mlir-levelzero-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ // RUN: --runner mlir-runner -e main \ // RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck -// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-wg-to-func-vc.pp \ -// RUN: --runner mlir-runner -e main \ -// RUN: --entry-point-result=void \ -// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%mlir_levelzero_runtime --filecheck module @reduction attributes {gpu.container_module} { - func.func @reduce_test(%a: memref<256x1024xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { + func.func @reduce_test(%arg0: memref<256x1024xf32>) -> memref<1x1024xf32> attributes {llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - - %a_gpu = gpu.alloc host_shared () : memref<256x1024xf32> - memref.copy %a, %a_gpu : memref<256x1024xf32> to memref<256x1024xf32> - %b_gpu = gpu.alloc host_shared () : memref<1x1024xf32> - - gpu.launch_func @kernel::@test_reduction blocks in (%c1, %c8, %c1) threads in (%c8, %c4, %c1) args(%a_gpu : memref<256x1024xf32>, %b_gpu : memref<1x1024xf32>) - - gpu.dealloc %a_gpu : memref<256x1024xf32> - return %b_gpu : memref<1x1024xf32> + %memref = gpu.alloc () : memref<256x1024xf32> + gpu.memcpy %memref, %arg0 : memref<256x1024xf32>, memref<256x1024xf32> + %memref_0 = gpu.alloc () : memref<1x1024xf32> + gpu.launch_func @kernel::@test_reduction blocks in (%c1, %c8, %c1) threads in (%c8, %c4, %c1) args(%memref : memref<256x1024xf32>, %memref_0 : memref<1x1024xf32>) + gpu.dealloc %memref : memref<256x1024xf32> + %alloc = memref.alloc() : memref<1x1024xf32> + gpu.memcpy %alloc, %memref_0 : memref<1x1024xf32>, memref<1x1024xf32> + gpu.dealloc %memref_0 : memref<1x1024xf32> + return %alloc : memref<1x1024xf32> } - -gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @test_reduction(%arg0 : memref<256x1024xf32>, %arg1 : memref<1x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %c0 = arith.constant 0 : index - %c256 = arith.constant 256 : index - %c128 = arith.constant 128 : index - %block_id_x = gpu.block_id x - %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c256 : index - %n = arith.muli %block_id_y, %c128 : index - %init_tile = xetile.init_tile %arg0[%m, %n] : memref<256x1024xf32> -> !xetile.tile<256x128xf32, #xetile.tile_attr>> - %load_tile = xetile.load_tile %init_tile: !xetile.tile<256x128xf32, #xetile.tile_attr>> -> vector<256x128xf32> - %cst_0 = arith.constant {map = #xetile.wg_map} dense<0.0> : vector<8x128xf32> - %reshape = vector.shape_cast %load_tile {map = #xetile.wg_map} : vector<256x128xf32> to vector<8x32x128xf32> - %reduction = vector.multi_reduction , %reshape, %cst_0 {map = #xetile.wg_map} [1] : vector<8x32x128xf32> to vector<8x128xf32> - %conv_layout = xetile.convert_layout %reduction {wg_map_result = #xetile.wg_map} : vector<8x128xf32> - %cst_1 = arith.constant {map = #xetile.wg_map} dense<0.0> : vector<128xf32> - %reduce = vector.multi_reduction , %conv_layout, %cst_1 {map = #xetile.wg_map} [0] : vector<8x128xf32> to vector<128xf32> - %shape_cast = vector.shape_cast %reduce {map = #xetile.wg_map} : vector<128xf32> to vector<1x128xf32> - %init_store_tile = xetile.init_tile %arg1[%c0, %n] : memref<1x1024xf32> -> !xetile.tile<1x128xf32, #xetile.tile_attr>> - xetile.store_tile %shape_cast, %init_store_tile : vector<1x128xf32>, !xetile.tile<1x128xf32, #xetile.tile_attr>> - gpu.return + gpu.module @kernel { + gpu.func @test_reduction(%arg0: memref<256x1024xf32>, %arg1: memref<1x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c128 = arith.constant 128 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c256 : index + %1 = arith.muli %block_id_y, %c128 : index + %2 = xetile.init_tile %arg0[%0, %1] : memref<256x1024xf32> -> !xetile.tile<256x128xf32, #xetile.tile_attr>> + %3 = xetile.load_tile %2 : !xetile.tile<256x128xf32, #xetile.tile_attr>> -> vector<256x128xf32> + %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<8x128xf32> + %4 = vector.shape_cast %3 {map = #xetile.wg_map} : vector<256x128xf32> to vector<8x32x128xf32> + %5 = vector.multi_reduction , %4, %cst {map = #xetile.wg_map} [1] : vector<8x32x128xf32> to vector<8x128xf32> + %6 = xetile.convert_layout %5 {wg_map_result = #xetile.wg_map} : vector<8x128xf32> + %cst_0 = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<128xf32> + %7 = vector.multi_reduction , %6, %cst_0 {map = #xetile.wg_map} [0] : vector<8x128xf32> to vector<128xf32> + %8 = vector.shape_cast %7 {map = #xetile.wg_map} : vector<128xf32> to vector<1x128xf32> + %9 = xetile.init_tile %arg1[%c0, %1] : memref<1x1024xf32> -> !xetile.tile<1x128xf32, #xetile.tile_attr>> + xetile.store_tile %8, %9 : vector<1x128xf32>, !xetile.tile<1x128xf32, #xetile.tile_attr>> + gpu.return + } } -} - -func.func @main() attributes {llvm.emit_c_interface} { + func.func @main() attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index %c1024 = arith.constant 1024 : index - %c32 = arith.constant 32 : index - %c0_f32 = arith.constant 0.0 : f32 - %c32_f32 = arith.constant 32.0 : f32 - %c1_f32 = arith.constant 1.0 : f32 - %c100_f32 = arith.constant 100.0 : f32 - %a = memref.alloc() : memref<256x1024xf32> - %b_ref = memref.alloc() : memref<1024xf32> - - // intialize matrix A ; A[i, j] = 1 - scf.for %i = %c0 to %c256 step %c1 { - scf.for %j = %c0 to %c1024 step %c1 { - memref.store %c1_f32, %a[%i, %j] : memref<256x1024xf32> + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %alloc = memref.alloc() : memref<256x1024xf32> + %alloc_1 = memref.alloc() : memref<1024xf32> + scf.for %arg0 = %c0 to %c256 step %c1 { + scf.for %arg1 = %c0 to %c1024 step %c1 { + memref.store %cst_0, %alloc[%arg0, %arg1] : memref<256x1024xf32> } } - - scf.for %j = %c0 to %c1024 step %c1 { - %sum = scf.for %i = %c0 to %c256 step %c1 iter_args(%arg = %c0_f32) -> (f32) { - %val = memref.load %a[%i, %j] : memref<256x1024xf32> - %2 = arith.addf %arg, %val : f32 - scf.yield %2 : f32 + scf.for %arg0 = %c0 to %c1024 step %c1 { + %1 = scf.for %arg1 = %c0 to %c256 step %c1 iter_args(%arg2 = %cst) -> (f32) { + %2 = memref.load %alloc[%arg1, %arg0] : memref<256x1024xf32> + %3 = arith.addf %arg2, %2 : f32 + scf.yield %3 : f32 } - memref.store %sum, %b_ref[%j] : memref<1024xf32> + memref.store %1, %alloc_1[%arg0] : memref<1024xf32> } - - %b = call @reduce_test(%a) : (memref<256x1024xf32>) -> memref<1x1024xf32> - %cast_b = memref.cast %b : memref<1x1024xf32> to memref<*xf32> - %cast_b_ref = memref.cast %b_ref : memref<1024xf32> to memref<*xf32> //call @printMemrefF32(%cast_b): (memref<*xf32>) -> () //call @printMemrefF32(%cast_b_ref): (memref<*xf32>) -> () // CHECK: [ALLCLOSE: TRUE] - call @printAllcloseF32(%cast_b, %cast_b_ref) : (memref<*xf32>, memref<*xf32>) -> () - memref.dealloc %a : memref<256x1024xf32> - memref.dealloc %b_ref : memref<1024xf32> + %0 = call @reduce_test(%alloc) : (memref<256x1024xf32>) -> memref<1x1024xf32> + %cast = memref.cast %0 : memref<1x1024xf32> to memref<*xf32> + %cast_2 = memref.cast %alloc_1 : memref<1024xf32> to memref<*xf32> + call @printAllcloseF32(%cast, %cast_2) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %alloc : memref<256x1024xf32> + memref.dealloc %alloc_1 : memref<1024xf32> return } func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} diff --git a/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp b/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp index 99691a327..0816be8d2 100644 --- a/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp +++ b/test/Integration/Dialect/XeTile/xetile-to-func-vc.pp @@ -1,38 +1,42 @@ +// gpu dialect with subgroup level XeTile dialect to +// llvm dialect (for host code) and +// spirv dialect (for device code) lowering pipeline. +// Ready for imex runner starting from GPU dialect. + builtin.module( cse - gpu.module(xetile-init-duplicate - xetile-canonicalization - xetile-blocking - cse - convert-xetile-to-xegpu - cse - imex-xegpu-hoist-transpose - imex-xegpu-apply-vnni-transformation - imex-xegpu-optimize-transpose - cse + gpu.module(xetile-init-duplicate, + xetile-canonicalization, + xetile-blocking, + cse, + convert-xetile-to-xegpu, + cse, + imex-xegpu-hoist-transpose, + imex-xegpu-apply-vnni-transformation, + imex-xegpu-optimize-transpose) + cse + gpu.module(convert-math-to-vc{enable-high-precision-interim-calculation=true}, convert-xegpu-to-vc) cse xegpu-vector-linearize canonicalize cse reconcile-unrealized-casts - bf16-to-gpu - imex-convert-gpu-to-spirv - spirv.module(spirv-lower-abi-attrs - spirv-update-vce) + gpu.module(math-extend-to-supported-types{target-type=f32}) + gpu.module(arith-emulate-unsupported-floats{source-types=bf16 target-type=f32}) + spirv-attach-target{ver=v1.0 caps=Addresses,BFloat16TypeKHR,Float16Buffer,Int64,Int16,Int8,Kernel,Linkage,Vector16,GenericPointer,Groups,Float16,Float64,AtomicFloat32AddEXT,ExpectAssumeKHR,SubgroupDispatch,VectorComputeINTEL,VectorAnyINTEL,Bfloat16ConversionINTEL exts=SPV_EXT_shader_atomic_float_add,SPV_KHR_bfloat16,SPV_KHR_expect_assume,SPV_INTEL_vector_compute,SPV_INTEL_bfloat16_conversion} + imex-convert-to-spirv{use-64bit-index=true} + gpu.module(spirv.module(spirv-lower-abi-attrs, spirv-update-vce)) func.func(llvm-request-c-wrappers) - serialize-spirv convert-vector-to-scf - convert-gpu-to-gpux convert-scf-to-cf + func.func(gpu-async-region) expand-strided-metadata + gpu-to-llvm{use-bare-pointers-for-kernels=true} finalize-memref-to-llvm - convert-cf-to-llvm - convert-vector-to-llvm - convert-index-to-llvm - convert-arith-to-llvm - convert-func-to-llvm - convert-math-to-llvm - convert-gpux-to-llvm + convert-to-llvm lower-affine - reconcile-unrealized-casts) + reconcile-unrealized-casts + gpu-module-to-binary) + +// End diff --git a/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp b/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp index e59558926..d2218213c 100644 --- a/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp +++ b/test/Integration/Dialect/XeTile/xetile-wg-to-func-vc.pp @@ -1,40 +1,44 @@ +// gpu dialect with workgroup level XeTile dialect to +// llvm dialect (for host code) and +// spirv dialect (for device code) lowering pipeline. +// Ready for imex runner starting from GPU dialect. + builtin.module( cse - gpu.module(xetile-wg-to-sg - cse - xetile-init-duplicate - xetile-canonicalization - xetile-blockop-fallback - xetile-blocking - cse - convert-xetile-to-xegpu - cse - imex-xegpu-hoist-transpose - imex-xegpu-apply-vnni-transformation - imex-xegpu-optimize-transpose - cse + gpu.module(xetile-wg-to-sg, + cse, + xetile-init-duplicate, + xetile-canonicalization, + xetile-blockop-fallback, + xetile-blocking, + cse, + convert-xetile-to-xegpu, + cse, + imex-xegpu-hoist-transpose, + imex-xegpu-apply-vnni-transformation, + imex-xegpu-optimize-transpose) + cse + gpu.module(convert-math-to-vc{enable-high-precision-interim-calculation=true}, convert-xegpu-to-vc) cse xegpu-vector-linearize canonicalize + cse reconcile-unrealized-casts - bf16-to-gpu - imex-convert-gpu-to-spirv - spirv.module(spirv-lower-abi-attrs - spirv-update-vce) + gpu.module(math-extend-to-supported-types{target-type=f32}) + gpu.module(arith-emulate-unsupported-floats{source-types=bf16 target-type=f32}) + spirv-attach-target{ver=v1.0 caps=Addresses,BFloat16TypeKHR,Float16Buffer,Int64,Int16,Int8,Kernel,Linkage,Vector16,GenericPointer,Groups,Float16,Float64,AtomicFloat32AddEXT,ExpectAssumeKHR,SubgroupDispatch,VectorComputeINTEL,VectorAnyINTEL,Bfloat16ConversionINTEL exts=SPV_EXT_shader_atomic_float_add,SPV_KHR_bfloat16,SPV_KHR_expect_assume,SPV_INTEL_vector_compute,SPV_INTEL_bfloat16_conversion} + imex-convert-to-spirv{use-64bit-index=true} + gpu.module(spirv.module(spirv-lower-abi-attrs, spirv-update-vce)) func.func(llvm-request-c-wrappers) - serialize-spirv convert-vector-to-scf - convert-gpu-to-gpux convert-scf-to-cf + func.func(gpu-async-region) expand-strided-metadata + gpu-to-llvm{use-bare-pointers-for-kernels=true} finalize-memref-to-llvm - convert-cf-to-llvm - convert-vector-to-llvm - convert-index-to-llvm - convert-arith-to-llvm - convert-func-to-llvm - convert-math-to-llvm - convert-gpux-to-llvm + convert-to-llvm + gpu-module-to-binary lower-affine reconcile-unrealized-casts) +// End