Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Dec 12, 2024
1 parent fb7846f commit d40a223
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def generate_configs_and_td_specs(

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(str(input_module))
problem_size: ProblemSize = dispatch_tuner.get_shapes(str(input_module).splitlines())
tune_logger.debug(str(problem_size))

# Index 0 is reserved for default config, so it gets no td spec.
Expand Down
1 change: 1 addition & 0 deletions tuner/tuner/dispatch_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
if contraction_op is None:
assert False, f"contraction op not found"
cdims = matcher.contraction_dimensions
assert cdims, "no contraction dimensions"
assert len(cdims.m) == 1, f"must have a single m dimension"
assert len(cdims.n) == 1, f"must have a single n dimension"
assert len(cdims.k) == 1, f"must have a single k dimension"
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import json
from abc import ABC, abstractmethod
import iree.runtime as ireert # type: ignore
from iree.compiler import ir
from iree.compiler import ir # type: ignore
from . import candidate_gen
from . import dispatch_parser
from .common import *
Expand Down
2 changes: 0 additions & 2 deletions tuner/tuner/op_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from .common import *
from iree.compiler import ir # type: ignore
from iree.compiler.dialects import linalg


def walk_collect_ops(
Expand Down Expand Up @@ -67,7 +66,6 @@ class NamedOpMatcher:
def __init__(self, op_names: list[str]):
self.op_names = op_names

@abstractmethod
def match(self, op: ir.Operation) -> bool:
return op.name in self.op_names

Expand Down
3 changes: 1 addition & 2 deletions tuner/tuner/spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
# in the code and runs it.

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_codegen
from iree.compiler.dialects import iree_gpu
from iree.compiler.dialects import iree_codegen # type: ignore

from .common import *
from .dispatch_constraints import *
Expand Down

0 comments on commit d40a223

Please sign in to comment.