Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tree_map and replace apply_to_primitive_constituents #1570

Open
wants to merge 69 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
8add029
refactor[next]: itir embedded: cleaner closure run
havogt Apr 4, 2024
853d3e1
cleanup
havogt Apr 4, 2024
f661cd3
fix test
havogt Apr 4, 2024
09e568d
without temporaries
havogt Apr 8, 2024
12b8696
temporaries
havogt Apr 8, 2024
540a2d8
cleanup
havogt Apr 9, 2024
23ddef1
move to SetAt
havogt Apr 10, 2024
e64b986
Merge branch 'refactor_itir_embedded' into itir_program_embedded2
havogt Apr 10, 2024
c99f44d
embedded
havogt Apr 10, 2024
1a6f885
roundtrip+double_roundtrip with shortcuts
havogt Apr 11, 2024
39d6d7c
changes
havogt Apr 11, 2024
ab44009
fencil2program only for gtfn
havogt Apr 11, 2024
12f1663
fix import
havogt Apr 11, 2024
aa80949
Merge remote-tracking branch 'upstream/main' into itir_program
havogt Apr 11, 2024
5037493
fix builtins list
havogt Apr 11, 2024
751581e
add comment
havogt Apr 11, 2024
3d2f33e
fix type checker
havogt Apr 11, 2024
53bad75
Merge branch 'itir_program' into itir_program_embedded2
havogt Apr 11, 2024
4cbce7e
Apply suggestions from code review
havogt Apr 12, 2024
c955645
format
havogt Apr 12, 2024
6effe10
pretty printing/parsing
havogt Apr 12, 2024
66de3ec
Apply suggestions from code review
havogt Apr 15, 2024
e63da77
address more review comments
havogt Apr 15, 2024
45fba85
move tmp to pretty_printer
havogt Apr 15, 2024
1a70218
pparse for temporaries
havogt Apr 15, 2024
c39c603
rename gtfn.FencilDefinition -> Program
havogt Apr 15, 2024
705cfcf
remove TODO
havogt Apr 15, 2024
c5e78c4
Apply suggestions from code review
havogt Apr 15, 2024
a336bf5
rename as_field_operator -> as_fieldop
havogt Apr 15, 2024
af16f40
missed a file
havogt Apr 15, 2024
2d6bfbf
Merge branch 'itir_program' into itir_program_embedded2
havogt Apr 15, 2024
df1146b
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 15, 2024
bc2c2d3
add fencil2program to roundtrip
havogt Apr 16, 2024
c7ccd6a
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 16, 2024
f45b460
pre-allocate result buffer
havogt Apr 17, 2024
5d4fc3d
fix tracer context
havogt Apr 17, 2024
1d93192
first (almost) complete embedded version
havogt Apr 17, 2024
5882c28
add dim kind to print/parse
havogt Apr 17, 2024
b7cbf16
fix tests
havogt Apr 17, 2024
e97ca25
cleanup test_program
havogt Apr 17, 2024
8c2bd8f
re-enable lift mode in roundtrip
havogt Apr 18, 2024
21b230b
replace lift_mode fixture by backend in program_processor
havogt Apr 18, 2024
0946783
fix doctests
havogt Apr 18, 2024
f93da09
fix tests
havogt Apr 18, 2024
97663bd
undo quickstart changes
havogt Apr 18, 2024
3297f7b
undo delete cpp_backend_tests
havogt Apr 18, 2024
f37b372
fix quickstart guide again
havogt Apr 18, 2024
e242ab6
remove runtime lift
havogt Apr 19, 2024
a07d8ea
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 19, 2024
3c2b9a5
Merge remote-tracking branch 'upstream/main' into test_lift_mode_to_p…
havogt Apr 19, 2024
a9f1043
Update docs/user/next/QuickstartGuide.md
havogt Apr 19, 2024
e7195a5
cleanup out field construction
havogt Apr 19, 2024
3f67746
Update src/gt4py/next/program_processors/runners/double_roundtrip.py
havogt Apr 22, 2024
369eae7
read config.DEBUG at execution
havogt Apr 22, 2024
35f2132
remove LiftMode.SIMPLE_HEURISTIC
havogt Apr 22, 2024
4a4f9b1
fix formatting
havogt Apr 23, 2024
2825588
Merge branch 'test_lift_mode_to_processor' into itir_program_embedded2
havogt Apr 23, 2024
65abde7
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt Apr 23, 2024
7373eb6
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 6, 2024
bfcc118
move ordering of unstructured domain to gtfn
havogt May 15, 2024
54f44cc
fix problem in column dtype if contains None
havogt May 16, 2024
b8b26e6
address more review comments
havogt May 16, 2024
b7b489e
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 16, 2024
9ee02e4
fix tuples in columns
havogt May 16, 2024
d104633
fix preserve axis kind in global tmps
havogt May 17, 2024
866c3d6
Merge remote-tracking branch 'upstream/main' into itir_program_embedded2
havogt May 17, 2024
66e2464
fix follow up issue
havogt May 17, 2024
56e0086
Start using tree_map instead of apply_to_primitive_constituents
SF-N Jun 10, 2024
8ab97c4
Add functionality to call also tree_map(lambda x: x + 1, ((1, 2), 3))…
SF-N Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 7 additions & 31 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from types import ModuleType
from typing import Any, Callable, Generic, Optional, ParamSpec, Sequence, TypeVar

import numpy as np

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, errors, utils
from gt4py.next import common, errors, field_utils, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.field_utils import get_array_ns
from gt4py.next.type_system import type_translation


_P = ParamSpec("_P")
Expand Down Expand Up @@ -64,8 +63,10 @@ def __call__( # type: ignore[override]
# even if the scan dimension is not in the input, we can scan over it
out_domain = common.Domain(*out_domain, (scan_range))

xp = _get_array_ns(*all_args)
res = _construct_scan_array(out_domain, xp)(self.init)
xp = get_array_ns(*all_args)
res = field_utils.field_from_typespec(out_domain, xp)(
type_translation.from_value(self.init)
)

def scan_loop(hpos: Sequence[common.NamedIndex]) -> None:
acc: core_defs.ScalarT | tuple[core_defs.ScalarT | tuple, ...] = self.init
Expand Down Expand Up @@ -163,31 +164,6 @@ def _intersect_scan_args(
)


def _get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...],
) -> ModuleType:
for arg in utils.flatten_nested_tuple(args):
if hasattr(arg, "array_ns"):
return arg.array_ns
return np


def _construct_scan_array(
domain: common.Domain,
xp: ModuleType, # TODO(havogt) introduce a NDArrayNamespace protocol
) -> Callable[
[core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]],
common.MutableField | tuple[common.MutableField | tuple, ...],
]:
@utils.tree_map
def impl(init: core_defs.Scalar) -> common.MutableField:
res = common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain)
assert isinstance(res, common.MutableField)
return res

return impl


def _tuple_assign_value(
pos: Sequence[common.NamedIndex],
target: common.MutableField | tuple[common.MutableField | tuple, ...],
Expand Down
13 changes: 9 additions & 4 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,12 +833,17 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:
f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.",
)

return_type = type_info.apply_to_primitive_constituents(
value.type,
# return_type = type_info.apply_to_primitive_constituents(
# value.type,
# lambda primitive_type: with_altered_scalar_kind(
# primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
# ),
# )
return_type = type_info.type_tree_map(
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
)
)
)(value.type)
assert isinstance(return_type, (ts.TupleType, ts.ScalarType, ts.FieldType))

return foast.Call(
Expand Down
7 changes: 1 addition & 6 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _construct_itir_domain_arg(
domain_args.append(
itir.FunCall(
fun=itir.SymRef(id="named_range"),
args=[itir.AxisLiteral(value=dim.value), lower, upper],
args=[itir.AxisLiteral(value=dim.value, kind=dim.kind), lower, upper],
)
)
domain_args_kind.append(dim.kind)
Expand All @@ -349,11 +349,6 @@ def _construct_itir_domain_arg(
domain_builtin = "cartesian_domain"
elif self.grid_type == common.GridType.UNSTRUCTURED:
domain_builtin = "unstructured_domain"
# for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical)
if domain_args_kind[0] == common.DimensionKind.VERTICAL:
assert len(domain_args) == 2
assert domain_args_kind[1] == common.DimensionKind.HORIZONTAL
domain_args[0], domain_args[1] = domain_args[1], domain_args[0]
else:
raise AssertionError()

Expand Down
20 changes: 16 additions & 4 deletions src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec:
return ts.FieldType(dims=[], dtype=type_el)
return type_el

return type_info.apply_to_primitive_constituents(type_, promote_el)
# return type_info.apply_to_primitive_constituents(type_, promote_el)

return type_info.type_tree_map(promote_el)(type_)


def promote_zero_dims(
Expand Down Expand Up @@ -308,6 +310,16 @@ def return_type_scanop(
# field
[callable_type.axis],
)
return type_info.apply_to_primitive_constituents(
carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg))
)

# return type_info.apply_to_primitive_constituents(
# carry_dtype, lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg))
# )
# @utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=lambda x: ts.TupleType(types=[*x]))
# def tmp(x):
# return ts.FieldType(dims=promoted_dims, dtype=x)

# return tmp(carry_dtype)

return type_info.type_tree_map(
lambda arg: ts.FieldType(dims=promoted_dims, dtype=cast(ts.ScalarType, arg))
)(carry_dtype)
29 changes: 29 additions & 0 deletions src/gt4py/next/field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,40 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from types import ModuleType
from typing import Callable

import numpy as np

from gt4py._core import definitions as core_defs
from gt4py.next import common, utils
from gt4py.next.type_system import type_specifications as ts, type_translation


@utils.tree_map
def asnumpy(field: common.Field | np.ndarray) -> np.ndarray:
return field.asnumpy() if isinstance(field, common.Field) else field


def field_from_typespec(
domain: common.Domain, xp: ModuleType
) -> Callable[..., common.MutableField | tuple[common.MutableField | tuple, ...]]:
@utils.tree_map(collection_type=ts.TupleType, result_collection_constructor=tuple)
def impl(type_: ts.ScalarType) -> common.MutableField:
res = common._field(
xp.empty(domain.shape, dtype=xp.dtype(type_translation.as_dtype(type_).scalar_type)),
domain=domain,
)
assert isinstance(res, common.MutableField)
return res

return impl


def get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...],
) -> ModuleType:
for arg in utils.flatten_nested_tuple(args):
if hasattr(arg, "array_ns"):
return arg.array_ns
return np
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def __init__(self) -> None:
super().__init__("Backend not selected")


@builtin_dispatch
def as_fieldop(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def deref(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -430,6 +435,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"cartesian_domain",
"unstructured_domain",
"named_range",
"as_fieldop",
*MATH_BUILTINS,
}

Expand Down
Loading