diff --git a/modin/core/execution/dask/implementations/pandas_on_dask/partitioning/virtual_partition.py b/modin/core/execution/dask/implementations/pandas_on_dask/partitioning/virtual_partition.py index ff3ba09bc7a..276e1b9fef3 100644 --- a/modin/core/execution/dask/implementations/pandas_on_dask/partitioning/virtual_partition.py +++ b/modin/core/execution/dask/implementations/pandas_on_dask/partitioning/virtual_partition.py @@ -32,8 +32,10 @@ class PandasOnDaskDataframeVirtualPartition(PandasDataframeAxisPartition): Parameters ---------- - list_of_blocks : list - List of ``PandasOnDaskDataframePartition`` objects. + list_of_blocks : Union[list, PandasOnDaskDataframePartition] + List of ``PandasOnDaskDataframePartition`` and + ``PandasOnDaskDataframeVirtualPartition`` objects, or a single + ``PandasOnDaskDataframePartition``. get_ip : bool, default: False Whether to get node IP addresses of conforming partitions or not. full_axis : bool, default: True @@ -45,16 +47,21 @@ class PandasOnDaskDataframeVirtualPartition(PandasDataframeAxisPartition): axis = None def __init__(self, list_of_blocks, get_ip=False, full_axis=True, call_queue=None): - self.call_queue = call_queue or [] - self.full_axis = full_axis if isinstance(list_of_blocks, PandasOnDaskDataframePartition): list_of_blocks = [list_of_blocks] + self.call_queue = call_queue or [] + self.full_axis = full_axis + # In the simple case, none of the partitions that will compose this + # partition are themselves virtual partition. The partitions that will + # be combined are just the partitions as given to the constructor. if not any( isinstance(o, PandasOnDaskDataframeVirtualPartition) for o in list_of_blocks ): self.list_of_partitions_to_combine = list_of_blocks return - + # Check that all axis are the same in `list_of_blocks` + # We should never have mismatching axis in the current implementation. We add this + # defensive assertion to ensure that undefined behavior does not happen. assert ( len( set( @@ -65,6 +72,8 @@ def __init__(self, list_of_blocks, get_ip=False, full_axis=True, call_queue=None ) == 1 ) + # When the axis of all virtual partitions matches this axis, + # extend and combine the lists of physical partitions. if ( next( o @@ -83,6 +92,7 @@ def __init__(self, list_of_blocks, get_ip=False, full_axis=True, call_queue=None o ) self.list_of_partitions_to_combine = new_list_of_blocks + # Materialize partitions if the axis of this virtual does not match the virtual partitions else: self.list_of_partitions_to_combine = [ obj.force_materialization().list_of_partitions_to_combine[0] @@ -277,6 +287,8 @@ def apply( A list of `PandasOnDaskDataframeVirtualPartition` objects. """ if not self.full_axis: + # If this is not a full axis partition, it already contains a subset of + # the full axis, so we shouldn't split the result further. num_splits = 1 if len(self.call_queue) > 0: self.drain_call_queue() @@ -286,6 +298,7 @@ def apply( if self.full_axis: return result else: + # If this is a full axis partition, just take out the single split in the result. return result[0] def force_materialization(self, get_ip=False): @@ -381,15 +394,24 @@ def width(self): self._width_cache = self.list_of_partitions_to_combine[0].width() return self._width_cache - def drain_call_queue(self): - """Execute all operations stored in this partition's call queue.""" + def drain_call_queue(self, num_splits=None): + """ + Execute all operations stored in this partition's call queue. + + Parameters + ---------- + num_splits : int, default: None + The number of times to split the result object. + """ def drain(df): for func, args, kwargs in self.call_queue: df = func(df, *args, **kwargs) return df - drained = super(PandasOnDaskDataframeVirtualPartition, self).apply(drain) + drained = super(PandasOnDaskDataframeVirtualPartition, self).apply( + drain, num_splits=num_splits + ) self.list_of_partitions_to_combine = drained self.call_queue = []