diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index f41e5d4..ce37729 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -537,7 +537,10 @@ def triton_kernel_call_lowering( named_args = dict(unsafe_zip(fn.arg_names, args)) if isinstance(fn, autotuner.Autotuner): - key_idxs = [fn.arg_names.index(k) for k in fn.keys] + if hasattr(fn, "key_idx"): + key_idxs = copy.copy(fn.key_idx) + else: + key_idxs = [fn.arg_names.index(k) for k in fn.keys] if any(idx not in key_idxs for idx, _, _ in scalar_args): logging.warning( "Auto-tuning key does not include all scalar arguments. "