Skip to content

Commit

Permalink
Update and extend various parts
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Mar 11, 2024
1 parent c80e394 commit 4c3dddf
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 11 deletions.
9 changes: 6 additions & 3 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import numpy as np
import numpy.typing as npt

import gt4py.next as gtx
from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -80,12 +79,16 @@ def __call__(self, val: int) -> NamedIndex:
return self, val

def __add__(self, offset: int):
from gt4py.next.ffront import fbuiltins

assert isinstance(self.value, str)
return gtx.FieldOffset(f"{self.value}off", source=self, target=(self,))[offset]
return fbuiltins.FieldOffset(f"{self.value}off", source=self, target=(self,))[offset]

def __sub__(self, offset: int):
from gt4py.next.ffront import fbuiltins

assert isinstance(self.value, str)
return gtx.FieldOffset(f"{self.value}off", source=self, target=(self,))[-offset]
return fbuiltins.FieldOffset(f"{self.value}off", source=self, target=(self,))[-offset]


class Infinity(enum.Enum):
Expand Down
4 changes: 1 addition & 3 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ def __call__(
index_field: common.ConnectivityField | fbuiltins.FieldOffset,
*args: common.ConnectivityField | fbuiltins.FieldOffset,
) -> common.Field:
if not args:
return self.remap(index_field)
return self.__call__(index_field, *args[1:]).remap(args[0])
return functools.reduce(lambda field, connectivity: field.remap(connectivity), [index_field, *args], self)

def restrict(self, index: common.AnyIndexSpec) -> common.Field:
new_domain, buffer_slice = self._slice(index)
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,10 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr:
left=foast.Name(id=dimension),
right=foast.Constant(value=offset_index),
):
shift_offset = im.shift(
dimension, offset_index
) # shift_offset = im.shift(node.args[0].left.type.dim.value, node.args[0].right.value)
if node.args[0].op == dialect_ast_enums.BinaryOperator.SUB:
offset_index *= -1
shift_offset = im.shift(f"{dimension}off", offset_index)
# shift_offset = im.shift(node.args[0].left.type.dim.value, node.args[0].right.value)
case foast.Name(id=offset_name):
return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs))
case foast.Call(func=foast.Name(id="as_offset")):
Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, cast

from gt4py.eve import NodeTranslator, concepts, traits
from gt4py.next import errors
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.ffront import lowering_utils, program_ast as past, type_specifications as ts_ffront
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -324,9 +325,10 @@ def _compute_field_slice(node: past.Subscript):
node_dims_ls = cast(ts.FieldType, node.type).dims
assert isinstance(node_dims_ls, list)
if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls):
raise ValueError(
raise errors.DSLError(
node.location,
f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}"
f"-dimensional, but {len(out_field_slice_)} were indexed."
f"-dimensional, but {len(out_field_slice_)} were indexed.",
)
return out_field_slice_

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def test_ffront_lap(cartesian_case):
ref=lap_ref(in_field.ndarray),
)


def test_ffront_skewedlap(cartesian_case):
in_field = cases.allocate(cartesian_case, skewedlap_program, "in_field")()
out_field = cases.allocate(cartesian_case, skewedlap_program, "out_field")()

Expand All @@ -112,6 +114,8 @@ def test_ffront_lap(cartesian_case):
ref=skewedlap_ref(in_field.ndarray),
)


def test_ffront_laplap(cartesian_case):
in_field = cases.allocate(cartesian_case, laplap_program, "in_field")()
out_field = cases.allocate(cartesian_case, laplap_program, "out_field")()

Expand Down

0 comments on commit 4c3dddf

Please sign in to comment.