Skip to content

Commit

Permalink
Remove add_keys_iter
Browse files Browse the repository at this point in the history
  • Loading branch information
moradology committed Feb 26, 2024
1 parent 07d7465 commit 05f1a38
Showing 1 changed file with 12 additions and 29 deletions.
41 changes: 12 additions & 29 deletions pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import random
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union

# PEP612 Concatenate & ParamSpec are useful for annotating decorators, but their import
# differs between Python versions 3.9 & 3.10. See: https://stackoverflow.com/a/71990006
if sys.version_info < (3, 10):
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
else:
from typing import Concatenate, ParamSpec
from typing import ParamSpec

import apache_beam as beam
import fsspec
Expand Down Expand Up @@ -90,28 +90,6 @@ class RequiredAtRuntimeDefault:
pass


def _add_keys_iter(
func: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[Iterable[IndexedArg], P], Iterator[IndexedReturn]]:
"""Convenience decorator to iteratively remove and re-add keys to items in a FlatMap"""
annotations = func.__annotations__.copy()
arg_name, annotation = next(iter(annotations.items()))
return_annotation = annotations["return"]

# mypy doesn't view `annotation` and `return_annotation` as valid types, so ignore
annotations[arg_name] = Iterable[Tuple[Index, annotation]] # type: ignore
annotations["return"] = Iterator[Tuple[Index, return_annotation]] # type: ignore

def iterable_wrapper(arg, *args: P.args, **kwargs: P.kwargs):
for inner_item in arg:
key, item = inner_item
result = func(item, *args, **kwargs)
yield key, result

iterable_wrapper.__annotations__ = annotations
return iterable_wrapper


def _assign_concurrency_group(elem, max_concurrency: int):
return (random.randint(0, max_concurrency - 1), elem)

Expand Down Expand Up @@ -144,11 +122,16 @@ def expand(self, pcoll):
if not self.max_concurrency
else (
pcoll
| beam.Map(_assign_concurrency_group, self.max_concurrency)
| beam.GroupByKey()
| beam.Values()
| "Assign concurrency key"
>> beam.Map(_assign_concurrency_group, self.max_concurrency)
| "Group together by concurrency key" >> beam.GroupByKey()
| "Drop concurrency key" >> beam.Values()
| f"{self.fn.__name__} (max_concurrency={self.max_concurrency})"
>> beam.FlatMap(_add_keys_iter(self.fn), *self.args, **self.kwargs)
>> beam.FlatMap(
lambda kvlist: [
(kv[0], self.fn(kv[1], *self.args, **self.kwargs)) for kv in kvlist
]
)
)
)

Expand Down

0 comments on commit 05f1a38

Please sign in to comment.