Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding conv igemm layout support - step 1: conv op matcher #943

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,83 @@ class ContractionDimensions:
batch: list[int] = field(default_factory=list)


@dataclass
class ConvolutionDimensions:
"""
Stores which dimensions of the iteration space belong to the convolution.
For example, the following is a simple nhwc_fhwc conv:
linalg.generic ... indexing_maps = [
affine_map<(b, oh, ow, oc, fh, fw, ic) -> (b, oh + fh, ow + fw, ic)>,
affine_map<(b, oh, ow, oc, fh, fw, ic) -> (oc, fh, fw, ic)>,
affine_map<(b, oh, ow, oc, fh, fw, ic) -> (b, oh, ow, oc)>,
]
The ConvolutionDimensions would be:
batch = [0]
outputImage = [1, 2]
outputChannel = [3]
filterLoop = [4, 5]
inputChannel = [6]
depth = []
strides = [1, 1]
dilations = [1, 1]
"""

batch: list[int] = field(default_factory=list)
outputImage: list[int] = field(default_factory=list)
outputChannel: list[int] = field(default_factory=list)
filterLoop: list[int] = field(default_factory=list)
inputChannel: list[int] = field(default_factory=list)
depth: list[int] = field(default_factory=list)
strides: list[int] = field(default_factory=list)
dilations: list[int] = field(default_factory=list)


@dataclass
class ProblemSize:
"""
Represents a problem size for a contraction or convolution operation. When it is
a convolution all fields including: contraction_dims, lhs_expr_dims, rhs_expr_dims,
res_expr_dims and conv_dims are required to be set.

For example, the following is a simple convolution:
%conv = linalg.conv_2d_nhwc_hwcf
{dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
ins(%6, %7 : tensor<2x62x66x960xf16>, tensor<3x3x960x640xf16>)
outs(%13 : tensor<2x60x64x640xf32>) -> tensor<2x60x64x640xf32>

The ProblemSize would be:
matmul_size = ContractionSizes(M=[2, 60, 64], N=[640], K=[3, 3, 960], B=[]),
lhs_type = ShapedType(shape=[2, 62, 66, 960], element_type=F16Type(f16)),
rhs_type = ShapedType(shape=[3, 3, 960, 640], element_type=F16Type(f16)),
res_type = ShapedType(shape=[2, 60, 64, 640], element_type=F32Type(f32)),
dispatch_kind = DispatchKind.conv

# The contraction_dims field is intuitive for gemms. For convolutions, it recorded the convolution
# dimension to contraction dimension mapping. In igemm setting:
# - [d0, d1, d2] map to [b, oh, ow], or m dimension of the gemm
# - [d3] map to [oc], or n dimension of the gemm
# - [d4, d5, d6] map to [fh, fw, ic], or k dimension of the gemm
contraction_dims = ContractionDimensions(m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[]),

# *expr_dims fields represent the dimensions that appear in the affine map result expressions at
# each dimension of the operator tensor.
lhs_expr_dims = [[0], [1, 4], [2, 5], [6]],
rhs_expr_dims = [[4], [5], [6], [3]],
res_expr_dims = [[0], [1], [2], [3]],
conv_dims = ConvolutionDimensions(...)

"""

matmul_size: ContractionSizes
lhs_type: ShapedType
rhs_type: ShapedType
res_type: ShapedType
dispatch_kind: DispatchKind
contraction_dims: ContractionDimensions
lhs_expr_dims: Optional[list[list[int]]] = None
rhs_expr_dims: Optional[list[list[int]]] = None
res_expr_dims: Optional[list[list[int]]] = None
conv_dims: Optional[ConvolutionDimensions] = None

@property
def MNK(self) -> tuple[list[int], list[int], list[int]]:
Expand Down
17 changes: 14 additions & 3 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
res_type=ShapedType(res_type.shape, res_type.element_type),
dispatch_kind=DispatchKind.contraction,
contraction_dims=contraction_dims,
lhs_expr_dims=[[d] for d in matcher.lhs_dims],
rhs_expr_dims=[[d] for d in matcher.rhs_dims],
res_expr_dims=[[d] for d in matcher.res_dims],
)


# TODO(Max191): Support more convolution types. Only NHWC convs are supported.
class ConvolutionOpInterfaceParser(DispatchParser):
def __init__(self):
self.supported_ops = ["linalg.conv_2d_nhwc_hwcf"]
self.supported_ops = ["conv"]

def supports(self, op_name: str) -> bool:
for supported_op_name in self.supported_ops:
Expand All @@ -111,13 +114,17 @@ def get_conv_operation(
self,
ir_module: ir.Module,
) -> Optional[ir.Operation]:
return match_root_op(ir_module, NamedOpMatcher(self.supported_ops))
return match_root_op(ir_module, ConvolutionOpInterfaceMatcher())

# TODO(Max191): Pass the ir_module directly instead of the template str.
def get_shapes(self, template: list[str]) -> ProblemSize:
ir_module = ir.Module.parse("\n".join(template))
conv_op = match_root_op(ir_module, NamedOpMatcher(self.supported_ops))
matcher = ConvolutionOpInterfaceMatcher()
conv_op = match_root_op(ir_module, matcher)
conv_dims = matcher.convolution_dimensions

assert conv_op is not None, f"convolution op not found"

lhs_type = ir.RankedTensorType(conv_op.operands[0].type)
rhs_type = ir.RankedTensorType(conv_op.operands[1].type)
res_type = ir.RankedTensorType(conv_op.operands[2].type)
Expand All @@ -137,4 +144,8 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
n=[3],
k=[4, 5, 6],
),
lhs_expr_dims=matcher.lhs_expr_dims,
rhs_expr_dims=matcher.rhs_expr_dims,
res_expr_dims=matcher.res_expr_dims,
conv_dims=conv_dims,
)
144 changes: 128 additions & 16 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
yield ctx


###############################################################################
# Contraction Op Interface Parser Tests
###############################################################################

CONTRACTION_TEMPLATE = r"""
builtin.module{{
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
Expand Down Expand Up @@ -59,9 +63,8 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
}}"""


def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
def test_get_contraction_transb_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx

with ir.Location.unknown():
transpose_b_str = CONTRACTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([16, 64], ir.F16Type.get()),
Expand Down Expand Up @@ -89,6 +92,9 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
assert shapes.res_type.shape == [16, 32]
assert isinstance(shapes.res_type.element_type, ir.F32Type)


def test_get_contraction_bmm_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
with ir.Location.unknown():
bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([5, 8, 128], ir.F16Type.get()),
Expand All @@ -100,8 +106,10 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
iterator_types='["parallel", "parallel", "parallel", "reduction"]',
)
module = ir.Module.parse(bmm_transposed_inputs_str, context)
parser = dispatch_parser.ContractionOpInterfaceParser()
mmt_op = parser.get_contraction_operation(module)
shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines())
assert mmt_op is not None
assert shapes.matmul_size.B == [5]
assert shapes.matmul_size.M == [8]
assert shapes.matmul_size.N == [40]
Expand Down Expand Up @@ -130,25 +138,129 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
assert shapes.matmul_size.K == [15, 256]


def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None:
###############################################################################
# Convolution Op Interface Parser Tests
###############################################################################

CONVOLUTION_NAMED_OP_TEMPLATE = r"""
builtin.module{{
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : {res_type}
%1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type}
%2 = {conv_op} {{root_op}}
ins(%arg0, %arg1 : {lhs_type}, {rhs_type})
outs(%1 : {res_type}) -> {res_type}
return %2 : {res_type}
}}
}}"""


def test_get_conv_named_hwcf_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
module_str = """
builtin.module{
func.func @test(%arg0: tensor<2x34x34x16xi8>, %arg1: tensor<3x3x16x16xi8>) -> tensor<2x32x32x16xi32> {
%cst = arith.constant 0 : i32
%0 = tensor.empty() : tensor<2x32x32x16xi32>
%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32>
%2 = linalg.conv_2d_nhwc_hwcf {root_op}
ins(%arg0, %arg1 : tensor<2x34x34x16xi8>, tensor<3x3x16x16xi8>)
outs(%1 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32>
return %2 : tensor<2x32x32x16xi32>
}
}"""
module = ir.Module.parse(module_str, context)
with ir.Location.unknown():
conv_nhwc_hwcf_input_str = CONVOLUTION_NAMED_OP_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([2, 34, 34, 16], ir.F16Type.get()),
rhs_type=ir.RankedTensorType.get([3, 3, 16, 16], ir.F16Type.get()),
res_type=ir.RankedTensorType.get([2, 32, 32, 16], ir.F16Type.get()),
conv_op="linalg.conv_2d_nhwc_hwcf",
)
module = ir.Module.parse(conv_nhwc_hwcf_input_str, context)
parser = dispatch_parser.ConvolutionOpInterfaceParser()
conv_op = parser.get_conv_operation(module)
shapes = parser.get_shapes(conv_nhwc_hwcf_input_str.splitlines())
assert conv_op is not None
assert isinstance(conv_op.opview, linalg.Conv2DNhwcHwcfOp)
assert shapes.matmul_size == common.ContractionSizes(
M=[2, 32, 32], N=[16], K=[3, 3, 16], B=[]
)
assert shapes.contraction_dims == common.ContractionDimensions(
m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[]
)
assert shapes.lhs_expr_dims == [[0], [1, 4], [2, 5], [6]]
assert shapes.rhs_expr_dims == [[4], [5], [6], [3]]
assert shapes.res_expr_dims == [[0], [1], [2], [3]]
assert shapes.conv_dims == common.ConvolutionDimensions(
batch=[0],
outputImage=[1, 2],
outputChannel=[3],
filterLoop=[4, 5],
inputChannel=[6],
depth=[],
strides=[1, 1],
dilations=[1, 1],
)


GENERICS_CONVOLUTION_TEMPLATE = r"""
builtin.module{{
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : {res_type}
%1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type}
%2 = linalg.generic {{
indexing_maps = [
{lhs_map},
{rhs_map},
{res_map}],
iterator_types = {iterator_types}}}
{{root_op}}
ins(%arg0, %arg1 : {lhs_type}, {rhs_type})
outs(%1 : {res_type}) {{
^bb0(%in: f16, %in_0: f16, %out: f32):
%3 = arith.extf %in : f16 to f32
%4 = arith.extf %in_0 : f16 to f32
%5 = arith.mulf %3, %4 : f32
%6 = arith.addf %out, %5 : f32
linalg.yield %6 : f32
}} -> {res_type}
return %2 : {res_type}
}}
}}"""


def test_get_conv_generic_hwcf_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
with ir.Location.unknown():
conv_hwfc_input_str = GENERICS_CONVOLUTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([2, 34, 34, 16], ir.F16Type.get()),
rhs_type=ir.RankedTensorType.get([3, 3, 16, 16], ir.F16Type.get()),
res_type=ir.RankedTensorType.get([2, 32, 32, 16], ir.F32Type.get()),
lhs_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1+d4, d2+d5, d6)>",
rhs_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>",
res_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>",
iterator_types='["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]',
)
module = ir.Module.parse(conv_hwfc_input_str, context)
parser = dispatch_parser.ConvolutionOpInterfaceParser()
conv_op = parser.get_conv_operation(module)
shapes = parser.get_shapes(conv_hwfc_input_str.splitlines())
assert conv_op is not None
assert isinstance(conv_op.opview, linalg.GenericOp)
assert shapes.matmul_size == common.ContractionSizes(
M=[2, 32, 32], N=[16], K=[3, 3, 16], B=[]
)
assert shapes.contraction_dims == common.ContractionDimensions(
m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[]
)
assert shapes.lhs_expr_dims == [[0], [1, 4], [2, 5], [6]]
assert shapes.rhs_expr_dims == [[4], [5], [6], [3]]
assert shapes.res_expr_dims == [[0], [1], [2], [3]]
assert shapes.conv_dims == common.ConvolutionDimensions(
batch=[0],
outputImage=[1, 2],
outputChannel=[3],
filterLoop=[4, 5],
inputChannel=[6],
depth=[],
strides=[1, 1],
dilations=[1, 1],
)


###############################################################################
# Misc Dispatch Parser Tests
###############################################################################


def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None:
Expand Down
Loading
Loading