Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU][Emitters] Make xla_gpu.shuffle_reduce syntax similar to xla_gpu.reduce. #17280

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions xla/service/gpu/fusions/ir/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,21 @@ func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0.00
// CHECK: xla_gpu.reindex %[[IN1]] at #[[$MAP]] default %[[C0]] :
// CHECK-SAME: tensor<1022xf32> -> tensor<16x64xf32>


// -----

func.func @do_nothing(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
return %a, %b : f32, i32
}
func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
%ret:2 = xla_gpu.shuffle_reduce(%a, %b) to 4 combiner=@do_nothing
{xla.range = [0 : index, 42 : index]} : f32, i32
return %ret#0, %ret#1 : f32, i32
}
// CHECK-LABEL: func.func @shuffler(
// CHECK-SAME: %[[IN1:.*]]: f32, %[[IN2:.*]]: i32)

// CHECK: xla_gpu.shuffle_reduce(%[[IN1]], %[[IN2]]) to 4
// CHECK-SAME: combiner=@do_nothing {xla.range = [0 : index, 42 : index]}
// CHECK-SAME: : f32, i32
39 changes: 39 additions & 0 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,45 @@ LogicalResult ReduceOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ShuffleReduceOp
//===----------------------------------------------------------------------===//

ParseResult ShuffleReduceOp::parse(OpAsmParser& parser,
OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
mlir::StringAttr combiner;
int64_t max_distance;
SmallVector<Type, 2> operand_types;
mlir::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLParen() || parseOperands(parser, &inputs) ||
parser.parseRParen() || parser.parseKeyword("to") ||
parser.parseInteger(max_distance) || parser.parseKeyword("combiner") ||
parser.parseEqual() || parser.parseSymbolName(combiner) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(operand_types) ||
parser.resolveOperands(inputs, operand_types, loc, result.operands)) {
return failure();
}
auto ctx = result.getContext();
mlir::OperationName opname(ShuffleReduceOp::getOperationName(), ctx);
result.addAttribute(ShuffleReduceOp::getCombinerAttrName(opname),
mlir::FlatSymbolRefAttr::get(ctx, combiner));
result.addAttribute(
ShuffleReduceOp::getMaxDistanceAttrName(opname),
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), max_distance));
result.addTypes(operand_types);
return success();
}

void ShuffleReduceOp::print(OpAsmPrinter& p) {
p << '(' << getOperands() << ") to " << getMaxDistance() << " combiner=@"
<< getCombiner();
p.printOptionalAttrDict((*this)->getAttrs(),
{getCombinerAttrName(), getMaxDistanceAttrName()});
p << " : " << TypeRange(getResultTypes());
}

} // namespace gpu
} // namespace xla

Expand Down
23 changes: 10 additions & 13 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,39 +169,35 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
function. The function is invoked with the operands from the low lanes,
followed by the operands from the high lanes. For example:

// TODO: update the syntax to make it similar to xla_gpu.reduce.
```
shuffle_reduce @argmax(%value, %idx) : (f32, index)
%result:2 = xla_gpu.shuffle_reduce (%in0, %in1) to 16 combiner=@argmax
```

Will perform shuffles with distance 16, 8, 4, 2 and 1, and will invoke
@argmax five times. The first invocations will be

```
@argmax(%value[i], %idx[i], %value[16+i], %idx[16+i])
@argmax(%in0[i], %in1[i], %in0[16+i], %in1[16+i])
```
}];
let builders = [
OpBuilder<(ins "mlir::func::FuncOp":$reducer, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
OpBuilder<(ins "mlir::func::FuncOp":$combiner, "mlir::ValueRange":$operands, "int64_t":$max_distance), [{
$_state.addOperands(operands);
$_state.addAttribute("reducer", mlir::SymbolRefAttr::get(reducer));
$_state.addAttribute("combiner", mlir::SymbolRefAttr::get(combiner));
$_state.addAttribute("max_distance",
mlir::IntegerAttr::get(
mlir::IntegerType::get(reducer.getContext(), 64),
mlir::IntegerType::get(combiner.getContext(), 64),
max_distance));
$_state.addTypes(reducer.getFunctionType().getResults());
$_state.addTypes(combiner.getFunctionType().getResults());
}]>];
let arguments = (ins FlatSymbolRefAttr:$reducer,
let arguments = (ins FlatSymbolRefAttr:$combiner,
Variadic<AnyType>:$operands,
I64Attr:$max_distance);
let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
$reducer `(` $operands `)` `to` $max_distance attr-dict `:` type($operands)
}];
let extraClassDeclaration = [{
mlir::CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<mlir::SymbolRefAttr>("reducer");
return (*this)->getAttrOfType<mlir::SymbolRefAttr>("combiner");
}
operand_range getArgOperands() {
return getOperands();
Expand All @@ -213,6 +209,7 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
(*this)->setAttr("reducer", callee.get<mlir::SymbolRefAttr>());
}
}];
let hasCustomAssemblyFormat = 1;
}

def XLAGPU_PredicatedInsertOp : XLAGPU_Op<"predicated_insert",
Expand Down Expand Up @@ -431,7 +428,7 @@ def XLAGPU_ReduceOp : XLAGPU_Op<"reduce", [
func.return %0, %1 : f32, i32
}
%sum:2 = xla_gpu.reduce (%in0, %in1) inits(%init0, %init1) dimensions=[0, 2]
combiner=@add
combiner=@add : tensor<16x8x4xf32>, tensor<16x8x4xi32>
```
}];
let arguments = (ins
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ fusion {
// CHECK: gpu.thread_id x {xla.range = [0 : index, 63 : index]}
// CHECK-NOT: vector<
// CHECK: allocate_shared : tensor<32x3xf32>
// CHECK: shuffle_reduce @add_add{{.*}} to 16 : f32
// CHECK: shuffle_reduce(%{{.*}}) to 16 combiner=@add_add
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ fusion {

// CHECK: vector<2xf32>
// CHECK: allocate_shared : tensor<8x33xf32>
// CHECK: shuffle_reduce @add_add{{.*}} to 4 : f32
// CHECK: shuffle_reduce(%{{.*}}) to 4 combiner=@add_add
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ fusion {
}

// CHECK: allocate_shared : tensor<32x9xf32>
// CHECK: shuffle_reduce @add_add{{.*}} to 16 : f32
// CHECK: shuffle_reduce(%{{.*}}) to 16 combiner=@add_add
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ fusion {

// CHECK: vector<4xf32>
// CHECK: allocate_shared : tensor<16x33xf32>
// CHECK: shuffle_reduce @add_add{{.*}} to 8 : f32
// CHECK: shuffle_reduce(%{{.*}}) to 8 combiner=@add_add
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/tests/reduce_multirow/f32_x8.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ fusion {
// Multi-row reductions do not use shared memory.
// CHECK-NOT: allocate_shared
// There should be 8 elements per warp.
// CHECK: shuffle_reduce {{.*}} to 2
// CHECK: shuffle_reduce(%{{.*}}) to 2
// CHECK-NOT: allocate_shared
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ fusion {

// reduce0 again, in the context of its status as a fusion hero:
// CHECK: tensor.extract %[[P1]][%[[RC]]]
// CHECK: shuffle_reduce @add_add(%{{.*}}) to 2
// CHECK: shuffle_reduce(%{{.*}}) to 2 combiner=@add_add
14 changes: 7 additions & 7 deletions xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ fusion {
ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs)
}

// CHECK-DAG: shuffle_reduce @add_add
// CHECK-DAG: shuffle_reduce @mul_mul
// CHECK-DAG: shuffle_reduce(%{{.*}}) to 16 combiner=@add_add
// CHECK-DAG: shuffle_reduce(%{{.*}}) to 16 combiner=@mul_mul
// CHECK: allocate_shared
// CHECK: allocate_shared
// CHECK: sync_threads
// CHECK-DAG: %[[ADDED:.*]] = xla_gpu.shuffle_reduce @add_add
// CHECK-DAG: %[[MULTIPLIED:.*]] = xla_gpu.shuffle_reduce @mul_mul
// CHECK-DAG: %[[LOG:.*]] = math.log %[[ADDED]]
// CHECK-DAG: %[[ABS:.*]] = math.absf %[[ADDED]]
// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[MULTIPLIED]]
// CHECK-DAG: %[[SUM:.*]] = xla_gpu.shuffle_reduce(%{{.*}}) to 16 combiner=@add_add
// CHECK-DAG: %[[PROD:.*]] = xla_gpu.shuffle_reduce(%{{.*}}) to 16 combiner=@mul_mul
// CHECK-DAG: %[[LOG:.*]] = math.log %[[SUM]]
// CHECK-DAG: %[[ABS:.*]] = math.absf %[[SUM]]
// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[PROD]]
// CHECK-DAG: xla_gpu.predicated_insert %[[LOG]]
// CHECK-DAG: xla_gpu.predicated_insert %[[ABS]]
// CHECK-DAG: xla_gpu.predicated_insert %[[NEG]]
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern<ShuffleReduceOp> {
args.push_back(shuffle(value));
}
values = b.create<PureCallOp>(op.getResultTypes(),
op.getReducerAttr().getAttr(), args)
op.getCombinerAttr().getAttr(), args)
.getResults();
}
rewriter.replaceOp(op, values);
Expand Down
26 changes: 13 additions & 13 deletions xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \
// RUN: | FileCheck %s

func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
func.func @combiner(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) {
return %a, %b : f32, i32
}

func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
%ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32
%ret:2 = xla_gpu.shuffle_reduce (%a, %b) to 4 combiner=@combiner: f32, i32
return %ret#0, %ret#1 : f32, i32
}
// CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32)
Expand All @@ -16,23 +16,23 @@ func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) {
// CHECK-DAG: %[[C32:.*]] = arith.constant 32
// CHECK: %[[A4H:.*]], {{.*}} = gpu.shuffle down %[[A]], %[[C4]], %[[C32]]
// CHECK: %[[B4H:.*]], {{.*}} = gpu.shuffle down %[[B]], %[[C4]], %[[C32]]
// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @reducer(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
// CHECK: %[[AB4_0:.*]], %[[AB4_1:.*]] = xla_gpu.pure_call @combiner(%[[A]], %[[B]], %[[A4H]], %[[B4H]])
// CHECK: %[[A2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_0]], %[[C2]], %[[C32]]
// CHECK: %[[B2H:.*]], {{.*}} = gpu.shuffle down %[[AB4_1]], %[[C2]], %[[C32]]
// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @reducer(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
// CHECK: %[[AB2_0:.*]], %[[AB2_1:.*]] = xla_gpu.pure_call @combiner(%[[AB4_0]], %[[AB4_1]], %[[A2H]], %[[B2H]])
// CHECK: %[[A1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_0]], %[[C1]], %[[C32]]
// CHECK: %[[B1H:.*]], {{.*}} = gpu.shuffle down %[[AB2_1]], %[[C1]], %[[C32]]
// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @reducer(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
// CHECK: %[[AB1_0:.*]], %[[AB1_1:.*]] = xla_gpu.pure_call @combiner(%[[AB2_0]], %[[AB2_1]], %[[A1H]], %[[B1H]])
// CHECK: return %[[AB1_0]], %[[AB1_1]]

// -----

func.func @reducer(%a: f64, %b: f64) -> f64 {
func.func @combiner(%a: f64, %b: f64) -> f64 {
return %a : f64
}

func.func @shuffler(%a: f64) -> f64 {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64
%ret = xla_gpu.shuffle_reduce(%a) to 1 combiner=@combiner : f64
return %ret : f64
}
// CHECK: @shuffler(%[[A:.*]]: f64
Expand All @@ -41,25 +41,25 @@ func.func @shuffler(%a: f64) -> f64 {

// -----

func.func @reducer(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
func.func @combiner(%a: complex<f64>, %b: complex<f64>) -> complex<f64> {
return %a : complex<f64>
}

func.func @shuffler(%a: complex<f64>) -> complex<f64> {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex<f64>
%ret = xla_gpu.shuffle_reduce(%a) to 1 combiner=@combiner : complex<f64>
return %ret : complex<f64>
}
// CHECK: @shuffler
// CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]]

// -----

func.func @reducer(%a: ui64, %b: ui64) -> ui64 {
func.func @combiner(%a: ui64, %b: ui64) -> ui64 {
return %a : ui64
}

func.func @shuffler(%a: ui64) -> ui64 {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64
%ret = xla_gpu.shuffle_reduce (%a) to 1 combiner=@combiner : ui64
return %ret : ui64
}
// CHECK: @shuffler
Expand All @@ -68,12 +68,12 @@ func.func @shuffler(%a: ui64) -> ui64 {

// -----

func.func @reducer(%a: i8, %b: i8) -> i8 {
func.func @combiner(%a: i8, %b: i8) -> i8 {
return %a : i8
}

func.func @shuffler_i8(%a: i8) -> i8 {
%ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8
%ret = xla_gpu.shuffle_reduce (%a) to 1 combiner=@combiner : i8
return %ret : i8
}
// CHECK: @shuffler_i8(
Expand Down
Loading