Skip to content

Commit

Permalink
[MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc…
Browse files Browse the repository at this point in the history
…h_matmul' operation.

Goals:
1. To add syntax and semantic to 'batch_matmul' without changing any of the
   existing syntax expectations for current usage. batch_matmul is still
   just batch_matmul.

2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS
   infra.

Scope of this patch:
To expose broadcast and transpose semantics on the 'batch_matmul'.

The broadcast and transpose semantic is as follows:

By default 'linalg.batch_matmul' behavior will remain as is.
Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below.This is a list attribute, so the list
must include all the maps if specified.

    Example Transpose:
    ```
    linalg.batch_matmul indexing_maps = [
                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose
                   affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                   affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                   ]
                   ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
                   outs(%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast:
    ```
    linalg.batch_matmul indexing_maps = [
                       affine_map<(d0, d1, d2, d3) -> (d3)>,     //broadcast
                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins(%arg0, %arg1 : memref<5xf32>,memref<2x5x7xf32>)
                     outs(%arg2: memref<2x3x7xf32>)
    ```

    Example Broadcast and transpose:
    ```
    linalg.batch_matmul indexing_maps = [
                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,     //broadcast
                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose
                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
                     ]
                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
                     outs(%arg2: memref<2x3x7xf32>)
    ```
  • Loading branch information
shahidact committed Jan 9, 2025
1 parent 86440cb commit 0f0aa7d
Show file tree
Hide file tree
Showing 8 changed files with 632 additions and 88 deletions.
69 changes: 0 additions & 69 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1472,75 +1472,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul
cpp_class_name: BatchMatmulOp
doc: |-
Performs a batched matrix multiplication of two 3D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
implements:
- LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
kind: input_tensor
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- !LinalgOperandDefConfig
name: B
kind: input_tensor
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- !LinalgOperandDefConfig
name: C
kind: output_tensor
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
iterator_types:
- parallel
- parallel
- parallel
- reduction
assignments:
- !ScalarAssign
arg: C
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: C
- !ScalarExpression
scalar_fn:
kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: A
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: B
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: batch_matmul_transpose_a
cpp_class_name: BatchMatmulTransposeAOp
Expand Down
124 changes: 124 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,130 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
}];
}

//===----------------------------------------------------------------------===//
// Op definition for BatchMatmulOp
//===----------------------------------------------------------------------===//

def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSizedOperandSegments],
/*extraInterfaces=*/[LinalgContractionOpInterface])> {

let summary = [{Performs a batched matrix multiplication of two 3D inputs.}];
let description = [{Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.

Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
'indexing_maps' as shown below.This is a list attribute, so the list must include all
the maps if specified.

Example Transpose:
```
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<2x3x7xf32>)
```

Example Broadcast:
```
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<2x3x7xf32>)
```

Example Broadcast and transpose:
```
linalg.batch_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<2x3x7xf32>)
```
}];

let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildBatchMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, BatchMatmulOp::getRegionBuilder(),
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildBatchMatmulOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, BatchMatmulOp::getRegionBuilder(),
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext()));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(operands);
$_state.addAttributes(attributes);
$_state.addTypes(resultTensorTypes);
(void)$_state.addRegion(),
BatchMatmulOp::getDefaultIndexingMaps($_builder.getContext());
}]>

];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{

SmallVector<utils::IteratorType> getIteratorTypesArray();
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}

/// Returns a list of AffineMap with the typical batch_matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);

/// Returns true if the given broadcast map \p bcastMap is valid for this op.
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);

::mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}

// Generic methods.
static unsigned getNumRegionArgs();
bool hasDynamicIndexingMaps() { return true; }
std::string getLibraryCallName();
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
/// user defined indexing maps are not equal to default map.
bool hasUserDefinedMaps();
}];
}


//===----------------------------------------------------------------------===//
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 0f0aa7d

Please sign in to comment.