Skip to content

Commit

Permalink
Adding parser tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryyin committed Feb 5, 2025
1 parent 9242816 commit 42ea5f5
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 21 deletions.
6 changes: 3 additions & 3 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def get_td_spec(
rhs_type = ir.ShapedType(conv_op.operands[1].type)
acc_type = ir.ShapedType(conv_op.operands[2].type)

M_str = "x".join([str(mDim) for mDim in problem_size.matmul_size.M])
N_str = "x".join([str(nDim) for nDim in problem_size.matmul_size.N])
K_str = "x".join([str(kDim) for kDim in problem_size.matmul_size.K])
M_str = "x".join(map(str, problem_size.matmul_size.M))
N_str = "x".join(map(str, problem_size.matmul_size.N))
K_str = "x".join(map(str, problem_size.matmul_size.K))

conv_type = conv_op.name.split(".")[-1]
# TODO(Max191): Get the function name from the func.func in the input module.
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class ConvolutionDimensions:
Stores which dimensions of the iteration space belong to the convolution.
For example, the following is a simple nhwc_fhwc conv:
linalg.generic ... indexing_maps = [
affine_map<(b, oh, ow, oc, fh, fw, ic) -> (b, ih, iw, ic)>,
affine_map<(b, oh, ow, oc, fh, fw, ic) -> (b, oh + fh, ow + fw, ic)>,
affine_map<(b, oh, ow, oc, fh, fw, ic) -> (oc, fh, fw, ic)>,
affine_map<(b, oh, ow, oc, fh, fw, ic) -> (b, oh, ow, oc)>,
]
Expand Down
180 changes: 164 additions & 16 deletions tuner/tuner/dispatch_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
yield ctx


###############################################################################
# Contraction Op Interface Parser Tests
###############################################################################

CONTRACTION_TEMPLATE = r"""
builtin.module{{
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
Expand Down Expand Up @@ -59,9 +63,8 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]:
}}"""


def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
def test_get_contraction_transb_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx

with ir.Location.unknown():
transpose_b_str = CONTRACTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([16, 64], ir.F16Type.get()),
Expand Down Expand Up @@ -89,6 +92,9 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
assert shapes.res_type.shape == [16, 32]
assert isinstance(shapes.res_type.element_type, ir.F32Type)


def test_get_contraction_bmm_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
with ir.Location.unknown():
bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([5, 8, 128], ir.F16Type.get()),
Expand All @@ -100,8 +106,10 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
iterator_types='["parallel", "parallel", "parallel", "reduction"]',
)
module = ir.Module.parse(bmm_transposed_inputs_str, context)
parser = dispatch_parser.ContractionOpInterfaceParser()
mmt_op = parser.get_contraction_operation(module)
shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines())
assert mmt_op is not None
assert shapes.matmul_size.B == [5]
assert shapes.matmul_size.M == [8]
assert shapes.matmul_size.N == [40]
Expand Down Expand Up @@ -130,25 +138,165 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
assert shapes.matmul_size.K == [15, 256]


def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None:
###############################################################################
# Convolution Op Interface Parser Tests
###############################################################################

CONVOLUTION_NAMED_OP_TEMPLATE = r"""
builtin.module{{
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : {res_type}
%1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type}
%2 = {conv_op} {{root_op}}
ins(%arg0, %arg1 : {lhs_type}, {rhs_type})
outs(%1 : {res_type}) -> {res_type}
return %2 : {res_type}
}}
}}"""


def test_get_conv_named_hwcf_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
module_str = """
builtin.module{
func.func @test(%arg0: tensor<2x34x34x16xi8>, %arg1: tensor<3x3x16x16xi8>) -> tensor<2x32x32x16xi32> {
%cst = arith.constant 0 : i32
%0 = tensor.empty() : tensor<2x32x32x16xi32>
%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32>
%2 = linalg.conv_2d_nhwc_hwcf {root_op}
ins(%arg0, %arg1 : tensor<2x34x34x16xi8>, tensor<3x3x16x16xi8>)
outs(%1 : tensor<2x32x32x16xi32>) -> tensor<2x32x32x16xi32>
return %2 : tensor<2x32x32x16xi32>
}
}"""
module = ir.Module.parse(module_str, context)
with ir.Location.unknown():
conv_nhwc_hwcf_input_str = CONVOLUTION_NAMED_OP_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([2, 34, 34, 16], ir.F16Type.get()),
rhs_type=ir.RankedTensorType.get([3, 3, 16, 16], ir.F16Type.get()),
res_type=ir.RankedTensorType.get([2, 32, 32, 16], ir.F16Type.get()),
conv_op="linalg.conv_2d_nhwc_hwcf",
)
module = ir.Module.parse(conv_nhwc_hwcf_input_str, context)
parser = dispatch_parser.ConvolutionOpInterfaceParser()
conv_op = parser.get_conv_operation(module)
shapes = parser.get_shapes(conv_nhwc_hwcf_input_str.splitlines())
assert conv_op is not None
assert isinstance(conv_op.opview, linalg.Conv2DNhwcHwcfOp)
assert shapes.matmul_size == common.ContractionSizes(
M=[2, 32, 32], N=[16], K=[3, 3, 16], B=[]
)
assert shapes.contraction_dims == common.ContractionDimensions(
m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[]
)
assert shapes.lhs_expr_dims == [[0], [1, 4], [2, 5], [6]]
assert shapes.rhs_expr_dims == [[4], [5], [6], [3]]
assert shapes.res_expr_dims == [[0], [1], [2], [3]]
assert shapes.conv_dims == common.ConvolutionDimensions(
batch=[0],
outputImage=[1, 2],
outputChannel=[3],
filterLoop=[4, 5],
inputChannel=[6],
depth=[],
strides=[1, 1],
dilations=[1, 1],
)


def test_get_conv_named_fhwc_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
with ir.Location.unknown():
conv_nhwc_fhwc_input_str = CONVOLUTION_NAMED_OP_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([2, 34, 34, 16], ir.F16Type.get()),
rhs_type=ir.RankedTensorType.get([16, 3, 3, 16], ir.F16Type.get()),
res_type=ir.RankedTensorType.get([2, 32, 32, 16], ir.F16Type.get()),
conv_op="linalg.conv_2d_nhwc_fhwc",
)
module = ir.Module.parse(conv_nhwc_fhwc_input_str, context)
parser = dispatch_parser.ConvolutionOpInterfaceParser()
conv_op = parser.get_conv_operation(module)
shapes = parser.get_shapes(conv_nhwc_fhwc_input_str.splitlines())
assert conv_op is not None
assert isinstance(conv_op.opview, linalg.Conv2DNhwcFhwcOp)
assert shapes.matmul_size == common.ContractionSizes(
M=[2, 32, 32], N=[16], K=[3, 3, 16], B=[]
)
assert shapes.contraction_dims == common.ContractionDimensions(
m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[]
)
assert shapes.lhs_expr_dims == [[0], [1, 4], [2, 5], [6]]
assert shapes.rhs_expr_dims == [[3], [4], [5], [6]]
assert shapes.res_expr_dims == [[0], [1], [2], [3]]
assert shapes.conv_dims == common.ConvolutionDimensions(
batch=[0],
outputImage=[1, 2],
outputChannel=[3],
filterLoop=[4, 5],
inputChannel=[6],
depth=[],
strides=[1, 1],
dilations=[1, 1],
)


GENERICS_CONVOLUTION_TEMPLATE = r"""
builtin.module{{
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : {res_type}
%1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type}
%2 = linalg.generic {{
indexing_maps = [
{lhs_map},
{rhs_map},
{res_map}],
iterator_types = {iterator_types}}}
{{root_op}}
ins(%arg0, %arg1 : {lhs_type}, {rhs_type})
outs(%1 : {res_type}) {{
^bb0(%in: f16, %in_0: f16, %out: f32):
%3 = arith.extf %in : f16 to f32
%4 = arith.extf %in_0 : f16 to f32
%5 = arith.mulf %3, %4 : f32
%6 = arith.addf %out, %5 : f32
linalg.yield %6 : f32
}} -> {res_type}
return %2 : {res_type}
}}
}}"""


def test_get_conv_generic_hwfc_operation(tuner_ctx: common.TunerContext) -> None:
context = tuner_ctx.mlir_ctx
with ir.Location.unknown():
conv_hwfc_input_str = GENERICS_CONVOLUTION_TEMPLATE.format(
lhs_type=ir.RankedTensorType.get([2, 34, 34, 16], ir.F16Type.get()),
rhs_type=ir.RankedTensorType.get([3, 3, 16, 16], ir.F16Type.get()),
res_type=ir.RankedTensorType.get([2, 32, 32, 16], ir.F32Type.get()),
lhs_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1+d4, d2+d5, d6)>",
rhs_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d3, d6)>",
res_map="affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>",
iterator_types='["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]',
)
module = ir.Module.parse(conv_hwfc_input_str, context)
parser = dispatch_parser.ConvolutionOpInterfaceParser()
conv_op = parser.get_conv_operation(module)
shapes = parser.get_shapes(conv_hwfc_input_str.splitlines())
assert conv_op is not None
assert isinstance(conv_op.opview, linalg.GenericOp)
assert shapes.matmul_size == common.ContractionSizes(
M=[2, 32, 32], N=[16], K=[3, 3, 16], B=[]
)
assert shapes.contraction_dims == common.ContractionDimensions(
m=[0, 1, 2], n=[3], k=[4, 5, 6], batch=[]
)
assert shapes.lhs_expr_dims == [[0], [1, 4], [2, 5], [6]]
assert shapes.rhs_expr_dims == [[4], [5], [3], [6]]
assert shapes.res_expr_dims == [[0], [1], [2], [3]]
assert shapes.conv_dims == common.ConvolutionDimensions(
batch=[0],
outputImage=[1, 2],
outputChannel=[3],
filterLoop=[4, 5],
inputChannel=[6],
depth=[],
strides=[1, 1],
dilations=[1, 1],
)


###############################################################################
# Misc Dispatch Parser Tests
###############################################################################


def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/op_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(self) -> None:
super().__init__()
self.supported_named_ops = [
"linalg.conv_2d_nhwc_hwcf",
"linalg.conv_2d_nhwc_hwfc",
"linalg.conv_2d_nhwc_fhwc",
]
self.op_names.extend(self.supported_named_ops)
self.convolution_dimensions: Optional[ConvolutionDimensions] = None
Expand Down

0 comments on commit 42ea5f5

Please sign in to comment.