Skip to content

Commit

Permalink
Refactor pipeline mappable (#272)
Browse files Browse the repository at this point in the history
* Change Pipeline to have a single Stage

* Inline Stage (in Pipeline)

* Fix CoiledFunctionsDagExecutor

* Fix formatting
  • Loading branch information
tomwhite authored Jul 25, 2023
1 parent e42f480 commit 834d2ee
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 151 deletions.
19 changes: 9 additions & 10 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from cubed.runtime.pipeline import already_computed
from cubed.storage.zarr import LazyZarrArray
from cubed.utils import chunk_memory, extract_stack_summaries, join_path, memory_repr
from cubed.vendor.rechunker.types import Stage

# A unique ID with sensible ordering, used for making directory names
CONTEXT_ID = f"cubed-{datetime.now().strftime('%Y%m%dT%H%M%S')}-{uuid.uuid4()}"
Expand Down Expand Up @@ -379,14 +378,6 @@ def create_zarr_array(lazy_zarr_array, *, config=None):


def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
stages = [
Stage(
create_zarr_array,
"create_zarr_array",
mappable=lazy_zarr_arrays,
)
]

# projected memory is size of largest initial values, or dtype size if there aren't any
projected_mem = (
max(
Expand All @@ -403,5 +394,13 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
num_tasks = len(lazy_zarr_arrays)

return CubedPipeline(
stages, None, None, projected_mem, reserved_mem, num_tasks, None
create_zarr_array,
"create_zarr_array",
lazy_zarr_arrays,
None,
None,
projected_mem,
reserved_mem,
num_tasks,
None,
)
42 changes: 20 additions & 22 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
from cubed.vendor.dask.core import flatten
from cubed.vendor.rechunker.types import Stage

from .types import CubedArrayProxy, CubedPipeline

Expand Down Expand Up @@ -204,14 +203,6 @@ def blockwise(
write_proxy = CubedArrayProxy(target_array, chunksize)
spec = BlockwiseSpec(block_function, func_with_kwargs, read_proxies, write_proxy)

stages = [
Stage(
apply_blockwise,
gensym("apply_blockwise"),
mappable=output_blocks,
)
]

# calculate projected memory
projected_mem = reserved_mem + extra_projected_mem
# inputs
Expand All @@ -233,7 +224,15 @@ def blockwise(
)

return CubedPipeline(
stages, spec, target_array, projected_mem, reserved_mem, num_tasks, None
apply_blockwise,
gensym("apply_blockwise"),
output_blocks,
spec,
target_array,
projected_mem,
reserved_mem,
num_tasks,
None,
)


Expand All @@ -244,8 +243,7 @@ def is_fuse_candidate(pipeline: CubedPipeline) -> bool:
"""
Return True if a pipeline is a candidate for blockwise fusion.
"""
stages = pipeline.stages
return len(stages) == 1 and stages[0].function == apply_blockwise
return pipeline.function == apply_blockwise


def can_fuse_pipelines(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> bool:
Expand All @@ -261,15 +259,7 @@ def fuse(pipeline1: CubedPipeline, pipeline2: CubedPipeline) -> CubedPipeline:

assert pipeline1.num_tasks == pipeline2.num_tasks

mappable = pipeline2.stages[0].mappable

stages = [
Stage(
apply_blockwise,
gensym("fused_apply_blockwise"),
mappable=mappable,
)
]
mappable = pipeline2.mappable

def fused_blockwise_func(out_key):
return pipeline1.config.block_function(
Expand All @@ -289,7 +279,15 @@ def fused_func(*args):
num_tasks = pipeline2.num_tasks

return CubedPipeline(
stages, spec, target_array, projected_mem, reserved_mem, num_tasks, None
apply_blockwise,
gensym("fused_apply_blockwise"),
mappable,
spec,
target_array,
projected_mem,
reserved_mem,
num_tasks,
None,
)


Expand Down
8 changes: 5 additions & 3 deletions cubed/primitive/types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from dataclasses import dataclass
from typing import Any, Optional, Sequence
from typing import Any, Iterable, Optional

import zarr

from cubed.storage.zarr import T_ZarrArray, open_if_lazy_zarr_array
from cubed.types import T_RegularChunks
from cubed.vendor.rechunker.types import Config, Stage
from cubed.vendor.rechunker.types import Config, StageFunction


@dataclass(frozen=True)
class CubedPipeline:
"""Generalisation of rechunker ``Pipeline`` with extra attributes."""

stages: Sequence[Stage]
function: StageFunction
name: str
mappable: Iterable
config: Config
target_array: Any
projected_mem: int
Expand Down
39 changes: 22 additions & 17 deletions cubed/runtime/executors/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,28 @@ def execute_dag(


def add_to_pcoll(name, rechunker_pipeline, pcoll):
for step, stage in enumerate(rechunker_pipeline.stages):
if stage.mappable is not None:
pcoll |= stage.name >> _SingleArgumentStage(
step, stage, rechunker_pipeline.config, name
)
else:
pcoll |= stage.name >> beam.Map(
_no_arg_stage,
current=step,
fun=stage.function,
config=rechunker_pipeline.config,
)

# This prevents fusion:
# https://cloud.google.com/dataflow/docs/guides/deploying-a-pipeline#preventing-fusion
# Avoiding fusion on Dataflow is necessary to ensure that stages execute serially.
pcoll |= gensym("Reshuffle") >> beam.Reshuffle()
step = 0
stage = Stage(
rechunker_pipeline.function,
rechunker_pipeline.name,
rechunker_pipeline.mappable,
)
if stage.mappable is not None:
pcoll |= stage.name >> _SingleArgumentStage(
step, stage, rechunker_pipeline.config, name
)
else:
pcoll |= stage.name >> beam.Map(
_no_arg_stage,
current=step,
fun=stage.function,
config=rechunker_pipeline.config,
)

# This prevents fusion:
# https://cloud.google.com/dataflow/docs/guides/deploying-a-pipeline#preventing-fusion
# Avoiding fusion on Dataflow is necessary to ensure that stages execute serially.
pcoll |= gensym("Reshuffle") >> beam.Reshuffle()

pcoll |= gensym("End") >> beam.Map(lambda x: -1)
return pcoll
Expand Down
34 changes: 15 additions & 19 deletions cubed/runtime/executors/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,18 @@ def execute_dag(
# Note this currently only builds the task graph for each stage once it gets to that stage in computation
for name, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
for stage in pipeline.stages:
if stage.mappable is not None:
futures = []
for m in stage.mappable:
future_func = exec_stage_func(
stage.function, m, coiled_kwargs, config=pipeline.config
)
futures.append(future_func)
else:
raise NotImplementedError()

# gather the results of the coiled functions
ac = as_completed(futures)
if callbacks is not None:
for future in ac:
result, stats = future.result()
if name is not None:
stats["array_name"] = name
handle_callbacks(callbacks, stats)
futures = []
for m in pipeline.mappable:
future_func = exec_stage_func(
pipeline.function, m, coiled_kwargs, config=pipeline.config
)
futures.append(future_func)

# gather the results of the coiled functions
ac = as_completed(futures)
if callbacks is not None:
for future in ac:
result, stats = future.result()
if name is not None:
stats["array_name"] = name
handle_callbacks(callbacks, stats)
21 changes: 7 additions & 14 deletions cubed/runtime/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,11 @@ def execute_dag(
# Note this currently only builds the task graph for each stage once it gets to that stage in computation
for name, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
for stage in pipeline.stages:
if stage.mappable is not None:
stage_delayed_funcs = []
for m in stage.mappable:
delayed_func = exec_stage_func(
stage.function, m, config=pipeline.config
)
stage_delayed_funcs.append(delayed_func)
else:
delayed_func = exec_stage_func(
stage.function, config=pipeline.config
)
stage_delayed_funcs = [delayed_func]
stage_delayed_funcs = []
for m in pipeline.mappable:
delayed_func = exec_stage_func(
pipeline.function, m, config=pipeline.config
)
stage_delayed_funcs.append(delayed_func)

dask.compute(*stage_delayed_funcs, **compute_kwargs)
dask.compute(*stage_delayed_funcs, **compute_kwargs)
35 changes: 16 additions & 19 deletions cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,33 +185,30 @@ def execute_dag(
if not compute_arrays_in_parallel:
for name, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]
for stage in pipeline.stages:
if stage.mappable is not None:
for _, stats in map_unordered(
executor,
[run_func],
[stage.mappable],
[name],
func=stage.function,
config=pipeline.config,
name=name,
use_backups=use_backups,
return_stats=True,
):
handle_callbacks(callbacks, stats)
else:
raise NotImplementedError()
for _, stats in map_unordered(
executor,
[run_func],
[pipeline.mappable],
[name],
func=pipeline.function,
config=pipeline.config,
name=name,
use_backups=use_backups,
return_stats=True,
):
handle_callbacks(callbacks, stats)
else:
for gen in visit_node_generations(dag, resume=resume):
group_map_functions = []
group_map_iterdata = []
group_names = []
for name, node in gen:
pipeline = node["pipeline"]
stage = pipeline.stages[0] # assume one
f = partial(run_func, func=stage.function, config=pipeline.config)
f = partial(
run_func, func=pipeline.function, config=pipeline.config
)
group_map_functions.append(f)
group_map_iterdata.append(stage.mappable)
group_map_iterdata.append(pipeline.mappable)
group_names.append(name)
for _, stats in map_unordered(
executor,
Expand Down
23 changes: 9 additions & 14 deletions cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,15 @@ def execute_dag(
raise ValueError(f"Unrecognized cloud: {cloud}")
for name, node in visit_nodes(dag, resume=resume):
pipeline = node["pipeline"]

for stage in pipeline.stages:
if stage.mappable is not None:
task_create_tstamp = time.time()
for _, stats in app_function.map(
stage.mappable,
order_outputs=False,
kwargs=dict(func=stage.function, config=pipeline.config),
):
stats["array_name"] = name
stats["task_create_tstamp"] = task_create_tstamp
handle_callbacks(callbacks, stats)
else:
raise NotImplementedError()
task_create_tstamp = time.time()
for _, stats in app_function.map(
pipeline.mappable,
order_outputs=False,
kwargs=dict(func=pipeline.function, config=pipeline.config),
):
stats["array_name"] = name
stats["task_create_tstamp"] = task_create_tstamp
handle_callbacks(callbacks, stats)


class ModalDagExecutor(DagExecutor):
Expand Down
8 changes: 2 additions & 6 deletions cubed/runtime/executors/modal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,18 @@ async def map_unordered(


def pipeline_to_stream(app_function, name, pipeline, **kwargs):
if any([stage for stage in pipeline.stages if stage.mappable is None]):
raise NotImplementedError("All stages must be mappable in pipelines")
it = stream.iterate(
[
partial(
map_unordered,
app_function,
stage.mappable,
pipeline.mappable,
return_stats=True,
name=name,
func=stage.function,
func=pipeline.function,
config=pipeline.config,
**kwargs,
)
for stage in pipeline.stages
if stage.mappable is not None
]
)
# concat stages, running only one stage at a time
Expand Down
17 changes: 5 additions & 12 deletions cubed/runtime/executors/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,8 @@ def execute_dag(
) -> None:
for name, node in visit_nodes(dag, resume=resume):
pipeline: CubedPipeline = node["pipeline"]
for stage in pipeline.stages:
if stage.mappable is not None:
for m in stage.mappable:
exec_stage_func(stage.function, m, config=pipeline.config)
if callbacks is not None:
event = TaskEndEvent(array_name=name)
[callback.on_task_end(event) for callback in callbacks]
else:
exec_stage_func(stage.function, config=pipeline.config)
if callbacks is not None:
event = TaskEndEvent(array_name=name)
[callback.on_task_end(event) for callback in callbacks]
for m in pipeline.mappable:
exec_stage_func(pipeline.function, m, config=pipeline.config)
if callbacks is not None:
event = TaskEndEvent(array_name=name)
[callback.on_task_end(event) for callback in callbacks]
Loading

0 comments on commit 834d2ee

Please sign in to comment.