Skip to content

Commit

Permalink
[SPARK-47060][SQL][TESTS] Check SparkIllegalArgumentException inste…
Browse files Browse the repository at this point in the history
…ad of `IllegalArgumentException` in `catalyst`

### What changes were proposed in this pull request?
In the PR, I propose to use `checkError()` in tests of `catalyst` to check `SparkIllegalArgumentException`, and its fields.

### Why are the changes needed?
By checking `SparkIllegalArgumentException` and its fields like error class and message parameters prevents replacing `SparkIllegalArgumentException` back to `IllegalArgumentException`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By running the modified test suites.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#45118 from MaxGekk/migrate-IllegalArgumentException-catalyst-tests.

Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
MaxGekk committed Feb 16, 2024
1 parent 61e25e1 commit 64fa13b
Show file tree
Hide file tree
Showing 17 changed files with 238 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.{SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -202,6 +202,7 @@ object IntervalUtils extends SparkIntervalUtils {
try {
f
} catch {
case e: SparkThrowable => throw e
case NonFatal(e) =>
throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3213",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.time.LocalDate

import org.json4s.JsonAST.{JArray, JBool, JDecimal, JDouble, JLong, JNull, JObject, JString, JValue}

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.encoders.{ExamplePoint, ExamplePointUDT}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -128,12 +128,17 @@ class RowJsonSuite extends SparkFunSuite {
}

test("unsupported type") {
val e = intercept[IllegalArgumentException] {
val row = new GenericRowWithSchema(
Array((1, 2)),
new StructType().add("a", ObjectType(classOf[(Int, Int)])))
row.jsonValue
}
assert(e.getMessage.contains("Failed to convert value"))
checkError(
exception = intercept[SparkIllegalArgumentException] {
val row = new GenericRowWithSchema(
Array((1, 2)),
new StructType().add("a", ObjectType(classOf[(Int, Int)])))
row.jsonValue
},
errorClass = "_LEGACY_ERROR_TEMP_3249",
parameters = Map(
"value" -> "(1,2)",
"valueClass" -> "class scala.Tuple2$mcII$sp",
"dataType" -> "ObjectType(class scala.Tuple2)"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.must.Matchers
import org.scalatest.matchers.should.Matchers._

import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -66,7 +66,7 @@ class RowTest extends AnyFunSpec with Matchers {
}

it("Accessing non existent field throws an exception") {
intercept[IllegalArgumentException] {
intercept[SparkIllegalArgumentException] {
sampleRow.getAs[String]("non_existent")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class CSVExprUtilsSuite extends SparkFunSuite {
// backslash, then tab
("""\\t""", Some("""\t"""), None),
// invalid special character (dot)
("""\.""", None, Some("Unsupported special character for delimiter")),
("""\.""", None, Some("_LEGACY_ERROR_TEMP_3236")),
// backslash, then dot
("""\\.""", Some("""\."""), None),
// nothing special, just straight conversion
Expand All @@ -90,17 +90,16 @@ class CSVExprUtilsSuite extends SparkFunSuite {
)

test("should correctly produce separator strings, or exceptions, from input") {
forAll(testCases) { (input, separatorStr, expectedErrorMsg) =>
forAll(testCases) { (input, separatorStr, expectedErrorClass) =>
try {
val separator = CSVExprUtils.toDelimiterStr(input)
assert(separatorStr.isDefined)
assert(expectedErrorMsg.isEmpty)
assert(expectedErrorClass.isEmpty)
assert(separator.equals(separatorStr.get))
} catch {
case e: IllegalArgumentException =>
case e: SparkIllegalArgumentException =>
assert(separatorStr.isEmpty)
assert(expectedErrorMsg.isDefined)
assert(e.getMessage.contains(expectedErrorMsg.get))
assert(e.getErrorClass === expectedErrorClass.get)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.{Locale, TimeZone}

import org.apache.commons.lang3.time.FastDateFormat

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -304,19 +304,23 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper {
filters = Seq(EqualTo("d", 3.14)),
expected = Some(InternalRow(1, 3.14)))

val errMsg = intercept[IllegalArgumentException] {
check(filters = Seq(EqualTo("invalid attr", 1)), expected = None)
}.getMessage
assert(errMsg.contains("invalid attr does not exist"))

val errMsg2 = intercept[IllegalArgumentException] {
check(
dataSchema = new StructType(),
requiredSchema = new StructType(),
filters = Seq(EqualTo("i", 1)),
expected = Some(InternalRow.empty))
}.getMessage
assert(errMsg2.contains("i does not exist"))
checkError(
exception = intercept[SparkIllegalArgumentException] {
check(filters = Seq(EqualTo("invalid attr", 1)), expected = None)
},
errorClass = "_LEGACY_ERROR_TEMP_3252",
parameters = Map("name" -> "invalid attr", "fieldNames" -> "i"))

checkError(
exception = intercept[SparkIllegalArgumentException] {
check(
dataSchema = new StructType(),
requiredSchema = new StructType(),
filters = Seq(EqualTo("i", 1)),
expected = Some(InternalRow.empty))
},
errorClass = "_LEGACY_ERROR_TEMP_3252",
parameters = Map("name" -> "i", "fieldNames" -> ""))
}

test("SPARK-30960: parse date/timestamp string with legacy format") {
Expand Down Expand Up @@ -366,9 +370,11 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper {
check(new UnivocityParser(StructType(Seq.empty), optionsWithPattern(true)))

// With legacy parser disabled, parsing results in error.
val err = intercept[IllegalArgumentException] {
check(new UnivocityParser(StructType(Seq.empty), optionsWithPattern(false)))
}
assert(err.getMessage.contains("Illegal pattern character: n"))
checkError(
exception = intercept[SparkIllegalArgumentException] {
check(new UnivocityParser(StructType(Seq.empty), optionsWithPattern(false)))
},
errorClass = "_LEGACY_ERROR_TEMP_3258",
parameters = Map("c" -> "n"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period}
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand Down Expand Up @@ -1105,9 +1105,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {

Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH")
.foreach { interval =>
checkExceptionInExpression[IllegalArgumentException](
checkErrorInExpression[SparkIllegalArgumentException](
cast(Literal.create(interval), YearMonthIntervalType()),
"Error parsing interval year-month string: integer overflow")
"_LEGACY_ERROR_TEMP_3213",
Map("interval" -> "year-month", "msg" -> "integer overflow"))
}

Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue)
Expand Down Expand Up @@ -1173,13 +1174,15 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {

Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
val dataType = YearMonthIntervalType()
val expectedMsg = s"Interval string does not match year-month format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval"
checkExceptionInExpression[IllegalArgumentException](
checkErrorInExpression[SparkIllegalArgumentException](
cast(Literal.create(interval), dataType),
expectedMsg
"_LEGACY_ERROR_TEMP_3214",
Map(
"fallBackNotice" -> "",
"typeName" -> "interval year to month",
"intervalStr" -> "year-month",
"supportedFormat" -> "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`",
"input" -> interval)
)
}
Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
Expand All @@ -1193,13 +1196,17 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR)),
("INTERVAL '1' MONTH", YearMonthIntervalType(YEAR, MONTH)))
.foreach { case (interval, dataType) =>
val expectedMsg = s"Interval string does not match year-month format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval"
checkExceptionInExpression[IllegalArgumentException](
checkErrorInExpression[SparkIllegalArgumentException](
cast(Literal.create(interval), dataType),
expectedMsg)
"_LEGACY_ERROR_TEMP_3214",
Map(
"fallBackNotice" -> "",
"typeName" -> dataType.typeName,
"intervalStr" -> "year-month",
"supportedFormat" ->
IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", "),
"input" -> interval))
}
}

Expand Down Expand Up @@ -1313,15 +1320,17 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
("1.23", DayTimeIntervalType(MINUTE)),
("1.23", DayTimeIntervalType(MINUTE)))
.foreach { case (interval, dataType) =>
val expectedMsg = s"Interval string does not match day-time format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval, " +
s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
"to restore the behavior before Spark 3.0."
checkExceptionInExpression[IllegalArgumentException](
checkErrorInExpression[SparkIllegalArgumentException](
cast(Literal.create(interval), dataType),
expectedMsg
"_LEGACY_ERROR_TEMP_3214",
Map("fallBackNotice" -> (", set spark.sql.legacy.fromDayTimeString.enabled" +
" to true to restore the behavior before Spark 3.0."),
"intervalStr" -> "day-time",
"typeName" -> dataType.typeName,
"input" -> interval,
"supportedFormat" ->
IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", "))
)
}

Expand All @@ -1337,15 +1346,17 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
("INTERVAL '1537228672801:54.7757' MINUTE TO SECOND", DayTimeIntervalType(MINUTE, SECOND)),
("INTERVAL '92233720368541.775807' SECOND", DayTimeIntervalType(SECOND)))
.foreach { case (interval, dataType) =>
val expectedMsg = "Interval string does not match day-time format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval, " +
s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
"to restore the behavior before Spark 3.0."
checkExceptionInExpression[IllegalArgumentException](
checkErrorInExpression[SparkIllegalArgumentException](
cast(Literal.create(interval), dataType),
expectedMsg)
"_LEGACY_ERROR_TEMP_3214",
Map("fallBackNotice" -> (", set spark.sql.legacy.fromDayTimeString.enabled" +
" to true to restore the behavior before Spark 3.0."),
"intervalStr" -> "day-time",
"typeName" -> dataType.typeName,
"input" -> interval,
"supportedFormat" ->
IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.Random

import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkException, SparkFunSuite, SparkUpgradeException}
import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkException, SparkFunSuite, SparkIllegalArgumentException, SparkUpgradeException}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
Expand Down Expand Up @@ -434,9 +434,12 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

withSQLConf((SQLConf.ANSI_ENABLED.key, "true")) {
checkExceptionInExpression[IllegalArgumentException](
checkErrorInExpression[SparkIllegalArgumentException](
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))),
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date")
"_LEGACY_ERROR_TEMP_2000",
Map("message" ->
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date",
"ansiConfig" -> "\"spark.sql.ansi.enabled\""))
}

withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
Expand Down Expand Up @@ -1499,7 +1502,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq('q', 'Q', 'e', 'c', 'A', 'n', 'N', 'p').foreach { l =>
checkException[IllegalArgumentException](l.toString)
checkException[SparkIllegalArgumentException](l.toString)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB
checkErrorInExpression[T](expression, InternalRow.empty, errorClass, parameters)
}

protected def checkErrorInExpression[T <: SparkThrowable : ClassTag](
expression: => Expression,
inputRow: InternalRow,
errorClass: String): Unit = {
checkErrorInExpression[T](expression, inputRow, errorClass, Map.empty[String, String])
}

protected def checkErrorInExpression[T <: SparkThrowable : ClassTag](
expression: => Expression,
inputRow: InternalRow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva
"interval 1 year 2 month",
"interval '1' year '2' month",
"\tinterval '1-2' year to month").foreach { interval =>
intercept[IllegalArgumentException] {
intercept[SparkIllegalArgumentException] {
TimeWindow(Literal(10L, TimestampType), interval, interval, interval)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import scala.reflect.ClassTag

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkFunSuite, SparkThrowable}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT
import org.apache.spark.sql.types._
Expand All @@ -44,6 +44,14 @@ class TryCastSuite extends CastWithAnsiOnSuite {
checkEvaluation(expression, null, inputRow)
}

override def checkErrorInExpression[T <: SparkThrowable : ClassTag](
expression: => Expression,
inputRow: InternalRow,
errorClass: String,
parameters: Map[String, String]): Unit = {
checkEvaluation(expression, null, inputRow)
}

override def checkCastToBooleanError(l: Literal, to: DataType, tryCastResult: Any): Unit = {
checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.streaming

import java.util.Locale

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.streaming.OutputMode

class InternalOutputModesSuite extends SparkFunSuite {
Expand All @@ -40,7 +40,7 @@ class InternalOutputModesSuite extends SparkFunSuite {
test("unsupported strings") {
def testMode(outputMode: String): Unit = {
val acceptedModes = Seq("append", "update", "complete")
val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode))
val e = intercept[SparkIllegalArgumentException](InternalOutputModes(outputMode))
(Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
Expand Down
Loading

0 comments on commit 64fa13b

Please sign in to comment.