From ede40da1f8c1e91601b985cd32ad785aa8806880 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Fri, 6 Sep 2024 10:45:59 +0800 Subject: [PATCH] [mlir][tensor] Add check for indices of `tensor.gather` (#106894) This patch add a check for indices of `tensor.gather` and `tensor.scatter`. For that the length of gather_dims/scatter_dims should match the size of last dimension of the indices. Fix #94901. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 +++- mlir/test/Dialect/Tensor/invalid.mlir | 80 ++++++++++++++++++------ 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 996de530c255d4..5fbb3d55f8faa5 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1288,7 +1288,8 @@ RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType, } static LogicalResult -verifyGatherOrScatterDims(Operation *op, ArrayRef dims, int64_t rank, +verifyGatherOrScatterDims(Operation *op, ArrayRef dims, + ArrayRef indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest) { if (dims.empty()) return op->emitOpError(gatherOrScatter) << "_dims must be non-empty"; @@ -1297,6 +1298,9 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef dims, int64_t rank, if (numGatherDims > rank) return op->emitOpError(gatherOrScatter) << "_dims overflow " << sourceOrDest << " rank"; + if (indices.empty() || indices.back() != numGatherDims) + return op->emitOpError(gatherOrScatter) + << "_dims length must match the size of last dimension of indices"; for (int64_t val : dims) { if (val < 0) return op->emitOpError(gatherOrScatter) @@ -1316,7 +1320,8 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef dims, int64_t rank, LogicalResult GatherOp::verify() { int64_t sourceRank = getSourceType().getRank(); ArrayRef gatherDims = getGatherDims(); - if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank, + if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, + getIndicesType().getShape(), sourceRank, "gather", "source"))) return failure(); @@ -3530,7 +3535,8 @@ void ScatterOp::getAsmResultNames( LogicalResult ScatterOp::verify() { int64_t destRank = getDestType().getRank(); ArrayRef scatterDims = getScatterDims(); - if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank, + if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, + getIndicesType().getShape(), destRank, "scatter", "dest"))) return failure(); diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index d9db32b8801ac2..84e6c59e403dde 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -455,41 +455,59 @@ func.func @gather_coordinate_rank_overflow( // ----- +func.func @gather_coordinate_rank_mismatch0( + %source: tensor<4x5x6xf32>, %indices: tensor) { + // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}} + %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]): + (tensor<4x5x6xf32>, tensor) -> tensor<1x2xf32> +} + +// ----- + +func.func @gather_coordinate_rank_mismatch1( + %source: tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) { + // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}} + %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]): + (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32> +} + +// ----- + func.func @gather_coordinate_negative( - %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) { // expected-error@+1 {{gather_dims value must be non-negative}} %out = tensor.gather %source[%indices] gather_dims([-1]): - (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + (tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32> return } // ----- func.func @gather_coordinate_overflow( - %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) { // expected-error@+1 {{gather_dims value must be smaller than source rank}} %out = tensor.gather %source[%indices] gather_dims([42]): - (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + (tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32> return } // ----- -func.func @gather_coordinate_overflow( - %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { +func.func @gather_coordinate_increase( + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) { // expected-error@+1 {{gather_dims values must be strictly increasing}} %out = tensor.gather %source[%indices] gather_dims([1, 0]): - (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32> return } // ----- func.func @gather_wrong_result_type( - %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) { // expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}} %out = tensor.gather %source[%indices] gather_dims([0, 2]): - (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32> + (tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32> return } @@ -517,12 +535,34 @@ func.func @scatter_coordinate_rank_overflow( // ----- +func.func @scatter_coordinate_rank_mismatch0( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor) { + // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique: + (tensor, tensor<4x5x6xf32>, tensor) -> tensor<1x2xf32> + return +} + +// ----- + +func.func @scatter_coordinate_rank_mismatch1( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) { + // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique: + (tensor, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32> + return +} + +// ----- + func.func @scatter_coordinate_negative( %source : tensor, - %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) { // expected-error@+1 {{scatter_dims value must be non-negative}} %out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique: - (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + (tensor, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32> return } @@ -530,21 +570,21 @@ func.func @scatter_coordinate_negative( func.func @scatter_coordinate_overflow( %source : tensor, - %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) { // expected-error@+1 {{scatter_dims value must be smaller than dest rank}} %out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique: - (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + (tensor, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32> return } // ----- -func.func @scatter_coordinate_overflow( +func.func @scatter_coordinate_increase( %source : tensor, - %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) { // expected-error@+1 {{scatter_dims values must be strictly increasing}} %out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique: - (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + (tensor, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32> return } @@ -552,10 +592,10 @@ func.func @scatter_coordinate_overflow( func.func @scatter_missing_unique( %source : tensor, - %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) { // expected-error@+1 {{requires 'unique' attribute to be set}} %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]): - (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32> + (tensor, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32> return } @@ -563,10 +603,10 @@ func.func @scatter_missing_unique( func.func @scatter_wrong_result_type( %source : tensor, - %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) { // expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor')}} %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique: - (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32> + (tensor, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32> return }