Skip to content

Commit

Permalink
Merge pull request #32 from firedrakeproject/connorjward/merge-upstream
Browse files Browse the repository at this point in the history
merge upstream
  • Loading branch information
connorjward authored Feb 5, 2025
2 parents ad07454 + 31548ad commit 27aead5
Show file tree
Hide file tree
Showing 40 changed files with 234 additions and 222 deletions.
6 changes: 1 addition & 5 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"pyopencl": ("https://documen.tician.de/pyopencl", None),
"cgen": ("https://documen.tician.de/cgen", None),
"pymbolic": ("https://documen.tician.de/pymbolic", None),
"pyrsistent": ("https://pyrsistent.readthedocs.io/en/latest/", None),
"constantdict": ("https://matthiasdiener.github.io/constantdict/", None),
}

nitpicky = True
Expand All @@ -43,10 +43,6 @@
["py:class", r"numpy\.float[0-9]+"],
["py:class", r"numpy\.complex[0-9]+"],

# As of 2022-06-22, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],

# Reference not found from "<unknown>"? I'm not even sure where to look.
["py:class", r"ExpressionNode"],

Expand Down
1 change: 1 addition & 0 deletions doc/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ For convenience, loopy kernels also directly accept :mod:`numpy` arrays:

.. doctest::

>>> knl = lp.set_options(knl, write_code=False)
>>> evt, (out,) = knl(queue, a=x_vec_host)
>>> assert (out == (2*x_vec_host)).all()

Expand Down
8 changes: 5 additions & 3 deletions examples/python/call-external.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from constantdict import constantdict

import loopy as lp
from loopy.diagnostic import LoopyError
Expand Down Expand Up @@ -30,9 +31,10 @@ def with_types(self, arg_id_to_dtype, callables_table):
"types")

return (self.copy(name_in_target=name_in_target,
arg_id_to_dtype={0: vec_dtype,
1: vec_dtype,
-1: vec_dtype}),
arg_id_to_dtype=constantdict({
0: vec_dtype,
1: vec_dtype,
-1: vec_dtype})),
callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
Expand Down
8 changes: 4 additions & 4 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
Sequence,
)

import immutables
import constantdict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -168,7 +168,7 @@ class CodeGenerationState:
seen_functions: set[SeenFunction]
seen_atomic_dtypes: set[LoopyType]

var_subst_map: immutables.Map[str, Expression]
var_subst_map: constantdict.constantdict[str, Expression]
allow_complex: bool
callables_table: CallablesTable
is_entrypoint: bool
Expand Down Expand Up @@ -381,7 +381,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
seen_dtypes=seen_dtypes,
seen_functions=seen_functions,
seen_atomic_dtypes=seen_atomic_dtypes,
var_subst_map=immutables.Map(),
var_subst_map=constantdict.constantdict(),
allow_complex=allow_complex,
var_name_generator=kernel.get_var_name_generator(),
is_generating_device_code=False,
Expand Down Expand Up @@ -482,7 +482,7 @@ def diverge_callee_entrypoints(program):

new_callables[name] = clbl

return program.copy(callables_table=immutables.Map(new_callables))
return program.copy(callables_table=constantdict.constantdict(new_callables))


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions loopy/frontend/fortran/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from constantdict import constantdict

import islpy as isl
from islpy import dim_type
Expand Down Expand Up @@ -334,7 +334,7 @@ def specialize_fortran_division(t_unit):

new_callables[name] = clbl

return t_unit.copy(callables_table=Map(new_callables))
return t_unit.copy(callables_table=constantdict(new_callables))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from warnings import warn

import numpy as np
from immutables import Map
from constantdict import constantdict

import islpy # to help out Sphinx
import islpy as isl
Expand Down Expand Up @@ -183,7 +183,7 @@ class LoopKernel(Taggable):
Callable[[LoopKernel, str], tuple[LoopyType, str] | None]] = ()
linearization: Sequence[ScheduleItem] | None = None
iname_slab_increments: Mapping[InameStr, tuple[int, int]] = field(
default_factory=Map)
default_factory=constantdict)
"""
A mapping from inames to (lower_incr,
upper_incr) tuples that will be separated out in the execution to generate
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@


if TYPE_CHECKING:
from immutables import Map
from collections.abc import Mapping

from pymbolic import ArithmeticExpression, Variable

Expand Down Expand Up @@ -437,7 +437,7 @@ class _ArraySeparationInfo:
should be used to realize this array.
"""
sep_axis_indices_set: frozenset[int]
subarray_names: Map[tuple[int, ...], str]
subarray_names: Mapping[tuple[int, ...], str]


class ArrayArg(ArrayBase, KernelArgument):
Expand Down
33 changes: 19 additions & 14 deletions loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from warnings import warn

from immutabledict import immutabledict
from constantdict import constantdict
from typing_extensions import Self

from loopy.diagnostic import LoopyError
Expand Down Expand Up @@ -348,15 +348,17 @@ def __init__(self,
try:
hash(arg_id_to_dtype)
except TypeError:
arg_id_to_dtype = immutabledict(arg_id_to_dtype)
assert arg_id_to_dtype is not None
arg_id_to_dtype = constantdict(arg_id_to_dtype)
warn("arg_id_to_dtype passed to InKernelCallable was not hashable. "
"This usage is deprecated and will stop working in 2026.",
DeprecationWarning, stacklevel=3)

try:
hash(arg_id_to_descr)
except TypeError:
arg_id_to_descr = immutabledict(arg_id_to_descr)
assert arg_id_to_descr is not None
arg_id_to_descr = constantdict(arg_id_to_descr)
warn("arg_id_to_descr passed to InKernelCallable was not hashable. "
"This usage is deprecated and will stop working in 2026.",
DeprecationWarning, stacklevel=3)
Expand Down Expand Up @@ -563,9 +565,10 @@ def with_types(self, arg_id_to_dtype, callables_table):
"the function %s." % (self.name))

def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate()
new_arg_id_to_descr[-1] = ValueArgDescriptor()

arg_id_to_descr[-1] = ValueArgDescriptor()
return (self.copy(arg_id_to_descr=arg_id_to_descr),
return (self.copy(arg_id_to_descr=new_arg_id_to_descr.finish()),
clbl_inf_ctx)

def get_hw_axes_sizes(self, arg_id_to_arg, space, callables_table):
Expand Down Expand Up @@ -773,21 +776,22 @@ def with_types(self, arg_id_to_dtype, callables_table):
# Return the kernel call with specialized subkernel and the corresponding
# new arg_id_to_dtype
return self.copy(subkernel=specialized_kernel,
arg_id_to_dtype=immutabledict(new_arg_id_to_dtype)), callables_table
arg_id_to_dtype=constantdict(new_arg_id_to_dtype)), callables_table

def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):

# arg_id_to_descr expressions provided are from the caller's namespace,
# need to register

new_arg_id_to_descr = constantdict(arg_id_to_descr).mutate()
kw_to_pos, pos_to_kw = get_kw_pos_association(self.subkernel)

kw_to_callee_idx = {arg.name: i
for i, arg in enumerate(self.subkernel.args)}

new_args = self.subkernel.args[:]

for arg_id, descr in arg_id_to_descr.items():
for arg_id, descr in new_arg_id_to_descr.items():
if isinstance(arg_id, int):
arg_id = pos_to_kw[arg_id]

Expand Down Expand Up @@ -835,20 +839,20 @@ def with_descrs(self, arg_id_to_descr, clbl_inf_ctx):
for arg in subkernel.args:
kw = arg.name
if isinstance(arg, ArrayBase):
arg_id_to_descr[kw] = (
new_arg_id_to_descr[kw] = (
ArrayArgDescriptor(shape=arg.shape,
dim_tags=arg.dim_tags,
address_space=arg.address_space))
else:
assert isinstance(arg, ValueArg)
arg_id_to_descr[kw] = ValueArgDescriptor()
new_arg_id_to_descr[kw] = ValueArgDescriptor()

arg_id_to_descr[kw_to_pos[kw]] = arg_id_to_descr[kw]
new_arg_id_to_descr[kw_to_pos[kw]] = new_arg_id_to_descr[kw]

# }}}

return (self.copy(subkernel=subkernel,
arg_id_to_descr=immutabledict(arg_id_to_descr)),
arg_id_to_descr=new_arg_id_to_descr.finish()),
clbl_inf_ctx)

def with_added_arg(self, arg_dtype, arg_descr):
Expand All @@ -866,6 +870,7 @@ def with_added_arg(self, arg_dtype, arg_descr):
arg_id_to_dtype = {}
else:
arg_id_to_dtype = dict(self.arg_id_to_dtype)

if self.arg_id_to_descr is None:
arg_id_to_descr = {}
else:
Expand All @@ -877,8 +882,8 @@ def with_added_arg(self, arg_dtype, arg_descr):
arg_id_to_descr[kw_to_pos[var_name]] = arg_descr

return (self.copy(subkernel=subknl,
arg_id_to_dtype=arg_id_to_dtype,
arg_id_to_descr=arg_id_to_descr),
arg_id_to_dtype=constantdict(arg_id_to_dtype),
arg_id_to_descr=constantdict(arg_id_to_descr)),
var_name)

else:
Expand All @@ -900,7 +905,7 @@ def with_packing_for_args(self):
address_space=AddressSpace.GLOBAL)

return self.copy(subkernel=self.subkernel,
arg_id_to_descr=arg_id_to_descr)
arg_id_to_descr=constantdict(arg_id_to_descr))

def get_used_hw_axes(self, callables_table):
gsize, lsize = self.subkernel.get_grid_size_upper_bounds(callables_table,
Expand Down
16 changes: 8 additions & 8 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init__(self,
*,
depends_on: frozenset[str] | str | None = None,
) -> None:
from immutabledict import immutabledict
from constantdict import constantdict

if predicates is None:
predicates = frozenset()
Expand Down Expand Up @@ -324,29 +324,29 @@ def __init__(self,
raise LoopyError("Setting depends_on_is_final to True requires "
"actually specifying happens_after/depends_on")

if isinstance(happens_after, immutabledict):
if isinstance(happens_after, constantdict):
pass
elif happens_after is None:
happens_after = immutabledict()
happens_after = constantdict()
elif isinstance(happens_after, str):
warn("Passing a string for happens_after/depends_on is deprecated and "
"will stop working in 2025. Instead, pass a full-fledged "
"happens_after data structure.", DeprecationWarning, stacklevel=2)

happens_after = immutabledict({
happens_after = constantdict({
after_id.strip(): HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after.split(",")
if after_id.strip()})
elif isinstance(happens_after, frozenset):
happens_after = immutabledict({
happens_after = constantdict({
after_id: HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after})
elif isinstance(happens_after, dict):
happens_after = immutabledict(happens_after)
happens_after = constantdict(happens_after)
else:
raise TypeError("'happens_after' has unexpected type: "
f"{type(happens_after)}")
Expand Down Expand Up @@ -569,13 +569,13 @@ def update_persistent_hash(self, key_hash, key_builder):
def __setstate__(self, val):
super().__setstate__(val)

from immutabledict import immutabledict
from constantdict import constantdict

from loopy.tools import intern_frozenset_of_ids

if self.id is not None: # pylint:disable=access-member-before-definition
self.id = intern(self.id)
self.happens_after = immutabledict({
self.happens_after = constantdict({
intern(after_id): ha
for after_id, ha in self.happens_after.items()})
self.groups = intern_frozenset_of_ids(self.groups)
Expand Down
4 changes: 2 additions & 2 deletions loopy/kernel/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,7 +2089,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):
:arg t_unit: An instance of :class:`TranslationUnit`.
"""
from pyrsistent import pmap
from constantdict import constantdict

from loopy.kernel import KernelState

Expand All @@ -2116,7 +2116,7 @@ def get_call_graph(t_unit, only_kernel_callables=False):
call_graph[name] = clbl.get_called_callables(t_unit.callables_table,
recursive=False)

return pmap(call_graph)
return constantdict(call_graph)

# }}}

Expand Down
14 changes: 8 additions & 6 deletions loopy/library/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING

import numpy as np
from constantdict import constantdict

from loopy.diagnostic import LoopyError
from loopy.kernel.function_interface import ScalarCallable
Expand All @@ -38,21 +39,22 @@

class MakeTupleCallable(ScalarCallable):
def with_types(self, arg_id_to_dtype, callables_table):
new_arg_id_to_dtype = arg_id_to_dtype.copy()
new_arg_id_to_dtype = constantdict(arg_id_to_dtype).mutate()
for i in range(len(arg_id_to_dtype)):
if i in arg_id_to_dtype and arg_id_to_dtype[i] is not None:
new_arg_id_to_dtype[-i-1] = new_arg_id_to_dtype[i]

return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype,
name_in_target="loopy_make_tuple"), callables_table)
return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype.finish(),
name_in_target="loopy_make_tuple"),
callables_table)

def with_descrs(self, arg_id_to_descr, callables_table):
from loopy.kernel.function_interface import ValueArgDescriptor
new_arg_id_to_descr = {(id, ValueArgDescriptor()):
(-id-1, ValueArgDescriptor()) for id in arg_id_to_descr.keys()}
(-id-1, ValueArgDescriptor()) for id in arg_id_to_descr}

return (
self.copy(arg_id_to_descr=new_arg_id_to_descr),
self.copy(arg_id_to_descr=constantdict(new_arg_id_to_descr)),
callables_table)


Expand All @@ -63,7 +65,7 @@ def with_types(self, arg_id_to_dtype, callables_table):
if dtype is not None}
new_arg_id_to_dtype[-1] = NumpyType(np.int32)

return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype),
return (self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype)),
callables_table)

def emit_call(self, expression_to_code_mapper, expression, target):
Expand Down
Loading

0 comments on commit 27aead5

Please sign in to comment.