diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a5d266c4c0..0fc9a6280d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -71,11 +71,11 @@ class Meta: model = PastToItir -def _column_axis(all_closure_vars: dict[str, Any]) -> common.Dimension: +def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]: # construct mapping from column axis to scan operators defined on # that dimension. only one column axis is allowed, but we can use # this mapping to provide good error messages. - scanops_per_axis: dict[common.Dimension, str] = {} + scanops_per_axis: dict[common.Dimension, list[str]] = {} for name, gt_callable in transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).items():