Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Seeding Scripts For Analysis Ready Dataset #53

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/arco_era5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .source_data import GCP_DIRECTORY,SINGLE_LEVEL_VARIABLES,MULTILEVEL_VARIABLES,PRESSURE_LEVELS_GROUPS, TIME_RESOLUTION_HOURS
from .source_data import get_var_attrs_dict, read_multilevel_vars, read_single_level_vars, daily_date_iterator, align_coordinates, parse_arguments
from .source_data import GCP_DIRECTORY,SINGLE_LEVEL_VARIABLES,MULTILEVEL_VARIABLES,PRESSURE_LEVELS_GROUPS, TIME_RESOLUTION_HOURS, HOURS_PER_DAY
from .source_data import get_var_attrs_dict, read_multilevel_vars, read_single_level_vars, daily_date_iterator, align_coordinates, parse_arguments, get_pressure_levels_arg, LoadTemporalDataForDateDoFn
from .pangeo import run, parse_args
from .update import UpdateSlice
83 changes: 74 additions & 9 deletions src/arco_era5/source_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

__author__ = 'Matthew Willson, Alvaro Sanchez, Peter Battaglia, Stephan Hoyer, Stephan Rasp'

import apache_beam as beam
import argparse
import datetime
import fsspec
import immutabledict
import logging

import pathlib
import xarray

import numpy as np
import pandas as pd
import typing as t
import xarray as xr
import xarray_beam as xb

TIME_RESOLUTION_HOURS = 1

Expand Down Expand Up @@ -334,6 +337,7 @@
"geopotential_at_surface": "geopotential"
}

HOURS_PER_DAY = 24

def _read_nc_dataset(gpath_file):
"""Read a .nc NetCDF dataset from a cloud storage path and disk.
Expand All @@ -352,7 +356,7 @@ def _read_nc_dataset(gpath_file):
"""
path = str(gpath_file).replace('gs:/', 'gs://')
with fsspec.open(path, mode="rb") as fid:
dataset = xarray.open_dataset(fid, engine="scipy", cache=False)
dataset = xr.open_dataset(fid, engine="scipy", cache=False)
# All dataset have a single data array in them, so we just return the array.
assert len(dataset) == 1
dataarray = next(iter(dataset.values()))
Expand All @@ -372,12 +376,12 @@ def _read_nc_dataset(gpath_file):
# and: https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Dataupdatefrequency # pylint: disable=line-too-long
# for further details.

all_dims_except_time = tuple(set(dataarray.dims) - {"time"})
all_dims_except_time = tuple(set(dataarray.dims) - {"time", "expver"})
# Should have only trailing nans.
a = dataarray.sel(expver=1).isnull().any(dim=all_dims_except_time)
# Should having only leading nans.
b = dataarray.sel(expver=5).isnull().any(dim=all_dims_except_time)
disjoint_nans = bool(next(iter((a ^ b).all().data_vars.values())))
disjoint_nans = bool((a ^ b).all().variable.values)
assert disjoint_nans, "The nans are not disjoint in expver=1 vs 5"
dataarray = dataarray.sel(expver=1).combine_first(dataarray.sel(expver=5))
return dataarray
Expand Down Expand Up @@ -412,7 +416,7 @@ def read_single_level_vars(year, month, day, variables=SINGLE_LEVEL_VARIABLES,
relative_path = SINGLE_LEVEL_SUBDIR_TEMPLATE.format(
year=year, month=month, day=day, variable=era5_variable)
output[variable] = _read_nc_dataset(root_path / relative_path)
return xarray.Dataset(output)
return xr.Dataset(output)


def read_multilevel_vars(year,
Expand Down Expand Up @@ -451,8 +455,8 @@ def read_multilevel_vars(year,
single_level_data_array.coords["level"] = pressure_level
pressure_data.append(
single_level_data_array.expand_dims(dim="level", axis=1))
output[variable] = xarray.concat(pressure_data, dim="level")
return xarray.Dataset(output)
output[variable] = xr.concat(pressure_data, dim="level")
return xr.Dataset(output)


def get_var_attrs_dict(root_path=GCP_DIRECTORY):
Expand Down Expand Up @@ -558,6 +562,63 @@ def align_coordinates(dataset: xr.Dataset) -> xr.Dataset:

return dataset

def get_pressure_levels_arg(pressure_levels_group: str):
return PRESSURE_LEVELS_GROUPS[pressure_levels_group]

DarshanSP19 marked this conversation as resolved.
Show resolved Hide resolved

class LoadTemporalDataForDateDoFn(beam.DoFn):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstring.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a PTransform instead?

def __init__(self, data_path, start_date, pressure_levels_group):
self.data_path = data_path
self.start_date = start_date
self.pressure_levels_group = pressure_levels_group

def process(self, args):
"""Loads temporal data for a day, with an xarray_beam key for it."""
year, month, day = args
logging.info("Loading NetCDF files for %d-%d-%d", year, month, day)

try:
single_level_vars = read_single_level_vars(
year,
month,
day,
variables=SINGLE_LEVEL_VARIABLES,
root_path=self.data_path)
multilevel_vars = read_multilevel_vars(
year,
month,
day,
variables=MULTILEVEL_VARIABLES,
pressure_levels=get_pressure_levels_arg(self.pressure_levels_group),
root_path=self.data_path)
except BaseException as e:
# Make sure we print the date as part of the error for easier debugging
# if something goes wrong. Note "from e" will also raise the details of the
# original exception.
raise Exception(f"Error loading {year}-{month}-{day}") from e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this?


# It is crucial to actually "load" as otherwise we get a pickle error.
single_level_vars = single_level_vars.load()
multilevel_vars = multilevel_vars.load()

dataset = xr.merge([single_level_vars, multilevel_vars])
dataset = align_coordinates(dataset)
offsets = {"latitude": 0, "longitude": 0, "level": 0,
"time": offset_along_time_axis(self.start_date, year, month, day)}
key = xb.Key(offsets, vars=set(dataset.data_vars.keys()))
logging.info("Finished loading NetCDF files for %s-%s-%s", year, month, day)
yield key, dataset
dataset.close()

DarshanSP19 marked this conversation as resolved.
Show resolved Hide resolved

def offset_along_time_axis(start_date: str, year: int, month: int, day: int) -> int:
"""Offset in indices along the time axis, relative to start of the dataset."""
# Note the length of years can vary due to leap years, so the chunk lengths
# will not always be the same, and we need to do a proper date calculation
# not just multiply by 365*24.
time_delta = pd.Timestamp(
year=year, month=month, day=day) - pd.Timestamp(start_date)
return time_delta.days * HOURS_PER_DAY // TIME_RESOLUTION_HOURS

def parse_arguments(desc: str) -> t.Tuple[argparse.Namespace, t.List[str]]:
"""Parse command-line arguments for the data processing pipeline.
Expand All @@ -580,14 +641,18 @@ def parse_arguments(desc: str) -> t.Tuple[argparse.Namespace, t.List[str]]:
help='Start date, iso format string.')
parser.add_argument('-e', "--end_date", default='2020-01-02',
help='End date, iso format string.')
parser.add_argument("--temp_location", type=str, required=True,
help="A temp location where this data is stored temporarily.")
parser.add_argument('--find-missing', action='store_true', default=False,
help='Print all paths to missing input data.') # implementation pending
parser.add_argument("--pressure_levels_group", type=str, default="weatherbench_13",
help="Group label for the set of pressure levels to use.")
parser.add_argument("--time_chunk_size", type=int, required=True,
help="Number of 1-hourly timesteps to include in a \
single chunk. Must evenly divide 24.")
parser.add_argument("--init_date", type=str, default='1900-01-01',
help="Date to initialize the zarr store.")
parser.add_argument("--from_init_date", action='store_true', default=False,
help="To initialize the store from some previous date (--init_date). i.e. 1900-01-01")
parser.add_argument("--only_initialize_store", action='store_true', default=False,
help="Initialize zarr store without data.")

return parser.parse_known_args()
36 changes: 36 additions & 0 deletions src/arco_era5/update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import apache_beam as beam
import datetime
import logging
import xarray as xr
import zarr

from arco_era5 import HOURS_PER_DAY
from dataclasses import dataclass
from typing import Tuple

logger = logging.getLogger(__name__)

@dataclass
class UpdateSlice(beam.PTransform):

target: str
init_date: str

def apply(self, offset_ds: Tuple[int, xr.Dataset, str]):
"""Generate region slice and update zarr array directly"""
key, ds = offset_ds
offset = key.offsets['time']
date = datetime.datetime.strptime(self.init_date, '%Y-%m-%d') + datetime.timedelta(days=offset / HOURS_PER_DAY)
date_str = date.strftime('%Y-%m-%d')
zf = zarr.open(self.target)
region = slice(offset, offset + HOURS_PER_DAY)
for vname in ds.data_vars:
logger.info(f"Started {vname} for {date_str}")
zv = zf[vname]
zv[region] = ds[vname].values
logger.info(f"Done {vname} for {date_str}")
del zv
del ds

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
return pcoll | beam.Map(self.apply)
Loading