diff --git a/xcdat/dataset.py b/xcdat/dataset.py index f5139cd2..acac2651 100644 --- a/xcdat/dataset.py +++ b/xcdat/dataset.py @@ -22,7 +22,7 @@ from xcdat import bounds as bounds_accessor # noqa: F401 from xcdat._logger import _setup_custom_logger -from xcdat.axis import CFAxisKey, _get_all_coord_keys, swap_lon_axis +from xcdat.axis import CFAxisKey, _get_all_coord_keys, get_dim_keys, swap_lon_axis from xcdat.axis import center_times as center_times_func logger = _setup_custom_logger(__name__) @@ -746,3 +746,107 @@ def _get_data_var(dataset: xr.Dataset, key: str) -> xr.DataArray: raise KeyError(f"The data variable '{key}' does not exist in the Dataset.") return dv.copy() + + +def get_bounded_dataarray(ds: xr.Dataset, key: str) -> xr.DataArray: + """ + Convert a dataset to a dataarray with the bounds embedded as coordinates + (i.e., a bounded DataArray). + + Parameters + ---------- + dataset : xr.Dataset + The Dataset. + key : str + The data variable key. + + Returns + ------- + xr.DataArray + The bounded DataArray. + + Raises + ------ + KeyError + If the data variable does not exist in the Dataset. + """ + ds = ds.copy() + # get dataarray + da = ds.get(key) + # check if dataarray exists + if da is None: + raise KeyError(f"The data variable '{key}' does not exist in the Dataset.") + # loop over coordinates to get coordinates and coordinate bounds + coords = {} + for c_key in ds[key].cf.axes.keys(): + try: + # get dimension key (e.g., "time", "lat", "lon") + dim_key = get_dim_keys(ds, c_key) + # add axis to coordinate dict + coords[dim_key] = ds[dim_key] + # get coordinate dtype + dim_value_dtype = ds[dim_key].dtype + # create a bounds dtype (based on coordinate dtype) + bounds_dtype = np.dtype( + [("lower", dim_value_dtype), ("upper", dim_value_dtype)] + ) + # get the bounds for axis + bnds = ds.bounds.get_bounds(axis=c_key) + # convert to expected form for bounds_dtype + newbnds = [tuple(row) for row in bnds.values] + # create new bounds object + newbnds = np.array(newbnds, dtype=bounds_dtype) + # add bounds to coordinate dict + coords[bnds.name] = (dim_key, newbnds) + except: # noqa: E722 + continue + # return dataarray with bounds + da = xr.DataArray(ds[key], coords=coords, dims=ds[key].dims, attrs=ds[key].attrs) + return da + + +def boundedDataArray_to_dataset(bda): + """ + Convert a bounded dataarray to a dataset. + + Parameters + ---------- + bda : xr.DataArray + The bounded dataarray. + + Returns + ------- + xr.Dataset + The dataset. + + Notes + ----- + Note that the .name attribute must be set in the dataarray. + """ + # convert to dataset object + ds = bda.to_dataset() + # loop over coordinates and convert data array bound coordinates + # to bounds dataarrays + for c_key in ds.cf.axes.keys(): + try: + # get dimension key (e.g., "time", "lat", "lon") + dim_key = get_dim_keys(ds, c_key) + # get bounds key + bnds_key = ds[dim_key].bounds + # get bounds in xr.dataset form + bnds = [[b[0], b[1]] for b in ds[bnds_key].to_numpy()] + # remove bounds from dataarray + del ds[bnds_key] + # create bounds dataarray + bda = xr.DataArray( + data=bnds, dims=[dim_key, "bnds"], coords={dim_key: ds[dim_key]} + ) + # update bounds in output dataset + ds[bnds_key] = bda + except: # noqa: E722 + continue + return ds + + +# add get_bounded_array call to xr dataset objects +xr.Dataset.__call__ = get_bounded_dataarray diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 15bec956..6dd181c0 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -26,7 +26,7 @@ get_dim_coords, get_dim_keys, ) -from xcdat.dataset import _get_data_var +from xcdat.dataset import _get_data_var, boundedDataArray_to_dataset from xcdat.utils import _if_multidim_dask_array_then_load #: Type alias for a dictionary of axis keys mapped to their bounds. @@ -173,7 +173,7 @@ def average( Using custom weights for averaging: >>> # The shape of the weights must align with the data var. - >>> self.weights = xr.DataArray( + >>> weights = xr.DataArray( >>> data=np.ones((4, 4)), >>> coords={"lat": self.ds.lat, "lon": self.ds.lon}, >>> dims=["lat", "lon"], @@ -742,3 +742,187 @@ def _averager( weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) return weighted_mean + + +# %% dataset accessors +@xr.register_dataarray_accessor("spatial") +class SpatialAccessorDa: + def __init__(self, dataarray: xr.DataArray): + self._dataarray: xr.DataArray = dataarray + + def average( + self, + axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"), + weights: Union[Literal["generate"], xr.DataArray] = "generate", + keep_weights: bool = False, + lat_bounds: Optional[RegionAxisBounds] = None, + lon_bounds: Optional[RegionAxisBounds] = None, + ) -> xr.DataArray: + """ + Calculates the spatial average for a rectilinear grid over an optionally + specified regional domain. + + Operations include: + + - If a regional boundary is specified, check to ensure it is within the + data variable's domain boundary. + - If axis weights are not provided, get axis weights for standard axis + domains specified in ``axis``. + - Adjust weights to conform to the specified regional boundary. + - Compute spatial weighted average. + + This method requires that the dataarray's coordinates have the 'axis' + attribute set to the keys in ``axis``. For example, the latitude + coordinates should have its 'axis' attribute set to 'Y' (which is also + CF-compliant). This 'axis' attribute is used to retrieve the related + coordinates via `cf_xarray`. Refer to this method's examples for more + information. + + Parameters + ---------- + axis : List[SpatialAxis] + List of axis dimensions to average over, by default ("X", "Y"). + Valid axis keys include "X" and "Y". + weights : {"generate", xr.DataArray}, optional + If "generate", then weights are generated. Otherwise, pass a + DataArray containing the regional weights used for weighted + averaging. ``weights`` must include the same spatial axis dimensions + and have the same dimensional sizes as the data variable, by default + "generate". + keep_weights : bool, optional + If calculating averages using weights, keep the weights in the + final dataset output, by default False. + lat_bounds : Optional[RegionAxisBounds], optional + A tuple of floats/ints for the regional latitude lower and upper + boundaries. This arg is used when calculating axis weights, but is + ignored if ``weights`` are supplied. The lower bound cannot be + larger than the upper bound, by default None. + lon_bounds : Optional[RegionAxisBounds], optional + A tuple of floats/ints for the regional longitude lower and upper + boundaries. This arg is used when calculating axis weights, but is + ignored if ``weights`` are supplied. The lower bound can be larger + than the upper bound (e.g., across the prime meridian, dateline), by + default None. + + Returns + ------- + xr.DataArray + Dataset with the spatially averaged variable. + + Examples + -------- + + Check the 'axis' attribute is set on the required coordinates: + + >>> da.lat.attrs["axis"] + >>> Y + >>> + >>> da.lon.attrs["axis"] + >>> X + + Set the 'axis' attribute for the required coordinates if it isn't: + + >>> da.lat.attrs["axis"] = "Y" + >>> da.lon.attrs["axis"] = "X" + + Call spatial averaging method: + + >>> da.spatial.average(...) + + Get global average time series: + + >>> ts_global = da.spatial.average(axis=["X", "Y"])["tas"] + + Get time series in Nino 3.4 domain: + + >>> ts_n34 = da.spatial.average(axis=["X", "Y"], + >>> lat_bounds=(-5, 5), + >>> lon_bounds=(-170, -120))["ts"] + + Get zonal mean time series: + + >>> ts_zonal = da.spatial.average(axis=["X"])["tas"] + + Using custom weights for averaging: + + >>> # The shape of the weights must align with the data var. + >>> weights = xr.DataArray( + >>> data=np.ones((4, 4)), + >>> coords={"lat": self.ds.lat, "lon": self.ds.lon}, + >>> dims=["lat", "lon"], + >>> ) + >>> + >>> ts_global = ds.spatial.average("tas", axis=["X", "Y"], + >>> weights=weights)["tas"] + """ + # convert dataarray to a dataset + da = self._dataarray.copy() + ds = boundedDataArray_to_dataset(da) + # get data_var key + data_var = da.name + # pass on call to spatial averager + ds_sa = ds.spatial.average( + data_var, + axis=axis, + weights=weights, + keep_weights=keep_weights, + lat_bounds=lat_bounds, + lon_bounds=lon_bounds, + ) + return ds_sa[data_var] + + def get_weights( + self, + axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], + lat_bounds: Optional[RegionAxisBounds] = None, + lon_bounds: Optional[RegionAxisBounds] = None, + data_var: Optional[str] = None, + ) -> xr.DataArray: + """ + Get area weights for specified axis keys and an optional target domain. + + This method first determines the weights for an individual axis based on + the difference between the upper and lower bound. For latitude the + weight is determined by the difference of sine(latitude). All axis + weights are then combined to form a DataArray of weights that can be + used to perform a weighted (spatial) average. + + If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells + outside this selected regional domain are given zero weight. Grid cells + that are partially in this domain are given partial weight. + + Parameters + ---------- + axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] + List of axis dimensions to average over. + lat_bounds : Optional[RegionAxisBounds] + Tuple of latitude boundaries for regional selection, by default + None. + lon_bounds : Optional[RegionAxisBounds] + Tuple of longitude boundaries for regional selection, by default + None. + + Returns + ------- + xr.DataArray + A DataArray containing the region weights to use during averaging. + ``weights`` are 1-D and correspond to the specified axes (``axis``) + in the region. + + Notes + ----- + This method was developed for rectilinear grids only. ``get_weights()`` + recognizes and operate on latitude and longitude, but could be extended + to work with other standard geophysical dimensions (e.g., time, depth, + and pressure). + """ + # convert dataarray to a dataset + da = self._dataarray.copy() + ds = boundedDataArray_to_dataset(da) + # get data_var key + data_var = da.name + # pass on call to get_weights + weights = ds.spatial.get_weights( + axis=axis, lat_bounds=lat_bounds, lon_bounds=lon_bounds, data_var=data_var + ) + return weights