diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 959edf29e2429..ce384a64bccaf 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -826,6 +826,55 @@ setMethod("xxhash64", column(jc) }) +#' @details +#' \code{assert_true}: Returns null if the input column is true; throws an exception +#' with the provided error message otherwise. +#' +#' @param errMsg (optional) The error message to be thrown. +#' +#' @rdname column_misc_functions +#' @aliases assert_true assert_true,Column-method +#' @examples +#' \dontrun{ +#' tmp <- mutate(df, v1 = assert_true(df$vs < 2), +#' v2 = assert_true(df$vs < 2, "custom error message"), +#' v3 = assert_true(df$vs < 2, df$vs)) +#' head(tmp)} +#' @note assert_true since 3.1.0 +setMethod("assert_true", + signature(x = "Column"), + function(x, errMsg = NULL) { + jc <- if (is.null(errMsg)) { + callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc) + } else { + if (is.character(errMsg) && length(errMsg) == 1) { + errMsg <- lit(errMsg) + } + callJStatic("org.apache.spark.sql.functions", "assert_true", x@jc, errMsg@jc) + } + column(jc) + }) + +#' @details +#' \code{raise_error}: Throws an exception with the provided error message. +#' +#' @rdname column_misc_functions +#' @aliases raise_error raise_error,characterOrColumn-method +#' @examples +#' \dontrun{ +#' tmp <- mutate(df, v1 = raise_error("error message")) +#' head(tmp)} +#' @note raise_error since 3.1.0 +setMethod("raise_error", + signature(x = "characterOrColumn"), + function(x) { + if (is.character(x) && length(x) == 1) { + x <- lit(x) + } + jc <- callJStatic("org.apache.spark.sql.functions", "raise_error", x@jc) + column(jc) + }) + #' @details #' \code{dayofmonth}: Extracts the day of the month as an integer from a #' given date/timestamp/string. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b9cf0261adc28..6b732e594cd3f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -850,6 +850,10 @@ setGeneric("arrays_zip_with", function(x, y, f) { standardGeneric("arrays_zip_wi #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) +#' @rdname column_misc_functions +#' @name NULL +setGeneric("assert_true", function(x, errMsg = NULL) { standardGeneric("assert_true") }) + #' @param x Column to compute on or a GroupedData object. #' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg @@ -1223,6 +1227,10 @@ setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") #' @name NULL setGeneric("quarter", function(x) { standardGeneric("quarter") }) +#' @rdname column_misc_functions +#' @name NULL +setGeneric("raise_error", function(x) { standardGeneric("raise_error") }) + #' @rdname column_nonaggregate_functions #' @name NULL setGeneric("rand", function(seed) { standardGeneric("rand") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 2ac3093e77ea8..268f5734813ba 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3945,6 +3945,24 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", { dropTempView("cars") }) +test_that("assert_true, raise_error", { + df <- read.json(jsonPath) + filtered <- filter(df, "age < 20") + + expect_equal(collect(select(filtered, assert_true(filtered$age < 20)))$age, c(NULL)) + expect_equal(collect(select(filtered, assert_true(filtered$age < 20, "error message")))$age, + c(NULL)) + expect_equal(collect(select(filtered, assert_true(filtered$age < 20, filtered$name)))$age, + c(NULL)) + expect_error(collect(select(df, assert_true(df$age < 20))), "is not true!") + expect_error(collect(select(df, assert_true(df$age < 20, "error message"))), + "error message") + expect_error(collect(select(df, assert_true(df$age < 20, df$name))), "Michael") + + expect_error(collect(select(filtered, raise_error("error message"))), "error message") + expect_error(collect(select(filtered, raise_error(filtered$name))), "Justin") +}) + compare_list <- function(list1, list2) { # get testthat to show the diff by first making the 2 lists equal in length expect_equal(length(list1), length(list2)) diff --git a/python/docs/source/reference/pyspark.sql.rst b/python/docs/source/reference/pyspark.sql.rst index 692d098c89cdc..0ed2f1b86ada5 100644 --- a/python/docs/source/reference/pyspark.sql.rst +++ b/python/docs/source/reference/pyspark.sql.rst @@ -292,6 +292,7 @@ Functions asc_nulls_last ascii asin + assert_true atan atan2 avg @@ -420,6 +421,7 @@ Functions pow quarter radians + raise_error rand randn rank diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7007d505d048d..97146fdb804ab 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1592,6 +1592,57 @@ def xxhash64(*cols): return Column(jc) +@since(3.1) +def assert_true(col, errMsg=None): + """ + Returns null if the input column is true; throws an exception with the provided error message + otherwise. + + >>> df = spark.createDataFrame([(0,1)], ['a', 'b']) + >>> df.select(assert_true(df.a < df.b).alias('r')).collect() + [Row(r=None)] + >>> df = spark.createDataFrame([(0,1)], ['a', 'b']) + >>> df.select(assert_true(df.a < df.b, df.a).alias('r')).collect() + [Row(r=None)] + >>> df = spark.createDataFrame([(0,1)], ['a', 'b']) + >>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect() + [Row(r=None)] + """ + sc = SparkContext._active_spark_context + if errMsg is None: + return Column(sc._jvm.functions.assert_true(_to_java_column(col))) + if not isinstance(errMsg, (str, Column)): + raise TypeError( + "errMsg should be a Column or a str, got {}".format(type(errMsg)) + ) + + errMsg = ( + _create_column_from_literal(errMsg) + if isinstance(errMsg, str) + else _to_java_column(errMsg) + ) + return Column(sc._jvm.functions.assert_true(_to_java_column(col), errMsg)) + + +@since(3.1) +def raise_error(errMsg): + """ + Throws an exception with the provided error message. + """ + if not isinstance(errMsg, (str, Column)): + raise TypeError( + "errMsg should be a Column or a str, got {}".format(type(errMsg)) + ) + + sc = SparkContext._active_spark_context + errMsg = ( + _create_column_from_literal(errMsg) + if isinstance(errMsg, str) + else _to_java_column(errMsg) + ) + return Column(sc._jvm.functions.raise_error(errMsg)) + + # ---------------------- String/Binary functions ------------------------------ _string_functions = { @@ -3448,14 +3499,14 @@ def bucket(numBuckets, col): ... ).createOrReplace() .. warning:: - This function can be used only in combinatiion with + This function can be used only in combination with :py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy` method of the `DataFrameWriterV2`. """ if not isinstance(numBuckets, (int, Column)): raise TypeError( - "numBuckets should be a Column or and int, got {}".format(type(numBuckets)) + "numBuckets should be a Column or an int, got {}".format(type(numBuckets)) ) sc = SparkContext._active_spark_context diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 8efe65205315e..6249bca5cef68 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -137,6 +137,8 @@ def sha1(col: ColumnOrName) -> Column: ... def sha2(col: ColumnOrName, numBits: int) -> Column: ... def hash(*cols: ColumnOrName) -> Column: ... def xxhash64(*cols: ColumnOrName) -> Column: ... +def assert_true(col: ColumnOrName, errMsg: Union[Column, str] = ...): ... +def raise_error(errMsg: Union[Column, str]): ... def concat(*cols: ColumnOrName) -> Column: ... def concat_ws(sep: str, *cols: ColumnOrName) -> Column: ... def decode(col: ColumnOrName, charset: str) -> Column: ... diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 8d05ed28b8d4e..26d260fe77b0c 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -19,6 +19,7 @@ from itertools import chain import re +from py4j.protocol import Py4JJavaError from pyspark.sql import Row, Window from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, lit from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -524,6 +525,55 @@ def test_datetime_functions(self): parse_result = df.select(functions.to_date(functions.col("dateCol"))).first() self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)']) + def test_assert_true(self): + from pyspark.sql.functions import assert_true + + df = self.spark.range(3) + + self.assertEquals( + df.select(assert_true(df.id < 3)).toDF("val").collect(), + [Row(val=None), Row(val=None), Row(val=None)], + ) + + with self.assertRaises(Py4JJavaError) as cm: + df.select(assert_true(df.id < 2, 'too big')).toDF("val").collect() + self.assertIn("java.lang.RuntimeException", str(cm.exception)) + self.assertIn("too big", str(cm.exception)) + + with self.assertRaises(Py4JJavaError) as cm: + df.select(assert_true(df.id < 2, df.id * 1e6)).toDF("val").collect() + self.assertIn("java.lang.RuntimeException", str(cm.exception)) + self.assertIn("2000000", str(cm.exception)) + + with self.assertRaises(TypeError) as cm: + df.select(assert_true(df.id < 2, 5)) + self.assertEquals( + "errMsg should be a Column or a str, got ", + str(cm.exception) + ) + + def test_raise_error(self): + from pyspark.sql.functions import raise_error + + df = self.spark.createDataFrame([Row(id="foobar")]) + + with self.assertRaises(Py4JJavaError) as cm: + df.select(raise_error(df.id)).collect() + self.assertIn("java.lang.RuntimeException", str(cm.exception)) + self.assertIn("foobar", str(cm.exception)) + + with self.assertRaises(Py4JJavaError) as cm: + df.select(raise_error("barfoo")).collect() + self.assertIn("java.lang.RuntimeException", str(cm.exception)) + self.assertIn("barfoo", str(cm.exception)) + + with self.assertRaises(TypeError) as cm: + df.select(raise_error(None)) + self.assertEquals( + "errMsg should be a Column or a str, got ", + str(cm.exception) + ) + if __name__ == "__main__": import unittest diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3fae34cbf00c2..508239077a70e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -479,6 +479,7 @@ object FunctionRegistry { // misc functions expression[AssertTrue]("assert_true"), + expression[RaiseError]("raise_error"), expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Uuid]("uuid"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 1eec26c8e987a..4e71c8c103889 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -53,51 +53,81 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { } /** - * A function throws an exception if 'condition' is not true. + * Throw with the result of an expression (used for debugging). */ @ExpressionDescription( - usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.", + usage = "_FUNC_(expr) - Throws an exception with `expr`.", examples = """ Examples: - > SELECT _FUNC_(0 < 1); - NULL + > SELECT _FUNC_('custom error message'); + java.lang.RuntimeException + custom error message """, - since = "2.0.0") -case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + since = "3.1.0") +case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def foldable: Boolean = false override def nullable: Boolean = true - - override def inputTypes: Seq[DataType] = Seq(BooleanType) - override def dataType: DataType = NullType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - override def prettyName: String = "assert_true" + override def prettyName: String = "raise_error" - private val errMsg = s"'${child.simpleString(SQLConf.get.maxToStringFields)}' is not true!" - - override def eval(input: InternalRow) : Any = { - val v = child.eval(input) - if (v == null || java.lang.Boolean.FALSE.equals(v)) { - throw new RuntimeException(errMsg) - } else { - null + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + throw new RuntimeException() } + throw new RuntimeException(value.toString) } + // if (true) is to avoid codegen compilation exception that statement is unreachable override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) + ExprCode( + code = code"""${eval.code} + |if (true) { + | if (${eval.isNull}) { + | throw new RuntimeException(); + | } + | throw new RuntimeException(${eval.value}.toString()); + |}""".stripMargin, + isNull = TrueLiteral, + value = JavaCode.defaultLiteral(dataType) + ) + } +} - // Use unnamed reference that doesn't create a local field here to reduce the number of fields - // because errMsgField is used only when the value is null or false. - val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = code"""${eval.code} - |if (${eval.isNull} || !${eval.value}) { - | throw new RuntimeException($errMsgField); - |}""".stripMargin, isNull = TrueLiteral, - value = JavaCode.defaultLiteral(dataType)) +/** + * A function that throws an exception if 'condition' is not true. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Throws an exception if `expr` is not true.", + examples = """ + Examples: + > SELECT _FUNC_(0 < 1); + NULL + """, + since = "2.0.0") +case class AssertTrue(left: Expression, right: Expression, child: Expression) + extends RuntimeReplaceable { + + override def prettyName: String = "assert_true" + + def this(left: Expression, right: Expression) = { + this(left, right, If(left, Literal(null), RaiseError(right))) } - override def sql: String = s"assert_true(${child.sql})" + def this(left: Expression) = { + this(left, Literal(s"'${left.simpleString(SQLConf.get.maxToStringFields)}' is not true!")) + } + + override def flatArguments: Iterator[Any] = Iterator(left, right) + override def exprsReplaced: Seq[Expression] = Seq(left, right) +} + +object AssertTrue { + def apply(left: Expression): AssertTrue = new AssertTrue(left) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f1de63adc3d9a..adaabfe4d32bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -332,7 +332,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-17160: field names are properly escaped by AssertTrue") { - GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil) + GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)).child :: Nil) } test("should not apply common subexpression elimination on conditional expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 341b26ddf6575..d0b0d04d1f719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -164,7 +164,11 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB val errMsg = intercept[T] { eval }.getMessage - if (!errMsg.contains(expectedErrMsg)) { + if (errMsg == null) { + if (expectedErrMsg != null) { + fail(s"Expected null error message, but `$errMsg` found") + } + } else if (!errMsg.contains(expectedErrMsg)) { fail(s"Expected error message is `$expectedErrMsg`, but `$errMsg` found") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 4b2d153a28cc8..d42081024c1dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -26,21 +26,21 @@ import org.apache.spark.sql.types._ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("assert_true") { - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null) - } - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null) - } - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Literal.create(null, NullType)), null) - } - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null) - } - checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null) - checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) + test("RaiseError") { + checkExceptionInExpression[RuntimeException]( + RaiseError(Literal("error message")), + EmptyRow, + "error message" + ) + + checkExceptionInExpression[RuntimeException]( + RaiseError(Literal.create(null, StringType)), + EmptyRow, + null + ) + + // Expects a string + assert(RaiseError(Literal(5)).checkInputDataTypes().isFailure) } test("uuid") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2c545fe762b6d..2efe5aae09709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2318,6 +2318,36 @@ object functions { new XxHash64(cols.map(_.expr)) } + /** + * Returns null if the condition is true, and throws an exception otherwise. + * + * @group misc_funcs + * @since 3.1.0 + */ + def assert_true(c: Column): Column = withExpr { + new AssertTrue(c.expr) + } + + /** + * Returns null if the condition is true; throws an exception with the error message otherwise. + * + * @group misc_funcs + * @since 3.1.0 + */ + def assert_true(c: Column, e: Column): Column = withExpr { + new AssertTrue(c.expr, e.expr) + } + + /** + * Throws an exception with the provided error message. + * + * @group misc_funcs + * @since 3.1.0 + */ + def raise_error(c: Column): Column = withExpr { + RaiseError(c.expr) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 473204c182a69..1675fb1cc7c62 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ ## Summary - - Number of queries: 340 + - Number of queries: 341 - Number of expressions that missing example: 13 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window ## Schema of Built-in Functions @@ -34,7 +34,7 @@ | org.apache.spark.sql.catalyst.expressions.Ascii | ascii | SELECT ascii('222') | struct | | org.apache.spark.sql.catalyst.expressions.Asin | asin | SELECT asin(0) | struct | | org.apache.spark.sql.catalyst.expressions.Asinh | asinh | SELECT asinh(0) | struct | -| org.apache.spark.sql.catalyst.expressions.AssertTrue | assert_true | SELECT assert_true(0 < 1) | struct | +| org.apache.spark.sql.catalyst.expressions.AssertTrue | assert_true | SELECT assert_true(0 < 1) | struct | | org.apache.spark.sql.catalyst.expressions.Atan | atan | SELECT atan(0) | struct | | org.apache.spark.sql.catalyst.expressions.Atan2 | atan2 | SELECT atan2(0, 0) | struct | | org.apache.spark.sql.catalyst.expressions.Atanh | atanh | SELECT atanh(0) | struct | @@ -210,6 +210,7 @@ | org.apache.spark.sql.catalyst.expressions.Pow | power | SELECT power(2, 3) | struct | | org.apache.spark.sql.catalyst.expressions.Quarter | quarter | SELECT quarter('2016-08-31') | struct | | org.apache.spark.sql.catalyst.expressions.RLike | rlike | SELECT '%SystemDrive%\Users\John' rlike '%SystemDrive%\\Users.*' | struct<%SystemDrive%UsersJohn RLIKE %SystemDrive%\Users.*:boolean> | +| org.apache.spark.sql.catalyst.expressions.RaiseError | raise_error | SELECT raise_error('custom error message') | struct | | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct | | org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | diff --git a/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql index 95f71925e9294..907ff33000d8e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql @@ -8,3 +8,15 @@ select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2); select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days'); select typeof(x'ABCD'), typeof('SPARK'); select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark')); + +-- Spark-32793: Rewrite AssertTrue with RaiseError +SELECT assert_true(true), assert_true(boolean(1)); +SELECT assert_true(false); +SELECT assert_true(boolean(0)); +SELECT assert_true(null); +SELECT assert_true(boolean(null)); +SELECT assert_true(false, 'custom error message'); + +CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v); +SELECT raise_error('error message'); +SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc; diff --git a/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out index bd8ffb82ee129..bf45ec3d10215 100644 --- a/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/misc-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 16 -- !query @@ -56,3 +56,82 @@ select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', struct -- !query output array map struct + + +-- !query +SELECT assert_true(true), assert_true(boolean(1)) +-- !query schema +struct +-- !query output +NULL NULL + + +-- !query +SELECT assert_true(false) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +'false' is not true! + + +-- !query +SELECT assert_true(boolean(0)) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +'cast(0 as boolean)' is not true! + + +-- !query +SELECT assert_true(null) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +'null' is not true! + + +-- !query +SELECT assert_true(boolean(null)) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +'cast(null as boolean)' is not true! + + +-- !query +SELECT assert_true(false, 'custom error message') +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +custom error message + + +-- !query +CREATE TEMPORARY VIEW tbl_misc AS SELECT * FROM (VALUES (1), (8), (2)) AS T(v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT raise_error('error message') +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +error message + + +-- !query +SELECT if(v > 5, raise_error('too big: ' || v), v + 1) FROM tbl_misc +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +too big: 8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b11f4c603dfd6..937de92bcaba6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.matchers.should.Matchers._ +import org.apache.spark.SparkException import org.apache.spark.sql.UpdateFieldsBenchmark._ import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression} import org.apache.spark.sql.execution.ProjectExec @@ -2302,4 +2303,54 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } } } + + test("assert_true") { + // assert_true(condition, errMsgCol) + val booleanDf = Seq((true), (false)).toDF("cond") + checkAnswer( + booleanDf.filter("cond = true").select(assert_true($"cond")), + Row(null) :: Nil + ) + val e1 = intercept[SparkException] { + booleanDf.select(assert_true($"cond", lit(null.asInstanceOf[String]))).collect() + } + assert(e1.getCause.isInstanceOf[RuntimeException]) + assert(e1.getCause.getMessage == null) + + val nullDf = Seq(("first row", None), ("second row", Some(true))).toDF("n", "cond") + checkAnswer( + nullDf.filter("cond = true").select(assert_true($"cond", $"cond")), + Row(null) :: Nil + ) + val e2 = intercept[SparkException] { + nullDf.select(assert_true($"cond", $"n")).collect() + } + assert(e2.getCause.isInstanceOf[RuntimeException]) + assert(e2.getCause.getMessage == "first row") + + // assert_true(condition) + val intDf = Seq((0, 1)).toDF("a", "b") + checkAnswer(intDf.select(assert_true($"a" < $"b")), Row(null) :: Nil) + val e3 = intercept[SparkException] { + intDf.select(assert_true($"a" > $"b")).collect() + } + assert(e3.getCause.isInstanceOf[RuntimeException]) + assert(e3.getCause.getMessage == "'('a > 'b)' is not true!") + } + + test("raise_error") { + val strDf = Seq(("hello")).toDF("a") + + val e1 = intercept[SparkException] { + strDf.select(raise_error(lit(null.asInstanceOf[String]))).collect() + } + assert(e1.getCause.isInstanceOf[RuntimeException]) + assert(e1.getCause.getMessage == null) + + val e2 = intercept[SparkException] { + strDf.select(raise_error($"a")).collect() + } + assert(e2.getCause.isInstanceOf[RuntimeException]) + assert(e2.getCause.getMessage == "hello") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index f487a30c8dfa3..9f62ff8301ebc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -163,7 +163,9 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { "org.apache.spark.sql.catalyst.expressions.InputFileBlockLength", // The example calls methods that return unstable results. "org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection", - "org.apache.spark.sql.catalyst.expressions.SparkVersion") + "org.apache.spark.sql.catalyst.expressions.SparkVersion", + // Throws an error + "org.apache.spark.sql.catalyst.expressions.RaiseError") val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector) parFuncs.foreach { funcId => @@ -197,9 +199,16 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression], classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression]) - // Do not check these expressions, because these expressions extend NullIntolerant - // and override the eval method to avoid evaluating input1 if input2 is 0. - val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod]) + // Do not check these expressions, because these expressions override the eval method + val ignoreSet = Set( + // Extend NullIntolerant and avoid evaluating input1 if input2 is 0 + classOf[IntegralDivide], + classOf[Divide], + classOf[Remainder], + classOf[Pmod], + // Throws an exception, even if input is null + classOf[RaiseError] + ) val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction() .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)