diff --git a/pangeo_forge_recipes/transforms.py b/pangeo_forge_recipes/transforms.py index 433ad860..efe40734 100644 --- a/pangeo_forge_recipes/transforms.py +++ b/pangeo_forge_recipes/transforms.py @@ -5,14 +5,14 @@ import random import sys from dataclasses import dataclass, field -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union # PEP612 Concatenate & ParamSpec are useful for annotating decorators, but their import # differs between Python versions 3.9 & 3.10. See: https://stackoverflow.com/a/71990006 if sys.version_info < (3, 10): - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec else: - from typing import Concatenate, ParamSpec + from typing import ParamSpec import apache_beam as beam import fsspec @@ -90,28 +90,6 @@ class RequiredAtRuntimeDefault: pass -def _add_keys_iter( - func: Callable[Concatenate[T, P], R], -) -> Callable[Concatenate[Iterable[IndexedArg], P], Iterator[IndexedReturn]]: - """Convenience decorator to iteratively remove and re-add keys to items in a FlatMap""" - annotations = func.__annotations__.copy() - arg_name, annotation = next(iter(annotations.items())) - return_annotation = annotations["return"] - - # mypy doesn't view `annotation` and `return_annotation` as valid types, so ignore - annotations[arg_name] = Iterable[Tuple[Index, annotation]] # type: ignore - annotations["return"] = Iterator[Tuple[Index, return_annotation]] # type: ignore - - def iterable_wrapper(arg, *args: P.args, **kwargs: P.kwargs): - for inner_item in arg: - key, item = inner_item - result = func(item, *args, **kwargs) - yield key, result - - iterable_wrapper.__annotations__ = annotations - return iterable_wrapper - - def _assign_concurrency_group(elem, max_concurrency: int): return (random.randint(0, max_concurrency - 1), elem) @@ -144,11 +122,16 @@ def expand(self, pcoll): if not self.max_concurrency else ( pcoll - | beam.Map(_assign_concurrency_group, self.max_concurrency) - | beam.GroupByKey() - | beam.Values() + | "Assign concurrency key" + >> beam.Map(_assign_concurrency_group, self.max_concurrency) + | "Group together by concurrency key" >> beam.GroupByKey() + | "Drop concurrency key" >> beam.Values() | f"{self.fn.__name__} (max_concurrency={self.max_concurrency})" - >> beam.FlatMap(_add_keys_iter(self.fn), *self.args, **self.kwargs) + >> beam.FlatMap( + lambda kvlist: [ + (kv[0], self.fn(kv[1], *self.args, **self.kwargs)) for kv in kvlist + ] + ) ) )