diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index df3f11a2fc03f1..dbd006df877c8e 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -5884,9 +5884,9 @@ LogicalResult ScatterOp::fold( // value. indexIndex.clear(); if (indexVectorDim == 0) indexIndex.push_back(0); + auto updateWindowDims = getScatterDimensionNumbers().getUpdateWindowDims(); for (int64_t i = 0; i < static_cast(updateIndex.size()); ++i) { - if (llvm::count(getScatterDimensionNumbers().getUpdateWindowDims(), i) == - 0) + if (!llvm::is_contained(updateWindowDims, i)) indexIndex.push_back(updateIndex[i]); if (static_cast(indexIndex.size()) == indexVectorDim) indexIndex.push_back(0); @@ -5894,6 +5894,14 @@ LogicalResult ScatterOp::fold( // Compute the index for the given update value in the base tensor. baseIndex.assign(baseType.getRank(), 0); + auto inputBatchingDims = + getScatterDimensionNumbers().getInputBatchingDims(); + auto scatterIndicesBatchingDims = + getScatterDimensionNumbers().getScatterIndicesBatchingDims(); + for (auto [operandDim, indicesDim] : + llvm::zip_equal(inputBatchingDims, scatterIndicesBatchingDims)) { + baseIndex[operandDim] = indexIndex[indicesDim]; + } uint64_t indexCount = indexType.getShape()[indexVectorDim]; for (uint64_t i = 0; i < indexCount; ++i) { uint64_t operandDim = @@ -5905,9 +5913,10 @@ LogicalResult ScatterOp::fold( uint64_t updateWindowDimIndex = 0; auto insertedWindowDims = getScatterDimensionNumbers().getInsertedWindowDims(); - auto updateWindowDims = getScatterDimensionNumbers().getUpdateWindowDims(); for (uint64_t i = 0; i < baseIndex.size(); ++i) { - if (llvm::count(insertedWindowDims, i)) continue; + if (llvm::is_contained(insertedWindowDims, i) || + llvm::is_contained(inputBatchingDims, i)) + continue; baseIndex[i] += updateIndex[updateWindowDims[updateWindowDimIndex]]; updateWindowDimIndex++; } diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir index 901428df8a10c3..fbea8d7f791290 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir @@ -53,6 +53,24 @@ func.func @scatter_full_overwrite() -> tensor<512x1x6400x6400xf32> { // ----- +// Verify that a full overwrite of the "base" with a batched scatter is +// correctly folded. +func.func @scatter_batching_dims_full_overwrite() -> tensor<3x1x6400x6400xf32> { + %base = mhlo.constant dense<0.000000e+00> : tensor<3x1x6400x6400xf32> + %index = mhlo.constant dense<0> : tensor<3x1xi32> + %update = mhlo.constant dense<1.000000e+00> : tensor<3x1x6400x6400xf32> + %scatter = "mhlo.scatter"(%base, %index, %update) ({ + ^bb0(%arg5: tensor, %arg6: tensor): + "mhlo.return"(%arg6) : (tensor) -> () + }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<3x1x6400x6400xf32>, tensor<3x1xi32>, tensor<3x1x6400x6400xf32>) -> tensor<3x1x6400x6400xf32> + + // CHECK: %[[FOLD:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3x1x6400x6400xf32> + // CHECK: return %[[FOLD]] + func.return %scatter : tensor<3x1x6400x6400xf32> +} + +// ----- + // Verify that a full overwrite of the "base" with a scatter is correctly folded // even if the base and update are not constant values. func.func @scatter_full_overwrite_non_const(