Skip to content

Commit

Permalink
implement locality awareness
Browse files Browse the repository at this point in the history
  • Loading branch information
EC2 Default User committed Nov 21, 2024
1 parent 274f300 commit ea89052
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 36 deletions.
7 changes: 7 additions & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class PartitionTask(Generic[PartitionT]):
# Indicates if the PartitionTask is "done" or not
is_done: bool = False

# Desired node_id to schedule this task
node_id: str | None = None

_id: int = field(default_factory=lambda: next(ID_GEN))

def id(self) -> str:
Expand Down Expand Up @@ -108,6 +111,7 @@ def __init__(
partial_metadatas: list[PartialPartitionMetadata] | None,
resource_request: ResourceRequest = ResourceRequest(),
actor_pool_id: str | None = None,
node_id: str | None = None,
) -> None:
self.inputs = inputs
if partial_metadatas is not None:
Expand All @@ -118,6 +122,7 @@ def __init__(
self.instructions: list[Instruction] = list()
self.num_results = len(inputs)
self.actor_pool_id = actor_pool_id
self.node_id = node_id

def add_instruction(
self,
Expand Down Expand Up @@ -156,6 +161,7 @@ def finalize_partition_task_single_output(self, stage_id: int) -> SingleOutputPa
resource_request=resource_request_final_cpu,
partial_metadatas=self.partial_metadatas,
actor_pool_id=self.actor_pool_id,
node_id=self.node_id,
)

def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPartitionTask[PartitionT]:
Expand All @@ -177,6 +183,7 @@ def finalize_partition_task_multi_output(self, stage_id: int) -> MultiOutputPart
resource_request=resource_request_final_cpu,
partial_metadatas=self.partial_metadatas,
actor_pool_id=self.actor_pool_id,
node_id=self.node_id,
)

def __str__(self) -> str:
Expand Down
105 changes: 72 additions & 33 deletions daft/execution/shuffles/pre_shuffle_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import Dict

import ray.experimental # noqa: TID253

from daft.daft import ResourceRequest
from daft.execution import execution_step
from daft.execution.execution_step import (
Expand Down Expand Up @@ -40,7 +42,7 @@ def pre_shuffle_merge(
no_more_input = False

while True:
# Get and sort materialized maps by size.
# Get materialized maps by size
materialized_maps = sorted(
[
(
Expand All @@ -58,42 +60,79 @@ def pre_shuffle_merge(
done_with_input = no_more_input and len(materialized_maps) == len(in_flight_maps)

if enough_maps or done_with_input:
# Initialize the first merge group
merge_groups = []
current_group = [materialized_maps[0][0]]
current_size = materialized_maps[0][1]

# Group remaining maps based on memory threshold
for partition, size in materialized_maps[1:]:
if current_size + size > pre_shuffle_merge_threshold:
merge_groups.append(current_group)
current_group = [partition]
current_size = size
# Get location information for all materialized partitions
partitions = [m[0].result().partition() for m in materialized_maps]
location_map = ray.experimental.get_object_locations(partitions)

# Group partitions by node
node_groups = {}
unknown_location_group = [] # Special group for partitions without known location

for (partition, size) in materialized_maps:
partition_ref = partition.partition()
location_info = location_map.get(partition_ref, {})

if not location_info or 'node_ids' not in location_info or not location_info['node_ids']:
unknown_location_group.append((partition, size))
else:
current_group.append(partition)
current_size += size

# Add the last group if it exists and is either:
# 1. Contains more than 1 partition
# 2. Is the last group and we're done with input
# 3. The partition exceeds the memory threshold
if current_group:
if len(current_group) > 1 or done_with_input or current_size > pre_shuffle_merge_threshold:
merge_groups.append(current_group)
node_id = location_info['node_ids'][0] # Use first node if multiple locations exist
if node_id not in node_groups:
node_groups[node_id] = []
node_groups[node_id].append((partition, size))

# Function to create merge groups for a list of partitions
def create_merge_groups(partitions_list):
if not partitions_list:
return []

groups = []
current_group = [partitions_list[0][0]]
current_size = partitions_list[0][1]

for partition, size in partitions_list[1:]:
if current_size + size > pre_shuffle_merge_threshold:
groups.append(current_group)
current_group = [partition]
current_size = size
else:
current_group.append(partition)
current_size += size

# Add the last group if it exists and is either:
# 1. Contains more than 1 partition
# 2. Is the last group and we're done with input
# 3. The partition exceeds the memory threshold
if current_group:
if len(current_group) > 1 or done_with_input or current_size > pre_shuffle_merge_threshold:
groups.append(current_group)

return groups

# Process each node's partitions and unknown location partitions
merge_groups = {}

# Process node-specific groups
for (node_id, node_partitions) in node_groups.items():
merge_groups[node_id] = create_merge_groups(node_partitions)

# Process unknown location group
merge_groups[None] = create_merge_groups(unknown_location_group)

# Create merge steps and remove processed maps
for group in merge_groups:
for (node_id, groups) in merge_groups.items():
# Remove processed maps from in_flight_maps
for partition in group:
del in_flight_maps[partition.id()]

total_size = sum(m.partition_metadata().size_bytes or 0 for m in group)
merge_step = PartitionTaskBuilder[PartitionT](
inputs=[p.partition() for p in group],
partial_metadatas=[m.partition_metadata() for m in group],
resource_request=ResourceRequest(memory_bytes=total_size),
).add_instruction(instruction=execution_step.ReduceMerge())
yield merge_step
for group in groups:
for partition in group:
del in_flight_maps[partition.id()]
print(f"Scheduling merge step for {len(group)} partitions on node {node_id}")
total_size = sum(m.partition_metadata().size_bytes or 0 for m in group)
merge_step = PartitionTaskBuilder[PartitionT](
inputs=[p.partition() for p in group],
partial_metadatas=[m.partition_metadata() for m in group],
resource_request=ResourceRequest(memory_bytes=total_size * 2),
node_id=node_id,
).add_instruction(instruction=execution_step.ReduceMerge())
yield merge_step

# Process next map task if available
try:
Expand Down
11 changes: 8 additions & 3 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# The ray runner is not a top-level module, so we don't need to lazily import pyarrow to minimize
# import times. If this changes, we first need to make the daft.lazy_import.LazyImport class
# serializable before importing pa from daft.dependencies.
import pyarrow as pa # noqa: TID253
import pyarrow as pa
import ray.experimental # noqa: TID253

from daft.arrow_utils import ensure_array
from daft.context import execution_config_ctx, get_context
Expand Down Expand Up @@ -517,7 +518,7 @@ def fanout_pipeline(


@ray_tracing.ray_remote_traced
@ray.remote(scheduling_strategy="SPREAD")
@ray.remote
def reduce_pipeline(
task_context: PartitionTaskContext,
daft_execution_config: PyDaftExecutionConfig,
Expand All @@ -534,7 +535,7 @@ def reduce_pipeline(


@ray_tracing.ray_remote_traced
@ray.remote(scheduling_strategy="SPREAD")
@ray.remote
def reduce_and_fanout(
task_context: PartitionTaskContext,
daft_execution_config: PyDaftExecutionConfig,
Expand Down Expand Up @@ -1016,6 +1017,10 @@ def _build_partitions(
if task.instructions and isinstance(task.instructions[-1], FanoutInstruction)
else reduce_pipeline
)
if task.node_id is not None:
ray_options["scheduling_strategy"] = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(task.node_id, soft=True)
else:
ray_options["scheduling_strategy"] = "SPREAD"
build_remote = build_remote.options(**ray_options).with_tracing(runner_tracer, task)
[metadatas_ref, *partitions] = build_remote.remote(
PartitionTaskContext(job_id=job_id, task_id=task.id(), stage_id=task.stage_id),
Expand Down

0 comments on commit ea89052

Please sign in to comment.