Skip to content

Commit

Permalink
zorder by support quote
Browse files Browse the repository at this point in the history
  • Loading branch information
XorSum committed Aug 4, 2024
1 parent 2a990be commit 16ffa12
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,30 @@ class KyuubiSparkSQLAstBuilder extends KyuubiSparkSQLBaseVisitor[AnyRef] with SQ

override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
withOrigin(ctx) {
ctx.parts.asScala.map(_.getText).toSeq
ctx.parts.asScala.map(typedVisit[String]).toSeq
}

override def visitIdentifier(ctx: IdentifierContext): String = {
withOrigin(ctx) {
ctx.strictIdentifier() match {
case quotedContext: QuotedIdentifierAlternativeContext =>
typedVisit[String](quotedContext)
case _ => ctx.getText
}
}
}

override def visitQuotedIdentifier(ctx: QuotedIdentifierContext): String = {
withOrigin(ctx) {
ctx.BACKQUOTED_IDENTIFIER().getText.stripPrefix("`").stripSuffix("`").replace("``", "`")
}
}

override def visitZorderClause(ctx: ZorderClauseContext): Seq[UnresolvedAttribute] =
withOrigin(ctx) {
val res = ListBuffer[UnresolvedAttribute]()
ctx.multipartIdentifier().forEach { identifier =>
res += UnresolvedAttribute(identifier.parts.asScala.map(_.getText).toSeq)
res += UnresolvedAttribute(identifier.parts.asScala.map(typedVisit[String]).toSeq)
}
res.toSeq
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,49 @@ trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHel
}
}

test("optimize sort by backquoted column name") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
withTable("up") {
sql(s"DROP TABLE IF EXISTS up")
val target = Seq(
Seq(0, 0),
Seq(1, 0),
Seq(0, 1),
Seq(1, 1),
Seq(2, 0),
Seq(3, 0),
Seq(2, 1),
Seq(3, 1),
Seq(0, 2),
Seq(1, 2),
Seq(0, 3),
Seq(1, 3),
Seq(2, 2),
Seq(3, 2),
Seq(2, 3),
Seq(3, 3))
sql(s"CREATE TABLE up (c1 INT, `@c2` INT, c3 INT)")
sql(s"INSERT INTO TABLE up VALUES" +
"(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
"(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
"(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
"(3,0,3),(3,1,4),(3,2,9),(3,3,0)")

sql("OPTIMIZE up ZORDER BY c1, `@c2`")
val res = sql("SELECT c1, `@c2` FROM up").collect()

assert(res.length == 16)

for (i <- target.indices) {
val t = target(i)
val r = res(i)
assert(t(0) == r.getInt(0))
assert(t(1) == r.getInt(1))
}
}
}
}

def createParser: ParserInterface
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,30 @@ class KyuubiSparkSQLAstBuilder extends KyuubiSparkSQLBaseVisitor[AnyRef] with SQ

override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
withOrigin(ctx) {
ctx.parts.asScala.map(_.getText).toSeq
ctx.parts.asScala.map(typedVisit[String]).toSeq
}

override def visitIdentifier(ctx: IdentifierContext): String = {
withOrigin(ctx) {
ctx.strictIdentifier() match {
case quotedContext: QuotedIdentifierAlternativeContext =>
typedVisit[String](quotedContext)
case _ => ctx.getText
}
}
}

override def visitQuotedIdentifier(ctx: QuotedIdentifierContext): String = {
withOrigin(ctx) {
ctx.BACKQUOTED_IDENTIFIER().getText.stripPrefix("`").stripSuffix("`").replace("``", "`")
}
}

override def visitZorderClause(ctx: ZorderClauseContext): Seq[UnresolvedAttribute] =
withOrigin(ctx) {
val res = ListBuffer[UnresolvedAttribute]()
ctx.multipartIdentifier().forEach { identifier =>
res += UnresolvedAttribute(identifier.parts.asScala.map(_.getText).toSeq)
res += UnresolvedAttribute(identifier.parts.asScala.map(typedVisit[String]).toSeq)
}
res.toSeq
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,49 @@ trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHel
}
}

test("optimize sort by backquoted column name") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
withTable("up") {
sql(s"DROP TABLE IF EXISTS up")
val target = Seq(
Seq(0, 0),
Seq(1, 0),
Seq(0, 1),
Seq(1, 1),
Seq(2, 0),
Seq(3, 0),
Seq(2, 1),
Seq(3, 1),
Seq(0, 2),
Seq(1, 2),
Seq(0, 3),
Seq(1, 3),
Seq(2, 2),
Seq(3, 2),
Seq(2, 3),
Seq(3, 3))
sql(s"CREATE TABLE up (c1 INT, `@c2` INT, c3 INT)")
sql(s"INSERT INTO TABLE up VALUES" +
"(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
"(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
"(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
"(3,0,3),(3,1,4),(3,2,9),(3,3,0)")

sql("OPTIMIZE up ZORDER BY c1, `@c2`")
val res = sql("SELECT c1, `@c2` FROM up").collect()

assert(res.length == 16)

for (i <- target.indices) {
val t = target(i)
val r = res(i)
assert(t(0) == r.getInt(0))
assert(t(1) == r.getInt(1))
}
}
}
}

def createParser: ParserInterface
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,30 @@ class KyuubiSparkSQLAstBuilder extends KyuubiSparkSQLBaseVisitor[AnyRef] with SQ

override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
withOrigin(ctx) {
ctx.parts.asScala.map(_.getText).toSeq
ctx.parts.asScala.map(typedVisit[String]).toSeq
}

override def visitIdentifier(ctx: IdentifierContext): String = {
withOrigin(ctx) {
ctx.strictIdentifier() match {
case quotedContext: QuotedIdentifierAlternativeContext =>
typedVisit[String](quotedContext)
case _ => ctx.getText
}
}
}

override def visitQuotedIdentifier(ctx: QuotedIdentifierContext): String = {
withOrigin(ctx) {
ctx.BACKQUOTED_IDENTIFIER().getText.stripPrefix("`").stripSuffix("`").replace("``", "`")
}
}

override def visitZorderClause(ctx: ZorderClauseContext): Seq[UnresolvedAttribute] =
withOrigin(ctx) {
val res = ListBuffer[UnresolvedAttribute]()
ctx.multipartIdentifier().forEach { identifier =>
res += UnresolvedAttribute(identifier.parts.asScala.map(_.getText).toSeq)
res += UnresolvedAttribute(identifier.parts.asScala.map(typedVisit[String]).toSeq)
}
res.toSeq
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,49 @@ trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHel
}
}

test("optimize sort by backquoted column name") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
withTable("up") {
sql(s"DROP TABLE IF EXISTS up")
val target = Seq(
Seq(0, 0),
Seq(1, 0),
Seq(0, 1),
Seq(1, 1),
Seq(2, 0),
Seq(3, 0),
Seq(2, 1),
Seq(3, 1),
Seq(0, 2),
Seq(1, 2),
Seq(0, 3),
Seq(1, 3),
Seq(2, 2),
Seq(3, 2),
Seq(2, 3),
Seq(3, 3))
sql(s"CREATE TABLE up (c1 INT, `@c2` INT, c3 INT)")
sql(s"INSERT INTO TABLE up VALUES" +
"(0,0,2),(0,1,2),(0,2,1),(0,3,3)," +
"(1,0,4),(1,1,2),(1,2,1),(1,3,3)," +
"(2,0,2),(2,1,1),(2,2,5),(2,3,5)," +
"(3,0,3),(3,1,4),(3,2,9),(3,3,0)")

sql("OPTIMIZE up ZORDER BY c1, `@c2`")
val res = sql("SELECT c1, `@c2` FROM up").collect()

assert(res.length == 16)

for (i <- target.indices) {
val t = target(i)
val r = res(i)
assert(t(0) == r.getInt(0))
assert(t(1) == r.getInt(1))
}
}
}
}

def createParser: ParserInterface
}

Expand Down

0 comments on commit 16ffa12

Please sign in to comment.