diff --git a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp index 524f111e271c..243a331ef480 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/DetachElementwiseFromNamedOps.cpp @@ -69,6 +69,18 @@ struct DetachElementwisePattern Location loc = linalgOp.getLoc(); + // Check if the output tensor access is a projected permutation + if (!linalgOp.getMatchingIndexingMap(outputOperands.front()) + .isProjectedPermutation()) { + return rewriter.notifyMatchFailure( + linalgOp, "Output indexing map must be a projected permutation."); + } + + int64_t outputRank = outputType.getRank(); + SmallVector iterators(outputRank, + utils::IteratorType::parallel); + SmallVector maps(3, rewriter.getMultiDimIdentityMap(outputRank)); + // Create a zero tensor as the new output tensor operand to the Linalg // contraction op. SmallVector mixedSizes = @@ -84,24 +96,6 @@ struct DetachElementwisePattern rewriter.modifyOpInPlace(linalgOp, [&]() { linalgOp.setDpsInitOperand(0, fill); }); - auto outputMap = mlir::compressUnusedDims( - linalgOp.getMatchingIndexingMap(outputOperands.front())); - // Only support identity map for output access for now; this is the case for - // all existing contraction/convolution ops. - if (!outputMap.isIdentity()) - return failure(); - SmallVector maps(3, outputMap); - - SmallVector iterators; - iterators.reserve(outputMap.getNumResults()); - for (int i = 0, e = outputMap.getNumResults(); i < e; ++i) { - int pos = cast(outputMap.getResult(i)).getPosition(); - auto attr = linalgOp.getIteratorTypesArray()[pos]; - if (!linalg::isParallelIterator(attr)) - return failure(); - iterators.push_back(attr); - } - // Create a generic op to add back the original output tensor operand. rewriter.setInsertionPointAfter(linalgOp); auto genericOp = rewriter.create( diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/detach_elementwise_from_named_ops.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/detach_elementwise_from_named_ops.mlir index cec787f43588..870e3518637e 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/detach_elementwise_from_named_ops.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/detach_elementwise_from_named_ops.mlir @@ -101,6 +101,27 @@ util.func public @conv(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32 // ----- +util.func public @depthwise_conv(%arg0: tensor<1x96x62x62xf32>, %arg1: tensor<96x7x7xf32>, %arg2: tensor<96xf32>) -> tensor<1x96x56x56xf32> { + %0 = tensor.empty() : tensor<1x96x56x56xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<96xf32>) outs(%0 : tensor<1x96x56x56xf32>) dimensions = [0, 2, 3] + %1 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %arg1 : tensor<1x96x62x62xf32>, tensor<96x7x7xf32>) outs(%broadcasted : tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32> + util.return %1 : tensor<1x96x56x56xf32> +} + +// CHECK-LABEL: util.func public @depthwise_conv +// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x96x62x62xf32>, %[[FILTER:.+]]: tensor<96x7x7xf32>, %[[BIAS:.+]]: tensor<96xf32>) +// CHECK: %[[INIT:.+]] = linalg.broadcast +// CHECK-SAME: ins(%[[BIAS]] : +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: ins(%[[INPUT]], %[[FILTER]] +// CHECK-SAME: outs(%[[FILL]] +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[CONV]], %[[INIT]] +// CHECK-SAME: outs(%[[FILL]] + +// ----- + util.func public @keep_fill(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index