Skip to content

Commit

Permalink
[XLA:GPU][Emitters] Add xla_gpu.reduce op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675443048
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Sep 17, 2024
1 parent d65956e commit d03c061
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 5 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ cc_library(
"@llvm-project//mlir:InliningUtils",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@stablehlo//:stablehlo_type_inference",
],
)

Expand Down
28 changes: 28 additions & 0 deletions xla/service/gpu/fusions/ir/tests/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,32 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>,
%0 = xla_gpu.insert %input(%i, %j) into %output at #map1
: !xla_gpu.indexed_vector<32x64xf32, #map> -> tensor<32x64xf32>
func.return %0 : tensor<32x64xf32>
}

// -----

func.func @reduce(%in0: tensor<16x8x4xf32>, %init0: f32,
%in1: tensor<16x8x4xi32>, %init1: i32) -> (tensor<8xf32>, tensor<8xi32>) {
// expected-error @+1 {{combiner `@add` not found}}
%sum:2 = xla_gpu.reduce (%in0, %in1) inits(%init0, %init1) dimensions=[0, 2]
combiner=@add {xla.range = [0 : index, 42 : index]}
: tensor<16x8x4xf32>, tensor<16x8x4xi32>
func.return %sum#0, %sum#1 : tensor<8xf32>, tensor<8xi32>
}

// -----

func.func @add(%a_acc: f32, %b_acc: f32, %a: f32, %b: f32)
-> (f32, f32) {
%0 = arith.addf %a_acc, %a : f32
%1 = arith.addf %b_acc, %b : f32
func.return %0, %1 : f32, f32
}
func.func @reduce(%in0: tensor<16x8x4xf32>, %init0: f32,
%in1: tensor<16x8x4xi32>, %init1: i32) -> (tensor<8xf32>, tensor<8xi32>) {
// expected-error @+1 {{combiner `@add expected to have type '(f32, i32, f32, i32) -> (f32, i32)' but got '(f32, f32, f32, f32) -> (f32, f32)'}}
%sum:2 = xla_gpu.reduce (%in0, %in1) inits(%init0, %init1) dimensions=[0, 2]
combiner=@add {xla.range = [0 : index, 42 : index]}
: tensor<16x8x4xf32>, tensor<16x8x4xi32>
func.return %sum#0, %sum#1 : tensor<8xf32>, tensor<8xi32>
}
24 changes: 24 additions & 0 deletions xla/service/gpu/fusions/ir/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,27 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index,
// CHECK-SAME: #[[$MAP]](%{{.*}}, %{{.*}})
// CHECK: xla_gpu.insert %[[MATERIALIZED]](%{{.*}}, %{{.*}}) into
// CHECK-SAME: at #[[$MAP2]] : <32x64xf32, #[[$MAP1]]>

// -----

func.func @add(%a_acc: f32, %b_acc: i32, %a: f32, %b: i32)
-> (f32, i32) {
%0 = arith.addf %a_acc, %a : f32
%1 = arith.addi %b_acc, %b : i32
func.return %0, %1 : f32, i32
}
func.func @reduce(%in0: tensor<16x8x4xf32>, %init0: f32,
%in1: tensor<16x8x4xi32>, %init1: i32) -> (tensor<8xf32>, tensor<8xi32>) {
%sum:2 = xla_gpu.reduce (%in0, %in1) inits(%init0, %init1) dimensions=[0, 2]
combiner=@add {xla.range = [0 : index, 42 : index]}
: tensor<16x8x4xf32>, tensor<16x8x4xi32>
func.return %sum#0, %sum#1 : tensor<8xf32>, tensor<8xi32>
}
// CHECK-LABEL: func.func @reduce(
// CHECK-SAME: %[[IN1:.*]]: tensor<16x8x4xf32>, %[[INIT1:.*]]: f32,
// CHECK-SAME: %[[IN2:.*]]: tensor<16x8x4xi32>, %[[INIT2:.*]]: i32)

// CHECK: xla_gpu.reduce(%[[IN1]], %[[IN2]])
// CHECK-SAME: inits(%[[INIT1]], %[[INIT2]]) dimensions=[0, 2]
// CHECK-SAME: combiner=@add {xla.range = [0 : index, 42 : index]}
// CHECK-SAME: : tensor<16x8x4xf32>, tensor<16x8x4xi32>
135 changes: 130 additions & 5 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h" // IWYU pragma: keep
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep
#include "mlir/IR/OpDefinition.h"
Expand All @@ -40,10 +41,12 @@ limitations under the License.
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/TypeInference.h"
#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc"
#include "xla/service/gpu/model/indexing_map.h"

Expand All @@ -55,16 +58,19 @@ using llvm::ArrayRef;
using mlir::AffineExpr;
using mlir::AffineMap;
using mlir::Block;
using mlir::DenseI64ArrayAttr;
using mlir::failure;
using mlir::getAffineConstantExpr;
using mlir::getAffineDimExpr;
using mlir::getAffineSymbolExpr;
using mlir::Location;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpAsmParser;
using mlir::OpAsmPrinter;
using mlir::OpBuilder;
using mlir::OperationState;
using mlir::ParseResult;
using mlir::PatternRewriter;
using mlir::RankedTensorType;
using mlir::Region;
Expand Down Expand Up @@ -143,16 +149,16 @@ void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result,
}

// Parses a comma-separated list of operands, ex: %d1, %d2.
mlir::ParseResult parseOperands(
ParseResult parseOperands(
OpAsmParser& parser,
SmallVector<OpAsmParser::UnresolvedOperand, 4>* operands) {
OpAsmParser::UnresolvedOperand operand;
return parser.parseCommaSeparatedList(
[&]() { return parser.parseOperand(operands->emplace_back()); });
}

mlir::ParseResult ApplyIndexingOp::parse(OpAsmParser& parser,
OperationState& result) {
ParseResult ApplyIndexingOp::parse(OpAsmParser& parser,
OperationState& result) {
mlir::Builder& builder = parser.getBuilder();
auto index_type = builder.getIndexType();

Expand Down Expand Up @@ -508,7 +514,7 @@ struct FoldApplyIndexingResults

LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op,
PatternRewriter& rewriter) const override {
mlir::Location loc = indexing_op.getLoc();
Location loc = indexing_op.getLoc();
IndexingMap indexing_map = indexing_op.getIndexingMap();
if (indexing_map.IsKnownEmpty()) {
return rewriter.notifyMatchFailure(indexing_op,
Expand Down Expand Up @@ -728,7 +734,7 @@ void LoopOp::build(OpBuilder& builder, OperationState& result,
bodyBuilder);
}

mlir::ParseResult LoopOp::parse(OpAsmParser& parser, OperationState& result) {
ParseResult LoopOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument, 4> region_args, ivs, map_results,
iter_args;
SmallVector<OpAsmParser::UnresolvedOperand, 4> dim_operands;
Expand Down Expand Up @@ -1053,6 +1059,125 @@ LogicalResult InsertOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//

SmallVector<Type> inferReductionResultTypes(TypeRange input_types,
ArrayRef<int64_t> reduced_dims) {
auto input_shape =
mlir::cast<RankedTensorType>(input_types.front()).getShape();
auto num_reduced_dims = reduced_dims.size();
SmallVector<int64_t, 4> output_shape;
output_shape.reserve(input_shape.size() - num_reduced_dims);
int reduce_dim = 0;
for (int64_t i = 0; i < input_shape.size(); ++i) {
if (reduce_dim >= num_reduced_dims || i == reduced_dims[reduce_dim]) {
++reduce_dim;
continue;
}
output_shape.push_back(input_shape[i]);
}
SmallVector<Type, 2> result_types;
result_types.reserve(input_types.size());
for (auto input_type : input_types) {
result_types.push_back(RankedTensorType::get(
output_shape,
mlir::cast<RankedTensorType>(input_type).getElementType()));
}
return result_types;
}

SmallVector<Type> inferReductionInitTypes(TypeRange input_types) {
SmallVector<Type, 2> init_types;
init_types.reserve(input_types.size());
for (auto input_type : input_types) {
init_types.push_back(
mlir::cast<RankedTensorType>(input_type).getElementType());
}
return init_types;
}

LogicalResult ReduceOp::inferReturnTypes(
MLIRContext* context, std::optional<Location> location, ValueRange operands,
mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties,
mlir::RegionRange regions,
mlir::SmallVectorImpl<Type>& inferredReturnTypes) {
ReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
inferredReturnTypes.append(inferReductionResultTypes(
TypeRange{adaptor.getInputs()}, adaptor.getDimensions()));
return success();
}

ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs, inits;
SmallVector<int64_t, 2> dimensions;
mlir::StringAttr combiner;
SmallVector<Type, 2> input_types;

if (parser.parseLParen() || parseOperands(parser, &inputs) ||
parser.parseRParen() || parser.parseKeyword("inits") ||
parser.parseLParen() || parseOperands(parser, &inits) ||
parser.parseRParen() || parser.parseKeyword("dimensions") ||
parser.parseEqual() ||
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Square,
[&]() -> ParseResult {
return parser.parseInteger(
dimensions.emplace_back());
}) ||
parser.parseKeyword("combiner") || parser.parseEqual() ||
parser.parseSymbolName(combiner) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(input_types)) {
return failure();
}
auto ctx = result.getContext();
mlir::OperationName opname(ReduceOp::getOperationName(), ctx);
result.addAttribute(ReduceOp::getDimensionsAttrName(opname),
DenseI64ArrayAttr::get(ctx, dimensions));
result.addAttribute(ReduceOp::getCombinerAttrName(opname),
mlir::FlatSymbolRefAttr::get(ctx, combiner));
result.addTypes(inferReductionResultTypes(input_types, dimensions));

auto init_types = inferReductionInitTypes(input_types);
mlir::SMLoc loc = parser.getCurrentLocation();
if (parser.resolveOperands(inputs, input_types, loc, result.operands) ||
parser.resolveOperands(inits, init_types, loc, result.operands)) {
return failure();
}
return success();
}

void ReduceOp::print(OpAsmPrinter& p) {
p << '(' << getInputs() << ") inits(" << getInits() << ") dimensions=["
<< getDimensions() << "] combiner=@" << getCombiner();
p.printOptionalAttrDict((*this)->getAttrs(),
{getCombinerAttrName(), getDimensionsAttrName()});
p << " : " << TypeRange(getInputs());
}

LogicalResult ReduceOp::verify() {
auto module = this->getOperation()->getParentOfType<mlir::ModuleOp>();
auto combiner = module.lookupSymbol<mlir::func::FuncOp>(getCombinerAttr());
if (!combiner) {
return emitOpError() << "combiner `@" << getCombiner() << "` not found";
}

auto inferred_init_types = inferReductionInitTypes(TypeRange(getInputs()));
SmallVector<Type, 2> combiner_operand_types;
combiner_operand_types.reserve(getNumOperands());
combiner_operand_types.append(inferred_init_types);
combiner_operand_types.append(inferred_init_types);
auto expected_combiner_type = mlir::FunctionType::get(
getContext(), combiner_operand_types, inferred_init_types);
if (expected_combiner_type != combiner.getFunctionType()) {
return emitOpError() << "provided combiner `@" << getCombiner()
<< " expected to have type " << expected_combiner_type
<< " but got " << combiner.getFunctionType();
}
return success();
}

} // namespace gpu
} // namespace xla

Expand Down
47 changes: 47 additions & 0 deletions xla/service/gpu/fusions/ir/xla_gpu_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def XLAGPU_ShuffleReduceOp : XLAGPU_Op<"shuffle_reduce",
function. The function is invoked with the operands from the low lanes,
followed by the operands from the high lanes. For example:

// TODO: update the syntax to make it similar to xla_gpu.reduce.
```
shuffle_reduce @argmax(%value, %idx) : (f32, index)
```
Expand Down Expand Up @@ -413,4 +414,50 @@ def XLAGPU_InsertOp : XLAGPU_Op<"insert", [TypesMatchWith<
}];
}

def XLAGPU_ReduceOp : XLAGPU_Op<"reduce", [
Pure, CallOpInterface, SameVariadicOperandSize,
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Performs a reduction";
let description = [{
The `xla_gpu.reduce` op performs a variadic reduction of the provided
operands using the list of dimensions and a symbol for a combiner function.

```mlir
func.func @add(%a_acc: f32, %b_acc: i32, %a: f32, %b: i32)
-> (f32, i32) {
%0 = arith.addf %a_acc, %a : f32
%1 = arith.addi %b_acc, %b : i32
func.return %0, %1 : f32, i32
}
%sum:2 = xla_gpu.reduce (%in0, %in1) inits(%init0, %init1) dimensions=[0, 2]
combiner=@add
```
}];
let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyType>:$inits,
ConfinedAttr<DenseI64ArrayAttr,
[DenseArrayStrictlySorted<DenseI64ArrayAttr>]>:$dimensions,
FlatSymbolRefAttr:$combiner);
let results = (outs Variadic<AnyRankedTensor>:$results);

let extraClassDeclaration = [{
operand_range getArgOperands() {
return getInputs();
}
mlir::MutableOperandRange getArgOperandsMutable() {
return getInputsMutable();
}
mlir::CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<mlir::SymbolRefAttr>("combiner");
}
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
(*this)->setAttr("combiner", callee.get<mlir::SymbolRefAttr>());
}
}];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_OPS

0 comments on commit d03c061

Please sign in to comment.