Skip to content

Commit

Permalink
Add LLVMGPUVectorDistribute check for dispatch registry
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 27, 2024
1 parent 9635551 commit 789614f
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion sharktank/sharktank/tools/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,19 @@ def register(self, dispatch_tuners: list[DispatchTuner]) -> None:
for dispatch_tuner in dispatch_tuners:
self.registry.add(dispatch_tuner)

def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool:
for attr in attrs:
if (attr.name == "translation_info") and (
"LLVMGPUVectorDistribute" in str(attr.attr)
):
return True
assert False, "Translation info not supported"

def find_handler(self, op_name: str) -> DispatchTuner:
for dispatch_tuner in self.registry:
if dispatch_tuner.supports(op_name):
return dispatch_tuner
assert False, "Not supported"
assert False, "Dispatch kind not supported"


class MmtTuner(DispatchTuner):
Expand Down Expand Up @@ -1249,6 +1257,8 @@ def walk_callback_get_fn(
walk_result: OpWalkResult,
dispatch_tuner_registry: DispatchTunerRegistry,
) -> ir.WalkResult:
if op.name == "func.func":
dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes])
if op.name == "util.func":
func_name = str(op.opview.sym_name)
walk_result.was_interrupted = True
Expand Down

0 comments on commit 789614f

Please sign in to comment.