Skip to content

Commit

Permalink
[KYUUBI #6024] Insert crc checksum observer after all project nodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
wForget committed Jan 29, 2024
1 parent f531a37 commit 8506283
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,11 @@ object KyuubiSQLConf {
.version("1.9.0")
.booleanConf
.createWithDefault(true)

val INSERT_CHECKSUM_OBSERVER_AFTER_PROJECT_ENABLED =
buildConf("spark.sql.optimizer.insertChecksumObserverAfterProject.enabled")
.doc("If true, insert crc checksum observer after all project nodes.")
.version("1.9.0")
.booleanConf
.createWithDefault(false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.kyuubi.sql

import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions}

import org.apache.kyuubi.sql.observe.InsertChecksumObserverAfterProject
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, KyuubiUnsupportedOperationsCheck, MaxScanStrategy}

// scalastyle:off line.size.limit
Expand All @@ -32,6 +33,8 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions)

extensions.injectResolutionRule(InsertChecksumObserverAfterProject(_))

extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource)
extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive)
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.sql.observe

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Crc32, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum}
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.types.{BinaryType, ByteType, DecimalType, IntegerType, LongType, ShortType, StringType}

import org.apache.kyuubi.sql.KyuubiSQLConf.INSERT_CHECKSUM_OBSERVER_AFTER_PROJECT_ENABLED
import org.apache.kyuubi.sql.observe.InsertChecksumObserverAfterProject._

case class InsertChecksumObserverAfterProject(session: SparkSession) extends Rule[LogicalPlan] {

private val INSERT_COLLECT_METRICS_TAG = TreeNodeTag[Unit]("__INSERT_COLLECT_METRICS_TAG")

override def apply(plan: LogicalPlan): LogicalPlan = {
if (conf.getConf(INSERT_CHECKSUM_OBSERVER_AFTER_PROJECT_ENABLED)) {
plan resolveOperatorsUp {
case p: Project if p.resolved && p.getTagValue(INSERT_COLLECT_METRICS_TAG).isEmpty =>
val metricExprs = p.output.map(toChecksumExpr) :+ countExpr
p.setTagValue(INSERT_COLLECT_METRICS_TAG, ())
CollectMetrics(nextObserverName, metricExprs, p)
}
} else {
plan
}
}

private def toChecksumExpr(attr: Attribute): NamedExpression = {
// sum(cast(crc32(cast(attr as binary)) as decimal(20, 0))) as attr_crc_sum
Alias(
Sum(Cast(Crc32(toBinaryExpr(attr)), LongDecimal)).toAggregateExpression(),
attr.name + "_crc_sum")()
}

private def toBinaryExpr(attr: Attribute): Expression = {
attr.dataType match {
case BinaryType => attr
case StringType | ByteType | ShortType | IntegerType | LongType => Cast(attr, BinaryType)
case _ => Cast(Cast(attr, StringType), BinaryType)
}
}

private def countExpr: NamedExpression = {
Alias(Count(Literal(1)).toAggregateExpression(), "cnt")()
}
}

object InsertChecksumObserverAfterProject {
private val id = new AtomicLong(0)
private def nextObserverName: String = s"CHECKSUM_OBSERVER_${id.getAndIncrement()}"
private val LongDecimal = DecimalType(20, 0)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.observe

import org.apache.spark.sql.{KyuubiSparkSQLExtensionTest, QueryTest, Row}

import org.apache.kyuubi.sql.KyuubiSQLConf.INSERT_CHECKSUM_OBSERVER_AFTER_PROJECT_ENABLED

class InsertChecksumObserverAfterProjectSuite extends KyuubiSparkSQLExtensionTest {

test("insert checksum observer after project") {
withSQLConf(INSERT_CHECKSUM_OBSERVER_AFTER_PROJECT_ENABLED.key -> "true") {
withTable("t") {
sql("CREATE TABLE t(i int)")
sql("INSERT INTO t VALUES (1), (2), (3)")
val df = sql("select a from (SELECT i as a FROM t) where a > 1")
df.collect()
val metrics = df.queryExecution.observedMetrics
assert(metrics.size == 2)
QueryTest.sameRows(
Seq(Row(BigDecimal(6569872598L), 2), Row(BigDecimal(8017165408L), 3)),
metrics.values.toSeq)
}
}
}

}

0 comments on commit 8506283

Please sign in to comment.