From 920cb49887ef815eac2debf7aa2a4bc128f6e6a0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 13 Nov 2024 11:16:08 -0600 Subject: [PATCH] Type rename_inames --- loopy/transform/iname.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 1f318313c..795154099 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -21,7 +21,7 @@ """ -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Collection, Iterable, Mapping, Sequence from typing import Any, FrozenSet, Optional from typing_extensions import TypeAlias @@ -34,6 +34,7 @@ from loopy.kernel import LoopKernel from loopy.kernel.function_interface import CallableKernel from loopy.kernel.instruction import InstructionBase +from loopy.match import ToStackMatchCovertible from loopy.symbolic import ( RuleAwareIdentityMapper, RuleAwareSubstitutionMapper, @@ -2369,8 +2370,14 @@ def add_inames_for_unused_hw_axes(kernel, within=None): @for_each_kernel @remove_any_newly_unused_inames -def rename_inames(kernel, old_inames, new_iname, existing_ok=False, - within=None, raise_on_domain_mismatch: Optional[bool] = None): +def rename_inames( + kernel: LoopKernel, + old_inames: Collection[str], + new_iname: str, + existing_ok: bool = False, + within: ToStackMatchCovertible = None, + raise_on_domain_mismatch: Optional[bool] = None + ) -> LoopKernel: r""" :arg old_inames: A collection of inames that must be renamed to **new_iname**. :arg within: a stack match as understood by @@ -2380,7 +2387,6 @@ def rename_inames(kernel, old_inames, new_iname, existing_ok=False, :math:`\exists (i_1,i_2) \in \{\text{old\_inames}\}^2 | \mathcal{D}_{i_1} \neq \mathcal{D}_{i_2}`. """ - from collections.abc import Collection if (isinstance(old_inames, str) or not isinstance(old_inames, Collection)): raise LoopyError("'old_inames' must be a collection of strings, " @@ -2508,9 +2514,15 @@ def does_insn_involve_iname(kernel, insn, *args): @for_each_kernel -def rename_iname(kernel, old_iname, new_iname, existing_ok=False, - within=None, preserve_tags=True, - raise_on_domain_mismatch: Optional[bool] = None): +def rename_iname( + kernel: LoopKernel, + old_iname: str, + new_iname: str, + existing_ok: bool = False, + within: ToStackMatchCovertible = None, + preserve_tags: bool = True, + raise_on_domain_mismatch: Optional[bool] = None + ) -> LoopKernel: r""" Single iname version of :func:`loopy.rename_inames`. :arg existing_ok: execute even if *new_iname* already exists. @@ -2528,7 +2540,7 @@ def rename_iname(kernel, old_iname, new_iname, existing_ok=False, kernel = rename_inames(kernel, [old_iname], new_iname, existing_ok, within, raise_on_domain_mismatch) if preserve_tags: - kernel = tag_inames(kernel, product([new_iname], tags)) + kernel = tag_inames(kernel, list(product([new_iname], tags))) return kernel # }}}