Skip to content

Commit

Permalink
require Mapping attributes of arrays to be immutabledicts
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored and inducer committed May 9, 2024
1 parent 5aa8aa3 commit c7674dd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
21 changes: 10 additions & 11 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,8 @@ class DictOfNamedArrays(AbstractResultWithNamedArrays):
.. automethod:: __init__
"""
_data: Mapping[str, Array]
_data: Mapping[str, Array] = attrs.field(
validator=attrs.validators.instance_of(immutabledict))

_mapper_method: ClassVar[str] = "map_dict_of_named_arrays"

Expand Down Expand Up @@ -887,17 +888,13 @@ class IndexLambda(_SuppliedShapeAndDtypeMixin, Array):
.. automethod:: with_tagged_reduction
"""
expr: prim.Expression
bindings: Mapping[str, Array] = attrs.field()
var_to_reduction_descr: Mapping[str, ReductionDescriptor]
bindings: Mapping[str, Array] = attrs.field(
validator=attrs.validators.instance_of(immutabledict))
var_to_reduction_descr: Mapping[str, ReductionDescriptor] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))

_mapper_method: ClassVar[str] = "map_index_lambda"

if __debug__:
@bindings.validator # type: ignore[attr-defined, misc]
def _check_bindings(self, attribute: Any, value: Any) -> None:
if isinstance(value, dict):
raise TypeError("bindings may not be a dict")

def with_tagged_reduction(self,
reduction_variable: str,
tag: Tag) -> IndexLambda:
Expand Down Expand Up @@ -1006,8 +1003,10 @@ class Einsum(Array):
access_descriptors: Tuple[Tuple[EinsumAxisDescriptor, ...], ...]
args: Tuple[Array, ...]
redn_axis_to_redn_descr: Mapping[EinsumReductionAxis,
ReductionDescriptor]
index_to_access_descr: Mapping[str, EinsumAxisDescriptor]
ReductionDescriptor] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))
index_to_access_descr: Mapping[str, EinsumAxisDescriptor] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))
_mapper_method: ClassVar[str] = "map_einsum"

@memoize_method
Expand Down
6 changes: 4 additions & 2 deletions pytato/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ class FunctionDefinition(Taggable):
"""
parameters: FrozenSet[str]
return_type: ReturnType
returns: Mapping[str, Array]
returns: Mapping[str, Array] = attrs.field(
validator=attrs.validators.instance_of(immutabledict))
tags: FrozenSet[Tag] = attrs.field(kw_only=True)

@cached_property
Expand Down Expand Up @@ -276,7 +277,8 @@ class Call(AbstractResultWithNamedArrays):
"""
function: FunctionDefinition
bindings: Mapping[str, Array]
bindings: Mapping[str, Array] = attrs.field(
validator=attrs.validators.instance_of(immutabledict))

_mapper_method: ClassVar[str] = "map_call"

Expand Down
3 changes: 2 additions & 1 deletion pytato/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class LoopyCall(AbstractResultWithNamedArrays):
:mod:`loopy` translation unit.
"""
translation_unit: "lp.TranslationUnit"
bindings: Mapping[str, ArrayOrScalar]
bindings: Mapping[str, ArrayOrScalar] = \
attrs.field(validator=attrs.validators.instance_of(immutabledict))
entrypoint: str

_mapper_method: ClassVar[str] = "map_loopy_call"
Expand Down

0 comments on commit c7674dd

Please sign in to comment.