diff --git a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/SpecializeDMACode.cpp b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/SpecializeDMACode.cpp index 2ed9fa9..923015c 100644 --- a/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/SpecializeDMACode.cpp +++ b/codegen/compiler/src/Quidditch/Dialect/Snitch/Transforms/SpecializeDMACode.cpp @@ -2,6 +2,7 @@ #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h" #include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/FunctionInterfaces.h" namespace quidditch::Snitch { @@ -31,6 +32,18 @@ static void removeComputeOps(FunctionOpInterface dmaCode) { dmaCode->walk([&](Operation *operation) { if (isa(operation)) operation->erase(); + if (auto index = dyn_cast(operation)) { + OpBuilder builder(operation); + // Make the DMA core follow the control flow of the first compute core. + // This whole pass runs under the assumption that any operation that is + // run on either the DMA core or compute cores are in non-divergent + // control flow. Making the DMA core follow any compute cores control + // flow is therefore safe to do. + // This is mainly required for barriers within a `scf.forall`. + operation->replaceAllUsesWith( + builder.create(operation->getLoc(), 0)); + operation->erase(); + } }); } diff --git a/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir b/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir index 39e9e46..1631720 100644 --- a/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir +++ b/codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir @@ -41,22 +41,26 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK: scf.if - %r = scf.if %cond -> !quidditch_snitch.dma_token { + %r:2 = scf.if %cond -> (!quidditch_snitch.dma_token, index) { // CHECK-NEXT: quidditch_snitch.microkernel_fence // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: %[[C:.*]] = quidditch_snitch.completed_token - // CHECK-NEXT: yield %[[C]] %t3 = quidditch_snitch.start_dma_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32> - scf.yield %t3 : !quidditch_snitch.dma_token + // CHECK-NEXT: %[[I:.*]] = quidditch_snitch.compute_core_index + %i = quidditch_snitch.compute_core_index + // CHECK-NEXT: yield %[[C]], %[[I]] + scf.yield %t3, %i : !quidditch_snitch.dma_token, index } else { // CHECK-NEXT: else // CHECK-NEXT: %[[C:.*]] = quidditch_snitch.completed_token - // CHECK-NEXT: yield %[[C]] %c = quidditch_snitch.completed_token - scf.yield %c : !quidditch_snitch.dma_token + // CHECK-NEXT: %[[I:.*]] = arith.constant + %i = arith.constant 1 : index + // CHECK-NEXT: yield %[[C]], %[[I]] + scf.yield %c, %i : !quidditch_snitch.dma_token, index } // CHECK: quidditch_snitch.barrier - quidditch_snitch.wait_for_dma_transfers %r : !quidditch_snitch.dma_token + quidditch_snitch.wait_for_dma_transfers %r#0 : !quidditch_snitch.dma_token // CHECK-NEXT: return return } @@ -83,9 +87,12 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) { // CHECK-NEXT: scf.if // CHECK-NEXT: quidditch_snitch.barrier // CHECK-NEXT: quidditch_snitch.start_dma_transfer -// CHECK-NEXT: yield +// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 +// CHECK-NEXT: yield %{{.*}}, %[[ZERO]] : // CHECK-NEXT: else // CHECK-NEXT: completed_token +// CHECK-NEXT: arith.constant +// CHECK-NEXT: yield // CHECK: quidditch_snitch.wait_for_dma_transfers // CHECK-NEXT: quidditch_snitch.barrier