Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace pyrsistent.PMap, immutables.Map with immutabledict #884

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"pyopencl": ("https://documen.tician.de/pyopencl", None),
"cgen": ("https://documen.tician.de/cgen", None),
"pymbolic": ("https://documen.tician.de/pymbolic", None),
"pyrsistent": ("https://pyrsistent.readthedocs.io/en/latest/", None),
"immutabledict": ("https://immutabledict.corenting.fr/", None),
}

# Some modules need to import things just so that sphinx can resolve symbols in
Expand All @@ -57,10 +57,6 @@
["py:class", r"numpy\.float[0-9]+"],
["py:class", r"numpy\.complex[0-9]+"],

# As of 2022-06-22, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],

# Reference not found from "<unknown>"? I'm not even sure where to look.
["py:class", r"ExpressionNode"],
]
Expand Down
8 changes: 4 additions & 4 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Union,
)

from immutables import Map
from immutabledict import immutabledict

from loopy.codegen.result import CodeGenerationResult
from loopy.library.reduction import ReductionOpFunction
Expand Down Expand Up @@ -207,7 +207,7 @@ class CodeGenerationState:
seen_functions: Set[SeenFunction]
seen_atomic_dtypes: Set[LoopyType]

var_subst_map: Map[str, Expression]
var_subst_map: immutabledict[str, Expression]
allow_complex: bool
callables_table: CallablesTable
is_entrypoint: bool
Expand Down Expand Up @@ -418,7 +418,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
seen_dtypes=seen_dtypes,
seen_functions=seen_functions,
seen_atomic_dtypes=seen_atomic_dtypes,
var_subst_map=Map(),
var_subst_map=immutabledict(),
allow_complex=allow_complex,
var_name_generator=kernel.get_var_name_generator(),
is_generating_device_code=False,
Expand Down Expand Up @@ -519,7 +519,7 @@ def diverge_callee_entrypoints(program):

new_callables[name] = clbl

return program.copy(callables_table=Map(new_callables))
return program.copy(callables_table=immutabledict(new_callables))


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions loopy/frontend/fortran/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from immutabledict import immutabledict

import islpy as isl
from islpy import dim_type
Expand Down Expand Up @@ -331,7 +331,7 @@ def specialize_fortran_division(t_unit):

new_callables[name] = clbl

return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=immutabledict(new_callables))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from immutabledict import immutabledict

import islpy as isl
from islpy import dim_type
Expand Down Expand Up @@ -178,7 +178,7 @@ class LoopKernel(Taggable):
Callable[["LoopKernel", str], Optional[Tuple[LoopyType, str]]]] = ()
linearization: Optional[Sequence[ScheduleItem]] = None
iname_slab_increments: Mapping[InameStr, Tuple[int, int]] = field(
default_factory=Map)
default_factory=immutabledict)
"""
A mapping from inames to (lower_incr,
upper_incr) tuples that will be separated out in the execution to generate
Expand Down
5 changes: 2 additions & 3 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
THE SOFTWARE.
"""


from collections.abc import Mapping
from dataclasses import dataclass, replace
from enum import IntEnum
from sys import intern
Expand All @@ -43,7 +43,6 @@

import numpy # FIXME: imported as numpy to allow sphinx to resolve things
import numpy as np
from immutables import Map

from pymbolic import ArithmeticExpression, Variable
from pytools import ImmutableRecord
Expand Down Expand Up @@ -434,7 +433,7 @@ class _ArraySeparationInfo:
should be used to realize this array.
"""
sep_axis_indices_set: FrozenSet[int]
subarray_names: Map[Tuple[int, ...], str]
subarray_names: Mapping[Tuple[int, ...], str]


class ArrayArg(ArrayBase, KernelArgument):
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):

:arg t_unit: An instance of :class:`TranslationUnit`.
"""
from pyrsistent import pmap
from immutabledict import immutabledict

from loopy.kernel import KernelState

Expand All @@ -2111,7 +2111,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):
call_graph[name] = clbl.get_called_callables(t_unit.callables_table,
recursive=False)

return pmap(call_graph)
return immutabledict(call_graph)

# }}}

Expand Down
12 changes: 5 additions & 7 deletions loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from functools import partial

import numpy as np
from immutables import Map
from immutabledict import immutabledict

from pytools import ProcessLogger

Expand Down Expand Up @@ -191,7 +191,7 @@ def make_arrays_for_sep_arrays(kernel: LoopKernel) -> LoopKernel:

sep_info = _ArraySeparationInfo(
sep_axis_indices_set=sep_axis_indices_set,
subarray_names=Map({
subarray_names=immutabledict({
ind: vng(f"{arg.name}_s{'_'.join(str(i) for i in ind)}")
for ind in np.ndindex(*cast(List[int], sep_shape))}))

Expand Down Expand Up @@ -599,8 +599,6 @@ def map_call_with_kwargs(self, expr):
raise NotImplementedError

def __call__(self, expr, kernel, insn, assignees=None):
import immutables

from loopy.kernel.data import InstructionBase
from loopy.symbolic import ExpansionState, UncachedIdentityMapper
assert insn is None or isinstance(insn, InstructionBase)
Expand All @@ -610,7 +608,7 @@ def __call__(self, expr, kernel, insn, assignees=None):
kernel=kernel,
instruction=insn,
stack=(),
arg_context=immutables.Map()), assignees=assignees)
arg_context=immutabledict()), assignees=assignees)

def map_kernel(self, kernel):

Expand Down Expand Up @@ -744,7 +742,7 @@ def filter_reachable_callables(t_unit):
t_unit.entrypoints)
new_callables = {name: clbl for name, clbl in t_unit.callables_table.items()
if name in (reachable_function_ids | t_unit.entrypoints)}
return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=immutabledict(new_callables))


def _preprocess_single_kernel(kernel: LoopKernel, is_entrypoint: bool) -> LoopKernel:
Expand Down Expand Up @@ -869,7 +867,7 @@ def preprocess_program(t_unit: TranslationUnit) -> TranslationUnit:

new_callables[func_id] = in_knl_callable

t_unit = t_unit.copy(callables_table=Map(new_callables))
t_unit = t_unit.copy(callables_table=immutabledict(new_callables))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TypeVar,
)

from immutables import Map
from immutabledict import immutabledict

import islpy as isl
from pytools import ImmutableRecord, MinRecursionLimit, ProcessLogger
Expand Down Expand Up @@ -2482,7 +2482,7 @@ def linearize(t_unit: TranslationUnit) -> TranslationUnit:
else:
raise NotImplementedError(type(clbl))

return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=immutabledict(new_callables))


# vim: foldmethod=marker
4 changes: 2 additions & 2 deletions loopy/schedule/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from functools import cached_property, reduce
from typing import AbstractSet, Dict, FrozenSet, List, Sequence, Set, Tuple

from immutables import Map
from immutabledict import immutabledict
from typing_extensions import TypeAlias

import islpy as isl
Expand Down Expand Up @@ -1048,7 +1048,7 @@ def _get_iname_to_tree_node_id_from_partial_loop_nest_tree(
for iname in node:
iname_to_tree_node_id[iname] = node

return Map(iname_to_tree_node_id)
return immutabledict(iname_to_tree_node_id)


def get_loop_tree(kernel: LoopKernel) -> LoopTree:
Expand Down
18 changes: 9 additions & 9 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from functools import cached_property, reduce
from typing import Generic, TypeVar

from immutables import Map
from immutabledict import immutabledict

from pytools import memoize_method

Expand Down Expand Up @@ -72,13 +72,13 @@ class Tree(Generic[NodeT]):
this allocates a new stack frame for each iteration of the operation.
"""

_parent_to_children: Map[NodeT, tuple[NodeT, ...]]
_child_to_parent: Map[NodeT, NodeT | None]
_parent_to_children: immutabledict[NodeT, tuple[NodeT, ...]]
_child_to_parent: immutabledict[NodeT, NodeT | None]

@staticmethod
def from_root(root: NodeT) -> Tree[NodeT]:
return Tree(Map({root: ()}),
Map({root: None}))
return Tree(immutabledict({root: ()}),
immutabledict({root: None}))

@cached_property
def root(self) -> NodeT:
Expand Down Expand Up @@ -183,7 +183,7 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:

# {{{ update child to parent

child_to_parent_mut = self._child_to_parent.mutate()
child_to_parent_mut = dict(self._child_to_parent)
del child_to_parent_mut[node]
child_to_parent_mut[new_node] = parent

Expand All @@ -194,7 +194,7 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:

# {{{ update parent_to_children

parent_to_children_mut = self._parent_to_children.mutate()
parent_to_children_mut = dict(self._parent_to_children)
del parent_to_children_mut[node]
parent_to_children_mut[new_node] = children

Expand All @@ -206,8 +206,8 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]:

# }}}

return Tree(parent_to_children_mut.finish(),
child_to_parent_mut.finish())
return Tree(immutabledict(parent_to_children_mut),
immutabledict(child_to_parent_mut))

def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]:
"""
Expand Down
12 changes: 6 additions & 6 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
)
from warnings import warn

import immutables
import numpy as np
from immutabledict import immutabledict

import islpy as isl
import pymbolic.primitives # FIXME: also import by full name to allow sphinx to resolve
Expand Down Expand Up @@ -1036,12 +1036,12 @@ class ExpansionState(ImmutableRecord):
a dict representing current argument values
"""
def __init__(self, kernel, instruction, stack, arg_context):
if not isinstance(arg_context, immutables.Map):
if not isinstance(arg_context, immutabledict):
warn(f"Got a {type(arg_context)} for arg_context,"
" expected `immutables.Map`. This is deprecated"
" expected `immutabledict`. This is deprecated"
" and will result in an error from 2023.",
DeprecationWarning, stacklevel=2)
arg_context = immutables.Map(arg_context)
arg_context = immutabledict(arg_context)
super().__init__(kernel=kernel,
instruction=instruction,
stack=stack,
Expand Down Expand Up @@ -1274,7 +1274,7 @@ def make_new_arg_context(

from pymbolic.mapper.substitutor import make_subst_func
arg_subst_map = SubstitutionMapper(make_subst_func(arg_context))
return immutables.Map({
return immutabledict({
formal_arg_name: arg_subst_map(arg_value)
for formal_arg_name, arg_value in zip(arg_names, arguments)})

Expand Down Expand Up @@ -1317,7 +1317,7 @@ def __call__(self, expr, kernel, insn):
kernel=kernel,
instruction=insn,
stack=(),
arg_context=immutables.Map()))
arg_context=immutabledict()))

def map_instruction(self, kernel, insn):
return insn
Expand Down
4 changes: 2 additions & 2 deletions loopy/target/c/c_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import logging
import os
import tempfile
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Optional, Sequence, Tuple, Union

import numpy as np
from codepy.jit import compile_from_string
from codepy.toolchain import GCCToolchain, ToolchainGuessError, guess_toolchain
from immutables import Map

from pytools import memoize_method
from pytools.codegen import CodeGenerator, Indentation
Expand Down Expand Up @@ -493,7 +493,7 @@ def get_wrapper_generator(self):

@memoize_method
def translation_unit_info(self,
arg_to_dtype: Optional[Map[str, LoopyType]] = None) -> _KernelInfo:
arg_to_dtype: Optional[Mapping[str, LoopyType]] = None) -> _KernelInfo:
t_unit = self.get_typed_and_scheduled_translation_unit(arg_to_dtype)

from loopy.codegen import generate_code_v2
Expand Down
Loading
Loading