From c30f6a868731ad21f259b5c3419bfe3e0f42dcea Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 18 Dec 2024 18:31:42 -0800 Subject: [PATCH] fix(udf): udf call with empty table and batch size (#3604) --- daft/udf.py | 2 +- tests/expressions/test_udf.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/daft/udf.py b/daft/udf.py index 36e841c683..3efad69dc2 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -125,7 +125,7 @@ def get_args_for_slice(start: int, end: int): return args, kwargs - if batch_size is None: + if batch_size is None or len(evaluated_expressions[0]) <= batch_size: args, kwargs = get_args_for_slice(0, len(evaluated_expressions[0])) try: results = [func(*args, **kwargs)] diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index 0fad641939..ff43df639e 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -423,3 +423,19 @@ def noop(data): with pytest.raises(OverflowError): table.eval_expression_list([noop.override_options(batch_size=-1)(col("a"))]) + + +@pytest.mark.parametrize("batch_size", [None, 1, 2]) +@pytest.mark.parametrize("use_actor_pool", [False, True]) +def test_udf_empty(batch_size, use_actor_pool): + df = daft.from_pydict({"a": []}) + + @udf(return_dtype=DataType.int64(), batch_size=batch_size) + def identity(data): + return data + + if use_actor_pool: + identity = identity.with_concurrency(2) + + result = df.select(identity(col("a"))) + assert result.to_pydict() == {"a": []}