Skip to content

Commit

Permalink
Add comments to virtual partition, and update drain call queue to acc…
Browse files Browse the repository at this point in the history
…ept num_splits argument

Signed-off-by: Rehan Durrani <[email protected]>
  • Loading branch information
RehanSD committed Jun 15, 2022
1 parent 348272d commit 6784b4b
Showing 1 changed file with 30 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ class PandasOnDaskDataframeVirtualPartition(PandasDataframeAxisPartition):
Parameters
----------
list_of_blocks : list
List of ``PandasOnDaskDataframePartition`` objects.
list_of_blocks : Union[list, PandasOnDaskDataframePartition]
List of ``PandasOnDaskDataframePartition`` and
``PandasOnDaskDataframeVirtualPartition`` objects, or a single
``PandasOnDaskDataframePartition``.
get_ip : bool, default: False
Whether to get node IP addresses of conforming partitions or not.
full_axis : bool, default: True
Expand All @@ -45,16 +47,21 @@ class PandasOnDaskDataframeVirtualPartition(PandasDataframeAxisPartition):
axis = None

def __init__(self, list_of_blocks, get_ip=False, full_axis=True, call_queue=None):
self.call_queue = call_queue or []
self.full_axis = full_axis
if isinstance(list_of_blocks, PandasOnDaskDataframePartition):
list_of_blocks = [list_of_blocks]
self.call_queue = call_queue or []
self.full_axis = full_axis
# In the simple case, none of the partitions that will compose this
# partition are themselves virtual partition. The partitions that will
# be combined are just the partitions as given to the constructor.
if not any(
isinstance(o, PandasOnDaskDataframeVirtualPartition) for o in list_of_blocks
):
self.list_of_partitions_to_combine = list_of_blocks
return

# Check that all axis are the same in `list_of_blocks`
# We should never have mismatching axis in the current implementation. We add this
# defensive assertion to ensure that undefined behavior does not happen.
assert (
len(
set(
Expand All @@ -65,6 +72,8 @@ def __init__(self, list_of_blocks, get_ip=False, full_axis=True, call_queue=None
)
== 1
)
# When the axis of all virtual partitions matches this axis,
# extend and combine the lists of physical partitions.
if (
next(
o
Expand All @@ -83,6 +92,7 @@ def __init__(self, list_of_blocks, get_ip=False, full_axis=True, call_queue=None
o
)
self.list_of_partitions_to_combine = new_list_of_blocks
# Materialize partitions if the axis of this virtual does not match the virtual partitions
else:
self.list_of_partitions_to_combine = [
obj.force_materialization().list_of_partitions_to_combine[0]
Expand Down Expand Up @@ -277,6 +287,8 @@ def apply(
A list of `PandasOnDaskDataframeVirtualPartition` objects.
"""
if not self.full_axis:
# If this is not a full axis partition, it already contains a subset of
# the full axis, so we shouldn't split the result further.
num_splits = 1
if len(self.call_queue) > 0:
self.drain_call_queue()
Expand All @@ -286,6 +298,7 @@ def apply(
if self.full_axis:
return result
else:
# If this is a full axis partition, just take out the single split in the result.
return result[0]

def force_materialization(self, get_ip=False):
Expand Down Expand Up @@ -381,15 +394,24 @@ def width(self):
self._width_cache = self.list_of_partitions_to_combine[0].width()
return self._width_cache

def drain_call_queue(self):
"""Execute all operations stored in this partition's call queue."""
def drain_call_queue(self, num_splits=None):
"""
Execute all operations stored in this partition's call queue.
Parameters
----------
num_splits : int, default: None
The number of times to split the result object.
"""

def drain(df):
for func, args, kwargs in self.call_queue:
df = func(df, *args, **kwargs)
return df

drained = super(PandasOnDaskDataframeVirtualPartition, self).apply(drain)
drained = super(PandasOnDaskDataframeVirtualPartition, self).apply(
drain, num_splits=num_splits
)
self.list_of_partitions_to_combine = drained
self.call_queue = []

Expand Down

0 comments on commit 6784b4b

Please sign in to comment.