Skip to content

Commit

Permalink
[SPARK-22543][SQL] fix java 64kb compile error for deeply nested expr…
Browse files Browse the repository at this point in the history
…essions

## What changes were proposed in this pull request?

A frequently reported issue of Spark is the Java 64kb compile error. This is because Spark generates a very big method and it's usually caused by 3 reasons:

1. a deep expression tree, e.g. a very complex filter condition
2. many individual expressions, e.g. expressions can have many children, operators can have many expressions.
3. a deep query plan tree (with whole stage codegen)

This PR focuses on 1. There are already several patches(apache#15620  apache#18972 apache#18641) trying to fix this issue and some of them are already merged. However this is an endless job as every non-leaf expression has this issue.

This PR proposes to fix this issue in `Expression.genCode`, to make sure the code for a single expression won't grow too big.

According to maropu 's benchmark, no regression is found with TPCDS (thanks maropu !): https://docs.google.com/spreadsheets/d/1K3_7lX05-ZgxDXi9X_GleNnDjcnJIfoSlSCDZcL4gdg/edit?usp=sharing

## How was this patch tested?

existing test

Author: Wenchen Fan <[email protected]>
Author: Wenchen Fan <[email protected]>

Closes apache#19767 from cloud-fan/codegen.
  • Loading branch information
cloud-fan authored and gatorsmile committed Nov 22, 2017
1 parent 327d25f commit 0605ad7
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,48 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val ve = doGenCode(ctx, ExprCode("", isNull, value))
if (ve.code.nonEmpty) {
val eval = doGenCode(ctx, ExprCode("", isNull, value))
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
// Add `this` in the comment.
ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim)
eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
} else {
ve
eval
}
}
}

private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
val globalIsNull = ctx.freshName("globalIsNull")
ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull)
val localIsNull = eval.isNull
eval.isNull = globalIsNull
s"$globalIsNull = $localIsNull;"
} else {
""
}

val javaType = ctx.javaType(dataType)
val newValue = ctx.freshName("value")

val funcName = ctx.freshName(nodeName)
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
| ${eval.code.trim}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)

eval.value = newValue
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}

/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,36 +930,6 @@ class CodegenContext {
}
}

/**
* Wrap the generated code of expression, which was created from a row object in INPUT_ROW,
* by a function. ev.isNull and ev.value are passed by global variables
*
* @param ev the code to evaluate expressions.
* @param dataType the data type of ev.value.
* @param baseFuncName the split function name base.
*/
def createAndAddFunction(
ev: ExprCode,
dataType: DataType,
baseFuncName: String): (String, String, String) = {
val globalIsNull = freshName("isNull")
addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;")
val globalValue = freshName("value")
addMutableState(javaType(dataType), globalValue,
s"$globalValue = ${defaultValue(dataType)};")
val funcName = freshName(baseFuncName)
val funcBody =
s"""
|private void $funcName(InternalRow ${INPUT_ROW}) {
| ${ev.code.trim}
| $globalIsNull = ${ev.isNull};
| $globalValue = ${ev.value};
|}
""".stripMargin
val fullFuncName = addNewFunction(funcName, funcBody)
(fullFuncName, globalIsNull, globalValue)
}

/**
* Perform a function which generates a sequence of ExprCodes with a given mapping between
* expressions and common expressions, instead of using the mapping in current context.
Expand Down Expand Up @@ -1065,7 +1035,8 @@ class CodegenContext {
* elimination will be performed. Subexpression elimination assumes that the code for each
* expression will be combined in the `expressions` order.
*/
def generateExpressions(expressions: Seq[Expression],
def generateExpressions(
expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[ExprCode] = {
if (doSubexpressionElimination) subexpressionElimination(expressions)
expressions.map(e => e.genCode(this))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,52 +64,22 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
val trueEval = trueValue.genCode(ctx)
val falseEval = falseValue.genCode(ctx)

// place generated code of condition, true value and false value in separate methods if
// their code combined is large
val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length
val generatedCode = if (combinedLength > 1024 &&
// Split these expressions only if they are created from a row object
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {

val (condFuncName, condGlobalIsNull, condGlobalValue) =
ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr")
val (trueFuncName, trueGlobalIsNull, trueGlobalValue) =
ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr")
val (falseFuncName, falseGlobalIsNull, falseGlobalValue) =
ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr")
val code =
s"""
$condFuncName(${ctx.INPUT_ROW});
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!$condGlobalIsNull && $condGlobalValue) {
$trueFuncName(${ctx.INPUT_ROW});
${ev.isNull} = $trueGlobalIsNull;
${ev.value} = $trueGlobalValue;
} else {
$falseFuncName(${ctx.INPUT_ROW});
${ev.isNull} = $falseGlobalIsNull;
${ev.value} = $falseGlobalValue;
}
"""
}
else {
s"""
${condEval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${condEval.isNull} && ${condEval.value}) {
${trueEval.code}
${ev.isNull} = ${trueEval.isNull};
${ev.value} = ${trueEval.value};
} else {
${falseEval.code}
${ev.isNull} = ${falseEval.isNull};
${ev.value} = ${falseEval.value};
}
"""
}

ev.copy(code = generatedCode)
|${condEval.code}
|boolean ${ev.isNull} = false;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|if (!${condEval.isNull} && ${condEval.value}) {
| ${trueEval.code}
| ${ev.isNull} = ${trueEval.isNull};
| ${ev.value} = ${trueEval.value};
|} else {
| ${falseEval.code}
| ${ev.isNull} = ${falseEval.isNull};
| ${ev.value} = ${falseEval.value};
|}
""".stripMargin
ev.copy(code = code)
}

override def toString: String = s"if ($predicate) $trueValue else $falseValue"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ case class Alias(child: Expression, name: String)(

/** Just a simple passthrough for code generation. */
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
throw new IllegalStateException("Alias.doGenCode should not be called.")
}

override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,46 +378,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
val eval2 = right.genCode(ctx)

// The result should be `false`, if any of them is `false` whenever the other is null or not.

// place generated code of eval1 and eval2 in separate methods if their code combined is large
val combinedLength = eval1.code.length + eval2.code.length
if (combinedLength > 1024 &&
// Split these expressions only if they are created from a row object
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {

val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) =
ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr")
val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) =
ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr")
if (!left.nullable && !right.nullable) {
val generatedCode = s"""
$eval1FuncName(${ctx.INPUT_ROW});
boolean ${ev.value} = false;
if (${eval1GlobalValue}) {
$eval2FuncName(${ctx.INPUT_ROW});
${ev.value} = ${eval2GlobalValue};
}
"""
ev.copy(code = generatedCode, isNull = "false")
} else {
val generatedCode = s"""
$eval1FuncName(${ctx.INPUT_ROW});
boolean ${ev.isNull} = false;
boolean ${ev.value} = false;
if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) {
} else {
$eval2FuncName(${ctx.INPUT_ROW});
if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) {
} else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) {
${ev.value} = true;
} else {
${ev.isNull} = true;
}
}
"""
ev.copy(code = generatedCode)
}
} else if (!left.nullable && !right.nullable) {
if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
${eval1.code}
boolean ${ev.value} = false;
Expand Down Expand Up @@ -480,46 +441,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
val eval2 = right.genCode(ctx)

// The result should be `true`, if any of them is `true` whenever the other is null or not.

// place generated code of eval1 and eval2 in separate methods if their code combined is large
val combinedLength = eval1.code.length + eval2.code.length
if (combinedLength > 1024 &&
// Split these expressions only if they are created from a row object
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {

val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) =
ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr")
val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) =
ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr")
if (!left.nullable && !right.nullable) {
val generatedCode = s"""
$eval1FuncName(${ctx.INPUT_ROW});
boolean ${ev.value} = true;
if (!${eval1GlobalValue}) {
$eval2FuncName(${ctx.INPUT_ROW});
${ev.value} = ${eval2GlobalValue};
}
"""
ev.copy(code = generatedCode, isNull = "false")
} else {
val generatedCode = s"""
$eval1FuncName(${ctx.INPUT_ROW});
boolean ${ev.isNull} = false;
boolean ${ev.value} = true;
if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) {
} else {
$eval2FuncName(${ctx.INPUT_ROW});
if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) {
} else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) {
${ev.value} = false;
} else {
${ev.isNull} = true;
}
}
"""
ev.copy(code = generatedCode)
}
} else if (!left.nullable && !right.nullable) {
if (!left.nullable && !right.nullable) {
ev.isNull = "false"
ev.copy(code = s"""
${eval1.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actual(0) == cases)
}

test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") {
test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") {
var strExpr: Expression = Literal("abc")
for (_ <- 1 to 150) {
strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8")
Expand Down Expand Up @@ -342,7 +342,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
projection(row)
}

test("SPARK-21720: split large predications into blocks due to JVM code size limit") {
test("SPARK-22543: split large predicates into blocks due to JVM code size limit") {
val length = 600

val input = new GenericInternalRow(length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ case class HashAggregateExec(
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;")
// The generated function doesn't have input row in the code context.
ctx.INPUT_ROW = null

// generate variables for aggregation buffer
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
Expand Down

0 comments on commit 0605ad7

Please sign in to comment.