Skip to content

Commit

Permalink
[SpecializeDMACode] Properly lower compute_core_index (#109)
Browse files Browse the repository at this point in the history
Barriers may be inserted in a loop resulting from the lowering of a
`scf.forall`. The lowering of `compute_core_index` unfortunately returns
`num_compute_cores + 1` which is outside the specified range of the
operation and leads to such loops and their barriers being skipped by
the DMA core. As the pass already assumes non-divergent control flow, we
can fix this by specializing `compute_core_index` to any integer in the
defined range of the op.
  • Loading branch information
zero9178 authored Aug 15, 2024
1 parent 60fa26e commit 03aaad4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchDialect.h"
#include "Quidditch/Dialect/Snitch/IR/QuidditchSnitchOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

namespace quidditch::Snitch {
Expand Down Expand Up @@ -31,6 +32,18 @@ static void removeComputeOps(FunctionOpInterface dmaCode) {
dmaCode->walk([&](Operation *operation) {
if (isa<MemRefMicrokernelOp, MicrokernelFenceOp>(operation))
operation->erase();
if (auto index = dyn_cast<ComputeCoreIndexOp>(operation)) {
OpBuilder builder(operation);
// Make the DMA core follow the control flow of the first compute core.
// This whole pass runs under the assumption that any operation that is
// run on either the DMA core or compute cores are in non-divergent
// control flow. Making the DMA core follow any compute cores control
// flow is therefore safe to do.
// This is mainly required for barriers within a `scf.forall`.
operation->replaceAllUsesWith(
builder.create<arith::ConstantIndexOp>(operation->getLoc(), 0));
operation->erase();
}
});
}

Expand Down
21 changes: 14 additions & 7 deletions codegen/tests/Dialect/Snitch/Transforms/specialize-dma-code.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,26 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {


// CHECK: scf.if
%r = scf.if %cond -> !quidditch_snitch.dma_token {
%r:2 = scf.if %cond -> (!quidditch_snitch.dma_token, index) {
// CHECK-NEXT: quidditch_snitch.microkernel_fence
// CHECK-NEXT: quidditch_snitch.barrier
// CHECK-NEXT: %[[C:.*]] = quidditch_snitch.completed_token
// CHECK-NEXT: yield %[[C]]
%t3 = quidditch_snitch.start_dma_transfer from %b_l1 : memref<32xf32> to %b : memref<32xf32>
scf.yield %t3 : !quidditch_snitch.dma_token
// CHECK-NEXT: %[[I:.*]] = quidditch_snitch.compute_core_index
%i = quidditch_snitch.compute_core_index
// CHECK-NEXT: yield %[[C]], %[[I]]
scf.yield %t3, %i : !quidditch_snitch.dma_token, index
} else {
// CHECK-NEXT: else
// CHECK-NEXT: %[[C:.*]] = quidditch_snitch.completed_token
// CHECK-NEXT: yield %[[C]]
%c = quidditch_snitch.completed_token
scf.yield %c : !quidditch_snitch.dma_token
// CHECK-NEXT: %[[I:.*]] = arith.constant
%i = arith.constant 1 : index
// CHECK-NEXT: yield %[[C]], %[[I]]
scf.yield %c, %i : !quidditch_snitch.dma_token, index
}
// CHECK: quidditch_snitch.barrier
quidditch_snitch.wait_for_dma_transfers %r : !quidditch_snitch.dma_token
quidditch_snitch.wait_for_dma_transfers %r#0 : !quidditch_snitch.dma_token
// CHECK-NEXT: return
return
}
Expand All @@ -83,9 +87,12 @@ func.func @test(%a : memref<32xf32>, %b : memref<32xf32>, %cond : i1) {
// CHECK-NEXT: scf.if
// CHECK-NEXT: quidditch_snitch.barrier
// CHECK-NEXT: quidditch_snitch.start_dma_transfer
// CHECK-NEXT: yield
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0
// CHECK-NEXT: yield %{{.*}}, %[[ZERO]] :
// CHECK-NEXT: else
// CHECK-NEXT: completed_token
// CHECK-NEXT: arith.constant
// CHECK-NEXT: yield
// CHECK: quidditch_snitch.wait_for_dma_transfers
// CHECK-NEXT: quidditch_snitch.barrier

Expand Down

0 comments on commit 03aaad4

Please sign in to comment.