Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: Fix accept args #1830

Merged
merged 2 commits into from
Feb 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 4 additions & 59 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
@@ -32,66 +32,11 @@ def _is_representable_as_int(s: int | str) -> bool:
return False


def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec):
"""
Predicate to determine if two types are compatible.

This function gracefully handles:
- iterators with unknown positions which are considered compatible to any other positions
of another iterator.
- iterators which are defined everywhere, i.e. empty defined dimensions
Beside that this function simply checks for equality of types.

>>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL)
>>> IDim = common.Dimension(value="IDim")
>>> type_on_i_of_i_it = it_ts.IteratorType(
... position_dims=[IDim], defined_dims=[IDim], element_type=bool_type
... )
>>> type_on_undefined_of_i_it = it_ts.IteratorType(
... position_dims="unknown", defined_dims=[IDim], element_type=bool_type
... )
>>> _is_compatible_type(type_on_i_of_i_it, type_on_undefined_of_i_it)
True

>>> JDim = common.Dimension(value="JDim")
>>> type_on_j_of_j_it = it_ts.IteratorType(
... position_dims=[JDim], defined_dims=[JDim], element_type=bool_type
... )
>>> _is_compatible_type(type_on_i_of_i_it, type_on_j_of_j_it)
False
"""
is_compatible = True

if isinstance(type_a, it_ts.IteratorType) and isinstance(type_b, it_ts.IteratorType):
if not any(el_type.position_dims == "unknown" for el_type in [type_a, type_b]):
is_compatible &= type_a.position_dims == type_b.position_dims
if type_a.defined_dims and type_b.defined_dims:
is_compatible &= type_a.defined_dims == type_b.defined_dims
is_compatible &= type_a.element_type == type_b.element_type
elif isinstance(type_a, ts.TupleType) and isinstance(type_b, ts.TupleType):
for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True):
is_compatible &= _is_compatible_type(el_type_a, el_type_b)
elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType):
for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True):
is_compatible &= _is_compatible_type(arg_a, arg_b)
for arg_a, arg_b in zip(
type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True
):
is_compatible &= _is_compatible_type(arg_a, arg_b)
for arg_a, arg_b in zip(
type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True
):
is_compatible &= _is_compatible_type(arg_a, arg_b)
is_compatible &= _is_compatible_type(type_a.returns, type_b.returns)
else:
is_compatible &= type_info.is_concretizable(type_a, type_b)

return is_compatible


def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None:
if node.type:
assert _is_compatible_type(node.type, type_), "Node already has a type which differs."
assert type_info.is_compatible_type(
node.type, type_
), "Node already has a type which differs."
node.type = type_


@@ -475,7 +420,7 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
if isinstance(node, itir.Node):
if isinstance(result, ts.TypeSpec):
if node.type and not isinstance(node.type, ts.DeferredType):
assert _is_compatible_type(node.type, result)
assert type_info.is_compatible_type(node.type, result)
node.type = result
elif isinstance(result, ObservableTypeSynthesizer) or result is None:
pass
72 changes: 66 additions & 6 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@

from gt4py.eve.utils import XIterable, xiter
from gt4py.next import common
from gt4py.next.iterator.type_system import type_specifications as it_ts
from gt4py.next.type_system import type_specifications as ts


@@ -432,6 +433,69 @@ def contains_local_field(type_: ts.TypeSpec) -> bool:
)


# TODO(tehrengruber): This function has specializations on Iterator types, which are not part of
# the general / shared type system. This functionality should be moved to the iterator-only
# type system, but we need some sort of multiple dispatch for that.
# TODO(tehrengruber): Should this have a direction like is_concretizable?
def is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec) -> bool:
"""
Predicate to determine if two types are compatible.

This function gracefully handles:
- iterators with unknown positions which are considered compatible to any other positions
of another iterator.
- iterators which are defined everywhere, i.e. empty defined dimensions
Beside that this function simply checks for equality of types.

>>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL)
>>> IDim = common.Dimension(value="IDim")
>>> type_on_i_of_i_it = it_ts.IteratorType(
... position_dims=[IDim], defined_dims=[IDim], element_type=bool_type
... )
>>> type_on_undefined_of_i_it = it_ts.IteratorType(
... position_dims="unknown", defined_dims=[IDim], element_type=bool_type
... )
>>> is_compatible_type(type_on_i_of_i_it, type_on_undefined_of_i_it)
True

>>> JDim = common.Dimension(value="JDim")
>>> type_on_j_of_j_it = it_ts.IteratorType(
... position_dims=[JDim], defined_dims=[JDim], element_type=bool_type
... )
>>> is_compatible_type(type_on_i_of_i_it, type_on_j_of_j_it)
False
"""
is_compatible = True

if isinstance(type_a, it_ts.IteratorType) and isinstance(type_b, it_ts.IteratorType):
if not any(el_type.position_dims == "unknown" for el_type in [type_a, type_b]):
is_compatible &= type_a.position_dims == type_b.position_dims
if type_a.defined_dims and type_b.defined_dims:
is_compatible &= type_a.defined_dims == type_b.defined_dims
is_compatible &= type_a.element_type == type_b.element_type
elif isinstance(type_a, ts.TupleType) and isinstance(type_b, ts.TupleType):
if len(type_a.types) != len(type_b.types):
return False
for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True):
is_compatible &= is_compatible_type(el_type_a, el_type_b)
elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType):
for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True):
is_compatible &= is_compatible_type(arg_a, arg_b)
for arg_a, arg_b in zip(
type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True
):
is_compatible &= is_compatible_type(arg_a, arg_b)
for arg_a, arg_b in zip(
type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True
):
is_compatible &= is_compatible_type(arg_a, arg_b)
is_compatible &= is_compatible_type(type_a.returns, type_b.returns)
else:
is_compatible &= is_concretizable(type_a, type_b)

return is_compatible


def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool:
"""
Check if ``symbol_type`` can be concretized to ``to_type``.
@@ -725,11 +789,7 @@ def function_signature_incompatibilities_func(
for i, (a_arg, b_arg) in enumerate(
zip(list(func_type.pos_only_args) + list(func_type.pos_or_kw_args.values()), args)
):
if (
b_arg is not UNDEFINED_ARG
and a_arg != b_arg
and not is_concretizable(a_arg, to_type=b_arg)
):
if b_arg is not UNDEFINED_ARG and a_arg != b_arg and not is_compatible_type(a_arg, b_arg):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the change below is only real change in this PR.

if i < len(func_type.pos_only_args):
arg_repr = f"{_number_to_ordinal_number(i + 1)} argument"
else:
@@ -739,7 +799,7 @@ def function_signature_incompatibilities_func(
for kwarg in set(func_type.kw_only_args.keys()) & set(kwargs.keys()):
if (a_kwarg := func_type.kw_only_args[kwarg]) != (
b_kwarg := kwargs[kwarg]
) and not is_concretizable(a_kwarg, to_type=b_kwarg):
) and not is_compatible_type(a_kwarg, b_kwarg):
yield f"Expected keyword argument '{kwarg}' to be of type '{func_type.kw_only_args[kwarg]}', got '{kwargs[kwarg]}'."


Loading