Skip to content

Commit

Permalink
Support observe hint
Browse files Browse the repository at this point in the history
  • Loading branch information
wForget committed Jan 25, 2024
1 parent 47a1091 commit 22bd178
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,12 @@ object KyuubiSQLConf {
.version("1.9.0")
.booleanConf
.createWithDefault(true)

val OBSERVE_HINT_ENABLE =
buildConf("spark.sql.optimizer.observeHint.enabled")
.doc(s"Provide OBSERVE Hint to create an observer to collect aggregated metrics." +
s" The OBSERVE Hint Syntax: /*+ OBSERVE(name, exprs) */.")
.version("1.9.0")
.booleanConf
.createWithDefault(false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.kyuubi.sql

import org.apache.spark.sql.{FinalStageResourceManager, InjectCustomResourceProfile, SparkSessionExtensions}

import org.apache.kyuubi.sql.observe.ResolveObserveHints
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, KyuubiUnsupportedOperationsCheck, MaxScanStrategy}

// scalastyle:off line.size.limit
Expand All @@ -32,6 +33,8 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions)

extensions.injectResolutionRule(_ => ResolveObserveHints)

extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource)
extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive)
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.sql.observe

import java.util.Locale
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAlias}
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Generator, NamedExpression, StringLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, LogicalPlan, UnresolvedHint}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_HINT
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression

import org.apache.kyuubi.sql.KyuubiSQLConf.OBSERVE_HINT_ENABLE

/**
* A rule to resolve the OBSERVE hint.
* OBSERVE hint usage like: /*+ OBSERVE('name', exprs) */
*/
object ResolveObserveHints extends Rule[LogicalPlan] {

private val OBSERVE_HINT_NAME = "OBSERVE"

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.getConf(OBSERVE_HINT_ENABLE)) {
return plan
}
plan.resolveOperatorsWithPruning(
_.containsPattern(UNRESOLVED_HINT)) {
case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match {
case OBSERVE_HINT_NAME =>
val (name, exprs) = hint.parameters match {
case Seq(StringLiteral(name), exprs @ _*) => (name, exprs)
case Seq(exprs @ _*) => (nextObserverName(), exprs)
}

val invalidParams = exprs.filter(!_.isInstanceOf[Expression])
if (invalidParams.nonEmpty) {
val hintName = hint.name.toUpperCase(Locale.ROOT)
throw invalidHintParameterError(hintName, invalidParams)
}

// named exprs, copy from org.apache.spark.sql.Column.named method
val namedExprs = exprs.map {
case expr: NamedExpression => expr
// Leave an unaliased generator with an empty list of names since the analyzer will
// generate the correct defaults after the nested expression's type has been resolved.
case g: Generator => MultiAlias(g, Nil)

// If we have a top level Cast, there is a chance to give it a better alias,
// if there is a NamedExpression under this Cast.
case c: Cast =>
c.transformUp {
case c @ Cast(_: NamedExpression, _, _, _) => UnresolvedAlias(c)
} match {
case ne: NamedExpression => ne
case _ => UnresolvedAlias(c, Some(generateAlias))
}

case expr: Expression => UnresolvedAlias(expr, Some(generateAlias))
}

CollectMetrics(name, namedExprs, hint.child)
case _ => hint
}
}
}

private val id = new AtomicLong(0)
private def nextObserverName(): String = s"OBSERVER_${id.getAndIncrement()}"

private def invalidHintParameterError(hintName: String, invalidParams: Seq[Any]): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1047",
messageParameters = Map(
"hintName" -> hintName,
"invalidParams" -> invalidParams.mkString(", ")))
}

private def generateAlias(e: Expression): String = {
e match {
case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
a.aggregateFunction.toString
case expr => toPrettySQL(expr)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.observe

import org.apache.spark.sql.{KyuubiSparkSQLExtensionTest, QueryTest, Row}

import org.apache.kyuubi.sql.KyuubiSQLConf.OBSERVE_HINT_ENABLE

class ResolveObserveHintsSuite extends KyuubiSparkSQLExtensionTest {

override protected def beforeAll(): Unit = {
super.beforeAll()
setupData()
}

test("test observe hint") {
withSQLConf(OBSERVE_HINT_ENABLE.key -> "true") {
val sqlText =
s"""
| SELECT /*+ OBSERVE('observer3', sum(tt2.c3), count(1)) */ *
| FROM
| (SELECT /*+ OBSERVE('observer1', sum(c1), count(1)) */ * from t1) tt1
| join
| (SELECT /*+ OBSERVE('observer2', sum(c1), count(1)) */ c1, c1 * 2 as c3 from t2) tt2
| on tt1.c1 = tt2.c1
|""".stripMargin
val df = spark.sql(sqlText)
df.collect()
val observedMetrics = df.queryExecution.observedMetrics
assert(observedMetrics.size == 3)
QueryTest.sameRows(Seq(observedMetrics("observer1")), Seq(Row(5050, 100)))
QueryTest.sameRows(Seq(observedMetrics("observer2")), Seq(Row(55, 10)))
QueryTest.sameRows(Seq(observedMetrics("observer3")), Seq(Row(110, 10)))
}
}
}

0 comments on commit 22bd178

Please sign in to comment.