Skip to content

Commit

Permalink
Added named_arrays.regridding module, a wrapper around the `regridd…
Browse files Browse the repository at this point in the history
…ing` package. (#96)
  • Loading branch information
byrdie authored Nov 11, 2024
1 parent b150c29 commit 3fa0c95
Show file tree
Hide file tree
Showing 7 changed files with 531 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
'xarray': ('https://docs.xarray.dev/en/stable/', None),
'ndfilters': ('https://ndfilters.readthedocs.io/en/stable', None),
'colorsynth': ('https://colorsynth.readthedocs.io/en/stable', None),
'regridding': ('https://regridding.readthedocs.io/en/stable', None),
}

# plt.Axes.__module__ = matplotlib.axes.__name__
Expand Down
1 change: 1 addition & 0 deletions named_arrays/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import transformations
from . import ndfilters
from . import colorsynth
from . import regridding
from ._core import *
from ._scalars.scalars import *
from ._scalars.uncertainties.uncertainties import *
Expand Down
78 changes: 78 additions & 0 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import astroscrappy
import ndfilters
import colorsynth
import regridding
import named_arrays as na
from . import scalars

Expand Down Expand Up @@ -1422,6 +1423,83 @@ def ndfilter(
)


@_implements(na.regridding.weights)
def regridding_weights(
coordinates_input: na.AbstractScalarArray | na.AbstractVectorArray,
coordinates_output: na.AbstractScalarArray | na.AbstractVectorArray,
axis_input: None | str | Sequence[str] = None,
axis_output: None | str | Sequence[str] = None,
method: Literal['multilinear', 'conservative'] = 'multilinear',
) -> tuple[na.AbstractScalar, dict[str, int], dict[str, int]]:

if not isinstance(coordinates_output, na.AbstractVectorArray):
coordinates_output = na.CartesianNdVectorArray(dict(x=coordinates_output))

if not isinstance(coordinates_input, na.AbstractVectorArray):
coordinates_input = na.CartesianNdVectorArray(dict(x=coordinates_input))

return na.regridding.weights(
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
axis_input=axis_input,
axis_output=axis_output,
method=method,
)


@_implements(na.regridding.regrid_from_weights)
def regridding_regrid_from_weights(
weights: na.AbstractScalarArray,
shape_input: dict[str, int],
shape_output: dict[str, int],
values_input: na.AbstractScalarArray,
) -> na.ScalarArray:

try:
weights = scalars._normalize(weights)
values_input = scalars._normalize(values_input)
except scalars.ScalarTypeError: # pragma: nocover
return NotImplemented

shape_weights = weights.shape

axis_input = tuple(a for a in shape_input if a not in shape_weights)
axis_output = tuple(a for a in shape_output if a not in shape_weights)

shape_values_input = values_input.shape
shape_orthogonal = {
a: shape_values_input[a]
for a in shape_values_input
if a not in axis_input
}
shape_orthogonal = na.broadcast_shapes(shape_orthogonal, shape_weights)

shape_input = na.broadcast_shapes(shape_orthogonal, shape_input)
shape_output = na.broadcast_shapes(shape_orthogonal, shape_output)

weights = weights.broadcast_to({
a: shape_input[a] if a not in axis_input else 1
for a in shape_input
})
values_input = values_input.broadcast_to(shape_input)

result = regridding.regrid_from_weights(
weights=weights.ndarray,
shape_input=tuple(shape_input.values()),
shape_output=tuple(shape_output.values()),
values_input=values_input.ndarray,
axis_input=tuple(tuple(shape_input).index(a) for a in axis_input),
axis_output=tuple(tuple(shape_output).index(a) for a in axis_output),
)

result = na.ScalarArray(
ndarray=result,
axes=tuple(shape_output),
)

return result


@_implements(na.despike)
def despike(
array: na.AbstractScalarArray,
Expand Down
101 changes: 100 additions & 1 deletion named_arrays/_vectors/vector_named_array_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Callable, TypeVar
from typing import Callable, TypeVar, Sequence, Literal
import numpy as np
import numpy.typing as npt
import matplotlib.axes
import astropy.units as u
import regridding
import named_arrays as na
from named_arrays._scalars import scalars
import named_arrays._scalars.scalar_named_array_functions
Expand Down Expand Up @@ -603,3 +604,101 @@ def ndfilter(
result = prototype.type_explicit.from_components(result)

return result


@_implements(na.regridding.weights)
def regridding_weights(
coordinates_input: na.AbstractVectorArray,
coordinates_output: na.AbstractVectorArray,
axis_input: None | str | Sequence[str] = None,
axis_output: None | str | Sequence[str] = None,
method: Literal['multilinear', 'conservative'] = 'multilinear',
) -> tuple[na.AbstractScalar, dict[str, int], dict[str, int]]:

try:
prototype = vectors._prototype(coordinates_input, coordinates_output)
coordinates_input = vectors._normalize(coordinates_input, prototype)
coordinates_output = vectors._normalize(coordinates_output, prototype)
except vectors.VectorTypeError: # pragma: nocover
return NotImplemented

try:
coordinates_output = coordinates_output.components
coordinates_output = {
c: scalars._normalize(coordinates_output[c])
for c in coordinates_output
if coordinates_output[c] is not None
}
coordinates_output = na.CartesianNdVectorArray(coordinates_output)
except scalars.ScalarTypeError: # pragma: nocover
return NotImplemented

try:
coordinates_input = coordinates_input.components
coordinates_input = {
c: scalars._normalize(coordinates_input[c])
for c in coordinates_output.components
}
coordinates_input = na.CartesianNdVectorArray(coordinates_input)
except scalars.ScalarTypeError: # pragma: nocover
return NotImplemented

coordinates_output = coordinates_output.explicit
coordinates_input = coordinates_input.explicit

shape_input = coordinates_input.shape
shape_output = coordinates_output.shape

if axis_input is None:
axis_input = tuple(shape_input)
elif isinstance(axis_input, str):
axis_input = (axis_input,)

if axis_output is None:
axis_output = tuple(shape_output)
elif isinstance(axis_output, str):
axis_output = (axis_output,)

shape_orthogonal_input = {
a: shape_input[a]
for a in shape_input if a not in axis_input
}
shape_orthogonal_output = {
a: shape_output[a]
for a in shape_output if a not in axis_output
}

shape_orthogonal = na.broadcast_shapes(
shape_orthogonal_input,
shape_orthogonal_output,
)

shape_input = na.broadcast_shapes(shape_orthogonal, shape_input)
shape_output = na.broadcast_shapes(shape_orthogonal, shape_output)

coordinates_input = coordinates_input.broadcast_to(shape_input)
coordinates_output = coordinates_output.broadcast_to(shape_output)

coordinates_input = coordinates_input.components
coordinates_output = coordinates_output.components

coordinates_input = tuple(coordinates_input[c].ndarray for c in coordinates_input)
coordinates_output = tuple(coordinates_output[c].ndarray for c in coordinates_output)

axis_input = tuple(tuple(shape_input).index(a) for a in axis_input)
axis_output = tuple(tuple(shape_output).index(a) for a in axis_output)

result, _shape_input, _shape_output = regridding.weights(
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
axis_input=axis_input,
axis_output=axis_output,
method=method,
)

result = na.ScalarArray(result, tuple(shape_orthogonal))

shape_input = dict(zip(shape_input, _shape_input))
shape_output = dict(zip(shape_output, _shape_output))

return result, shape_input, shape_output
Loading

0 comments on commit 3fa0c95

Please sign in to comment.