Skip to content

Commit

Permalink
Added named_arrays.AbstractArray.cell_centers(), a method to conver…
Browse files Browse the repository at this point in the history
…t from cell vertices to cell centers. (#93)
  • Loading branch information
byrdie authored Nov 6, 2024
1 parent 839f093 commit daebd94
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
30 changes: 30 additions & 0 deletions named_arrays/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,36 @@ def combine_axes(
Array with the specified axes combined
"""

def cell_centers(
self,
axis: None | str | Sequence[str] = None,
) -> na.AbstractExplicitArray:
"""
Convert an array from cell vertices to cell centers.
Parameters
----------
axis
The axes of the array to average over.
"""

if axis is None:
axis = self.axes
elif isinstance(axis, str):
axis = (axis, )

result = self.explicit

shape = result.shape

for a in axis:
if a in shape:
lower = {a: slice(None, ~0)}
upper = {a: slice(+1, None)}
result = (result[lower] + result[upper]) / 2

return result

def volume_cell(self, axis: None | str | Sequence[str]) -> na.AbstractScalar:
"""
Computes the n-dimensional volume of each cell formed by interpreting
Expand Down
10 changes: 10 additions & 0 deletions named_arrays/_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ def combine_axes(
outputs=outputs.combine_axes(axes=axes, axis_new=axis_new),
)

def cell_centers(
self,
axis: None | str | Sequence[str] = None,
) -> na.AbstractExplicitArray:
return dataclasses.replace(
self,
inputs=self.inputs.cell_centers(axis),
outputs=self.outputs.cell_centers(axis),
)

def to_string_array(
self,
format_value: str = "%.2f",
Expand Down
27 changes: 27 additions & 0 deletions named_arrays/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,33 @@ def test_combine_axes(
with pytest.raises(ValueError):
array.combine_axes(axes=axes, axis_new=axis_new)

@pytest.mark.parametrize(
argnames="axis",
argvalues=[
None,
"y",
("y",),
("x", "y"),
]
)
def test_cell_centers(
self,
array: na.AbstractArray,
axis: None | str | Sequence[str],
):
if axis is None:
axis_normalized = array.axes
elif isinstance(axis, str):
axis_normalized = (axis, )
else:
axis_normalized = axis

result = array.cell_centers(axis)

for a in axis_normalized:
if a in array.shape:
assert result.shape[a] == array.shape[a] - 1

@pytest.mark.parametrize(
argnames="axis",
argvalues=[
Expand Down

0 comments on commit daebd94

Please sign in to comment.