From 8add0298d33e6bdb5d80ad69458d1fa7b2eecb3c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 4 Apr 2024 10:55:52 +0200 Subject: [PATCH 01/61] refactor[next]: itir embedded: cleaner closure run --- src/gt4py/next/embedded/context.py | 4 +- src/gt4py/next/iterator/embedded.py | 75 +++++++++++++++++------------ 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py index 672dc9c620..ef7ba1c7e3 100644 --- a/src/gt4py/next/embedded/context.py +++ b/src/gt4py/next/embedded/context.py @@ -62,4 +62,6 @@ def ctx_updater(*args: tuple[cvars.ContextVar[Any], Any]) -> None: def within_context() -> bool: - return offset_provider.get() is not _undefined_offset_provider + return ( + offset_provider.get() is not _undefined_offset_provider + ) # TODO: this is broken: if there are no shifts, the offset_provider can be empty even within context diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 5c6976eae8..e9c65e9168 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -53,7 +53,7 @@ runtime_checkable, ) from gt4py.next import common, embedded as next_embedded -from gt4py.next.embedded import exceptions as embedded_exceptions +from gt4py.next.embedded import context as embedded_context, exceptions as embedded_exceptions from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime @@ -1508,28 +1508,29 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: ) -def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): - if "offset_provider" not in kwargs: - raise RuntimeError("'offset_provider' not provided.") - - offset_provider = kwargs["offset_provider"] - - @runtime.closure.register(EMBEDDED) - def closure( - domain_: Domain, - sten: Callable[..., Any], - out, #: MutableLocatedField, - ins: list[common.Field], - ) -> None: - _validate_domain(domain_, kwargs["offset_provider"]) - domain: dict[Tag, range] = _dimension_to_tag(domain_) - if not (isinstance(out, common.Field) or is_tuple_of_field(out)): - raise TypeError("'Out' needs to be a located field.") - - column_range = None - column: Optional[ColumnDescriptor] = None - if kwargs.get("column_axis") and kwargs["column_axis"].value in domain: - column_axis = kwargs["column_axis"] +@runtime.closure.register(EMBEDDED) +def closure( + domain_: Domain, + sten: Callable[..., Any], + out, #: MutableLocatedField, + ins: list[common.Field], +) -> None: + offset_provider = embedded_context.offset_provider.get() + _validate_domain(domain_, offset_provider) + domain: dict[Tag, range] = _dimension_to_tag(domain_) + if not (isinstance(out, common.Field) or is_tuple_of_field(out)): + raise TypeError("'Out' needs to be a located field.") + + new_context: dict[str, Any] = {"offset_provider": offset_provider} + + column: Optional[ColumnDescriptor] = None + column_range: Optional[common.NamedRange] = None + if (col_range_placeholder := embedded_context.closure_column_range.get(None)) is not None: + assert ( + col_range_placeholder.unit_range.is_empty() + ) # check it's just the placeholder with empty range + column_axis = col_range_placeholder.dim + if column_axis is not None and column_axis.value in domain: column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] @@ -1537,13 +1538,13 @@ def closure( column_axis, common.UnitRange(column.col_range.start, column.col_range.stop) ) - out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) + new_context["closure_column_range"] = column_range + + out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) - def _closure_runner(): - # Set context variables before executing the closure - column_range_cvar.set(column_range) - offset_provider_cvar.set(offset_provider) + with embedded_context.new_context(**new_context) as ctx: + def _iterate(): for pos in _domain_iterator(domain): promoted_ins = [promote_scalars(inp) for inp in ins] ins_iters = list( @@ -1562,10 +1563,22 @@ def _closure_runner(): assert _is_concrete_position(col_pos) out.field_setitem(col_pos, res[k]) - ctx = cvars.copy_context() - ctx.run(_closure_runner) + ctx.run(_iterate) + + +def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): + if "offset_provider" not in kwargs: + raise RuntimeError("'offset_provider' not provided.") + + context_vars = {"offset_provider": kwargs["offset_provider"]} + if "column_axis" in kwargs: + context_vars["closure_column_range"] = common.NamedRange( + kwargs["column_axis"], + common.UnitRange(0, 0), # empty: indicates column operation, will update later + ) - fun(*args) + with embedded_context.new_context(**context_vars) as ctx: + ctx.run(fun, *args) runtime.fendef_embedded = fendef_embedded From 853d3e1b775d6f781dc0b97710ff010d90369d73 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 4 Apr 2024 11:00:10 +0200 Subject: [PATCH 02/61] cleanup --- src/gt4py/next/embedded/context.py | 2 +- src/gt4py/next/iterator/embedded.py | 25 +++++++++---------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py index ef7ba1c7e3..d663e9aac2 100644 --- a/src/gt4py/next/embedded/context.py +++ b/src/gt4py/next/embedded/context.py @@ -64,4 +64,4 @@ def ctx_updater(*args: tuple[cvars.ContextVar[Any], Any]) -> None: def within_context() -> bool: return ( offset_provider.get() is not _undefined_offset_provider - ) # TODO: this is broken: if there are no shifts, the offset_provider can be empty even within context + ) # TODO(havogt): this is broken: if there are no shifts, the offset_provider can be empty even within context diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index e9c65e9168..edd2bd53e8 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -17,7 +17,6 @@ from __future__ import annotations import abc -import contextvars as cvars import copy import dataclasses import itertools @@ -52,7 +51,7 @@ overload, runtime_checkable, ) -from gt4py.next import common, embedded as next_embedded +from gt4py.next import common from gt4py.next.embedded import context as embedded_context, exceptions as embedded_exceptions from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime @@ -191,12 +190,6 @@ class MutableLocatedField(LocatedField, Protocol): def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ... -#: Column range used in column mode (`column_axis != None`) in the current closure execution context. -column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range -#: Offset provider dict in the current closure execution context. -offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider - - class Column(np.lib.mixins.NDArrayOperatorsMixin): """Represents a column when executed in column mode (`column_axis != None`). @@ -207,7 +200,7 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin): def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 - column_range: common.NamedRange = column_range_cvar.get() + column_range: common.NamedRange = embedded_context.closure_column_range.get() self.data = ( data if isinstance(data, np.ndarray) else np.full(len(column_range.unit_range), data) ) @@ -751,7 +744,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get().unit_range + column_range = embedded_context.closure_column_range.get().unit_range assert column_range is not None col: list[ @@ -796,7 +789,7 @@ class MDIterator: def shift(self, *offsets: OffsetPart) -> MDIterator: complete_offsets = group_offsets(*offsets) - offset_provider = offset_provider_cvar.get() + offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None return MDIterator( self.field, @@ -821,7 +814,7 @@ def deref(self) -> Any: if not all(axis.value in shifted_pos.keys() for axis in axes if axis is not None): raise IndexError("Iterator position doesn't point to valid location for its field.") slice_column = dict[Tag, range]() - column_range = column_range_cvar.get() + column_range = embedded_context.closure_column_range.get() if self.column_axis is not None: assert column_range is not None k_pos = shifted_pos.pop(self.column_axis) @@ -862,7 +855,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get().unit_range + column_range = embedded_context.closure_column_range.get().unit_range # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1303,7 +1296,7 @@ def __getitem__(self, _): def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_str = offset.value if isinstance(offset, runtime.Offset) else offset assert isinstance(offset_str, str) - offset_provider = offset_provider_cvar.get() + offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[offset_str] assert isinstance(connectivity, common.Connectivity) @@ -1359,7 +1352,7 @@ class SparseListIterator: offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True) def deref(self) -> Any: - offset_provider = offset_provider_cvar.get() + offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] assert isinstance(connectivity, common.Connectivity) @@ -1480,7 +1473,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get().unit_range + column_range = embedded_context.closure_column_range.get().unit_range if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") From f661cd32cbce692ecadb07fb2742c34fb98bc05a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 4 Apr 2024 17:36:00 +0200 Subject: [PATCH 03/61] fix test --- .../unit_tests/iterator_tests/test_embedded_internals.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py index ec6e613529..4ce8394bb8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py @@ -13,13 +13,13 @@ # SPDX-License-Identifier: GPL-3.0-or-later import contextvars as cvars -import threading from typing import Any, Callable, Optional import numpy as np import pytest from gt4py.next import common +from gt4py.next.embedded import context as embedded_context from gt4py.next.iterator import embedded @@ -30,8 +30,8 @@ def _run_within_context( offset_provider: Optional[embedded.OffsetProvider] = None, ) -> Any: def wrapped_func(): - embedded.column_range_cvar.set(column_range) - embedded.offset_provider_cvar.set(offset_provider) + embedded_context.closure_column_range.set(column_range) + embedded_context.offset_provider.set(offset_provider) func() cvars.copy_context().run(wrapped_func) @@ -59,7 +59,7 @@ def test_func(data_a: int, data_b: int): assert res.kstart == 1 # Setting an invalid column_range here shouldn't affect other contexts - embedded.column_range_cvar.set(range(2, 999)) + embedded_context.closure_column_range.set(range(2, 999)) _run_within_context( lambda: test_func(2, 3), column_range=common.NamedRange( From 09e568dce7432fabf98cfa0a2e12b014998e38bf Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 8 Apr 2024 21:29:22 +0200 Subject: [PATCH 04/61] without temporaries --- src/gt4py/next/iterator/ir.py | 19 +++ .../iterator/transforms/fencil_to_program.py | 36 +++++ .../next/iterator/transforms/pass_manager.py | 8 +- .../codegens/gtfn/gtfn_module.py | 6 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 138 +++++++++++------- .../formatters/type_check.py | 2 +- .../gtfn_tests/test_itir_to_gtfn_ir.py | 20 +++ 7 files changed, 174 insertions(+), 55 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/fencil_to_program.py diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 56f931f451..87478f0316 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -198,6 +198,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "can_deref", "scan", "if_", + "apply_stencil", *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } @@ -212,6 +213,24 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait): _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] +class Stmt(Node): ... + + +class Assign(Stmt): + target: SymRef + expr: Expr # TODO Program expression + + +class Program(Node, ValidatedSymbolTableTrait): + id: Coerced[SymbolName] + function_definitions: List[FunctionDefinition] + params: List[Sym] + declarations: List[Sym] + body: List[Stmt] + + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] + + # TODO(fthaler): just use hashable types in nodes (tuples instead of lists) Sym.__hash__ = Node.__hash__ # type: ignore[method-assign] Expr.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py new file mode 100644 index 0000000000..31f8f7f826 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -0,0 +1,36 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py import eve +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im + + +class FencilToProgram(eve.NodeTranslator): + @classmethod + def apply(cls, node: itir.FencilDefinition) -> itir.Program: + return cls().visit(node) + + def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.Assign: + apply_stencil = im.call(im.call("apply_stencil")(node.stencil, node.domain))(*node.inputs) + return itir.Assign(target=node.output, expr=apply_stencil) + + def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: + return itir.Program( + id=node.id, + function_definitions=node.function_definitions, + params=node.params, + declarations=[], + body=self.visit(node.closures), + ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 5852ba9ae5..98edc19928 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -23,6 +23,7 @@ from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination from gt4py.next.iterator.transforms.eta_reduction import EtaReduction +from gt4py.next.iterator.transforms.fencil_to_program import FencilToProgram from gt4py.next.iterator.transforms.fuse_maps import FuseMaps from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars @@ -88,7 +89,7 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, -): +) -> itir.Program: icdlv_uids = eve_utils.UIDGenerator() if lift_mode is None: @@ -203,4 +204,7 @@ def apply_common_transforms( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) - return ir + assert isinstance(ir, itir.FencilDefinition) + prog = FencilToProgram.apply(ir) + + return prog diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ca293aa235..043c508c4d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -28,7 +28,7 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode, pass_manager +from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen @@ -183,7 +183,7 @@ def _preprocess_program( program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], runtime_lift_mode: Optional[LiftMode], - ) -> itir.FencilDefinition: + ) -> itir.Program: # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added # to the interface of all (or at least all of concern) backends, but instead should be # configured in the backend itself (like it is here), until then we respect the argument @@ -197,7 +197,7 @@ def _preprocess_program( ) if not self.enable_itir_transforms: - return program + return fencil_to_program.FencilToProgram.apply(program) apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 4617e54eae..3081e3d947 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -16,8 +16,8 @@ from typing import Any, ClassVar, Iterable, Optional, Type, TypeGuard, Union import gt4py.eve as eve +from gt4py.eve import utils as eve_utils from gt4py.eve.concepts import SymbolName -from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import global_tmps @@ -66,8 +66,18 @@ def pytype_to_cpptype(t: str) -> Optional[str]: _horizontal_dimension = "gtfn::unstructured::dim::horizontal" -def _get_domains(closures: Iterable[itir.StencilClosure]) -> Iterable[itir.FunCall]: - return (c.domain for c in closures) +def _get_domains(node: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: + apply_stencils = ( + eve_utils.xiter(node) + .if_isinstance(itir.Assign) + .getattr("expr") + .if_isinstance(itir.FunCall) + .filter( + lambda x: isinstance(x.fun, itir.FunCall) + and x.fun.fun == itir.SymRef(id="apply_stencil") + ) + ) + return (a.fun.args[1] for a in apply_stencils) def _extract_grid_type(domain: itir.FunCall) -> common.GridType: @@ -78,8 +88,8 @@ def _extract_grid_type(domain: itir.FunCall) -> common.GridType: return common.GridType.UNSTRUCTURED -def _get_gridtype(closures: list[itir.StencilClosure]) -> common.GridType: - domains = _get_domains(closures) +def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: + domains = _get_domains(body) grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( @@ -97,9 +107,9 @@ def _name_from_named_range(named_range_call: itir.FunCall) -> str: def _collect_dimensions_from_domain( - closures: Iterable[itir.StencilClosure], + body: Iterable[itir.Stmt], ) -> dict[str, TagDefinition]: - domains = _get_domains(closures) + domains = _get_domains(body) offset_definitions = {} for domain in domains: if domain.fun == itir.SymRef(id="cartesian_domain"): @@ -198,6 +208,14 @@ def _bool_from_literal(node: itir.Node) -> bool: return node.value == "True" +def _is_applied_apply_stencil(arg: itir.Expr) -> TypeGuard[itir.FunCall]: + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and arg.fun.fun == itir.SymRef(id="apply_stencil") + ) + + @dataclasses.dataclass(frozen=True) class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): _binary_op_map: ClassVar[dict[str, str]] = { @@ -224,26 +242,29 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) - uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) + uids: eve_utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=eve_utils.UIDGenerator + ) @classmethod def apply( cls, - node: itir.FencilDefinition | global_tmps.FencilWithTemporaries, + node: itir.Program | global_tmps.FencilWithTemporaries, *, offset_provider: dict, column_axis: Optional[common.Dimension], ) -> FencilDefinition: if isinstance(node, global_tmps.FencilWithTemporaries): - fencil_definition = node.fencil - elif isinstance(node, itir.FencilDefinition): - fencil_definition = node + raise AssertionError() # TODO + prog = node.fencil + elif isinstance(node, itir.Program): + prog = node else: raise TypeError( f"Expected a 'FencilDefinition' or 'FencilWithTemporaries', got '{type(node).__name__}'." ) - grid_type = _get_gridtype(fencil_definition.closures) + grid_type = _get_gridtype(prog.body) return cls( offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type ).visit(node) @@ -437,40 +458,9 @@ def _visit_output_argument(self, node: itir.Expr) -> SidComposite | SymRef: def visit_StencilClosure( self, node: itir.StencilClosure, extracted_functions: list, **kwargs: Any ) -> Union[ScanExecution, StencilExecution]: - backend = Backend(domain=self.visit(node.domain, stencil=node.stencil, **kwargs)) - if _is_scan(node.stencil): - scan_id = self.uids.sequential_id(prefix="_scan") - scan_lambda = self.visit(node.stencil.args[0], **kwargs) - forward = _bool_from_literal(node.stencil.args[1]) - scan_def = ScanPassDefinition( - id=scan_id, params=scan_lambda.params, expr=scan_lambda.expr, forward=forward - ) - extracted_functions.append(scan_def) - scan = Scan( - function=SymRef(id=scan_id), - output=0, - inputs=[i + 1 for i, _ in enumerate(node.inputs)], - init=self.visit(node.stencil.args[2], **kwargs), - ) - column_axis = self.column_axis - assert isinstance(column_axis, common.Dimension) - return ScanExecution( - backend=backend, - scans=[scan], - args=[self._visit_output_argument(node.output), *self.visit(node.inputs)], - axis=SymRef(id=column_axis.value), - ) - return StencilExecution( - stencil=self.visit( - node.stencil, - force_function_extraction=True, - extracted_functions=extracted_functions, - **kwargs, - ), - output=self._visit_output_argument(node.output), - inputs=self.visit(node.inputs, **kwargs), - backend=backend, - ) + raise AssertionError( + "Internal error: StencilClosures are no longer supported." + ) # TODO remove after refactoring is complete @staticmethod def _merge_scans( @@ -517,15 +507,65 @@ def remap_args(s: Scan) -> Scan: res.append(execution) return res + def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: + raise AssertionError("Internal error: all Stmts need to be handled explicitly.") + + def visit_Assign( + self, node: itir.Assign, *, extracted_functions: list, **kwargs: Any + ) -> Union[StencilExecution, ScanExecution]: + assert _is_applied_apply_stencil(node.expr) + stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert + domain = node.expr.fun.args[1] # type: ignore[attr-defined] # checked in assert + inputs = node.expr.args + backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs)) + if _is_scan(stencil): + scan_id = self.uids.sequential_id(prefix="_scan") + scan_lambda = self.visit(stencil.args[0], **kwargs) + forward = _bool_from_literal(stencil.args[1]) + scan_def = ScanPassDefinition( + id=scan_id, params=scan_lambda.params, expr=scan_lambda.expr, forward=forward + ) + extracted_functions.append(scan_def) + scan = Scan( + function=SymRef(id=scan_id), + output=0, + inputs=[i + 1 for i, _ in enumerate(inputs)], + init=self.visit(stencil.args[2], **kwargs), + ) + column_axis = self.column_axis + assert isinstance(column_axis, common.Dimension) + return ScanExecution( + backend=backend, + scans=[scan], + args=[self._visit_output_argument(node.target), *self.visit(inputs)], + axis=SymRef(id=column_axis.value), + ) + return StencilExecution( + stencil=self.visit( + stencil, + force_function_extraction=True, + extracted_functions=extracted_functions, + **kwargs, + ), + output=self._visit_output_argument(node.target), + inputs=self.visit(inputs, **kwargs), + backend=backend, + ) + def visit_FencilDefinition( self, node: itir.FencilDefinition, **kwargs: Any ) -> FencilDefinition: + raise AssertionError( + "Internal error: Fencils are no longer supported." + ) # TODO remove after refactoring is complete + + def visit_Program(self, node: itir.Program, **kwargs: Any) -> FencilDefinition: extracted_functions: list[Union[FunctionDefinition, ScanPassDefinition]] = [] - executions = self.visit(node.closures, extracted_functions=extracted_functions) + executions = self.visit(node.body, extracted_functions=extracted_functions) executions = self._merge_scans(executions) function_definitions = self.visit(node.function_definitions) + extracted_functions offset_definitions = { - **_collect_dimensions_from_domain(node.closures), + **_collect_dimensions_from_domain(node.body), **_collect_offset_definitions(node, self.grid_type, self.offset_provider), } return FencilDefinition( diff --git a/src/gt4py/next/program_processors/formatters/type_check.py b/src/gt4py/next/program_processors/formatters/type_check.py index 03aeef1264..749fc923c6 100644 --- a/src/gt4py/next/program_processors/formatters/type_check.py +++ b/src/gt4py/next/program_processors/formatters/type_check.py @@ -26,7 +26,7 @@ def check_type_inference(program: itir.FencilDefinition, *args: Any, **kwargs: A program, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] ) if isinstance(transformed, global_tmps.FencilWithTemporaries): - transformed = transformed.fencil + raise AssertionError("TODO") # transformed = transformed.fencil return type_inference.pformat( type_inference.infer(transformed, offset_provider=kwargs["offset_provider"]) ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 43a0b45ce6..79d3a78656 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -14,6 +14,7 @@ import gt4py.next as gtx from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.program_processors.codegens.gtfn import gtfn_ir, itir_to_gtfn_ir as it2gtfn @@ -41,3 +42,22 @@ def test_unapplied_funcall_to_function_object(): ).visit(testee) assert expected == actual + + +def test_get_domains(): + domain = im.call("cartesian_domain")(im.call("named_range")(itir.AxisLiteral(value="D"), 1, 2)) + testee = itir.Program( + id="foo", + function_definitions=[], + params=[itir.Sym(id="bar")], + declarations=[], + body=[ + itir.Assign( + target=itir.SymRef(id="bar"), + expr=im.call(im.call("apply_stencil")("deref", domain))(), + ) + ], + ) + + result = list(it2gtfn._get_domains(testee.body)) + assert result == [domain] From 12b8696f7b8ffdb5016c7b8d1315c78b15df5210 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 8 Apr 2024 22:16:06 +0200 Subject: [PATCH 05/61] temporaries --- src/gt4py/next/iterator/ir.py | 14 +++++++++---- .../iterator/transforms/fencil_to_program.py | 9 +++++++++ .../next/iterator/transforms/global_tmps.py | 20 ++++++------------- .../next/iterator/transforms/pass_manager.py | 3 +-- .../codegens/gtfn/itir_to_gtfn_ir.py | 17 +++++----------- .../runners/dace_iterator/itir_to_sdfg.py | 8 ++------ .../program_processors/runners/roundtrip.py | 2 +- 7 files changed, 34 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 87478f0316..c9fd2db5b4 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import typing -from typing import ClassVar, List, Optional, Union +from typing import Any, ClassVar, List, Optional, Union import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels @@ -217,15 +217,21 @@ class Stmt(Node): ... class Assign(Stmt): - target: SymRef - expr: Expr # TODO Program expression + target: Expr # `make_tuple` or SymRef + expr: Expr # only `apply_stencil` + + +class Temporary(Node): + id: Coerced[eve.SymbolName] + domain: Optional[Expr] = None + dtype: Optional[Any] = None # TODO class Program(Node, ValidatedSymbolTableTrait): id: Coerced[SymbolName] function_definitions: List[FunctionDefinition] params: List[Sym] - declarations: List[Sym] + declarations: List[Temporary] body: List[Stmt] _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index 31f8f7f826..33273f6961 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -34,3 +34,12 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: declarations=[], body=self.visit(node.closures), ) + + def visit_FencilWithTemporaries(self, node) -> itir.Program: + return itir.Program( + id=node.fencil.id, + function_definitions=node.fencil.function_definitions, + params=node.params, + declarations=node.tmps, + body=self.visit(node.fencil.closures), + ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d099272b2b..cc59d239e9 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -17,9 +17,8 @@ from collections.abc import Mapping from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence -import gt4py.eve as eve import gt4py.next as gtx -from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.next import common @@ -54,26 +53,18 @@ # Iterator IR extension nodes -class Temporary(ir.Node): - """Iterator IR extension: declaration of a temporary buffer.""" - - id: Coerced[eve.SymbolName] - domain: Optional[ir.Expr] = None - dtype: Optional[Any] = None - - class FencilWithTemporaries(ir.Node, SymbolTableTrait): """Iterator IR extension: declaration of a fencil with temporary buffers.""" fencil: ir.FencilDefinition params: list[ir.Sym] - tmps: list[Temporary] + tmps: list[ir.Temporary] # Extensions for `PrettyPrinter` for easier debugging -def pformat_Temporary(printer: PrettyPrinter, node: Temporary, *, prec: int) -> list[str]: +def pformat_Temporary(printer: PrettyPrinter, node: ir.Temporary, *, prec: int) -> list[str]: start, end = [node.id + " = temporary("], [");"] args = [] if node.domain is not None: @@ -367,7 +358,7 @@ def always_extract_heuristics(_): location=node.location, ), params=node.params, - tmps=[Temporary(id=tmp.id) for tmp in tmps], + tmps=[ir.Temporary(id=tmp.id) for tmp in tmps], ) @@ -638,7 +629,8 @@ def convert_type(dtype): fencil=node.fencil, params=node.params, tmps=[ - Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id]) for tmp in node.tmps + ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id]) + for tmp in node.tmps ], ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 98edc19928..01f3d4f1d3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -204,7 +204,6 @@ def apply_common_transforms( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) - assert isinstance(ir, itir.FencilDefinition) - prog = FencilToProgram.apply(ir) + prog = FencilToProgram.apply(ir) # type: ignore[arg-type] # TODO: remove after refactoring return prog diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3081e3d947..8b9711849a 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -575,11 +575,11 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> FencilDefinition: grid_type=self.grid_type, offset_definitions=list(offset_definitions.values()), function_definitions=function_definitions, - temporaries=[], + temporaries=self.visit(node.declarations, params=[p.id for p in node.params]), ) def visit_Temporary( - self, node: global_tmps.Temporary, *, params: list, **kwargs: Any + self, node: itir.Temporary, *, params: list, **kwargs: Any ) -> TemporaryAllocation: def dtype_to_cpp(x: int | tuple | str) -> str: if isinstance(x, int): @@ -601,13 +601,6 @@ def dtype_to_cpp(x: int | tuple | str) -> str: def visit_FencilWithTemporaries( self, node: global_tmps.FencilWithTemporaries, **kwargs: Any ) -> FencilDefinition: - fencil = self.visit(node.fencil, **kwargs) - return FencilDefinition( - id=fencil.id, - params=self.visit(node.params), - executions=fencil.executions, - grid_type=fencil.grid_type, - offset_definitions=fencil.offset_definitions, - function_definitions=fencil.function_definitions, - temporaries=self.visit(node.tmps, params=[p.id for p in node.params]), - ) + raise AssertionError( + "Internal error: Fencils are no longer supported." + ) # TODO remove after refactoring is complete diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index f0cfad5f1f..a8844fa5f1 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -20,11 +20,7 @@ import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing from gt4py.next.common import NeighborTable -from gt4py.next.iterator import ( - ir as itir, - transforms as itir_transforms, - type_inference as itir_typing, -) +from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef from gt4py.next.type_system import type_specifications as ts, type_translation @@ -164,7 +160,7 @@ def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTable], - tmps: list[itir_transforms.global_tmps.Temporary], + tmps: list[itir.Temporary], use_field_canonical_representation: bool, column_axis: Optional[Dimension] = None, ): diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 38714221fc..aba2dbf71b 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -86,7 +86,7 @@ def visit_FencilWithTemporaries( + f"\n {node.fencil.id}({args}, **kwargs)\n" ) - def visit_Temporary(self, node: gtmps_transform.Temporary, **kwargs: Any) -> str: + def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: assert ( isinstance(node.domain, itir.FunCall) and isinstance(node.domain.fun, itir.SymRef) From 540a2d8858630b8f8c9b151c7d0cd77148306472 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 9 Apr 2024 10:17:04 +0200 Subject: [PATCH 06/61] cleanup --- .../codegens/gtfn/itir_to_gtfn_ir.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 8b9711849a..dfd5956f15 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -455,13 +455,6 @@ def _visit_output_argument(self, node: itir.Expr) -> SidComposite | SymRef: return SidComposite(values=[self._visit_output_argument(v) for v in node.args]) raise ValueError("Expected 'SymRef' or 'make_tuple' in output argument.") - def visit_StencilClosure( - self, node: itir.StencilClosure, extracted_functions: list, **kwargs: Any - ) -> Union[ScanExecution, StencilExecution]: - raise AssertionError( - "Internal error: StencilClosures are no longer supported." - ) # TODO remove after refactoring is complete - @staticmethod def _merge_scans( executions: list[Union[StencilExecution, ScanExecution]], @@ -552,13 +545,6 @@ def visit_Assign( backend=backend, ) - def visit_FencilDefinition( - self, node: itir.FencilDefinition, **kwargs: Any - ) -> FencilDefinition: - raise AssertionError( - "Internal error: Fencils are no longer supported." - ) # TODO remove after refactoring is complete - def visit_Program(self, node: itir.Program, **kwargs: Any) -> FencilDefinition: extracted_functions: list[Union[FunctionDefinition, ScanPassDefinition]] = [] executions = self.visit(node.body, extracted_functions=extracted_functions) @@ -597,10 +583,3 @@ def dtype_to_cpp(x: int | tuple | str) -> str: return TemporaryAllocation( id=node.id, dtype=dtype_to_cpp(node.dtype), domain=self.visit(node.domain, **kwargs) ) - - def visit_FencilWithTemporaries( - self, node: global_tmps.FencilWithTemporaries, **kwargs: Any - ) -> FencilDefinition: - raise AssertionError( - "Internal error: Fencils are no longer supported." - ) # TODO remove after refactoring is complete From 23ddef1779f468f70e0553a60f33be06bc3d08c0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 10 Apr 2024 21:45:01 +0200 Subject: [PATCH 07/61] move to SetAt --- src/gt4py/next/iterator/ir.py | 9 +++++++-- .../iterator/transforms/fencil_to_program.py | 4 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 18 ++++-------------- .../gtfn_tests/test_itir_to_gtfn_ir.py | 5 +++-- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index c9fd2db5b4..5c9c993130 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -186,6 +186,10 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "cast_", } +GTIR_BUILTINS = { + "apply_stencil", # `apply_stencil(stencil)` creates field_operator from stencil +} + BUILTINS = { *GRAMMAR_BUILTINS, "named_range", @@ -216,9 +220,10 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait): class Stmt(Node): ... -class Assign(Stmt): +class SetAt(Stmt): # from JAX array.at[...].set() + expr: Expr # only `apply_stencil(stencil)(inp0, ...)` in first refactoring + domain: Expr target: Expr # `make_tuple` or SymRef - expr: Expr # only `apply_stencil` class Temporary(Node): diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index 33273f6961..a266023143 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -22,9 +22,9 @@ class FencilToProgram(eve.NodeTranslator): def apply(cls, node: itir.FencilDefinition) -> itir.Program: return cls().visit(node) - def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.Assign: + def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: apply_stencil = im.call(im.call("apply_stencil")(node.stencil, node.domain))(*node.inputs) - return itir.Assign(target=node.output, expr=apply_stencil) + return itir.SetAt(expr=apply_stencil, domain=node.domain, target=node.output) def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: return itir.Program( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index dfd5956f15..8f2d8db66d 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -67,17 +67,7 @@ def pytype_to_cpptype(t: str) -> Optional[str]: def _get_domains(node: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: - apply_stencils = ( - eve_utils.xiter(node) - .if_isinstance(itir.Assign) - .getattr("expr") - .if_isinstance(itir.FunCall) - .filter( - lambda x: isinstance(x.fun, itir.FunCall) - and x.fun.fun == itir.SymRef(id="apply_stencil") - ) - ) - return (a.fun.args[1] for a in apply_stencils) + return eve_utils.xiter(node).if_isinstance(itir.SetAt).getattr("domain").to_set() def _extract_grid_type(domain: itir.FunCall) -> common.GridType: @@ -503,12 +493,12 @@ def remap_args(s: Scan) -> Scan: def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: raise AssertionError("Internal error: all Stmts need to be handled explicitly.") - def visit_Assign( - self, node: itir.Assign, *, extracted_functions: list, **kwargs: Any + def visit_SetAt( + self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: assert _is_applied_apply_stencil(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert - domain = node.expr.fun.args[1] # type: ignore[attr-defined] # checked in assert + domain = node.domain inputs = node.expr.args backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs)) if _is_scan(stencil): diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 79d3a78656..507718523f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -52,9 +52,10 @@ def test_get_domains(): params=[itir.Sym(id="bar")], declarations=[], body=[ - itir.Assign( + itir.SetAt( + expr=im.call(im.call("apply_stencil")("deref"))(), + domain=domain, target=itir.SymRef(id="bar"), - expr=im.call(im.call("apply_stencil")("deref", domain))(), ) ], ) From c99f44d88d192b68b46c14879dad3ec6b8cf96ce Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 10 Apr 2024 22:43:46 +0200 Subject: [PATCH 08/61] embedded --- src/gt4py/next/iterator/builtins.py | 5 ++ src/gt4py/next/iterator/embedded.py | 42 ++++++++++ src/gt4py/next/iterator/runtime.py | 7 +- .../iterator_tests/test_program.py | 78 +++++++++++++++++++ 4 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 7bbcd67337..854730881c 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -23,6 +23,11 @@ def __init__(self) -> None: super().__init__("Backend not selected") +@builtin_dispatch +def apply_stencil(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def deref(*args): raise BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index edd2bd53e8..a1c5f6de48 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1501,6 +1501,48 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: ) +@runtime.set_at.register(EMBEDDED) +def set_at(expr, domain, target) -> None: + # domain is assumed to match the domain of outer apply stencil / fields + if callable(expr): # lazy apply stencil #TODO check + expr(target, lazy_domain=domain) + + +@builtins.apply_stencil.register(EMBEDDED) +def apply_stencil(fun, *, domain=None): + def impl(*args, **kwargs): + new_domain = None + # if common.is_domain_like(domain): + new_domain = domain + # else: + # assert callable(domain) + # new_domain = domain(*args, **kwargs) + + # TODO this only works if the apply_stencil is directly in set_at + # for the clean solution we need to pre-allocate the result buffer (see strategy for scan in field_view embedded) + def lazy_apply_stencil(out, *, lazy_domain=None): + if new_domain is None: + assert lazy_domain is not None + domain = lazy_domain + else: + domain = new_domain + + closure( + # cartesian_domain( + # *(named_range(d, start, stop) for d, (start, stop) in domain.items()) + # ), + domain, + fun, + out, + list(args), # TODO kwargs + ) + return out + + return lazy_apply_stencil + + return impl + + @runtime.closure.register(EMBEDDED) def closure( domain_: Domain, diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index f1710159a4..5790cceefc 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -206,5 +206,10 @@ def fundef(fun): @builtin_dispatch -def closure(*args): +def closure(*args): # TODO remove + return BackendNotSelectedError() + + +@builtin_dispatch +def set_at(*args): return BackendNotSelectedError() diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py new file mode 100644 index 0000000000..00af4ecee0 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -0,0 +1,78 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next.iterator.builtins import apply_stencil, cartesian_domain, deref, named_range +from gt4py.next.iterator.runtime import fendef, fundef, set_at + +from next_tests.unit_tests.conftest import program_processor, run_processor + + +I = gtx.Dimension("I") + +_isize = 10 + + +@pytest.fixture +def dom(): + return {I: range(_isize)} + + +def a_field(): + return gtx.as_field([I], np.arange(0, _isize, dtype=np.float64)) + + +def out_field(): + return gtx.as_field([I], np.zeros(shape=(_isize,))) + + +@fundef +def copy_stencil(inp): + return deref(inp) + + +@fendef +def copy_program(inp, out, size): + set_at( + # apply_stencil(copy_stencil, domain=cartesian_domain(named_range(I, 0, size)))(inp), + apply_stencil(copy_stencil)(inp), + cartesian_domain(named_range(I, 0, size)), + out, + ) + + +def test_prog(): + validate = True + + inp = a_field() + out = out_field() + + copy_program(inp, out, _isize, offset_provider={}) + if validate: + assert np.allclose(inp.asnumpy(), out.asnumpy()) + + +# example for +# @field_operator +# def sum(a, b, c): +# a + b + c + +# def plus(a,b): +# return deref(a)+deref(b) + +# def sum_prog(a, b, c, out): +# set_at(apply_stencil(plus)(a, apply_stencil(plus)(b, c)), out.domain, out) From 1a6f8852706079319fd483d78344564955c10d10 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Apr 2024 09:43:14 +0200 Subject: [PATCH 09/61] roundtrip+double_roundtrip with shortcuts --- src/gt4py/next/iterator/builtins.py | 1 + src/gt4py/next/iterator/ir.py | 2 ++ src/gt4py/next/iterator/pretty_printer.py | 33 +++++++++++++++++ src/gt4py/next/iterator/runtime.py | 2 +- src/gt4py/next/iterator/tracing.py | 36 ++++++++++++++----- .../iterator/transforms/fencil_to_program.py | 2 +- .../next/iterator/transforms/pass_manager.py | 2 ++ .../program_processors/runners/roundtrip.py | 13 ++++++- 8 files changed, 80 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 854730881c..4c08b39234 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -435,6 +435,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "cartesian_domain", "unstructured_domain", "named_range", + "apply_stencil", *MATH_BUILTINS, } diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 5c9c993130..c1879cd6f4 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -255,3 +255,5 @@ class Program(Node, ValidatedSymbolTableTrait): FunctionDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] StencilClosure.__hash__ = Node.__hash__ # type: ignore[method-assign] FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] +Program.__hash__ = Node.__hash__ # type: ignore[method-assign] +SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 041c12e7b0..66629051c9 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -272,6 +272,21 @@ def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[st ) return self._optimum(h, v) + def visit_SetAt(self, node: ir.SetAt, *, prec: int) -> list[str]: + expr = self.visit(node.expr, prec=0) + domain = self.visit(node.domain, prec=0) + target = self.visit(node.target, prec=0) + + head = self._hmerge(target, [" @ "], domain) + foot = self._hmerge([" ← "], expr, [";"]) + + h = self._hmerge(head, foot) + v = self._vmerge( + head, + self._indent(self._indent(foot)), + ) + return self._optimum(h, v) + def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) @@ -291,6 +306,24 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> lis params, self._indent(function_definitions), self._indent(closures), ["}"] ) + def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: + assert prec == 0 + function_definitions = self.visit(node.function_definitions, prec=0) + body = self.visit(node.body, prec=0) + # TODO declarations + params = self.visit(node.params, prec=0) + + hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) + vparams = self._vmerge( + [node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"] + ) + params = self._optimum(hparams, vparams) + + function_definitions = self._vmerge(*function_definitions) + body = self._vmerge(*body) + + return self._vmerge(params, self._indent(function_definitions), self._indent(body), ["}"]) + @classmethod def apply(cls, node: ir.Node, indent: int, width: int) -> str: return "\n".join(cls(indent=indent, width=width).visit(node, prec=0)) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 5790cceefc..133d5b0df3 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -31,7 +31,7 @@ ) -__all__ = ["offset", "fundef", "fendef", "closure"] +__all__ = ["offset", "fundef", "fendef", "closure", "set_at"] @dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index d6dbb47ee9..d4fc993102 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -20,7 +20,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import Node from gt4py.next import common, iterator -from gt4py.next.iterator import builtins +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir import ( AxisLiteral, Expr, @@ -210,6 +210,7 @@ def __bool__(self): class TracerContext: fundefs: ClassVar[List[FunctionDefinition]] = [] closures: ClassVar[List[StencilClosure]] = [] + body: ClassVar[List[itir.Stmt]] = [] @classmethod def add_fundef(cls, fun): @@ -220,6 +221,10 @@ def add_fundef(cls, fun): def add_closure(cls, closure): cls.closures.append(closure) + @classmethod + def add_stmt(cls, stmt): + cls.body.append(stmt) + def __enter__(self): iterator.builtins.builtin_dispatch.push_key(TRACING) @@ -241,6 +246,11 @@ def closure(domain, stencil, output, inputs): ) +@iterator.runtime.set_at.register(TRACING) +def set_at(expr, domain, target): + TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target)) + + def _contains_tuple_dtype_field(arg): if isinstance(arg, tuple): return any(_contains_tuple_dtype_field(el) for el in arg) @@ -296,7 +306,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: def trace_fencil_definition( fun: typing.Callable, args: typing.Iterable, *, use_arg_types=True -) -> FencilDefinition: +) -> FencilDefinition | itir.Program: """ Transform fencil given as a callable into `itir.FencilDefinition` using tracing. @@ -313,9 +323,19 @@ def trace_fencil_definition( params = _make_fencil_params(fun, args, use_arg_types=use_arg_types) trace_function_call(fun, args=(_s(param.id) for param in params)) - return FencilDefinition( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - closures=TracerContext.closures, - ) + if TracerContext.closures: + return FencilDefinition( + id=fun.__name__, + function_definitions=TracerContext.fundefs, + params=params, + closures=TracerContext.closures, + ) + else: + assert TracerContext.body + return itir.Program( + id=fun.__name__, + function_definitions=TracerContext.fundefs, + params=params, + declarations=[], # TODO + body=TracerContext.body, + ) diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index a266023143..c34947fc60 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -23,7 +23,7 @@ def apply(cls, node: itir.FencilDefinition) -> itir.Program: return cls().visit(node) def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - apply_stencil = im.call(im.call("apply_stencil")(node.stencil, node.domain))(*node.inputs) + apply_stencil = im.call(im.call("apply_stencil")(node.stencil))(*node.inputs) return itir.SetAt(expr=apply_stencil, domain=node.domain, target=node.output) def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 01f3d4f1d3..06be5dbe97 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -90,6 +90,8 @@ def apply_common_transforms( ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: + if isinstance(ir, itir.Program): + return ir # TODO upgrade transformations to work on Program icdlv_uids = eve_utils.UIDGenerator() if lift_mode is None: diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index aba2dbf71b..035d9e3362 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -67,6 +67,15 @@ def ${id}(${','.join(params)}): return ${expr} """ ) + Program = as_mako( + """ +${''.join(function_definitions)} +@fendef +def ${id}(${','.join(params)}): + ${'\\n '.join(body)} + """ + ) + SetAt = as_mako("set_at(${expr}, ${domain}, ${target})") # extension required by global_tmps def visit_FencilWithTemporaries( @@ -190,7 +199,9 @@ def fencil_generator( if not debug: pathlib.Path(source_file_name).unlink(missing_ok=True) - assert isinstance(ir, (itir.FencilDefinition, gtmps_transform.FencilWithTemporaries)) + assert isinstance( + ir, (itir.FencilDefinition, gtmps_transform.FencilWithTemporaries, itir.Program) + ) fencil_name = ( ir.fencil.id + "_wrapper" if isinstance(ir, gtmps_transform.FencilWithTemporaries) From 39d6d7cb4c764c7b8c1fc138f526e30ebb56aaf9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Apr 2024 14:32:07 +0200 Subject: [PATCH 10/61] changes --- src/gt4py/next/iterator/ir.py | 2 +- .../feature_tests/iterator_tests/test_program.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index c1879cd6f4..fa8e64461b 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -202,7 +202,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "can_deref", "scan", "if_", - "apply_stencil", + *GTIR_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index 00af4ecee0..9a97c8c101 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -55,6 +55,18 @@ def copy_program(inp, out, size): ) +# @fundef +# def plus_stencil(inp0,inp1): +# return plus(deref(inp0),deref(inp1)) + +# set_at( +# # apply_stencil(copy_stencil, domain=cartesian_domain(named_range(I, 0, size)))(inp), +# apply_stencil(plus_stencil)(inp0, apply_stencil(plus_stencil)(inp1,inp2)), +# cartesian_domain(named_range(I, 0, size)), +# out, +# ) + + def test_prog(): validate = True From ab44009972f835fa6f3e570b19d87d2ce1c3d8a6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Apr 2024 14:55:05 +0200 Subject: [PATCH 11/61] fencil2program only for gtfn --- .../next/iterator/transforms/fencil_to_program.py | 5 +++-- src/gt4py/next/iterator/transforms/pass_manager.py | 10 ++++------ .../program_processors/codegens/gtfn/gtfn_module.py | 9 +++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index a266023143..c573b206eb 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -15,11 +15,12 @@ from gt4py import eve from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import global_tmps class FencilToProgram(eve.NodeTranslator): @classmethod - def apply(cls, node: itir.FencilDefinition) -> itir.Program: + def apply(cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries) -> itir.Program: return cls().visit(node) def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: @@ -35,7 +36,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: body=self.visit(node.closures), ) - def visit_FencilWithTemporaries(self, node) -> itir.Program: + def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program: return itir.Program( id=node.fencil.id, function_definitions=node.fencil.function_definitions, diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 01f3d4f1d3..8a19203275 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -23,9 +23,8 @@ from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.fencil_to_program import FencilToProgram from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps +from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps, FencilWithTemporaries from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars from gt4py.next.iterator.transforms.inline_fundefs import InlineFundefs, PruneUnreferencedFundefs from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan @@ -89,7 +88,7 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, -) -> itir.Program: +) -> itir.FencilDefinition | FencilWithTemporaries: icdlv_uids = eve_utils.UIDGenerator() if lift_mode is None: @@ -204,6 +203,5 @@ def apply_common_transforms( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) - prog = FencilToProgram.apply(ir) # type: ignore[arg-type] # TODO: remove after refactoring - - return prog + assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)) + return ir diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 043c508c4d..c7895aea20 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -28,7 +28,7 @@ from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager +from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, global_tmps, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen @@ -183,7 +183,7 @@ def _preprocess_program( program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], runtime_lift_mode: Optional[LiftMode], - ) -> itir.Program: + ) -> itir.FencilDefinition | global_tmps.FencilWithTemporaries: # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added # to the interface of all (or at least all of concern) backends, but instead should be # configured in the backend itself (like it is here), until then we respect the argument @@ -197,7 +197,7 @@ def _preprocess_program( ) if not self.enable_itir_transforms: - return fencil_to_program.FencilToProgram.apply(program) + return program apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, @@ -231,8 +231,9 @@ def generate_stencil_source( runtime_lift_mode: Optional[LiftMode] = None, ) -> str: new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + program_itir = fencil_to_program.FencilToProgram().apply(new_program) gtfn_ir = GTFN_lowering.apply( - new_program, offset_provider=offset_provider, column_axis=column_axis + program_itir, offset_provider=offset_provider, column_axis=column_axis ) if self.use_imperative_backend: From 12f1663978fd967e3b84f09c59560613253b5614 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Apr 2024 14:56:47 +0200 Subject: [PATCH 12/61] fix import --- .../transforms_tests/test_global_tmps.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 0521b0414b..442f71bc9b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -21,7 +21,6 @@ AUTO_DOMAIN, FencilWithTemporaries, SimpleTemporaryExtractionHeuristics, - Temporary, collect_tmps_info, split_closures, update_domains, @@ -87,7 +86,7 @@ def test_split_closures(): ], ) actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [Temporary(id="_tmp_1"), Temporary(id="_tmp_2")] + assert actual.tmps == [ir.Temporary(id="_tmp_1"), ir.Temporary(id="_tmp_2")] assert actual.fencil == expected @@ -141,7 +140,7 @@ def test_split_closures_simple_heuristics(): actual = split_closures( testee, extraction_heuristics=SimpleTemporaryExtractionHeuristics, offset_provider={} ) - assert actual.tmps == [Temporary(id="_tmp_1")] + assert actual.tmps == [ir.Temporary(id="_tmp_1")] assert actual.fencil == expected @@ -211,7 +210,7 @@ def test_split_closures_lifted_scan(): ) actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [Temporary(id="_tmp_1")] + assert actual.tmps == [ir.Temporary(id="_tmp_1")] assert actual.fencil == expected @@ -255,7 +254,7 @@ def test_update_cartesian_domains(): ], ), params=[im.sym("i"), im.sym("j"), im.sym("k"), im.sym("inp"), im.sym("out")], - tmps=[Temporary(id="_gtmp_0"), Temporary(id="_gtmp_1")], + tmps=[ir.Temporary(id="_gtmp_0"), ir.Temporary(id="_gtmp_1")], ) expected = copy.deepcopy(testee) assert expected.fencil.params.pop() == im.sym("_gtmp_auto_domain") @@ -413,14 +412,14 @@ def test_collect_tmps_info(): ], ), params=[ir.Sym(id="i"), ir.Sym(id="j"), ir.Sym(id="k"), ir.Sym(id="inp"), ir.Sym(id="out")], - tmps=[Temporary(id="_gtmp_0"), Temporary(id="_gtmp_1")], + tmps=[ir.Temporary(id="_gtmp_0"), ir.Temporary(id="_gtmp_1")], ) expected = FencilWithTemporaries( fencil=testee.fencil, params=testee.params, tmps=[ - Temporary(id="_gtmp_0", domain=tmp_domain, dtype="float64"), - Temporary(id="_gtmp_1", domain=tmp_domain, dtype="float64"), + ir.Temporary(id="_gtmp_0", domain=tmp_domain, dtype="float64"), + ir.Temporary(id="_gtmp_1", domain=tmp_domain, dtype="float64"), ], ) actual = collect_tmps_info(testee, offset_provider={}) From 50374938dec27c18222c3243d67795e198ba9e5d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Apr 2024 15:00:37 +0200 Subject: [PATCH 13/61] fix builtins list --- src/gt4py/next/iterator/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 5c9c993130..2e2c30d6e8 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -202,7 +202,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "can_deref", "scan", "if_", - "apply_stencil", + *GTIR_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } From 751581ef27cb60904e2c97f2cbefe1652bae2eda Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Apr 2024 15:02:09 +0200 Subject: [PATCH 14/61] add comment --- .../next/program_processors/codegens/gtfn/gtfn_module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index c7895aea20..d28e513093 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -231,7 +231,9 @@ def generate_stencil_source( runtime_lift_mode: Optional[LiftMode] = None, ) -> str: new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) - program_itir = fencil_to_program.FencilToProgram().apply(new_program) + program_itir = fencil_to_program.FencilToProgram().apply( + new_program + ) # TODO(havogt): should be removed after refactoring to combined IR gtfn_ir = GTFN_lowering.apply( program_itir, offset_provider=offset_provider, column_axis=column_axis ) From 3d2f33e829de119b235ac29d4eb38f2d5c63695f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Apr 2024 16:36:20 +0200 Subject: [PATCH 15/61] fix type checker --- src/gt4py/next/program_processors/formatters/type_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/formatters/type_check.py b/src/gt4py/next/program_processors/formatters/type_check.py index 749fc923c6..03aeef1264 100644 --- a/src/gt4py/next/program_processors/formatters/type_check.py +++ b/src/gt4py/next/program_processors/formatters/type_check.py @@ -26,7 +26,7 @@ def check_type_inference(program: itir.FencilDefinition, *args: Any, **kwargs: A program, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] ) if isinstance(transformed, global_tmps.FencilWithTemporaries): - raise AssertionError("TODO") # transformed = transformed.fencil + transformed = transformed.fencil return type_inference.pformat( type_inference.infer(transformed, offset_provider=kwargs["offset_provider"]) ) From 4cbce7e6d77ecb09b27fd91a1c06271a2e9cec73 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 12 Apr 2024 08:42:49 +0200 Subject: [PATCH 16/61] Apply suggestions from code review --- src/gt4py/next/iterator/transforms/global_tmps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index cc59d239e9..f23dafe63b 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -53,7 +53,7 @@ # Iterator IR extension nodes -class FencilWithTemporaries(ir.Node, SymbolTableTrait): +class FencilWithTemporaries(ir.Node, SymbolTableTrait): # TODO(havogt): will be removed after refactoring """Iterator IR extension: declaration of a fencil with temporary buffers.""" fencil: ir.FencilDefinition From c95564516c404c5dd8f647766532823e0b3204b5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 12 Apr 2024 07:58:28 +0000 Subject: [PATCH 17/61] format --- src/gt4py/next/iterator/transforms/global_tmps.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index f23dafe63b..81dbace942 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -53,7 +53,9 @@ # Iterator IR extension nodes -class FencilWithTemporaries(ir.Node, SymbolTableTrait): # TODO(havogt): will be removed after refactoring +class FencilWithTemporaries( + ir.Node, SymbolTableTrait +): # TODO(havogt): will be removed after refactoring """Iterator IR extension: declaration of a fencil with temporary buffers.""" fencil: ir.FencilDefinition From 6effe10be05af055936cbe6c92b1e26dd56ee381 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 12 Apr 2024 12:51:56 +0200 Subject: [PATCH 18/61] pretty printing/parsing --- src/gt4py/next/iterator/pretty_parser.py | 32 +++++++++++++++ src/gt4py/next/iterator/pretty_printer.py | 36 ++++++++++++++++- .../iterator_tests/test_pretty_parser.py | 39 +++++++++++++++++++ .../iterator_tests/test_pretty_printer.py | 39 +++++++++++++++++++ 4 files changed, 145 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 9dd96b076e..2f00e98df2 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -24,6 +24,8 @@ start: fencil_definition | function_definition | stencil_closure + | set_at + | program | prec0 SYM: CNAME @@ -64,6 +66,7 @@ | "·" prec7 -> deref | "¬" prec7 -> bool_not | "↑" prec7 -> lift + | "⇑" prec7 -> apply_stencil ?prec8: prec9 | prec8 "[" prec0 "]" -> tuple_get @@ -81,7 +84,9 @@ named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" + set_at: prec0 "@" prec0 "←" prec1 ";" fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}" + program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( set_at )+ "}" %import common (CNAME, SIGNED_FLOAT, SIGNED_INT, WS) %ignore WS @@ -167,6 +172,9 @@ def deref(self, arg: ir.Expr) -> ir.FunCall: def lift(self, arg: ir.Expr) -> ir.FunCall: return ir.FunCall(fun=ir.SymRef(id="lift"), args=[arg]) + def apply_stencil(self, arg: ir.Expr) -> ir.FunCall: + return ir.FunCall(fun=ir.SymRef(id="apply_stencil"), args=[arg]) + def astype(self, arg: ir.Expr) -> ir.FunCall: return ir.FunCall(fun=ir.SymRef(id="cast_"), args=[arg]) @@ -202,6 +210,10 @@ def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: output, stencil, *inputs, domain = args return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) + def set_at(self, *args: ir.Expr) -> ir.SetAt: + target, domain, expr = args + return ir.SetAt(expr=expr, domain=domain, target=target) + def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: params = [] function_definitions = [] @@ -218,6 +230,26 @@ def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: id=fid, function_definitions=function_definitions, params=params, closures=closures ) + def program(self, fid: str, *args: ir.Node) -> ir.Program: + params = [] + function_definitions = [] + body = [] + for arg in args: + if isinstance(arg, ir.Sym): + params.append(arg) + elif isinstance(arg, ir.FunctionDefinition): + function_definitions.append(arg) + else: + assert isinstance(arg, ir.SetAt) + body.append(arg) + return ir.Program( + id=fid, + function_definitions=function_definitions, + params=params, + body=body, + declarations=[], # TODO + ) + def start(self, arg: ir.Node) -> ir.Node: return arg diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 041c12e7b0..941e9c3073 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -43,7 +43,7 @@ } # replacements for builtin unary operations -UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬"} +UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬", "apply_stencil": "⇑"} # operator precedence PRECEDENCE: Final = { @@ -63,6 +63,7 @@ "deref": 7, "not_": 7, "lift": 7, + "apply_stencil": 7, "tuple_get": 8, "__call__": 8, } @@ -272,6 +273,21 @@ def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[st ) return self._optimum(h, v) + def visit_SetAt(self, node: ir.SetAt, *, prec: int) -> list[str]: + expr = self.visit(node.expr, prec=0) + domain = self.visit(node.domain, prec=0) + target = self.visit(node.target, prec=0) + + head = self._hmerge(target, [" @ "], domain) + foot = self._hmerge([" ← "], expr, [";"]) + + h = self._hmerge(head, foot) + v = self._vmerge( + head, + self._indent(self._indent(foot)), + ) + return self._optimum(h, v) + def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) @@ -291,6 +307,24 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> lis params, self._indent(function_definitions), self._indent(closures), ["}"] ) + def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: + assert prec == 0 + function_definitions = self.visit(node.function_definitions, prec=0) + body = self.visit(node.body, prec=0) + # TODO declarations + params = self.visit(node.params, prec=0) + + hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) + vparams = self._vmerge( + [node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"] + ) + params = self._optimum(hparams, vparams) + + function_definitions = self._vmerge(*function_definitions) + body = self._vmerge(*body) + + return self._vmerge(params, self._indent(function_definitions), self._indent(body), ["}"]) + @classmethod def apply(cls, node: ir.Node, indent: int, width: int) -> str: return "\n".join(cls(indent=indent, width=width).visit(node, prec=0)) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index b3a9ba8001..3f610035ab 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -69,6 +69,13 @@ def test_lift(): assert actual == expected +def test_apply_stencil(): + testee = "⇑x" + expected = ir.FunCall(fun=ir.SymRef(id="apply_stencil"), args=[ir.SymRef(id="x")]) + actual = pparse(testee) + assert actual == expected + + def test_bool_arithmetic(): testee = "¬(¬a ∨ b ∧ (c ∨ d))" expected = ir.FunCall( @@ -195,6 +202,17 @@ def test_stencil_closure(): assert actual == expected +def test_set_at(): + testee = "y @ cartesian_domain() ← x;" + expected = ir.SetAt( + expr=ir.SymRef(id="x"), + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + target=ir.SymRef(id="y"), + ) + actual = pparse(testee) + assert actual == expected + + def test_fencil_definition(): testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" expected = ir.FencilDefinition( @@ -214,3 +232,24 @@ def test_fencil_definition(): ) actual = pparse(testee) assert actual == expected + + +def test_program(): + testee = "f(d, x, y) {\n g = λ(x) → x;\n y @ cartesian_domain() ← x;\n}" + expected = ir.Program( + id="f", + function_definitions=[ + ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) + ], + params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], + body=[ + ir.SetAt( + expr=ir.SymRef(id="x"), + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + target=ir.SymRef(id="y"), + ) + ], + declarations=[], # TODO + ) + actual = pparse(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 844c905e8e..aa0ac28050 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -159,6 +159,13 @@ def test_lift(): assert actual == expected +def test_apply_stencil(): + testee = ir.FunCall(fun=ir.SymRef(id="apply_stencil"), args=[ir.SymRef(id="x")]) + expected = "⇑x" # TODO consider ⇈ + actual = pformat(testee) + assert actual == expected + + def test_bool_arithmetic(): testee = ir.FunCall( fun=ir.SymRef(id="not_"), @@ -299,6 +306,17 @@ def test_stencil_closure(): assert actual == expected +def test_set_at(): + testee = ir.SetAt( + expr=ir.SymRef(id="x"), + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + target=ir.SymRef(id="y"), + ) + expected = "y @ cartesian_domain() ← x;" + actual = pformat(testee) + assert actual == expected + + def test_fencil_definition(): testee = ir.FencilDefinition( id="f", @@ -318,3 +336,24 @@ def test_fencil_definition(): actual = pformat(testee) expected = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" assert actual == expected + + +def test_program(): + testee = ir.Program( + id="f", + function_definitions=[ + ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) + ], + params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], + declarations=[], # TODO + body=[ + ir.SetAt( + expr=ir.SymRef(id="x"), + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + target=ir.SymRef(id="y"), + ) + ], + ) + actual = pformat(testee) + expected = "f(d, x, y) {\n g = λ(x) → x;\n y @ cartesian_domain() ← x;\n}" + assert actual == expected From 66de3ec43870205e30dcba13661c1b36000b956a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 08:56:39 +0200 Subject: [PATCH 19/61] Apply suggestions from code review Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/pretty_parser.py | 1 + src/gt4py/next/iterator/transforms/global_tmps.py | 2 +- .../next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py | 2 +- .../next_tests/unit_tests/iterator_tests/test_pretty_parser.py | 1 + .../next_tests/unit_tests/iterator_tests/test_pretty_printer.py | 1 + 5 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 2f00e98df2..195180af42 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -214,6 +214,7 @@ def set_at(self, *args: ir.Expr) -> ir.SetAt: target, domain, expr = args return ir.SetAt(expr=expr, domain=domain, target=target) + # TODO(havogt): remove after refactoring. def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: params = [] function_definitions = [] diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 81dbace942..5a90cd0360 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -55,7 +55,7 @@ class FencilWithTemporaries( ir.Node, SymbolTableTrait -): # TODO(havogt): will be removed after refactoring +): # TODO(havogt): remove and use new `itir.Program` instead. """Iterator IR extension: declaration of a fencil with temporary buffers.""" fencil: ir.FencilDefinition diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 8f2d8db66d..6f63d83db8 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -491,7 +491,7 @@ def remap_args(s: Scan) -> Scan: return res def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: - raise AssertionError("Internal error: all Stmts need to be handled explicitly.") + raise AssertionError("All Stmts need to be handled explicitly.") def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index 3f610035ab..6173c66bdf 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -213,6 +213,7 @@ def test_set_at(): assert actual == expected +# TODO: remove after refactoring def test_fencil_definition(): testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" expected = ir.FencilDefinition( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index aa0ac28050..d354b779ea 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -317,6 +317,7 @@ def test_set_at(): assert actual == expected +# TODO(havogt): remove after refactoring. def test_fencil_definition(): testee = ir.FencilDefinition( id="f", From e63da777bfd765e97f8f5322de7a656c81c5d6e1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 09:15:07 +0200 Subject: [PATCH 20/61] address more review comments --- src/gt4py/next/iterator/ir.py | 21 ++++++++++++------ src/gt4py/next/iterator/pretty_parser.py | 6 ++--- src/gt4py/next/iterator/pretty_printer.py | 4 ++-- .../iterator/transforms/fencil_to_program.py | 6 +++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 22 ++++++------------- .../iterator_tests/test_pretty_parser.py | 4 ++-- .../iterator_tests/test_pretty_printer.py | 4 ++-- .../gtfn_tests/test_itir_to_gtfn_ir.py | 2 +- 8 files changed, 35 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 2e2c30d6e8..f8f03559f5 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -22,6 +22,11 @@ from gt4py.eve.utils import noninstantiable +# TODO(havogt): +# After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere. +# During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards. + + @noninstantiable class Node(eve.Node): location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) @@ -186,10 +191,6 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "cast_", } -GTIR_BUILTINS = { - "apply_stencil", # `apply_stencil(stencil)` creates field_operator from stencil -} - BUILTINS = { *GRAMMAR_BUILTINS, "named_range", @@ -202,11 +203,17 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "can_deref", "scan", "if_", - *GTIR_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } +# only used in `Program`` not `FencilDefinition` +# TODO(havogt): restructure after refactoring to GTIR +GTIR_BUILTINS = { + *BUILTINS, + "as_field_operator", # `as_field_operator(stencil)` creates field_operator from stencil +} + class FencilDefinition(Node, ValidatedSymbolTableTrait): id: Coerced[SymbolName] @@ -221,7 +228,7 @@ class Stmt(Node): ... class SetAt(Stmt): # from JAX array.at[...].set() - expr: Expr # only `apply_stencil(stencil)(inp0, ...)` in first refactoring + expr: Expr # only `as_field_operator(stencil)(inp0, ...)` in first refactoring domain: Expr target: Expr # `make_tuple` or SymRef @@ -239,7 +246,7 @@ class Program(Node, ValidatedSymbolTableTrait): declarations: List[Temporary] body: List[Stmt] - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS] # TODO(fthaler): just use hashable types in nodes (tuples instead of lists) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 195180af42..1afc7b841d 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -66,7 +66,7 @@ | "·" prec7 -> deref | "¬" prec7 -> bool_not | "↑" prec7 -> lift - | "⇑" prec7 -> apply_stencil + | "⇑" prec7 -> as_field_operator ?prec8: prec9 | prec8 "[" prec0 "]" -> tuple_get @@ -172,8 +172,8 @@ def deref(self, arg: ir.Expr) -> ir.FunCall: def lift(self, arg: ir.Expr) -> ir.FunCall: return ir.FunCall(fun=ir.SymRef(id="lift"), args=[arg]) - def apply_stencil(self, arg: ir.Expr) -> ir.FunCall: - return ir.FunCall(fun=ir.SymRef(id="apply_stencil"), args=[arg]) + def as_field_operator(self, arg: ir.Expr) -> ir.FunCall: + return ir.FunCall(fun=ir.SymRef(id="as_field_operator"), args=[arg]) def astype(self, arg: ir.Expr) -> ir.FunCall: return ir.FunCall(fun=ir.SymRef(id="cast_"), args=[arg]) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 941e9c3073..dd64f38d5d 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -43,7 +43,7 @@ } # replacements for builtin unary operations -UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬", "apply_stencil": "⇑"} +UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬", "as_field_operator": "⇑"} # operator precedence PRECEDENCE: Final = { @@ -63,7 +63,7 @@ "deref": 7, "not_": 7, "lift": 7, - "apply_stencil": 7, + "as_field_operator": 7, "tuple_get": 8, "__call__": 8, } diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index c573b206eb..18aeb9ab0e 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -24,8 +24,10 @@ def apply(cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries) return cls().visit(node) def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - apply_stencil = im.call(im.call("apply_stencil")(node.stencil, node.domain))(*node.inputs) - return itir.SetAt(expr=apply_stencil, domain=node.domain, target=node.output) + as_field_operator = im.call(im.call("as_field_operator")(node.stencil, node.domain))( + *node.inputs + ) + return itir.SetAt(expr=as_field_operator, domain=node.domain, target=node.output) def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: return itir.Program( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 6f63d83db8..0dfb683ac1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -20,7 +20,6 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import global_tmps from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, BinaryExpr, @@ -198,11 +197,11 @@ def _bool_from_literal(node: itir.Node) -> bool: return node.value == "True" -def _is_applied_apply_stencil(arg: itir.Expr) -> TypeGuard[itir.FunCall]: +def _is_applied_as_field_operator(arg: itir.Expr) -> TypeGuard[itir.FunCall]: return ( isinstance(arg, itir.FunCall) and isinstance(arg.fun, itir.FunCall) - and arg.fun.fun == itir.SymRef(id="apply_stencil") + and arg.fun.fun == itir.SymRef(id="as_field_operator") ) @@ -239,22 +238,15 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): @classmethod def apply( cls, - node: itir.Program | global_tmps.FencilWithTemporaries, + node: itir.Program, *, offset_provider: dict, column_axis: Optional[common.Dimension], ) -> FencilDefinition: - if isinstance(node, global_tmps.FencilWithTemporaries): - raise AssertionError() # TODO - prog = node.fencil - elif isinstance(node, itir.Program): - prog = node - else: - raise TypeError( - f"Expected a 'FencilDefinition' or 'FencilWithTemporaries', got '{type(node).__name__}'." - ) + if not isinstance(node, itir.Program): + raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") - grid_type = _get_gridtype(prog.body) + grid_type = _get_gridtype(node.body) return cls( offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type ).visit(node) @@ -496,7 +488,7 @@ def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - assert _is_applied_apply_stencil(node.expr) + assert _is_applied_as_field_operator(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain inputs = node.expr.args diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index 6173c66bdf..c80f993d3a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -69,9 +69,9 @@ def test_lift(): assert actual == expected -def test_apply_stencil(): +def test_as_field_operator(): testee = "⇑x" - expected = ir.FunCall(fun=ir.SymRef(id="apply_stencil"), args=[ir.SymRef(id="x")]) + expected = ir.FunCall(fun=ir.SymRef(id="as_field_operator"), args=[ir.SymRef(id="x")]) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index d354b779ea..7ba464bc92 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -159,8 +159,8 @@ def test_lift(): assert actual == expected -def test_apply_stencil(): - testee = ir.FunCall(fun=ir.SymRef(id="apply_stencil"), args=[ir.SymRef(id="x")]) +def test_as_field_operator(): + testee = ir.FunCall(fun=ir.SymRef(id="as_field_operator"), args=[ir.SymRef(id="x")]) expected = "⇑x" # TODO consider ⇈ actual = pformat(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 507718523f..d31ce389d4 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -53,7 +53,7 @@ def test_get_domains(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("apply_stencil")("deref"))(), + expr=im.call(im.call("as_field_operator")("deref"))(), domain=domain, target=itir.SymRef(id="bar"), ) From 45fba853e1b233b6144911a48ee8dea5abda8c2f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 09:36:13 +0200 Subject: [PATCH 21/61] move tmp to pretty_printer --- src/gt4py/next/iterator/pretty_printer.py | 27 ++++++++++++++++--- .../next/iterator/transforms/global_tmps.py | 16 ----------- .../iterator_tests/test_pretty_printer.py | 17 ++++++++++-- 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index dd64f38d5d..219329335d 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""A pretty printer for the functional IR. +"""A pretty self for the functional IR. Inspired by P. Yelland, “A New Approach to Optimal Code Formatting”, 2015 """ @@ -273,6 +273,20 @@ def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[st ) return self._optimum(h, v) + def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: + start, end = [node.id + " = temporary("], [");"] + args = [] + if node.domain is not None: + args.append(self._hmerge(["domain="], self.visit(node.domain, prec=0))) + if node.dtype is not None: + args.append(self._hmerge(["dtype="], [str(node.dtype)])) + hargs = self._hmerge(*self._hinterleave(args, ", ")) + vargs = self._vmerge(*self._hinterleave(args, ",")) + oargs = self._optimum(hargs, vargs) + h = self._hmerge(start, oargs, end) + v = self._vmerge(start, self._indent(oargs), end) + return self._optimum(h, v) + def visit_SetAt(self, node: ir.SetAt, *, prec: int) -> list[str]: expr = self.visit(node.expr, prec=0) domain = self.visit(node.domain, prec=0) @@ -311,7 +325,7 @@ def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) body = self.visit(node.body, prec=0) - # TODO declarations + declarations = self.visit(node.declarations, prec=0) params = self.visit(node.params, prec=0) hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) @@ -321,9 +335,16 @@ def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: params = self._optimum(hparams, vparams) function_definitions = self._vmerge(*function_definitions) + declarations = self._vmerge(*declarations) body = self._vmerge(*body) - return self._vmerge(params, self._indent(function_definitions), self._indent(body), ["}"]) + return self._vmerge( + params, + self._indent(function_definitions), + self._indent(declarations), + self._indent(body), + ["}"], + ) @classmethod def apply(cls, node: ir.Node, indent: int, width: int) -> str: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5a90cd0360..42d68318a0 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -66,21 +66,6 @@ class FencilWithTemporaries( # Extensions for `PrettyPrinter` for easier debugging -def pformat_Temporary(printer: PrettyPrinter, node: ir.Temporary, *, prec: int) -> list[str]: - start, end = [node.id + " = temporary("], [");"] - args = [] - if node.domain is not None: - args.append(printer._hmerge(["domain="], printer.visit(node.domain, prec=0))) - if node.dtype is not None: - args.append(printer._hmerge(["dtype="], [str(node.dtype)])) - hargs = printer._hmerge(*printer._hinterleave(args, ", ")) - vargs = printer._vmerge(*printer._hinterleave(args, ",")) - oargs = printer._optimum(hargs, vargs) - h = printer._hmerge(start, oargs, end) - v = printer._vmerge(start, printer._indent(oargs), end) - return printer._optimum(h, v) - - def pformat_FencilWithTemporaries( printer: PrettyPrinter, node: FencilWithTemporaries, *, prec: int ) -> list[str]: @@ -110,7 +95,6 @@ def pformat_FencilWithTemporaries( return printer._vmerge(params, printer._indent(body), ["}"]) -PrettyPrinter.visit_Temporary = pformat_Temporary # type: ignore PrettyPrinter.visit_FencilWithTemporaries = pformat_FencilWithTemporaries # type: ignore diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 7ba464bc92..099c6a3403 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -294,6 +294,13 @@ def test_function_definition(): assert actual == expected +def test_temporary(): + testee = ir.Temporary(id="t", domain=ir.SymRef(id="domain"), dtype="float64") + expected = "t = temporary(domain=domain, dtype=float64);" + actual = pformat(testee) + assert actual == expected, actual + + def test_stencil_closure(): testee = ir.StencilClosure( domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), @@ -346,7 +353,13 @@ def test_program(): ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) ], params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - declarations=[], # TODO + declarations=[ + ir.Temporary( + id="tmp", + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + dtype="float64", + ), + ], body=[ ir.SetAt( expr=ir.SymRef(id="x"), @@ -356,5 +369,5 @@ def test_program(): ], ) actual = pformat(testee) - expected = "f(d, x, y) {\n g = λ(x) → x;\n y @ cartesian_domain() ← x;\n}" + expected = "f(d, x, y) {\n g = λ(x) → x;\n tmp = temporary(domain=cartesian_domain(), dtype=float64);\n y @ cartesian_domain() ← x;\n}" assert actual == expected From 1a702184271dec84cd560af023c01a18ddf4634b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 09:50:17 +0200 Subject: [PATCH 22/61] pparse for temporaries --- src/gt4py/next/iterator/pretty_parser.py | 13 +++++++++++-- .../iterator_tests/test_pretty_parser.py | 19 ++++++++++++++++--- .../iterator_tests/test_pretty_printer.py | 2 +- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 1afc7b841d..3d938e01c8 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -23,6 +23,7 @@ GRAMMAR = """ start: fencil_definition | function_definition + | declaration | stencil_closure | set_at | program @@ -83,10 +84,11 @@ named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" + declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" prec0 ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" set_at: prec0 "@" prec0 "←" prec1 ";" fencil_definition: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( stencil_closure )+ "}" - program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( set_at )+ "}" + program: ID_NAME "(" ( SYM "," )* SYM ")" "{" ( function_definition )* ( declaration )* ( set_at )+ "}" %import common (CNAME, SIGNED_FLOAT, SIGNED_INT, WS) %ignore WS @@ -210,6 +212,10 @@ def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: output, stencil, *inputs, domain = args return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) + def declaration(self, *args: ir.Expr) -> ir.Temporary: + tid, domain, dtype = args + return ir.Temporary(id=tid, domain=domain, dtype=dtype) + def set_at(self, *args: ir.Expr) -> ir.SetAt: target, domain, expr = args return ir.SetAt(expr=expr, domain=domain, target=target) @@ -235,11 +241,14 @@ def program(self, fid: str, *args: ir.Node) -> ir.Program: params = [] function_definitions = [] body = [] + declarations = [] for arg in args: if isinstance(arg, ir.Sym): params.append(arg) elif isinstance(arg, ir.FunctionDefinition): function_definitions.append(arg) + elif isinstance(arg, ir.Temporary): + declarations.append(arg) else: assert isinstance(arg, ir.SetAt) body.append(arg) @@ -248,7 +257,7 @@ def program(self, fid: str, *args: ir.Node) -> ir.Program: function_definitions=function_definitions, params=params, body=body, - declarations=[], # TODO + declarations=declarations, ) def start(self, arg: ir.Node) -> ir.Node: diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index c80f993d3a..969f30bdce 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -190,6 +190,13 @@ def test_function_definition(): assert actual == expected +def test_temporary(): + testee = "t = temporary(domain=domain, dtype=float64);" + expected = ir.Temporary(id="t", domain=ir.SymRef(id="domain"), dtype=ir.SymRef(id="float64")) + actual = pparse(testee) + assert actual == expected + + def test_stencil_closure(): testee = "y ← (deref)(x) @ cartesian_domain();" expected = ir.StencilClosure( @@ -213,7 +220,7 @@ def test_set_at(): assert actual == expected -# TODO: remove after refactoring +# TODO(havogt): remove after refactoring to GTIR def test_fencil_definition(): testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" expected = ir.FencilDefinition( @@ -236,13 +243,20 @@ def test_fencil_definition(): def test_program(): - testee = "f(d, x, y) {\n g = λ(x) → x;\n y @ cartesian_domain() ← x;\n}" + testee = "f(d, x, y) {\n g = λ(x) → x;\n tmp = temporary(domain=cartesian_domain(), dtype=float64);\n y @ cartesian_domain() ← x;\n}" expected = ir.Program( id="f", function_definitions=[ ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) ], params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], + declarations=[ + ir.Temporary( + id="tmp", + domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + dtype=ir.SymRef(id="float64"), + ), + ], body=[ ir.SetAt( expr=ir.SymRef(id="x"), @@ -250,7 +264,6 @@ def test_program(): target=ir.SymRef(id="y"), ) ], - declarations=[], # TODO ) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 099c6a3403..b4a894c8cc 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -298,7 +298,7 @@ def test_temporary(): testee = ir.Temporary(id="t", domain=ir.SymRef(id="domain"), dtype="float64") expected = "t = temporary(domain=domain, dtype=float64);" actual = pformat(testee) - assert actual == expected, actual + assert actual == expected def test_stencil_closure(): From c39c6039c9674db596b8d0e974804081c1b5f5d8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 09:55:52 +0200 Subject: [PATCH 23/61] rename gtfn.FencilDefinition -> Program --- .../next/program_processors/codegens/gtfn/codegen.py | 6 ++---- .../next/program_processors/codegens/gtfn/gtfn_ir.py | 2 +- .../program_processors/codegens/gtfn/itir_to_gtfn_ir.py | 8 ++++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 987fe6ca27..cc07fac55c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -234,9 +234,7 @@ def visit_TemporaryAllocation(self, node: gtfn_ir.TemporaryAllocation, **kwargs: "auto {id} = gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {tmp_sizes});" ) - def visit_FencilDefinition( - self, node: gtfn_ir.FencilDefinition, **kwargs: Any - ) -> Union[str, Collection[str]]: + def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Collection[str]]: self.is_cartesian = node.grid_type == common.GridType.CARTESIAN self.user_defined_function_ids = list( str(fundef.id) for fundef in node.function_definitions @@ -248,7 +246,7 @@ def visit_FencilDefinition( **kwargs, ) - FencilDefinition = as_mako( + Program = as_mako( """ #include #include diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index b266f577c3..65ea6c9137 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -155,7 +155,7 @@ class TagDefinition(Node): alias: Optional[Union[str, SymRef]] = None -class FencilDefinition(Node, ValidatedSymbolTableTrait): +class Program(Node, ValidatedSymbolTableTrait): id: SymbolName params: list[Sym] function_definitions: list[ diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 0dfb683ac1..de9f63525b 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -25,13 +25,13 @@ BinaryExpr, CartesianDomain, CastExpr, - FencilDefinition, FunCall, FunctionDefinition, IntegralConstant, Lambda, Literal, OffsetLiteral, + Program, Scan, ScanExecution, ScanPassDefinition, @@ -242,7 +242,7 @@ def apply( *, offset_provider: dict, column_axis: Optional[common.Dimension], - ) -> FencilDefinition: + ) -> Program: if not isinstance(node, itir.Program): raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") @@ -527,7 +527,7 @@ def visit_SetAt( backend=backend, ) - def visit_Program(self, node: itir.Program, **kwargs: Any) -> FencilDefinition: + def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: extracted_functions: list[Union[FunctionDefinition, ScanPassDefinition]] = [] executions = self.visit(node.body, extracted_functions=extracted_functions) executions = self._merge_scans(executions) @@ -536,7 +536,7 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> FencilDefinition: **_collect_dimensions_from_domain(node.body), **_collect_offset_definitions(node, self.grid_type, self.offset_provider), } - return FencilDefinition( + return Program( id=SymbolName(node.id), params=self.visit(node.params), executions=executions, From 705cfcff14a3017c32e644e78f7822e24d28ff4d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 10:33:11 +0200 Subject: [PATCH 24/61] remove TODO --- .../next_tests/unit_tests/iterator_tests/test_pretty_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index b4a894c8cc..2f02812f0f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -161,7 +161,7 @@ def test_lift(): def test_as_field_operator(): testee = ir.FunCall(fun=ir.SymRef(id="as_field_operator"), args=[ir.SymRef(id="x")]) - expected = "⇑x" # TODO consider ⇈ + expected = "⇑x" actual = pformat(testee) assert actual == expected From c5e78c454b33e5910554ae68696bb30259aa270f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 10:35:05 +0200 Subject: [PATCH 25/61] Apply suggestions from code review --- src/gt4py/next/iterator/pretty_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 219329335d..d08b184469 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""A pretty self for the functional IR. +"""A pretty printer for the functional IR. Inspired by P. Yelland, “A New Approach to Optimal Code Formatting”, 2015 """ From a336bf55b2f8fec9c3f0d57532c31cfcd5e35c57 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 11:54:45 +0200 Subject: [PATCH 26/61] rename as_field_operator -> as_fieldop --- src/gt4py/next/iterator/ir.py | 4 ++-- src/gt4py/next/iterator/pretty_parser.py | 6 +++--- src/gt4py/next/iterator/pretty_printer.py | 4 ++-- src/gt4py/next/iterator/transforms/fencil_to_program.py | 6 ++---- .../program_processors/codegens/gtfn/itir_to_gtfn_ir.py | 6 +++--- .../unit_tests/iterator_tests/test_pretty_parser.py | 4 ++-- .../unit_tests/iterator_tests/test_pretty_printer.py | 4 ++-- 7 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index f8f03559f5..5d1907b26f 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -211,7 +211,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib # TODO(havogt): restructure after refactoring to GTIR GTIR_BUILTINS = { *BUILTINS, - "as_field_operator", # `as_field_operator(stencil)` creates field_operator from stencil + "as_fieldop", # `as_fieldop(stencil)` creates field_operator from stencil } @@ -228,7 +228,7 @@ class Stmt(Node): ... class SetAt(Stmt): # from JAX array.at[...].set() - expr: Expr # only `as_field_operator(stencil)(inp0, ...)` in first refactoring + expr: Expr # only `as_fieldop(stencil)(inp0, ...)` in first refactoring domain: Expr target: Expr # `make_tuple` or SymRef diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 3d938e01c8..f6d532ee30 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -67,7 +67,7 @@ | "·" prec7 -> deref | "¬" prec7 -> bool_not | "↑" prec7 -> lift - | "⇑" prec7 -> as_field_operator + | "⇑" prec7 -> as_fieldop ?prec8: prec9 | prec8 "[" prec0 "]" -> tuple_get @@ -174,8 +174,8 @@ def deref(self, arg: ir.Expr) -> ir.FunCall: def lift(self, arg: ir.Expr) -> ir.FunCall: return ir.FunCall(fun=ir.SymRef(id="lift"), args=[arg]) - def as_field_operator(self, arg: ir.Expr) -> ir.FunCall: - return ir.FunCall(fun=ir.SymRef(id="as_field_operator"), args=[arg]) + def as_fieldop(self, arg: ir.Expr) -> ir.FunCall: + return ir.FunCall(fun=ir.SymRef(id="as_fieldop"), args=[arg]) def astype(self, arg: ir.Expr) -> ir.FunCall: return ir.FunCall(fun=ir.SymRef(id="cast_"), args=[arg]) diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index d08b184469..3f224d2ef4 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -43,7 +43,7 @@ } # replacements for builtin unary operations -UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬", "as_field_operator": "⇑"} +UNARY_OPS: Final = {"deref": "·", "lift": "↑", "not_": "¬", "as_fieldop": "⇑"} # operator precedence PRECEDENCE: Final = { @@ -63,7 +63,7 @@ "deref": 7, "not_": 7, "lift": 7, - "as_field_operator": 7, + "as_fieldop": 7, "tuple_get": 8, "__call__": 8, } diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index 18aeb9ab0e..fff12a6c88 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -24,10 +24,8 @@ def apply(cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries) return cls().visit(node) def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - as_field_operator = im.call(im.call("as_field_operator")(node.stencil, node.domain))( - *node.inputs - ) - return itir.SetAt(expr=as_field_operator, domain=node.domain, target=node.output) + as_fieldop = im.call(im.call("as_fieldop")(node.stencil, node.domain))(*node.inputs) + return itir.SetAt(expr=as_fieldop, domain=node.domain, target=node.output) def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: return itir.Program( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index de9f63525b..bbde720677 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -197,11 +197,11 @@ def _bool_from_literal(node: itir.Node) -> bool: return node.value == "True" -def _is_applied_as_field_operator(arg: itir.Expr) -> TypeGuard[itir.FunCall]: +def _is_applied_as_fieldop(arg: itir.Expr) -> TypeGuard[itir.FunCall]: return ( isinstance(arg, itir.FunCall) and isinstance(arg.fun, itir.FunCall) - and arg.fun.fun == itir.SymRef(id="as_field_operator") + and arg.fun.fun == itir.SymRef(id="as_fieldop") ) @@ -488,7 +488,7 @@ def visit_Stmt(self, node: itir.Stmt, **kwargs: Any) -> None: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: - assert _is_applied_as_field_operator(node.expr) + assert _is_applied_as_fieldop(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain inputs = node.expr.args diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index 969f30bdce..b02c610aff 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -69,9 +69,9 @@ def test_lift(): assert actual == expected -def test_as_field_operator(): +def test_as_fieldop(): testee = "⇑x" - expected = ir.FunCall(fun=ir.SymRef(id="as_field_operator"), args=[ir.SymRef(id="x")]) + expected = ir.FunCall(fun=ir.SymRef(id="as_fieldop"), args=[ir.SymRef(id="x")]) actual = pparse(testee) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 2f02812f0f..bc1372cea6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -159,8 +159,8 @@ def test_lift(): assert actual == expected -def test_as_field_operator(): - testee = ir.FunCall(fun=ir.SymRef(id="as_field_operator"), args=[ir.SymRef(id="x")]) +def test_as_fieldop(): + testee = ir.FunCall(fun=ir.SymRef(id="as_fieldop"), args=[ir.SymRef(id="x")]) expected = "⇑x" actual = pformat(testee) assert actual == expected From af16f4083694e6083599968b7f5412d66d320513 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Apr 2024 13:02:36 +0200 Subject: [PATCH 27/61] missed a file --- .../codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index d31ce389d4..bfb511751d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -53,7 +53,7 @@ def test_get_domains(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_field_operator")("deref"))(), + expr=im.call(im.call("as_fieldop")("deref"))(), domain=domain, target=itir.SymRef(id="bar"), ) From bc2c2d37e0f656b93f57ef6ae2a06763951a57b2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 16 Apr 2024 13:38:12 +0200 Subject: [PATCH 28/61] add fencil2program to roundtrip --- src/gt4py/next/iterator/embedded.py | 2 +- src/gt4py/next/iterator/transforms/pass_manager.py | 3 +++ src/gt4py/next/program_processors/runners/roundtrip.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3fd476a2be..beee47ca26 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1504,7 +1504,7 @@ def set_at(expr, domain, target) -> None: @builtins.as_fieldop.register(EMBEDDED) -def as_fieldop(fun, *, domain=None): +def as_fieldop(fun, domain=None): def impl(*args, **kwargs): new_domain = None # if common.is_domain_like(domain): diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8a19203275..c3232b867e 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -89,6 +89,9 @@ def apply_common_transforms( ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.FencilDefinition | FencilWithTemporaries: + if isinstance(ir, itir.Program): + # TODO(havogt): during refactoring to GTIR, we bypass transformations in case we already translated to itir.Program + return ir icdlv_uids = eve_utils.UIDGenerator() if lift_mode is None: diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 035d9e3362..b1e419b13e 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -28,7 +28,7 @@ from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms -from gt4py.next.iterator.transforms import global_tmps as gtmps_transform +from gt4py.next.iterator.transforms import fencil_to_program, global_tmps as gtmps_transform from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import modular_executor, processor_interface as ppi @@ -146,6 +146,8 @@ def fencil_generator( ir, lift_mode=lift_mode, offset_provider=offset_provider ) + ir = fencil_to_program.FencilToProgram.apply(ir) + program = EmbeddedDSL.apply(ir) # format output in debug mode for better debuggability From f45b460735d5997fc33867c64032e7bf25eb29d9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Apr 2024 09:23:22 +0200 Subject: [PATCH 29/61] pre-allocate result buffer --- src/gt4py/next/iterator/embedded.py | 87 +++++++++---------- src/gt4py/next/iterator/ir.py | 4 +- src/gt4py/next/iterator/tracing.py | 2 +- .../iterator/transforms/fencil_to_program.py | 4 +- .../next/iterator/transforms/pass_manager.py | 2 +- .../program_processors/runners/roundtrip.py | 9 +- .../iterator_tests/test_program.py | 4 +- 7 files changed, 57 insertions(+), 55 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index beee47ca26..8a7f5e3f89 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -53,7 +53,11 @@ runtime_checkable, ) from gt4py.next import common -from gt4py.next.embedded import context as embedded_context, exceptions as embedded_exceptions +from gt4py.next.embedded import ( + context as embedded_context, + exceptions as embedded_exceptions, + operators, +) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime @@ -1498,42 +1502,44 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: @runtime.set_at.register(EMBEDDED) def set_at(expr, domain, target) -> None: - # domain is assumed to match the domain of outer apply stencil / fields - if callable(expr): # lazy apply stencil #TODO check - expr(target, lazy_domain=domain) + operators._tuple_assign_field(target, expr, common.domain(domain)) + + +def _compute_point( + sten: Callable, ins, pos, column_range: common.NamedRange | eve.NothingType = eve.NOTHING +): + promoted_ins = [promote_scalars(inp) for inp in ins] + ins_iters = list( + make_in_iterator( + inp, + pos, + column_axis=column_range.dim.value if column_range is not eve.NOTHING else None, + ) + for inp in promoted_ins + ) + return sten(*ins_iters) -@builtins.as_fieldop.register(EMBEDDED) -def as_fieldop(fun, domain=None): - def impl(*args, **kwargs): - new_domain = None - # if common.is_domain_like(domain): - new_domain = domain - # else: - # assert callable(domain) - # new_domain = domain(*args, **kwargs) - - # TODO this only works if the as_fieldop is directly in set_at - # for the clean solution we need to pre-allocate the result buffer (see strategy for scan in field_view embedded) - def lazy_as_fieldop(out, *, lazy_domain=None): - if new_domain is None: - assert lazy_domain is not None - domain = lazy_domain - else: - domain = new_domain - - closure( - # cartesian_domain( - # *(named_range(d, start, stop) for d, (start, stop) in domain.items()) - # ), - domain, - fun, - out, - list(args), # TODO kwargs - ) - return out +# def _allocate_out(sten, ins, pos) -> common.MutableField: - return lazy_as_fieldop + +@builtins.as_fieldop.register(EMBEDDED) +def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.UnstructuredDomain): + def impl(*args): + # TODO extract function, move private utils + pos = next(_domain_iterator(_dimension_to_tag(domain))) + single_point_result = _compute_point(fun, args, pos) + + xp = operators._get_array_ns(*args) + out = operators._construct_scan_array(common.domain(domain), xp)(single_point_result) + + closure( + domain, + fun, + out, + list(args), + ) + return out return impl @@ -1571,18 +1577,7 @@ def closure( def _iterate(): for pos in _domain_iterator(domain): - promoted_ins = [promote_scalars(inp) for inp in ins] - ins_iters = list( - make_in_iterator( - inp, - pos, - column_axis=column_range.dim.value - if column_range is not eve.NOTHING - else None, - ) - for inp in promoted_ins - ) - res = sten(*ins_iters) + res = _compute_point(sten, ins, pos, column_range) if column_range is eve.NOTHING: assert _is_concrete_position(pos) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 6f594dd8dc..98f2cb5aee 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -20,6 +20,7 @@ from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable +from gt4py.next import common # TODO(havogt): @@ -91,6 +92,7 @@ class OffsetLiteral(Expr): class AxisLiteral(Expr): value: str + kind: common.DimensionKind = common.DimensionKind.HORIZONTAL class SymRef(Expr): @@ -211,7 +213,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib # TODO(havogt): restructure after refactoring to GTIR GTIR_BUILTINS = { *BUILTINS, - "as_fieldop", # `as_fieldop(stencil)` creates field_operator from stencil + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) } diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index d4fc993102..a63dda5c56 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -148,7 +148,7 @@ def make_node(o): if isinstance(o, Node): return o if isinstance(o, common.Dimension): - return AxisLiteral(value=o.value) + return AxisLiteral(value=o.value, kind=o.kind) if callable(o): if o.__name__ == "": return lambdadef(o) diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index fff12a6c88..3df1160f3f 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -20,7 +20,9 @@ class FencilToProgram(eve.NodeTranslator): @classmethod - def apply(cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries) -> itir.Program: + def apply( + cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program + ) -> itir.Program: return cls().visit(node) def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index c3232b867e..3e4d021432 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -88,7 +88,7 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, -) -> itir.FencilDefinition | FencilWithTemporaries: +) -> itir.FencilDefinition | FencilWithTemporaries | itir.Program: if isinstance(ir, itir.Program): # TODO(havogt): during refactoring to GTIR, we bypass transformations in case we already translated to itir.Program return ir diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index b1e419b13e..ba7a87f07a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -162,8 +162,8 @@ def fencil_generator( .if_isinstance(str) .to_set() ) - axis_literals: Iterable[str] = ( - ir.pre_walk_values().if_isinstance(itir.AxisLiteral).getattr("value").to_set() + axis_literals: Iterable[itir.AxisLiteral] = ( + ir.pre_walk_values().if_isinstance(itir.AxisLiteral).to_set() ) if use_embedded: @@ -185,7 +185,10 @@ def fencil_generator( if debug: print(source_file_name) offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] - axis_literals = [f'{o} = gtx.Dimension("{o}")' for o in axis_literals] + axis_literals = [ + f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' + for o in axis_literals + ] source_file.write(header) source_file.write("\n".join(offset_literals)) source_file.write("\n") diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index e6d2e8fe44..3d79a91039 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -48,8 +48,8 @@ def copy_stencil(inp): @fendef def copy_program(inp, out, size): set_at( - # as_fieldop(copy_stencil, domain=cartesian_domain(named_range(I, 0, size)))(inp), - as_fieldop(copy_stencil)(inp), + as_fieldop(copy_stencil, domain=cartesian_domain(named_range(I, 0, size)))(inp), + # as_fieldop(copy_stencil)(inp), cartesian_domain(named_range(I, 0, size)), out, ) From 5d4fc3d42aa04f013f616599612a3712239776fc Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Apr 2024 09:29:28 +0200 Subject: [PATCH 30/61] fix tracer context --- src/gt4py/next/iterator/tracing.py | 1 + src/gt4py/next/program_processors/runners/roundtrip.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index a63dda5c56..2d85245214 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -231,6 +231,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): type(self).fundefs = [] type(self).closures = [] + type(self).body = [] iterator.builtins.builtin_dispatch.pop_key() diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index ba7a87f07a..930e5f0db8 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -162,7 +162,7 @@ def fencil_generator( .if_isinstance(str) .to_set() ) - axis_literals: Iterable[itir.AxisLiteral] = ( + axis_literals_set: Iterable[itir.AxisLiteral] = ( ir.pre_walk_values().if_isinstance(itir.AxisLiteral).to_set() ) @@ -187,7 +187,7 @@ def fencil_generator( offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] axis_literals = [ f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' - for o in axis_literals + for o in axis_literals_set ] source_file.write(header) source_file.write("\n".join(offset_literals)) From 1d93192aea710e00952c015043e38e9b8e9907ff Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Apr 2024 16:54:11 +0200 Subject: [PATCH 31/61] first (almost) complete embedded version --- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/iterator/embedded.py | 78 +++++++++++++------ .../codegens/gtfn/gtfn_module.py | 2 +- .../program_processors/runners/roundtrip.py | 6 +- .../ffront_tests/test_gt4py_builtins.py | 2 + .../ffront_tests/test_ffront_fvm_nabla.py | 6 +- 6 files changed, 67 insertions(+), 29 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a7e9751c4e..d7727fa097 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -340,7 +340,7 @@ def _construct_itir_domain_arg( domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), - args=[itir.AxisLiteral(value=dim.value), lower, upper], + args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper], ) ) domain_args_kind.append(dim.kind) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 8a7f5e3f89..d3ccddba40 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -286,10 +286,20 @@ def cast_(obj, new_dtype): @builtins.not_.register(EMBEDDED) def not_(a): if isinstance(a, Column): - return np.logical_not(a.data) + return np.logical_not(a) return not a +@builtins.gamma.register(EMBEDDED) +def gamma(a): + gamma_ = np.vectorize(math.gamma) + if isinstance(a, Column): + return Column(kstart=a.kstart, data=gamma_(a.data)) + res = gamma_(a) + assert res.ndim == 0 + return res.item() + + @builtins.and_.register(EMBEDDED) def and_(a, b): if isinstance(a, Column): @@ -491,8 +501,7 @@ def promote_scalars(val: CompositeOfScalarOrField): decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable if math_builtin_name == "gamma": - # numpy has no gamma function - impl = np.vectorize(math.gamma) + continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent # with compiled backends. Currently using Python types to preserve @@ -500,6 +509,7 @@ def promote_scalars(val: CompositeOfScalarOrField): impl = python_builtins[math_builtin_name] else: impl = getattr(np, math_builtin_name) + globals()[math_builtin_name] = decorator(impl) @@ -1502,6 +1512,7 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: @runtime.set_at.register(EMBEDDED) def set_at(expr, domain, target) -> None: + # TODO we can't set the column_range here, because it's too late: `expr` already evaluated operators._tuple_assign_field(target, expr, common.domain(domain)) @@ -1513,28 +1524,59 @@ def _compute_point( make_in_iterator( inp, pos, - column_axis=column_range.dim.value if column_range is not eve.NOTHING else None, + column_axis=column_range.dim.value + if isinstance(column_range, common.NamedRange) + else None, ) for inp in promoted_ins ) return sten(*ins_iters) -# def _allocate_out(sten, ins, pos) -> common.MutableField: +def _extract_column_range(domain) -> common.NamedRange | eve.NothingType: + if (col_range_placeholder := embedded_context.closure_column_range.get(None)) is not None: + assert ( + col_range_placeholder.unit_range.is_empty() + ) # check it's just the placeholder with empty range + column_axis = col_range_placeholder.dim + if column_axis is not None and column_axis.value in domain: + return common.NamedRange( + column_axis, + common.UnitRange(domain[column_axis.value].start, domain[column_axis.value].stop), + ) + return eve.NOTHING + + +# TODO handle in clean way +def _np_void_to_tuple(a): + if isinstance(a, np.void): + return tuple(_np_void_to_tuple(elem) for elem in a) + return a @builtins.as_fieldop.register(EMBEDDED) -def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.UnstructuredDomain): +def as_fieldop(fun: Callable, domain_: runtime.CartesianDomain | runtime.UnstructuredDomain): def impl(*args): # TODO extract function, move private utils - pos = next(_domain_iterator(_dimension_to_tag(domain))) - single_point_result = _compute_point(fun, args, pos) + domain = _dimension_to_tag(domain_) + col_range = _extract_column_range(domain) + if col_range is not eve.NOTHING: + del domain[col_range.dim.value] + + pos = next(_domain_iterator(domain)) + with embedded_context.new_context(closure_column_range=col_range) as ctx: + single_point_result = ctx.run(_compute_point, fun, args, pos, col_range) + if isinstance(single_point_result, Column): + single_point_result = single_point_result.data[0] + single_point_result = _np_void_to_tuple(single_point_result) xp = operators._get_array_ns(*args) - out = operators._construct_scan_array(common.domain(domain), xp)(single_point_result) + out = operators._construct_scan_array(common.domain(domain_), xp)(single_point_result) + + # TODO `out` gets allocated in the order of domain_, but might not match the order of `target` in set_at closure( - domain, + _dimension_to_tag(domain_), fun, out, list(args), @@ -1558,18 +1600,10 @@ def closure( if not (isinstance(out, common.Field) or is_tuple_of_field(out)): raise TypeError("'Out' needs to be a located field.") - column_range: common.NamedRange | eve.NothingType = eve.NOTHING - if (col_range_placeholder := embedded_context.closure_column_range.get(None)) is not None: - assert ( - col_range_placeholder.unit_range.is_empty() - ) # check it's just the placeholder with empty range - column_axis = col_range_placeholder.dim - if column_axis is not None and column_axis.value in domain: - column_range = common.NamedRange( - column_axis, - common.UnitRange(domain[column_axis.value].start, domain[column_axis.value].stop), - ) - del domain[column_axis.value] + column_range: common.NamedRange | eve.NothingType = _extract_column_range(domain) + + if isinstance(column_range, common.NamedRange): + del domain[column_range.dim.value] out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d28e513093..8986ace116 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -183,7 +183,7 @@ def _preprocess_program( program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], runtime_lift_mode: Optional[LiftMode], - ) -> itir.FencilDefinition | global_tmps.FencilWithTemporaries: + ) -> itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program: # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added # to the interface of all (or at least all of concern) backends, but instead should be # configured in the backend itself (like it is here), until then we respect the argument diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 930e5f0db8..daa792075f 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -26,7 +26,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from gt4py.next import allocators as next_allocators, backend as next_backend, common +from gt4py.next import allocators as next_allocators, backend as next_backend, common, config from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms from gt4py.next.iterator.transforms import fencil_to_program, global_tmps as gtmps_transform from gt4py.next.otf import stages, workflow @@ -225,7 +225,7 @@ def execute_roundtrip( *args: Any, column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], - debug: bool = False, + debug: bool = config.DEBUG, lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, dispatch_backend: Optional[ppi.ProgramExecutor] = None, ) -> None: @@ -246,7 +246,7 @@ def execute_roundtrip( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: bool = False + debug: bool = config.DEBUG lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 0bb1e78582..404a91eaa5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -111,6 +111,8 @@ def reduction_ke_field( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case, fop): + if fop == reduction_ke_field: # TODO need to resolve order of dimensions + pytest.skip() v2e_table = unstructured_case.offset_provider["V2E"].table edge_f = cases.allocate(unstructured_case, fop, "edge_f")() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index aeed607b01..bdb50f27ff 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -87,7 +87,9 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False ) - compute_zavgS.with_backend(executor)(pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v}) + compute_zavgS.with_backend(exec_alloc_descriptor)( + pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) assert_close(388241977.58389181, np.max(zavgS.asnumpy())) @@ -113,7 +115,7 @@ def test_ffront_nabla(exec_alloc_descriptor): atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 ) - pnabla.with_backend(executor)( + pnabla.with_backend(exec_alloc_descriptor)( pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} ) From 5882c28919fb4077fb5f758c37781b46d0a3f6c5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Apr 2024 18:19:16 +0200 Subject: [PATCH 32/61] add dim kind to print/parse --- src/gt4py/next/iterator/ir.py | 2 ++ src/gt4py/next/iterator/pretty_parser.py | 6 ++++-- src/gt4py/next/iterator/pretty_printer.py | 7 ++++++- .../iterator_tests/test_pretty_parser.py | 18 ++++++++++++++++-- .../iterator_tests/test_pretty_printer.py | 18 ++++++++++++++++-- 5 files changed, 44 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 98f2cb5aee..ea2a84ee52 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -23,6 +23,8 @@ from gt4py.next import common +DimensionKind = common.DimensionKind + # TODO(havogt): # After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere. # During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards. diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index f6d532ee30..be57e4ccd7 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -36,7 +36,7 @@ OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ" _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL ID_NAME: CNAME - AXIS_NAME: CNAME + AXIS_NAME: CNAME ("ᵥ" | "ₕ") ?prec0: prec1 | "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam @@ -123,7 +123,9 @@ def ID_NAME(self, value: lark_lexer.Token) -> str: return value.value def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral: - return ir.AxisLiteral(value=value.value) + name = value.value[:-1] + kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL + return ir.AxisLiteral(value=name, kind=kind) def lam(self, *args: ir.Node) -> ir.Lambda: *params, expr = args diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 3f224d2ef4..64d56254a5 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -145,7 +145,12 @@ def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str] return [str(node.value) + "ₒ"] def visit_AxisLiteral(self, node: ir.AxisLiteral, *, prec: int) -> list[str]: - return [str(node.value)] + kind = "" + if node.kind == ir.DimensionKind.HORIZONTAL: + kind = "ₕ" + elif node.kind == ir.DimensionKind.VERTICAL: + kind = "ᵥ" + return [str(node.value) + kind] def visit_SymRef(self, node: ir.SymRef, *, prec: int) -> list[str]: return [node.id] diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index b02c610aff..003a7663d0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -130,8 +130,8 @@ def test_make_tuple(): assert actual == expected -def test_named_range(): - testee = "IDim: [x, y)" +def test_named_range_horizontal(): + testee = "IDimₕ: [x, y)" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], @@ -140,6 +140,20 @@ def test_named_range(): assert actual == expected +def test_named_range_vertical(): + testee = "IDimᵥ: [x, y)" + expected = ir.FunCall( + fun=ir.SymRef(id="named_range"), + args=[ + ir.AxisLiteral(value="IDim", kind=ir.DimensionKind.VERTICAL), + ir.SymRef(id="x"), + ir.SymRef(id="y"), + ], + ) + actual = pparse(testee) + assert actual == expected + + def test_cartesian_domain(): testee = "c⟨ x, y ⟩" expected = ir.FunCall( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index bc1372cea6..0ec74f5d5d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -218,12 +218,26 @@ def test_make_tuple(): assert actual == expected -def test_named_range(): +def test_axis_literal_horizontal(): + testee = ir.AxisLiteral(value="I", kind=ir.DimensionKind.HORIZONTAL) + expected = "Iₕ: [x, y)" + actual = pformat(testee) + assert actual == expected + + +def test_axis_literal_vertical(): + testee = ir.AxisLiteral(value="I", kind=ir.DimensionKind.VERTICAL) + expected = "Iᵥ" + actual = pformat(testee) + assert actual == expected + + +def test_named_range_horizontal(): testee = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], ) - expected = "IDim: [x, y)" + expected = "IDimₕ: [x, y)" actual = pformat(testee) assert actual == expected From b7cbf167c152aa08cae7a2a2aaa792563d43701f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Apr 2024 20:58:04 +0200 Subject: [PATCH 33/61] fix tests --- .../feature_tests/iterator_tests/test_program.py | 3 +-- .../unit_tests/iterator_tests/test_pretty_printer.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index 3d79a91039..503c0e9e1d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -48,8 +48,7 @@ def copy_stencil(inp): @fendef def copy_program(inp, out, size): set_at( - as_fieldop(copy_stencil, domain=cartesian_domain(named_range(I, 0, size)))(inp), - # as_fieldop(copy_stencil)(inp), + as_fieldop(copy_stencil, cartesian_domain(named_range(I, 0, size)))(inp), cartesian_domain(named_range(I, 0, size)), out, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 0ec74f5d5d..d0bc4427b8 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -220,7 +220,7 @@ def test_make_tuple(): def test_axis_literal_horizontal(): testee = ir.AxisLiteral(value="I", kind=ir.DimensionKind.HORIZONTAL) - expected = "Iₕ: [x, y)" + expected = "Iₕ" actual = pformat(testee) assert actual == expected From e97ca25ff83d32a6213f5e261514cfc2919186e4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Apr 2024 21:21:29 +0200 Subject: [PATCH 34/61] cleanup test_program --- .../iterator_tests/test_program.py | 60 ++++++------------- 1 file changed, 17 insertions(+), 43 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index 503c0e9e1d..3b28717921 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -18,27 +18,14 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range from gt4py.next.iterator.runtime import fendef, fundef, set_at +from gt4py.next.program_processors.formatters import type_check +from gt4py.next.program_processors.runners import dace, gtfn from next_tests.unit_tests.conftest import program_processor, run_processor I = gtx.Dimension("I") -_isize = 10 - - -@pytest.fixture -def dom(): - return {I: range(_isize)} - - -def a_field(): - return gtx.as_field([I], np.arange(0, _isize, dtype=np.float64)) - - -def out_field(): - return gtx.as_field([I], np.zeros(shape=(_isize,))) - @fundef def copy_stencil(inp): @@ -54,36 +41,23 @@ def copy_program(inp, out, size): ) -# @fundef -# def plus_stencil(inp0,inp1): -# return plus(deref(inp0),deref(inp1)) +def test_prog(program_processor): + program_processor, validate = program_processor -# set_at( -# # as_fieldop(copy_stencil, domain=cartesian_domain(named_range(I, 0, size)))(inp), -# as_fieldop(plus_stencil)(inp0, as_fieldop(plus_stencil)(inp1,inp2)), -# cartesian_domain(named_range(I, 0, size)), -# out, -# ) + if program_processor in [ + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn.run_gtfn_with_temporaries.executor, + dace.run_dace_cpu.executor, + type_check.check_type_inference, + ]: + # TODO(havogt): Remove skip during refactoring to GTIR + pytest.skip("Executor requires to start from fencil.") + isize = 10 + inp = gtx.as_field([I], np.arange(0, isize, dtype=np.float64)) + out = gtx.as_field([I], np.zeros(shape=(isize,))) -def test_prog(): - validate = True - - inp = a_field() - out = out_field() - - copy_program(inp, out, _isize, offset_provider={}) + run_processor(copy_program, program_processor, inp, out, isize, offset_provider={}) if validate: assert np.allclose(inp.asnumpy(), out.asnumpy()) - - -# example for -# @field_operator -# def sum(a, b, c): -# a + b + c - -# def plus(a,b): -# return deref(a)+deref(b) - -# def sum_prog(a, b, c, out): -# set_at(as_fieldop(plus)(a, as_fieldop(plus)(b, c)), out.domain, out) From 8c2bd8f623477c38da5fd59b1ea8df92c2467d64 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Apr 2024 13:48:18 +0200 Subject: [PATCH 35/61] re-enable lift mode in roundtrip --- src/gt4py/next/iterator/ir.py | 2 + .../program_processors/runners/roundtrip.py | 53 ++++++------------- .../iterator_tests/test_fvm_nabla.py | 5 +- 3 files changed, 20 insertions(+), 40 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index ea2a84ee52..ee4f87e89a 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -93,6 +93,8 @@ class OffsetLiteral(Expr): class AxisLiteral(Expr): + # TODO(havogt): Refactor to use declare Axis/Dimension at the Program level. + # Now every use of the literal has to provide the kind, where usually we only care of the name. value: str kind: common.DimensionKind = common.DimensionKind.HORIZONTAL diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index daa792075f..612bdb3979 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -19,6 +19,7 @@ import pathlib import tempfile import textwrap +import warnings from collections.abc import Callable, Iterable from typing import Any, Optional @@ -28,7 +29,7 @@ from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms -from gt4py.next.iterator.transforms import fencil_to_program, global_tmps as gtmps_transform +from gt4py.next.iterator.transforms import fencil_to_program from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import modular_executor, processor_interface as ppi @@ -52,14 +53,6 @@ class EmbeddedDSL(codegen.TemplatedGenerator): FunCall = as_fmt("{fun}({','.join(args)})") Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") StencilClosure = as_mako("closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])") - FencilDefinition = as_mako( - """ -${''.join(function_definitions)} -@fendef -def ${id}(${','.join(params)}): - ${'\\n '.join(closures)} - """ - ) FunctionDefinition = as_mako( """ @fundef @@ -72,29 +65,12 @@ def ${id}(${','.join(params)}): ${''.join(function_definitions)} @fendef def ${id}(${','.join(params)}): + ${'\\n '.join(declarations)} ${'\\n '.join(body)} """ ) SetAt = as_mako("set_at(${expr}, ${domain}, ${target})") - # extension required by global_tmps - def visit_FencilWithTemporaries( - self, node: gtmps_transform.FencilWithTemporaries, **kwargs: Any - ) -> str: - params = self.visit(node.params) - - tmps = "\n ".join(self.visit(node.tmps)) - args = ", ".join(params + [tmp.id for tmp in node.tmps]) - params = ", ".join(params) - fencil = self.visit(node.fencil) - return ( - fencil - + "\n" - + f"def {node.fencil.id}_wrapper({params}, **kwargs):\n " - + tmps - + f"\n {node.fencil.id}({args}, **kwargs)\n" - ) - def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: assert ( isinstance(node.domain, itir.FunCall) @@ -112,8 +88,6 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" -_BACKEND_NAME = "roundtrip" - _FENCIL_CACHE: dict[int, Callable] = {} @@ -140,6 +114,8 @@ def fencil_generator( # caching mechanism cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) if cache_key in _FENCIL_CACHE: + if debug: + print(f"Using cached fencil for key {cache_key}") return _FENCIL_CACHE[cache_key] ir = itir_transforms.apply_common_transforms( @@ -204,14 +180,8 @@ def fencil_generator( if not debug: pathlib.Path(source_file_name).unlink(missing_ok=True) - assert isinstance( - ir, (itir.FencilDefinition, gtmps_transform.FencilWithTemporaries, itir.Program) - ) - fencil_name = ( - ir.fencil.id + "_wrapper" - if isinstance(ir, gtmps_transform.FencilWithTemporaries) - else ir.id - ) + assert isinstance(ir, itir.Program) + fencil_name = ir.id fencil = getattr(mod, fencil_name) _FENCIL_CACHE[cache_key] = fencil @@ -251,11 +221,18 @@ class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): use_embedded: bool = True def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: + lift_mode = inp.kwargs.get("lift_mode", self.lift_mode) + if lift_mode != self.lift_mode: + warnings.warn( + f"Roundtrip Backend was configured for LiftMode `{self.lift_mode!s}`, but " + f"overriden to be {lift_mode!s} at runtime.", + stacklevel=2, + ) return fencil_generator( inp.program, offset_provider=inp.kwargs.get("offset_provider", None), debug=self.debug, - lift_mode=self.lift_mode, + lift_mode=lift_mode, use_embedded=self.use_embedded, ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index d29ef68d4e..0de0a7b1d2 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -274,7 +274,9 @@ def test_nabla2(program_processor, lift_mode): AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 ) - nabla2( + run_processor( + nabla2, + program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), pp, @@ -282,7 +284,6 @@ def test_nabla2(program_processor, lift_mode): sign, vol, offset_provider={"E2V": e2v, "V2E": v2e}, - program_processor=program_processor, lift_mode=lift_mode, ) From 21b230be149e5009a2b95e50ebb3d1c6881e50a0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Apr 2024 14:41:30 +0200 Subject: [PATCH 36/61] replace lift_mode fixture by backend in program_processor --- src/gt4py/next/__init__.py | 2 +- .../runners/double_roundtrip.py | 2 +- .../program_processors/runners/roundtrip.py | 19 ++- tests/next_tests/definitions.py | 8 +- .../feature_tests/iterator_tests/test_scan.py | 5 +- .../iterator_tests/test_trivial.py | 14 +- .../cpp_backend_tests/.gitignore | 1 - .../cpp_backend_tests/CMakeLists.txt | 79 ------------ .../cpp_backend_tests/anton_lap.py | 77 ----------- .../cpp_backend_tests/anton_lap_driver.cpp | 40 ------ .../cmake/FetchGoogleTest.cmake | 14 -- .../cpp_backend_tests/copy_stencil.py | 56 -------- .../cpp_backend_tests/copy_stencil_driver.cpp | 42 ------ .../copy_stencil_field_view.py | 55 -------- .../copy_stencil_field_view_driver.cpp | 52 -------- .../cpp_backend_tests/fn_select_gt4py.hpp | 78 ----------- .../cpp_backend_tests/fvm_nabla.py | 106 --------------- .../cpp_backend_tests/fvm_nabla_driver.cpp | 122 ------------------ .../cpp_backend_tests/test_driver.py | 73 ----------- .../cpp_backend_tests/tridiagonal_solve.py | 78 ----------- .../tridiagonal_solve_driver.cpp | 54 -------- .../iterator_tests/test_anton_toy.py | 18 +-- .../iterator_tests/test_column_stencil.py | 17 +-- .../iterator_tests/test_fvm_nabla.py | 29 ++--- .../iterator_tests/test_hdiff.py | 18 +-- .../iterator_tests/test_vertical_advection.py | 21 +-- .../test_with_toy_connectivity.py | 47 +++---- tests/next_tests/unit_tests/conftest.py | 13 +- 28 files changed, 70 insertions(+), 1070 deletions(-) delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index e79e2f5517..f33b9c5127 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -53,7 +53,7 @@ run_gtfn_cached as gtfn_cpu, run_gtfn_gpu_cached as gtfn_gpu, ) -from .program_processors.runners.roundtrip import backend as itir_python +from .program_processors.runners.roundtrip import default as itir_python __all__ = [ diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index e6220ea879..6d3f368170 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -22,5 +22,5 @@ executor=roundtrip.RoundtripExecutorFactory( dispatch_backend=roundtrip.RoundtripExecutorFactory() ), - allocator=roundtrip.backend.allocator, + allocator=roundtrip.default.allocator, ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index aba2dbf71b..408451648e 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -26,7 +26,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from gt4py.next import allocators as next_allocators, backend as next_backend, common +from gt4py.next import allocators as next_allocators, backend as next_backend, common, config from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms from gt4py.next.iterator.transforms import global_tmps as gtmps_transform from gt4py.next.otf import stages, workflow @@ -103,8 +103,6 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" -_BACKEND_NAME = "roundtrip" - _FENCIL_CACHE: dict[int, Callable] = {} @@ -131,6 +129,8 @@ def fencil_generator( # caching mechanism cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) if cache_key in _FENCIL_CACHE: + if debug: + print(f"Using cached fencil for key {cache_key}") return _FENCIL_CACHE[cache_key] ir = itir_transforms.apply_common_transforms( @@ -209,7 +209,7 @@ def execute_roundtrip( *args: Any, column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], - debug: bool = False, + debug: bool = config.DEBUG, lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, dispatch_backend: Optional[ppi.ProgramExecutor] = None, ) -> None: @@ -230,7 +230,7 @@ def execute_roundtrip( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: bool = False + debug: bool = config.DEBUG lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True @@ -275,7 +275,14 @@ class Params: executor = RoundtripExecutorFactory(name="roundtrip") +executor_with_temporaries = RoundtripExecutorFactory( + name="roundtrip_with_temporaries", + roundtrip_workflow=RoundtripFactory(lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES), +) -backend = next_backend.Backend( +default = next_backend.Backend( executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() ) +with_temporaries = next_backend.Backend( + executor=executor_with_temporaries, allocator=next_allocators.StandardCPUFieldBufferAllocator() +) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 86eac69712..c7573aa8f3 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -54,7 +54,8 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" ) GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" - ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.backend" + ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.default" + ROUNDTRIP_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.roundtrip.with_temporaries" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -192,4 +193,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), + ], } diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index fce1aa3960..ef38e23e60 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -20,11 +20,11 @@ from gt4py.next.iterator.runtime import fundef, offset from next_tests.integration_tests.cases import IDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor @pytest.mark.uses_index_fields -def test_scan_in_stencil(program_processor, lift_mode): +def test_scan_in_stencil(program_processor): program_processor, validate = program_processor isize = 1 @@ -54,7 +54,6 @@ def wrapped(inp): program_processor, inp, out=out, - lift_mode=lift_mode, offset_provider={"Koff": KDim}, column_axis=KDim, ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index f85b9b4035..7bb023aabb 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -21,7 +21,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor I = offset("I") @@ -44,7 +44,7 @@ def baz(baz_inp): return deref(lift(bar)(baz_inp)) -def test_trivial(program_processor, lift_mode): +def test_trivial(program_processor): program_processor, validate = program_processor rng = np.random.default_rng() @@ -60,7 +60,6 @@ def test_trivial(program_processor, lift_mode): program_processor, inp_s, out=out_s, - lift_mode=lift_mode, offset_provider={"I": IDim, "J": JDim}, ) @@ -73,12 +72,9 @@ def stencil_shifted_arg_to_lift(inp): return deref(lift(deref)(shift(I, -1)(inp))) -def test_shifted_arg_to_lift(program_processor, lift_mode): +def test_shifted_arg_to_lift(program_processor): program_processor, validate = program_processor - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") - rng = np.random.default_rng() inp = rng.uniform(size=(5, 7)) out = np.zeros_like(inp) @@ -95,7 +91,6 @@ def test_shifted_arg_to_lift(program_processor, lift_mode): program_processor, inp_s, out=out_s, - lift_mode=lift_mode, offset_provider={"I": IDim, "J": JDim}, ) @@ -113,7 +108,7 @@ def fen_direct_deref(i_size, j_size, out, inp): ) -def test_direct_deref(program_processor, lift_mode): +def test_direct_deref(program_processor): program_processor, validate = program_processor rng = np.random.default_rng() @@ -129,7 +124,6 @@ def test_direct_deref(program_processor, lift_mode): *out.shape, out_s, inp_s, - lift_mode=lift_mode, offset_provider=dict(), ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore deleted file mode 100644 index 4ecaff7fe9..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore +++ /dev/null @@ -1 +0,0 @@ -build_*/ \ No newline at end of file diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt deleted file mode 100644 index 0a1b19d2a9..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt +++ /dev/null @@ -1,79 +0,0 @@ -cmake_minimum_required(VERSION 3.18.1) - -project(cpp_backend_tests_driver LANGUAGES CXX) - -set(BACKEND NOTFOUND CACHE STRING "fn backend") -set_property(CACHE BACKEND PROPERTY STRINGS "NOTFOUND;naive;gpu") - -if(NOT BACKEND) - message(FATAL_ERROR "No backend selected") -else() - message(STATUS "Testing backend \"${BACKEND}\"") -endif() - -if(BACKEND STREQUAL "gpu") - enable_language(CUDA) - set(is_gpu ON) -endif() -string(TOUPPER ${BACKEND} backend_upper_case) - -include(FetchContent) -FetchContent_Declare(GridTools - GIT_REPOSITORY https://github.com/GridTools/gridtools.git - GIT_TAG master -) -FetchContent_MakeAvailable(GridTools) - -function(generate_computation) - set(options) - set(oneValueArgs NAME SRC_FILE GENERATED_FILENAME IMPERATIVE) - set(multiValueArgs) - cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - add_custom_command(OUTPUT ${ARG_GENERATED_FILENAME} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${ARG_SRC_FILE} ${ARG_GENERATED_FILENAME} ${ARG_IMPERATIVE} - DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${ARG_INPUT}) - add_custom_target(generated_${ARG_NAME} DEPENDS ${ARG_GENERATED_FILENAME}) -endfunction() - -add_library(regression_main ${gridtools_SOURCE_DIR}/tests/src/regression_main.cpp) -target_include_directories(regression_main PUBLIC ${gridtools_SOURCE_DIR}/tests/include) -target_link_libraries(regression_main PUBLIC gtest gmock gridtools) - -function(add_fn_codegen_test) - set(options) - set(oneValueArgs NAME SRC_FILE DRIVER_FILE IMPERATIVE) - set(multiValueArgs) - cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - set(STENCIL_IMPL_SOURCE ${CMAKE_CURRENT_BINARY_DIR}/generated_${ARG_NAME}.hpp) - generate_computation(NAME ${ARG_NAME} SRC_FILE ${ARG_SRC_FILE} GENERATED_FILENAME ${STENCIL_IMPL_SOURCE} IMPERATIVE ${ARG_IMPERATIVE}) - add_executable(${ARG_NAME} ${ARG_DRIVER_FILE}) - target_link_libraries(${ARG_NAME} fn_${BACKEND} regression_main) - target_compile_definitions(${ARG_NAME} PRIVATE GT_FN_${backend_upper_case}) - target_compile_definitions(${ARG_NAME} PRIVATE GT_FN_BACKEND=${BACKEND}) - target_compile_definitions(${ARG_NAME} PRIVATE GENERATED_FILE=\"${STENCIL_IMPL_SOURCE}\") - if(is_gpu) - gridtools_setup_target(${ARG_NAME} CUDA_ARCH sm_60) #TODO - endif() - add_dependencies(${ARG_NAME} generated_${ARG_NAME}) - - add_test(NAME ${ARG_NAME} COMMAND $) -endfunction() - -include(CTest) -if(BUILD_TESTING) - find_package(Python3 COMPONENTS Interpreter REQUIRED) - - include(cmake/FetchGoogleTest.cmake) - fetch_googletest() - - add_fn_codegen_test(NAME copy_stencil SRC_FILE copy_stencil.py DRIVER_FILE copy_stencil_driver.cpp) - add_fn_codegen_test(NAME copy_stencil_field_view SRC_FILE copy_stencil_field_view.py DRIVER_FILE copy_stencil_field_view_driver.cpp) - add_fn_codegen_test(NAME anton_lap SRC_FILE anton_lap.py DRIVER_FILE anton_lap_driver.cpp) - add_fn_codegen_test(NAME fvm_nabla_fun SRC_FILE fvm_nabla.py DRIVER_FILE fvm_nabla_driver.cpp IMPERATIVE FALSE) - add_fn_codegen_test(NAME fvm_nabla_imp SRC_FILE fvm_nabla.py DRIVER_FILE fvm_nabla_driver.cpp IMPERATIVE TRUE) - add_fn_codegen_test(NAME tridiagonal_solve SRC_FILE tridiagonal_solve.py DRIVER_FILE tridiagonal_solve_driver.cpp) -endif() - - diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py deleted file mode 100644 index 5af4605988..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py +++ /dev/null @@ -1,77 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import sys - -import gt4py.next as gtx -from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fundef, offset -from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.runners.gtfn import run_gtfn - - -@fundef -def ldif(d): - return lambda inp: deref(shift(d, -1)(inp)) - deref(inp) - - -@fundef -def rdif(d): - return lambda inp: ldif(d)(shift(d, 1)(inp)) - - -@fundef -def dif2(d): - return lambda inp: ldif(d)(lift(rdif(d))(inp)) - - -i = offset("i") -j = offset("j") - - -@fundef -def lap(inp): - return dif2(i)(inp) + dif2(j)(inp) - - -IDim = gtx.Dimension("IDim") -JDim = gtx.Dimension("JDim") -KDim = gtx.Dimension("KDim") - - -def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): - closure( - cartesian_domain( - named_range(IDim, i_off, i_size + i_off), - named_range(JDim, j_off, j_size + j_off), - named_range(KDim, k_off, k_size + k_off), - ), - lap, - out, - [inp], - ) - - -if __name__ == "__main__": - if len(sys.argv) != 2: - raise RuntimeError(f"Usage: {sys.argv[0]} ") - output_file = sys.argv[1] - - prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) - generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( - prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None - ) - - with open(output_file, "w+") as output: - output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp deleted file mode 100644 index e2dbcbb6ce..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include - -#include "fn_select_gt4py.hpp" -#include -#include GENERATED_FILE - -#include - -namespace { -using namespace gridtools; -using namespace fn; -using namespace literals; - -constexpr inline auto in = [](auto... indices) { return (... + indices); }; - -using backend_t = - fn_backend_t::values<32, 8, 1>>; - -GT_REGRESSION_TEST(fn_lap, test_environment<1>, backend_t) { - auto actual = TypeParam::make_storage(); - - auto expected = [&](auto i, auto j, auto k) { - return in(i + 1, j, k) + in(i - 1, j, k) + in(i, j + 1, k) + - in(i, j - 1, k) - 4 * in(i, j, k); - }; - - generated::lap_fencil(tuple{})( - backend_t(), at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), 1, 1, 0, - sid::rename_numbered_dimensions(actual), - sid::rename_numbered_dimensions( - TypeParam::make_const_storage(in))); - - TypeParam::verify(expected, actual); -} -} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake deleted file mode 100644 index 983cce8594..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake +++ /dev/null @@ -1,14 +0,0 @@ -function(fetch_googletest) - # The gtest library needs to be built as static library to avoid RPATH issues - set(BUILD_SHARED_LIBS OFF) - - include(FetchContent) - option(INSTALL_GTEST OFF) - mark_as_advanced(INSTALL_GTEST) - FetchContent_Declare( - GoogleTest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG release-1.11.0 - ) - FetchContent_MakeAvailable(GoogleTest) -endfunction() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py deleted file mode 100644 index 3e8b88ac66..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py +++ /dev/null @@ -1,56 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import sys - -import gt4py.next as gtx -from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fundef -from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.runners.gtfn import run_gtfn - - -IDim = gtx.Dimension("IDim") -JDim = gtx.Dimension("JDim") -KDim = gtx.Dimension("KDim") - - -@fundef -def copy_stencil(inp): - return deref(inp) - - -def copy_fencil(isize, jsize, ksize, inp, out): - closure( - cartesian_domain( - named_range(IDim, 0, isize), named_range(JDim, 0, jsize), named_range(KDim, 0, ksize) - ), - copy_stencil, - out, - [inp], - ) - - -if __name__ == "__main__": - if len(sys.argv) != 2: - raise RuntimeError(f"Usage: {sys.argv[0]} ") - output_file = sys.argv[1] - - prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) - generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( - prog, offset_provider={}, column_axis=None - ) - - with open(output_file, "w+") as output: - output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp deleted file mode 100644 index e226b78907..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -#include "fn_select_gt4py.hpp" -#include -#include GENERATED_FILE - -#include - -namespace { -using namespace gridtools; -using namespace fn; -using namespace literals; - -constexpr inline auto in = [](auto... indices) { return (... + indices); }; - -using backend_t = - fn_backend_t::values<32, 8, 1>>; - -GT_REGRESSION_TEST(fn_cartesian_copy, test_environment<>, backend_t) { - auto out = TypeParam::make_storage(); - auto out_wrapped = - sid::rename_numbered_dimensions(out); - - auto in_wrapped = - sid::rename_numbered_dimensions( - TypeParam::make_const_storage(in)); - auto comp = [&] { - generated::copy_fencil(tuple{})( - backend_t{}, at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), in_wrapped, - out_wrapped); - }; - comp(); - - TypeParam::verify(in, out); -} - -} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py deleted file mode 100644 index fdc57449ee..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py +++ /dev/null @@ -1,55 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import sys - -from numpy import float64 - -import gt4py.next as gtx -from gt4py.next import Field, field_operator, program -from gt4py.next.program_processors.runners.gtfn import run_gtfn - - -IDim = gtx.Dimension("IDim") -JDim = gtx.Dimension("JDim") -KDim = gtx.Dimension("KDim") - - -@field_operator -def copy_stencil(inp: Field[[IDim, JDim, KDim], float64]) -> Field[[IDim, JDim, KDim], float64]: - return inp - - -@program -def copy_program( - inp: Field[[IDim, JDim, KDim], float64], - out: Field[[IDim, JDim, KDim], float64], - out2: Field[[IDim, JDim, KDim], float64], -): - copy_stencil(inp, out=out) - copy_stencil(inp, out=out2) - - -if __name__ == "__main__": - if len(sys.argv) != 2: - raise RuntimeError(f"Usage: {sys.argv[0]} ") - output_file = sys.argv[1] - - prog = copy_program.itir - generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( - prog, offset_provider={}, column_axis=None - ) - - with open(output_file, "w+") as output: - output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp deleted file mode 100644 index 19bd9fc3b9..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include - -#include "fn_select_gt4py.hpp" -#include -#include GENERATED_FILE - -#include - -namespace { -using namespace gridtools; -using namespace fn; -using namespace literals; - -constexpr inline auto in = [](auto... indices) { return (... + indices); }; -using backend_t = - fn_backend_t::values<32, 8, 1>>; - -GT_REGRESSION_TEST(fn_cartesian_copy, test_environment<>, backend_t) { - auto out = TypeParam::make_storage(); - auto out_wrapped = - sid::rename_numbered_dimensions(out); - auto out2 = TypeParam::make_storage(); - auto out2_wrapped = - sid::rename_numbered_dimensions(out2); - - auto in_wrapped = - sid::rename_numbered_dimensions( - TypeParam::make_const_storage(in)); - auto comp = [&] { - generated::copy_program(tuple{})( - backend_t{}, in_wrapped, out_wrapped, out2_wrapped, - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes())); - }; - comp(); - - TypeParam::verify(in, out); - TypeParam::verify(in, out2); -} - -} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp deleted file mode 100644 index a8f664c334..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -// COPIED AND SLIGHTLY MODIFIED FROM gritools/test/include/fn_select.hpp -// The gridtools fn_select assumes that dimensions are named -// integral_constant, the GT4Py generated code uses structs as dimension -// identifiers. - -#include - -#include - -// fn backend -#if defined(GT_FN_NAIVE) -#ifndef GT_STENCIL_NAIVE -#define GT_STENCIL_NAIVE -#endif -#ifndef GT_STORAGE_CPU_KFIRST -#define GT_STORAGE_CPU_KFIRST -#endif -#ifndef GT_TIMER_DUMMY -#define GT_TIMER_DUMMY -#endif -#include -namespace { -template using fn_backend_t = gridtools::fn::backend::naive; -} -#elif defined(GT_FN_GPU) -#ifndef GT_STENCIL_GPU -#define GT_STENCIL_GPU -#endif -#ifndef GT_STORAGE_GPU -#define GT_STORAGE_GPU -#endif -#ifndef GT_TIMER_CUDA -#define GT_TIMER_CUDA -#endif -#include -namespace { -template -using fn_backend_t = gridtools::fn::backend::gpu; -} // namespace -#endif - -#include "stencil_select.hpp" -#include "storage_select.hpp" -#include "timer_select.hpp" -namespace { -template struct block_sizes_t { - template - using values = gridtools::meta::zip< - gridtools::meta::list, - gridtools::meta::list...>>; -}; -} // namespace -namespace gridtools::fn::backend { -namespace naive_impl_ { -template struct naive_with_threadpool; -template -storage::cpu_kfirst backend_storage_traits(naive_with_threadpool); -template -timer_dummy backend_timer_impl(naive_with_threadpool); -template -inline char const *backend_name(naive_with_threadpool const &) { - return "naive"; -} -} // namespace naive_impl_ - -namespace gpu_impl_ { -template struct gpu; -template -storage::gpu backend_storage_traits(gpu); -template timer_cuda backend_timer_impl(gpu); -template -inline char const *backend_name(gpu const &) { - return "gpu"; -} -} // namespace gpu_impl_ -} // namespace gridtools::fn::backend diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py deleted file mode 100644 index fe8b54f95c..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ /dev/null @@ -1,106 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import sys -from dataclasses import dataclass - -import gt4py.next as gtx -from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fundef, offset -from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative - - -E2V = offset("E2V") -V2E = offset("V2E") - - -@fundef -def compute_zavgS(pp, S_M): - zavg = 0.5 * (deref(shift(E2V, 0)(pp)) + deref(shift(E2V, 1)(pp))) - return make_tuple(tuple_get(0, deref(S_M)) * zavg, tuple_get(1, deref(S_M)) * zavg) - - -@fundef -def tuple_dot_fun(acc, zavgS, sign): - return make_tuple( - tuple_get(0, acc) + tuple_get(0, zavgS) * sign, - tuple_get(1, acc) + tuple_get(1, zavgS) * sign, - ) - - -@fundef -def tuple_dot(a, b): - return reduce(tuple_dot_fun, make_tuple(0.0, 0.0))(a, b) - - -@fundef -def compute_pnabla(pp, S_M, sign, vol): - zavgS = lift(compute_zavgS)(pp, S_M) - pnabla_M = tuple_dot(neighbors(V2E, zavgS), deref(sign)) - return make_tuple(tuple_get(0, pnabla_M) / deref(vol), tuple_get(1, pnabla_M) / deref(vol)) - - -def zavgS_fencil(edge_domain, out, pp, S_M): - closure(edge_domain, compute_zavgS, out, [pp, S_M]) - - -Vertex = gtx.Dimension("Vertex") -K = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL) - - -def nabla_fencil(n_vertices, n_levels, out, pp, S_M, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_vertices), named_range(K, 0, n_levels)), - compute_pnabla, - out, - [pp, S_M, sign, vol], - ) - - -@dataclass -class DummyConnectivity: - max_neighbors: int - has_skip_values: int - origin_axis: gtx.Dimension = gtx.Dimension("dummy_origin") - neighbor_axis: gtx.Dimension = gtx.Dimension("dummy_neighbor") - index_type: type[int] = int - - def mapped_index(_, __) -> int: - return 0 - - -if __name__ == "__main__": - if len(sys.argv) != 3: - raise RuntimeError(f"Usage: {sys.argv[0]} ") - output_file = sys.argv[1] - imperative = sys.argv[2].lower() == "true" - - if imperative: - backend = run_gtfn_imperative - else: - backend = run_gtfn - - # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils - prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) - offset_provider = { - "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), - "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), - } - generated_code = backend.executor.otf_workflow.translation.generate_stencil_source( - prog, offset_provider=offset_provider, column_axis=None - ) - - with open(output_file, "w+") as output: - output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp deleted file mode 100644 index f0112a1636..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include - -#include "fn_select_gt4py.hpp" -#include -#include GENERATED_FILE - -namespace { -using namespace gridtools; -using namespace fn; -using namespace literals; - -using backend_t = - fn_backend_t::values<32, 8>>; - -// copied from gridtools::fn test -constexpr inline auto pp = [](int vertex, int k) { return (vertex + k) % 19; }; -constexpr inline auto sign = [](int vertex) { - return array{0, 1, vertex % 2, 1, (vertex + 1) % 2, 0}; -}; -constexpr inline auto vol = [](int vertex) { return vertex % 13 + 1; }; -constexpr inline auto s = [](int edge, int k) { - return tuple((edge + k) % 17, (edge + k) % 7); -}; - -constexpr inline auto zavg = [](auto const &e2v) { - return [&e2v](int edge, int k) { - double tmp = 0.0; - for (int neighbor = 0; neighbor < 2; ++neighbor) - tmp += pp(e2v(edge)[neighbor], k); - tmp /= 2.0; - return tuple{tmp * get<0>(s(edge, k)), tmp * get<1>(s(edge, k))}; - }; -}; - -constexpr inline auto make_zavg_expected = [](auto const &mesh) { - return [e2v_table = mesh.e2v_table()](int edge, int k) { - auto e2v = e2v_table->const_host_view(); - return zavg(e2v)(edge, k); - }; -}; - -constexpr inline auto expected = [](auto const &v2e, auto const &e2v) { - return [&v2e, zavg = zavg(e2v)](int vertex, int k) { - auto res = tuple(0.0, 0.0); - for (int neighbor = 0; neighbor < 6; ++neighbor) { - int edge = v2e(vertex)[neighbor]; - if (edge != -1) { - get<0>(res) += get<0>(zavg(edge, k)) * sign(vertex)[neighbor]; - get<1>(res) += get<1>(zavg(edge, k)) * sign(vertex)[neighbor]; - } - } - get<0>(res) /= vol(vertex); - get<1>(res) /= vol(vertex); - return res; - }; -}; - -constexpr inline auto make_expected = [](auto const &mesh) { - return [v2e_table = mesh.v2e_table(), - e2v_table = mesh.e2v_table()](int vertex, int k) { - auto v2e = v2e_table->const_host_view(); - auto e2v = e2v_table->const_host_view(); - return expected(v2e, e2v)(vertex, k); - }; -}; - -// GT_REGRESSION_TEST(unstructured_zavg, test_environment<>, fn_backend_t) { -// using float_t = typename TypeParam::float_t; - -// auto mesh = TypeParam::fn_unstructured_mesh(); -// auto actual = mesh.template make_storage>( -// mesh.nedges(), mesh.nlevels()); - -// auto pp_ = mesh.make_const_storage(pp, mesh.nvertices(), mesh.nlevels()); -// auto s_ = mesh.template make_const_storage>( -// s, mesh.nedges(), mesh.nlevels()); - -// auto e2v_conn = -// connectivity(mesh.e2v_table()->get_const_target_ptr()); -// auto edge_domain = -// unstructured_domain({mesh.nedges(), mesh.nlevels()}, {}, e2v_conn); - -// generated::zavgS_fencil(fn_backend_t{}, edge_domain, actual, pp_, s_); - -// auto expected = make_zavg_expected(mesh); -// TypeParam::verify(expected, actual); -// } - -GT_REGRESSION_TEST(unstructured_nabla, test_environment<>, backend_t) { - using float_t = typename TypeParam::float_t; - - auto mesh = TypeParam::fn_unstructured_mesh(); - auto actual = mesh.template make_storage>( - mesh.nvertices(), mesh.nlevels()); - - auto pp_ = mesh.make_const_storage(pp, mesh.nvertices(), mesh.nlevels()); - auto sign_ = mesh.template make_const_storage>( - sign, mesh.nvertices()); - auto vol_ = mesh.make_const_storage(vol, mesh.nvertices()); - auto s_ = mesh.template make_const_storage>( - s, mesh.nedges(), mesh.nlevels()); - - auto v2e_tbl = mesh.v2e_table(); - auto v2e_conn = - connectivity(v2e_tbl->get_const_target_ptr()); - - auto e2v_tbl = mesh.e2v_table(); - auto e2v_conn = - connectivity(e2v_tbl->get_const_target_ptr()); - auto vertex_domain = unstructured_domain({mesh.nvertices(), mesh.nlevels()}, - {}, v2e_conn, e2v_conn); - - generated::nabla_fencil(e2v_conn, v2e_conn)(backend_t{}, mesh.nvertices(), - mesh.nlevels(), actual, pp_, s_, - sign_, vol_); - - auto expected = make_expected(mesh); - TypeParam::verify(expected, actual); -} - -} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py deleted file mode 100644 index 7de56cb5bb..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py +++ /dev/null @@ -1,73 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import os -import subprocess -import sys -from pathlib import Path - -import pytest - - -def _source_dir(): - return Path(__file__).resolve().parent - - -def _build_dir(backend_dir: str): - return _source_dir() / f"build_{backend_dir}" - - -def _execute_cmake(backend_str: str): - build_dir = _build_dir(backend_str) - build_dir.mkdir(exist_ok=True) - cmake = ["cmake", "-B", build_dir, f"-DBACKEND={backend_str}"] - subprocess.run(cmake, cwd=_source_dir(), check=True, stderr=sys.stderr) - - -def _get_available_cpu_count(): - if hasattr(os, "sched_getaffinity"): - return len(os.sched_getaffinity(0)) - return os.cpu_count() - - -def _execute_build(backend_str: str): - build = [ - "cmake", - "--build", - _build_dir(backend_str), - "--parallel", - str(_get_available_cpu_count()), - ] - subprocess.run(build, check=True) - - -def _execute_ctest(backend_str: str): - ctest = "ctest" - subprocess.run(ctest, cwd=_build_dir(backend_str), check=True) - - -backends = ["naive"] -try: - import cupy # TODO actually cupy is not the requirement but a CUDA compiler... - - backends.append("gpu") -except ImportError: - pass - - -@pytest.mark.parametrize("backend_str", backends) -def test_driver_cpp_backends(backend_str): - _execute_cmake(backend_str) - _execute_build(backend_str) - _execute_ctest(backend_str) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py deleted file mode 100644 index 9755774fd0..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ /dev/null @@ -1,78 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import sys - -import gt4py.next as gtx -from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fundef -from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors.runners.gtfn import run_gtfn - - -IDim = gtx.Dimension("IDim") -JDim = gtx.Dimension("JDim") -KDim = gtx.Dimension("KDim") - - -@fundef -def tridiag_forward(state, a, b, c, d): - return make_tuple( - deref(c) / (deref(b) - deref(a) * tuple_get(0, state)), - (deref(d) - deref(a) * tuple_get(1, state)) / (deref(b) - deref(a) * tuple_get(0, state)), - ) - - -@fundef -def tridiag_backward(x_kp1, cpdp): - cpdpv = deref(cpdp) - cp = tuple_get(0, cpdpv) - dp = tuple_get(1, cpdpv) - return dp - cp * x_kp1 - - -@fundef -def solve_tridiag(a, b, c, d): - cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) - return scan(tridiag_backward, False, 0.0)(cpdp) - - -def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): - closure( - cartesian_domain( - named_range(IDim, 0, isize), named_range(JDim, 0, jsize), named_range(KDim, 0, ksize) - ), - solve_tridiag, - x, - [a, b, c, d], - ) - - -if __name__ == "__main__": - if len(sys.argv) != 2: - raise RuntimeError(f"Usage: {sys.argv[0]} ") - output_file = sys.argv[1] - - prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) - offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} - generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( - prog, - offset_provider=offset_provider, - runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, - column_axis=KDim, - ) - - with open(output_file, "w+") as output: - output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp deleted file mode 100644 index c51e3d4c41..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include - -#include "fn_select_gt4py.hpp" -#include -#include GENERATED_FILE - -#include - -namespace { -using namespace gridtools; -using namespace fn; -using namespace literals; - -constexpr inline auto a = [](auto...) { return -1; }; -constexpr inline auto b = [](auto...) { return 3; }; -constexpr inline auto c = [](auto...) { return 1; }; -constexpr inline auto d = [](int ksize) { - return [kmax = ksize - 1](auto, auto, auto k) { - return k == 0 ? 4 : k == kmax ? 2 : 3; - }; -}; -constexpr inline auto expected = [](auto...) { return 1; }; -using backend_t = - fn_backend_t::values<32, 8, 1>>; - -GT_REGRESSION_TEST(fn_cartesian_tridiagonal_solve, vertical_test_environment<>, - backend_t) { - using float_t = typename TypeParam::float_t; - - auto wrap = [](auto &&storage) { - return sid::rename_numbered_dimensions( - std::forward(storage)); - }; - - auto x = TypeParam::make_storage(); - auto x_wrapped = wrap(x); - auto a_wrapped = wrap(TypeParam::make_const_storage(a)); - auto b_wrapped = wrap(TypeParam::make_const_storage(b)); - auto c_wrapped = wrap(TypeParam::make_const_storage(c)); - auto d_wrapped = wrap(TypeParam::make_const_storage(d(TypeParam::d(2)))); - auto comp = [&] { - generated::tridiagonal_solve_fencil(tuple{})( - backend_t(), at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), - at_key(TypeParam::fn_cartesian_sizes()), a_wrapped, - b_wrapped, c_wrapped, d_wrapped, x_wrapped); - }; - comp(); - TypeParam::verify(expected, x); -} - -} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index bcea9e0901..f79ef747b9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -20,7 +20,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor @fundef @@ -79,19 +79,9 @@ def naive_lap(inp): @pytest.mark.uses_origin -def test_anton_toy(program_processor, lift_mode): +def test_anton_toy(program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn.executor, - gtfn.run_gtfn_imperative.executor, - gtfn.run_gtfn_with_temporaries.executor, - ]: - from gt4py.next.iterator import transforms - - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("TODO: issue with temporaries that crashes the application") - shape = [5, 7, 9] rng = np.random.default_rng() inp = gtx.as_field( @@ -102,9 +92,7 @@ def test_anton_toy(program_processor, lift_mode): out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) ref = naive_lap(inp) - run_processor( - fencil, program_processor, shape[0], shape[1], shape[2], out, inp, lift_mode=lift_mode - ) + run_processor(fencil, program_processor, shape[0], shape[1], shape[2], out, inp) if validate: assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index c7c8cf6c57..7f6caa9de0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -21,7 +21,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor I = offset("I") @@ -78,7 +78,7 @@ def basic_stencils(request): @pytest.mark.uses_origin -def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): +def test_basic_column_stencils(program_processor, basic_stencils): program_processor, validate = program_processor stencil, ref_fun, inp_fun = basic_stencils @@ -99,7 +99,6 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): out=out, offset_provider={"I": IDim, "K": KDim}, column_axis=KDim, - lift_mode=lift_mode, ) if validate: @@ -153,7 +152,7 @@ def k_level_condition_upper_tuple(k_idx, k_level): ], ) @pytest.mark.uses_tuple_args -def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): +def test_k_level_condition(program_processor, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor k_size = 5 @@ -170,7 +169,6 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct out=out, offset_provider={"K": KDim}, column_axis=KDim, - lift_mode=lift_mode, ) if validate: @@ -201,7 +199,7 @@ def ksum_fencil(i_size, k_start, k_end, inp, out): "kstart, reference", [(0, np.asarray([[0, 1, 3, 6, 10, 15, 21]])), (2, np.asarray([[0, 0, 2, 5, 9, 14, 20]]))], ) -def test_ksum_scan(program_processor, lift_mode, kstart, reference): +def test_ksum_scan(program_processor, kstart, reference): program_processor, validate = program_processor shape = [1, 7] inp = gtx.as_field([IDim, KDim], np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) @@ -216,7 +214,6 @@ def test_ksum_scan(program_processor, lift_mode, kstart, reference): inp, out, offset_provider={"I": IDim, "K": KDim}, - lift_mode=lift_mode, ) if validate: @@ -238,7 +235,7 @@ def ksum_back_fencil(i_size, k_size, inp, out): ) -def test_ksum_back_scan(program_processor, lift_mode): +def test_ksum_back_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] inp = gtx.as_field([IDim, KDim], np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) @@ -254,7 +251,6 @@ def test_ksum_back_scan(program_processor, lift_mode): inp, out, offset_provider={"I": IDim, "K": KDim}, - lift_mode=lift_mode, ) if validate: @@ -300,7 +296,7 @@ def kdoublesum_fencil(i_size, k_start, k_end, inp0, inp1, out): ), ], ) -def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): +def test_kdoublesum_scan(program_processor, kstart, reference): program_processor, validate = program_processor pytest.xfail("structured dtype input/output currently unsupported") shape = [1, 7] @@ -321,7 +317,6 @@ def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): inp1, out, offset_provider={"I": IDim, "K": KDim}, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index d29ef68d4e..f83430ef9f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -42,7 +42,7 @@ assert_close, nabla_setup, ) -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor Vertex = gtx.Dimension("Vertex") @@ -116,7 +116,7 @@ def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): @pytest.mark.requires_atlas -def test_compute_zavgS(program_processor, lift_mode): +def test_compute_zavgS(program_processor): program_processor, validate = program_processor setup = nabla_setup() @@ -137,7 +137,6 @@ def test_compute_zavgS(program_processor, lift_mode): pp, S_MXX, offset_provider={"E2V": e2v}, - lift_mode=lift_mode, ) if validate: @@ -152,7 +151,6 @@ def test_compute_zavgS(program_processor, lift_mode): pp, S_MYY, offset_provider={"E2V": e2v}, - lift_mode=lift_mode, ) if validate: assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) @@ -165,7 +163,7 @@ def compute_zavgS2_fencil(n_edges, out, pp, S_M): @pytest.mark.requires_atlas -def test_compute_zavgS2(program_processor, lift_mode): +def test_compute_zavgS2(program_processor): program_processor, validate = program_processor setup = nabla_setup() @@ -190,7 +188,6 @@ def test_compute_zavgS2(program_processor, lift_mode): pp, S, offset_provider={"E2V": e2v}, - lift_mode=lift_mode, ) if validate: @@ -202,10 +199,9 @@ def test_compute_zavgS2(program_processor, lift_mode): @pytest.mark.requires_atlas -def test_nabla(program_processor, lift_mode): +def test_nabla(program_processor): program_processor, validate = program_processor - if lift_mode != LiftMode.FORCE_INLINE: - pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") + setup = nabla_setup() sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) @@ -234,7 +230,6 @@ def test_nabla(program_processor, lift_mode): sign, vol, offset_provider={"E2V": e2v, "V2E": v2e}, - lift_mode=lift_mode, ) if validate: @@ -255,7 +250,7 @@ def nabla2(n_nodes, out, pp, S, sign, vol): @pytest.mark.requires_atlas -def test_nabla2(program_processor, lift_mode): +def test_nabla2(program_processor): program_processor, validate = program_processor setup = nabla_setup() @@ -274,7 +269,9 @@ def test_nabla2(program_processor, lift_mode): AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 ) - nabla2( + run_processor( + nabla2, + program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), pp, @@ -282,8 +279,6 @@ def test_nabla2(program_processor, lift_mode): sign, vol, offset_provider={"E2V": e2v, "V2E": v2e}, - program_processor=program_processor, - lift_mode=lift_mode, ) if validate: @@ -334,10 +329,9 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ @pytest.mark.requires_atlas -def test_nabla_sign(program_processor, lift_mode): +def test_nabla_sign(program_processor): program_processor, validate = program_processor - if lift_mode != LiftMode.FORCE_INLINE: - pytest.xfail("test is broken due to bad lift semantics in iterator IR") + setup = nabla_setup() is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) @@ -368,7 +362,6 @@ def test_nabla_sign(program_processor, lift_mode): gtx.index_field(Vertex), is_pole_edge, offset_provider={"E2V": e2v, "V2E": v2e}, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 5d369c3a8f..3abdd7cd5a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -24,7 +24,7 @@ from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import ( hdiff_reference, ) -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor I = offset("I") @@ -72,18 +72,8 @@ def hdiff(inp, coeff, out, x, y): @pytest.mark.uses_origin -def test_hdiff(hdiff_reference, program_processor, lift_mode): +def test_hdiff(hdiff_reference, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn.executor, - gtfn.run_gtfn_imperative.executor, - gtfn.run_gtfn_with_temporaries.executor, - ]: - # TODO(tehrengruber): check if still true - from gt4py.next.iterator import transforms - - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("Temporaries are not compatible with origins.") inp, coeff, out = hdiff_reference shape = (out.shape[0], out.shape[1]) @@ -92,9 +82,7 @@ def test_hdiff(hdiff_reference, program_processor, lift_mode): coeff_s = gtx.as_field([IDim, JDim], coeff[:, :, 0]) out_s = gtx.as_field([IDim, JDim], np.zeros_like(coeff[:, :, 0])) - run_processor( - hdiff, program_processor, inp_s, coeff_s, out_s, shape[0], shape[1], lift_mode=lift_mode - ) + run_processor(hdiff, program_processor, inp_s, coeff_s, out_s, shape[0], shape[1]) if validate: assert np.allclose(out[:, :, 0], out_s.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 820e9415bc..4eb07e27b9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -23,7 +23,7 @@ from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor @fundef @@ -110,23 +110,9 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) @pytest.mark.uses_lift_expressions -def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): +def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor - if ( - program_processor - in [ - gtfn.run_gtfn.executor, - gtfn.run_gtfn_imperative.executor, - gtfn.run_gtfn_with_temporaries.executor, - gtfn_formatters.format_cpp, - ] - and lift_mode == LiftMode.FORCE_INLINE - ): - pytest.skip("gtfn does only support lifted scans when using temporaries") - if ( - program_processor == gtfn.run_gtfn_with_temporaries.executor - or lift_mode == LiftMode.USE_TEMPORARIES - ): + if program_processor == gtfn.run_gtfn_with_temporaries.executor: pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference shape = a.shape @@ -150,7 +136,6 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): x_s, offset_provider={}, column_axis=KDim, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 714e568b8f..dfd54debb6 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -48,7 +48,7 @@ v2e_arr, v2v_arr, ) -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor def edge_index_field(): # TODO replace by gtx.index_field once supported in bindings @@ -84,7 +84,7 @@ def sum_edges_to_vertices_reduce(in_edges): "stencil", [sum_edges_to_vertices, sum_edges_to_vertices_list_get_neighbors, sum_edges_to_vertices_reduce], ) -def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): +def test_sum_edges_to_vertices(program_processor, stencil): program_processor, validate = program_processor inp = edge_index_field() out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -96,7 +96,6 @@ def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -107,7 +106,7 @@ def map_neighbors(in_edges): return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) -def test_map_neighbors(program_processor, lift_mode): +def test_map_neighbors(program_processor): program_processor, validate = program_processor inp = edge_index_field() out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -119,7 +118,6 @@ def test_map_neighbors(program_processor, lift_mode): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -131,7 +129,7 @@ def map_make_const_list(in_edges): @pytest.mark.uses_constant_fields -def test_map_make_const_list(program_processor, lift_mode): +def test_map_make_const_list(program_processor): program_processor, validate = program_processor inp = edge_index_field() out = gtx.as_field([Vertex], np.zeros([9], inp.dtype)) @@ -143,7 +141,6 @@ def test_map_make_const_list(program_processor, lift_mode): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -154,7 +151,7 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) -def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor, lift_mode): +def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor): program_processor, validate = program_processor inp = vertex_index_field() out = gtx.as_field([Cell], np.zeros([9], dtype=inp.dtype)) @@ -169,7 +166,6 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), }, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -180,7 +176,7 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) -def test_sparse_input_field(program_processor, lift_mode): +def test_sparse_input_field(program_processor): program_processor, validate = program_processor non_sparse = gtx.as_field([Edge], np.zeros(18, dtype=np.int32)) @@ -196,14 +192,13 @@ def test_sparse_input_field(program_processor, lift_mode): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) -def test_sparse_input_field_v2v(program_processor, lift_mode): +def test_sparse_input_field_v2v(program_processor): program_processor, validate = program_processor non_sparse = gtx.as_field([Edge], np.zeros(18, dtype=np.int32)) @@ -222,7 +217,6 @@ def test_sparse_input_field_v2v(program_processor, lift_mode): "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), }, - lift_mode=lift_mode, ) if validate: @@ -235,7 +229,7 @@ def slice_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields -def test_slice_sparse(program_processor, lift_mode): +def test_slice_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -248,7 +242,6 @@ def test_slice_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -261,7 +254,7 @@ def slice_twice_sparse_stencil(sparse): @pytest.mark.xfail(reason="Field with more than one sparse dimension is not implemented.") -def test_slice_twice_sparse(program_processor, lift_mode): +def test_slice_twice_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim, V2VDim], v2v_arr[v2v_arr]) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -273,7 +266,6 @@ def test_slice_twice_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -286,7 +278,7 @@ def shift_sliced_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields -def test_shift_sliced_sparse(program_processor, lift_mode): +def test_shift_sliced_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -299,7 +291,6 @@ def test_shift_sliced_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -312,7 +303,7 @@ def slice_shifted_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields -def test_slice_shifted_sparse(program_processor, lift_mode): +def test_slice_shifted_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -325,7 +316,6 @@ def test_slice_shifted_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -342,7 +332,7 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) -def test_lift(program_processor, lift_mode): +def test_lift(program_processor): program_processor, validate = program_processor inp = vertex_index_field() out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -354,7 +344,6 @@ def test_lift(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -366,7 +355,7 @@ def sparse_shifted_stencil(inp): @pytest.mark.uses_sparse_fields -def test_shift_sparse_input_field(program_processor, lift_mode): +def test_shift_sparse_input_field(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -378,7 +367,6 @@ def test_shift_sparse_input_field(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -396,7 +384,7 @@ def shift_sparse_stencil2(inp): @pytest.mark.uses_sparse_fields -def test_shift_sparse_input_field2(program_processor, lift_mode): +def test_shift_sparse_input_field2(program_processor): program_processor, validate = program_processor if program_processor in [ gtfn.run_gtfn, @@ -423,7 +411,6 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): inp, out=out1, offset_provider=offset_provider, - lift_mode=lift_mode, ) run_processor( shift_sparse_stencil2[domain], @@ -431,7 +418,6 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): inp_sparse, out=out2, offset_provider=offset_provider, - lift_mode=lift_mode, ) if validate: @@ -448,10 +434,8 @@ def sum_(a, b): @pytest.mark.uses_sparse_fields @pytest.mark.uses_reduction_with_only_sparse_fields -def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): +def test_sparse_shifted_stencil_reduce(program_processor): program_processor, validate = program_processor - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -472,7 +456,6 @@ def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index c9406884e6..84a2d459e5 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -35,18 +35,6 @@ import next_tests -@pytest.fixture( - params=[ - transforms.LiftMode.FORCE_INLINE, - transforms.LiftMode.USE_TEMPORARIES, - transforms.LiftMode.SIMPLE_HEURISTIC, - ], - ids=lambda p: f"lift_mode={p.name}", -) -def lift_mode(request): - return request.param - - OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append((next_tests.definitions.OptionalProgramBackendId.DACE_CPU, True)) @@ -62,6 +50,7 @@ def lift_mode(request): params=[ (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True), (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), From 09467834e116bc22c123f0b16b7085d42e70f5d2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Apr 2024 15:00:53 +0200 Subject: [PATCH 37/61] fix doctests --- src/gt4py/next/constructors.py | 12 ++++-------- .../ffront_tests/test_ffront_fvm_nabla.py | 6 ++++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 43d9cb81b9..18a89ec07a 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -67,9 +67,8 @@ def empty( Initialize a field in one dimension with a backend and a range domain: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> a = gtx.empty({IDim: range(3, 10)}, allocator=roundtrip.backend) + >>> a = gtx.empty({IDim: range(3, 10)}, allocator=gtx.itir_python) >>> a.shape (7,) @@ -109,9 +108,8 @@ def zeros( Examples: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> gtx.zeros({IDim: range(3, 10)}, allocator=roundtrip.backend).ndarray + >>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([0., 0., 0., 0., 0., 0., 0.]) """ field = empty( @@ -137,9 +135,8 @@ def ones( Examples: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> gtx.ones({IDim: range(3, 10)}, allocator=roundtrip.backend).ndarray + >>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([1., 1., 1., 1., 1., 1., 1.]) """ field = empty( @@ -171,9 +168,8 @@ def full( Examples: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> gtx.full({IDim: 3}, 5, allocator=roundtrip.backend).ndarray + >>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray array([5, 5, 5]) """ field = empty( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index aeed607b01..bdb50f27ff 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -87,7 +87,9 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False ) - compute_zavgS.with_backend(executor)(pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v}) + compute_zavgS.with_backend(exec_alloc_descriptor)( + pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) assert_close(388241977.58389181, np.max(zavgS.asnumpy())) @@ -113,7 +115,7 @@ def test_ffront_nabla(exec_alloc_descriptor): atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 ) - pnabla.with_backend(executor)( + pnabla.with_backend(exec_alloc_descriptor)( pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} ) From f93da09fa0a357146cb9ad42e63506b6d539f341 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Apr 2024 17:29:50 +0200 Subject: [PATCH 38/61] fix tests --- docs/user/next/QuickstartGuide.md | 21 ++++++++++++++----- .../iterator_tests/test_anton_toy.py | 5 +++++ .../iterator_tests/test_vertical_advection.py | 6 ++++++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index f8ff64a980..ee39831d45 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -77,7 +77,7 @@ array_of_ones_numpy = np.ones((grid_shape[0], grid_shape[1])) field_of_ones = gtx.ones( domain={I: range(grid_shape[0]), J: range(grid_shape[0])}, dtype=np.float64, - allocator=gtx.program_processors.runners.roundtrip.backend + allocator=gtx.itir_python ) ``` @@ -150,10 +150,21 @@ The sign of the edge difference in the sum of the pseudo-laplacian is always suc This section approaches the pseudo-laplacian by introducing the required APIs progressively through the following subsections: -- [Defining the mesh and the connectivities (adjacencies) between cells and edges](#Defining-the-mesh-and-its-connectivities) -- [Using connectivities in field operators](#Using-connectivities-in-field-operators) -- [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) -- [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) +- [Getting started with the GT4Py declarative frontend](#getting-started-with-the-gt4py-declarative-frontend) + - [Installation](#installation) + - [Programming guide](#programming-guide) + - [Key concepts and application structure](#key-concepts-and-application-structure) + - [Importing features](#importing-features) + - [Fields](#fields) + - [Field operators](#field-operators) + - [Programs](#programs) + - [Composing field operators and programs](#composing-field-operators-and-programs) + - [Operations on unstructured meshes](#operations-on-unstructured-meshes) + - [Defining the mesh and its connectivities](#defining-the-mesh-and-its-connectivities) + - [Using connectivities in field operators](#using-connectivities-in-field-operators) + - [Using reductions on connected mesh elements](#using-reductions-on-connected-mesh-elements) + - [Using conditionals on fields](#using-conditionals-on-fields) + - [Implementing the pseudo-laplacian](#implementing-the-pseudo-laplacian) +++ diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index f79ef747b9..445255f391 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -82,6 +82,11 @@ def naive_lap(inp): def test_anton_toy(program_processor): program_processor, validate = program_processor + if program_processor in [ + gtfn.run_gtfn_with_temporaries.executor, + ]: + pytest.xfail("TODO: issue with temporaries that crashes the application") + shape = [5, 7, 9] rng = np.random.default_rng() inp = gtx.as_field( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 4eb07e27b9..921d1ae116 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -112,6 +112,12 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): @pytest.mark.uses_lift_expressions def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor + if program_processor in [ + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn_formatters.format_cpp, + ]: + pytest.skip("gtfn does only support lifted scans when using temporaries") if program_processor == gtfn.run_gtfn_with_temporaries.executor: pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference From 97663bde6c2826662152ba62539d6f8133b38018 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Apr 2024 19:50:57 +0200 Subject: [PATCH 39/61] undo quickstart changes --- docs/user/next/QuickstartGuide.md | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index ee39831d45..f8ff64a980 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -77,7 +77,7 @@ array_of_ones_numpy = np.ones((grid_shape[0], grid_shape[1])) field_of_ones = gtx.ones( domain={I: range(grid_shape[0]), J: range(grid_shape[0])}, dtype=np.float64, - allocator=gtx.itir_python + allocator=gtx.program_processors.runners.roundtrip.backend ) ``` @@ -150,21 +150,10 @@ The sign of the edge difference in the sum of the pseudo-laplacian is always suc This section approaches the pseudo-laplacian by introducing the required APIs progressively through the following subsections: -- [Getting started with the GT4Py declarative frontend](#getting-started-with-the-gt4py-declarative-frontend) - - [Installation](#installation) - - [Programming guide](#programming-guide) - - [Key concepts and application structure](#key-concepts-and-application-structure) - - [Importing features](#importing-features) - - [Fields](#fields) - - [Field operators](#field-operators) - - [Programs](#programs) - - [Composing field operators and programs](#composing-field-operators-and-programs) - - [Operations on unstructured meshes](#operations-on-unstructured-meshes) - - [Defining the mesh and its connectivities](#defining-the-mesh-and-its-connectivities) - - [Using connectivities in field operators](#using-connectivities-in-field-operators) - - [Using reductions on connected mesh elements](#using-reductions-on-connected-mesh-elements) - - [Using conditionals on fields](#using-conditionals-on-fields) - - [Implementing the pseudo-laplacian](#implementing-the-pseudo-laplacian) +- [Defining the mesh and the connectivities (adjacencies) between cells and edges](#Defining-the-mesh-and-its-connectivities) +- [Using connectivities in field operators](#Using-connectivities-in-field-operators) +- [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) +- [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) +++ From 3297f7b207795473007ad433e60d18d8ed3aa0a2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Apr 2024 19:56:31 +0200 Subject: [PATCH 40/61] undo delete cpp_backend_tests --- .../cpp_backend_tests/.gitignore | 1 + .../cpp_backend_tests/CMakeLists.txt | 79 ++++++++++++ .../cpp_backend_tests/anton_lap.py | 77 +++++++++++ .../cpp_backend_tests/anton_lap_driver.cpp | 40 ++++++ .../cmake/FetchGoogleTest.cmake | 14 ++ .../cpp_backend_tests/copy_stencil.py | 56 ++++++++ .../cpp_backend_tests/copy_stencil_driver.cpp | 42 ++++++ .../copy_stencil_field_view.py | 55 ++++++++ .../copy_stencil_field_view_driver.cpp | 52 ++++++++ .../cpp_backend_tests/fn_select_gt4py.hpp | 78 +++++++++++ .../cpp_backend_tests/fvm_nabla.py | 106 +++++++++++++++ .../cpp_backend_tests/fvm_nabla_driver.cpp | 122 ++++++++++++++++++ .../cpp_backend_tests/test_driver.py | 73 +++++++++++ .../cpp_backend_tests/tridiagonal_solve.py | 78 +++++++++++ .../tridiagonal_solve_driver.cpp | 54 ++++++++ 15 files changed, 927 insertions(+) create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore new file mode 100644 index 0000000000..4ecaff7fe9 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/.gitignore @@ -0,0 +1 @@ +build_*/ \ No newline at end of file diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt new file mode 100644 index 0000000000..0a1b19d2a9 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/CMakeLists.txt @@ -0,0 +1,79 @@ +cmake_minimum_required(VERSION 3.18.1) + +project(cpp_backend_tests_driver LANGUAGES CXX) + +set(BACKEND NOTFOUND CACHE STRING "fn backend") +set_property(CACHE BACKEND PROPERTY STRINGS "NOTFOUND;naive;gpu") + +if(NOT BACKEND) + message(FATAL_ERROR "No backend selected") +else() + message(STATUS "Testing backend \"${BACKEND}\"") +endif() + +if(BACKEND STREQUAL "gpu") + enable_language(CUDA) + set(is_gpu ON) +endif() +string(TOUPPER ${BACKEND} backend_upper_case) + +include(FetchContent) +FetchContent_Declare(GridTools + GIT_REPOSITORY https://github.com/GridTools/gridtools.git + GIT_TAG master +) +FetchContent_MakeAvailable(GridTools) + +function(generate_computation) + set(options) + set(oneValueArgs NAME SRC_FILE GENERATED_FILENAME IMPERATIVE) + set(multiValueArgs) + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + add_custom_command(OUTPUT ${ARG_GENERATED_FILENAME} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/${ARG_SRC_FILE} ${ARG_GENERATED_FILENAME} ${ARG_IMPERATIVE} + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${ARG_INPUT}) + add_custom_target(generated_${ARG_NAME} DEPENDS ${ARG_GENERATED_FILENAME}) +endfunction() + +add_library(regression_main ${gridtools_SOURCE_DIR}/tests/src/regression_main.cpp) +target_include_directories(regression_main PUBLIC ${gridtools_SOURCE_DIR}/tests/include) +target_link_libraries(regression_main PUBLIC gtest gmock gridtools) + +function(add_fn_codegen_test) + set(options) + set(oneValueArgs NAME SRC_FILE DRIVER_FILE IMPERATIVE) + set(multiValueArgs) + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(STENCIL_IMPL_SOURCE ${CMAKE_CURRENT_BINARY_DIR}/generated_${ARG_NAME}.hpp) + generate_computation(NAME ${ARG_NAME} SRC_FILE ${ARG_SRC_FILE} GENERATED_FILENAME ${STENCIL_IMPL_SOURCE} IMPERATIVE ${ARG_IMPERATIVE}) + add_executable(${ARG_NAME} ${ARG_DRIVER_FILE}) + target_link_libraries(${ARG_NAME} fn_${BACKEND} regression_main) + target_compile_definitions(${ARG_NAME} PRIVATE GT_FN_${backend_upper_case}) + target_compile_definitions(${ARG_NAME} PRIVATE GT_FN_BACKEND=${BACKEND}) + target_compile_definitions(${ARG_NAME} PRIVATE GENERATED_FILE=\"${STENCIL_IMPL_SOURCE}\") + if(is_gpu) + gridtools_setup_target(${ARG_NAME} CUDA_ARCH sm_60) #TODO + endif() + add_dependencies(${ARG_NAME} generated_${ARG_NAME}) + + add_test(NAME ${ARG_NAME} COMMAND $) +endfunction() + +include(CTest) +if(BUILD_TESTING) + find_package(Python3 COMPONENTS Interpreter REQUIRED) + + include(cmake/FetchGoogleTest.cmake) + fetch_googletest() + + add_fn_codegen_test(NAME copy_stencil SRC_FILE copy_stencil.py DRIVER_FILE copy_stencil_driver.cpp) + add_fn_codegen_test(NAME copy_stencil_field_view SRC_FILE copy_stencil_field_view.py DRIVER_FILE copy_stencil_field_view_driver.cpp) + add_fn_codegen_test(NAME anton_lap SRC_FILE anton_lap.py DRIVER_FILE anton_lap_driver.cpp) + add_fn_codegen_test(NAME fvm_nabla_fun SRC_FILE fvm_nabla.py DRIVER_FILE fvm_nabla_driver.cpp IMPERATIVE FALSE) + add_fn_codegen_test(NAME fvm_nabla_imp SRC_FILE fvm_nabla.py DRIVER_FILE fvm_nabla_driver.cpp IMPERATIVE TRUE) + add_fn_codegen_test(NAME tridiagonal_solve SRC_FILE tridiagonal_solve.py DRIVER_FILE tridiagonal_solve_driver.cpp) +endif() + + diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py new file mode 100644 index 0000000000..5af4605988 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py @@ -0,0 +1,77 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import sys + +import gt4py.next as gtx +from gt4py.next.iterator.builtins import * +from gt4py.next.iterator.runtime import closure, fundef, offset +from gt4py.next.iterator.tracing import trace_fencil_definition +from gt4py.next.program_processors.runners.gtfn import run_gtfn + + +@fundef +def ldif(d): + return lambda inp: deref(shift(d, -1)(inp)) - deref(inp) + + +@fundef +def rdif(d): + return lambda inp: ldif(d)(shift(d, 1)(inp)) + + +@fundef +def dif2(d): + return lambda inp: ldif(d)(lift(rdif(d))(inp)) + + +i = offset("i") +j = offset("j") + + +@fundef +def lap(inp): + return dif2(i)(inp) + dif2(j)(inp) + + +IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") +KDim = gtx.Dimension("KDim") + + +def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): + closure( + cartesian_domain( + named_range(IDim, i_off, i_size + i_off), + named_range(JDim, j_off, j_size + j_off), + named_range(KDim, k_off, k_size + k_off), + ), + lap, + out, + [inp], + ) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + raise RuntimeError(f"Usage: {sys.argv[0]} ") + output_file = sys.argv[1] + + prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None + ) + + with open(output_file, "w+") as output: + output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp new file mode 100644 index 0000000000..e2dbcbb6ce --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap_driver.cpp @@ -0,0 +1,40 @@ +#include + +#include "fn_select_gt4py.hpp" +#include +#include GENERATED_FILE + +#include + +namespace { +using namespace gridtools; +using namespace fn; +using namespace literals; + +constexpr inline auto in = [](auto... indices) { return (... + indices); }; + +using backend_t = + fn_backend_t::values<32, 8, 1>>; + +GT_REGRESSION_TEST(fn_lap, test_environment<1>, backend_t) { + auto actual = TypeParam::make_storage(); + + auto expected = [&](auto i, auto j, auto k) { + return in(i + 1, j, k) + in(i - 1, j, k) + in(i, j + 1, k) + + in(i, j - 1, k) - 4 * in(i, j, k); + }; + + generated::lap_fencil(tuple{})( + backend_t(), at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), 1, 1, 0, + sid::rename_numbered_dimensions(actual), + sid::rename_numbered_dimensions( + TypeParam::make_const_storage(in))); + + TypeParam::verify(expected, actual); +} +} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake new file mode 100644 index 0000000000..983cce8594 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/cmake/FetchGoogleTest.cmake @@ -0,0 +1,14 @@ +function(fetch_googletest) + # The gtest library needs to be built as static library to avoid RPATH issues + set(BUILD_SHARED_LIBS OFF) + + include(FetchContent) + option(INSTALL_GTEST OFF) + mark_as_advanced(INSTALL_GTEST) + FetchContent_Declare( + GoogleTest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.11.0 + ) + FetchContent_MakeAvailable(GoogleTest) +endfunction() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py new file mode 100644 index 0000000000..3e8b88ac66 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py @@ -0,0 +1,56 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import sys + +import gt4py.next as gtx +from gt4py.next.iterator.builtins import * +from gt4py.next.iterator.runtime import closure, fundef +from gt4py.next.iterator.tracing import trace_fencil_definition +from gt4py.next.program_processors.runners.gtfn import run_gtfn + + +IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") +KDim = gtx.Dimension("KDim") + + +@fundef +def copy_stencil(inp): + return deref(inp) + + +def copy_fencil(isize, jsize, ksize, inp, out): + closure( + cartesian_domain( + named_range(IDim, 0, isize), named_range(JDim, 0, jsize), named_range(KDim, 0, ksize) + ), + copy_stencil, + out, + [inp], + ) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + raise RuntimeError(f"Usage: {sys.argv[0]} ") + output_file = sys.argv[1] + + prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) + + with open(output_file, "w+") as output: + output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp new file mode 100644 index 0000000000..e226b78907 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_driver.cpp @@ -0,0 +1,42 @@ +#include + +#include "fn_select_gt4py.hpp" +#include +#include GENERATED_FILE + +#include + +namespace { +using namespace gridtools; +using namespace fn; +using namespace literals; + +constexpr inline auto in = [](auto... indices) { return (... + indices); }; + +using backend_t = + fn_backend_t::values<32, 8, 1>>; + +GT_REGRESSION_TEST(fn_cartesian_copy, test_environment<>, backend_t) { + auto out = TypeParam::make_storage(); + auto out_wrapped = + sid::rename_numbered_dimensions(out); + + auto in_wrapped = + sid::rename_numbered_dimensions( + TypeParam::make_const_storage(in)); + auto comp = [&] { + generated::copy_fencil(tuple{})( + backend_t{}, at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), in_wrapped, + out_wrapped); + }; + comp(); + + TypeParam::verify(in, out); +} + +} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py new file mode 100644 index 0000000000..fdc57449ee --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py @@ -0,0 +1,55 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import sys + +from numpy import float64 + +import gt4py.next as gtx +from gt4py.next import Field, field_operator, program +from gt4py.next.program_processors.runners.gtfn import run_gtfn + + +IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") +KDim = gtx.Dimension("KDim") + + +@field_operator +def copy_stencil(inp: Field[[IDim, JDim, KDim], float64]) -> Field[[IDim, JDim, KDim], float64]: + return inp + + +@program +def copy_program( + inp: Field[[IDim, JDim, KDim], float64], + out: Field[[IDim, JDim, KDim], float64], + out2: Field[[IDim, JDim, KDim], float64], +): + copy_stencil(inp, out=out) + copy_stencil(inp, out=out2) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + raise RuntimeError(f"Usage: {sys.argv[0]} ") + output_file = sys.argv[1] + + prog = copy_program.itir + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) + + with open(output_file, "w+") as output: + output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp new file mode 100644 index 0000000000..19bd9fc3b9 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view_driver.cpp @@ -0,0 +1,52 @@ +#include + +#include "fn_select_gt4py.hpp" +#include +#include GENERATED_FILE + +#include + +namespace { +using namespace gridtools; +using namespace fn; +using namespace literals; + +constexpr inline auto in = [](auto... indices) { return (... + indices); }; +using backend_t = + fn_backend_t::values<32, 8, 1>>; + +GT_REGRESSION_TEST(fn_cartesian_copy, test_environment<>, backend_t) { + auto out = TypeParam::make_storage(); + auto out_wrapped = + sid::rename_numbered_dimensions(out); + auto out2 = TypeParam::make_storage(); + auto out2_wrapped = + sid::rename_numbered_dimensions(out2); + + auto in_wrapped = + sid::rename_numbered_dimensions( + TypeParam::make_const_storage(in)); + auto comp = [&] { + generated::copy_program(tuple{})( + backend_t{}, in_wrapped, out_wrapped, out2_wrapped, + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes())); + }; + comp(); + + TypeParam::verify(in, out); + TypeParam::verify(in, out2); +} + +} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp new file mode 100644 index 0000000000..a8f664c334 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fn_select_gt4py.hpp @@ -0,0 +1,78 @@ +#pragma once + +// COPIED AND SLIGHTLY MODIFIED FROM gritools/test/include/fn_select.hpp +// The gridtools fn_select assumes that dimensions are named +// integral_constant, the GT4Py generated code uses structs as dimension +// identifiers. + +#include + +#include + +// fn backend +#if defined(GT_FN_NAIVE) +#ifndef GT_STENCIL_NAIVE +#define GT_STENCIL_NAIVE +#endif +#ifndef GT_STORAGE_CPU_KFIRST +#define GT_STORAGE_CPU_KFIRST +#endif +#ifndef GT_TIMER_DUMMY +#define GT_TIMER_DUMMY +#endif +#include +namespace { +template using fn_backend_t = gridtools::fn::backend::naive; +} +#elif defined(GT_FN_GPU) +#ifndef GT_STENCIL_GPU +#define GT_STENCIL_GPU +#endif +#ifndef GT_STORAGE_GPU +#define GT_STORAGE_GPU +#endif +#ifndef GT_TIMER_CUDA +#define GT_TIMER_CUDA +#endif +#include +namespace { +template +using fn_backend_t = gridtools::fn::backend::gpu; +} // namespace +#endif + +#include "stencil_select.hpp" +#include "storage_select.hpp" +#include "timer_select.hpp" +namespace { +template struct block_sizes_t { + template + using values = gridtools::meta::zip< + gridtools::meta::list, + gridtools::meta::list...>>; +}; +} // namespace +namespace gridtools::fn::backend { +namespace naive_impl_ { +template struct naive_with_threadpool; +template +storage::cpu_kfirst backend_storage_traits(naive_with_threadpool); +template +timer_dummy backend_timer_impl(naive_with_threadpool); +template +inline char const *backend_name(naive_with_threadpool const &) { + return "naive"; +} +} // namespace naive_impl_ + +namespace gpu_impl_ { +template struct gpu; +template +storage::gpu backend_storage_traits(gpu); +template timer_cuda backend_timer_impl(gpu); +template +inline char const *backend_name(gpu const &) { + return "gpu"; +} +} // namespace gpu_impl_ +} // namespace gridtools::fn::backend diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py new file mode 100644 index 0000000000..fe8b54f95c --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -0,0 +1,106 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import sys +from dataclasses import dataclass + +import gt4py.next as gtx +from gt4py.next.iterator.builtins import * +from gt4py.next.iterator.runtime import closure, fundef, offset +from gt4py.next.iterator.tracing import trace_fencil_definition +from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative + + +E2V = offset("E2V") +V2E = offset("V2E") + + +@fundef +def compute_zavgS(pp, S_M): + zavg = 0.5 * (deref(shift(E2V, 0)(pp)) + deref(shift(E2V, 1)(pp))) + return make_tuple(tuple_get(0, deref(S_M)) * zavg, tuple_get(1, deref(S_M)) * zavg) + + +@fundef +def tuple_dot_fun(acc, zavgS, sign): + return make_tuple( + tuple_get(0, acc) + tuple_get(0, zavgS) * sign, + tuple_get(1, acc) + tuple_get(1, zavgS) * sign, + ) + + +@fundef +def tuple_dot(a, b): + return reduce(tuple_dot_fun, make_tuple(0.0, 0.0))(a, b) + + +@fundef +def compute_pnabla(pp, S_M, sign, vol): + zavgS = lift(compute_zavgS)(pp, S_M) + pnabla_M = tuple_dot(neighbors(V2E, zavgS), deref(sign)) + return make_tuple(tuple_get(0, pnabla_M) / deref(vol), tuple_get(1, pnabla_M) / deref(vol)) + + +def zavgS_fencil(edge_domain, out, pp, S_M): + closure(edge_domain, compute_zavgS, out, [pp, S_M]) + + +Vertex = gtx.Dimension("Vertex") +K = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL) + + +def nabla_fencil(n_vertices, n_levels, out, pp, S_M, sign, vol): + closure( + unstructured_domain(named_range(Vertex, 0, n_vertices), named_range(K, 0, n_levels)), + compute_pnabla, + out, + [pp, S_M, sign, vol], + ) + + +@dataclass +class DummyConnectivity: + max_neighbors: int + has_skip_values: int + origin_axis: gtx.Dimension = gtx.Dimension("dummy_origin") + neighbor_axis: gtx.Dimension = gtx.Dimension("dummy_neighbor") + index_type: type[int] = int + + def mapped_index(_, __) -> int: + return 0 + + +if __name__ == "__main__": + if len(sys.argv) != 3: + raise RuntimeError(f"Usage: {sys.argv[0]} ") + output_file = sys.argv[1] + imperative = sys.argv[2].lower() == "true" + + if imperative: + backend = run_gtfn_imperative + else: + backend = run_gtfn + + # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils + prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) + offset_provider = { + "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), + "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), + } + generated_code = backend.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider=offset_provider, column_axis=None + ) + + with open(output_file, "w+") as output: + output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp new file mode 100644 index 0000000000..f0112a1636 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla_driver.cpp @@ -0,0 +1,122 @@ +#include + +#include "fn_select_gt4py.hpp" +#include +#include GENERATED_FILE + +namespace { +using namespace gridtools; +using namespace fn; +using namespace literals; + +using backend_t = + fn_backend_t::values<32, 8>>; + +// copied from gridtools::fn test +constexpr inline auto pp = [](int vertex, int k) { return (vertex + k) % 19; }; +constexpr inline auto sign = [](int vertex) { + return array{0, 1, vertex % 2, 1, (vertex + 1) % 2, 0}; +}; +constexpr inline auto vol = [](int vertex) { return vertex % 13 + 1; }; +constexpr inline auto s = [](int edge, int k) { + return tuple((edge + k) % 17, (edge + k) % 7); +}; + +constexpr inline auto zavg = [](auto const &e2v) { + return [&e2v](int edge, int k) { + double tmp = 0.0; + for (int neighbor = 0; neighbor < 2; ++neighbor) + tmp += pp(e2v(edge)[neighbor], k); + tmp /= 2.0; + return tuple{tmp * get<0>(s(edge, k)), tmp * get<1>(s(edge, k))}; + }; +}; + +constexpr inline auto make_zavg_expected = [](auto const &mesh) { + return [e2v_table = mesh.e2v_table()](int edge, int k) { + auto e2v = e2v_table->const_host_view(); + return zavg(e2v)(edge, k); + }; +}; + +constexpr inline auto expected = [](auto const &v2e, auto const &e2v) { + return [&v2e, zavg = zavg(e2v)](int vertex, int k) { + auto res = tuple(0.0, 0.0); + for (int neighbor = 0; neighbor < 6; ++neighbor) { + int edge = v2e(vertex)[neighbor]; + if (edge != -1) { + get<0>(res) += get<0>(zavg(edge, k)) * sign(vertex)[neighbor]; + get<1>(res) += get<1>(zavg(edge, k)) * sign(vertex)[neighbor]; + } + } + get<0>(res) /= vol(vertex); + get<1>(res) /= vol(vertex); + return res; + }; +}; + +constexpr inline auto make_expected = [](auto const &mesh) { + return [v2e_table = mesh.v2e_table(), + e2v_table = mesh.e2v_table()](int vertex, int k) { + auto v2e = v2e_table->const_host_view(); + auto e2v = e2v_table->const_host_view(); + return expected(v2e, e2v)(vertex, k); + }; +}; + +// GT_REGRESSION_TEST(unstructured_zavg, test_environment<>, fn_backend_t) { +// using float_t = typename TypeParam::float_t; + +// auto mesh = TypeParam::fn_unstructured_mesh(); +// auto actual = mesh.template make_storage>( +// mesh.nedges(), mesh.nlevels()); + +// auto pp_ = mesh.make_const_storage(pp, mesh.nvertices(), mesh.nlevels()); +// auto s_ = mesh.template make_const_storage>( +// s, mesh.nedges(), mesh.nlevels()); + +// auto e2v_conn = +// connectivity(mesh.e2v_table()->get_const_target_ptr()); +// auto edge_domain = +// unstructured_domain({mesh.nedges(), mesh.nlevels()}, {}, e2v_conn); + +// generated::zavgS_fencil(fn_backend_t{}, edge_domain, actual, pp_, s_); + +// auto expected = make_zavg_expected(mesh); +// TypeParam::verify(expected, actual); +// } + +GT_REGRESSION_TEST(unstructured_nabla, test_environment<>, backend_t) { + using float_t = typename TypeParam::float_t; + + auto mesh = TypeParam::fn_unstructured_mesh(); + auto actual = mesh.template make_storage>( + mesh.nvertices(), mesh.nlevels()); + + auto pp_ = mesh.make_const_storage(pp, mesh.nvertices(), mesh.nlevels()); + auto sign_ = mesh.template make_const_storage>( + sign, mesh.nvertices()); + auto vol_ = mesh.make_const_storage(vol, mesh.nvertices()); + auto s_ = mesh.template make_const_storage>( + s, mesh.nedges(), mesh.nlevels()); + + auto v2e_tbl = mesh.v2e_table(); + auto v2e_conn = + connectivity(v2e_tbl->get_const_target_ptr()); + + auto e2v_tbl = mesh.e2v_table(); + auto e2v_conn = + connectivity(e2v_tbl->get_const_target_ptr()); + auto vertex_domain = unstructured_domain({mesh.nvertices(), mesh.nlevels()}, + {}, v2e_conn, e2v_conn); + + generated::nabla_fencil(e2v_conn, v2e_conn)(backend_t{}, mesh.nvertices(), + mesh.nlevels(), actual, pp_, s_, + sign_, vol_); + + auto expected = make_expected(mesh); + TypeParam::verify(expected, actual); +} + +} // namespace diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py new file mode 100644 index 0000000000..7de56cb5bb --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os +import subprocess +import sys +from pathlib import Path + +import pytest + + +def _source_dir(): + return Path(__file__).resolve().parent + + +def _build_dir(backend_dir: str): + return _source_dir() / f"build_{backend_dir}" + + +def _execute_cmake(backend_str: str): + build_dir = _build_dir(backend_str) + build_dir.mkdir(exist_ok=True) + cmake = ["cmake", "-B", build_dir, f"-DBACKEND={backend_str}"] + subprocess.run(cmake, cwd=_source_dir(), check=True, stderr=sys.stderr) + + +def _get_available_cpu_count(): + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + return os.cpu_count() + + +def _execute_build(backend_str: str): + build = [ + "cmake", + "--build", + _build_dir(backend_str), + "--parallel", + str(_get_available_cpu_count()), + ] + subprocess.run(build, check=True) + + +def _execute_ctest(backend_str: str): + ctest = "ctest" + subprocess.run(ctest, cwd=_build_dir(backend_str), check=True) + + +backends = ["naive"] +try: + import cupy # TODO actually cupy is not the requirement but a CUDA compiler... + + backends.append("gpu") +except ImportError: + pass + + +@pytest.mark.parametrize("backend_str", backends) +def test_driver_cpp_backends(backend_str): + _execute_cmake(backend_str) + _execute_build(backend_str) + _execute_ctest(backend_str) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py new file mode 100644 index 0000000000..9755774fd0 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -0,0 +1,78 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import sys + +import gt4py.next as gtx +from gt4py.next.iterator.builtins import * +from gt4py.next.iterator.runtime import closure, fundef +from gt4py.next.iterator.tracing import trace_fencil_definition +from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.program_processors.runners.gtfn import run_gtfn + + +IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") +KDim = gtx.Dimension("KDim") + + +@fundef +def tridiag_forward(state, a, b, c, d): + return make_tuple( + deref(c) / (deref(b) - deref(a) * tuple_get(0, state)), + (deref(d) - deref(a) * tuple_get(1, state)) / (deref(b) - deref(a) * tuple_get(0, state)), + ) + + +@fundef +def tridiag_backward(x_kp1, cpdp): + cpdpv = deref(cpdp) + cp = tuple_get(0, cpdpv) + dp = tuple_get(1, cpdpv) + return dp - cp * x_kp1 + + +@fundef +def solve_tridiag(a, b, c, d): + cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) + return scan(tridiag_backward, False, 0.0)(cpdp) + + +def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): + closure( + cartesian_domain( + named_range(IDim, 0, isize), named_range(JDim, 0, jsize), named_range(KDim, 0, ksize) + ), + solve_tridiag, + x, + [a, b, c, d], + ) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + raise RuntimeError(f"Usage: {sys.argv[0]} ") + output_file = sys.argv[1] + + prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) + offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, + offset_provider=offset_provider, + runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, + column_axis=KDim, + ) + + with open(output_file, "w+") as output: + output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp new file mode 100644 index 0000000000..c51e3d4c41 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve_driver.cpp @@ -0,0 +1,54 @@ +#include + +#include "fn_select_gt4py.hpp" +#include +#include GENERATED_FILE + +#include + +namespace { +using namespace gridtools; +using namespace fn; +using namespace literals; + +constexpr inline auto a = [](auto...) { return -1; }; +constexpr inline auto b = [](auto...) { return 3; }; +constexpr inline auto c = [](auto...) { return 1; }; +constexpr inline auto d = [](int ksize) { + return [kmax = ksize - 1](auto, auto, auto k) { + return k == 0 ? 4 : k == kmax ? 2 : 3; + }; +}; +constexpr inline auto expected = [](auto...) { return 1; }; +using backend_t = + fn_backend_t::values<32, 8, 1>>; + +GT_REGRESSION_TEST(fn_cartesian_tridiagonal_solve, vertical_test_environment<>, + backend_t) { + using float_t = typename TypeParam::float_t; + + auto wrap = [](auto &&storage) { + return sid::rename_numbered_dimensions( + std::forward(storage)); + }; + + auto x = TypeParam::make_storage(); + auto x_wrapped = wrap(x); + auto a_wrapped = wrap(TypeParam::make_const_storage(a)); + auto b_wrapped = wrap(TypeParam::make_const_storage(b)); + auto c_wrapped = wrap(TypeParam::make_const_storage(c)); + auto d_wrapped = wrap(TypeParam::make_const_storage(d(TypeParam::d(2)))); + auto comp = [&] { + generated::tridiagonal_solve_fencil(tuple{})( + backend_t(), at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), + at_key(TypeParam::fn_cartesian_sizes()), a_wrapped, + b_wrapped, c_wrapped, d_wrapped, x_wrapped); + }; + comp(); + TypeParam::verify(expected, x); +} + +} // namespace From f37b3726b058c7d9d5ba1ec2c1747eae582e6860 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Apr 2024 20:01:06 +0200 Subject: [PATCH 41/61] fix quickstart guide again --- docs/user/next/QuickstartGuide.md | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index f8ff64a980..ee39831d45 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -77,7 +77,7 @@ array_of_ones_numpy = np.ones((grid_shape[0], grid_shape[1])) field_of_ones = gtx.ones( domain={I: range(grid_shape[0]), J: range(grid_shape[0])}, dtype=np.float64, - allocator=gtx.program_processors.runners.roundtrip.backend + allocator=gtx.itir_python ) ``` @@ -150,10 +150,21 @@ The sign of the edge difference in the sum of the pseudo-laplacian is always suc This section approaches the pseudo-laplacian by introducing the required APIs progressively through the following subsections: -- [Defining the mesh and the connectivities (adjacencies) between cells and edges](#Defining-the-mesh-and-its-connectivities) -- [Using connectivities in field operators](#Using-connectivities-in-field-operators) -- [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) -- [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) +- [Getting started with the GT4Py declarative frontend](#getting-started-with-the-gt4py-declarative-frontend) + - [Installation](#installation) + - [Programming guide](#programming-guide) + - [Key concepts and application structure](#key-concepts-and-application-structure) + - [Importing features](#importing-features) + - [Fields](#fields) + - [Field operators](#field-operators) + - [Programs](#programs) + - [Composing field operators and programs](#composing-field-operators-and-programs) + - [Operations on unstructured meshes](#operations-on-unstructured-meshes) + - [Defining the mesh and its connectivities](#defining-the-mesh-and-its-connectivities) + - [Using connectivities in field operators](#using-connectivities-in-field-operators) + - [Using reductions on connected mesh elements](#using-reductions-on-connected-mesh-elements) + - [Using conditionals on fields](#using-conditionals-on-fields) + - [Implementing the pseudo-laplacian](#implementing-the-pseudo-laplacian) +++ From e242ab61aa9a108b57c742c9617aa94a70a07ea6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 19 Apr 2024 09:35:23 +0200 Subject: [PATCH 42/61] remove runtime lift --- .../codegens/gtfn/gtfn_module.py | 20 ++----------------- .../program_processors/formatters/gtfn.py | 1 - .../runners/dace_iterator/__init__.py | 3 --- .../runners/dace_iterator/workflow.py | 17 +--------------- .../next/program_processors/runners/gtfn.py | 2 -- .../cpp_backend_tests/tridiagonal_solve.py | 1 - 6 files changed, 3 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d28e513093..52b9376b82 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -16,7 +16,6 @@ import dataclasses import functools -import warnings from typing import Any, Callable, Final, Optional import factory @@ -182,26 +181,13 @@ def _preprocess_program( self, program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], - runtime_lift_mode: Optional[LiftMode], ) -> itir.FencilDefinition | global_tmps.FencilWithTemporaries: - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode and runtime_lift_mode != self.lift_mode: - warnings.warn( - f"GTFN Backend was configured for LiftMode `{self.lift_mode!s}`, but " - f"overriden to be {runtime_lift_mode!s} at runtime.", - stacklevel=2, - ) - if not self.enable_itir_transforms: return program apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, - lift_mode=lift_mode, + lift_mode=self.lift_mode, offset_provider=offset_provider, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, @@ -228,9 +214,8 @@ def generate_stencil_source( program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], column_axis: Optional[common.Dimension], - runtime_lift_mode: Optional[LiftMode] = None, ) -> str: - new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + new_program = self._preprocess_program(program, offset_provider) program_itir = fencil_to_program.FencilToProgram().apply( new_program ) # TODO(havogt): should be removed after refactoring to combined IR @@ -278,7 +263,6 @@ def __call__( program, inp.kwargs["offset_provider"], inp.kwargs.get("column_axis", None), - inp.kwargs.get("lift_mode", None), ) source_code = interface.format_source( self._language_settings(), diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index 6c8d4478c2..632974a787 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -29,5 +29,4 @@ def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str program, offset_provider=kwargs.get("offset_provider", None), column_axis=kwargs.get("column_axis", None), - runtime_lift_mode=kwargs.get("lift_mode", None), ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 2b56dc0420..2bbf068d53 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -248,9 +248,6 @@ def build_sdfg_from_itir( load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically. - - Notes: - Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored. """ sdfg_filename = f"_dacegraphs/gt4py/{program.id}.sdfg" diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 96a40e7450..1a7a36b8c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -15,7 +15,6 @@ from __future__ import annotations import dataclasses -import warnings from typing import Callable, Optional, cast import dace @@ -62,22 +61,9 @@ def generate_sdfg( arg_types: list[ts.TypeSpec], offset_provider: dict[str, common.Dimension | common.Connectivity], column_axis: Optional[common.Dimension], - runtime_lift_mode: Optional[LiftMode] = None, ) -> dace.SDFG: on_gpu = True if self.device_type == core_defs.DeviceType.CUDA else False - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode and runtime_lift_mode != self.lift_mode: - warnings.warn( - f"DaCe Backend was configured for LiftMode `{self.lift_mode!s}`, but " - f"overriden to be {runtime_lift_mode!s} at runtime.", - stacklevel=2, - ) - return build_sdfg_from_itir( program, arg_types, @@ -85,7 +71,7 @@ def generate_sdfg( auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, - lift_mode=lift_mode, + lift_mode=self.lift_mode, symbolic_domain_sizes=self.symbolic_domain_sizes, temporary_extraction_heuristics=self.temporary_extraction_heuristics, load_sdfg_from_file=False, @@ -105,7 +91,6 @@ def __call__( arg_types, inp.kwargs["offset_provider"], inp.kwargs.get("column_axis", None), - inp.kwargs.get("lift_mode", None), ) param_types = tuple( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 39ec607323..f88dcc5825 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -109,8 +109,6 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: content_hash(tuple(from_value(arg) for arg in otf_closure.args)), id(offset_provider) if offset_provider else None, otf_closure.kwargs.get("column_axis", None), - # TODO(tehrengruber): Remove `lift_mode` from call interface. - otf_closure.kwargs.get("lift_mode", None), ) ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py index 9755774fd0..6eb66240f7 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -70,7 +70,6 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider=offset_provider, - runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, column_axis=KDim, ) From a9f1043a4e4452a01b72e7e1f4407e00eba3f809 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 19 Apr 2024 10:40:45 +0200 Subject: [PATCH 43/61] Update docs/user/next/QuickstartGuide.md --- docs/user/next/QuickstartGuide.md | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index ee39831d45..81604c7770 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -150,21 +150,10 @@ The sign of the edge difference in the sum of the pseudo-laplacian is always suc This section approaches the pseudo-laplacian by introducing the required APIs progressively through the following subsections: -- [Getting started with the GT4Py declarative frontend](#getting-started-with-the-gt4py-declarative-frontend) - - [Installation](#installation) - - [Programming guide](#programming-guide) - - [Key concepts and application structure](#key-concepts-and-application-structure) - - [Importing features](#importing-features) - - [Fields](#fields) - - [Field operators](#field-operators) - - [Programs](#programs) - - [Composing field operators and programs](#composing-field-operators-and-programs) - - [Operations on unstructured meshes](#operations-on-unstructured-meshes) - - [Defining the mesh and its connectivities](#defining-the-mesh-and-its-connectivities) - - [Using connectivities in field operators](#using-connectivities-in-field-operators) - - [Using reductions on connected mesh elements](#using-reductions-on-connected-mesh-elements) - - [Using conditionals on fields](#using-conditionals-on-fields) - - [Implementing the pseudo-laplacian](#implementing-the-pseudo-laplacian) +- [Defining the mesh and the connectivities (adjacencies) between cells and edges](#Defining-the-mesh-and-its-connectivities) +- [Using connectivities in field operators](#Using-connectivities-in-field-operators) +- [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) +- [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) +++ From e7195a52d5ee4693d1f2384e2b2608f368edc733 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 19 Apr 2024 23:19:47 +0200 Subject: [PATCH 44/61] cleanup out field construction --- src/gt4py/next/embedded/operators.py | 38 ++-------- src/gt4py/next/field_utils.py | 29 +++++++ src/gt4py/next/iterator/embedded.py | 73 ++++++++++-------- .../next/type_system/type_specifications.py | 3 + .../next/type_system/type_translation.py | 29 +++++++ src/gt4py/next/utils.py | 76 +++++++++++++++---- tests/next_tests/unit_tests/test_utils.py | 60 +++++++++++++++ 7 files changed, 234 insertions(+), 74 deletions(-) create mode 100644 tests/next_tests/unit_tests/test_utils.py diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index dd197ac2c5..26ccae0d85 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -13,15 +13,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses -from types import ModuleType from typing import Any, Callable, Generic, Optional, ParamSpec, Sequence, TypeVar -import numpy as np - from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.next import common, errors, utils +from gt4py.next import common, errors, field_utils, utils from gt4py.next.embedded import common as embedded_common, context as embedded_context +from gt4py.next.field_utils import get_array_ns +from gt4py.next.type_system import type_translation _P = ParamSpec("_P") @@ -64,8 +63,10 @@ def __call__( # type: ignore[override] # even if the scan dimension is not in the input, we can scan over it out_domain = common.Domain(*out_domain, (scan_range)) - xp = _get_array_ns(*all_args) - res = _construct_scan_array(out_domain, xp)(self.init) + xp = get_array_ns(*all_args) + res = field_utils.field_from_typespec(out_domain, xp)( + type_translation.from_value(self.init) + ) def scan_loop(hpos: Sequence[common.NamedIndex]) -> None: acc: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] = self.init @@ -163,31 +164,6 @@ def _intersect_scan_args( ) -def _get_array_ns( - *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], -) -> ModuleType: - for arg in utils.flatten_nested_tuple(args): - if hasattr(arg, "array_ns"): - return arg.array_ns - return np - - -def _construct_scan_array( - domain: common.Domain, - xp: ModuleType, # TODO(havogt) introduce a NDArrayNamespace protocol -) -> Callable[ - [core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]], - common.MutableField | tuple[common.MutableField | tuple, ...], -]: - @utils.tree_map - def impl(init: core_defs.Scalar) -> common.MutableField: - res = common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain) - assert isinstance(res, common.MutableField) - return res - - return impl - - def _tuple_assign_value( pos: Sequence[common.NamedIndex], target: common.MutableField | tuple[common.MutableField | tuple, ...], diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py index 04c51d3698..8e3b5e032a 100644 --- a/src/gt4py/next/field_utils.py +++ b/src/gt4py/next/field_utils.py @@ -12,11 +12,40 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from types import ModuleType +from typing import Callable + import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import common, utils +from gt4py.next.type_system import type_specifications as ts, type_translation @utils.tree_map def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: return field.asnumpy() if isinstance(field, common.Field) else field + + +def field_from_typespec( + domain: common.Domain, xp: ModuleType +) -> Callable[..., common.MutableField | tuple[common.MutableField | tuple, ...]]: + @utils.tree_map(collection_type=ts.TupleType, result_collection_type=tuple) + def impl(type_: ts.ScalarType) -> common.MutableField: + res = common._field( + xp.empty(domain.shape, dtype=xp.dtype(type_translation.as_dtype(type_).scalar_type)), + domain=domain, + ) + assert isinstance(res, common.MutableField) + return res + + return impl + + +def get_array_ns( + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], +) -> ModuleType: + for arg in utils.flatten_nested_tuple(args): + if hasattr(arg, "array_ns"): + return arg.array_ns + return np diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index d3ccddba40..565cba6c0b 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -52,7 +52,7 @@ overload, runtime_checkable, ) -from gt4py.next import common +from gt4py.next import common, field_utils from gt4py.next.embedded import ( context as embedded_context, exceptions as embedded_exceptions, @@ -60,6 +60,7 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime +from gt4py.next.type_system import type_specifications as ts, type_translation EMBEDDED = "embedded" @@ -210,6 +211,10 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: data if isinstance(data, np.ndarray) else np.full(len(column_range.unit_range), data) ) + @property + def dtype(self) -> np.dtype: + return self.data.dtype + def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] # numpy type @@ -792,7 +797,7 @@ def _make_tuple( raise RuntimeError( "Found 'Undefined' value, this should not happen for a legal program." ) - dtype = _column_dtype(first) + dtype = _elem_dtype(first) return Column(column_range.start, np.asarray(col, dtype=dtype)) @@ -1472,11 +1477,12 @@ def as_tuple_field(field: tuple | TupleField) -> TupleField: return TupleOfFields(tuple(_wrap_field(f) for f in field)) -def _column_dtype(elem: Any) -> np.dtype: +def _elem_dtype(elem: Any) -> np.dtype: + if hasattr(elem, "dtype"): + return elem.dtype if isinstance(elem, tuple): - return np.dtype([(f"f{i}", _column_dtype(e)) for i, e in enumerate(elem)]) - else: - return np.dtype(type(elem)) + return np.dtype([(f"f{i}", _elem_dtype(e)) for i, e in enumerate(elem)]) + return np.dtype(type(elem)) @builtins.scan.register(EMBEDDED) @@ -1488,7 +1494,7 @@ def impl(*iters: ItIterator): sorted_column_range = column_range if is_forward else reversed(column_range) state = init - col = Column(column_range.start, np.zeros(len(column_range), dtype=_column_dtype(init))) + col = Column(column_range.start, np.zeros(len(column_range), dtype=_elem_dtype(init))) for i in sorted_column_range: state = scan_pass(state, *map(shifted_scan_arg(i), iters)) col[i] = state @@ -1547,36 +1553,43 @@ def _extract_column_range(domain) -> common.NamedRange | eve.NothingType: return eve.NOTHING -# TODO handle in clean way -def _np_void_to_tuple(a): - if isinstance(a, np.void): - return tuple(_np_void_to_tuple(elem) for elem in a) - return a +def _structured_dtype_to_typespec(structured_dtype: np.dtype) -> ts.ScalarType | ts.TupleType: + if structured_dtype.names is None: + return type_translation.from_dtype(core_defs.dtype(structured_dtype)) + return ts.TupleType( + types=[ + _structured_dtype_to_typespec(structured_dtype[name]) for name in structured_dtype.names + ] + ) + + +def _get_output_type( + fun: Callable, + domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, + args: tuple[Any, ...], +) -> ts.TypeSpec: + domain = _dimension_to_tag(domain_) + col_range = _extract_column_range(domain) + if isinstance(col_range, common.NamedRange): + del domain[col_range.dim.value] + + pos = next(iter(_domain_iterator(domain))) + with embedded_context.new_context(closure_column_range=col_range) as ctx: + single_point_result = ctx.run(_compute_point, fun, args, pos, col_range) + dtype = _elem_dtype(single_point_result) + return _structured_dtype_to_typespec(dtype) @builtins.as_fieldop.register(EMBEDDED) -def as_fieldop(fun: Callable, domain_: runtime.CartesianDomain | runtime.UnstructuredDomain): +def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.UnstructuredDomain): def impl(*args): - # TODO extract function, move private utils - domain = _dimension_to_tag(domain_) - col_range = _extract_column_range(domain) - if col_range is not eve.NOTHING: - del domain[col_range.dim.value] - - pos = next(_domain_iterator(domain)) - with embedded_context.new_context(closure_column_range=col_range) as ctx: - single_point_result = ctx.run(_compute_point, fun, args, pos, col_range) - if isinstance(single_point_result, Column): - single_point_result = single_point_result.data[0] - single_point_result = _np_void_to_tuple(single_point_result) - - xp = operators._get_array_ns(*args) - out = operators._construct_scan_array(common.domain(domain_), xp)(single_point_result) + xp = field_utils.get_array_ns(*args) + type_ = _get_output_type(fun, domain, args) + out = field_utils.field_from_typespec(common.domain(domain), xp)(type_) # TODO `out` gets allocated in the order of domain_, but might not match the order of `target` in set_at - closure( - _dimension_to_tag(domain_), + _dimension_to_tag(domain), fun, out, list(args), diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 9487d2f12b..21399b5f74 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -102,6 +102,9 @@ def __str__(self) -> str: def __iter__(self) -> Iterator[DataType]: yield from self.types + def __len__(self) -> int: + return len(self.types) + @dataclass(frozen=True) class FieldType(DataType, CallableType): diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 396d4c06b6..429e551738 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -22,6 +22,7 @@ import numpy as np import numpy.typing as npt +from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.next import common from gt4py.next.type_system import type_info, type_specifications as ts @@ -219,3 +220,31 @@ def from_value(value: Any) -> ts.TypeSpec: return symbol_type else: raise ValueError(f"Impossible to map '{value}' value to a 'Symbol'.") + + +def as_dtype(type_: ts.ScalarType) -> core_defs.DType: + if type_.kind == ts.ScalarKind.BOOL: + return core_defs.BoolDType() + elif type_.kind == ts.ScalarKind.INT32: + return core_defs.Int32DType() + elif type_.kind == ts.ScalarKind.INT64: + return core_defs.Int64DType() + elif type_.kind == ts.ScalarKind.FLOAT32: + return core_defs.Float32DType() + elif type_.kind == ts.ScalarKind.FLOAT64: + return core_defs.Float64DType() + raise ValueError(f"Scalar type '{type_}' not supported.") + + +def from_dtype(dtype: core_defs.DType) -> ts.ScalarType: + if dtype == core_defs.BoolDType(): + return ts.ScalarType(kind=ts.ScalarKind.BOOL) + elif dtype == core_defs.Int32DType(): + return ts.ScalarType(kind=ts.ScalarKind.INT32) + elif dtype == core_defs.Int64DType(): + return ts.ScalarType(kind=ts.ScalarKind.INT64) + elif dtype == core_defs.Float32DType(): + return ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + elif dtype == core_defs.Float64DType(): + return ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + raise ValueError(f"DType '{dtype}' not supported.") diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 14caa1ae3e..6aa70dab5b 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import functools -from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast +from typing import Any, Callable, ClassVar, Optional, ParamSpec, TypeGuard, TypeVar, cast, overload class RecursionGuard: @@ -69,9 +69,31 @@ def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]: return (value,) -def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: +@overload +def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... + + +@overload +def tree_map( + *, collection_type: type | tuple[type, ...], result_collection_type: Optional[type] = None +) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ... + + +def tree_map( + *args: Callable[_P, _R], + collection_type: type | tuple[type, ...] = tuple, + result_collection_type: Optional[type] = None, +) -> ( + Callable[..., _R | tuple[_R | tuple, ...]] + | Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]] +): """ - Apply `fun` to each entry of (possibly nested) tuples. + Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s). + + Args: + fun: Function to apply to each entry of the collection. + collection_type: Type of the collection to be traversed. Can be a single type or a tuple of types. + result_collection_type: Type of the collection to be returned. If `None` the same type as `collection_type` is used. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) @@ -79,16 +101,44 @@ def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...] >>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6)) ((5, 7), 9) - """ - @functools.wraps(fun) - def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: - if isinstance(args[0], tuple): - assert all(isinstance(arg, tuple) and len(args[0]) == len(arg) for arg in args) - return tuple(impl(*arg) for arg in zip(*args)) + >>> tree_map(collection_type=list)(lambda x: x + 1)([[1, 2], 3]) + [[2, 3], 4] - return fun( - *cast(_P.args, args) - ) # mypy doesn't understand that `args` at this point is of type `_P.args` + >>> tree_map(collection_type=list, result_collection_type=tuple)(lambda x: x + 1)([[1, 2], 3]) + ((2, 3), 4) + """ - return impl + if result_collection_type is None: + if isinstance(collection_type, tuple): + raise TypeError( + "tree_map() requires `result_collection_type` when `collection_type` is a tuple." + ) + result_collection_type = collection_type + + if len(args) == 1: + fun = args[0] + + @functools.wraps(fun) + def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: + if isinstance(args[0], collection_type): + assert all( + isinstance(arg, collection_type) and len(args[0]) == len(arg) for arg in args + ) + assert result_collection_type is not None + return result_collection_type(impl(*arg) for arg in zip(*args)) + + return fun( + *cast(_P.args, args) + ) # mypy doesn't understand that `args` at this point is of type `_P.args` + + return impl + if len(args) == 0: + return functools.partial( + tree_map, + collection_type=collection_type, + result_collection_type=result_collection_type, + ) + raise TypeError( + "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_type`." + ) diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py new file mode 100644 index 0000000000..42b16ae2de --- /dev/null +++ b/tests/next_tests/unit_tests/test_utils.py @@ -0,0 +1,60 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest + +from gt4py.next import utils + + +def test_tree_map_default(): + @utils.tree_map + def testee(x): + return x + 1 + + assert testee(((1, 2), 3)) == ((2, 3), 4) + + +def test_tree_map_multi_arg(): + @utils.tree_map + def testee(x, y): + return x + y + + assert testee(((1, 2), 3), ((4, 5), 6)) == ((5, 7), 9) + + +def test_tree_map_custom_input_type(): + @utils.tree_map(collection_type=list) + def testee(x): + return x + 1 + + assert testee([[1, 2], 3]) == [[2, 3], 4] + + with pytest.raises(TypeError): + testee(((1, 2), 3)) # tries to `((1, 2), 3)` because `tuple_type` is `list` + + +def test_tree_map_custom_output_type(): + @utils.tree_map(result_collection_type=list) + def testee(x): + return x + 1 + + assert testee(((1, 2), 3)) == [[2, 3], 4] + + +def test_tree_map_multiple_input_types(): + @utils.tree_map(collection_type=(list, tuple), result_collection_type=tuple) + def testee(x): + return x + 1 + + assert testee([(1, [2]), 3]) == ((2, (3,)), 4) From 3f677466cf19572cd49814c923562fec9873f6aa Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 22 Apr 2024 13:52:37 +0200 Subject: [PATCH 45/61] Update src/gt4py/next/program_processors/runners/double_roundtrip.py Co-authored-by: Till Ehrengruber --- src/gt4py/next/program_processors/runners/double_roundtrip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 6d3f368170..0eb6a10000 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -20,7 +20,7 @@ backend = next_backend.Backend( executor=roundtrip.RoundtripExecutorFactory( - dispatch_backend=roundtrip.RoundtripExecutorFactory() + dispatch_backend=roundtrip.default.executor ), allocator=roundtrip.default.allocator, ) From 369eae714ada9b0e207b917622410fdc9c74cb1c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 22 Apr 2024 13:54:18 +0200 Subject: [PATCH 46/61] read config.DEBUG at execution --- src/gt4py/next/program_processors/runners/roundtrip.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 408451648e..d90e7e5c8f 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -209,10 +209,11 @@ def execute_roundtrip( *args: Any, column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], - debug: bool = config.DEBUG, + debug: Optional[bool] = None, lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, dispatch_backend: Optional[ppi.ProgramExecutor] = None, ) -> None: + debug = debug if debug is not None else config.DEBUG fencil = fencil_generator( ir, offset_provider=offset_provider, @@ -230,15 +231,16 @@ def execute_roundtrip( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: bool = config.DEBUG + debug: Optional[bool] = None lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: + debug = config.DEBUG if self.debug is None else self.debug return fencil_generator( inp.program, offset_provider=inp.kwargs.get("offset_provider", None), - debug=self.debug, + debug=debug, lift_mode=self.lift_mode, use_embedded=self.use_embedded, ) From 35f2132d0010ea91bf54890dc4c7904f8da15db6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 22 Apr 2024 13:55:39 +0200 Subject: [PATCH 47/61] remove LiftMode.SIMPLE_HEURISTIC --- src/gt4py/next/iterator/transforms/pass_manager.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8a19203275..32b42f8d2b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -17,7 +17,6 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import simple_inline_heuristic from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -41,14 +40,11 @@ class LiftMode(enum.Enum): FORCE_INLINE = enum.auto() USE_TEMPORARIES = enum.auto() - SIMPLE_HEURISTIC = enum.auto() def _inline_lifts(ir, lift_mode): if lift_mode == LiftMode.FORCE_INLINE: return InlineLifts().visit(ir) - elif lift_mode == LiftMode.SIMPLE_HEURISTIC: - return InlineLifts(simple_inline_heuristic.is_eligible_for_inlining).visit(ir) elif lift_mode == LiftMode.USE_TEMPORARIES: return InlineLifts( flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT From 4a4f9b18c70827e60caf28de01dc54ef8b5c5b74 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 23 Apr 2024 08:32:15 +0200 Subject: [PATCH 48/61] fix formatting --- src/gt4py/next/program_processors/runners/double_roundtrip.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index 0eb6a10000..0b0f71c2f7 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -19,8 +19,6 @@ backend = next_backend.Backend( - executor=roundtrip.RoundtripExecutorFactory( - dispatch_backend=roundtrip.default.executor - ), + executor=roundtrip.RoundtripExecutorFactory(dispatch_backend=roundtrip.default.executor), allocator=roundtrip.default.allocator, ) From bfcc11873b0fad120c0dd9ea5cd0422cc80d3d75 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 15 May 2024 16:33:14 +0200 Subject: [PATCH 49/61] move ordering of unstructured domain to gtfn --- src/gt4py/next/ffront/past_to_itir.py | 5 ---- .../codegens/gtfn/itir_to_gtfn_ir.py | 28 +++++++++++++++++++ .../ffront_tests/test_gt4py_builtins.py | 2 -- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 9d6e050ab3..b0a348f60b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -349,11 +349,6 @@ def _construct_itir_domain_arg( domain_builtin = "cartesian_domain" elif self.grid_type == common.GridType.UNSTRUCTURED: domain_builtin = "unstructured_domain" - # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) - if domain_args_kind[0] == common.DimensionKind.VERTICAL: - assert len(domain_args) == 2 - assert domain_args_kind[1] == common.DimensionKind.HORIZONTAL - domain_args[0], domain_args[1] = domain_args[1], domain_args[0] else: raise AssertionError() diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index e5ba965e3b..ef7f0f4ffc 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -206,6 +206,32 @@ def _is_applied_as_fieldop(arg: itir.Expr) -> TypeGuard[itir.FunCall]: ) +class _CannonicalizeUnstructuredDomain(eve.NodeTranslator): + def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: + if node.fun == itir.SymRef(id="unstructured_domain"): + # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) + assert isinstance(node.args[0], itir.FunCall) + first_axis_literal = node.args[0].args[0] + assert isinstance(first_axis_literal, itir.AxisLiteral) + if first_axis_literal.kind == itir.DimensionKind.VERTICAL: + assert len(node.args) == 2 + assert isinstance(node.args[1], itir.FunCall) + assert isinstance(node.args[1].args[0], itir.AxisLiteral) + assert node.args[1].args[0].kind == itir.DimensionKind.HORIZONTAL + return itir.FunCall(fun=node.fun, args=[node.args[1], node.args[0]]) + return node + + @classmethod + def apply( + cls, + node: itir.Program, + ) -> itir.Program: + if not isinstance(node, itir.Program): + raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") + + return cls().visit(node) + + @dataclasses.dataclass(frozen=True) class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): _binary_op_map: ClassVar[dict[str, str]] = { @@ -248,6 +274,8 @@ def apply( raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") grid_type = _get_gridtype(node.body) + if grid_type == common.GridType.UNSTRUCTURED: + node = _CannonicalizeUnstructuredDomain.apply(node) return cls( offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type ).visit(node) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 404a91eaa5..0bb1e78582 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -111,8 +111,6 @@ def reduction_ke_field( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case, fop): - if fop == reduction_ke_field: # TODO need to resolve order of dimensions - pytest.skip() v2e_table = unstructured_case.offset_provider["V2E"].table edge_f = cases.allocate(unstructured_case, fop, "edge_f")() From 54f44cc6769cd7d5145438c43d9f1d696cc49a1c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 16 May 2024 10:28:29 +0200 Subject: [PATCH 50/61] fix problem in column dtype if contains None --- src/gt4py/next/iterator/embedded.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 565cba6c0b..5d1677e225 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -213,7 +213,8 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: @property def dtype(self) -> np.dtype: - return self.data.dtype + # not directly dtype of `self.data` as that might be a structured type containing `None` + return np.dtype(type(self.data[self.kstart])) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] From b8b26e65d716d775b6cfe80128bc73a51f34fee5 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 16 May 2024 10:46:33 +0200 Subject: [PATCH 51/61] address more review comments --- src/gt4py/next/iterator/embedded.py | 13 ++++--------- src/gt4py/next/iterator/tracing.py | 4 +++- src/gt4py/next/iterator/transforms/pass_manager.py | 1 + 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 5d1677e225..b8427d539f 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1518,8 +1518,7 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: @runtime.set_at.register(EMBEDDED) -def set_at(expr, domain, target) -> None: - # TODO we can't set the column_range here, because it's too late: `expr` already evaluated +def set_at(expr: common.Field, domain: common.DomainLike, target: common.MutableField) -> None: operators._tuple_assign_field(target, expr, common.domain(domain)) @@ -1588,13 +1587,9 @@ def impl(*args): type_ = _get_output_type(fun, domain, args) out = field_utils.field_from_typespec(common.domain(domain), xp)(type_) - # TODO `out` gets allocated in the order of domain_, but might not match the order of `target` in set_at - closure( - _dimension_to_tag(domain), - fun, - out, - list(args), - ) + # TODO(havogt): after updating all tests to use the new program, + # we should get rid of closure and move the implementation to this function + closure(_dimension_to_tag(domain), fun, out, list(args)) return out return impl diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 2d85245214..5c1d2030f3 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -209,7 +209,9 @@ def __bool__(self): class TracerContext: fundefs: ClassVar[List[FunctionDefinition]] = [] - closures: ClassVar[List[StencilClosure]] = [] + closures: ClassVar[ + List[StencilClosure] + ] = [] # TODO(havogt): remove after refactoring to `Program` is complete, currently handles both programs and fencils body: ClassVar[List[itir.Stmt]] = [] @classmethod diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ef8283605e..bddf25db3a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -87,6 +87,7 @@ def apply_common_transforms( ) -> itir.FencilDefinition | FencilWithTemporaries | itir.Program: if isinstance(ir, itir.Program): # TODO(havogt): during refactoring to GTIR, we bypass transformations in case we already translated to itir.Program + # (currently the case when using the roundtrip backend) return ir icdlv_uids = eve_utils.UIDGenerator() From 9ee02e4be247db4002b537814e1749e0caf7cbd0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 16 May 2024 13:38:47 +0200 Subject: [PATCH 52/61] fix tuples in columns --- src/gt4py/next/iterator/embedded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b8427d539f..b399bbde08 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -214,7 +214,7 @@ def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: @property def dtype(self) -> np.dtype: # not directly dtype of `self.data` as that might be a structured type containing `None` - return np.dtype(type(self.data[self.kstart])) + return _elem_dtype(self.data[self.kstart]) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] From d10463334cecc59f696fbd7d2073b9f9cc50686d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 17 May 2024 10:44:00 +0200 Subject: [PATCH 53/61] fix preserve axis kind in global tmps --- src/gt4py/next/iterator/transforms/global_tmps.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a3260d5a37..f554a63e7f 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -408,7 +408,7 @@ def translate(self, distance: int) -> "SymbolicRange": @dataclasses.dataclass class SymbolicDomain: grid_type: Literal["unstructured_domain", "cartesian_domain"] - ranges: dict[str, SymbolicRange] + ranges: dict[ir.AxisLiteral, SymbolicRange] @classmethod def from_expr(cls, node: ir.Node): @@ -417,7 +417,7 @@ def from_expr(cls, node: ir.Node): im.ref("cartesian_domain"), ] - ranges: dict[str, SymbolicRange] = {} + ranges: dict[ir.AxisLiteral, SymbolicRange] = {} for named_range in node.args: assert ( isinstance(named_range, ir.FunCall) @@ -427,13 +427,13 @@ def from_expr(cls, node: ir.Node): axis_literal, lower_bound, upper_bound = named_range.args assert isinstance(axis_literal, ir.AxisLiteral) - ranges[axis_literal.value] = SymbolicRange(lower_bound, upper_bound) + ranges[axis_literal] = SymbolicRange(lower_bound, upper_bound) return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above def as_expr(self): return im.call(self.grid_type)( *[ - im.call("named_range")(ir.AxisLiteral(value=d), r.start, r.stop) + im.call("named_range")(ir.AxisLiteral(value=d.value, kind=d.kind), r.start, r.stop) for d, r in self.ranges.items() ] ) From 66e2464312c14437abd5ca07408a851bced4b752 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 17 May 2024 11:26:28 +0200 Subject: [PATCH 54/61] fix follow up issue --- src/gt4py/next/iterator/transforms/global_tmps.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index f554a63e7f..95777f368e 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -408,7 +408,9 @@ def translate(self, distance: int) -> "SymbolicRange": @dataclasses.dataclass class SymbolicDomain: grid_type: Literal["unstructured_domain", "cartesian_domain"] - ranges: dict[ir.AxisLiteral, SymbolicRange] + ranges: dict[ + common.Dimension, SymbolicRange + ] # TODO(havogt): remove `AxisLiteral` by `Dimension` everywhere @classmethod def from_expr(cls, node: ir.Node): @@ -417,7 +419,7 @@ def from_expr(cls, node: ir.Node): im.ref("cartesian_domain"), ] - ranges: dict[ir.AxisLiteral, SymbolicRange] = {} + ranges: dict[common.Dimension, SymbolicRange] = {} for named_range in node.args: assert ( isinstance(named_range, ir.FunCall) @@ -427,7 +429,9 @@ def from_expr(cls, node: ir.Node): axis_literal, lower_bound, upper_bound = named_range.args assert isinstance(axis_literal, ir.AxisLiteral) - ranges[axis_literal] = SymbolicRange(lower_bound, upper_bound) + ranges[common.Dimension(value=axis_literal.value, kind=axis_literal.kind)] = ( + SymbolicRange(lower_bound, upper_bound) + ) return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above def as_expr(self): @@ -514,7 +518,7 @@ def update_domains( for offset_name, offset in _group_offsets(shift_chain): if isinstance(offset_provider[offset_name], gtx.Dimension): # cartesian shift - dim = offset_provider[offset_name].value + dim = offset_provider[offset_name] consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) elif isinstance(offset_provider[offset_name], common.Connectivity): # unstructured shift From 56e0086e4f8e8bbea163dd643a1e30e7a7316c0e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 10 Jun 2024 09:24:50 +0200 Subject: [PATCH 55/61] Start using tree_map instead of apply_to_primitive_constituents --- .../ffront/foast_passes/type_deduction.py | 13 +++++--- src/gt4py/next/ffront/type_info.py | 19 ++++++++--- src/gt4py/next/field_utils.py | 2 +- src/gt4py/next/type_system/type_info.py | 2 ++ src/gt4py/next/utils.py | 24 +++++++------- tests/next_tests/unit_tests/test_utils.py | 32 +++++++++++++++++-- 6 files changed, 67 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index fad4df8c84..13a60caba1 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -833,12 +833,15 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call: f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.", ) - return_type = type_info.apply_to_primitive_constituents( - value.type, - lambda primitive_type: with_altered_scalar_kind( + # return_type = type_info.apply_to_primitive_constituents( + # value.type, + # lambda primitive_type: with_altered_scalar_kind( + # primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) + # ), + # ) + return_type = type_info.type_tree_map(lambda primitive_type: with_altered_scalar_kind( primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) - ), - ) + ))(value.type) assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType)) return foast.Call( diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 2bd4f21993..5188facc24 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -17,7 +17,7 @@ import gt4py.next.ffront.type_specifications as ts_ffront import gt4py.next.type_system.type_specifications as ts -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.type_system import type_info @@ -37,8 +37,9 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec: return ts.FieldType(dims=[], dtype=type_el) return type_el - return type_info.apply_to_primitive_constituents(type_, promote_el) + # return type_info.apply_to_primitive_constituents(type_, promote_el) + return type_info.type_tree_map(promote_el)(type_) def promote_zero_dims( function_type: ts.FunctionType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] @@ -308,6 +309,14 @@ def return_type_scanop( # field [callable_type.axis], ) - return type_info.apply_to_primitive_constituents( - carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)) - ) + + # return type_info.apply_to_primitive_constituents( + # carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)) + # ) + # @utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=lambda x: ts.TupleType(types=[*x])) + # def tmp(x): + # return ts.FieldType(dims=promoted_dims, dtype=x) + + #return tmp(carry_dtype) + + return type_info.type_tree_map(lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)))(carry_dtype) \ No newline at end of file diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py index 8e3b5e032a..5f7950b4a5 100644 --- a/src/gt4py/next/field_utils.py +++ b/src/gt4py/next/field_utils.py @@ -30,7 +30,7 @@ def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: def field_from_typespec( domain: common.Domain, xp: ModuleType ) -> Callable[..., common.MutableField | tuple[common.MutableField | tuple, ...]]: - @utils.tree_map(collection_type=ts.TupleType, result_collection_type=tuple) + @utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=tuple) def impl(type_: ts.ScalarType) -> common.MutableField: res = common._field( xp.empty(domain.shape, dtype=xp.dtype(type_translation.as_dtype(type_).scalar_type)), diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index a05b9afde8..da64c46f3c 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -23,6 +23,7 @@ from gt4py.eve.utils import XIterable, xiter from gt4py.next import common from gt4py.next.type_system import type_specifications as ts +from gt4py.next import utils def _number_to_ordinal_number(number: int) -> str: @@ -182,6 +183,7 @@ def apply_to_primitive_constituents( else: return fun(symbol_type) # type: ignore[call-arg] # mypy not aware of `with_path_arg` +type_tree_map=utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=lambda x: ts.TupleType(types=[*x])) def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: """ diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 6aa70dab5b..c17a7bb317 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -75,14 +75,14 @@ def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, . @overload def tree_map( - *, collection_type: type | tuple[type, ...], result_collection_type: Optional[type] = None + *, collection_type: type | tuple[type, ...], result_collection_constructor: Optional[type] = None ) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ... def tree_map( *args: Callable[_P, _R], collection_type: type | tuple[type, ...] = tuple, - result_collection_type: Optional[type] = None, + result_collection_constructor: Optional[type] = None, # Todo: check name with Enrique ) -> ( Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]] @@ -93,10 +93,10 @@ def tree_map( Args: fun: Function to apply to each entry of the collection. collection_type: Type of the collection to be traversed. Can be a single type or a tuple of types. - result_collection_type: Type of the collection to be returned. If `None` the same type as `collection_type` is used. + result_collection_constructor: Constructor of the collection to be returned. If `None` the same type as `collection_type` is used. Examples: - >>> tree_map(lambda x: x + 1)(((1, 2), 3)) + >>> tree_map(lambda x: x + 1)(((1, 2), 3)) # TODO: tree_map(lambda x: x + 1,((1, 2), 3)) like map/reduce, decorator and original way should also work ((2, 3), 4) >>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6)) @@ -105,16 +105,16 @@ def tree_map( >>> tree_map(collection_type=list)(lambda x: x + 1)([[1, 2], 3]) [[2, 3], 4] - >>> tree_map(collection_type=list, result_collection_type=tuple)(lambda x: x + 1)([[1, 2], 3]) + >>> tree_map(collection_type=list, result_collection_constructor=tuple)(lambda x: x + 1)([[1, 2], 3]) ((2, 3), 4) """ - if result_collection_type is None: + if result_collection_constructor is None: if isinstance(collection_type, tuple): raise TypeError( - "tree_map() requires `result_collection_type` when `collection_type` is a tuple." + "tree_map() requires `result_collection_constructor` when `collection_type` is a tuple." ) - result_collection_type = collection_type + result_collection_constructor = collection_type if len(args) == 1: fun = args[0] @@ -125,8 +125,8 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: assert all( isinstance(arg, collection_type) and len(args[0]) == len(arg) for arg in args ) - assert result_collection_type is not None - return result_collection_type(impl(*arg) for arg in zip(*args)) + assert result_collection_constructor is not None + return result_collection_constructor(impl(*arg) for arg in zip(*args)) return fun( *cast(_P.args, args) @@ -137,8 +137,8 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: return functools.partial( tree_map, collection_type=collection_type, - result_collection_type=result_collection_type, + result_collection_constructor=result_collection_constructor, ) raise TypeError( - "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_type`." + "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`." ) diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index 42b16ae2de..903ffa66d5 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -15,7 +15,35 @@ import pytest from gt4py.next import utils +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info +from gt4py.next.common import Field +from numpy import int64 +def test_tree_map_scalar(): + @utils.tree_map(collection_type=ts.ScalarType,result_collection_constructor=tuple) + def testee(x): + return x + 1 + + assert testee(1) == (2) + +def test_apply_to_primitive_constituents(): + int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) + tuple_type = ts.TupleType(types=[ts.TupleType(types=[int_type, int_type]), int_type]) + + # @type_info.type_tree_map + # def tmp(x): + # return ts.FieldType(dims=[], dtype=x) + # + # tree = tmp(tuple_type) + + tree = type_info.type_tree_map(lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type))(tuple_type) + + prim = type_info.apply_to_primitive_constituents( + tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type) + ) + + assert tree == prim def test_tree_map_default(): @utils.tree_map @@ -45,7 +73,7 @@ def testee(x): def test_tree_map_custom_output_type(): - @utils.tree_map(result_collection_type=list) + @utils.tree_map(result_collection_constructor=list) def testee(x): return x + 1 @@ -53,7 +81,7 @@ def testee(x): def test_tree_map_multiple_input_types(): - @utils.tree_map(collection_type=(list, tuple), result_collection_type=tuple) + @utils.tree_map(collection_type=(list, tuple), result_collection_constructor=tuple) def testee(x): return x + 1 From 8ab97c493bda9d6cc7618b79c101a75215708b10 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 3 Jul 2024 18:42:36 +0200 Subject: [PATCH 56/61] Add functionality to call also tree_map(lambda x: x + 1, ((1, 2), 3)) and run pre-commit (still failing because of src/gt4py/next/type_system/type_info.py:186:17: error: No overload variant of "tree_map" matches argument types "type[TupleType]", "Callable[[Any], TupleType]" [call-overload] ) --- .../ffront/foast_passes/type_deduction.py | 6 ++- src/gt4py/next/ffront/type_info.py | 9 ++-- src/gt4py/next/type_system/type_info.py | 9 ++-- src/gt4py/next/utils.py | 49 ++++++++++++++----- tests/next_tests/unit_tests/test_utils.py | 19 +++++-- 5 files changed, 66 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 13a60caba1..8972fd4c81 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -839,9 +839,11 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call: # primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) # ), # ) - return_type = type_info.type_tree_map(lambda primitive_type: with_altered_scalar_kind( + return_type = type_info.type_tree_map( + lambda primitive_type: with_altered_scalar_kind( primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) - ))(value.type) + ) + )(value.type) assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType)) return foast.Call( diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 5188facc24..b24f622e80 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -17,7 +17,7 @@ import gt4py.next.ffront.type_specifications as ts_ffront import gt4py.next.type_system.type_specifications as ts -from gt4py.next import common, utils +from gt4py.next import common from gt4py.next.type_system import type_info @@ -41,6 +41,7 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec: return type_info.type_tree_map(promote_el)(type_) + def promote_zero_dims( function_type: ts.FunctionType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] ) -> tuple[list, dict]: @@ -317,6 +318,8 @@ def return_type_scanop( # def tmp(x): # return ts.FieldType(dims=promoted_dims, dtype=x) - #return tmp(carry_dtype) + # return tmp(carry_dtype) - return type_info.type_tree_map(lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)))(carry_dtype) \ No newline at end of file + return type_info.type_tree_map( + lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg)) + )(carry_dtype) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index da64c46f3c..a4f3957caf 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -21,9 +21,8 @@ import numpy as np from gt4py.eve.utils import XIterable, xiter -from gt4py.next import common +from gt4py.next import common, utils from gt4py.next.type_system import type_specifications as ts -from gt4py.next import utils def _number_to_ordinal_number(number: int) -> str: @@ -183,7 +182,11 @@ def apply_to_primitive_constituents( else: return fun(symbol_type) # type: ignore[call-arg] # mypy not aware of `with_path_arg` -type_tree_map=utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=lambda x: ts.TupleType(types=[*x])) + +type_tree_map = utils.tree_map( + collection_type=ts.TupleType, result_collection_constructor=lambda x: ts.TupleType(types=[*x]) +) + def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: """ diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index c17a7bb317..651ef483dd 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -75,17 +75,21 @@ def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, . @overload def tree_map( - *, collection_type: type | tuple[type, ...], result_collection_constructor: Optional[type] = None + *, + collection_type: type | tuple[type, ...], + result_collection_constructor: Optional[type] = None, ) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ... def tree_map( *args: Callable[_P, _R], collection_type: type | tuple[type, ...] = tuple, - result_collection_constructor: Optional[type] = None, # Todo: check name with Enrique + result_collection_constructor: Optional[type] = None, # Todo: check name with Enrique ) -> ( Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]] + | _R + | tuple[_R | tuple, ...] ): """ Apply `fun` to each entry of (possibly nested) collections (by default `tuple`s). @@ -96,7 +100,10 @@ def tree_map( result_collection_constructor: Constructor of the collection to be returned. If `None` the same type as `collection_type` is used. Examples: - >>> tree_map(lambda x: x + 1)(((1, 2), 3)) # TODO: tree_map(lambda x: x + 1,((1, 2), 3)) like map/reduce, decorator and original way should also work + >>> tree_map(lambda x: x + 1)(((1, 2), 3)) + ((2, 3), 4) + + >>> tree_map(lambda x: x + 1, ((1, 2), 3)) ((2, 3), 4) >>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6)) @@ -105,7 +112,12 @@ def tree_map( >>> tree_map(collection_type=list)(lambda x: x + 1)([[1, 2], 3]) [[2, 3], 4] - >>> tree_map(collection_type=list, result_collection_constructor=tuple)(lambda x: x + 1)([[1, 2], 3]) + >>> tree_map(collection_type=list)(lambda x: x + 1, [[1, 2], 3]) + [[2, 3], 4] + + >>> tree_map(collection_type=list, result_collection_constructor=tuple)(lambda x: x + 1)( + ... [[1, 2], 3] + ... ) ((2, 3), 4) """ @@ -116,8 +128,24 @@ def tree_map( ) result_collection_constructor = collection_type - if len(args) == 1: + if len(args) == 0: + return functools.partial( + tree_map, + collection_type=collection_type, + result_collection_constructor=result_collection_constructor, + ) + + if callable(args[0]): fun = args[0] + colls = args[1:] + + if len(colls) == 0: + return functools.partial( + tree_map, + fun, + collection_type=collection_type, + result_collection_constructor=result_collection_constructor, + ) @functools.wraps(fun) def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: @@ -132,13 +160,8 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: *cast(_P.args, args) ) # mypy doesn't understand that `args` at this point is of type `_P.args` - return impl - if len(args) == 0: - return functools.partial( - tree_map, - collection_type=collection_type, - result_collection_constructor=result_collection_constructor, - ) + return impl(*colls) + raise TypeError( - "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`." + "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`, or with a function and collection." ) diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index 903ffa66d5..4ec36707d7 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -20,13 +20,15 @@ from gt4py.next.common import Field from numpy import int64 + def test_tree_map_scalar(): - @utils.tree_map(collection_type=ts.ScalarType,result_collection_constructor=tuple) + @utils.tree_map(collection_type=ts.ScalarType, result_collection_constructor=tuple) def testee(x): return x + 1 assert testee(1) == (2) + def test_apply_to_primitive_constituents(): int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) tuple_type = ts.TupleType(types=[ts.TupleType(types=[int_type, int_type]), int_type]) @@ -37,20 +39,27 @@ def test_apply_to_primitive_constituents(): # # tree = tmp(tuple_type) - tree = type_info.type_tree_map(lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type))(tuple_type) + tree = type_info.type_tree_map( + lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type) + )(tuple_type) prim = type_info.apply_to_primitive_constituents( - tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type) + tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type) ) assert tree == prim + def test_tree_map_default(): + expected_result = ((2, 3), 4) + @utils.tree_map - def testee(x): + def testee1(x): return x + 1 - assert testee(((1, 2), 3)) == ((2, 3), 4) + assert testee1(((1, 2), 3)) == expected_result + assert utils.tree_map(lambda x: x + 1)(((1, 2), 3)) == expected_result + assert utils.tree_map(lambda x: x + 1, ((1, 2), 3)) == expected_result def test_tree_map_multi_arg(): From 827a4d3e70df416876df1f1cbb4ebf93d6c033ed Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 31 Dec 2024 11:09:06 +0100 Subject: [PATCH 57/61] Run pre-commit --- src/gt4py/next/embedded/operators.py | 1 - src/gt4py/next/ffront/foast_passes/type_deduction.py | 1 - src/gt4py/next/field_utils.py | 4 ---- src/gt4py/next/iterator/embedded.py | 1 + src/gt4py/next/iterator/ir.py | 2 ++ src/gt4py/next/iterator/runtime.py | 1 - src/gt4py/next/iterator/tracing.py | 10 ---------- src/gt4py/next/iterator/transforms/global_tmps.py | 1 - src/gt4py/next/iterator/transforms/pass_manager.py | 1 - src/gt4py/next/program_processors/runners/roundtrip.py | 1 - src/gt4py/next/utils.py | 2 +- .../feature_tests/iterator_tests/test_program.py | 3 ++- tests/next_tests/unit_tests/test_utils.py | 2 +- 13 files changed, 7 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index ebabdbb936..77420dae09 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -17,7 +17,6 @@ from gt4py.next.type_system import type_specifications as ts, type_translation - _P = ParamSpec("_P") _R = TypeVar("_R") diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 7a491089bb..8f131569a4 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -848,7 +848,6 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call: f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.", ) - return_type = type_info.type_tree_map( lambda primitive_type: with_altered_scalar_kind( primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py index 8c5d4cb5aa..65865709ba 100644 --- a/src/gt4py/next/field_utils.py +++ b/src/gt4py/next/field_utils.py @@ -8,9 +8,6 @@ from types import ModuleType -from types import ModuleType -from typing import Callable - import numpy as np from gt4py._core import definitions as core_defs @@ -61,7 +58,6 @@ def impl(type_: ts.ScalarType) -> common.MutableField: return impl(type_) - def get_array_ns( *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> ModuleType: diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 00bf596616..13c64e264e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1724,6 +1724,7 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider def set_at(expr: common.Field, domain: common.DomainLike, target: common.MutableField) -> None: operators._tuple_assign_field(target, expr, common.domain(domain)) + @runtime.if_stmt.register(EMBEDDED) def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None: """ diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 679697ead6..e875709631 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -19,6 +19,7 @@ DimensionKind = common.DimensionKind + @noninstantiable class Node(eve.Node): location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) @@ -172,6 +173,7 @@ class FunctionDefinition(Node, SymbolTableTrait): *TYPEBUILTINS, } + class Stmt(Node): ... diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index 8ac8e4be4b..c9a5b15de7 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -22,7 +22,6 @@ from gt4py.next.program_processors import program_formatter - if TYPE_CHECKING: # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 943049963b..12c86680b5 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -212,10 +212,6 @@ def add_fundef(cls, fun): def add_stmt(cls, stmt): cls.body.append(stmt) - @classmethod - def add_stmt(cls, stmt): - cls.body.append(stmt) - def __enter__(self): iterator.builtins.builtin_dispatch.push_key(TRACING) @@ -251,11 +247,6 @@ def if_stmt( ) -@iterator.runtime.set_at.register(TRACING) -def set_at(expr, domain, target): - TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target)) - - def _contains_tuple_dtype_field(arg): if isinstance(arg, tuple): return any(_contains_tuple_dtype_field(el) for el in arg) @@ -302,7 +293,6 @@ def _make_program_params(fun, args) -> list[Sym]: return params - def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> itir.Program: """ Transform fencil given as a callable into `itir.Program` using tracing. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 2e3c47a4d7..334fb330d7 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -47,7 +47,6 @@ def _transform_if( return None - def _transform_by_pattern( stmt: itir.Stmt, predicate: Callable[[itir.Expr, int], bool], diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b69b1b5ae8..d967c8fbb8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -52,7 +52,6 @@ def apply_common_transforms( #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, - offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index b88fe90414..32c3f7a360 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -21,7 +21,6 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config - from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import stages, workflow diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index c72103996c..4dfcf1461a 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -170,4 +170,4 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: raise TypeError( "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`, or with a function and collection." - ) \ No newline at end of file + ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index caa9e95a6f..facacc1c23 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -1,6 +1,5 @@ # GT4Py - GridTools Framework # - # Copyright (c) 2014-2024, ETH Zurich # All rights reserved. # @@ -29,6 +28,7 @@ I = gtx.Dimension("I") Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,)) + @fundef def copy_stencil(inp): return deref(inp) @@ -54,6 +54,7 @@ def test_prog(program_processor): if validate: assert np.allclose(inp.asnumpy(), out.asnumpy()) + @fendef def index_program_simple(out, size): set_at( diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index 6f36dd9c06..7ce9674247 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -33,7 +33,7 @@ def test_apply_to_primitive_constituents(): )(tuple_type) prim = type_info.apply_to_primitive_constituents( - lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type), tuple_type + lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type), tuple_type ) assert tree == prim From 0a5ac377fd2ba2a2e4392ea0c964bad87a97c000 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 31 Dec 2024 11:13:12 +0100 Subject: [PATCH 58/61] Minor --- src/gt4py/next/iterator/pretty_parser.py | 4 --- .../codegens/gtfn/itir_to_gtfn_ir.py | 26 ------------------- .../iterator_tests/test_program.py | 1 - tests/next_tests/unit_tests/test_utils.py | 5 ++-- 4 files changed, 2 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index a5710aa422..a077b39911 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -34,10 +34,6 @@ AXIS_LITERAL: CNAME ("ᵥ" | "ₕ") _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL ID_NAME: CNAME -<<<<<<< HEAD - AXIS_NAME: CNAME ("ᵥ" | "ₕ") -======= ->>>>>>> main ?prec0: prec1 | "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 1bad5910d3..d5b34fd5b9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -307,32 +307,6 @@ def _gen_constituent_expr(el_type: ts.ScalarType | ts.FieldType, path: tuple[int return result -class _CannonicalizeUnstructuredDomain(eve.NodeTranslator): - def visit_FunCall(self, node: itir.FunCall) -> itir.FunCall: - if node.fun == itir.SymRef(id="unstructured_domain"): - # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) - assert isinstance(node.args[0], itir.FunCall) - first_axis_literal = node.args[0].args[0] - assert isinstance(first_axis_literal, itir.AxisLiteral) - if first_axis_literal.kind == itir.DimensionKind.VERTICAL: - assert len(node.args) == 2 - assert isinstance(node.args[1], itir.FunCall) - assert isinstance(node.args[1].args[0], itir.AxisLiteral) - assert node.args[1].args[0].kind == itir.DimensionKind.HORIZONTAL - return itir.FunCall(fun=node.fun, args=[node.args[1], node.args[0]]) - return node - - @classmethod - def apply( - cls, - node: itir.Program, - ) -> itir.Program: - if not isinstance(node, itir.Program): - raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") - - return cls().visit(node) - - @dataclasses.dataclass(frozen=True) class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): _binary_op_map: ClassVar[dict[str, str]] = { diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index facacc1c23..c79f8dbb6b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - import numpy as np import pytest diff --git a/tests/next_tests/unit_tests/test_utils.py b/tests/next_tests/unit_tests/test_utils.py index 7ce9674247..2204380353 100644 --- a/tests/next_tests/unit_tests/test_utils.py +++ b/tests/next_tests/unit_tests/test_utils.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - import pytest from gt4py.next import utils @@ -43,10 +42,10 @@ def test_tree_map_default(): expected_result = ((2, 3), 4) @utils.tree_map - def testee1(x): + def testee(x): return x + 1 - assert testee1(((1, 2), 3)) == expected_result + assert testee(((1, 2), 3)) == expected_result assert utils.tree_map(lambda x: x + 1)(((1, 2), 3)) == expected_result assert utils.tree_map(lambda x: x + 1, ((1, 2), 3)) == expected_result From 70283e60d95cd366b209b79ba29927de2e28907b Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 31 Dec 2024 11:50:42 +0100 Subject: [PATCH 59/61] Replace more apply_to_primitive_constituents by tree_map --- .../iterator/transforms/fuse_as_fieldop.py | 4 ++-- .../next/iterator/transforms/global_tmps.py | 19 +++++++------------ .../next/iterator/type_system/inference.py | 5 ++--- .../iterator/type_system/type_synthesizer.py | 9 ++++----- src/gt4py/next/utils.py | 4 ++-- 5 files changed, 17 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index b7087472e0..8331317cf4 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -139,7 +139,7 @@ def fuse_as_fieldop( # just a safety check if typing information is available if arg.type and not isinstance(arg.type, ts.DeferredType): assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + dtype = type_info.type_tree_map(type_info.extract_dtype)(arg.type) assert not isinstance(dtype, it_ts.ListType) new_param: str if isinstance( @@ -233,7 +233,7 @@ def visit_FunCall(self, node: itir.FunCall): eligible_args = [] for arg, arg_shifts in zip(args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + dtype = type_info.type_tree_map(type_info.extract_dtype)(arg.type) # TODO(tehrengruber): make this configurable eligible_args.append( _is_tuple_expr_of_literals(arg) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 334fb330d7..cca2a7599c 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -93,18 +93,13 @@ def _transform_by_pattern( domain_expr = domain.as_expr() assert isinstance(tmp_expr.type, ts.TypeSpec) - tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( - lambda x: uids.sequential_id(), - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) - tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( - type_info.apply_to_primitive_constituents( - type_info.extract_dtype, - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) - ) + tmp_names: str | tuple[str | tuple, ...] = type_info.type_tree_map( + result_collection_constructor=lambda *elements: tuple(elements) + )(lambda x: uids.sequential_id)(tmp_expr.type) + + tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = type_info.type_tree_map( + result_collection_constructor=lambda *elements: tuple(elements) + )(type_info.extract_dtype)(tmp_expr.type) # allocate temporary for all tuple elements def allocate_temporary(tmp_name: str, dtype: ts.ScalarType): diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1b980783fa..34760813fc 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -478,9 +478,8 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup assert isinstance(domain, it_ts.DomainType) assert domain.dims != "unknown" assert node.dtype - return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above - node.dtype, + return type_info.type_tree_map(lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype))( + node.dtype ) def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5be9ed7438..3e3de3c931 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -258,7 +258,7 @@ def _convert_as_fieldop_input_to_iterator( input_dims = [] element_type: ts.DataType - element_type = type_info.apply_to_primitive_constituents(type_info.extract_dtype, input_) + element_type = type_info.type_tree_map(type_info.extract_dtype)(input_) # handle neighbor / sparse input fields defined_dims = [] @@ -311,12 +311,11 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: offset_provider_type=offset_provider_type, ) assert isinstance(stencil_return, ts.DataType) - return type_info.apply_to_primitive_constituents( + return type_info.type_tree_map( lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type) if domain.dims != "unknown" - else ts.DeferredType(constraint=ts.FieldType), - stencil_return, - ) + else ts.DeferredType(constraint=ts.FieldType) + )(stencil_return) return applied_as_fieldop diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 4dfcf1461a..e9b94df6e9 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -130,7 +130,7 @@ def tree_map( if result_collection_constructor is None: if isinstance(collection_type, tuple): raise TypeError( - "tree_map() requires `result_collection_constructor` when `collection_type` is a tuple." + "tree_map() requires `result_collection_constructor` when `collection_type` is a tuple of types." ) result_collection_constructor = collection_type @@ -162,7 +162,7 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: assert result_collection_constructor is not None return result_collection_constructor(impl(*arg) for arg in zip(*args)) - return fun( + return fun( # type: ignore[call-arg] # mypy not smart enough *cast(_P.args, args) ) # mypy doesn't understand that `args` at this point is of type `_P.args` From 14619d2dbb05ffde65d42a7e2824e5c748ca19b3 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 31 Dec 2024 12:08:50 +0100 Subject: [PATCH 60/61] Minor fix --- src/gt4py/next/iterator/transforms/global_tmps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index cca2a7599c..2d81036dae 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -95,7 +95,7 @@ def _transform_by_pattern( assert isinstance(tmp_expr.type, ts.TypeSpec) tmp_names: str | tuple[str | tuple, ...] = type_info.type_tree_map( result_collection_constructor=lambda *elements: tuple(elements) - )(lambda x: uids.sequential_id)(tmp_expr.type) + )(lambda x: uids.sequential_id())(tmp_expr.type) tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = type_info.type_tree_map( result_collection_constructor=lambda *elements: tuple(elements) From fac65de6ca9f90fbd8d11bd8f65f750e2a9602df Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 31 Dec 2024 12:53:05 +0100 Subject: [PATCH 61/61] Revert replacing when tuple_constructor is present --- .../next/iterator/transforms/global_tmps.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 2d81036dae..db12ba13ed 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -93,13 +93,19 @@ def _transform_by_pattern( domain_expr = domain.as_expr() assert isinstance(tmp_expr.type, ts.TypeSpec) - tmp_names: str | tuple[str | tuple, ...] = type_info.type_tree_map( - result_collection_constructor=lambda *elements: tuple(elements) - )(lambda x: uids.sequential_id())(tmp_expr.type) - - tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = type_info.type_tree_map( - result_collection_constructor=lambda *elements: tuple(elements) - )(type_info.extract_dtype)(tmp_expr.type) + tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( + lambda x: uids.sequential_id(), + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), + ) # TODO: how should tuple_constructorb e handled? + + tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( + type_info.apply_to_primitive_constituents( + type_info.extract_dtype, + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), + ) + ) # TODO: how should tuple_constructorb e handled? # allocate temporary for all tuple elements def allocate_temporary(tmp_name: str, dtype: ts.ScalarType):