Skip to content

Commit

Permalink
Replace the HLO and MHLO TopK comparator with a simpler boolean arg…
Browse files Browse the repository at this point in the history
…ument `largest`.

`comparator` only supported GT and LT, so this new version is equivalent and more type safe.

PiperOrigin-RevId: 584545672
  • Loading branch information
dimitar-asenov authored and TensorFlow MLIR Team committed Nov 22, 2023
1 parent b0249a8 commit 50ff2ab
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 150 deletions.
40 changes: 0 additions & 40 deletions mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5442,46 +5442,6 @@ LogicalResult TopKOp::inferReturnTypeComponents(
inferredReturnShapes);
}

bool isMhloCompareOfBodyArgumentsGtOrLt(Block& body) {
auto terminator = dyn_cast<ReturnOp>(body.getTerminator());
if (!terminator || terminator->getNumOperands() != 1) return false;

auto compare = terminator.getOperand(0).getDefiningOp<CompareOp>();
if (!compare) return false;
auto direction = compare.getComparisonDirection();
if (direction != ComparisonDirection::GT &&
direction != ComparisonDirection::LT)
return false;

if (body.getNumArguments() != 2) return false;
auto arg0 = matchers::m_Val(body.getArgument(0));
auto arg1 = matchers::m_Val(body.getArgument(1));
return matchPattern(compare.getResult(), m_Op<CompareOp>(arg0, arg1)) ||
matchPattern(compare.getResult(), m_Op<CompareOp>(arg1, arg0));
}

LogicalResult TopKOp::verify() {
Builder builder(getContext());
auto operandType = getOperand().getType();
Block& body = getBody().front();

auto expectedBodyArgType =
RankedTensorType::get({}, operandType.getElementType());
auto expectedBodyType =
builder.getFunctionType({expectedBodyArgType, expectedBodyArgType},
{RankedTensorType::get({}, builder.getI1Type())});
auto actualBodyType = builder.getFunctionType(
body.getArgumentTypes(), body.getTerminator()->getOperandTypes());
if (expectedBodyType != actualBodyType)
return emitOpError() << "unsupported body: expected: " << expectedBodyType
<< ", got " << actualBodyType;
if (!isMhloCompareOfBodyArgumentsGtOrLt(body))
return emitOpError() << "unsupported body: expected mhlo.compare of "
<< "body arguments with GT or LT comparison_direction";

return success();
}

//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 7 additions & 11 deletions mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -2976,33 +2976,29 @@ def MHLO_TopKOp : MHLO_Op<"topk", [RecursiveMemoryEffects, InferTensorType]> {
let summary = "TopK operation";
let description = [{
Returns top `k` values and their indices, along the last
dimension of the operand using the given `comparator` (for usual topk
behavior, it should be strict-greater-than operation).
dimension of the operand if `largest=true` or the bottom `k` values if
`largest=false`.

See:
https://www.tensorflow.org/xla/operation_semantics#top-k

Example:
```mlir
%values, %indices = mhlo.topk(%operand, k=5) {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%predicate = mhlo.compare GT, %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<100xf32> -> (tensor<5xf32>, tensor<5xi32>)
%values, %indices = mhlo.topk(%operand, k=5, largest=true)
: tensor<100xf32> -> (tensor<5xf32>, tensor<5xi32>)
```
}];

let arguments = (ins
MHLO_Tensor:$operand,
I64Attr:$k
I64Attr:$k,
DefaultValuedOptionalAttr<BoolAttr, "true">:$largest
);
let regions = (region SizedRegion<1>:$body);
let results = (outs MHLO_Tensor:$values,
MHLO_Tensor:$indices);

let hasVerifier = 1;
let assemblyFormat = [{
`(`$operand `,` `k` `=` $k`)` $body attr-dict `:`
`(`$operand `,` `k` `=` $k `,` `largest` `=` $largest `)` attr-dict `:`
type($operand) `->` `(`type($values)`,` type($indices)`)`
}];
let hasCustomHLOConverter = 1;
Expand Down
6 changes: 1 addition & 5 deletions tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2059,11 +2059,7 @@ func.func @op_stochastic_convert(%arg0: tensor<f32>, %arg1: tensor<ui32>) -> ten

func.func @op_topk(%arg0 : tensor<16xf32>) {
// expected-error@+1 {{failed to legalize operation 'mhlo.topk' that was explicitly marked illegal}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

Expand Down
101 changes: 7 additions & 94 deletions tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6518,55 +6518,35 @@ func.func @f8e5m2(%arg0: tensor<f16>) -> tensor<f8E5M2> {
// -----

func.func @top_k_1d(%arg0 : tensor<16xf32>) {
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

// -----

func.func @top_k_nd(%arg0 : tensor<16x16xf32>) {
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>)
%0:2 = mhlo.topk(%arg0, k=8, largest=false) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>)
return
}

// -----

func.func @top_k_unbounded(%arg0 : tensor<?x16x?xf32>) {
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<?x16x?xf32> -> (tensor<?x16x8xf32>, tensor<?x16x8xi32>)
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<?x16x?xf32> -> (tensor<?x16x8xf32>, tensor<?x16x8xi32>)
return
}

// -----

func.func @top_k_bounded(%arg0 : tensor<?x?x?xf32, #mhlo.type_extensions<bounds = [?, 16, 16]>>) {
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<?x?x?xf32, #mhlo.type_extensions<bounds = [?, 16, 16]>> -> (tensor<16x?x8xf32, #mhlo.type_extensions<bounds = [?, 16, ?]>>, tensor<16x?x8xi32, #mhlo.type_extensions<bounds = [?, 16, ?]>>)
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<?x?x?xf32, #mhlo.type_extensions<bounds = [?, 16, 16]>> -> (tensor<16x?x8xf32, #mhlo.type_extensions<bounds = [?, 16, ?]>>, tensor<16x?x8xi32, #mhlo.type_extensions<bounds = [?, 16, ?]>>)
return
}

// -----

func.func @top_k_unranked(%arg0 : tensor<*xf32>) {
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<*xf32> -> (tensor<*xf32>, tensor<*xi32>)
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<*xf32> -> (tensor<*xf32>, tensor<*xi32>)
return
}

Expand All @@ -6575,11 +6555,7 @@ func.func @top_k_unranked(%arg0 : tensor<*xf32>) {
func.func @topk_rank_at_least_one(%arg0 : tensor<f32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{operand's rank must be at least 1}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<f32> -> (tensor<8xf32>, tensor<8xi32>)
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<f32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

Expand All @@ -6588,69 +6564,6 @@ func.func @topk_rank_at_least_one(%arg0 : tensor<f32>) {
func.func @topk_last_dimension_at_least_k(%arg0 : tensor<4xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{operand's last dimension must be at least 8}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<4xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

// -----

func.func @topk_body_must_have_two_arguments(%arg0 : tensor<16xf32>) {
// expected-error@+1 {{unsupported body: expected: '(tensor<f32>, tensor<f32>) -> tensor<i1>', got '(tensor<f32>) -> tensor<i1>'}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

// -----

func.func @topk_body_must_have_one_result(%arg0 : tensor<16xf32>) {
// expected-error@+1 {{unsupported body: expected: '(tensor<f32>, tensor<f32>) -> tensor<i1>', got '(tensor<f32>, tensor<f32>) -> (tensor<i1>, tensor<i1>)'}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate, %predicate : tensor<i1>, tensor<i1>
} : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

// -----

func.func @topk_body_arguments_must_have_operand_element_type(%arg0 : tensor<16xf32>) {
// expected-error@+1 {{unsupported body: expected: '(tensor<f32>, tensor<f32>) -> tensor<i1>', got '(tensor<i32>, tensor<i32>) -> tensor<i1>'}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):
%predicate = mhlo.compare GT, %arg1, %arg2 : (tensor<i32>, tensor<i32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

// -----

func.func @topk_body_results_must_have_i1_element_type(%arg0 : tensor<16xf32>) {
// expected-error@+1 {{unsupported body: expected: '(tensor<f32>, tensor<f32>) -> tensor<i1>', got '(tensor<f32>, tensor<f32>) -> tensor<f32>'}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
mhlo.return %arg1 : tensor<f32>
} : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

// -----

func.func @topk_body_must_consist_of_compare_gt_or_compare_lt(%arg0 : tensor<16xf32>) {
// expected-error@+1 {{unsupported body: expected mhlo.compare of body arguments with GT or LT comparison_direction}}
%0:2 = mhlo.topk(%arg0, k=8) {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%predicate = mhlo.compare EQ, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
mhlo.return %predicate : tensor<i1>
} : tensor<16xf32> -> (tensor<8xf32>, tensor<8xi32>)
%0:2 = mhlo.topk(%arg0, k=8, largest=true) : tensor<4xf32> -> (tensor<8xf32>, tensor<8xi32>)
return
}

0 comments on commit 50ff2ab

Please sign in to comment.