Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace pyrsistent.PMap, immutables.Map, immutabledict with constantdict #884

Merged
merged 16 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"pyopencl": ("https://documen.tician.de/pyopencl", None),
"cgen": ("https://documen.tician.de/cgen", None),
"pymbolic": ("https://documen.tician.de/pymbolic", None),
"pyrsistent": ("https://pyrsistent.readthedocs.io/en/latest/", None),
"constantdict": ("https://matthiasdiener.github.io/constantdict/", None),
}

nitpicky = True
Expand All @@ -43,10 +43,6 @@
["py:class", r"numpy\.float[0-9]+"],
["py:class", r"numpy\.complex[0-9]+"],

# As of 2022-06-22, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],

# Reference not found from "<unknown>"? I'm not even sure where to look.
["py:class", r"ExpressionNode"],

Expand Down
8 changes: 4 additions & 4 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Sequence,
)

import immutables
import constantdict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -168,7 +168,7 @@ class CodeGenerationState:
seen_functions: set[SeenFunction]
seen_atomic_dtypes: set[LoopyType]

var_subst_map: immutables.Map[str, Expression]
var_subst_map: constantdict.constantdict[str, Expression]
allow_complex: bool
callables_table: CallablesTable
is_entrypoint: bool
Expand Down Expand Up @@ -381,7 +381,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
seen_dtypes=seen_dtypes,
seen_functions=seen_functions,
seen_atomic_dtypes=seen_atomic_dtypes,
var_subst_map=immutables.Map(),
var_subst_map=constantdict.constantdict(),
allow_complex=allow_complex,
var_name_generator=kernel.get_var_name_generator(),
is_generating_device_code=False,
Expand Down Expand Up @@ -482,7 +482,7 @@ def diverge_callee_entrypoints(program):

new_callables[name] = clbl

return program.copy(callables_table=immutables.Map(new_callables))
return program.copy(callables_table=constantdict.constantdict(new_callables))


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions loopy/frontend/fortran/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from constantdict import constantdict

import islpy as isl
from islpy import dim_type
Expand Down Expand Up @@ -334,7 +334,7 @@ def specialize_fortran_division(t_unit):

new_callables[name] = clbl

return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=constantdict(new_callables))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from constantdict import constantdict

import islpy # to help out Sphinx
import islpy as isl
Expand Down Expand Up @@ -183,7 +183,7 @@ class LoopKernel(Taggable):
Callable[[LoopKernel, str], tuple[LoopyType, str] | None]] = ()
linearization: Sequence[ScheduleItem] | None = None
iname_slab_increments: Mapping[InameStr, tuple[int, int]] = field(
default_factory=Map)
default_factory=constantdict)
"""
A mapping from inames to (lower_incr,
upper_incr) tuples that will be separated out in the execution to generate
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@


if TYPE_CHECKING:
from immutables import Map
from collections.abc import Mapping

from pymbolic import ArithmeticExpression, Variable

Expand Down Expand Up @@ -437,7 +437,7 @@ class _ArraySeparationInfo:
should be used to realize this array.
"""
sep_axis_indices_set: frozenset[int]
subarray_names: Map[tuple[int, ...], str]
subarray_names: Mapping[tuple[int, ...], str]


class ArrayArg(ArrayBase, KernelArgument):
Expand Down
12 changes: 7 additions & 5 deletions loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from warnings import warn

from immutabledict import immutabledict
from constantdict import constantdict
from typing_extensions import Self

from loopy.diagnostic import LoopyError
Expand Down Expand Up @@ -348,15 +348,17 @@ def __init__(self,
try:
hash(arg_id_to_dtype)
except TypeError:
arg_id_to_dtype = immutabledict(arg_id_to_dtype)
assert arg_id_to_dtype is not None
arg_id_to_dtype = constantdict(arg_id_to_dtype)
warn("arg_id_to_dtype passed to InKernelCallable was not hashable. "
"This usage is deprecated and will stop working in 2026.",
DeprecationWarning, stacklevel=3)

try:
hash(arg_id_to_descr)
except TypeError:
arg_id_to_descr = immutabledict(arg_id_to_descr)
assert arg_id_to_descr is not None
arg_id_to_descr = constantdict(arg_id_to_descr)
warn("arg_id_to_descr passed to InKernelCallable was not hashable. "
"This usage is deprecated and will stop working in 2026.",
DeprecationWarning, stacklevel=3)
Expand Down Expand Up @@ -773,7 +775,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
# Return the kernel call with specialized subkernel and the corresponding
# new arg_id_to_dtype
return self.copy(subkernel=specialized_kernel,
arg_id_to_dtype=immutabledict(new_arg_id_to_dtype)), callables_table
arg_id_to_dtype=constantdict(new_arg_id_to_dtype)), callables_table

def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):

Expand Down Expand Up @@ -848,7 +850,7 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
# }}}

return (self.copy(subkernel=subkernel,
arg_id_to_descr=immutabledict(arg_id_to_descr)),
arg_id_to_descr=constantdict(arg_id_to_descr)),
clbl_inf_ctx)

def with_added_arg(self, arg_dtype, arg_descr):
Expand Down
16 changes: 8 additions & 8 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init__(self,
*,
depends_on: frozenset[str] | str | None = None,
) -> None:
from immutabledict import immutabledict
from constantdict import constantdict

if predicates is None:
predicates = frozenset()
Expand Down Expand Up @@ -324,29 +324,29 @@ def __init__(self,
raise LoopyError("Setting depends_on_is_final to True requires "
"actually specifying happens_after/depends_on")

if isinstance(happens_after, immutabledict):
if isinstance(happens_after, constantdict):
pass
elif happens_after is None:
happens_after = immutabledict()
happens_after = constantdict()
elif isinstance(happens_after, str):
warn("Passing a string for happens_after/depends_on is deprecated and "
"will stop working in 2025. Instead, pass a full-fledged "
"happens_after data structure.", DeprecationWarning, stacklevel=2)

happens_after = immutabledict({
happens_after = constantdict({
after_id.strip(): HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after.split(",")
if after_id.strip()})
elif isinstance(happens_after, frozenset):
happens_after = immutabledict({
happens_after = constantdict({
after_id: HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after})
elif isinstance(happens_after, dict):
happens_after = immutabledict(happens_after)
happens_after = constantdict(happens_after)
else:
raise TypeError("'happens_after' has unexpected type: "
f"{type(happens_after)}")
Expand Down Expand Up @@ -569,13 +569,13 @@ def update_persistent_hash(self, key_hash, key_builder):
def __setstate__(self, val):
super().__setstate__(val)

from immutabledict import immutabledict
from constantdict import constantdict

from loopy.tools import intern_frozenset_of_ids

if self.id is not None: # pylint:disable=access-member-before-definition
self.id = intern(self.id)
self.happens_after = immutabledict({
self.happens_after = constantdict({
intern(after_id): ha
for after_id, ha in self.happens_after.items()})
self.groups = intern_frozenset_of_ids(self.groups)
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,7 +2089,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):

:arg t_unit: An instance of :class:`TranslationUnit`.
"""
from pyrsistent import pmap
from constantdict import constantdict

from loopy.kernel import KernelState

Expand All @@ -2116,7 +2116,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):
call_graph[name] = clbl.get_called_callables(t_unit.callables_table,
recursive=False)

return pmap(call_graph)
return constantdict(call_graph)

# }}}

Expand Down
12 changes: 5 additions & 7 deletions loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from functools import partial

import numpy as np
from immutables import Map
from constantdict import constantdict

from pytools import ProcessLogger

Expand Down Expand Up @@ -197,7 +197,7 @@ def make_arrays_for_sep_arrays(kernel: LoopKernel) -> LoopKernel:

sep_info = _ArraySeparationInfo(
sep_axis_indices_set=sep_axis_indices_set,
subarray_names=Map({
subarray_names=constantdict({
ind: vng(f"{arg.name}_s{'_'.join(str(i) for i in ind)}")
for ind in np.ndindex(*cast("List[int]", sep_shape))}))

Expand Down Expand Up @@ -605,8 +605,6 @@ def map_call_with_kwargs(self, expr):
raise NotImplementedError

def __call__(self, expr, kernel, insn, assignees=None):
import immutables

from loopy.kernel.data import InstructionBase
from loopy.symbolic import ExpansionState, UncachedIdentityMapper
assert insn is None or isinstance(insn, InstructionBase)
Expand All @@ -616,7 +614,7 @@ def __call__(self, expr, kernel, insn, assignees=None):
kernel=kernel,
instruction=insn,
stack=(),
arg_context=immutables.Map()), assignees=assignees)
arg_context=constantdict()), assignees=assignees)

def map_kernel(self, kernel):

Expand Down Expand Up @@ -750,7 +748,7 @@ def filter_reachable_callables(t_unit):
t_unit.entrypoints)
new_callables = {name: clbl for name, clbl in t_unit.callables_table.items()
if name in (reachable_function_ids | t_unit.entrypoints)}
return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=constantdict(new_callables))


def _preprocess_single_kernel(kernel: LoopKernel, is_entrypoint: bool) -> LoopKernel:
Expand Down Expand Up @@ -875,7 +873,7 @@ def preprocess_program(t_unit: TranslationUnit) -> TranslationUnit:

new_callables[func_id] = in_knl_callable

t_unit = t_unit.copy(callables_table=Map(new_callables))
t_unit = t_unit.copy(callables_table=constantdict(new_callables))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TypeVar,
)

from immutables import Map
from constantdict import constantdict

import islpy as isl
from pytools import ImmutableRecord, MinRecursionLimit, ProcessLogger
Expand Down Expand Up @@ -2480,7 +2480,7 @@ def linearize(t_unit: TranslationUnit) -> TranslationUnit:
else:
raise NotImplementedError(type(clbl))

return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=constantdict(new_callables))


# vim: foldmethod=marker
4 changes: 2 additions & 2 deletions loopy/schedule/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from functools import cached_property, reduce
from typing import TYPE_CHECKING, AbstractSet, Sequence

from immutables import Map
from constantdict import constantdict
from typing_extensions import TypeAlias

import islpy as isl
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def _get_iname_to_tree_node_id_from_partial_loop_nest_tree(
for iname in node:
iname_to_tree_node_id[iname] = node

return Map(iname_to_tree_node_id)
return constantdict(iname_to_tree_node_id)


def get_loop_tree(kernel: LoopKernel) -> LoopTree:
Expand Down
2 changes: 1 addition & 1 deletion loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from functools import cached_property, reduce
from typing import Generic, TypeVar

from immutables import Map
from constantdict import constantdict as Map # noqa: N812

from pytools import memoize_method

Expand Down
8 changes: 4 additions & 4 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
)
from warnings import warn

import immutables
import numpy as np
from constantdict import constantdict
from typing_extensions import Self

import islpy as isl
Expand Down Expand Up @@ -1114,7 +1114,7 @@
kernel: LoopKernel
instruction: InstructionBase
stack: tuple[tuple[str, Tag], ...]
arg_context: immutables.Map[str, Expression]
arg_context: Mapping[str, Expression]

def __post_init__(self) -> None:
hash(self.arg_context)
Expand Down Expand Up @@ -1352,7 +1352,7 @@

from pymbolic.mapper.substitutor import make_subst_func
arg_subst_map = SubstitutionMapper(make_subst_func(arg_context))
return immutables.Map({
return constantdict({
formal_arg_name: arg_subst_map(arg_value)
for formal_arg_name, arg_value in zip(arg_names, arguments)})

Expand Down Expand Up @@ -1398,7 +1398,7 @@
kernel=kernel,
instruction=insn,
stack=(),
arg_context=immutables.Map()))
arg_context=constantdict()))

def map_instruction(self, kernel, insn):
return insn
Expand Down Expand Up @@ -1740,7 +1740,7 @@
# pstate.expect(_colon):
pstate.advance()
subscript = self.parse_expression(pstate, _PREC_UNARY)
return SubArrayRef(swept_inames, subscript)

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1743 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.
else:
pstate = rollback_pstate
return super().parse_prefix(rollback_pstate)
Expand Down
4 changes: 2 additions & 2 deletions loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


if TYPE_CHECKING:
from immutables import Map
from constantdict import constantdict

from loopy.codegen.result import GeneratedProgram
from loopy.kernel import LoopKernel
Expand Down Expand Up @@ -500,7 +500,7 @@ def get_wrapper_generator(self):

@memoize_method
def translation_unit_info(self,
arg_to_dtype: Map[str, LoopyType] | None = None) -> _KernelInfo:
arg_to_dtype: constantdict[str, LoopyType] | None = None) -> _KernelInfo:
t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)

from loopy.codegen import generate_code_v2
Expand Down
Loading
Loading