From c11bbbcf4b35d0c58884c042ba69340ecb5a6b47 Mon Sep 17 00:00:00 2001 From: Shane St Savage Date: Wed, 22 May 2024 00:19:46 -0700 Subject: [PATCH] WIP: Initial dataset transformation extension for virtual vector vars Generates speed and direction variables from detected component vars. Currently a proof of concept, the results are incomplete and surely incorrect in many (all?) cases. Specifics on how to render the directional variable with arrows/barbs and how to structure requests for and layering of the speed/magnitude and directional layers are TBD. --- datasets/datasets.yml | 2 + xreds/dataset_provider.py | 3 +- xreds/extensions/__init__.py | 3 +- xreds/extensions/virtual_vectors.py | 116 ++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 xreds/extensions/virtual_vectors.py diff --git a/datasets/datasets.yml b/datasets/datasets.yml index a477f84..285a726 100644 --- a/datasets/datasets.yml +++ b/datasets/datasets.yml @@ -5,6 +5,7 @@ cbofs: drop_variables: - dstart extensions: + virtual_vectors: vdatum: path: s3://nextgen-dmac-cloud-ingest/nos/vdatums/cbofs_vdatums.nc.zarr water_level_var: zeta @@ -17,6 +18,7 @@ ciofs: drop_variables: - dstart extensions: + virtual_vectors: vdatum: path: s3://nextgen-dmac-cloud-ingest/nos/vdatums/ciofs_vdatums.nc.zarr water_level_var: zeta diff --git a/xreds/dataset_provider.py b/xreds/dataset_provider.py index e90a8cc..94078f2 100644 --- a/xreds/dataset_provider.py +++ b/xreds/dataset_provider.py @@ -11,11 +11,12 @@ from xreds.logging import logger from xreds.config import settings from xreds.utils import load_dataset -from xreds.extensions import VDatumTransformationExtension +from xreds.extensions import VDatumTransformationExtension, VirtualVectorsTransformationExtension dataset_extension_manager = PluginManager(DATASET_EXTENSION_PLUGIN_NAMESPACE) dataset_extension_manager.register(VDatumTransformationExtension, name="vdatum") +dataset_extension_manager.register(VirtualVectorsTransformationExtension, name="virtual_vectors") class DatasetProvider(Plugin): diff --git a/xreds/extensions/__init__.py b/xreds/extensions/__init__.py index e07b9fb..ee00955 100644 --- a/xreds/extensions/__init__.py +++ b/xreds/extensions/__init__.py @@ -1,2 +1,3 @@ # __module__ -from .vdatum import VDatumTransformationExtension \ No newline at end of file +from .vdatum import VDatumTransformationExtension +from .virtual_vectors import VirtualVectorsTransformationExtension diff --git a/xreds/extensions/virtual_vectors.py b/xreds/extensions/virtual_vectors.py new file mode 100644 index 0000000..d2a5fb6 --- /dev/null +++ b/xreds/extensions/virtual_vectors.py @@ -0,0 +1,116 @@ +import numpy as np +import xarray as xr + +from xreds.dataset_extension import DatasetExtension, hookimpl +from xreds.logging import logger + + +class VectorPair(): + def __init__(self): + self.x_var = None + self.y_var = None + + def is_complete(self): + return self.x_var is not None and self.y_var is not None + + +class VirtualVectorsTransformationExtension(DatasetExtension): + """Virtual vector variables transformation extension""" + + name: str = "virtual_vectors" + + @hookimpl + def transform_dataset(self, ds: xr.Dataset, config: dict) -> xr.Dataset: + """Transform a dataset by adding virtual vector variables""" + + vector_pairs = {} + for var_name in ds: + var = ds[var_name] + if "standard_name" not in var.attrs: + continue + + def get_or_init_pair(vector_name): + if vector_name not in vector_pairs: + vector_pairs[vector_name] = VectorPair() + return vector_pairs[vector_name] + + def get_vector_var_name(std_name, prefixes, substrs, excludes): + if any(exclude in std_name for exclude in excludes): + return None + for prefix in prefixes: + if std_name.startswith(prefix): + return std_name.removeprefix(prefix) + for substr in substrs: + if substr in std_name: + return std_name.replace(substr, "_") + return None + + def check_scalar(var, vector_pair_attr, prefixes, substrs, excludes): + vector_name = get_vector_var_name( + std_name=var.attrs["standard_name"], + prefixes=prefixes, + substrs=substrs, + excludes=excludes) + if vector_name: + setattr(get_or_init_pair(vector_name), vector_pair_attr, var) + + check_scalar( + var, + vector_pair_attr="x_var", + prefixes=["eastward_"], + substrs=["_eastward_", "_x_"], + excludes=["_x_edges", "_x_spacing"]) + check_scalar( + var, + vector_pair_attr="y_var", + prefixes=["northward_"], + substrs=["_northward_", "_y_"], + excludes=["_y_edges", "_y_spacing"]) + + for pair_var_name in vector_pairs: + vector_pair = vector_pairs[pair_var_name] + if not vector_pair.is_complete(): + continue + + x_var = vector_pair.x_var + y_var = vector_pair.y_var + + if not x_var.dims == y_var.dims: + logger.warn( + f'Discovered vector pair {x_var.name}/{y_var.name}' + f' have mismatched dims {x_var.dims} vs {y_var.dims}' + ', skipping' + ) + continue + + template_var = x_var + vector_long_name = pair_var_name.replace("_", " ") + + speed_var = xr.DataArray( + data=np.sqrt(np.square(x_var) + np.square(y_var)), + dims=template_var.dims, + coords=template_var.coords, + attrs=template_var.attrs, + ) + del speed_var.attrs['standard_name'] + speed_var.attrs.update({ + "long_name": f"{vector_long_name} speed", + }) + ds[f"{pair_var_name}_speed"] = speed_var + + # NOTE: this is not yet checked whatsoever for correctness + # with regard to wind or wave to/from direction conventions + direction_var = xr.DataArray( + data=np.degrees(np.arctan2(x_var, y_var)) % 360, + dims=template_var.dims, + coords=template_var.coords, + attrs=template_var.attrs, + ) + del direction_var.attrs['standard_name'] + direction_var.attrs.update({ + "long_name": f"{vector_long_name} direction", + "units": "degrees", + }) + ds[f"{pair_var_name}_direction"] = direction_var + + return ds