Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51118][PYTHON] Fix ExtractPythonUDFs to check the chained UDF input types for fallback #50341

Closed

Conversation

ueshin
Copy link
Member

@ueshin ueshin commented Mar 20, 2025

What changes were proposed in this pull request?

Fixes ExtractPythonUDFs to check the chained UDF input types for fallback.

Why are the changes needed?

Currently the fallback of Arrow-optimized Python UDF to non Arrow for the case the UDF has UDT input/output only works with not chained UDFs because it checks only the last UDFs.

For example:

from pyspark.sql.functions import udf
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ExamplePoint, ExamplePointUDT

row = Row(
    label=1.0,
    point=ExamplePoint(1.0, 2.0),
)

df = spark.createDataFrame([row])

@udf(returnType=DoubleType(), useArrow=True)
def udtInDoubleOut(e):
    return e.y

@udf(returnType=DoubleType(), useArrow=True)
def doubleInDoubleOut(d):
    return d * 100.0

df.select(doubleInDoubleOut(udtInDoubleOut(df.point))).show()

This doesn't fallback to non Arrow because doubleInDoubleOut looks like no UDT input/output and fails with:

pyspark.errors.exceptions.captured.PythonException:
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  ...
AttributeError: 'list' object has no attribute 'y'

Does this PR introduce any user-facing change?

Yes, the fallback will work with chained UDFs, too.

How was this patch tested?

Added the related tests.

Was this patch authored or co-authored using generative AI tooling?

No.

@ueshin ueshin marked this pull request as draft March 21, 2025 00:24
@ueshin
Copy link
Member Author

ueshin commented Mar 21, 2025

I'll change the implementation.

@ueshin ueshin requested a review from zhengruifeng March 21, 2025 02:30
@ueshin ueshin marked this pull request as ready for review March 21, 2025 03:11
@@ -173,7 +173,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
case Seq(u: PythonUDF) => correctEvalType(e) == correctEvalType(u) && canEvaluateInPython(u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if is possible to add a rewrite-like rule for SQL_ARROW_BATCHED_UDF <-> SQL_BATCHED_UDF conversion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can, but I feel it's too much .. as this fallback should be temporary, and we should support UDTs with Arrow-optimized Python UDFs soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG

@HyukjinKwon
Copy link
Member

HyukjinKwon commented Mar 24, 2025

Merged to master and branch-4.0.

HyukjinKwon pushed a commit that referenced this pull request Mar 24, 2025
…input types for fallback

### What changes were proposed in this pull request?

Fixes `ExtractPythonUDFs` to check the chained UDF input types for fallback.

### Why are the changes needed?

Currently the fallback of Arrow-optimized Python UDF to non Arrow for the case the UDF has UDT input/output only works with not chained UDFs because it checks only the last UDFs.

For example:

```py
from pyspark.sql.functions import udf
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ExamplePoint, ExamplePointUDT

row = Row(
    label=1.0,
    point=ExamplePoint(1.0, 2.0),
)

df = spark.createDataFrame([row])

udf(returnType=DoubleType(), useArrow=True)
def udtInDoubleOut(e):
    return e.y

udf(returnType=DoubleType(), useArrow=True)
def doubleInDoubleOut(d):
    return d * 100.0

df.select(doubleInDoubleOut(udtInDoubleOut(df.point))).show()
```

This doesn't fallback to non Arrow because `doubleInDoubleOut` looks like no UDT input/output and fails with:

```
pyspark.errors.exceptions.captured.PythonException:
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  ...
AttributeError: 'list' object has no attribute 'y'
```

### Does this PR introduce _any_ user-facing change?

Yes, the fallback will work with chained UDFs, too.

### How was this patch tested?

Added the related tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #50341 from ueshin/issues/SPARK-51118/chained_udf_with_udt.

Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 4e30f96)
Signed-off-by: Hyukjin Kwon <[email protected]>
SauronShepherd pushed a commit to SauronShepherd/spark that referenced this pull request Mar 25, 2025
…input types for fallback

### What changes were proposed in this pull request?

Fixes `ExtractPythonUDFs` to check the chained UDF input types for fallback.

### Why are the changes needed?

Currently the fallback of Arrow-optimized Python UDF to non Arrow for the case the UDF has UDT input/output only works with not chained UDFs because it checks only the last UDFs.

For example:

```py
from pyspark.sql.functions import udf
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ExamplePoint, ExamplePointUDT

row = Row(
    label=1.0,
    point=ExamplePoint(1.0, 2.0),
)

df = spark.createDataFrame([row])

udf(returnType=DoubleType(), useArrow=True)
def udtInDoubleOut(e):
    return e.y

udf(returnType=DoubleType(), useArrow=True)
def doubleInDoubleOut(d):
    return d * 100.0

df.select(doubleInDoubleOut(udtInDoubleOut(df.point))).show()
```

This doesn't fallback to non Arrow because `doubleInDoubleOut` looks like no UDT input/output and fails with:

```
pyspark.errors.exceptions.captured.PythonException:
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  ...
AttributeError: 'list' object has no attribute 'y'
```

### Does this PR introduce _any_ user-facing change?

Yes, the fallback will work with chained UDFs, too.

### How was this patch tested?

Added the related tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#50341 from ueshin/issues/SPARK-51118/chained_udf_with_udt.

Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants