Skip to content

Commit

Permalink
[XLA:GPU][Emitters] Make xla_gpu.shuffle_reduce syntax similar to xla…
Browse files Browse the repository at this point in the history
…_gpu.reduce.

PiperOrigin-RevId: 675693420
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 18, 2024
1 parent 372d3bb commit 8130309
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 34 deletions.
18 changes: 18 additions & 0 deletions xla/service/gpu/fusions/ir/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,21 @@ func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0.00
// CHECK: xla_gpu.reindex %[[IN1]] at #[[$MAP]] default %[[C0]] :
// CHECK-SAME: tensor<1022xf32> -> tensor<16x64xf32>


// -----

func.func @do_nothing(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
return %a, %b : f32, i32
}
func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
%ret:2 = xla_gpu.shuffle_reduce(%a, %b) to 4 combiner=@do_nothing
{xla.range = [0 : index, 42 : index]} : f32, i32
return %ret#0, %ret#1 : f32, i32
}
// CHECK-LABEL: func.func @shuffler(
// CHECK-SAME: %[[IN1:.*]]: f32, %[[IN2:.*]]: i32)

// CHECK: xla_gpu.shuffle_reduce(%[[IN1]], %[[IN2]]) to 4
// CHECK-SAME: combiner=@do_nothing {xla.range = [0 : index, 42 : index]}
// CHECK-SAME: : f32, i32
39 changes: 39 additions & 0 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,45 @@ LogicalResult ReduceOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ShuffleReduceOp
//===----------------------------------------------------------------------===//

ParseResult ShuffleReduceOp::parse(OpAsmParser& parser,
OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
mlir::StringAttr combiner;
int64_t max_distance;
SmallVector<Type, 2> operand_types;
mlir::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLParen() || parseOperands(parser, &inputs) ||
parser.parseRParen() || parser.parseKeyword("to") ||
parser.parseInteger(max_distance) || parser.parseKeyword("combiner") ||
parser.parseEqual() || parser.parseSymbolName(combiner) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(operand_types) ||
parser.resolveOperands(inputs, operand_types, loc, result.operands)) {
return failure();
}
auto ctx = result.getContext();
mlir::OperationName opname(ShuffleReduceOp::getOperationName(), ctx);
result.addAttribute(ShuffleReduceOp::getCombinerAttrName(opname),
mlir::FlatSymbolRefAttr::get(ctx, combiner));
result.addAttribute(
ShuffleReduceOp::getMaxDistanceAttrName(opname),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), max_distance));
result.addTypes(operand_types);
return success();
}

void ShuffleReduceOp::print(OpAsmPrinter& p) {
p << '(' << getOperands() << ") to " << getMaxDistance() << " combiner=@"
<< getCombiner();
p.printOptionalAttrDict((*this)->getAttrs(),
{getCombinerAttrName(), getMaxDistanceAttrName()});
p << " : " << TypeRange(getResultTypes());
}

} // namespace gpu
} // namespace xla

Expand Down
23 changes: 10 additions & 13 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,39 +169,35 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
function. The function is invoked with the operands from the low lanes,
followed by the operands from the high lanes. For example:

// TODO: update the syntax to make it similar to xla_gpu.reduce.
```
shuffle_reduce @argmax(%value, %idx) : (f32, index)
%result:2 = xla_gpu.shuffle_reduce (%in0, %in1) to 16 combiner=@argmax
```

Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke
@argmax five times. The first invocations will be

```
@argmax(%value[i], %idx[i], %value[16+i], %idx[16+i])
@argmax(%in0[i], %in1[i], %in0[16+i], %in1[16+i])
```
}];
let builders = [
OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
OpBuilder<(ins "mlir::func::FuncOp":$combiner, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
$_state.addOperands(operands);
$_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer));
$_state.addAttribute("combiner", mlir::SymbolRefAttr::get(combiner));
$_state.addAttribute("max_distance",
mlir::IntegerAttr::get(
mlir::IntegerType::get(reducer.getContext(), 64),
mlir::IntegerType::get(combiner.getContext(), 64),
max_distance));
$_state.addTypes(reducer.getFunctionType().getResults());
$_state.addTypes(combiner.getFunctionType().getResults());
}]>];
let arguments = (ins FlatSymbolRefAttr:$reducer,
let arguments = (ins FlatSymbolRefAttr:$combiner,
Variadic<AnyType>:$operands,
I64Attr:$max_distance);
let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
$reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands)
}];
let extraClassDeclaration = [{
mlir::CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<mlir::SymbolRefAttr>("reducer");
return (*this)->getAttrOfType<mlir::SymbolRefAttr>("combiner");
}
operand_range getArgOperands() {
return getOperands();
Expand All @@ -213,6 +209,7 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
(*this)->setAttr("reducer", callee.get<mlir::SymbolRefAttr>());
}
}];
let hasCustomAssemblyFormat = 1;
}

def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert",
Expand Down Expand Up @@ -431,7 +428,7 @@ def XLAGPU_ReduceOp : XLAGPU_Op<"reduce", [
func.return %0, %1 : f32, i32
}
%sum:2 = xla_gpu.reduce (%in0, %in1) inits(%init0, %init1) dimensions=[0, 2]
combiner=@add
combiner=@add : tensor<16x8x4xf32>, tensor<16x8x4xi32>
```
}];
let arguments = (ins
Expand Down
14 changes: 7 additions & 7 deletions xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ fusion {
ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs)
}

// CHECK-DAG: shuffle_reduce @add_add
// CHECK-DAG: shuffle_reduce @mul_mul
// CHECK-DAG: shuffle_reduce(%{{.*}}) to 16 combiner=@add_add
// CHECK-DAG: shuffle_reduce(%{{.*}}) to 16 combiner=@mul_mul
// CHECK: allocate_shared
// CHECK: allocate_shared
// CHECK: sync_threads
// CHECK-DAG: %[[ADDED:.*]] = xla_gpu.shuffle_reduce @add_add
// CHECK-DAG: %[[MULTIPLIED:.*]] = xla_gpu.shuffle_reduce @mul_mul
// CHECK-DAG: %[[LOG:.*]] = math.log %[[ADDED]]
// CHECK-DAG: %[[ABS:.*]] = math.absf %[[ADDED]]
// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[MULTIPLIED]]
// CHECK-DAG: %[[SUM:.*]] = xla_gpu.shuffle_reduce(%{{.*}}) to 16 combiner=@add_add
// CHECK-DAG: %[[PROD:.*]] = xla_gpu.shuffle_reduce(%{{.*}}) to 16 combiner=@mul_mul
// CHECK-DAG: %[[LOG:.*]] = math.log %[[SUM]]
// CHECK-DAG: %[[ABS:.*]] = math.absf %[[SUM]]
// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[PROD]]
// CHECK-DAG: xla_gpu.predicated_insert %[[LOG]]
// CHECK-DAG: xla_gpu.predicated_insert %[[ABS]]
// CHECK-DAG: xla_gpu.predicated_insert %[[NEG]]
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern<ShuffleReduceOp> {
args.push_back(shuffle(value));
}
values = b.create<PureCallOp>(op.getResultTypes(),
op.getReducerAttr().getAttr(), args)
op.getCombinerAttr().getAttr(), args)
.getResults();
}
rewriter.replaceOp(op, values);
Expand Down
26 changes: 13 additions & 13 deletions xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \
// RUN: | FileCheck %s

func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
func.func @combiner(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
return %a, %b : f32, i32
}

func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
%ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32
%ret:2 = xla_gpu.shuffle_reduce (%a, %b) to 4 combiner=@combiner: f32, i32
return %ret#0, %ret#1 : f32, i32
}
// CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32)
Expand All @@ -16,23 +16,23 @@ func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
// CHECK-DAG: %[[C32:.*]] = arith.constant 32
// CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]]
// CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]]
// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @combiner(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
// CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]]
// CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]]
// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @reducer(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @combiner(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
// CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]]
// CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]]
// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @reducer(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @combiner(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
// CHECK: return %[[AB1_0]], %[[AB1_1]]

// -----

func.func @reducer(%a: f64, %b: f64) -> f64 {
func.func @combiner(%a: f64, %b: f64) -> f64 {
return %a : f64
}

func.func @shuffler(%a: f64) -> f64 {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64
%ret = xla_gpu.shuffle_reduce(%a) to 1 combiner=@combiner : f64
return %ret : f64
}
// CHECK: @shuffler(%[[A:.*]]: f64
Expand All @@ -41,25 +41,25 @@ func.func @shuffler(%a: f64) -> f64 {

// -----

func.func @reducer(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
func.func @combiner(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
return %a : complex<f64>
}

func.func @shuffler(%a: complex<f64>) -> complex<f64> {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex<f64>
%ret = xla_gpu.shuffle_reduce(%a) to 1 combiner=@combiner : complex<f64>
return %ret : complex<f64>
}
// CHECK: @shuffler
// CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]]

// -----

func.func @reducer(%a: ui64, %b: ui64) -> ui64 {
func.func @combiner(%a: ui64, %b: ui64) -> ui64 {
return %a : ui64
}

func.func @shuffler(%a: ui64) -> ui64 {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64
%ret = xla_gpu.shuffle_reduce (%a) to 1 combiner=@combiner : ui64
return %ret : ui64
}
// CHECK: @shuffler
Expand All @@ -68,12 +68,12 @@ func.func @shuffler(%a: ui64) -> ui64 {

// -----

func.func @reducer(%a: i8, %b: i8) -> i8 {
func.func @combiner(%a: i8, %b: i8) -> i8 {
return %a : i8
}

func.func @shuffler_i8(%a: i8) -> i8 {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8
%ret = xla_gpu.shuffle_reduce (%a) to 1 combiner=@combiner : i8
return %ret : i8
}
// CHECK: @shuffler_i8(
Expand Down

0 comments on commit 8130309

Please sign in to comment.