Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to using uclales-utils for extraction #13

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
101 changes: 46 additions & 55 deletions genesis/utils/pipeline/data/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ....utils import calc_flux, find_vertical_grid_spacing, transforms
from ... import mask_functions
from ..data_sources.uclales.common import _fix_time_units
from .base import (
NumpyDatetimeParameter,
XArrayTarget,
Expand All @@ -32,51 +33,19 @@
)


def fix_time_units(da):
modified = False
if np.issubdtype(da.dtype, np.datetime64):
# already converted since xarray has managed to parse the time in
# CF-format
pass
elif da.attrs["units"].startswith("seconds since 2000-01-01"):
# I fixed UCLALES to CF valid output, this is output from a fixed
# version
pass
elif da.attrs["units"].startswith("seconds since 2000-00-00"):
da.attrs["units"] = da.attrs["units"].replace(
"seconds since 2000-00-00",
"seconds since 2000-01-01",
)
modified = True
elif da.attrs["units"].startswith("seconds since 0-00-00"):
# 2D fields have strange time units...
da.attrs["units"] = da.attrs["units"].replace(
"seconds since 0-00-00",
"seconds since 2000-01-01",
)
modified = True
elif da.attrs["units"].startswith("seconds since 0-0-0"):
# 2D fields have strange time units...
da.attrs["units"] = da.attrs["units"].replace(
"seconds since 0-0-0",
"seconds since 2000-01-01",
)
modified = True
elif da.attrs["units"] == "day as %Y%m%d.%f":
da = (da * 24 * 60 * 60).astype(int)
da.attrs["units"] = "seconds since 2000-01-01 00:00:00"
modified = True
else:
raise NotImplementedError(da.attrs["units"])
return da, modified


class XArrayTarget3DExtraction(XArrayTarget):
def open(self, *args, **kwargs):
ds = super(XArrayTarget3DExtraction, self).open(*args, **kwargs)
if len(ds.coords) == 0:
if len(ds.coords) == 0 and len(ds.dims) == 0:
raise Exception(f"{self.fn} doesn't contain any data")
ds = self._ensure_coord_units(ds)

if isinstance(ds, xr.Dataset) and len(ds.variables) == 0:
raise Exception(
f"Stored 3D file for `{self.path}` is empty, please delete so"
"it can be recreated"
)

return ds

def _ensure_coord_units(self, da):
Expand All @@ -98,6 +67,8 @@ class ExtractField3D(luigi.Task):
base_name = luigi.Parameter()
field_name = luigi.Parameter()

# follows filename of uclales-utils given that we put ".tn{tn}" into "var_name
# SINGLE_VAR_FILENAME_FORMAT_3D = "{file_prefix}.{var_name}.tn{tn}.nc"
FN_FORMAT = "{experiment_name}.{field_name}.nc"

@staticmethod
Expand Down Expand Up @@ -209,26 +180,14 @@ def output(self):

p = get_workdir() / self.base_name / fn

t = XArrayTarget3DExtraction(str(p))

if t.exists():
data = t.open()
if isinstance(data, xr.Dataset):
if len(data.variables) == 0:
warnings.warn(
"Stored file for `{}` is empty, deleting..."
"".format(self.field_name)
)
p.unlink()

return t
return XArrayTarget3DExtraction(str(p))


class XArrayTarget2DCrossSection(XArrayTarget):
def open(self, *args, **kwargs):
kwargs["decode_times"] = False
da = super().open(*args, **kwargs)
da["time"], _ = fix_time_units(da["time"])
da["time"], _ = _fix_time_units(da["time"])

# xr.decode_cf only works on datasets
ds = xr.decode_cf(da.to_dataset())
Expand All @@ -243,6 +202,31 @@ class TimeCrossSectionSlices2D(luigi.Task):

FN_FORMAT = "{exp_name}.out.xy.{var_name}.nc"

@staticmethod
def _get_data_loader_module(meta):
model_name = meta.get("model")
if model_name is None:
model_name = "UCLALES"

module_name = ".data_sources.{}".format(model_name.lower().replace("-", "_"))
return importlib.import_module(module_name, package="genesis.utils.pipeline")

def requires(self):
meta = _get_dataset_meta_info(self.base_name)
data_loader = self._get_data_loader_module(meta=meta)
fn = getattr(data_loader, "build_runtime_cross_section_extraction_task")
# TODO remove hardcoded orientation
base_name = meta.get("experiment_name", self.base_name)
dest_path = get_workdir() / self.base_name / "cross_sections" / "runtime_slices"
task = fn(
dataset_meta=meta,
var_name=self.var_name,
orientation="xy",
dest_path=str(dest_path),
base_name=base_name,
)
return task

def _extract_and_symlink_local_file(self):
meta = _get_dataset_meta_info(self.base_name)

Expand Down Expand Up @@ -414,7 +398,14 @@ def output(self):

fn = ".".join(name_parts)

p = get_workdir() / self.base_name / "cross_sections" / "runtime_slices" / fn
p = (
get_workdir()
/ self.base_name
/ "cross_sections"
/ "runtime_slices"
/ "by_time"
/ fn
)
return XArrayTarget(str(p))


Expand Down
10 changes: 7 additions & 3 deletions genesis/utils/pipeline/data/tracking_2d/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def fn_unique_dropna(v):
def run(self):
da_labels = self.input()["tracking_labels"].open().fillna(0).astype(int)

op = self.op
if self.var_name in ["xt", "yt"]:
if self.var_name in "xt":
_, da_values = xr.broadcast(da_labels.xt, da_labels.yt)
Expand All @@ -187,18 +188,19 @@ def run(self):
da_values = xr.ones_like(da_labels) * dx**2.0
da_values.attrs["units"] = f"{da_labels.xt.units}^2"
da_values.attrs["long_name"] = "area"
op = "sum_labels"
else:
da_values = self.input()["field"].open()

if self.op == "histogram":
if op == "histogram":
da_out = self._aggregate_as_hist(
da_values=da_values,
da_labels=da_labels,
)
name = f"{self.var_name}__hist_{self.dx}"
elif self.op in vars(dmeasure):
elif op in vars(dmeasure):
da_out = self._aggregate_generic(
da_values=da_values, da_labels=da_labels, op=self.op
da_values=da_values, da_labels=da_labels, op=op
)
name = f"{self.var_name}__{self.op}"
else:
Expand Down Expand Up @@ -468,6 +470,8 @@ def make_object_agg_task(self, object_id):
def _get_times(object_id):
tstart_obj = self.da_tstart.sel({obj_var: object_id})
tend_obj = self.da_tend.sel({obj_var: object_id})
if tstart_obj == tend_obj:
return []
times = self.da_time.sel(time=slice(tstart_obj, tend_obj)).values

# super hacky way of sending list of datetimes as luigi parameter
Expand Down
4 changes: 4 additions & 0 deletions genesis/utils/pipeline/data/tracking_2d/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ def run(self):
)

else:
print(
f"Didn't find tracking output in `{self.output().path}`"
", trying to run tracking utility"
)
dataset_name = meta["experiment_name"]

if self.run_in_temp_dir:
Expand Down
6 changes: 5 additions & 1 deletion genesis/utils/pipeline/data_sources/uclales/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from . import tracking_2d # noqa
from .base import Extract3D # noqa
from .base import ( # noqa
DERIVED_FIELDS,
build_runtime_cross_section_extraction_task,
extract_field_to_filename,
)
104 changes: 81 additions & 23 deletions genesis/utils/pipeline/data_sources/uclales/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
import scipy.optimize
import xarray as xr
from uclales.output import Extract

from .... import center_staggered_field
from .common import _fix_time_units
from .extraction import Extract3D

FIELD_NAME_MAPPING = dict(
w="w_zt",
Expand Down Expand Up @@ -120,23 +120,49 @@ class RawDataPathDoesNotExist(Exception):
pass


def _build_block_extraction_task(dataset_meta, field_name):
if field_name in ["u", "v", "w"]:
var_name = field_name
else:
var_name = FIELD_NAME_MAPPING[field_name]
def _build_block_extraction_task(dataset_meta, field_name, output_path):

raw_data_path = Path(dataset_meta["path"]) / "raw_data"
dest_path = Path(dataset_meta["path"]) / "3d_blocks" / "full_domain"

if not raw_data_path.exists():
raise RawDataPathDoesNotExist

task = Extract3D(
task_kwargs = dict(
source_path=raw_data_path,
file_prefix=dataset_meta["experiment_name"],
var_name=var_name,
tn=dataset_meta["timestep"],
kind="3d",
dest_path=dest_path,
fix_units=True,
)

if field_name in ["u", "v", "w"]:
task_kwargs["var_name"] = field_name
elif field_name in FIELD_NAME_MAPPING:
task_kwargs["var_name"] = FIELD_NAME_MAPPING[field_name]

if not raw_data_path.exists():
raise RawDataPathDoesNotExist

task = Extract(**task_kwargs)
return task


def build_runtime_cross_section_extraction_task(
dataset_meta, var_name, orientation, base_name, dest_path
):
raw_data_path = Path(dataset_meta["path"]) / "raw_data"

task_kwargs = {}
task_kwargs["var_name"] = var_name
task_kwargs["kind"] = "2d"
task_kwargs["orientation"] = orientation
task_kwargs["dest_path"] = dest_path
task_kwargs["file_prefix"] = base_name
task_kwargs["source_path"] = raw_data_path

if not raw_data_path.exists():
raise RawDataPathDoesNotExist

task = Extract(**task_kwargs)
return task


Expand Down Expand Up @@ -182,15 +208,21 @@ def extract_field_to_filename(dataset_meta, path_out, field_name, **kwargs): #
task = _build_block_extraction_task(
dataset_meta=dataset_meta,
field_name=field_name,
output_path=path_out,
)
# if the source file doesn't exist we return a task to create
# it, next time we pass here the file should exist and we can
# just open it
if not task.output().exists():
return task

da = task.output().open(decode_times=False)
can_symlink = False
try:
da = task.output().open()
except ValueError:
da = task.output().open(decode_times=False)
can_symlink = False
# set `path_in` so that we ensure we try symlinking the right file
path_in = Path(task.output().path)

except RawDataPathDoesNotExist:
raise Exception(
Expand Down Expand Up @@ -287,40 +319,63 @@ def _calc_qv__norain(qt, qc):
return qv


@np.vectorize
def _calc_temperature_single(q_l, p, theta_l):
# @numba.jit(numba.float64(numba.float64, numba.float64, numba.float64), nopython=True)
def _calc_theta_l(T, p, q_l):
# constants from UCLALES
L_v = 2.5 * 1.0e6 # [J/kg]
p_theta = 1.0e5
cp_d = 1.004 * 1.0e3 # [J/kg/K]
R_d = 287.04 # [J/kg/K]
# XXX: this is *not* the *actual* liquid potential temperature (as
# given in B. Steven's notes on moist thermodynamics), but instead
# reflects the form used in UCLALES where in place of the mixture
# heat-capacity the dry-air heat capacity is used
return T * (p_theta / p) ** (R_d / cp_d) * np.exp(-L_v * q_l / (cp_d * T))


def _calc_temperature_single(q_l, p, theta_l, T_rtol=1.0e-6, T_abstol=1.0e-6):
# constants from UCLALES
cp_d = 1.004 * 1.0e3 # [J/kg/K]
R_d = 287.04 # [J/kg/K]
L_v = 2.5 * 1.0e6 # [J/kg]
p_theta = 1.0e5

# XXX: this is *not* the *actual* liquid potential temperature (as
# given in B. Steven's notes on moist thermodynamics), but instead
# reflects the form used in UCLALES where in place of the mixture
# heat-capacity the dry-air heat capacity is used
def temp_func(T):
return theta_l - T * (p_theta / p) ** (R_d / cp_d) * np.exp(
-L_v * q_l / (cp_d * T)
)
def temp_func(T, theta_l, p, q_l):
return theta_l - _calc_theta_l(T=T, p=p, q_l=q_l)

if np.all(q_l == 0.0):
if q_l == 0.0:
# no need for root finding
return theta_l / ((p_theta / p) ** (R_d / cp_d))

# XXX: brentq solver requires bounds, I don't expect we'll get below -100C
T_min = -100.0 + 273.0
T_max = 50.0 + 273.0
T = scipy.optimize.brentq(f=temp_func, a=T_min, b=T_max)
T = scipy.optimize.brentq(f=temp_func, a=T_min, b=T_max, args=(theta_l, p, q_l), xtol=T_abstol, rtol=T_rtol)

# check that we're within 1.0e-4
assert np.all(np.abs(temp_func(T)) < 1.0e-4)
assert np.all(np.abs(temp_func(T, theta_l, p, q_l)) < 1.0e-4)

return T


def _calc_temperature(qc, qr, p, theta_l):
if not qc.dims == qr.dims == p.dims == theta_l.dims:
expvars = [qc, qr, p, theta_l]
s = "\n\t".join([f"{v.name}: {list(v.dims)}" for v in expvars])
raise Exception("Incompatible dims:\n\t" + s)

raise Exception("Use julia implementation")
"""

q_l = qc + qr

q_l = q_l.isel(**kws)
p = p.isel(**kws)
theta_l = theta_l.isel(**kws)

arr_temperature = np.vectorize(_calc_temperature_single)(
q_l=q_l, p=p, theta_l=theta_l
)
Expand All @@ -329,6 +384,9 @@ def _calc_temperature(qc, qr, p, theta_l):
arr_temperature,
dims=p.dims,
attrs=dict(longname="temperature", units="K"),
coords=p.coords,
name="abstemp"
)

return da_temperature
"""
Loading
Loading