Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580488801
  • Loading branch information
Weatherbench authors committed Nov 8, 2023
1 parent 0fa2cba commit 1167715
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
32 changes: 31 additions & 1 deletion scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from weatherbench2 import flag_utils
from weatherbench2 import metrics
from weatherbench2.derived_variables import DERIVED_VARIABLE_DICT
from weatherbench2.regions import CombinedRegion
from weatherbench2.regions import LandRegion
from weatherbench2.regions import SliceRegion
import xarray as xr

Expand Down Expand Up @@ -91,7 +93,9 @@
EVALUATE_CLIMATOLOGY = flags.DEFINE_bool(
'evaluate_climatology',
False,
'Evaluate climatology forecast specified in climatology path',
'Evaluate climatology forecast specified in climatology path. Note that'
' this will not work for probabilistic evaluation. Please use the'
' EVALUATE_PROBABILISTIC_CLIMATOLOGY flag.',
)
EVALUATE_PROBABILISTIC_CLIMATOLOGY = flags.DEFINE_bool(
'evaluate_probabilistic_climatology',
Expand Down Expand Up @@ -122,6 +126,15 @@
'predefined regions.'
),
)
LSM_DATASET = flags.DEFINE_string(
'lsm_dataset',
None,
help=(
'Dataset containing land-sea-mask at same resolution of datasets to be'
' evaluated. Required if region with land-sea-mask is picked. If None,'
' defaults to observation dataset.'
),
)
COMPUTE_SEEPS = flags.DEFINE_bool(
'compute_seeps', False, 'Compute SEEPS for total_precipitation_24hr.'
)
Expand Down Expand Up @@ -305,6 +318,23 @@ def main(argv: list[str]) -> None:
'arctic': SliceRegion(lat_slice=slice(60, 90)),
'antarctic': SliceRegion(lat_slice=slice(-90, -60)),
}
try:
if LSM_DATASET.value:
land_sea_mask = xr.open_zarr(LSM_DATASET.value)['land_sea_mask'].compute()
else:
land_sea_mask = xr.open_zarr(OBS_PATH.value)['land_sea_mask'].compute()
land_regions = {
'global_land': LandRegion(land_sea_mask=land_sea_mask),
'extra-tropics_land': CombinedRegion(
regions=[
SliceRegion(lat_slice=[slice(None, -20), slice(20, None)]),
LandRegion(land_sea_mask=land_sea_mask),
]
),
}
predefined_regions = predefined_regions | land_regions
except KeyError:
print('No land_sea_mask found.')
if REGIONS.value == ['all']:
regions = predefined_regions
elif REGIONS.value is None:
Expand Down
2 changes: 1 addition & 1 deletion weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def _evaluate(
)
forecast_pipeline |= beam.MapTuple(
self._climatology_like_forecast_chunk,
probabilistic_climatology=probabilistic_climatology,
climatology=probabilistic_climatology,
variables=variables,
)
elif self.eval_config.evaluate_persistence:
Expand Down
31 changes: 28 additions & 3 deletions weatherbench2/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Region:

def apply(
self, dataset: xr.Dataset, weights: xr.DataArray
) -> tuple[xr.Dataset, xr.Dataset]:
) -> tuple[xr.Dataset, xr.DataArray]:
"""Apply region selection to dataset and/or weights.
Args:
Expand All @@ -48,8 +48,8 @@ def apply(
Returns:
dataset: Potentially modified (sliced) dataset.
weights: Potentially modified weights dataset, to be used in combination
with dataset, e.g. in _spatial_average().
weights: Potentially modified weights data array, to be used in
combination with dataset, e.g. in _spatial_average().
"""
raise NotImplementedError

Expand Down Expand Up @@ -128,6 +128,31 @@ def apply( # pytype: disable=signature-mismatch
) -> tuple[xr.Dataset, xr.DataArray]:
"""Returns weights multiplied with a boolean land mask."""
land_weights = self.land_sea_mask
# Make sure lsm has same dtype for lat/lon
land_weights = land_weights.assign_coords(
latitude=land_weights.latitude.astype(dataset.latitude.dtype),
longitude=land_weights.longitude.astype(dataset.longitude.dtype),
)
if self.threshold is not None:
land_weights = (land_weights > self.threshold).astype(float)
return dataset, weights * land_weights


@dataclasses.dataclass
class CombinedRegion(Region):
"""Sequentially applies regions selections.
Allows for combination of e.g. SliceRegion and LandRegion.
Attributes:
regions: List of Region instances
"""

regions: list[Region] = dataclasses.field(default_factory=list)

def apply( # pytype: disable=signature-mismatch
self, dataset: xr.Dataset, weights: xr.DataArray
) -> tuple[xr.Dataset, xr.DataArray]:
for region in self.regions:
dataset, weights = region.apply(dataset, weights)
return dataset, weights

0 comments on commit 1167715

Please sign in to comment.