Skip to content

Commit

Permalink
Fix tag_inames to apply multiple tags, type it
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 12, 2024
1 parent e02d390 commit e9ec2a3
Showing 1 changed file with 45 additions and 47 deletions.
92 changes: 45 additions & 47 deletions loopy/transform/iname.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@
"""


from collections.abc import Iterable, Mapping
from typing import Any, FrozenSet, Optional

from typing_extensions import TypeAlias

import islpy as isl
from islpy import dim_type
from pyopencl import Sequence
from pytools.tag import Tag

from loopy.diagnostic import LoopyError
from loopy.kernel import LoopKernel
Expand Down Expand Up @@ -675,9 +680,18 @@ def untag_inames(kernel, iname_to_untag, tag_type):

# {{{ tag inames

_Tags_ish: TypeAlias = Tag | Sequence[Tag] | str | Sequence[str]


@for_each_kernel
def tag_inames(kernel, iname_to_tag, force=False,
ignore_nonexistent=False):
def tag_inames(
kernel: LoopKernel,
iname_to_tag: (Mapping[str, _Tags_ish]
| Sequence[tuple[str, _Tags_ish]]
| str),
force: bool = False,
ignore_nonexistent: bool = False
) -> LoopKernel:
"""Tag an iname
:arg iname_to_tag: a list of tuples ``(iname, new_tag)``. *new_tag* is given
Expand All @@ -697,97 +711,81 @@ def tag_inames(kernel, iname_to_tag, force=False,
"""

if isinstance(iname_to_tag, str):
def parse_kv(s):
def parse_kv(s: str) -> tuple[str, str]:
colon_index = s.find(":")
if colon_index == -1:
raise ValueError("tag decl '%s' has no colon" % s)

return (s[:colon_index].strip(), s[colon_index+1:].strip())

iname_to_tag = [
iname_to_tags_seq = [
parse_kv(s) for s in iname_to_tag.split(",")
if s.strip()]
elif isinstance(iname_to_tag, Mapping):
iname_to_tags_seq = list(iname_to_tag.items())
else:
iname_to_tags_seq = iname_to_tag

if not iname_to_tag:
return kernel

# convert dict to list of tuples
if isinstance(iname_to_tag, dict):
iname_to_tag = list(iname_to_tag.items())

# flatten iterables of tags for each iname

try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable # pylint:disable=no-name-in-module

unpack_iname_to_tag = []
for iname, tags in iname_to_tag:
unpack_iname_to_tag: list[tuple[str, Tag | str]] = []
for iname, tags in iname_to_tags_seq:
if isinstance(tags, Iterable) and not isinstance(tags, str):
for tag in tags:
unpack_iname_to_tag.append((iname, tag))
else:
unpack_iname_to_tag.append((iname, tags))
iname_to_tag = unpack_iname_to_tag

from loopy.kernel.data import parse_tag as inner_parse_tag

def parse_tag(tag):
def parse_tag(tag: Tag | str) -> Iterable[Tag]:
if isinstance(tag, str):
if tag.startswith("like."):
tags = kernel.iname_tags(tag[5:])
if len(tags) == 0:
return None
if len(tags) == 1:
return tags[0]
else:
raise LoopyError("cannot use like for multiple tags (for now)")
return kernel.iname_tags(tag[5:])
elif tag == "unused.g":
return find_unused_axis_tag(kernel, "g")
elif tag == "unused.l":
return find_unused_axis_tag(kernel, "l")

return inner_parse_tag(tag)

iname_to_tag = [(iname, parse_tag(tag)) for iname, tag in iname_to_tag]
result = inner_parse_tag(tag)
if result is None:
return []
else:
return [result]

# {{{ globbing
iname_to_parsed_tag = [
(iname, subtag)
for iname, tag in unpack_iname_to_tag
for subtag in parse_tag(tag)
]

knl_inames = dict(kernel.inames)
all_inames = kernel.all_inames()

from loopy.match import re_from_glob
new_iname_to_tag = {}
for iname, new_tag in iname_to_tag:

for iname, new_tag in iname_to_parsed_tag:
if "*" in iname or "?" in iname:
match_re = re_from_glob(iname)
for sub_iname in all_inames:
if match_re.match(sub_iname):
new_iname_to_tag[sub_iname] = new_tag

inames = [sub_iname for sub_iname in all_inames
if match_re.match(sub_iname)]
else:
if iname not in all_inames:
if ignore_nonexistent:
continue
else:
raise LoopyError("iname '%s' does not exist" % iname)

new_iname_to_tag[iname] = new_tag

iname_to_tag = new_iname_to_tag
del new_iname_to_tag
inames = [iname]

# }}}

knl_inames = kernel.inames.copy()
for name, new_tag in iname_to_tag.items():
if not new_tag:
if new_tag is None:
continue

if name not in kernel.all_inames():
raise ValueError("cannot tag '%s'--not known" % name)

knl_inames[name] = knl_inames[name].tagged(new_tag)
for sub_iname in inames:
knl_inames[sub_iname] = knl_inames[iname].tagged(new_tag)

return kernel.copy(inames=knl_inames)

Expand Down

0 comments on commit e9ec2a3

Please sign in to comment.