diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td index e35acb0..3224d91 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/Passes.td @@ -10,6 +10,17 @@ def FormMicrokernelsPass ]; } +def PromotePadsToL1Pass : Pass<"quidditch-promote-pads-to-l1"> { + let description = [{ + Converts supported `tensor.pad` operations to `start_tensor_transfer` and + `wait_for_tensor_copy` pairs. + }]; + + let dependentDialects = [ + "quidditch::Snitch::QuidditchSnitchDialect", + ]; +} + def PromoteOperandsToL1Pass : Pass<"quidditch-promote-operands-to-l1"> { let description = [{ TODO: diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp index a39369b..2a5129b 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/PromoteToL1.cpp @@ -4,11 +4,15 @@ #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/Matchers.h" #include "mlir/Interfaces/TilingInterface.h" namespace quidditch::Snitch { #define GEN_PASS_DEF_PROMOTEOPERANDSTOL1PASS #define GEN_PASS_DEF_PROMOTEALLOCSTOL1PASS +#define GEN_PASS_DEF_PROMOTEPADSTOL1PASS #include "Quidditch/Dialect/Snitch/Transforms/Passes.h.inc" } // namespace quidditch::Snitch @@ -32,6 +36,16 @@ class PromoteAllocsToL1 protected: void runOnOperation() override; }; + +class PromotePadsToL1 + : public quidditch::Snitch::impl::PromotePadsToL1PassBase { +public: + using Base::Base; + +protected: + void runOnOperation() override; +}; + } // namespace using namespace mlir; @@ -76,3 +90,32 @@ void PromoteAllocsToL1::runOnOperation() { tensorOp.erase(); }); } + +void PromotePadsToL1::runOnOperation() { + getOperation()->walk([&](tensor::PadOp padOp) { + // 'start_tensor_copy' does not yet support lower padding. + if (!padOp.hasZeroLowPad()) + return; + + Value constant = padOp.getConstantPaddingValue(); + if (!constant) + return; + + // 'start_tensor_copy' only supports zero-padding right now. + // Poison (undef) can also be lowered to perform zero-padding. + if (!matchPattern(constant, m_NonZero()) && + !matchPattern(constant, m_PosZeroFloat()) && + !matchPattern(constant, m_Constant(nullptr))) + return; + + OpBuilder builder(padOp); + auto copyOp = builder.create( + padOp.getLoc(), padOp.getType(), builder.getType(), + padOp.getSource(), padOp.getHigh(), padOp.getStaticHighAttr()); + auto waitOp = builder.create( + padOp.getLoc(), copyOp.getResult(), copyOp.getToken(), + /*copy=*/padOp.getSource()); + padOp.replaceAllUsesWith(waitOp.getResult()); + padOp.erase(); + }); +} diff --git a/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp b/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp index 24a1188..3fd65b9 100644 --- a/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp +++ b/codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp @@ -59,6 +59,14 @@ static LogicalResult setRootConfig(FunctionOpInterface funcOp, SmallVector l1Tiles(3, 0); bool dualBuffer = false; + if (funcOp.getName() == + "main$async_dispatch_9_matmul_transpose_b_1x161x600_f64") { + workgroupTiles[2] = 100; + + l1Tiles[0] = 0; + l1Tiles[1] = 56; + dualBuffer = true; + } if (funcOp.getName() == "main$async_dispatch_0_matmul_transpose_b_1x400x161_f64") { l1Tiles[1] = 40; diff --git a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp index 69fa9aa..d4e9d40 100644 --- a/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp +++ b/codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp @@ -180,6 +180,8 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend { .addPass(quidditch::createRemoveTrivialLoopsPass) .addPass(createCanonicalizerPass) .addPass(createCSEPass) + .addPass(createFuseTensorPadWithConsumerPass) + .addPass(createConcretizePadResultShapePass) .addPass([] { return quidditch::createTensorTilePass( {quidditch::TilingLevel::Reduction}); @@ -191,6 +193,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend { }) .addPass(createFuseTensorPadWithConsumerPass) .addPass(createConcretizePadResultShapePass) + .addPass(quidditch::Snitch::createPromotePadsToL1Pass) .addPass(quidditch::Snitch::createPromoteOperandsToL1Pass) .addPass(createCanonicalizerPass) .addPass(createCSEPass) diff --git a/codegen/tests/Dialect/Snitch/Transforms/promote-pads-to-l1.mlir b/codegen/tests/Dialect/Snitch/Transforms/promote-pads-to-l1.mlir new file mode 100644 index 0000000..de95548 --- /dev/null +++ b/codegen/tests/Dialect/Snitch/Transforms/promote-pads-to-l1.mlir @@ -0,0 +1,35 @@ +// RUN: quidditch-opt %s -p "builtin.module(func.func(quidditch-promote-pads-to-l1))" --allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @test_zero_f32( +// CHECK-SAME: %[[A:[[:alnum:]]+]]: tensor<32x32xf32> +func.func @test_zero_f32(%a : tensor<32x32xf32>) -> tensor<33x33xf32> { + %c = arith.constant 0.0 : f32 + // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy %[[A]] + // CHECK-SAME: pad with zero to [1, 1] + // CHECK: %[[R2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[A]] + // CHECK-SAME: to %[[R]] + // CHECK-SAME: using %[[T]] + %0 = tensor.pad %a low[0, 0] high[1, 1] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %c : f32 + } : tensor<32x32xf32> to tensor<33x33xf32> + // CHECK: return %[[R2]] + return %0 : tensor<33x33xf32> +} + +// CHECK-LABEL: @test_poison( +// CHECK-SAME: %[[A:[[:alnum:]]+]]: tensor<32x32xf32> +func.func @test_poison(%a : tensor<32x32xf32>) -> tensor<33x33xf32> { + %c = ub.poison : f32 + // CHECK: %[[R:.*]], %[[T:.*]] = quidditch_snitch.start_tensor_copy %[[A]] + // CHECK-SAME: pad with zero to [1, 1] + // CHECK: %[[R2:.*]] = quidditch_snitch.wait_for_tensor_copy of %[[A]] + // CHECK-SAME: to %[[R]] + // CHECK-SAME: using %[[T]] + %0 = tensor.pad %a low[0, 0] high[1, 1] { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %c : f32 + } : tensor<32x32xf32> to tensor<33x33xf32> + // CHECK: return %[[R2]] + return %0 : tensor<33x33xf32> +}