Skip to content

Commit

Permalink
[GlobalOptimization] Fix a silent bug in DetatchElementwiseFromNamedO…
Browse files Browse the repository at this point in the history
…ps pass (iree-org#19356)

This moves match failure checks before modifying linalg ops, and loosens
the check for identity map access to the output tensor.

### Context:

Specific depthwise convolution ops were encountering numeric failures.
See <iree-org#18600> and
<iree-org#19339>. I noticed that the bias
was not affecting the output values, and tracked down where the bias was
getting deleted.

The issue is that the pass `DetatchElementwiseFromNamedOps` was
modifying the `depthwise_conv` op to use a zero-fill *before* checking
for some match failures. This resulted in a partial application of the
pattern where the original bias did not get added back to the modified
linalg op result.

The depthwise conv ops were specifically failing to have an identity map
for the output tensor access.

For example:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @torch_jit(%arg0: tensor<1x96x56x56xf32>, %arg1: tensor<96x1x7x7xf32>, %arg2: tensor<96xf32>) -> tensor<1x96x56x56xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %padded = tensor.pad %arg0 low[0, 0, 3, 3] high[0, 0, 3, 3] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<1x96x56x56xf32> to tensor<1x96x62x62xf32>
    %0 = tensor.empty() : tensor<1x96x56x56xf32>
    %broadcasted = linalg.broadcast ins(%arg2 : tensor<96xf32>) outs(%0 : tensor<1x96x56x56xf32>) dimensions = [0, 2, 3] 
    %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2], [3]] : tensor<96x1x7x7xf32> into tensor<96x7x7xf32>
    %1 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded, %collapsed : tensor<1x96x62x62xf32>, tensor<96x7x7xf32>) outs(%broadcasted : tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
    return %1 : tensor<1x96x56x56xf32>
  }
}
```

generalizes to

```mlir
#map = affine_map<(d0, d1, d2, d3) -> (d1)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @torch_jit(%arg0: tensor<1x96x56x56xf32>, %arg1: tensor<96x1x7x7xf32>, %arg2: tensor<96xf32>) -> tensor<1x96x56x56xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %padded = tensor.pad %arg0 low[0, 0, 3, 3] high[0, 0, 3, 3] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<1x96x56x56xf32> to tensor<1x96x62x62xf32>
    %0 = tensor.empty() : tensor<1x96x56x56xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<96xf32>) outs(%0 : tensor<1x96x56x56xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x96x56x56xf32>
    %collapsed = tensor.collapse_shape %arg1 [[0, 1], [2], [3]] : tensor<96x1x7x7xf32> into tensor<96x7x7xf32>
    %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%padded, %collapsed : tensor<1x96x62x62xf32>, tensor<96x7x7xf32>) outs(%1 : tensor<1x96x56x56xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %3 = arith.mulf %in, %in_0 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<1x96x56x56xf32>
    return %2 : tensor<1x96x56x56xf32>
  }
}
```

For some reason, the channel dim `d3` appears after the spatial dims
(`d1` and `d2`) for this particular op.

---------

Signed-off-by: zjgarvey <[email protected]>
  • Loading branch information
zjgarvey authored Dec 5, 2024
1 parent 543fb31 commit d48071d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ struct DetachElementwisePattern

Location loc = linalgOp.getLoc();

// Check if the output tensor access is a projected permutation
if (!linalgOp.getMatchingIndexingMap(outputOperands.front())
.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
linalgOp, "Output indexing map must be a projected permutation.");
}

int64_t outputRank = outputType.getRank();
SmallVector<utils::IteratorType> iterators(outputRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> maps(3, rewriter.getMultiDimIdentityMap(outputRank));

// Create a zero tensor as the new output tensor operand to the Linalg
// contraction op.
SmallVector<OpFoldResult> mixedSizes =
Expand All @@ -84,24 +96,6 @@ struct DetachElementwisePattern
rewriter.modifyOpInPlace(linalgOp,
[&]() { linalgOp.setDpsInitOperand(0, fill); });

auto outputMap = mlir::compressUnusedDims(
linalgOp.getMatchingIndexingMap(outputOperands.front()));
// Only support identity map for output access for now; this is the case for
// all existing contraction/convolution ops.
if (!outputMap.isIdentity())
return failure();
SmallVector<AffineMap> maps(3, outputMap);

SmallVector<utils::IteratorType> iterators;
iterators.reserve(outputMap.getNumResults());
for (int i = 0, e = outputMap.getNumResults(); i < e; ++i) {
int pos = cast<AffineDimExpr>(outputMap.getResult(i)).getPosition();
auto attr = linalgOp.getIteratorTypesArray()[pos];
if (!linalg::isParallelIterator(attr))
return failure();
iterators.push_back(attr);
}

// Create a generic op to add back the original output tensor operand.
rewriter.setInsertionPointAfter(linalgOp);
auto genericOp = rewriter.create<linalg::GenericOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ util.func public @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32

// -----

util.func public @depthwise_conv(%arg0: tensor<1x96x62x62xf32>, %arg1: tensor<96x7x7xf32>, %arg2: tensor<96xf32>) -> tensor<1x96x56x56xf32> {
%0 = tensor.empty() : tensor<1x96x56x56xf32>
%broadcasted = linalg.broadcast ins(%arg2 : tensor<96xf32>) outs(%0 : tensor<1x96x56x56xf32>) dimensions = [0, 2, 3]
%1 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %arg1 : tensor<1x96x62x62xf32>, tensor<96x7x7xf32>) outs(%broadcasted : tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
util.return %1 : tensor<1x96x56x56xf32>
}

// CHECK-LABEL: util.func public @depthwise_conv
// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x96x62x62xf32>, %[[FILTER:.+]]: tensor<96x7x7xf32>, %[[BIAS:.+]]: tensor<96xf32>)
// CHECK: %[[INIT:.+]] = linalg.broadcast
// CHECK-SAME: ins(%[[BIAS]] :
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nchw_chw
// CHECK-SAME: ins(%[[INPUT]], %[[FILTER]]
// CHECK-SAME: outs(%[[FILL]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[CONV]], %[[INIT]]
// CHECK-SAME: outs(%[[FILL]]

// -----

util.func public @keep_fill(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down

0 comments on commit d48071d

Please sign in to comment.