Skip to content

Commit

Permalink
dialects: (onnx) implementation onnx.MatMul (#2467)
Browse files Browse the repository at this point in the history
Implementation of onnx dialect operation MatMul:

- Implementation of  MatMul IRLDOperation in onnx.py
- Implementation of tests in onnx_invalid.mlir
- Implementation of tests in onnx_ops.mlir

@superlopuh @compor

---------

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
alecerio and superlopuh authored Apr 17, 2024
1 parent ff8e134 commit 8482c67
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 1 deletion.
36 changes: 36 additions & 0 deletions tests/filecheck/dialects/onnx/onnx_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,39 @@ builtin.module {

}

// -----

builtin.module {
%t0, %t1 = "test.op"() : () -> (tensor<2x4x3xf32>, tensor<4x2xf32>)

// CHECK: Operation does not verify: input matrix A should be a 2D tensor
%res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4x3xf32>, tensor<4x2xf32>) -> tensor<2x2xf32>
}

// -----

builtin.module {
%t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2x3xf32>)

// CHECK: Operation does not verify: input matrix B should be a 2D tensor
%res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2x3xf32>) -> tensor<2x2xf32>
}

// -----

builtin.module {
%t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<5x2xf32>)

// CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (5, 2)
%res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<5x2xf32>) -> tensor<2x2xf32>
}

// -----


builtin.module {
%t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2xf32>)

// CHECK: Operation does not verify: result shape [2, 2] does not match result type [2, 3]
%res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x3xf32>
}
3 changes: 2 additions & 1 deletion tests/filecheck/dialects/onnx/onnx_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
%res_gemm_3 = "onnx.Gemm"(%t27, %t28, %t29) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32>
// CHECK: %res_gemm_3 = onnx.Gemm(%t27, %t28, %t29) {"onnx_node_name" = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64} : (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32>


%res_matmul = "onnx.MatMul"(%t9, %t10) {onnx_node_name = "/MatMul"}: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %res_matmul = onnx.MatMul(%t9, %t10) {"onnx_node_name" = "/MatMul"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>


73 changes: 73 additions & 0 deletions xdsl/dialects/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Self

from xdsl.dialects.builtin import (
Any,
AnyFloat,
AnyIntegerAttr,
AnyTensorType,
Expand Down Expand Up @@ -860,6 +861,77 @@ def __init__(self, func: Attribute):
)


@irdl_op_definition
class MatMul(IRDLOperation):
"""
The operation MatMul performs matrix multiplication between two input matrices, A and B, and returns the result as matrix Y.
Matrix multiplication is a fundamental operation in linear algebra, where each element of the resulting matrix Y is computed by taking the
dot product of the corresponding row of matrix A and column of matrix B.
"""

name = "onnx.MatMul"

# describe annotated type
T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")]

# input matrices
matrix_A = operand_def(TensorType[T])
matrix_B = operand_def(TensorType[T])

# output matrices
matrix_Y = result_def(TensorType[T])

assembly_format = (
"`(` $matrix_A `,` $matrix_B `)` attr-dict `:` `(` type($matrix_A) `,"
"` type($matrix_B) `)` `->` type($matrix_Y) "
)

def __init__(
self,
matrix_A: SSAValue,
matrix_B: SSAValue,
matrix_Y_type: Attribute,
):
super().__init__(
operands=[matrix_A, matrix_B],
result_types=[matrix_Y_type],
)

def verify_(self) -> None:
# store dimensions of tensor A and tensor B
res_shape: list[int] = []
matrix_A_type = cast(TensorType[Any], self.matrix_A.type)
matrix_B_type = cast(TensorType[Any], self.matrix_B.type)
matrix_Y_type = cast(TensorType[Any], self.matrix_Y.type)

# check shape compatibility
matrix_A_shape = matrix_A_type.get_shape()
matrix_B_shape = matrix_B_type.get_shape()

if matrix_A_type.get_num_dims() != 2:
raise VerifyException("input matrix A should be a 2D tensor")

if matrix_B_type.get_num_dims() != 2:
raise VerifyException("input matrix B should be a 2D tensor")

if matrix_A_shape[1] != matrix_B_shape[0]:
raise VerifyException(
f"operands have incompatible shapes: {matrix_A_shape} and {matrix_B_shape}"
)
else:
res_shape.append(matrix_A_shape[0])
res_shape.append(matrix_B_shape[1])

matrix_Y_type_shape = list(matrix_Y_type.get_shape())
if (
len(res_shape) != len(matrix_Y_type_shape)
or res_shape != matrix_Y_type_shape
):
raise VerifyException(
f"result shape {res_shape} does not match result type {matrix_Y_type_shape}"
)


ONNX = Dialect(
"onnx",
[
Expand All @@ -870,6 +942,7 @@ def __init__(self, func: Attribute):
Div,
EntryPoint,
Gemm,
MatMul,
MaxPoolSingleOut,
Mul,
Relu,
Expand Down

0 comments on commit 8482c67

Please sign in to comment.