Skip to content

Commit

Permalink
[mhlo] fix ScatterOp::fold for batching dims
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675727628
  • Loading branch information
tomnatan30 authored and Google-ML-Automation committed Sep 17, 2024
1 parent ee266ed commit 891d972
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
17 changes: 13 additions & 4 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5884,16 +5884,24 @@ LogicalResult ScatterOp::fold(
// value.
indexIndex.clear();
if (indexVectorDim == 0) indexIndex.push_back(0);
auto updateWindowDims = getScatterDimensionNumbers().getUpdateWindowDims();
for (int64_t i = 0; i < static_cast<int64_t>(updateIndex.size()); ++i) {
if (llvm::count(getScatterDimensionNumbers().getUpdateWindowDims(), i) ==
0)
if (!llvm::is_contained(updateWindowDims, i))
indexIndex.push_back(updateIndex[i]);
if (static_cast<int64_t>(indexIndex.size()) == indexVectorDim)
indexIndex.push_back(0);
}

// Compute the index for the given update value in the base tensor.
baseIndex.assign(baseType.getRank(), 0);
auto inputBatchingDims =
getScatterDimensionNumbers().getInputBatchingDims();
auto scatterIndicesBatchingDims =
getScatterDimensionNumbers().getScatterIndicesBatchingDims();
for (auto [operandDim, indicesDim] :
llvm::zip_equal(inputBatchingDims, scatterIndicesBatchingDims)) {
baseIndex[operandDim] = indexIndex[indicesDim];
}
uint64_t indexCount = indexType.getShape()[indexVectorDim];
for (uint64_t i = 0; i < indexCount; ++i) {
uint64_t operandDim =
Expand All @@ -5905,9 +5913,10 @@ LogicalResult ScatterOp::fold(
uint64_t updateWindowDimIndex = 0;
auto insertedWindowDims =
getScatterDimensionNumbers().getInsertedWindowDims();
auto updateWindowDims = getScatterDimensionNumbers().getUpdateWindowDims();
for (uint64_t i = 0; i < baseIndex.size(); ++i) {
if (llvm::count(insertedWindowDims, i)) continue;
if (llvm::is_contained(insertedWindowDims, i) ||
llvm::is_contained(inputBatchingDims, i))
continue;
baseIndex[i] += updateIndex[updateWindowDims[updateWindowDimIndex]];
updateWindowDimIndex++;
}
Expand Down
18 changes: 18 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ func.func @scatter_full_overwrite() -> tensor<512x1x6400x6400xf32> {

// -----

// Verify that a full overwrite of the "base" with a batched scatter is
// correctly folded.
func.func @scatter_batching_dims_full_overwrite() -> tensor<3x1x6400x6400xf32> {
%base = mhlo.constant dense<0.000000e+00> : tensor<3x1x6400x6400xf32>
%index = mhlo.constant dense<0> : tensor<3x1xi32>
%update = mhlo.constant dense<1.000000e+00> : tensor<3x1x6400x6400xf32>
%scatter = "mhlo.scatter"(%base, %index, %update) ({
^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):
"mhlo.return"(%arg6) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1, 2, 3], input_batching_dims = [0], scatter_indices_batching_dims = [0], scatter_dims_to_operand_dims = [3], index_vector_dim = 1>, unique_indices = true} : (tensor<3x1x6400x6400xf32>, tensor<3x1xi32>, tensor<3x1x6400x6400xf32>) -> tensor<3x1x6400x6400xf32>

// CHECK: %[[FOLD:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3x1x6400x6400xf32>
// CHECK: return %[[FOLD]]
func.return %scatter : tensor<3x1x6400x6400xf32>
}

// -----

// Verify that a full overwrite of the "base" with a scatter is correctly folded
// even if the base and update are not constant values.
func.func @scatter_full_overwrite_non_const(
Expand Down

0 comments on commit 891d972

Please sign in to comment.