Skip to content

Commit

Permalink
Type rename_inames
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 13, 2024
1 parent d52c290 commit 920cb49
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions loopy/transform/iname.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""


from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Collection, Iterable, Mapping, Sequence
from typing import Any, FrozenSet, Optional

from typing_extensions import TypeAlias
Expand All @@ -34,6 +34,7 @@
from loopy.kernel import LoopKernel
from loopy.kernel.function_interface import CallableKernel
from loopy.kernel.instruction import InstructionBase
from loopy.match import ToStackMatchCovertible
from loopy.symbolic import (
RuleAwareIdentityMapper,
RuleAwareSubstitutionMapper,
Expand Down Expand Up @@ -2369,8 +2370,14 @@ def add_inames_for_unused_hw_axes(kernel, within=None):

@for_each_kernel
@remove_any_newly_unused_inames
def rename_inames(kernel, old_inames, new_iname, existing_ok=False,
within=None, raise_on_domain_mismatch: Optional[bool] = None):
def rename_inames(
kernel: LoopKernel,
old_inames: Collection[str],
new_iname: str,
existing_ok: bool = False,
within: ToStackMatchCovertible = None,
raise_on_domain_mismatch: Optional[bool] = None
) -> LoopKernel:
r"""
:arg old_inames: A collection of inames that must be renamed to **new_iname**.
:arg within: a stack match as understood by
Expand All @@ -2380,7 +2387,6 @@ def rename_inames(kernel, old_inames, new_iname, existing_ok=False,
:math:`\exists (i_1,i_2) \in \{\text{old\_inames}\}^2 |
\mathcal{D}_{i_1} \neq \mathcal{D}_{i_2}`.
"""
from collections.abc import Collection
if (isinstance(old_inames, str)
or not isinstance(old_inames, Collection)):
raise LoopyError("'old_inames' must be a collection of strings, "
Expand Down Expand Up @@ -2508,9 +2514,15 @@ def does_insn_involve_iname(kernel, insn, *args):


@for_each_kernel
def rename_iname(kernel, old_iname, new_iname, existing_ok=False,
within=None, preserve_tags=True,
raise_on_domain_mismatch: Optional[bool] = None):
def rename_iname(
kernel: LoopKernel,
old_iname: str,
new_iname: str,
existing_ok: bool = False,
within: ToStackMatchCovertible = None,
preserve_tags: bool = True,
raise_on_domain_mismatch: Optional[bool] = None
) -> LoopKernel:
r"""
Single iname version of :func:`loopy.rename_inames`.
:arg existing_ok: execute even if *new_iname* already exists.
Expand All @@ -2528,7 +2540,7 @@ def rename_iname(kernel, old_iname, new_iname, existing_ok=False,
kernel = rename_inames(kernel, [old_iname], new_iname, existing_ok,
within, raise_on_domain_mismatch)
if preserve_tags:
kernel = tag_inames(kernel, product([new_iname], tags))
kernel = tag_inames(kernel, list(product([new_iname], tags)))
return kernel

# }}}
Expand Down

0 comments on commit 920cb49

Please sign in to comment.