Skip to content

Commit

Permalink
Merge pull request #700 from pangeo-forge/feature/remove-add-keys-iter
Browse files Browse the repository at this point in the history
Remove add_keys_iter
  • Loading branch information
moradology authored Feb 26, 2024
2 parents 07d7465 + 05f1a38 commit e39a205
Showing 1 changed file with 12 additions and 29 deletions.
41 changes: 12 additions & 29 deletions pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import random
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union

# PEP612 Concatenate & ParamSpec are useful for annotating decorators, but their import
# differs between Python versions 3.9 & 3.10. See: https://stackoverflow.com/a/71990006
if sys.version_info < (3, 10):
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
else:
from typing import Concatenate, ParamSpec
from typing import ParamSpec

import apache_beam as beam
import fsspec
Expand Down Expand Up @@ -90,28 +90,6 @@ class RequiredAtRuntimeDefault:
pass


def _add_keys_iter(
func: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[Iterable[IndexedArg], P], Iterator[IndexedReturn]]:
"""Convenience decorator to iteratively remove and re-add keys to items in a FlatMap"""
annotations = func.__annotations__.copy()
arg_name, annotation = next(iter(annotations.items()))
return_annotation = annotations["return"]

# mypy doesn't view `annotation` and `return_annotation` as valid types, so ignore
annotations[arg_name] = Iterable[Tuple[Index, annotation]] # type: ignore
annotations["return"] = Iterator[Tuple[Index, return_annotation]] # type: ignore

def iterable_wrapper(arg, *args: P.args, **kwargs: P.kwargs):
for inner_item in arg:
key, item = inner_item
result = func(item, *args, **kwargs)
yield key, result

iterable_wrapper.__annotations__ = annotations
return iterable_wrapper


def _assign_concurrency_group(elem, max_concurrency: int):
return (random.randint(0, max_concurrency - 1), elem)

Expand Down Expand Up @@ -144,11 +122,16 @@ def expand(self, pcoll):
if not self.max_concurrency
else (
pcoll
| beam.Map(_assign_concurrency_group, self.max_concurrency)
| beam.GroupByKey()
| beam.Values()
| "Assign concurrency key"
>> beam.Map(_assign_concurrency_group, self.max_concurrency)
| "Group together by concurrency key" >> beam.GroupByKey()
| "Drop concurrency key" >> beam.Values()
| f"{self.fn.__name__} (max_concurrency={self.max_concurrency})"
>> beam.FlatMap(_add_keys_iter(self.fn), *self.args, **self.kwargs)
>> beam.FlatMap(
lambda kvlist: [
(kv[0], self.fn(kv[1], *self.args, **self.kwargs)) for kv in kvlist
]
)
)
)

Expand Down

0 comments on commit e39a205

Please sign in to comment.