Skip to content

Commit f0becdc

Browse files
zhengruifengdongjoon-hyun
authored andcommitted
[SPARK-54144][PYTHON] Short Circuit Eval Type Inferences
### What changes were proposed in this pull request? Short Circuit Eval Type Inferences: ### Why are the changes needed? minor optimization that avoid unnecessary inference ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #52843 from zhengruifeng/short_circuit_type_infer. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 6e4936d) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 23afe0d commit f0becdc

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

python/pyspark/sql/pandas/typehints.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@ def infer_group_arrow_eval_type(
353353
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
354354
)
355355
)
356+
if is_iterator_batch:
357+
return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
358+
356359
# Tuple[pa.Scalar, ...], Iterator[pa.RecordBatch] -> Iterator[pa.RecordBatch]
357360
is_iterator_batch_with_keys = (
358361
len(parameters_sig) == 2
@@ -364,19 +367,21 @@ def infer_group_arrow_eval_type(
364367
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
365368
)
366369
)
367-
368-
if is_iterator_batch or is_iterator_batch_with_keys:
370+
if is_iterator_batch_with_keys:
369371
return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
370372

371373
# pa.Table -> pa.Table
372374
is_table = (
373375
len(parameters_sig) == 1 and parameters_sig[0] == pa.Table and return_annotation == pa.Table
374376
)
377+
if is_table:
378+
return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
379+
375380
# Tuple[pa.Scalar, ...], pa.Table -> pa.Table
376381
is_table_with_keys = (
377382
len(parameters_sig) == 2 and parameters_sig[1] == pa.Table and return_annotation == pa.Table
378383
)
379-
if is_table or is_table_with_keys:
384+
if is_table_with_keys:
380385
return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
381386

382387
return None
@@ -441,6 +446,9 @@ def infer_group_pandas_eval_type(
441446
return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
442447
)
443448
)
449+
if is_iterator_dataframe:
450+
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
451+
444452
# Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
445453
is_iterator_dataframe_with_keys = (
446454
len(parameters_sig) == 2
@@ -452,8 +460,7 @@ def infer_group_pandas_eval_type(
452460
return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
453461
)
454462
)
455-
456-
if is_iterator_dataframe or is_iterator_dataframe_with_keys:
463+
if is_iterator_dataframe_with_keys:
457464
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
458465

459466
# pd.DataFrame -> pd.DataFrame
@@ -462,13 +469,16 @@ def infer_group_pandas_eval_type(
462469
and parameters_sig[0] == pd.DataFrame
463470
and return_annotation == pd.DataFrame
464471
)
472+
if is_dataframe:
473+
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
474+
465475
# Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
466476
is_dataframe_with_keys = (
467477
len(parameters_sig) == 2
468478
and parameters_sig[1] == pd.DataFrame
469479
and return_annotation == pd.DataFrame
470480
)
471-
if is_dataframe or is_dataframe_with_keys:
481+
if is_dataframe_with_keys:
472482
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
473483

474484
return None
@@ -512,11 +522,9 @@ def check_iterator_annotation(
512522
def check_union_annotation(
513523
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = None
514524
) -> bool:
515-
import typing
516-
517525
# Note that we cannot rely on '__origin__' in other type hints as it has changed from version
518526
# to version.
519527
origin = getattr(annotation, "__origin__", None)
520-
return origin == typing.Union and (
528+
return origin == Union and (
521529
parameter_check_func is None or all(map(parameter_check_func, annotation.__args__))
522530
)

0 commit comments

Comments
 (0)