From 8130309a70df0b5b1eb3f64c6beb8329e2c6638d Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Tue, 17 Sep 2024 13:40:52 -0700 Subject: [PATCH] [XLA:GPU][Emitters] Make xla_gpu.shuffle_reduce syntax similar to xla_gpu.reduce. PiperOrigin-RevId: 675693420 --- xla/service/gpu/fusions/ir/tests/ops.mlir | 18 +++++++++ xla/service/gpu/fusions/ir/xla_gpu_ops.cc | 39 +++++++++++++++++++ xla/service/gpu/fusions/ir/xla_gpu_ops.td | 23 +++++------ .../fusions/tests/reduce_row/mof_epilogue.hlo | 14 +++---- .../transforms/lower_xla_gpu_to_scf.cc | 2 +- .../tests/lower_xla_gpu_to_scf.mlir | 26 ++++++------- 6 files changed, 88 insertions(+), 34 deletions(-) diff --git a/xla/service/gpu/fusions/ir/tests/ops.mlir b/xla/service/gpu/fusions/ir/tests/ops.mlir index d398721f9e4860..54779f84486cfc 100644 --- a/xla/service/gpu/fusions/ir/tests/ops.mlir +++ b/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -261,3 +261,21 @@ func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { // CHECK: %[[C0:.*]] = arith.constant 0.00 // CHECK: xla_gpu.reindex %[[IN1]] at #[[$MAP]] default %[[C0]] : // CHECK-SAME: tensor<1022xf32> -> tensor<16x64xf32> + + +// ----- + +func.func @do_nothing(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { + return %a, %b : f32, i32 +} +func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { + %ret:2 = xla_gpu.shuffle_reduce(%a, %b) to 4 combiner=@do_nothing + {xla.range = [0 : index, 42 : index]} : f32, i32 + return %ret#0, %ret#1 : f32, i32 +} +// CHECK-LABEL: func.func @shuffler( +// CHECK-SAME: %[[IN1:.*]]: f32, %[[IN2:.*]]: i32) + +// CHECK: xla_gpu.shuffle_reduce(%[[IN1]], %[[IN2]]) to 4 +// CHECK-SAME: combiner=@do_nothing {xla.range = [0 : index, 42 : index]} +// CHECK-SAME: : f32, i32 \ No newline at end of file diff --git a/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index 98bcb81fc757e4..faa9a20f558b56 100644 --- a/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -1190,6 +1190,45 @@ LogicalResult ReduceOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ShuffleReduceOp +//===----------------------------------------------------------------------===// + +ParseResult ShuffleReduceOp::parse(OpAsmParser& parser, + OperationState& result) { + SmallVector inputs; + mlir::StringAttr combiner; + int64_t max_distance; + SmallVector operand_types; + mlir::SMLoc loc = parser.getCurrentLocation(); + if (parser.parseLParen() || parseOperands(parser, &inputs) || + parser.parseRParen() || parser.parseKeyword("to") || + parser.parseInteger(max_distance) || parser.parseKeyword("combiner") || + parser.parseEqual() || parser.parseSymbolName(combiner) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonTypeList(operand_types) || + parser.resolveOperands(inputs, operand_types, loc, result.operands)) { + return failure(); + } + auto ctx = result.getContext(); + mlir::OperationName opname(ShuffleReduceOp::getOperationName(), ctx); + result.addAttribute(ShuffleReduceOp::getCombinerAttrName(opname), + mlir::FlatSymbolRefAttr::get(ctx, combiner)); + result.addAttribute( + ShuffleReduceOp::getMaxDistanceAttrName(opname), + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), max_distance)); + result.addTypes(operand_types); + return success(); +} + +void ShuffleReduceOp::print(OpAsmPrinter& p) { + p << '(' << getOperands() << ") to " << getMaxDistance() << " combiner=@" + << getCombiner(); + p.printOptionalAttrDict((*this)->getAttrs(), + {getCombinerAttrName(), getMaxDistanceAttrName()}); + p << " : " << TypeRange(getResultTypes()); +} + } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/fusions/ir/xla_gpu_ops.td b/xla/service/gpu/fusions/ir/xla_gpu_ops.td index 715b5cdfdbaed3..a5a40884d6efa3 100644 --- a/xla/service/gpu/fusions/ir/xla_gpu_ops.td +++ b/xla/service/gpu/fusions/ir/xla_gpu_ops.td @@ -169,39 +169,35 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce", function. The function is invoked with the operands from the low lanes, followed by the operands from the high lanes. For example: - // TODO: update the syntax to make it similar to xla_gpu.reduce. ``` - shuffle_reduce @argmax(%value, %idx) : (f32, index) + %result:2 = xla_gpu.shuffle_reduce (%in0, %in1) to 16 combiner=@argmax ``` Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke @argmax five times. The first invocations will be ``` - @argmax(%value[i], %idx[i], %value[16+i], %idx[16+i]) + @argmax(%in0[i], %in1[i], %in0[16+i], %in1[16+i]) ``` }]; let builders = [ - OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{ + OpBuilder<(ins "mlir::func::FuncOp":$combiner, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{ $_state.addOperands(operands); - $_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer)); + $_state.addAttribute("combiner", mlir::SymbolRefAttr::get(combiner)); $_state.addAttribute("max_distance", mlir::IntegerAttr::get( - mlir::IntegerType::get(reducer.getContext(), 64), + mlir::IntegerType::get(combiner.getContext(), 64), max_distance)); - $_state.addTypes(reducer.getFunctionType().getResults()); + $_state.addTypes(combiner.getFunctionType().getResults()); }]>]; - let arguments = (ins FlatSymbolRefAttr:$reducer, + let arguments = (ins FlatSymbolRefAttr:$combiner, Variadic:$operands, I64Attr:$max_distance); let results = (outs Variadic:$results); - let assemblyFormat = [{ - $reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands) - }]; let extraClassDeclaration = [{ mlir::CallInterfaceCallable getCallableForCallee() { - return (*this)->getAttrOfType("reducer"); + return (*this)->getAttrOfType("combiner"); } operand_range getArgOperands() { return getOperands(); @@ -213,6 +209,7 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce", (*this)->setAttr("reducer", callee.get()); } }]; + let hasCustomAssemblyFormat = 1; } def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert", @@ -431,7 +428,7 @@ def XLAGPU_ReduceOp : XLAGPU_Op<"reduce", [ func.return %0, %1 : f32, i32 } %sum:2 = xla_gpu.reduce (%in0, %in1) inits(%init0, %init1) dimensions=[0, 2] - combiner=@add + combiner=@add : tensor<16x8x4xf32>, tensor<16x8x4xi32> ``` }]; let arguments = (ins diff --git a/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo b/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo index 45b2d5f9043da1..d2c5885b9d6eca 100644 --- a/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo +++ b/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo @@ -28,16 +28,16 @@ fusion { ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs) } -// CHECK-DAG: shuffle_reduce @add_add -// CHECK-DAG: shuffle_reduce @mul_mul +// CHECK-DAG: shuffle_reduce(%{{.*}}) to 16 combiner=@add_add +// CHECK-DAG: shuffle_reduce(%{{.*}}) to 16 combiner=@mul_mul // CHECK: allocate_shared // CHECK: allocate_shared // CHECK: sync_threads -// CHECK-DAG: %[[ADDED:.*]] = xla_gpu.shuffle_reduce @add_add -// CHECK-DAG: %[[MULTIPLIED:.*]] = xla_gpu.shuffle_reduce @mul_mul -// CHECK-DAG: %[[LOG:.*]] = math.log %[[ADDED]] -// CHECK-DAG: %[[ABS:.*]] = math.absf %[[ADDED]] -// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[MULTIPLIED]] +// CHECK-DAG: %[[SUM:.*]] = xla_gpu.shuffle_reduce(%{{.*}}) to 16 combiner=@add_add +// CHECK-DAG: %[[PROD:.*]] = xla_gpu.shuffle_reduce(%{{.*}}) to 16 combiner=@mul_mul +// CHECK-DAG: %[[LOG:.*]] = math.log %[[SUM]] +// CHECK-DAG: %[[ABS:.*]] = math.absf %[[SUM]] +// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[PROD]] // CHECK-DAG: xla_gpu.predicated_insert %[[LOG]] // CHECK-DAG: xla_gpu.predicated_insert %[[ABS]] // CHECK-DAG: xla_gpu.predicated_insert %[[NEG]] diff --git a/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc index f0197709d8a833..be1686164d656f 100644 --- a/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc +++ b/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc @@ -186,7 +186,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { args.push_back(shuffle(value)); } values = b.create(op.getResultTypes(), - op.getReducerAttr().getAttr(), args) + op.getCombinerAttr().getAttr(), args) .getResults(); } rewriter.replaceOp(op, values); diff --git a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index 8cc6b652efecc0..dd15bdaafc533f 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -1,12 +1,12 @@ // RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \ // RUN: | FileCheck %s -func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { +func.func @combiner(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { return %a, %b : f32, i32 } func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { - %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32 + %ret:2 = xla_gpu.shuffle_reduce (%a, %b) to 4 combiner=@combiner: f32, i32 return %ret#0, %ret#1 : f32, i32 } // CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32) @@ -16,23 +16,23 @@ func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { // CHECK-DAG: %[[C32:.*]] = arith.constant 32 // CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]] // CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]] -// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]]) +// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @combiner(%[[A]], %[[B]], %[[A4H]], %[[B4H]]) // CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]] // CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]] -// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @reducer(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]]) +// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @combiner(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]]) // CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]] // CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]] -// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @reducer(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]]) +// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @combiner(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]]) // CHECK: return %[[AB1_0]], %[[AB1_1]] // ----- -func.func @reducer(%a: f64, %b: f64) -> f64 { +func.func @combiner(%a: f64, %b: f64) -> f64 { return %a : f64 } func.func @shuffler(%a: f64) -> f64 { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64 + %ret = xla_gpu.shuffle_reduce(%a) to 1 combiner=@combiner : f64 return %ret : f64 } // CHECK: @shuffler(%[[A:.*]]: f64 @@ -41,12 +41,12 @@ func.func @shuffler(%a: f64) -> f64 { // ----- -func.func @reducer(%a: complex, %b: complex) -> complex { +func.func @combiner(%a: complex, %b: complex) -> complex { return %a : complex } func.func @shuffler(%a: complex) -> complex { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex + %ret = xla_gpu.shuffle_reduce(%a) to 1 combiner=@combiner : complex return %ret : complex } // CHECK: @shuffler @@ -54,12 +54,12 @@ func.func @shuffler(%a: complex) -> complex { // ----- -func.func @reducer(%a: ui64, %b: ui64) -> ui64 { +func.func @combiner(%a: ui64, %b: ui64) -> ui64 { return %a : ui64 } func.func @shuffler(%a: ui64) -> ui64 { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64 + %ret = xla_gpu.shuffle_reduce (%a) to 1 combiner=@combiner : ui64 return %ret : ui64 } // CHECK: @shuffler @@ -68,12 +68,12 @@ func.func @shuffler(%a: ui64) -> ui64 { // ----- -func.func @reducer(%a: i8, %b: i8) -> i8 { +func.func @combiner(%a: i8, %b: i8) -> i8 { return %a : i8 } func.func @shuffler_i8(%a: i8) -> i8 { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8 + %ret = xla_gpu.shuffle_reduce (%a) to 1 combiner=@combiner : i8 return %ret : i8 } // CHECK: @shuffler_i8(