Skip to content

Commit

Permalink
feat(models): Add GFS Model repository
Browse files Browse the repository at this point in the history
  • Loading branch information
devsjc committed Oct 31, 2024
1 parent e22c95f commit 0f6c84b
Show file tree
Hide file tree
Showing 12 changed files with 497 additions and 81 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ authors = [
classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
"dask == 2024.8.1",
"eccodes == 2.38.1",
"eccodes == 2.38.3",
"ecmwf-api-client == 1.6.3",
"cfgrib == 0.9.14.0",
"dagster-pipes == 1.8.5",
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ def parse_env() -> Adaptors:
"""Parse from the environment."""
model_repository_adaptor: type[ports.ModelRepository]
match os.getenv("MODEL_REPOSITORY"):
case None:
log.error("MODEL_REPOSITORY is not set in environment.")
sys.exit(1)
case None | "gfs":
model_repository_adaptor = repositories.NOAAGFSS3ModelRepository
case "ceda":
model_repository_adaptor = repositories.CedaMetOfficeGlobalModelRepository
case "ecmwf-realtime-s3":
case "ecmwf-realtime":
model_repository_adaptor = repositories.ECMWFRealTimeS3ModelRepository
case _ as model:
log.error(f"Unknown model: {model}")
Expand Down
2 changes: 1 addition & 1 deletion src/nwp_consumer/internal/entities/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def to_pandas(self) -> dict[str, pd.Index]: # type: ignore
This is useful for interoperability with xarray, which prefers to define
DataArray coordinates using a dict pandas Index objects.
For the most part, the conversion consists of a straighforward cast
For the most part, the conversion consists of a straightforward cast
to a pandas Index object. However, there are some caveats involving
the time-centric dimensions:
Expand Down
32 changes: 32 additions & 0 deletions src/nwp_consumer/internal/entities/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import dataclasses
from enum import StrEnum, auto

from returns.result import Failure, ResultE, Success


@dataclasses.dataclass(slots=True)
class ParameterLimits:
Expand Down Expand Up @@ -77,6 +79,9 @@ class ParameterData:
Used in sanity and validity checking the database values.
"""

alternate_shortnames: list[str] = dataclasses.field(default_factory=list)
"""Alternate names for the parameter found in the wild."""

def __str__(self) -> str:
"""String representation of the parameter."""
return self.name
Expand Down Expand Up @@ -121,6 +126,7 @@ def metadata(self) -> ParameterData:
description="Temperature at screen level",
units="C",
limits=ParameterLimits(upper=60, lower=-90),
alternate_shortnames=["t", "t2m"],
)
case self.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.name:
return ParameterData(
Expand All @@ -130,6 +136,7 @@ def metadata(self) -> ParameterData:
"incident on the surface expected over the next hour.",
units="W/m^2",
limits=ParameterLimits(upper=1500, lower=0),
alternate_shortnames=["swavr", "ssrd", "dswrf"],
)
case self.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.name:
return ParameterData(
Expand All @@ -139,6 +146,7 @@ def metadata(self) -> ParameterData:
"incident on the surface expected over the next hour.",
units="W/m^2",
limits=ParameterLimits(upper=500, lower=0),
alternate_shortnames=["strd", "dlwrf"]
)
case self.RELATIVE_HUMIDITY_SL.name:
return ParameterData(
Expand All @@ -148,6 +156,7 @@ def metadata(self) -> ParameterData:
"to the equilibrium vapour pressure of water",
units="%",
limits=ParameterLimits(upper=100, lower=0),
alternate_shortnames=["r"],
)
case self.VISIBILITY_SL.name:
return ParameterData(
Expand All @@ -157,6 +166,7 @@ def metadata(self) -> ParameterData:
"horizontally in daylight conditions.",
units="m",
limits=ParameterLimits(upper=4500, lower=0),
alternate_shortnames=["vis"],
)
case self.WIND_U_COMPONENT_10m.name:
return ParameterData(
Expand All @@ -166,6 +176,7 @@ def metadata(self) -> ParameterData:
"the wind in the eastward direction.",
units="m/s",
limits=ParameterLimits(upper=100, lower=-100),
alternate_shortnames=["u10"],
)
case self.WIND_V_COMPONENT_10m.name:
return ParameterData(
Expand All @@ -176,6 +187,7 @@ def metadata(self) -> ParameterData:
units="m/s",
# Non-tornadic winds are usually < 100m/s
limits=ParameterLimits(upper=100, lower=-100),
alternate_shortnames=["v10"],
)
case self.WIND_U_COMPONENT_100m.name:
return ParameterData(
Expand All @@ -185,6 +197,7 @@ def metadata(self) -> ParameterData:
"the wind in the eastward direction.",
units="m/s",
limits=ParameterLimits(upper=100, lower=-100),
alternate_shortnames=["u100"],
)
case self.WIND_V_COMPONENT_100m.name:
return ParameterData(
Expand All @@ -194,6 +207,7 @@ def metadata(self) -> ParameterData:
"the wind in the northward direction.",
units="m/s",
limits=ParameterLimits(upper=100, lower=-100),
alternate_shortnames=["v100"],
)
case self.WIND_U_COMPONENT_200m.name:
return ParameterData(
Expand All @@ -203,6 +217,7 @@ def metadata(self) -> ParameterData:
"the wind in the eastward direction.",
units="m/s",
limits=ParameterLimits(upper=150, lower=-150),
alternate_shortnames=["u200"],
)
case self.WIND_V_COMPONENT_200m.name:
return ParameterData(
Expand All @@ -212,13 +227,15 @@ def metadata(self) -> ParameterData:
"the wind in the northward direction.",
units="m/s",
limits=ParameterLimits(upper=150, lower=-150),
alternate_shortnames=["v200"],
)
case self.SNOW_DEPTH_GL.name:
return ParameterData(
name=str(self),
description="Depth of snow on the ground.",
units="m",
limits=ParameterLimits(upper=12, lower=0),
alternate_shortnames=["sd", "sdwe"],
)
case self.CLOUD_COVER_HIGH.name:
return ParameterData(
Expand All @@ -229,6 +246,7 @@ def metadata(self) -> ParameterData:
"to the square's total area.",
units="UI",
limits=ParameterLimits(upper=1, lower=0),
alternate_shortnames=["hcc"],
)
case self.CLOUD_COVER_MEDIUM.name:
return ParameterData(
Expand All @@ -239,6 +257,7 @@ def metadata(self) -> ParameterData:
"to the square's total area.",
units="UI",
limits=ParameterLimits(upper=1, lower=0),
alternate_shortnames=["mcc"],
)
case self.CLOUD_COVER_LOW.name:
return ParameterData(
Expand All @@ -249,6 +268,7 @@ def metadata(self) -> ParameterData:
"to the square's total area.",
units="UI",
limits=ParameterLimits(upper=1, lower=0),
alternate_shortnames=["lcc"],
)
case self.CLOUD_COVER_TOTAL.name:
return ParameterData(
Expand All @@ -259,6 +279,7 @@ def metadata(self) -> ParameterData:
"to the square's total area.",
units="UI",
limits=ParameterLimits(upper=1, lower=0),
alternate_shortnames=["tcc", "clt"],
)
case self.TOTAL_PRECIPITATION_RATE_GL.name:
return ParameterData(
Expand All @@ -268,6 +289,7 @@ def metadata(self) -> ParameterData:
"including rain, snow, and hail.",
units="kg/m^2/s",
limits=ParameterLimits(upper=0.2, lower=0),
alternate_shortnames=["prate", "tprate"],
)
case self.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL.name:
return ParameterData(
Expand All @@ -278,6 +300,7 @@ def metadata(self) -> ParameterData:
"expected over the next hour.",
units="W/m^2",
limits=ParameterLimits(upper=1000, lower=0),
alternate_shortnames=["uvb"],
)
case self.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.name:
return ParameterData(
Expand All @@ -289,7 +312,16 @@ def metadata(self) -> ParameterData:
"expected over the next hour.",
units="W/m^2",
limits=ParameterLimits(upper=1000, lower=0),
alternate_shortnames=["dsrp"],
)
case _:
# Shouldn't happen thanks to the test case in test_parameters.py
raise ValueError(f"Unknown parameter: {self}")

def try_from_alternate(name: str) -> ResultE["Parameter"]:
"""Map an alternate name to a parameter."""
for p in Parameter:
if name in p.metadata().alternate_shortnames:
return Success(p)
return Failure(ValueError(f"Unknown shortname: {name}"))

10 changes: 10 additions & 0 deletions src/nwp_consumer/internal/entities/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from hypothesis import given
from hypothesis import strategies as st
from returns.pipeline import is_successful

from .parameters import Parameter

Expand All @@ -15,6 +16,15 @@ def test_metadata(self, p: Parameter) -> None:
metadata = p.metadata()
self.assertEqual(metadata.name, p.value)

@given(st.sampled_from([s for p in Parameter for s in p.metadata().alternate_shortnames]))
def test_try_from_shortname(self, shortname: str) -> None:
"""Test the try_from_shortname method."""
p = Parameter.try_from_alternate(shortname)
self.assertTrue(is_successful(p))

p = Parameter.try_from_alternate("invalid")
self.assertFalse(is_successful(p))


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions src/nwp_consumer/internal/repositories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .model_repositories import (
CedaMetOfficeGlobalModelRepository,
ECMWFRealTimeS3ModelRepository,
NOAAGFSS3ModelRepository,
)
from .notification_repositories import (
StdoutNotificationRepository,
Expand All @@ -35,6 +36,7 @@
__all__ = [
"CedaMetOfficeGlobalModelRepository",
"ECMWFRealTimeS3ModelRepository",
"NOAAGFSS3ModelRepository",
"StdoutNotificationRepository",
"DagsterPipesNotificationRepository",
]
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .metoffice_global import CedaMetOfficeGlobalModelRepository
from .ecmwf_realtime import ECMWFRealTimeS3ModelRepository
from .noaa_gfs import NOAAGFSS3ModelRepository

__all__ = [
"CedaMetOfficeGlobalModelRepository",
"ECMWFRealTimeS3ModelRepository",
"NOAAGFSS3ModelRepository",
]

Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def repository() -> entities.ModelRepositoryMetadata:
name="ECMWF-Realtime-S3",
is_archive=False,
is_order_based=True,
running_hours=[0, 12],
running_hours=[0, 6, 12, 18],
delay_minutes=(60 * 6), # 6 hours
max_connections=100,
required_env=[
Expand Down Expand Up @@ -196,7 +196,7 @@ def _download(self, url: str) -> ResultE[pathlib.Path]:
# Only download the file if not already present
if not local_path.exists():
local_path.parent.mkdir(parents=True, exist_ok=True)
log.info("Requesting file from S3 at: '%s'", url)
log.debug("Requesting file from S3 at: '%s'", url)

try:
if not self._fs.exists(url):
Expand Down Expand Up @@ -234,13 +234,19 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:
f"Error opening '{path}' as list of xarray Datasets: {e}",
))
if len(dss) == 0:
return Failure(ValueError(f"No datasets found in '{path}'"))
return Failure(ValueError(
f"No datasets found in '{path}'. File may be corrupted. "
"A redownload of the file may be required.",
))

processed_das: list[xr.DataArray] = []
for i, ds in enumerate(dss):
try:
da: xr.DataArray = (
ds.pipe(ECMWFRealTimeS3ModelRepository._rename_vars)
ECMWFRealTimeS3ModelRepository._rename_or_drop_vars(
ds=ds,
allowed_parameters=ECMWFRealTimeS3ModelRepository.model().expected_coordinates.variable,
)
.rename(name_dict={"time": "init_time"})
.expand_dims(dim="init_time")
.expand_dims(dim="step")
Expand Down Expand Up @@ -274,36 +280,6 @@ def _convert(path: pathlib.Path) -> ResultE[list[xr.DataArray]]:

return Success(processed_das)

@staticmethod
def _rename_vars(ds: xr.Dataset) -> xr.Dataset:
"""Rename variables to match the expected names."""
rename_map: dict[str, str] = {
"dsrp": entities.Parameter.DIRECT_SHORTWAVE_RADIATION_FLUX_GL.value,
"uvb": entities.Parameter.DOWNWARD_ULTRAVIOLET_RADIATION_FLUX_GL.value,
"sd": entities.Parameter.SNOW_DEPTH_GL.value,
"tcc": entities.Parameter.CLOUD_COVER_TOTAL.value,
"clt": entities.Parameter.CLOUD_COVER_TOTAL.value,
"u10": entities.Parameter.WIND_U_COMPONENT_10m.value,
"v10": entities.Parameter.WIND_V_COMPONENT_10m.value,
"t2m": entities.Parameter.TEMPERATURE_SL.value,
"ssrd": entities.Parameter.DOWNWARD_SHORTWAVE_RADIATION_FLUX_GL.value,
"strd": entities.Parameter.DOWNWARD_LONGWAVE_RADIATION_FLUX_GL.value,
"lcc": entities.Parameter.CLOUD_COVER_LOW.value,
"mcc": entities.Parameter.CLOUD_COVER_MEDIUM.value,
"hcc": entities.Parameter.CLOUD_COVER_HIGH.value,
"vis": entities.Parameter.VISIBILITY_SL.value,
"u200": entities.Parameter.WIND_U_COMPONENT_200m.value,
"v200": entities.Parameter.WIND_V_COMPONENT_200m.value,
"u100": entities.Parameter.WIND_U_COMPONENT_100m.value,
"v100": entities.Parameter.WIND_V_COMPONENT_100m.value,
"tprate": entities.Parameter.TOTAL_PRECIPITATION_RATE_GL.value,
}

for old, new in rename_map.items():
if old in ds.data_vars:
ds = ds.rename({old: new})
return ds

@staticmethod
def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool:
"""Determine if the file is wanted based on the init time.
Expand All @@ -329,3 +305,25 @@ def _wanted_file(filename: str, it: dt.datetime, max_step: int) -> bool:
"%Y%m%d%H%M%z",
)
return tt < it + dt.timedelta(hours=max_step)


@staticmethod
def _rename_or_drop_vars(ds: xr.Dataset, allowed_parameters: list[entities.Parameter]) \
-> xr.Dataset:
"""Rename variables to match the expected names, dropping invalid ones.
Args:
ds: The xarray dataset to rename.
allowed_parameters: The list of parameters allowed in the resultant dataset.
"""
for var in ds.data_vars:
param_result = entities.Parameter.try_from_alternate(str(var))
match param_result:
case Success(p):
if p in allowed_parameters:
ds = ds.rename_vars({var: p.value})
continue
log.warning("Dropping invalid parameter '%s' from dataset", var)
ds = ds.drop_vars(str(var))
return ds

Loading

0 comments on commit 0f6c84b

Please sign in to comment.