diff --git a/tests/filecheck/dialects/onnx/onnx_invalid.mlir b/tests/filecheck/dialects/onnx/onnx_invalid.mlir index 870f1228b6..3773f8a9c0 100644 --- a/tests/filecheck/dialects/onnx/onnx_invalid.mlir +++ b/tests/filecheck/dialects/onnx/onnx_invalid.mlir @@ -484,3 +484,39 @@ builtin.module { } +// ----- + +builtin.module { + %t0, %t1 = "test.op"() : () -> (tensor<2x4x3xf32>, tensor<4x2xf32>) + + // CHECK: Operation does not verify: input matrix A should be a 2D tensor + %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4x3xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> +} + +// ----- + +builtin.module { + %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2x3xf32>) + + // CHECK: Operation does not verify: input matrix B should be a 2D tensor + %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2x3xf32>) -> tensor<2x2xf32> +} + +// ----- + +builtin.module { + %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<5x2xf32>) + + // CHECK: Operation does not verify: operands have incompatible shapes: (2, 4) and (5, 2) + %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<5x2xf32>) -> tensor<2x2xf32> +} + +// ----- + + +builtin.module { + %t0, %t1 = "test.op"() : () -> (tensor<2x4xf32>, tensor<4x2xf32>) + + // CHECK: Operation does not verify: result shape [2, 2] does not match result type [2, 3] + %res_matmul = "onnx.MatMul"(%t0, %t1) {onnx_node_name = "/MatMul"} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x3xf32> +} diff --git a/tests/filecheck/dialects/onnx/onnx_ops.mlir b/tests/filecheck/dialects/onnx/onnx_ops.mlir index 75fa6391dc..96acd664f2 100644 --- a/tests/filecheck/dialects/onnx/onnx_ops.mlir +++ b/tests/filecheck/dialects/onnx/onnx_ops.mlir @@ -78,6 +78,7 @@ %res_gemm_3 = "onnx.Gemm"(%t27, %t28, %t29) {onnx_node_name = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64}: (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> // CHECK: %res_gemm_3 = onnx.Gemm(%t27, %t28, %t29) {"onnx_node_name" = "/Gemm", "alpha" = 1.000000e+00 : f32, "beta" = 1.000000e+00 : f32, "transA" = 0 : si64, "transB" = 1 : si64} : (tensor<1x320xf32>, tensor<50x320xf32>, tensor<50xf32>) -> tensor<1x50xf32> - +%res_matmul = "onnx.MatMul"(%t9, %t10) {onnx_node_name = "/MatMul"}: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// CHECK: %res_matmul = onnx.MatMul(%t9, %t10) {"onnx_node_name" = "/MatMul"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> diff --git a/xdsl/dialects/onnx.py b/xdsl/dialects/onnx.py index a38a789403..cfc52ad40c 100644 --- a/xdsl/dialects/onnx.py +++ b/xdsl/dialects/onnx.py @@ -7,6 +7,7 @@ from typing_extensions import Self from xdsl.dialects.builtin import ( + Any, AnyFloat, AnyIntegerAttr, AnyTensorType, @@ -860,6 +861,77 @@ def __init__(self, func: Attribute): ) +@irdl_op_definition +class MatMul(IRDLOperation): + """ + The operation MatMul performs matrix multiplication between two input matrices, A and B, and returns the result as matrix Y. + Matrix multiplication is a fundamental operation in linear algebra, where each element of the resulting matrix Y is computed by taking the + dot product of the corresponding row of matrix A and column of matrix B. + """ + + name = "onnx.MatMul" + + # describe annotated type + T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] + + # input matrices + matrix_A = operand_def(TensorType[T]) + matrix_B = operand_def(TensorType[T]) + + # output matrices + matrix_Y = result_def(TensorType[T]) + + assembly_format = ( + "`(` $matrix_A `,` $matrix_B `)` attr-dict `:` `(` type($matrix_A) `," + "` type($matrix_B) `)` `->` type($matrix_Y) " + ) + + def __init__( + self, + matrix_A: SSAValue, + matrix_B: SSAValue, + matrix_Y_type: Attribute, + ): + super().__init__( + operands=[matrix_A, matrix_B], + result_types=[matrix_Y_type], + ) + + def verify_(self) -> None: + # store dimensions of tensor A and tensor B + res_shape: list[int] = [] + matrix_A_type = cast(TensorType[Any], self.matrix_A.type) + matrix_B_type = cast(TensorType[Any], self.matrix_B.type) + matrix_Y_type = cast(TensorType[Any], self.matrix_Y.type) + + # check shape compatibility + matrix_A_shape = matrix_A_type.get_shape() + matrix_B_shape = matrix_B_type.get_shape() + + if matrix_A_type.get_num_dims() != 2: + raise VerifyException("input matrix A should be a 2D tensor") + + if matrix_B_type.get_num_dims() != 2: + raise VerifyException("input matrix B should be a 2D tensor") + + if matrix_A_shape[1] != matrix_B_shape[0]: + raise VerifyException( + f"operands have incompatible shapes: {matrix_A_shape} and {matrix_B_shape}" + ) + else: + res_shape.append(matrix_A_shape[0]) + res_shape.append(matrix_B_shape[1]) + + matrix_Y_type_shape = list(matrix_Y_type.get_shape()) + if ( + len(res_shape) != len(matrix_Y_type_shape) + or res_shape != matrix_Y_type_shape + ): + raise VerifyException( + f"result shape {res_shape} does not match result type {matrix_Y_type_shape}" + ) + + ONNX = Dialect( "onnx", [ @@ -870,6 +942,7 @@ def __init__(self, func: Attribute): Div, EntryPoint, Gemm, + MatMul, MaxPoolSingleOut, Mul, Relu,