Skip to content

Commit

Permalink
[SPARK-10092] [SQL] Multi-DB support follow up.
Browse files Browse the repository at this point in the history
https://issues.apache.org/jira/browse/SPARK-10092

This pr is a follow-up one for Multi-DB support. It has the following changes:

* `HiveContext.refreshTable` now accepts `dbName.tableName`.
* `HiveContext.analyze` now accepts `dbName.tableName`.
* `CreateTableUsing`, `CreateTableUsingAsSelect`, `CreateTempTableUsing`, `CreateTempTableUsingAsSelect`, `CreateMetastoreDataSource`, and `CreateMetastoreDataSourceAsSelect` all take `TableIdentifier` instead of the string representation of table name.
* When you call `saveAsTable` with a specified database, the data will be saved to the correct location.
* Explicitly do not allow users to create a temporary with a specified database name (users cannot do it before).
* When we save table to metastore, we also check if db name and table name can be accepted by hive (using `MetaStoreUtils.validateName`).

Author: Yin Huai <[email protected]>

Closes apache#8324 from yhuai/saveAsTableDB.
  • Loading branch information
yhuai authored and liancheng committed Aug 20, 2015
1 parent b762f99 commit 43e0135
Show file tree
Hide file tree
Showing 16 changed files with 398 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ private[sql] case class TableIdentifier(table: String, database: Option[String]

def toSeq: Seq[String] = database.toSeq :+ table

override def toString: String = toSeq.map("`" + _ + "`").mkString(".")
override def toString: String = quotedString

def quotedString: String = toSeq.map("`" + _ + "`").mkString(".")

def unquotedString: String = toSeq.mkString(".")
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}

Expand Down Expand Up @@ -55,12 +56,15 @@ trait Catalog {

def refreshTable(tableIdent: TableIdentifier): Unit

// TODO: Refactor it in the work of SPARK-10104
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit

// TODO: Refactor it in the work of SPARK-10104
def unregisterTable(tableIdentifier: Seq[String]): Unit

def unregisterAllTables(): Unit

// TODO: Refactor it in the work of SPARK-10104
protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = {
if (conf.caseSensitiveAnalysis) {
tableIdentifier
Expand All @@ -69,6 +73,7 @@ trait Catalog {
}
}

// TODO: Refactor it in the work of SPARK-10104
protected def getDbTableName(tableIdent: Seq[String]): String = {
val size = tableIdent.size
if (size <= 2) {
Expand All @@ -78,9 +83,22 @@ trait Catalog {
}
}

// TODO: Refactor it in the work of SPARK-10104
protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
(tableIdent.lift(tableIdent.size - 2), tableIdent.last)
}

/**
* It is not allowed to specifiy database name for tables stored in [[SimpleCatalog]].
* We use this method to check it.
*/
protected def checkTableIdentifier(tableIdentifier: Seq[String]): Unit = {
if (tableIdentifier.length > 1) {
throw new AnalysisException("Specifying database name or other qualifiers are not allowed " +
"for temporary tables. If the table name has dots (.) in it, please quote the " +
"table name with backticks (`).")
}
}
}

class SimpleCatalog(val conf: CatalystConf) extends Catalog {
Expand All @@ -89,11 +107,13 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
tables.put(getDbTableName(tableIdent), plan)
}

override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
tables.remove(getDbTableName(tableIdent))
}
Expand All @@ -103,13 +123,15 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
}

override def tableExists(tableIdentifier: Seq[String]): Boolean = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
tables.containsKey(getDbTableName(tableIdent))
}

override def lookupRelation(
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
val tableFullName = getDbTableName(tableIdent)
val table = tables.get(tableFullName)
Expand Down Expand Up @@ -149,7 +171,13 @@ trait OverrideCatalog extends Catalog {

abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.get(getDBTable(tableIdent)) match {
// A temporary tables only has a single part in the tableIdentifier.
val overriddenTable = if (tableIdentifier.length > 1) {
None: Option[LogicalPlan]
} else {
overrides.get(getDBTable(tableIdent))
}
overriddenTable match {
case Some(_) => true
case None => super.tableExists(tableIdentifier)
}
Expand All @@ -159,7 +187,12 @@ trait OverrideCatalog extends Catalog {
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan = {
val tableIdent = processTableIdentifier(tableIdentifier)
val overriddenTable = overrides.get(getDBTable(tableIdent))
// A temporary tables only has a single part in the tableIdentifier.
val overriddenTable = if (tableIdentifier.length > 1) {
None: Option[LogicalPlan]
} else {
overrides.get(getDBTable(tableIdent))
}
val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))

// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
Expand All @@ -171,20 +204,8 @@ trait OverrideCatalog extends Catalog {
}

abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
val dbName = if (conf.caseSensitiveAnalysis) {
databaseName
} else {
if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
}

val temporaryTables = overrides.filter {
// If a temporary table does not have an associated database, we should return its name.
case ((None, _), _) => true
// If a temporary table does have an associated database, we should return it if the database
// matches the given database name.
case ((db: Some[String], _), _) if db == dbName => true
case _ => false
}.map {
// We always return all temporary tables.
val temporaryTables = overrides.map {
case ((_, tableName), _) => (tableName, true)
}.toSeq

Expand All @@ -194,13 +215,19 @@ trait OverrideCatalog extends Catalog {
override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
checkTableIdentifier(tableIdentifier)
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.put(getDBTable(tableIdent), plan)
}

override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.remove(getDBTable(tableIdent))
// A temporary tables only has a single part in the tableIdentifier.
// If tableIdentifier has more than one parts, it is not a temporary table
// and we do not need to do anything at here.
if (tableIdentifier.length == 1) {
val tableIdent = processTableIdentifier(tableIdentifier)
overrides.remove(getDBTable(tableIdent))
}
}

override def unregisterAllTables(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
case _ =>
val cmd =
CreateTableUsingAsSelect(
tableIdent.unquotedString,
tableIdent,
source,
temporary = false,
partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]),
Expand Down
15 changes: 10 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -584,17 +584,18 @@ class SQLContext(@transient val sparkContext: SparkContext)
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableName,
tableIdent,
userSpecifiedSchema = None,
source,
temporary = false,
options,
allowExisting = false,
managedIfNoPath = false)
executePlan(cmd).toRdd
table(tableName)
table(tableIdent)
}

/**
Expand Down Expand Up @@ -629,17 +630,18 @@ class SQLContext(@transient val sparkContext: SparkContext)
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableName,
tableIdent,
userSpecifiedSchema = Some(schema),
source,
temporary = false,
options,
allowExisting = false,
managedIfNoPath = false)
executePlan(cmd).toRdd
table(tableName)
table(tableIdent)
}

/**
Expand Down Expand Up @@ -724,7 +726,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
val tableIdent = new SqlParser().parseTableIdentifier(tableName)
table(new SqlParser().parseTableIdentifier(tableName))
}

private def table(tableIdent: TableIdentifier): DataFrame = {
DataFrame(this, catalog.lookupRelation(tableIdent.toSeq))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,22 +395,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

object DDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) =>
case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, true, opts, false, _) =>
ExecutedCommand(
CreateTempTableUsing(
tableName, userSpecifiedSchema, provider, opts)) :: Nil
tableIdent, userSpecifiedSchema, provider, opts)) :: Nil
case c: CreateTableUsing if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")

case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query)
case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query)
if partitionsCols.nonEmpty =>
sys.error("Cannot create temporary partitioned table.")

case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) =>
case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) =>
val cmd = CreateTempTableUsingAsSelect(
tableName, provider, Array.empty[String], mode, opts, query)
tableIdent, provider, Array.empty[String], mode, opts, query)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ class DDLParser(parseQuery: String => LogicalPlan)
*/
protected lazy val createTable: Parser[LogicalPlan] = {
// TODO: Support database.table.
(CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~
(CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~
tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ {
case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query =>
case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query =>
if (temp.isDefined && allowExisting.isDefined) {
throw new DDLException(
"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
Expand All @@ -104,7 +104,7 @@ class DDLParser(parseQuery: String => LogicalPlan)
}

val queryPlan = parseQuery(query.get)
CreateTableUsingAsSelect(tableName,
CreateTableUsingAsSelect(tableIdent,
provider,
temp.isDefined,
Array.empty[String],
Expand All @@ -114,7 +114,7 @@ class DDLParser(parseQuery: String => LogicalPlan)
} else {
val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
CreateTableUsing(
tableName,
tableIdent,
userSpecifiedSchema,
provider,
temp.isDefined,
Expand All @@ -125,28 +125,28 @@ class DDLParser(parseQuery: String => LogicalPlan)
}
}

// This is the same as tableIdentifier in SqlParser.
protected lazy val tableIdentifier: Parser[TableIdentifier] =
(ident <~ ".").? ~ ident ^^ {
case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName)
}

protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"

/*
* describe [extended] table avroTable
* This will display all columns of table `avroTable` includes column_name,column_type,comment
*/
protected lazy val describeTable: Parser[LogicalPlan] =
(DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ {
case e ~ db ~ tbl =>
val tblIdentifier = db match {
case Some(dbName) =>
Seq(dbName, tbl)
case None =>
Seq(tbl)
}
DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined)
(DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ {
case e ~ tableIdent =>
DescribeCommand(UnresolvedRelation(tableIdent.toSeq, None), e.isDefined)
}

protected lazy val refreshTable: Parser[LogicalPlan] =
REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ {
case maybeDatabaseName ~ tableName =>
RefreshTable(TableIdentifier(tableName, maybeDatabaseName))
REFRESH ~> TABLE ~> tableIdentifier ^^ {
case tableIndet =>
RefreshTable(tableIndet)
}

protected lazy val options: Parser[Map[String, String]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ case class DescribeCommand(
* If it is false, an exception will be thrown
*/
case class CreateTableUsing(
tableName: String,
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
provider: String,
temporary: Boolean,
Expand All @@ -71,36 +71,37 @@ case class CreateTableUsing(
* can analyze the logical plan that will be used to populate the table.
* So, [[PreWriteCheck]] can detect cases that are not allowed.
*/
// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104).
case class CreateTableUsingAsSelect(
tableName: String,
tableIdent: TableIdentifier,
provider: String,
temporary: Boolean,
partitionColumns: Array[String],
mode: SaveMode,
options: Map[String, String],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
// TODO: Override resolved after we support databaseName.
// override lazy val resolved = databaseName != None && childrenResolved
}

case class CreateTempTableUsing(
tableName: String,
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String]) extends RunnableCommand {

def run(sqlContext: SQLContext): Seq[Row] = {
val resolved = ResolvedDataSource(
sqlContext, userSpecifiedSchema, Array.empty[String], provider, options)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
sqlContext.catalog.registerTable(
tableIdent.toSeq,
DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan)

Seq.empty[Row]
}
}

case class CreateTempTableUsingAsSelect(
tableName: String,
tableIdent: TableIdentifier,
provider: String,
partitionColumns: Array[String],
mode: SaveMode,
Expand All @@ -110,8 +111,9 @@ case class CreateTempTableUsingAsSelect(
override def run(sqlContext: SQLContext): Seq[Row] = {
val df = DataFrame(sqlContext, query)
val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df)
sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
sqlContext.catalog.registerTable(
tableIdent.toSeq,
DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan)

Seq.empty[Row]
}
Expand Down
Loading

0 comments on commit 43e0135

Please sign in to comment.