From 39777ebe2138532edf449138a851e0d074e3ac6b Mon Sep 17 00:00:00 2001 From: jerryyin Date: Mon, 10 Feb 2025 19:57:32 +0000 Subject: [PATCH 1/2] Adding convolution op matcher support Co-authored-by: Max Dawkins Co-authored-by: jerryyin --- tuner/tuner/common.py | 69 +++++++++++++ tuner/tuner/dispatch_parser.py | 17 +++- tuner/tuner/dispatch_parser_test.py | 144 +++++++++++++++++++++++--- tuner/tuner/op_matchers.py | 150 +++++++++++++++++++++++++++- 4 files changed, 360 insertions(+), 20 deletions(-) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 05820526b..ac22ffe82 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -125,14 +125,83 @@ class ContractionDimensions: batch: list[int] = field(default_factory=list) +@dataclass +class ConvolutionDimensions: + """ + Stores which dimensions of the iteration space belong to the convolution. + For example, the following is a simple nhwc_fhwc conv: + linalg.generic ... indexing_maps = [ + affine_map<(b, oh, ow, oc, fh, fw, ic) -> (b, oh + fh, ow + fw, ic)>, + affine_map<(b, oh, ow, oc, fh, fw, ic) -> (oc, fh, fw, ic)>, + affine_map<(b, oh, ow, oc, fh, fw, ic) -> (b, oh, ow, oc)>, + ] + The ConvolutionDimensions would be: + batch = [0] + outputImage = [1, 2] + outputChannel = [3] + filterLoop = [4, 5] + inputChannel = [6] + depth = [] + strides = [1, 1] + dilations = [1, 1] + """ + + batch: list[int] = field(default_factory=list) + outputImage: list[int] = field(default_factory=list) + outputChannel: list[int] = field(default_factory=list) + filterLoop: list[int] = field(default_factory=list) + inputChannel: list[int] = field(default_factory=list) + depth: list[int] = field(default_factory=list) + strides: list[int] = field(default_factory=list) + dilations: list[int] = field(default_factory=list) + + @dataclass class ProblemSize: + """ + Represents a problem size for a contraction or convolution operation. When it is + a convolution all fields including: contraction_dims, lhs_expr_dims, rhs_expr_dims, + res_expr_dims and conv_dims are required to be set. + + For example, the following is a simple convolution: + %conv = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} + ins(%6, %7 : tensor<2x62x66x960xf16>, tensor<3x3x960x640xf16>) + outs(%13 : tensor<2x60x64x640xf32>) -> tensor<2x60x64x640xf32> + + The ProblemSize would be: + matmul_size = ContractionSizes(M=[2, 60, 64], N=[640], K=[3, 3, 960], B=[]), + lhs_type = ShapedType(shape=[2, 62, 66, 960], element_type=F16Type(f16)), + rhs_type = ShapedType(shape=[3, 3, 960, 640], element_type=F16Type(f16)), + res_type = ShapedType(shape=[2, 60, 64, 640], element_type=F32Type(f32)), + dispatch_kind = DispatchKind.conv + + # The contraction_dims field is intuitive for gemms. For convolutions, it recorded the convolution + # dimension to contraction dimension mapping. In igemm setting: + # - [d0, d1, d2] map to [b, oh, ow], or m dimension of the gemm + # - [d3] map to [oc], or n dimension of the gemm + # - [d4, d5, d6] map to [fh, fw, ic], or k dimension of the gemm + contraction_dims = ContractionDimensions(m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[]), + + # *expr_dims fields represent the dimensions that appear in the affine map result expressions at + # each dimension of the operator tensor. + lhs_expr_dims = [[0], [1, 4], [2, 5], [6]], + rhs_expr_dims = [[4], [5], [6], [3]], + res_expr_dims = [[0], [1], [2], [3]], + conv_dims = ConvolutionDimensions(...) + + """ + matmul_size: ContractionSizes lhs_type: ShapedType rhs_type: ShapedType res_type: ShapedType dispatch_kind: DispatchKind contraction_dims: ContractionDimensions + lhs_expr_dims: Optional[list[list[int]]] = None + rhs_expr_dims: Optional[list[list[int]]] = None + res_expr_dims: Optional[list[list[int]]] = None + conv_dims: Optional[ConvolutionDimensions] = None @property def MNK(self) -> tuple[list[int], list[int], list[int]]: diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 8f7ee08c4..965f8ef74 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -93,13 +93,16 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_type=ShapedType(res_type.shape, res_type.element_type), dispatch_kind=DispatchKind.contraction, contraction_dims=contraction_dims, + lhs_expr_dims=[[d] for d in matcher.lhs_dims], + rhs_expr_dims=[[d] for d in matcher.rhs_dims], + res_expr_dims=[[d] for d in matcher.res_dims], ) # TODO(Max191): Support more convolution types. Only NHWC convs are supported. class ConvolutionOpInterfaceParser(DispatchParser): def __init__(self): - self.supported_ops = ["linalg.conv_2d_nhwc_hwcf"] + self.supported_ops = ["conv"] def supports(self, op_name: str) -> bool: for supported_op_name in self.supported_ops: @@ -111,13 +114,17 @@ def get_conv_operation( self, ir_module: ir.Module, ) -> Optional[ir.Operation]: - return match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + return match_root_op(ir_module, ConvolutionOpInterfaceMatcher()) # TODO(Max191): Pass the ir_module directly instead of the template str. def get_shapes(self, template: list[str]) -> ProblemSize: ir_module = ir.Module.parse("\n".join(template)) - conv_op = match_root_op(ir_module, NamedOpMatcher(self.supported_ops)) + matcher = ConvolutionOpInterfaceMatcher() + conv_op = match_root_op(ir_module, matcher) + conv_dims = matcher.convolution_dimensions + assert conv_op is not None, f"convolution op not found" + lhs_type = ir.RankedTensorType(conv_op.operands[0].type) rhs_type = ir.RankedTensorType(conv_op.operands[1].type) res_type = ir.RankedTensorType(conv_op.operands[2].type) @@ -137,4 +144,8 @@ def get_shapes(self, template: list[str]) -> ProblemSize: n=[3], k=[4, 5, 6], ), + lhs_expr_dims=matcher.lhs_expr_dims, + rhs_expr_dims=matcher.rhs_expr_dims, + res_expr_dims=matcher.res_expr_dims, + conv_dims=conv_dims, ) diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 204f84b28..2f3474696 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -32,6 +32,10 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: yield ctx +############################################################################### +# Contraction Op Interface Parser Tests +############################################################################### + CONTRACTION_TEMPLATE = r""" builtin.module{{ func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{ @@ -59,9 +63,8 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: }}""" -def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: +def test_get_contraction_transb_operation(tuner_ctx: common.TunerContext) -> None: context = tuner_ctx.mlir_ctx - with ir.Location.unknown(): transpose_b_str = CONTRACTION_TEMPLATE.format( lhs_type=ir.RankedTensorType.get([16, 64], ir.F16Type.get()), @@ -89,6 +92,9 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: assert shapes.res_type.shape == [16, 32] assert isinstance(shapes.res_type.element_type, ir.F32Type) + +def test_get_contraction_bmm_operation(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx with ir.Location.unknown(): bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format( lhs_type=ir.RankedTensorType.get([5, 8, 128], ir.F16Type.get()), @@ -100,8 +106,10 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: iterator_types='["parallel", "parallel", "parallel", "reduction"]', ) module = ir.Module.parse(bmm_transposed_inputs_str, context) + parser = dispatch_parser.ContractionOpInterfaceParser() mmt_op = parser.get_contraction_operation(module) shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines()) + assert mmt_op is not None assert shapes.matmul_size.B == [5] assert shapes.matmul_size.M == [8] assert shapes.matmul_size.N == [40] @@ -130,25 +138,129 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: assert shapes.matmul_size.K == [15, 256] -def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None: +############################################################################### +# Convolution Op Interface Parser Tests +############################################################################### + +CONVOLUTION_NAMED_OP_TEMPLATE = r""" +builtin.module{{ + func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : {res_type} + %1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type} + %2 = {conv_op} {{root_op}} + ins(%arg0, %arg1 : {lhs_type}, {rhs_type}) + outs(%1 : {res_type}) -> {res_type} + return %2 : {res_type} + }} +}}""" + + +def test_get_conv_named_hwcf_operation(tuner_ctx: common.TunerContext) -> None: context = tuner_ctx.mlir_ctx - module_str = """ - builtin.module{ - func.func @test(%arg0: tensor<2x34x34x16xi8>, %arg1: tensor<3x3x16x16xi8>) -> tensor<2x32x32x16xi32> { - %cst = arith.constant 0 : i32 - %0 = tensor.empty() : tensor<2x32x32x16xi32> - %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> - %2 = linalg.conv_2d_nhwc_hwcf {root_op} - ins(%arg0, %arg1 : tensor<2x34x34x16xi8>, tensor<3x3x16x16xi8>) - outs(%1 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32> - return %2 : tensor<2x32x32x16xi32> - } - }""" - module = ir.Module.parse(module_str, context) + with ir.Location.unknown(): + conv_nhwc_hwcf_input_str = CONVOLUTION_NAMED_OP_TEMPLATE.format( + lhs_type=ir.RankedTensorType.get([2, 34, 34, 16], ir.F16Type.get()), + rhs_type=ir.RankedTensorType.get([3, 3, 16, 16], ir.F16Type.get()), + res_type=ir.RankedTensorType.get([2, 32, 32, 16], ir.F16Type.get()), + conv_op="linalg.conv_2d_nhwc_hwcf", + ) + module = ir.Module.parse(conv_nhwc_hwcf_input_str, context) parser = dispatch_parser.ConvolutionOpInterfaceParser() conv_op = parser.get_conv_operation(module) + shapes = parser.get_shapes(conv_nhwc_hwcf_input_str.splitlines()) assert conv_op is not None assert isinstance(conv_op.opview, linalg.Conv2DNhwcHwcfOp) + assert shapes.matmul_size == common.ContractionSizes( + M=[2, 32, 32], N=[16], K=[3, 3, 16], B=[] + ) + assert shapes.contraction_dims == common.ContractionDimensions( + m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[] + ) + assert shapes.lhs_expr_dims == [[0], [1, 4], [2, 5], [6]] + assert shapes.rhs_expr_dims == [[4], [5], [6], [3]] + assert shapes.res_expr_dims == [[0], [1], [2], [3]] + assert shapes.conv_dims == common.ConvolutionDimensions( + batch=[0], + outputImage=[1, 2], + outputChannel=[3], + filterLoop=[4, 5], + inputChannel=[6], + depth=[], + strides=[1, 1], + dilations=[1, 1], + ) + + +GENERICS_CONVOLUTION_TEMPLATE = r""" +builtin.module{{ + func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{ + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : {res_type} + %1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type} + %2 = linalg.generic {{ + indexing_maps = [ + {lhs_map}, + {rhs_map}, + {res_map}], + iterator_types = {iterator_types}}} + {{root_op}} + ins(%arg0, %arg1 : {lhs_type}, {rhs_type}) + outs(%1 : {res_type}) {{ + ^bb0(%in: f16, %in_0: f16, %out: f32): + %3 = arith.extf %in : f16 to f32 + %4 = arith.extf %in_0 : f16 to f32 + %5 = arith.mulf %3, %4 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + }} -> {res_type} + return %2 : {res_type} + }} +}}""" + + +def test_get_conv_generic_hwcf_operation(tuner_ctx: common.TunerContext) -> None: + context = tuner_ctx.mlir_ctx + with ir.Location.unknown(): + conv_hwfc_input_str = GENERICS_CONVOLUTION_TEMPLATE.format( + lhs_type=ir.RankedTensorType.get([2, 34, 34, 16], ir.F16Type.get()), + rhs_type=ir.RankedTensorType.get([3, 3, 16, 16], ir.F16Type.get()), + res_type=ir.RankedTensorType.get([2, 32, 32, 16], ir.F32Type.get()), + lhs_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1+d4, d2+d5, d6)>", + rhs_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>", + res_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>", + iterator_types='["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]', + ) + module = ir.Module.parse(conv_hwfc_input_str, context) + parser = dispatch_parser.ConvolutionOpInterfaceParser() + conv_op = parser.get_conv_operation(module) + shapes = parser.get_shapes(conv_hwfc_input_str.splitlines()) + assert conv_op is not None + assert isinstance(conv_op.opview, linalg.GenericOp) + assert shapes.matmul_size == common.ContractionSizes( + M=[2, 32, 32], N=[16], K=[3, 3, 16], B=[] + ) + assert shapes.contraction_dims == common.ContractionDimensions( + m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[] + ) + assert shapes.lhs_expr_dims == [[0], [1, 4], [2, 5], [6]] + assert shapes.rhs_expr_dims == [[4], [5], [6], [3]] + assert shapes.res_expr_dims == [[0], [1], [2], [3]] + assert shapes.conv_dims == common.ConvolutionDimensions( + batch=[0], + outputImage=[1, 2], + outputChannel=[3], + filterLoop=[4, 5], + inputChannel=[6], + depth=[], + strides=[1, 1], + dilations=[1, 1], + ) + + +############################################################################### +# Misc Dispatch Parser Tests +############################################################################### def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py index 09fcd17ea..51e70e91f 100644 --- a/tuner/tuner/op_matchers.py +++ b/tuner/tuner/op_matchers.py @@ -93,7 +93,10 @@ def match(self, op: ir.Operation) -> bool: maps_attr = None for attr in op.opview.attributes: - if attr.name == "indexing_maps" and isinstance(attr.attr, ir.ArrayAttr): + if attr.name in ( + "indexing_maps", + "linalg.memoized_indexing_maps", + ) and isinstance(attr.attr, ir.ArrayAttr): maps_attr = attr.attr if maps_attr is None: return False @@ -124,6 +127,113 @@ def get_map_result_dim_positions(map: ir.AffineMap): return exprs +def get_convolution_dims( + lhs_map: ir.AffineMap, rhs_map: ir.AffineMap, res_map: ir.AffineMap +): + if not rhs_map.is_projected_permutation or not res_map.is_projected_permutation: + return None + lhs_dims = [] + rhs_dims = get_map_result_dim_positions(rhs_map) + res_dims = get_map_result_dim_positions(res_map) + batch_dims = [] + input_channel_dims = [] + output_channel_dims = [] + output_img_dims = [] + filter_loop_dims = [] + strides = [] + dilations = [] + + assert rhs_dims is not None + assert res_dims is not None + + def get_dim_pos(dim): + if len(dim) < 1 or dim[0] != "d" or not dim[1:].isdigit(): + return None + return int(dim[1:]) + + def verify_and_add_img_and_filter_dims(img_dim, filter_dim): + # Verify set for img dim + if img_dim in rhs_dims or img_dim not in res_dims: + return None + output_img_dims.append(img_dim) + # Verify set for filter dim + if filter_dim not in rhs_dims or filter_dim in res_dims: + return None + filter_loop_dims.append(filter_dim) + + for expr in lhs_map.results: + dim_strs = str(expr).split(" ") + if len(dim_strs) == 1: + dim = get_dim_pos(dim_strs[0]) + if dim is None: + return None + if dim not in rhs_dims and dim in res_dims: + batch_dims.append(dim) + elif dim in rhs_dims and dim not in res_dims: + input_channel_dims.append(dim) + else: + return None + lhs_dims.append([dim]) + continue + # Convolved dim with no stride + if len(dim_strs) == 3: + if dim_strs[1] != "+": + return None + img_dim = get_dim_pos(dim_strs[0]) + filter_dim = get_dim_pos(dim_strs[2]) + if img_dim is None or filter_dim is None: + return None + verify_and_add_img_and_filter_dims(img_dim, filter_dim) + strides.append(1) + dilations.append(1) + lhs_dims.append([img_dim, filter_dim]) + continue + # Convolved dim with stride + if len(dim_strs) == 5: + if dim_strs[1] != "*" or dim_strs[3] != "+": + return None + img_dim = get_dim_pos(dim_strs[0]) + filter_dim = get_dim_pos(dim_strs[4]) + if img_dim is None or filter_dim is None: + return None + verify_and_add_img_and_filter_dims(img_dim, filter_dim) + if not dim_strs[2].isdigit(): + return None + strides.append(int(dim_strs[2])) + dilations.append(1) + lhs_dims.append([img_dim, filter_dim]) + continue + + # lhs_dims is a 2D array. Get the set of dims lhs dims. + lhs_dim_set = set() + for dims in lhs_dims: + for dim in dims: + lhs_dim_set.add(dim) + # Collect output channel dims + for d in range(rhs_map.n_dims): + if d not in lhs_dim_set and d in rhs_dims and d in res_dims: + output_channel_dims.append(d) + + # Verify that there are convolved dims + if len(filter_loop_dims) == 0: + return None + + return ( + ConvolutionDimensions( + batch=batch_dims, + outputImage=output_img_dims, + outputChannel=output_channel_dims, + filterLoop=filter_loop_dims, + inputChannel=input_channel_dims, + strides=strides, + dilations=dilations, + ), + lhs_dims, + [[d] for d in rhs_dims], + [[d] for d in res_dims], + ) + + class ContractionOpInterfaceMatcher(GenericOpMatcher): def __init__(self) -> None: super().__init__() @@ -179,3 +289,41 @@ def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: self.rhs_dims = rhs_dims self.res_dims = res_dims return True + + +class ConvolutionOpInterfaceMatcher(GenericOpMatcher): + def __init__(self) -> None: + super().__init__() + self.supported_named_ops = [ + "linalg.conv_2d_nhwc_hwcf", + "linalg.conv_2d_nhwc_fhwc", + ] + self.op_names.extend(self.supported_named_ops) + self.convolution_dimensions: Optional[ConvolutionDimensions] = None + self.lhs_expr_dims: Optional[list[list[int]]] = None + self.rhs_expr_dims: Optional[list[list[int]]] = None + self.res_expr_dims: Optional[list[list[int]]] = None + + def match_operands(self, operands: ir.OpOperandList) -> bool: + if len(operands) != 3: + return False + for operand in operands: + if not isinstance(operand.type, ir.ShapedType): + return False + return True + + def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: + if len(maps) != 3: + return False + + conv_dim_info = get_convolution_dims(maps[0], maps[1], maps[2]) + if conv_dim_info is None: + return False + + conv_dims, lhs_dims, rhs_dims, res_dims = conv_dim_info + self.convolution_dimensions = conv_dims + + self.lhs_expr_dims = lhs_dims + self.rhs_expr_dims = rhs_dims + self.res_expr_dims = res_dims + return True From 10ec7fece26fca05cec11c9da7f1349a7ef68bb2 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Tue, 11 Feb 2025 15:01:16 +0000 Subject: [PATCH 2/2] Addressing review feedbacks Signed-off-by: jerryyin --- tuner/tuner/op_matchers.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py index 51e70e91f..f75e627a7 100644 --- a/tuner/tuner/op_matchers.py +++ b/tuner/tuner/op_matchers.py @@ -69,10 +69,9 @@ def match(self, op: ir.Operation) -> bool: return op.name in self.op_names -# TODO(Max191): Add logic to match the body of the generic op. -class GenericOpMatcher(NamedOpMatcher): - def __init__(self): - super().__init__(["linalg.generic"]) +class LinalgOpMatcher(NamedOpMatcher): + def __init__(self, op_names: list[str]) -> None: + super().__init__(op_names) @abstractmethod def match_operands(self, operands: ir.OpOperandList) -> bool: @@ -110,6 +109,12 @@ def match(self, op: ir.Operation) -> bool: return True +# TODO(Max191): Add logic to match the body of the generic op. +class GenericOpMatcher(LinalgOpMatcher): + def __init__(self): + super().__init__(["linalg.generic"]) + + def get_map_result_dim_positions(map: ir.AffineMap): exprs = [] if not map.is_projected_permutation: @@ -127,6 +132,8 @@ def get_map_result_dim_positions(map: ir.AffineMap): return exprs +# TODO: Replace this function with a call to MLIR python bindings for linalg::inferConvolutionDims +# once the particular python binding is exposed def get_convolution_dims( lhs_map: ir.AffineMap, rhs_map: ir.AffineMap, res_map: ir.AffineMap ):