Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nodata values #35

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions pvxarray/accessor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Optional
import warnings

import numpy as np
import pyvista as pv
import xarray as xr

from pvxarray import rectilinear, structured
from pvxarray.errors import DataCopyWarning, DataModificationWarning


class _LocIndexer:
Expand Down Expand Up @@ -39,9 +42,28 @@ def _get_array(self, key):
f"Key {key} not present in DataArray. Choices are: {list(self._obj.coords.keys())}"
)

@property
def data(self):
return self._obj.values
def data(self, nodata: Optional[float] = None):
values = self._obj.values
if nodata is not None:
nans = values == nodata
if np.any(nans):
try:
values[nans] = np.nan
warnings.warn(
DataModificationWarning(
"nodata values overwritten with `np.nan` in source DataArray."
)
)
except ValueError:
dytpe = values.dtype
values = values.astype(float)
values[nans] = np.nan
warnings.warn(
DataCopyWarning(
f"{dytpe} does not support overwritting values with nan. Copying and casting these data to float."
)
)
return values

def mesh(
self,
Expand All @@ -50,6 +72,7 @@ def mesh(
z: Optional[str] = None,
order: Optional[str] = None,
component: Optional[str] = None,
nodata: Optional[float] = None,
) -> pv.DataSet:
ndim = 0
if x is not None:
Expand All @@ -69,14 +92,15 @@ def mesh(
else:
# RectilinearGrid
meth = rectilinear.mesh
return meth(self, x=x, y=y, z=z, order=order, component=component)
return meth(self, x=x, y=y, z=z, order=order, component=component, nodata=nodata)

def plot(
self,
x: Optional[str] = None,
y: Optional[str] = None,
z: Optional[str] = None,
order: str = "C",
nodata: Optional[float] = None,
**kwargs,
):
return self.mesh(x=x, y=y, z=z, order=order).plot(**kwargs)
return self.mesh(x=x, y=y, z=z, order=order, nodata=nodata).plot(**kwargs)
4 changes: 4 additions & 0 deletions pvxarray/errors.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class DataCopyWarning(Warning):
pass


class DataModificationWarning(Warning):
pass
3 changes: 2 additions & 1 deletion pvxarray/rectilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def mesh(
z: Optional[str] = None,
order: Optional[str] = "C",
component: Optional[str] = None,
nodata: Optional[float] = None,
):
if order is None:
order = "C"
Expand All @@ -28,7 +29,7 @@ def mesh(
if z is not None:
self._mesh.z = self._get_array(z)
# Handle data values
values = self.data
values = self.data(nodata=nodata)
values_dim = values.ndim
if component is not None:
# if ndim < values.ndim and values.ndim == ndim + 1:
Expand Down
3 changes: 2 additions & 1 deletion pvxarray/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def mesh(
z: Optional[str] = None,
order: str = "F",
component: Optional[str] = None, # TODO
nodata: Optional[float] = None,
):
if order is None:
order = "F"
Expand All @@ -89,7 +90,7 @@ def mesh(
points, shape = _points(self, x=x, y=y, z=z, order=order)
self._mesh.points = points
self._mesh.dimensions = shape
data = self.data
data = self.data(nodata=nodata)
if tuple(data.shape) != tuple(shape):
raise ValueError(
"Coord and data shape mismatch. You may need to `transpose` the DataArray. "
Expand Down