Skip to content

Commit

Permalink
Parallel UDF optimizations (#211)
Browse files Browse the repository at this point in the history
* Parallel UDF optimizations

- return rows instead of dicts from SQL query
- marshall UDF rows with msgpack instead of dill
- add batching to UDF results

* Fix 'processed rows' counter
  • Loading branch information
dreadatour authored Aug 8, 2024
1 parent 77662ff commit 6a16e6f
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 155 deletions.
22 changes: 8 additions & 14 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from datachain.client import Client
from datachain.data_storage.serializer import Serializable
from datachain.dataset import DatasetRecord, RowDict
from datachain.dataset import DatasetRecord
from datachain.node import DirType, DirTypeGroup, Entry, Node, NodeWithPath, get_path
from datachain.sql.functions import path as pathfunc
from datachain.sql.types import Int, SQLType
Expand Down Expand Up @@ -201,23 +201,17 @@ def dataset_row_cls(self) -> type["DataTable"]:
def dataset_select_paginated(
self,
query,
limit: Optional[int] = None,
order_by: tuple["ColumnElement[Any]", ...] = (),
page_size: int = SELECT_BATCH_SIZE,
) -> Generator[RowDict, None, None]:
) -> Generator[Sequence, None, None]:
"""
This is equivalent to `db.execute`, but for selecting rows in batches
"""
cols = query.selected_columns
cols_names = [c.name for c in cols]
limit = query._limit
paginated_query = query.limit(page_size)

if not order_by:
ordering = [cols.sys__id]
else:
ordering = order_by # type: ignore[assignment]

# reset query order by and apply new order by id
paginated_query = query.order_by(None).order_by(*ordering).limit(page_size)
if not paginated_query._order_by_clauses:
# default order by is order by `sys__id`
paginated_query = paginated_query.order_by(query.selected_columns.sys__id)

results = None
offset = 0
Expand All @@ -236,7 +230,7 @@ def dataset_select_paginated(
processed = False
for row in results:
processed = True
yield RowDict(zip(cols_names, row))
yield row
num_yielded += 1

if not processed:
Expand Down
35 changes: 21 additions & 14 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
import traceback
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING, Callable, Optional

from fsspec.callbacks import DEFAULT_CALLBACK, Callback
Expand All @@ -14,16 +13,19 @@
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf_signature import UdfSignature
from datachain.lib.utils import AbstractUDF, DataChainError, DataChainParamsError
from datachain.query.batch import RowBatch
from datachain.query.batch import UDFInputBatch
from datachain.query.schema import ColumnParameter
from datachain.query.udf import UDFBase as _UDFBase
from datachain.query.udf import UDFProperties, UDFResult
from datachain.query.udf import UDFProperties

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence

from typing_extensions import Self

from datachain.catalog import Catalog
from datachain.query.batch import BatchingResult
from datachain.query.batch import RowsOutput, UDFInput
from datachain.query.udf import UDFResult


class UdfError(DataChainParamsError):
Expand All @@ -42,35 +44,40 @@ def __init__(

def run(
self,
udf_inputs: "Iterable[BatchingResult]",
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[RowsOutput]",
catalog: "Catalog",
is_generator: bool,
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
) -> Iterator[Iterable["UDFResult"]]:
) -> "Iterator[Iterable[UDFResult]]":
self.inner._catalog = catalog
if hasattr(self.inner, "setup") and callable(self.inner.setup):
self.inner.setup()

for batch in udf_inputs:
n_rows = len(batch.rows) if isinstance(batch, RowBatch) else 1
output = self.run_once(catalog, batch, is_generator, cache, cb=download_cb)
processed_cb.relative_update(n_rows)
yield output
yield from super().run(
udf_fields,
udf_inputs,
catalog,
is_generator,
cache,
download_cb,
processed_cb,
)

if hasattr(self.inner, "teardown") and callable(self.inner.teardown):
self.inner.teardown()

def run_once(
self,
catalog: "Catalog",
arg: "BatchingResult",
arg: "UDFInput",
is_generator: bool = False,
cache: bool = False,
cb: Callback = DEFAULT_CALLBACK,
) -> Iterable[UDFResult]:
if isinstance(arg, RowBatch):
) -> "Iterable[UDFResult]":
if isinstance(arg, UDFInputBatch):
udf_inputs = [
self.bind_parameters(catalog, row, cache=cache, cb=cb)
for row in arg.rows
Expand Down
86 changes: 45 additions & 41 deletions src/datachain/query/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,29 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Union

import sqlalchemy as sa

from datachain.data_storage.schema import PARTITION_COLUMN_ID
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE

if TYPE_CHECKING:
from sqlalchemy import Select

from datachain.dataset import RowDict


@dataclass
class RowBatch:
class RowsOutputBatch:
rows: Sequence[Sequence]


RowsOutput = Union[Sequence, RowsOutputBatch]


@dataclass
class UDFInputBatch:
rows: Sequence["RowDict"]


BatchingResult = Union["RowDict", RowBatch]
UDFInput = Union["RowDict", UDFInputBatch]


class BatchingStrategy(ABC):
Expand All @@ -28,9 +36,9 @@ class BatchingStrategy(ABC):
@abstractmethod
def __call__(
self,
execute: Callable,
query: sa.sql.selectable.Select,
) -> Generator[BatchingResult, None, None]:
execute: Callable[..., Generator[Sequence, None, None]],
query: "Select",
) -> Generator[RowsOutput, None, None]:
"""Apply the provided parameters to the UDF."""


Expand All @@ -42,10 +50,10 @@ class NoBatching(BatchingStrategy):

def __call__(
self,
execute: Callable,
query: sa.sql.selectable.Select,
) -> Generator["RowDict", None, None]:
return execute(query, limit=query._limit, order_by=query._order_by_clauses)
execute: Callable[..., Generator[Sequence, None, None]],
query: "Select",
) -> Generator[Sequence, None, None]:
return execute(query)


class Batch(BatchingStrategy):
Expand All @@ -59,31 +67,24 @@ def __init__(self, count: int):

def __call__(
self,
execute: Callable,
query: sa.sql.selectable.Select,
) -> Generator[RowBatch, None, None]:
execute: Callable[..., Generator[Sequence, None, None]],
query: "Select",
) -> Generator[RowsOutputBatch, None, None]:
# choose page size that is a multiple of the batch size
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count

# select rows in batches
results: list[RowDict] = []

with contextlib.closing(
execute(
query,
page_size=page_size,
limit=query._limit,
order_by=query._order_by_clauses,
)
) as rows:
results: list[Sequence] = []

with contextlib.closing(execute(query, page_size=page_size)) as rows:
for row in rows:
results.append(row)
if len(results) >= self.count:
batch, results = results[: self.count], results[self.count :]
yield RowBatch(batch)
yield RowsOutputBatch(batch)

if len(results) > 0:
yield RowBatch(results)
yield RowsOutputBatch(results)


class Partition(BatchingStrategy):
Expand All @@ -95,27 +96,30 @@ class Partition(BatchingStrategy):

def __call__(
self,
execute: Callable,
query: sa.sql.selectable.Select,
) -> Generator[RowBatch, None, None]:
execute: Callable[..., Generator[Sequence, None, None]],
query: "Select",
) -> Generator[RowsOutputBatch, None, None]:
current_partition: Optional[int] = None
batch: list[RowDict] = []

with contextlib.closing(
execute(
query,
order_by=(PARTITION_COLUMN_ID, "sys__id", *query._order_by_clauses),
limit=query._limit,
)
) as rows:
batch: list[Sequence] = []

query_fields = [str(c.name) for c in query.selected_columns]
partition_column_idx = query_fields.index(PARTITION_COLUMN_ID)

ordered_query = query.order_by(None).order_by(
PARTITION_COLUMN_ID,
"sys__id",
*query._order_by_clauses,
)

with contextlib.closing(execute(ordered_query)) as rows:
for row in rows:
partition = row[PARTITION_COLUMN_ID]
partition = row[partition_column_idx]
if current_partition != partition:
current_partition = partition
if len(batch) > 0:
yield RowBatch(batch)
yield RowsOutputBatch(batch)
batch = []
batch.append(row)

if len(batch) > 0:
yield RowBatch(batch)
yield RowsOutputBatch(batch)
19 changes: 13 additions & 6 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

processes = determine_processes(self.parallel)

udf_fields = [str(c.name) for c in query.selected_columns]

try:
if workers:
from datachain.catalog.loader import get_distributed_class
Expand All @@ -473,6 +475,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
query,
workers,
processes,
udf_fields=udf_fields,
is_generator=self.is_generator,
use_partitioning=use_partitioning,
cache=self.cache,
Expand All @@ -489,6 +492,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
"warehouse_clone_params": self.catalog.warehouse.clone_params(),
"table": udf_table,
"query": query,
"udf_fields": udf_fields,
"batching": batching,
"processes": processes,
"is_generator": self.is_generator,
Expand Down Expand Up @@ -528,6 +532,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
Expand Down Expand Up @@ -1244,21 +1249,23 @@ def extract(
actual_params = [normalize_param(p) for p in params]
try:
query = self.apply_steps().select()
query_fields = [str(c.name) for c in query.selected_columns]

def row_iter() -> Generator[RowDict, None, None]:
def row_iter() -> Generator[Sequence, None, None]:
# warehouse isn't threadsafe, we need to clone() it
# in the thread that uses the results
with self.catalog.warehouse.clone() as warehouse:
gen = warehouse.dataset_select_paginated(
query, limit=query._limit, order_by=query._order_by_clauses
)
gen = warehouse.dataset_select_paginated(query)
with contextlib.closing(gen) as rows:
yield from rows

async def get_params(row: RowDict) -> tuple:
async def get_params(row: Sequence) -> tuple:
row_dict = RowDict(zip(query_fields, row))
return tuple(
[
await p.get_value_async(self.catalog, row, mapper, **kwargs)
await p.get_value_async(
self.catalog, row_dict, mapper, **kwargs
)
for p in actual_params
]
)
Expand Down
Loading

0 comments on commit 6a16e6f

Please sign in to comment.