From 010182003038bfb3c0cf5ccd6a34f5cc0fb74150 Mon Sep 17 00:00:00 2001 From: banesullivan Date: Thu, 21 Jul 2022 21:23:54 -0600 Subject: [PATCH] Support nodata values --- pvxarray/accessor.py | 34 +++++++++++++++++++++++++++++----- pvxarray/errors.py | 4 ++++ pvxarray/rectilinear.py | 3 ++- pvxarray/structured.py | 3 ++- 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/pvxarray/accessor.py b/pvxarray/accessor.py index 833afdc..683cad5 100644 --- a/pvxarray/accessor.py +++ b/pvxarray/accessor.py @@ -1,9 +1,12 @@ from typing import Optional +import warnings +import numpy as np import pyvista as pv import xarray as xr from pvxarray import rectilinear, structured +from pvxarray.errors import DataCopyWarning, DataModificationWarning class _LocIndexer: @@ -39,9 +42,28 @@ def _get_array(self, key): f"Key {key} not present in DataArray. Choices are: {list(self._obj.coords.keys())}" ) - @property - def data(self): - return self._obj.values + def data(self, nodata: Optional[float] = None): + values = self._obj.values + if nodata is not None: + nans = values == nodata + if np.any(nans): + try: + values[nans] = np.nan + warnings.warn( + DataModificationWarning( + "nodata values overwritten with `np.nan` in source DataArray." + ) + ) + except ValueError: + dytpe = values.dtype + values = values.astype(float) + values[nans] = np.nan + warnings.warn( + DataCopyWarning( + f"{dytpe} does not support overwritting values with nan. Copying and casting these data to float." + ) + ) + return values def mesh( self, @@ -50,6 +72,7 @@ def mesh( z: Optional[str] = None, order: Optional[str] = None, component: Optional[str] = None, + nodata: Optional[float] = None, ) -> pv.DataSet: ndim = 0 if x is not None: @@ -69,7 +92,7 @@ def mesh( else: # RectilinearGrid meth = rectilinear.mesh - return meth(self, x=x, y=y, z=z, order=order, component=component) + return meth(self, x=x, y=y, z=z, order=order, component=component, nodata=nodata) def plot( self, @@ -77,6 +100,7 @@ def plot( y: Optional[str] = None, z: Optional[str] = None, order: str = "C", + nodata: Optional[float] = None, **kwargs, ): - return self.mesh(x=x, y=y, z=z, order=order).plot(**kwargs) + return self.mesh(x=x, y=y, z=z, order=order, nodata=nodata).plot(**kwargs) diff --git a/pvxarray/errors.py b/pvxarray/errors.py index 9580902..c630e38 100644 --- a/pvxarray/errors.py +++ b/pvxarray/errors.py @@ -1,2 +1,6 @@ class DataCopyWarning(Warning): pass + + +class DataModificationWarning(Warning): + pass diff --git a/pvxarray/rectilinear.py b/pvxarray/rectilinear.py index e8d00eb..e777aa9 100644 --- a/pvxarray/rectilinear.py +++ b/pvxarray/rectilinear.py @@ -13,6 +13,7 @@ def mesh( z: Optional[str] = None, order: Optional[str] = "C", component: Optional[str] = None, + nodata: Optional[float] = None, ): if order is None: order = "C" @@ -28,7 +29,7 @@ def mesh( if z is not None: self._mesh.z = self._get_array(z) # Handle data values - values = self.data + values = self.data(nodata=nodata) values_dim = values.ndim if component is not None: # if ndim < values.ndim and values.ndim == ndim + 1: diff --git a/pvxarray/structured.py b/pvxarray/structured.py index fd4dadb..4b12777 100644 --- a/pvxarray/structured.py +++ b/pvxarray/structured.py @@ -76,6 +76,7 @@ def mesh( z: Optional[str] = None, order: str = "F", component: Optional[str] = None, # TODO + nodata: Optional[float] = None, ): if order is None: order = "F" @@ -89,7 +90,7 @@ def mesh( points, shape = _points(self, x=x, y=y, z=z, order=order) self._mesh.points = points self._mesh.dimensions = shape - data = self.data + data = self.data(nodata=nodata) if tuple(data.shape) != tuple(shape): raise ValueError( "Coord and data shape mismatch. You may need to `transpose` the DataArray. "