Skip to content

Commit

Permalink
update error msg
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonwang-db committed Aug 30, 2024
1 parent 7b54d9c commit 475fb49
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def _create_batch(self, series):
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
"A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s))
"Invalid return type. Please make sure that the UDF returns a "
"pandas.DataFrame when the specified return type is StructType."
)
arrs.append(self._create_struct_array(s, t))
else:
Expand Down
16 changes: 15 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit
from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType
from pyspark.errors import ParseException, PythonException, PySparkTypeError
from pyspark.errors import ParseException, PythonException, PySparkTypeError, PySparkValueError
from pyspark.util import PythonEvalType
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -339,6 +339,20 @@ def noop(s: pd.Series) -> pd.Series:
self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second")
self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))

def test_pandas_udf_return_type_error(self):
import pandas as pd

@pandas_udf("s string")
def upper(s: pd.Series) -> pd.Series:
return s.str.upper()

df = self.spark.createDataFrame([("a",)], schema="s string")

with self.assertRaisesRegex(
PySparkValueError, "Invalid return type."
):
df.select(upper("s")).collect()


class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down

0 comments on commit 475fb49

Please sign in to comment.