From 4ab1b0bae4b48852f941c51fac31c27144354900 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 15 Dec 2021 15:44:14 -0600 Subject: [PATCH] allow empty sub-containers in reductions --- arraycontext/container/traversal.py | 36 +++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 07c15446..d40d6418 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -464,7 +464,7 @@ def rec_map_reduce_array_container( or any other such traversal. """ - def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: + def rec(_ary: ArrayOrContainerT) -> Optional[ArrayOrContainerT]: if type(_ary) is leaf_class: return map_func(_ary) else: @@ -473,11 +473,22 @@ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT: except NotAnArrayContainerError: return map_func(_ary) else: - return reduce_func([ - rec(subary) for _, subary in iterable - ]) + subary_results = [ + rec(subary) for _, subary in iterable] + filtered_subary_results = [ + result for result in subary_results + if result is not None] + if len(filtered_subary_results) > 0: + return reduce_func(filtered_subary_results) + else: + return None - return rec(ary) + result = rec(ary) + + if result is None: + raise ValueError("cannot reduce empty array container") + + return result def rec_multimap_reduce_array_container( @@ -503,12 +514,23 @@ def rec_multimap_reduce_array_container( # NOTE: this wrapper matches the signature of `deserialize_container` # to make plugging into `_multimap_array_container_impl` easier def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any: - return reduce_func([subary for _, subary in iterable]) + filtered_subary_results = [ + result for _, result in iterable + if result is not None] + if len(filtered_subary_results) > 0: + return reduce_func(filtered_subary_results) + else: + return None - return _multimap_array_container_impl( + result = _multimap_array_container_impl( map_func, *args, reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True) + if result is None: + raise ValueError("cannot reduce empty array container") + + return result + # }}}