diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 33be28ff1..00a617d8e 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -652,7 +652,7 @@ def generate_configs_and_td_specs( dispatch_tuner = walk_result.dispatch_tuner assert dispatch_tuner, "No suitable dispatch tuner found" - problem_size: ProblemSize = dispatch_tuner.get_shapes(str(input_module)) + problem_size: ProblemSize = dispatch_tuner.get_shapes(str(input_module).splitlines()) tune_logger.debug(str(problem_size)) # Index 0 is reserved for default config, so it gets no td spec. diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index d2c8430ca..b749b7e26 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -129,6 +129,7 @@ def get_shapes(self, template: list[str]) -> ProblemSize: if contraction_op is None: assert False, f"contraction op not found" cdims = matcher.contraction_dimensions + assert cdims, "no contraction dimensions" assert len(cdims.m) == 1, f"must have a single m dimension" assert len(cdims.n) == 1, f"must have a single n dimension" assert len(cdims.k) == 1, f"must have a single k dimension" diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 4349d0da0..a8c98a83e 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -39,7 +39,7 @@ import json from abc import ABC, abstractmethod import iree.runtime as ireert # type: ignore -from iree.compiler import ir +from iree.compiler import ir # type: ignore from . import candidate_gen from . import dispatch_parser from .common import * diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py index 9409f29af..8108019b1 100644 --- a/tuner/tuner/op_matchers.py +++ b/tuner/tuner/op_matchers.py @@ -10,7 +10,6 @@ from .common import * from iree.compiler import ir # type: ignore -from iree.compiler.dialects import linalg def walk_collect_ops( @@ -67,7 +66,6 @@ class NamedOpMatcher: def __init__(self, op_names: list[str]): self.op_names = op_names - @abstractmethod def match(self, op: ir.Operation) -> bool: return op.name in self.op_names diff --git a/tuner/tuner/spec_builder.py b/tuner/tuner/spec_builder.py index cad576439..bc59b33c5 100644 --- a/tuner/tuner/spec_builder.py +++ b/tuner/tuner/spec_builder.py @@ -8,8 +8,7 @@ # in the code and runs it. from iree.compiler import ir # type: ignore -from iree.compiler.dialects import iree_codegen -from iree.compiler.dialects import iree_gpu +from iree.compiler.dialects import iree_codegen # type: ignore from .common import * from .dispatch_constraints import *