From 789614fc8bc7bf9d38be08a637adf0b33ec8c950 Mon Sep 17 00:00:00 2001 From: Amily Wu Date: Tue, 27 Aug 2024 11:42:39 -0500 Subject: [PATCH] Add LLVMGPUVectorDistribute check for dispatch registry --- sharktank/sharktank/tools/tuner/candidate_gen.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sharktank/sharktank/tools/tuner/candidate_gen.py b/sharktank/sharktank/tools/tuner/candidate_gen.py index be31f8914..d19551621 100755 --- a/sharktank/sharktank/tools/tuner/candidate_gen.py +++ b/sharktank/sharktank/tools/tuner/candidate_gen.py @@ -529,11 +529,19 @@ def register(self, dispatch_tuners: list[DispatchTuner]) -> None: for dispatch_tuner in dispatch_tuners: self.registry.add(dispatch_tuner) + def validate_translation(self, attrs: list[ir.NamedAttribute]) -> bool: + for attr in attrs: + if (attr.name == "translation_info") and ( + "LLVMGPUVectorDistribute" in str(attr.attr) + ): + return True + assert False, "Translation info not supported" + def find_handler(self, op_name: str) -> DispatchTuner: for dispatch_tuner in self.registry: if dispatch_tuner.supports(op_name): return dispatch_tuner - assert False, "Not supported" + assert False, "Dispatch kind not supported" class MmtTuner(DispatchTuner): @@ -1249,6 +1257,8 @@ def walk_callback_get_fn( walk_result: OpWalkResult, dispatch_tuner_registry: DispatchTunerRegistry, ) -> ir.WalkResult: + if op.name == "func.func": + dispatch_tuner_registry.validate_translation([a for a in op.opview.attributes]) if op.name == "util.func": func_name = str(op.opview.sym_name) walk_result.was_interrupted = True