Skip to content

Commit

Permalink
[mlir][linalg] Fix crashes in parser on linalg ops without operands
Browse files Browse the repository at this point in the history
`parseDstStyleOp` parses both `ins()` and `outs()` optionally. The parsers
for `linalg.transpose`, `linalg.broadcast` and `linalg.map` however
assume that at least one operand is present in the state, leading to crashes
otherwise.

This patch adds checks to the parsers which stop them from crashing if
no operands were parsed. After the Ops are parsed successfuly, the verifier
takes it from there.

Fix llvm#97857
  • Loading branch information
ubfx committed Jul 7, 2024
1 parent f4c7811 commit bd56160
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
9 changes: 6 additions & 3 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1356,8 +1356,10 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();

if (payloadOpName.has_value()) {
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
ArrayRef(result.operands).drop_back());
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
payloadOpAttrs,
ArrayRef(result.operands).drop_back());
} else {
SmallVector<OpAsmParser::Argument> regionArgs;
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
Expand Down Expand Up @@ -1739,7 +1741,8 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc,
ValueRange outputs) {
buildGenericRegion(builder, loc, region, inputs, outputs,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
if (!args.empty())
b.create<linalg::YieldOp>(loc, args[0]);
});
}

Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Dialect/Linalg/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,18 @@ func.func @map_input_output_shape_mismatch(

// -----

func.func @map_no_operands(
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-> tensor<64xf32> {
// This must not crash the parser.
linalg.map { arith.addf }
// expected-error @+1 {{cannot name an operation with no results}}
%add = linalg.map { arith.addf }
func.return %add : tensor<64xf32>
}

// -----

func.func @reduce_input_vs_init_dimension_mismatch(
%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
Expand Down Expand Up @@ -676,6 +688,16 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,

// -----

func.func @transpose_no_operands() -> tensor<32x64x16xf32> {
// This must not crash the parser.
linalg.transpose permutation = [1, 0, 2]
// expected-error @+1 {{cannot name an operation with no results}}
%transpose = linalg.transpose permutation = [1, 0, 2]
func.return %transpose : tensor<32x64x16xf32>
}

// -----

func.func @broadcast_input_dims_rank_mismatch(
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
-> tensor<4x8x16xf32> {
Expand Down Expand Up @@ -725,6 +747,15 @@ func.func @broadcast_size_1_extension_not_supported(
dimensions = [1]
func.return %bcast : tensor<4x?x16xf32>
}
// -----

func.func @broadcast_no_operands()
-> tensor<4x?x16xf32> {
linalg.broadcast dimensions = [1]
// expected-error @+1 {{cannot name an operation with no results}}
%broadcast = linalg.broadcast dimensions = [1]
func.return %broadcast : tensor<32x64x16xf32>
}

// -----

Expand Down

0 comments on commit bd56160

Please sign in to comment.