diff --git a/pangeo_forge_recipes/transforms.py b/pangeo_forge_recipes/transforms.py index 226025ec..50b9d409 100644 --- a/pangeo_forge_recipes/transforms.py +++ b/pangeo_forge_recipes/transforms.py @@ -4,7 +4,7 @@ import random import sys from dataclasses import dataclass, field -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union # PEP612 Concatenate & ParamSpec are useful for annotating decorators, but their import # differs between Python versions 3.9 & 3.10. See: https://stackoverflow.com/a/71990006 @@ -69,6 +69,19 @@ P = ParamSpec("P") +class RequiredAtRuntimeDefault: + """Sentinel class to use as default for transform attributes which are required to run a + pipeline, but may not be available (or preferable) to define during recipe develoment; for + example, the ``target_root`` kwarg of a transform that writes data to a target location. By + using this sentinel as the default value for such an kwarg, a recipe module can define all + required arguments on the transform (and therefore be importable, satisfy type-checkers, be + unit-testable, etc.) before it is deployed, with the understanding that the attribute using + this sentinel as default will be re-assigned to the desired value at deploy time. + """ + + pass + + # TODO: replace with beam.MapTuple? def _add_keys( func: Callable[Concatenate[T, P], R], @@ -434,9 +447,17 @@ def expand(self, references: beam.PCollection) -> beam.PCollection: class WriteCombinedReference(beam.PTransform, ZarrWriterMixin): """Store a singleton PCollection consisting of a ``kerchunk.combine.MultiZarrToZarr`` object. + :param store_name: Name for the Zarr store. It will be created with + this name under `target_root`. + :param target_root: Root path the Zarr store will be created inside; + `store_name` will be appended to this prefix to create a full path. :param output_json_fname: Name to give the output references file. Must end in ``.json``. """ + store_name: str + target_root: Union[str, FSSpecTarget, RequiredAtRuntimeDefault] = field( + default_factory=RequiredAtRuntimeDefault + ) output_json_fname: str = "reference.json" def expand(self, reference: beam.PCollection) -> beam.PCollection: @@ -452,6 +473,10 @@ class StoreToZarr(beam.PTransform, ZarrWriterMixin): """Store a PCollection of Xarray datasets to Zarr. :param combine_dims: The dimensions to combine + :param store_name: Name for the Zarr store. It will be created with + this name under `target_root`. + :param target_root: Root path the Zarr store will be created inside; + `store_name` will be appended to this prefix to create a full path. :param target_chunks: Dictionary mapping dimension names to chunks sizes. If a dimension is a not named, the chunks will be inferred from the data. """ @@ -459,6 +484,10 @@ class StoreToZarr(beam.PTransform, ZarrWriterMixin): # TODO: make it so we don't have to explicitly specify combine_dims # Could be inferred from the pattern instead combine_dims: List[Dimension] + store_name: str + target_root: Union[str, FSSpecTarget, RequiredAtRuntimeDefault] = field( + default_factory=RequiredAtRuntimeDefault + ) target_chunks: Dict[str, int] = field(default_factory=dict) def expand( diff --git a/pangeo_forge_recipes/writers.py b/pangeo_forge_recipes/writers.py index 71ae025f..4ce5b01e 100644 --- a/pangeo_forge_recipes/writers.py +++ b/pangeo_forge_recipes/writers.py @@ -1,6 +1,5 @@ import os -from dataclasses import dataclass -from typing import Tuple, Union +from typing import Protocol, Tuple, Union import numpy as np import xarray as xr @@ -114,22 +113,27 @@ def write_combined_reference( raise NotImplementedError(f"{file_ext = } not supported.") -@dataclass -class ZarrWriterMixin: - """Defines common attributes and methods for storing zarr datasets, which can be either actual - zarr stores or virtual (i.e. kerchunked) stores. This class should not be directly instantiated. - Instead, PTransforms in the `.transforms` module which write consolidated zarr stores should - inherit from this mixin, so that they share a common interface for target store naming. - - :param target_root: Location the Zarr store will be created inside. - :param store_name: Name for the Zarr store. It will be created with this name - under `target_root`. +class ZarrWriterProtocol(Protocol): + """Protocol for mixin typing, following best practices described in: + https://mypy.readthedocs.io/en/stable/more_types.html#mixin-classes. + When used as a type hint for the `self` argument on mixin classes, this protocol just tells type + checkers that the given method is expected to be called in the context of a class which defines + the attributes declared here. This satisfies type checkers without the need to define these + attributes more than once in an inheritance heirarchy. """ - target_root: Union[str, FSSpecTarget] store_name: str + target_root: Union[str, FSSpecTarget] + + +class ZarrWriterMixin: + """Defines common methods relevant to storing zarr datasets, which can be either actual zarr + stores or virtual (i.e. kerchunked) stores. This class should not be directly instantiated. + Instead, PTransforms in the `.transforms` module which write zarr stores should inherit from + this mixin, so that they share a common interface for target store naming. + """ - def get_full_target(self) -> FSSpecTarget: + def get_full_target(self: ZarrWriterProtocol) -> FSSpecTarget: if isinstance(self.target_root, str): target_root = FSSpecTarget.from_url(self.target_root) else: diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 296e5786..e40f77c1 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -264,3 +264,17 @@ def _is_xr_dataset(actual): ) open_store = target_store | OpenZarrStore() assert_that(open_store, is_xrdataset()) + + +def test_StoreToZarr_target_root_default_unrunnable( + pipeline, + netcdf_local_file_pattern_sequential, +): + pattern: FilePattern = netcdf_local_file_pattern_sequential + with pytest.raises(TypeError, match=r"unsupported operand"): + with pipeline as p: + datasets = p | beam.Create(pattern.items()) | OpenWithXarray() + _ = datasets | StoreToZarr( + store_name="test.zarr", + combine_dims=pattern.combine_dim_keys, + )