Skip to content

Commit

Permalink
Merge pull request #588 from pangeo-forge/target-root-default
Browse files Browse the repository at this point in the history
Default value for `target_root`, to make recipe modules importable
  • Loading branch information
cisaacstern authored Sep 1, 2023
2 parents 5e9eae4 + 7c7c23d commit f4202a6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 15 deletions.
31 changes: 30 additions & 1 deletion pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union

# PEP612 Concatenate & ParamSpec are useful for annotating decorators, but their import
# differs between Python versions 3.9 & 3.10. See: https://stackoverflow.com/a/71990006
Expand Down Expand Up @@ -69,6 +69,19 @@
P = ParamSpec("P")


class RequiredAtRuntimeDefault:
"""Sentinel class to use as default for transform attributes which are required to run a
pipeline, but may not be available (or preferable) to define during recipe develoment; for
example, the ``target_root`` kwarg of a transform that writes data to a target location. By
using this sentinel as the default value for such an kwarg, a recipe module can define all
required arguments on the transform (and therefore be importable, satisfy type-checkers, be
unit-testable, etc.) before it is deployed, with the understanding that the attribute using
this sentinel as default will be re-assigned to the desired value at deploy time.
"""

pass


# TODO: replace with beam.MapTuple?
def _add_keys(
func: Callable[Concatenate[T, P], R],
Expand Down Expand Up @@ -434,9 +447,17 @@ def expand(self, references: beam.PCollection) -> beam.PCollection:
class WriteCombinedReference(beam.PTransform, ZarrWriterMixin):
"""Store a singleton PCollection consisting of a ``kerchunk.combine.MultiZarrToZarr`` object.
:param store_name: Name for the Zarr store. It will be created with
this name under `target_root`.
:param target_root: Root path the Zarr store will be created inside;
`store_name` will be appended to this prefix to create a full path.
:param output_json_fname: Name to give the output references file. Must end in ``.json``.
"""

store_name: str
target_root: Union[str, FSSpecTarget, RequiredAtRuntimeDefault] = field(
default_factory=RequiredAtRuntimeDefault
)
output_json_fname: str = "reference.json"

def expand(self, reference: beam.PCollection) -> beam.PCollection:
Expand All @@ -452,13 +473,21 @@ class StoreToZarr(beam.PTransform, ZarrWriterMixin):
"""Store a PCollection of Xarray datasets to Zarr.
:param combine_dims: The dimensions to combine
:param store_name: Name for the Zarr store. It will be created with
this name under `target_root`.
:param target_root: Root path the Zarr store will be created inside;
`store_name` will be appended to this prefix to create a full path.
:param target_chunks: Dictionary mapping dimension names to chunks sizes.
If a dimension is a not named, the chunks will be inferred from the data.
"""

# TODO: make it so we don't have to explicitly specify combine_dims
# Could be inferred from the pattern instead
combine_dims: List[Dimension]
store_name: str
target_root: Union[str, FSSpecTarget, RequiredAtRuntimeDefault] = field(
default_factory=RequiredAtRuntimeDefault
)
target_chunks: Dict[str, int] = field(default_factory=dict)

def expand(
Expand Down
32 changes: 18 additions & 14 deletions pangeo_forge_recipes/writers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from dataclasses import dataclass
from typing import Tuple, Union
from typing import Protocol, Tuple, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -114,22 +113,27 @@ def write_combined_reference(
raise NotImplementedError(f"{file_ext = } not supported.")


@dataclass
class ZarrWriterMixin:
"""Defines common attributes and methods for storing zarr datasets, which can be either actual
zarr stores or virtual (i.e. kerchunked) stores. This class should not be directly instantiated.
Instead, PTransforms in the `.transforms` module which write consolidated zarr stores should
inherit from this mixin, so that they share a common interface for target store naming.
:param target_root: Location the Zarr store will be created inside.
:param store_name: Name for the Zarr store. It will be created with this name
under `target_root`.
class ZarrWriterProtocol(Protocol):
"""Protocol for mixin typing, following best practices described in:
https://mypy.readthedocs.io/en/stable/more_types.html#mixin-classes.
When used as a type hint for the `self` argument on mixin classes, this protocol just tells type
checkers that the given method is expected to be called in the context of a class which defines
the attributes declared here. This satisfies type checkers without the need to define these
attributes more than once in an inheritance heirarchy.
"""

target_root: Union[str, FSSpecTarget]
store_name: str
target_root: Union[str, FSSpecTarget]


class ZarrWriterMixin:
"""Defines common methods relevant to storing zarr datasets, which can be either actual zarr
stores or virtual (i.e. kerchunked) stores. This class should not be directly instantiated.
Instead, PTransforms in the `.transforms` module which write zarr stores should inherit from
this mixin, so that they share a common interface for target store naming.
"""

def get_full_target(self) -> FSSpecTarget:
def get_full_target(self: ZarrWriterProtocol) -> FSSpecTarget:
if isinstance(self.target_root, str):
target_root = FSSpecTarget.from_url(self.target_root)
else:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,17 @@ def _is_xr_dataset(actual):
)
open_store = target_store | OpenZarrStore()
assert_that(open_store, is_xrdataset())


def test_StoreToZarr_target_root_default_unrunnable(
pipeline,
netcdf_local_file_pattern_sequential,
):
pattern: FilePattern = netcdf_local_file_pattern_sequential
with pytest.raises(TypeError, match=r"unsupported operand"):
with pipeline as p:
datasets = p | beam.Create(pattern.items()) | OpenWithXarray()
_ = datasets | StoreToZarr(
store_name="test.zarr",
combine_dims=pattern.combine_dim_keys,
)

0 comments on commit f4202a6

Please sign in to comment.