Skip to content

Commit

Permalink
[SPARK-32793][SQL] Add raise_error function, adds error message param…
Browse files Browse the repository at this point in the history
…eter to assert_true

## What changes were proposed in this pull request?

Adds a SQL function `raise_error` which underlies the refactored `assert_true` function. `assert_true` now also (optionally) accepts a custom error message field.
`raise_error` is exposed in SQL, Python, Scala, and R.
`assert_true` was previously only exposed in SQL; it is now also exposed in Python, Scala, and R.

### Why are the changes needed?

Improves usability of `assert_true` by clarifying error messaging, and adds the useful helper function `raise_error`.

### Does this PR introduce _any_ user-facing change?

Yes:
- Adds `raise_error` function to the SQL, Python, Scala, and R APIs.
- Adds `assert_true` function to the SQL, Python and R APIs.

### How was this patch tested?

Adds unit tests in SQL, Python, Scala, and R for `assert_true` and `raise_error`.

Closes apache#29947 from karenfeng/spark-32793.

Lead-authored-by: Karen Feng <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
  • Loading branch information
karenfeng and HyukjinKwon committed Oct 8, 2020
1 parent 473b3ba commit 39510b0
Show file tree
Hide file tree
Showing 18 changed files with 450 additions and 53 deletions.
49 changes: 49 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,55 @@ setMethod("xxhash64",
column(jc)
})

#' @details
#' \code{assert_true}: Returns null if the input column is true; throws an exception
#' with the provided error message otherwise.
#'
#' @param errMsg (optional) The error message to be thrown.
#'
#' @rdname column_misc_functions
#' @aliases assert_true assert_true,Column-method
#' @examples
#' \dontrun{
#' tmp <- mutate(df, v1 = assert_true(df$vs < 2),
#' v2 = assert_true(df$vs < 2, "custom error message"),
#' v3 = assert_true(df$vs < 2, df$vs))
#' head(tmp)}
#' @note assert_true since 3.1.0
setMethod("assert_true",
signature(x = "Column"),
function(x, errMsg = NULL) {
jc <- if (is.null(errMsg)) {
callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc)
} else {
if (is.character(errMsg) && length(errMsg) == 1) {
errMsg <- lit(errMsg)
}
callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc, errMsg@jc)
}
column(jc)
})

#' @details
#' \code{raise_error}: Throws an exception with the provided error message.
#'
#' @rdname column_misc_functions
#' @aliases raise_error raise_error,characterOrColumn-method
#' @examples
#' \dontrun{
#' tmp <- mutate(df, v1 = raise_error("error message"))
#' head(tmp)}
#' @note raise_error since 3.1.0
setMethod("raise_error",
signature(x = "characterOrColumn"),
function(x) {
if (is.character(x) && length(x) == 1) {
x <- lit(x)
}
jc <- callJStatic("org.apache.spark.sql.functions", "raise_error", x@jc)
column(jc)
})

#' @details
#' \code{dayofmonth}: Extracts the day of the month as an integer from a
#' given date/timestamp/string.
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,10 @@ setGeneric("arrays_zip_with", function(x, y, f) { standardGeneric("arrays_zip_wi
#' @name NULL
setGeneric("ascii", function(x) { standardGeneric("ascii") })

#' @rdname column_misc_functions
#' @name NULL
setGeneric("assert_true", function(x, errMsg = NULL) { standardGeneric("assert_true") })

#' @param x Column to compute on or a GroupedData object.
#' @param ... additional argument(s) when \code{x} is a GroupedData object.
#' @rdname avg
Expand Down Expand Up @@ -1223,6 +1227,10 @@ setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer")
#' @name NULL
setGeneric("quarter", function(x) { standardGeneric("quarter") })

#' @rdname column_misc_functions
#' @name NULL
setGeneric("raise_error", function(x) { standardGeneric("raise_error") })

#' @rdname column_nonaggregate_functions
#' @name NULL
setGeneric("rand", function(seed) { standardGeneric("rand") })
Expand Down
18 changes: 18 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -3945,6 +3945,24 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", {
dropTempView("cars")
})

test_that("assert_true, raise_error", {
df <- read.json(jsonPath)
filtered <- filter(df, "age < 20")

expect_equal(collect(select(filtered, assert_true(filtered$age < 20)))$age, c(NULL))
expect_equal(collect(select(filtered, assert_true(filtered$age < 20, "error message")))$age,
c(NULL))
expect_equal(collect(select(filtered, assert_true(filtered$age < 20, filtered$name)))$age,
c(NULL))
expect_error(collect(select(df, assert_true(df$age < 20))), "is not true!")
expect_error(collect(select(df, assert_true(df$age < 20, "error message"))),
"error message")
expect_error(collect(select(df, assert_true(df$age < 20, df$name))), "Michael")

expect_error(collect(select(filtered, raise_error("error message"))), "error message")
expect_error(collect(select(filtered, raise_error(filtered$name))), "Justin")
})

compare_list <- function(list1, list2) {
# get testthat to show the diff by first making the 2 lists equal in length
expect_equal(length(list1), length(list2))
Expand Down
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ Functions
asc_nulls_last
ascii
asin
assert_true
atan
atan2
avg
Expand Down Expand Up @@ -420,6 +421,7 @@ Functions
pow
quarter
radians
raise_error
rand
randn
rank
Expand Down
55 changes: 53 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,57 @@ def xxhash64(*cols):
return Column(jc)


@since(3.1)
def assert_true(col, errMsg=None):
"""
Returns null if the input column is true; throws an exception with the provided error message
otherwise.
>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
>>> df.select(assert_true(df.a < df.b).alias('r')).collect()
[Row(r=None)]
>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
>>> df.select(assert_true(df.a < df.b, df.a).alias('r')).collect()
[Row(r=None)]
>>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
>>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect()
[Row(r=None)]
"""
sc = SparkContext._active_spark_context
if errMsg is None:
return Column(sc._jvm.functions.assert_true(_to_java_column(col)))
if not isinstance(errMsg, (str, Column)):
raise TypeError(
"errMsg should be a Column or a str, got {}".format(type(errMsg))
)

errMsg = (
_create_column_from_literal(errMsg)
if isinstance(errMsg, str)
else _to_java_column(errMsg)
)
return Column(sc._jvm.functions.assert_true(_to_java_column(col), errMsg))


@since(3.1)
def raise_error(errMsg):
"""
Throws an exception with the provided error message.
"""
if not isinstance(errMsg, (str, Column)):
raise TypeError(
"errMsg should be a Column or a str, got {}".format(type(errMsg))
)

sc = SparkContext._active_spark_context
errMsg = (
_create_column_from_literal(errMsg)
if isinstance(errMsg, str)
else _to_java_column(errMsg)
)
return Column(sc._jvm.functions.raise_error(errMsg))


# ---------------------- String/Binary functions ------------------------------

_string_functions = {
Expand Down Expand Up @@ -3448,14 +3499,14 @@ def bucket(numBuckets, col):
... ).createOrReplace()
.. warning::
This function can be used only in combinatiion with
This function can be used only in combination with
:py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy`
method of the `DataFrameWriterV2`.
"""
if not isinstance(numBuckets, (int, Column)):
raise TypeError(
"numBuckets should be a Column or and int, got {}".format(type(numBuckets))
"numBuckets should be a Column or an int, got {}".format(type(numBuckets))
)

sc = SparkContext._active_spark_context
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def sha1(col: ColumnOrName) -> Column: ...
def sha2(col: ColumnOrName, numBits: int) -> Column: ...
def hash(*cols: ColumnOrName) -> Column: ...
def xxhash64(*cols: ColumnOrName) -> Column: ...
def assert_true(col: ColumnOrName, errMsg: Union[Column, str] = ...): ...
def raise_error(errMsg: Union[Column, str]): ...
def concat(*cols: ColumnOrName) -> Column: ...
def concat_ws(sep: str, *cols: ColumnOrName) -> Column: ...
def decode(col: ColumnOrName, charset: str) -> Column: ...
Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from itertools import chain
import re

from py4j.protocol import Py4JJavaError
from pyspark.sql import Row, Window
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, lit
from pyspark.testing.sqlutils import ReusedSQLTestCase
Expand Down Expand Up @@ -524,6 +525,55 @@ def test_datetime_functions(self):
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)'])

def test_assert_true(self):
from pyspark.sql.functions import assert_true

df = self.spark.range(3)

self.assertEquals(
df.select(assert_true(df.id < 3)).toDF("val").collect(),
[Row(val=None), Row(val=None), Row(val=None)],
)

with self.assertRaises(Py4JJavaError) as cm:
df.select(assert_true(df.id < 2, 'too big')).toDF("val").collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("too big", str(cm.exception))

with self.assertRaises(Py4JJavaError) as cm:
df.select(assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("2000000", str(cm.exception))

with self.assertRaises(TypeError) as cm:
df.select(assert_true(df.id < 2, 5))
self.assertEquals(
"errMsg should be a Column or a str, got <class 'int'>",
str(cm.exception)
)

def test_raise_error(self):
from pyspark.sql.functions import raise_error

df = self.spark.createDataFrame([Row(id="foobar")])

with self.assertRaises(Py4JJavaError) as cm:
df.select(raise_error(df.id)).collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("foobar", str(cm.exception))

with self.assertRaises(Py4JJavaError) as cm:
df.select(raise_error("barfoo")).collect()
self.assertIn("java.lang.RuntimeException", str(cm.exception))
self.assertIn("barfoo", str(cm.exception))

with self.assertRaises(TypeError) as cm:
df.select(raise_error(None))
self.assertEquals(
"errMsg should be a Column or a str, got <class 'NoneType'>",
str(cm.exception)
)


if __name__ == "__main__":
import unittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ object FunctionRegistry {

// misc functions
expression[AssertTrue]("assert_true"),
expression[RaiseError]("raise_error"),
expression[Crc32]("crc32"),
expression[Md5]("md5"),
expression[Uuid]("uuid"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,51 +53,81 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
}

/**
* A function throws an exception if 'condition' is not true.
* Throw with the result of an expression (used for debugging).
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.",
usage = "_FUNC_(expr) - Throws an exception with `expr`.",
examples = """
Examples:
> SELECT _FUNC_(0 < 1);
NULL
> SELECT _FUNC_('custom error message');
java.lang.RuntimeException
custom error message
""",
since = "2.0.0")
case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
since = "3.1.0")
case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

override def foldable: Boolean = false
override def nullable: Boolean = true

override def inputTypes: Seq[DataType] = Seq(BooleanType)

override def dataType: DataType = NullType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def prettyName: String = "assert_true"
override def prettyName: String = "raise_error"

private val errMsg = s"'${child.simpleString(SQLConf.get.maxToStringFields)}' is not true!"

override def eval(input: InternalRow) : Any = {
val v = child.eval(input)
if (v == null || java.lang.Boolean.FALSE.equals(v)) {
throw new RuntimeException(errMsg)
} else {
null
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
throw new RuntimeException()
}
throw new RuntimeException(value.toString)
}

// if (true) is to avoid codegen compilation exception that statement is unreachable
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ExprCode(
code = code"""${eval.code}
|if (true) {
| if (${eval.isNull}) {
| throw new RuntimeException();
| }
| throw new RuntimeException(${eval.value}.toString());
|}""".stripMargin,
isNull = TrueLiteral,
value = JavaCode.defaultLiteral(dataType)
)
}
}

// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
ExprCode(code = code"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
|}""".stripMargin, isNull = TrueLiteral,
value = JavaCode.defaultLiteral(dataType))
/**
* A function that throws an exception if 'condition' is not true.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.",
examples = """
Examples:
> SELECT _FUNC_(0 < 1);
NULL
""",
since = "2.0.0")
case class AssertTrue(left: Expression, right: Expression, child: Expression)
extends RuntimeReplaceable {

override def prettyName: String = "assert_true"

def this(left: Expression, right: Expression) = {
this(left, right, If(left, Literal(null), RaiseError(right)))
}

override def sql: String = s"assert_true(${child.sql})"
def this(left: Expression) = {
this(left, Literal(s"'${left.simpleString(SQLConf.get.maxToStringFields)}' is not true!"))
}

override def flatArguments: Iterator[Any] = Iterator(left, right)
override def exprsReplaced: Seq[Expression] = Seq(left, right)
}

object AssertTrue {
def apply(left: Expression): AssertTrue = new AssertTrue(left)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("SPARK-17160: field names are properly escaped by AssertTrue") {
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)).child :: Nil)
}

test("should not apply common subexpression elimination on conditional expressions") {
Expand Down
Loading

0 comments on commit 39510b0

Please sign in to comment.