Skip to content

Commit

Permalink
style[next]: improve typing in next.type_translation (#1493)
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt authored Mar 15, 2024
1 parent 1abfaea commit 424dfe2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,6 @@ module = 'gt4py.next.ffront.*'
ignore_errors = true
module = 'gt4py.next.ffront.decorator'

[[tool.mypy.overrides]]
ignore_errors = true
module = 'gt4py.next.type_system.type_translation'

[[tool.mypy.overrides]]
ignore_errors = true
module = 'gt4py.next.iterator.runtime'
Expand Down
34 changes: 25 additions & 9 deletions src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind:
# make int & float precision platform independent.
dt: np.dtype
if dtype is builtins.int:
dt = np.dtype("int64")
elif dtype is builtins.float:
Expand Down Expand Up @@ -66,7 +67,7 @@ def from_type_hint(
localns: Optional[dict[str, Any]] = None,
) -> ts.TypeSpec:
recursive_make_symbol = functools.partial(from_type_hint, globalns=globalns, localns=localns)
extra_args = ()
extra_args: list = []

# ForwardRef
if isinstance(type_hint, str):
Expand All @@ -82,7 +83,10 @@ def from_type_hint(
# Annotated
if typing.get_origin(type_hint) is typing.Annotated:
type_hint, *extra_args = typing.get_args(type_hint)
if not isinstance(type_hint, collections.abc.Callable):
if not isinstance(
type_hint,
collections.abc.Callable, # type:ignore[arg-type] # see https://github.com/python/mypy/issues/14928
):
type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns)

canonical_type = (
Expand All @@ -101,7 +105,9 @@ def from_type_hint(
raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.")
if Ellipsis in args:
raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.")
return ts.TupleType(types=[recursive_make_symbol(arg) for arg in args])
tuple_types = [recursive_make_symbol(arg) for arg in args]
assert all(isinstance(elem, ts.DataType) for elem in tuple_types)
return ts.TupleType(types=tuple_types) # type: ignore[arg-type] # checked in assert

case common.Field:
if (n_args := len(args)) != 2:
Expand Down Expand Up @@ -137,7 +143,8 @@ def from_type_hint(

try:
arg_types, return_type = args
args = [recursive_make_symbol(arg) for arg in arg_types]
new_args = [recursive_make_symbol(arg) for arg in arg_types]
assert all(isinstance(arg, ts.DataType) for arg in new_args)
except Exception as error:
raise ValueError(f"Invalid callable annotations in '{type_hint}'.") from error

Expand All @@ -148,13 +155,17 @@ def from_type_hint(
arg: recursive_make_symbol(arg_type)
for arg, arg_type in kwargs_info[0].data.items()
}
assert all(isinstance(val, (ts.DataType, ts.DeferredType)) for val in kwargs.values())

returns = recursive_make_symbol(return_type)
assert isinstance(returns, (ts.DataType, ts.DeferredType, ts.VoidType))

# TODO(tehrengruber): print better error when no return type annotation is given
return ts.FunctionType(
pos_only_args=args,
pos_or_kw_args=kwargs,
pos_only_args=new_args, # type: ignore[arg-type] # checked in assert
pos_or_kw_args=kwargs, # type: ignore[arg-type] # checked in assert
kw_only_args={}, # TODO
returns=recursive_make_symbol(return_type),
returns=returns,
)

raise ValueError(f"'{type_hint}' type is not supported.")
Expand Down Expand Up @@ -188,18 +199,23 @@ def from_value(value: Any) -> ts.TypeSpec:
elif common.is_field(value):
dims = list(value.domain.dims)
dtype = from_type_hint(value.dtype.scalar_type)
assert isinstance(dtype, ts.ScalarType)
symbol_type = ts.FieldType(dims=dims, dtype=dtype)
elif isinstance(value, tuple):
# Since the elements of the tuple might be one of the special cases
# above, we can not resort to generic `infer_type` but need to do it
# manually here. If we get rid of all the special cases this is
# not needed anymore.
return ts.TupleType(types=[from_value(el) for el in value])
elems = [from_value(el) for el in value]
assert all(isinstance(elem, ts.DataType) for elem in elems)
return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert
else:
type_ = xtyping.infer_type(value, annotate_callable_kwargs=True)
symbol_type = from_type_hint(type_)

if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)):
if isinstance(symbol_type, (ts.DataType, ts.OffsetType, ts.DimensionType)) or (
isinstance(symbol_type, ts.CallableType) and isinstance(symbol_type, ts.TypeSpec)
):
return symbol_type
else:
raise ValueError(f"Impossible to map '{value}' value to a 'Symbol'.")

0 comments on commit 424dfe2

Please sign in to comment.