diff --git a/mhlo/IR/hlo_ops.cc b/mhlo/IR/hlo_ops.cc index 4d01cd520..50aa7b333 100644 --- a/mhlo/IR/hlo_ops.cc +++ b/mhlo/IR/hlo_ops.cc @@ -5442,46 +5442,6 @@ LogicalResult TopKOp::inferReturnTypeComponents( inferredReturnShapes); } -bool isMhloCompareOfBodyArgumentsGtOrLt(Block& body) { - auto terminator = dyn_cast(body.getTerminator()); - if (!terminator || terminator->getNumOperands() != 1) return false; - - auto compare = terminator.getOperand(0).getDefiningOp(); - if (!compare) return false; - auto direction = compare.getComparisonDirection(); - if (direction != ComparisonDirection::GT && - direction != ComparisonDirection::LT) - return false; - - if (body.getNumArguments() != 2) return false; - auto arg0 = matchers::m_Val(body.getArgument(0)); - auto arg1 = matchers::m_Val(body.getArgument(1)); - return matchPattern(compare.getResult(), m_Op(arg0, arg1)) || - matchPattern(compare.getResult(), m_Op(arg1, arg0)); -} - -LogicalResult TopKOp::verify() { - Builder builder(getContext()); - auto operandType = getOperand().getType(); - Block& body = getBody().front(); - - auto expectedBodyArgType = - RankedTensorType::get({}, operandType.getElementType()); - auto expectedBodyType = - builder.getFunctionType({expectedBodyArgType, expectedBodyArgType}, - {RankedTensorType::get({}, builder.getI1Type())}); - auto actualBodyType = builder.getFunctionType( - body.getArgumentTypes(), body.getTerminator()->getOperandTypes()); - if (expectedBodyType != actualBodyType) - return emitOpError() << "unsupported body: expected: " << expectedBodyType - << ", got " << actualBodyType; - if (!isMhloCompareOfBodyArgumentsGtOrLt(body)) - return emitOpError() << "unsupported body: expected mhlo.compare of " - << "body arguments with GT or LT comparison_direction"; - - return success(); -} - //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mhlo/IR/hlo_ops.td b/mhlo/IR/hlo_ops.td index d480e6794..dea1b6a1e 100644 --- a/mhlo/IR/hlo_ops.td +++ b/mhlo/IR/hlo_ops.td @@ -2976,33 +2976,29 @@ def MHLO_TopKOp : MHLO_Op<"topk", [RecursiveMemoryEffects, InferTensorType]> { let summary = "TopK operation"; let description = [{ Returns top `k` values and their indices, along the last - dimension of the operand using the given `comparator` (for usual topk - behavior, it should be strict-greater-than operation). + dimension of the operand if `largest=true` or the bottom `k` values if + `largest=false`. See: https://www.tensorflow.org/xla/operation_semantics#top-k Example: ```mlir - %values, %indices = mhlo.topk(%operand, k=5) { - ^bb0(%arg0: tensor, %arg1: tensor): - %predicate = mhlo.compare GT, %arg0, %arg1 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<100xf32> -> (tensor<5xf32>, tensor<5xi32>) + %values, %indices = mhlo.topk(%operand, k=5, largest=true) + : tensor<100xf32> -> (tensor<5xf32>, tensor<5xi32>) ``` }]; let arguments = (ins MHLO_Tensor:$operand, - I64Attr:$k + I64Attr:$k, + DefaultValuedOptionalAttr:$largest ); - let regions = (region SizedRegion<1>:$body); let results = (outs MHLO_Tensor:$values, MHLO_Tensor:$indices); - let hasVerifier = 1; let assemblyFormat = [{ - `(`$operand `,` `k` `=` $k`)` $body attr-dict `:` + `(`$operand `,` `k` `=` $k `,` `largest` `=` $largest `)` attr-dict `:` type($operand) `->` `(`type($values)`,` type($indices)`)` }]; let hasCustomHLOConverter = 1; diff --git a/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 25093f046..0d0699418 100644 --- a/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -2059,11 +2059,7 @@ func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> ten func.func @op_topk(%arg0 : tensor<16xf32>) { // expected-error@+1 {{failed to legalize operation 'mhlo.topk' that was explicitly marked illegal}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return } diff --git a/tests/Dialect/mhlo/ops.mlir b/tests/Dialect/mhlo/ops.mlir index 95583482c..27241c1dc 100644 --- a/tests/Dialect/mhlo/ops.mlir +++ b/tests/Dialect/mhlo/ops.mlir @@ -6518,55 +6518,35 @@ func.func @f8e5m2(%arg0: tensor) -> tensor { // ----- func.func @top_k_1d(%arg0 : tensor<16xf32>) { - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) return } // ----- func.func @top_k_nd(%arg0 : tensor<16x16xf32>) { - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=false) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) return } // ----- func.func @top_k_unbounded(%arg0 : tensor) { - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor -> (tensor, tensor) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor -> (tensor, tensor) return } // ----- func.func @top_k_bounded(%arg0 : tensor>) { - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor> -> (tensor<16x?x8xf32, #mhlo.type_extensions>, tensor<16x?x8xi32, #mhlo.type_extensions>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor> -> (tensor<16x?x8xf32, #mhlo.type_extensions>, tensor<16x?x8xi32, #mhlo.type_extensions>) return } // ----- func.func @top_k_unranked(%arg0 : tensor<*xf32>) { - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<*xf32> -> (tensor<*xf32>, tensor<*xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<*xf32> -> (tensor<*xf32>, tensor<*xi32>) return } @@ -6575,11 +6555,7 @@ func.func @top_k_unranked(%arg0 : tensor<*xf32>) { func.func @topk_rank_at_least_one(%arg0 : tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{operand's rank must be at least 1}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor -> (tensor<8xf32>, tensor<8xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor -> (tensor<8xf32>, tensor<8xi32>) return } @@ -6588,69 +6564,6 @@ func.func @topk_rank_at_least_one(%arg0 : tensor) { func.func @topk_last_dimension_at_least_k(%arg0 : tensor<4xf32>) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{operand's last dimension must be at least 8}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<4xf32> -> (tensor<8xf32>, tensor<8xi32>) - return -} - -// ----- - -func.func @topk_body_must_have_two_arguments(%arg0 : tensor<16xf32>) { - // expected-error@+1 {{unsupported body: expected: '(tensor, tensor) -> tensor', got '(tensor) -> tensor'}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor): - %predicate = mhlo.compare GT, %arg1, %arg1 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) - return -} - -// ----- - -func.func @topk_body_must_have_one_result(%arg0 : tensor<16xf32>) { - // expected-error@+1 {{unsupported body: expected: '(tensor, tensor) -> tensor', got '(tensor, tensor) -> (tensor, tensor)'}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate, %predicate : tensor, tensor - } : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) - return -} - -// ----- - -func.func @topk_body_arguments_must_have_operand_element_type(%arg0 : tensor<16xf32>) { - // expected-error@+1 {{unsupported body: expected: '(tensor, tensor) -> tensor', got '(tensor, tensor) -> tensor'}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare GT, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) - return -} - -// ----- - -func.func @topk_body_results_must_have_i1_element_type(%arg0 : tensor<16xf32>) { - // expected-error@+1 {{unsupported body: expected: '(tensor, tensor) -> tensor', got '(tensor, tensor) -> tensor'}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - mhlo.return %arg1 : tensor - } : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) - return -} - -// ----- - -func.func @topk_body_must_consist_of_compare_gt_or_compare_lt(%arg0 : tensor<16xf32>) { - // expected-error@+1 {{unsupported body: expected mhlo.compare of body arguments with GT or LT comparison_direction}} - %0:2 = mhlo.topk(%arg0, k=8) { - ^bb0(%arg1: tensor, %arg2: tensor): - %predicate = mhlo.compare EQ, %arg1, %arg2 : (tensor, tensor) -> tensor - mhlo.return %predicate : tensor - } : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>) + %0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<4xf32> -> (tensor<8xf32>, tensor<8xi32>) return }