Skip to content

Commit

Permalink
Fix reduce result type when the minor-most dimension is not reduced.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675903105
  • Loading branch information
jreiffers authored and Google-ML-Automation committed Sep 18, 2024
1 parent c4ae62b commit 7ba8195
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
25 changes: 23 additions & 2 deletions xla/service/gpu/fusions/ir/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32,

// -----

func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32
func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32

#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1),
domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32],
Expand Down Expand Up @@ -208,4 +208,25 @@ func.func @reduce(%in0: tensor<16x8x4xf32>, %init0: f32,
// CHECK: xla_gpu.reduce(%[[IN1]], %[[IN2]])
// CHECK-SAME: inits(%[[INIT1]], %[[INIT2]]) dimensions=[0, 2]
// CHECK-SAME: combiner=@add {xla.range = [0 : index, 42 : index]}
// CHECK-SAME: : tensor<16x8x4xf32>, tensor<16x8x4xi32>
// CHECK-SAME: : tensor<16x8x4xf32>, tensor<16x8x4xi32>

// -----

func.func @add(%a_acc: f32, %a: f32) -> (f32) {
%0 = arith.addf %a_acc, %a : f32
func.return %0 : f32
}

func.func @reduce_middle_dim(%in: tensor<16x8x4xf32>, %init: f32)
-> (tensor<16x4xf32>) {
%sum = xla_gpu.reduce (%in) inits(%init) dimensions=[1]
combiner=@add : tensor<16x8x4xf32>
func.return %sum : tensor<16x4xf32>
}

// CHECK-LABEL: func.func @reduce_middle_dim(
// CHECK-SAME: %[[IN:.*]]: tensor<16x8x4xf32>, %[[INIT:.*]]: f32)
// CHECK: xla_gpu.reduce(%[[IN]])
// CHECK-SAME: inits(%[[INIT]]) dimensions=[1]
// CHECK-SAME: combiner=@add
// CHECK-SAME: : tensor<16x8x4xf32>
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ SmallVector<Type> inferReductionResultTypes(TypeRange input_types,
output_shape.reserve(input_shape.size() - num_reduced_dims);
int reduce_dim = 0;
for (int64_t i = 0; i < input_shape.size(); ++i) {
if (reduce_dim >= num_reduced_dims || i == reduced_dims[reduce_dim]) {
if (reduce_dim < num_reduced_dims && i == reduced_dims[reduce_dim]) {
++reduce_dim;
continue;
}
Expand Down

0 comments on commit 7ba8195

Please sign in to comment.