Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51118][PYTHON] Fix ExtractPythonUDFs to check the chained UDF input types for fallback #50341

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.sql.types import VarcharType
from pyspark.sql.types import ArrayType, DoubleType, VarcharType
from pyspark.testing.sqlutils import (
ExamplePoint,
ExamplePointUDT,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
ReusedSQLTestCase,
)
from pyspark.testing.utils import assertDataFrameEqual
from pyspark.util import PythonEvalType


Expand Down Expand Up @@ -214,6 +217,103 @@ def test_udf(a, b):
with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()

def test_udf_with_udt(self):
row = Row(
label=1.0,
point=ExamplePoint(1.0, 2.0),
points=[ExamplePoint(4.0, 5.0), ExamplePoint(6.0, 7.0)],
)
df = self.spark.createDataFrame([row])

for use_arrow in [False, True]:
with self.subTest(use_arrow=use_arrow):

@udf(returnType=ExamplePointUDT(), useArrow=use_arrow)
def doubleInUDTOut(d):
return ExamplePoint(d, 10 * d)

@udf(returnType=DoubleType(), useArrow=use_arrow)
def udtInDoubleOut(e):
return e.y

@udf(returnType=ArrayType(ExamplePointUDT()), useArrow=use_arrow)
def doubleInUDTArrayOut(d):
return [ExamplePoint(d + i, 10 * d + i) for i in range(2)]

@udf(returnType=DoubleType(), useArrow=use_arrow)
def udtArrayInDoubleOut(es):
return es[-1].y

@udf(returnType=ExamplePointUDT(), useArrow=use_arrow)
def udtInUDTOut(e):
return ExamplePoint(e.x * 10.0, e.y * 10.0)

@udf(returnType=DoubleType(), useArrow=use_arrow)
def doubleInDoubleOut(d):
return d * 100.0

queries = [
(
"double -> UDT",
df.select(doubleInUDTOut(df.label)),
[Row(ExamplePoint(1.0, 10.0))],
),
(
"UDT -> double",
df.select(udtInDoubleOut(df.point)),
[Row(2.0)],
),
(
"double -> array of UDT",
df.select(doubleInUDTArrayOut(df.label)),
[Row([ExamplePoint(1.0, 10.0), ExamplePoint(2.0, 11.0)])],
),
(
"array of UDT -> double",
df.select(udtArrayInDoubleOut(df.points)),
[Row(7.0)],
),
(
"double -> UDT -> double",
df.select(udtInDoubleOut(doubleInUDTOut(df.label))),
[Row(10.0)],
),
(
"double -> UDT -> UDT",
df.select(udtInUDTOut(doubleInUDTOut(df.label))),
[Row(ExamplePoint(10.0, 100.0))],
),
(
"double -> double -> UDT",
df.select(doubleInUDTOut(doubleInDoubleOut(df.label))),
[Row(ExamplePoint(100.0, 1000.0))],
),
(
"UDT -> UDT -> double",
df.select(udtInDoubleOut(udtInUDTOut(df.point))),
[Row(20.0)],
),
(
"UDT -> UDT -> UDT",
df.select(udtInUDTOut(udtInUDTOut(df.point))),
[Row(ExamplePoint(100.0, 200.0))],
),
(
"UDT -> double -> double",
df.select(doubleInDoubleOut(udtInDoubleOut(df.point))),
[Row(200.0)],
),
(
"UDT -> double -> UDT",
df.select(doubleInUDTOut(udtInDoubleOut(df.point))),
[Row(ExamplePoint(2.0, 20.0))],
),
]

for chain, actual, expected in queries:
with self.subTest(chain=chain):
assertDataFrameEqual(actual=actual, expected=expected)


class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def test_udf_with_udt(self):
udf = F.udf(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
arrow_udf = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
self.assertEqual(2.0, df.select(arrow_udf(df.point)).first()[0])

udf2 = F.udf(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
case Seq(u: PythonUDF) => correctEvalType(e) == correctEvalType(u) && canEvaluateInPython(u)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if is possible to add a rewrite-like rule for SQL_ARROW_BATCHED_UDF <-> SQL_BATCHED_UDF conversion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can, but I feel it's too much .. as this fallback should be temporary, and we should support UDTs with Arrow-optimized Python UDFs soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG

// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasScalarPythonUDF)
}
Expand All @@ -197,10 +197,10 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& firstVisitedScalarUDFEvalType.isEmpty =>
firstVisitedScalarUDFEvalType = Some(udf.evalType)
firstVisitedScalarUDFEvalType = Some(correctEvalType(udf))
Seq(udf)
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& canChainUDF(udf.evalType) =>
&& canChainUDF(correctEvalType(udf)) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs)
}
Expand Down Expand Up @@ -235,6 +235,18 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
}
}

private def correctEvalType(udf: PythonUDF): Int = {
if (udf.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) {
if (containsUDT(udf.dataType) || udf.children.exists(expr => containsUDT(expr.dataType))) {
PythonEvalType.SQL_BATCHED_UDF
} else {
PythonEvalType.SQL_ARROW_BATCHED_UDF
}
} else {
udf.evalType
}
}

private def containsUDT(dataType: DataType): Boolean = dataType match {
case _: UserDefinedType[_] => true
case ArrayType(elementType, _) => containsUDT(elementType)
Expand Down Expand Up @@ -272,33 +284,25 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
AttributeReference(s"pythonUDF$i", u.dataType)()
}

val evalTypes = validUdfs.map(_.evalType).toSet
val evalTypes = validUdfs.map(correctEvalType).toSet
if (evalTypes.size != 1) {
throw SparkException.internalError(
"Expected udfs have the same evalType but got different evalTypes: " +
evalTypes.mkString(","))
}
val evalType = evalTypes.head

val hasUDTInput = validUdfs.exists(_.children.exists(expr => containsUDT(expr.dataType)))
val hasUDTReturn = validUdfs.exists(udf => containsUDT(udf.dataType))

val evaluation = evalType match {
case PythonEvalType.SQL_BATCHED_UDF =>
BatchEvalPython(validUdfs, resultAttrs, child)
case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF =>
ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
case PythonEvalType.SQL_ARROW_BATCHED_UDF =>

if (hasUDTInput || hasUDTReturn) {
if (validUdfs.exists(_.evalType != PythonEvalType.SQL_BATCHED_UDF)) {
// Use BatchEvalPython if UDT is detected
logWarning(log"Arrow optimization disabled due to " +
log"${MDC(REASON, "UDT input or return type")}. " +
log"Falling back to non-Arrow-optimized UDF execution.")
BatchEvalPython(validUdfs, resultAttrs, child)
} else {
ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
}
BatchEvalPython(validUdfs, resultAttrs, child)
case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
| PythonEvalType.SQL_ARROW_BATCHED_UDF =>
ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
case _ =>
throw SparkException.internalError("Unexpected UDF evalType")
}
Expand Down