diff --git a/pyproject.toml b/pyproject.toml index ecd1688c0d..acdaa019ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -215,10 +215,6 @@ module = 'gt4py.next.ffront.*' ignore_errors = true module = 'gt4py.next.ffront.decorator' -[[tool.mypy.overrides]] -ignore_errors = true -module = 'gt4py.next.type_system.type_translation' - [[tool.mypy.overrides]] ignore_errors = true module = 'gt4py.next.iterator.runtime' diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 237e50e5ae..a124703b8a 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -29,6 +29,7 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: # make int & float precision platform independent. + dt: np.dtype if dtype is builtins.int: dt = np.dtype("int64") elif dtype is builtins.float: @@ -66,7 +67,7 @@ def from_type_hint( localns: Optional[dict[str, Any]] = None, ) -> ts.TypeSpec: recursive_make_symbol = functools.partial(from_type_hint, globalns=globalns, localns=localns) - extra_args = () + extra_args: list = [] # ForwardRef if isinstance(type_hint, str): @@ -82,7 +83,10 @@ def from_type_hint( # Annotated if typing.get_origin(type_hint) is typing.Annotated: type_hint, *extra_args = typing.get_args(type_hint) - if not isinstance(type_hint, collections.abc.Callable): + if not isinstance( + type_hint, + collections.abc.Callable, # type:ignore[arg-type] # see https://github.com/python/mypy/issues/14928 + ): type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns) canonical_type = ( @@ -101,7 +105,9 @@ def from_type_hint( raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") if Ellipsis in args: raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") - return ts.TupleType(types=[recursive_make_symbol(arg) for arg in args]) + tuple_types = [recursive_make_symbol(arg) for arg in args] + assert all(isinstance(elem, ts.DataType) for elem in tuple_types) + return ts.TupleType(types=tuple_types) # type: ignore[arg-type] # checked in assert case common.Field: if (n_args := len(args)) != 2: @@ -137,7 +143,8 @@ def from_type_hint( try: arg_types, return_type = args - args = [recursive_make_symbol(arg) for arg in arg_types] + new_args = [recursive_make_symbol(arg) for arg in arg_types] + assert all(isinstance(arg, ts.DataType) for arg in new_args) except Exception as error: raise ValueError(f"Invalid callable annotations in '{type_hint}'.") from error @@ -148,13 +155,17 @@ def from_type_hint( arg: recursive_make_symbol(arg_type) for arg, arg_type in kwargs_info[0].data.items() } + assert all(isinstance(val, (ts.DataType, ts.DeferredType)) for val in kwargs.values()) + + returns = recursive_make_symbol(return_type) + assert isinstance(returns, (ts.DataType, ts.DeferredType, ts.VoidType)) # TODO(tehrengruber): print better error when no return type annotation is given return ts.FunctionType( - pos_only_args=args, - pos_or_kw_args=kwargs, + pos_only_args=new_args, # type: ignore[arg-type] # checked in assert + pos_or_kw_args=kwargs, # type: ignore[arg-type] # checked in assert kw_only_args={}, # TODO - returns=recursive_make_symbol(return_type), + returns=returns, ) raise ValueError(f"'{type_hint}' type is not supported.") @@ -188,18 +199,23 @@ def from_value(value: Any) -> ts.TypeSpec: elif common.is_field(value): dims = list(value.domain.dims) dtype = from_type_hint(value.dtype.scalar_type) + assert isinstance(dtype, ts.ScalarType) symbol_type = ts.FieldType(dims=dims, dtype=dtype) elif isinstance(value, tuple): # Since the elements of the tuple might be one of the special cases # above, we can not resort to generic `infer_type` but need to do it # manually here. If we get rid of all the special cases this is # not needed anymore. - return ts.TupleType(types=[from_value(el) for el in value]) + elems = [from_value(el) for el in value] + assert all(isinstance(elem, ts.DataType) for elem in elems) + return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) symbol_type = from_type_hint(type_) - if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): + if isinstance(symbol_type, (ts.DataType, ts.OffsetType, ts.DimensionType)) or ( + isinstance(symbol_type, ts.CallableType) and isinstance(symbol_type, ts.TypeSpec) + ): return symbol_type else: raise ValueError(f"Impossible to map '{value}' value to a 'Symbol'.")