From 43e0135421b2262cbb0e06aae53523f663b4f959 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 20 Aug 2015 15:30:31 +0800 Subject: [PATCH] [SPARK-10092] [SQL] Multi-DB support follow up. https://issues.apache.org/jira/browse/SPARK-10092 This pr is a follow-up one for Multi-DB support. It has the following changes: * `HiveContext.refreshTable` now accepts `dbName.tableName`. * `HiveContext.analyze` now accepts `dbName.tableName`. * `CreateTableUsing`, `CreateTableUsingAsSelect`, `CreateTempTableUsing`, `CreateTempTableUsingAsSelect`, `CreateMetastoreDataSource`, and `CreateMetastoreDataSourceAsSelect` all take `TableIdentifier` instead of the string representation of table name. * When you call `saveAsTable` with a specified database, the data will be saved to the correct location. * Explicitly do not allow users to create a temporary with a specified database name (users cannot do it before). * When we save table to metastore, we also check if db name and table name can be accepted by hive (using `MetaStoreUtils.validateName`). Author: Yin Huai Closes #8324 from yhuai/saveAsTableDB. --- .../spark/sql/catalyst/TableIdentifier.scala | 4 +- .../spark/sql/catalyst/analysis/Catalog.scala | 63 +++++-- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 15 +- .../spark/sql/execution/SparkStrategies.scala | 10 +- .../sql/execution/datasources/DDLParser.scala | 32 ++-- .../spark/sql/execution/datasources/ddl.scala | 22 +-- .../sql/execution/datasources/rules.scala | 8 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 35 ++++ .../apache/spark/sql/hive/HiveContext.scala | 14 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 ++- .../spark/sql/hive/HiveStrategies.scala | 12 +- .../spark/sql/hive/execution/commands.scala | 54 ++++-- .../spark/sql/hive/ListTablesSuite.scala | 6 - .../spark/sql/hive/MultiDatabaseSuite.scala | 158 +++++++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 35 ++++ 16 files changed, 398 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala index aebcdeb9d070f..d701559bf2d9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -25,7 +25,9 @@ private[sql] case class TableIdentifier(table: String, database: Option[String] def toSeq: Seq[String] = database.toSeq :+ table - override def toString: String = toSeq.map("`" + _ + "`").mkString(".") + override def toString: String = quotedString + + def quotedString: String = toSeq.map("`" + _ + "`").mkString(".") def unquotedString: String = toSeq.mkString(".") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 5766e6a2dd51a..503c4f4b20f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} @@ -55,12 +56,15 @@ trait Catalog { def refreshTable(tableIdent: TableIdentifier): Unit + // TODO: Refactor it in the work of SPARK-10104 def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit + // TODO: Refactor it in the work of SPARK-10104 def unregisterTable(tableIdentifier: Seq[String]): Unit def unregisterAllTables(): Unit + // TODO: Refactor it in the work of SPARK-10104 protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = { if (conf.caseSensitiveAnalysis) { tableIdentifier @@ -69,6 +73,7 @@ trait Catalog { } } + // TODO: Refactor it in the work of SPARK-10104 protected def getDbTableName(tableIdent: Seq[String]): String = { val size = tableIdent.size if (size <= 2) { @@ -78,9 +83,22 @@ trait Catalog { } } + // TODO: Refactor it in the work of SPARK-10104 protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = { (tableIdent.lift(tableIdent.size - 2), tableIdent.last) } + + /** + * It is not allowed to specifiy database name for tables stored in [[SimpleCatalog]]. + * We use this method to check it. + */ + protected def checkTableIdentifier(tableIdentifier: Seq[String]): Unit = { + if (tableIdentifier.length > 1) { + throw new AnalysisException("Specifying database name or other qualifiers are not allowed " + + "for temporary tables. If the table name has dots (.) in it, please quote the " + + "table name with backticks (`).") + } + } } class SimpleCatalog(val conf: CatalystConf) extends Catalog { @@ -89,11 +107,13 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.put(getDbTableName(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.remove(getDbTableName(tableIdent)) } @@ -103,6 +123,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def tableExists(tableIdentifier: Seq[String]): Boolean = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.containsKey(getDbTableName(tableIdent)) } @@ -110,6 +131,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def lookupRelation( tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) val tableFullName = getDbTableName(tableIdent) val table = tables.get(tableFullName) @@ -149,7 +171,13 @@ trait OverrideCatalog extends Catalog { abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - overrides.get(getDBTable(tableIdent)) match { + // A temporary tables only has a single part in the tableIdentifier. + val overriddenTable = if (tableIdentifier.length > 1) { + None: Option[LogicalPlan] + } else { + overrides.get(getDBTable(tableIdent)) + } + overriddenTable match { case Some(_) => true case None => super.tableExists(tableIdentifier) } @@ -159,7 +187,12 @@ trait OverrideCatalog extends Catalog { tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) - val overriddenTable = overrides.get(getDBTable(tableIdent)) + // A temporary tables only has a single part in the tableIdentifier. + val overriddenTable = if (tableIdentifier.length > 1) { + None: Option[LogicalPlan] + } else { + overrides.get(getDBTable(tableIdent)) + } val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are @@ -171,20 +204,8 @@ trait OverrideCatalog extends Catalog { } abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val dbName = if (conf.caseSensitiveAnalysis) { - databaseName - } else { - if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None - } - - val temporaryTables = overrides.filter { - // If a temporary table does not have an associated database, we should return its name. - case ((None, _), _) => true - // If a temporary table does have an associated database, we should return it if the database - // matches the given database name. - case ((db: Some[String], _), _) if db == dbName => true - case _ => false - }.map { + // We always return all temporary tables. + val temporaryTables = overrides.map { case ((_, tableName), _) => (tableName, true) }.toSeq @@ -194,13 +215,19 @@ trait OverrideCatalog extends Catalog { override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) overrides.put(getDBTable(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.remove(getDBTable(tableIdent)) + // A temporary tables only has a single part in the tableIdentifier. + // If tableIdentifier has more than one parts, it is not a temporary table + // and we do not need to do anything at here. + if (tableIdentifier.length == 1) { + val tableIdent = processTableIdentifier(tableIdentifier) + overrides.remove(getDBTable(tableIdent)) + } } override def unregisterAllTables(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index f0bf1be506411..ce8744b53175b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -218,7 +218,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { case _ => val cmd = CreateTableUsingAsSelect( - tableIdent.unquotedString, + tableIdent, source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 58fe75b59f418..126c9c6f839c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -584,9 +584,10 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = None, source, temporary = false, @@ -594,7 +595,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -629,9 +630,10 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = Some(schema), source, temporary = false, @@ -639,7 +641,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -724,7 +726,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + table(new SqlParser().parseTableIdentifier(tableName)) + } + + private def table(tableIdent: TableIdentifier): DataFrame = { DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1fc870d44b578..4df53687a0731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -395,22 +395,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) => + case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, true, opts, false, _) => ExecutedCommand( CreateTempTableUsing( - tableName, userSpecifiedSchema, provider, opts)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts)) :: Nil case c: CreateTableUsing if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query) + case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) if partitionsCols.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) => + case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => val cmd = CreateTempTableUsingAsSelect( - tableName, provider, Array.empty[String], mode, opts, query) + tableIdent, provider, Array.empty[String], mode, opts, query) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 6c462fa30461b..f7a88b98c0b48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -80,9 +80,9 @@ class DDLParser(parseQuery: String => LogicalPlan) */ protected lazy val createTable: Parser[LogicalPlan] = { // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => + case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => if (temp.isDefined && allowExisting.isDefined) { throw new DDLException( "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") @@ -104,7 +104,7 @@ class DDLParser(parseQuery: String => LogicalPlan) } val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, + CreateTableUsingAsSelect(tableIdent, provider, temp.isDefined, Array.empty[String], @@ -114,7 +114,7 @@ class DDLParser(parseQuery: String => LogicalPlan) } else { val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema, provider, temp.isDefined, @@ -125,6 +125,12 @@ class DDLParser(parseQuery: String => LogicalPlan) } } + // This is the same as tableIdentifier in SqlParser. + protected lazy val tableIdentifier: Parser[TableIdentifier] = + (ident <~ ".").? ~ ident ^^ { + case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) + } + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" /* @@ -132,21 +138,15 @@ class DDLParser(parseQuery: String => LogicalPlan) * This will display all columns of table `avroTable` includes column_name,column_type,comment */ protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { + case e ~ tableIdent => + DescribeCommand(UnresolvedRelation(tableIdent.toSeq, None), e.isDefined) } protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) + REFRESH ~> TABLE ~> tableIdentifier ^^ { + case tableIndet => + RefreshTable(tableIndet) } protected lazy val options: Parser[Map[String, String]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index ecd304c30cdee..31d6b75e13477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -53,7 +53,7 @@ case class DescribeCommand( * If it is false, an exception will be thrown */ case class CreateTableUsing( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, @@ -71,8 +71,9 @@ case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). case class CreateTableUsingAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, temporary: Boolean, partitionColumns: Array[String], @@ -80,12 +81,10 @@ case class CreateTableUsingAsSelect( options: Map[String, String], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] - // TODO: Override resolved after we support databaseName. - // override lazy val resolved = databaseName != None && childrenResolved } case class CreateTempTableUsing( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -93,14 +92,16 @@ case class CreateTempTableUsing( def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent.toSeq, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) + Seq.empty[Row] } } case class CreateTempTableUsingAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -110,8 +111,9 @@ case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent.toSeq, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 9d3d35692ffcc..16c9138419fa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -140,12 +140,12 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableName, _, _, partitionColumns, mode, _, query) => + case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(Seq(tableName))) { + if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent.toSeq)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { + EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation) => @@ -155,7 +155,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableName that is also being read from.") + s"Cannot overwrite table $tableIdent that is also being read from.") } else { // OK } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 141468ca00d67..da50aec17c89e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1644,4 +1644,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select count(num) from 1one"), Row(10)) } } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 17762649fd70d..17cc83087fb1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -43,7 +43,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} @@ -189,6 +189,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options // into the isolated client loader val metadataConf = new HiveConf() + + val defaltWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("defalt warehouse location is " + defaltWarehouseLocation) + // `configure` goes second to override other settings. val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure @@ -288,12 +292,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - catalog.invalidateTable(catalog.client.currentDatabase, tableName) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + catalog.invalidateTable(tableIdent) } /** @@ -307,7 +312,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { case relation: MetastoreRelation => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6770462bb0ad3..bbe8c1911bf86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -174,10 +174,13 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table) + invalidateTable(tableIdent) } - def invalidateTable(databaseName: String, tableName: String): Unit = { + def invalidateTable(tableIdent: TableIdentifier): Unit = { + val databaseName = tableIdent.database.getOrElse(client.currentDatabase) + val tableName = tableIdent.table + cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) } @@ -187,6 +190,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * Creates a data source table (a table created with USING clause) in Hive's metastore. * Returns true when the table has been created. Otherwise, false. */ + // TODO: Remove this in SPARK-10104. def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], @@ -203,7 +207,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive isExternal) } - private def createDataSourceTable( + def createDataSourceTable( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], @@ -371,10 +375,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { + hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + } + + def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) + val database = tableIdent.database.getOrElse(client.currentDatabase) + new Path( - new Path(client.getDatabase(client.currentDatabase).location), - tableName.toLowerCase).toString + new Path(client.getDatabase(database).location), + tableIdent.table.toLowerCase).toString } def tableExists(tableIdentifier: Seq[String]): Boolean = { @@ -635,7 +645,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( - desc.name, + TableIdentifier(desc.name), hive.conf.defaultDataSourceName, temporary = false, Array.empty[String], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index cd6cd322c94ed..d38ad9127327d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -83,14 +83,16 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateTableUsing( - tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => - ExecutedCommand( + tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => + val cmd = CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) + ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, partitionCols, mode, opts, query) => + case CreateTableUsingAsSelect( + tableIdent, provider, false, partitionCols, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, partitionCols, mode, opts, query) + CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 05a78930afe3d..d1699dd536817 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive.execution +import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser} import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -120,9 +122,10 @@ case class AddFile(path: String) extends RunnableCommand { } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSource( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], @@ -130,9 +133,24 @@ case class CreateMetastoreDataSource( managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (hiveContext.catalog.tableExists(tableIdent.toSeq)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -144,13 +162,13 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, userSpecifiedSchema, Array.empty[String], provider, @@ -161,9 +179,10 @@ case class CreateMetastoreDataSource( } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSourceAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -171,19 +190,34 @@ case class CreateMetastoreDataSourceAsSelect( query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } var existingSchema = None: Option[StructType] - if (sqlContext.catalog.tableExists(Seq(tableName))) { + if (sqlContext.catalog.tableExists(tableIdent.toSeq)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -200,7 +234,7 @@ case class CreateMetastoreDataSourceAsSelect( val resolved = ResolvedDataSource( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { + EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent.toSeq)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => if (l.relation != createdRelation.relation) { val errorDescription = @@ -249,7 +283,7 @@ case class CreateMetastoreDataSourceAsSelect( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, Some(resolved.relation.schema), partitionColumns, provider, @@ -258,7 +292,7 @@ case class CreateMetastoreDataSourceAsSelect( } // Refresh the cache of the table in the catalog. - hiveContext.refreshTable(tableName) + hiveContext.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 1c15997ea8e6d..d3388a9429e41 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -34,7 +34,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) - catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") @@ -42,7 +41,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { override def afterAll(): Unit = { catalog.unregisterTable(Seq("ListTablesSuiteTable")) - catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") @@ -55,7 +53,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hivelisttablessuitetable'"), Row("hivelisttablessuitetable", false)) @@ -69,9 +66,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - checkAnswer( - allTables.filter("tableName = 'indblisttablessuitetable'"), - Row("indblisttablessuitetable", true)) assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hiveindblisttablessuitetable'"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 417e8b07917cc..997c667ec0d1b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -19,14 +19,22 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val _sqlContext: SQLContext = TestHive + override val _sqlContext: HiveContext = TestHive private val sqlContext = _sqlContext private val df = sqlContext.range(10).coalesce(1) + private def checkTablePath(dbName: String, tableName: String): Unit = { + // val hiveContext = sqlContext.asInstanceOf[HiveContext] + val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName) + val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName + + assert(metastoreTable.serdeProperties("path") === expectedPath) + } + test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => activateDatabase(db) { @@ -37,6 +45,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") } } @@ -45,6 +55,58 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"createExternalTable() to non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + + sqlContext.createExternalTable("t", path, "parquet") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table("t"), df) + + sql( + s""" + |CREATE TABLE t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table("t1"), df) + } + } + } + } + + test(s"createExternalTable() to non-default database - without USE") { + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + sqlContext.createExternalTable(s"$db.t", path, "parquet") + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + sql( + s""" + |CREATE TABLE $db.t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table(s"$db.t1"), df) + } } } @@ -59,6 +121,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -68,6 +132,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -130,7 +196,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { } } - test("Refreshes a table in a non-default database") { + test("Refreshes a table in a non-default database - with USE") { import org.apache.spark.sql.functions.lit withTempDatabase { db => @@ -151,8 +217,94 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql("ALTER TABLE t ADD PARTITION (p=2)") + sqlContext.refreshTable("t") + checkAnswer( + sqlContext.table("t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) } } } } + + test("Refreshes a table in a non-default database - without USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + sql( + s"""CREATE EXTERNAL TABLE $db.t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") + sql(s"REFRESH TABLE $db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") + sqlContext.refreshTable(s"$db.t") + checkAnswer( + sqlContext.table(s"$db.t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + + test("invalid database name and table names") { + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`t:a`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`table`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`t:a` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`table` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8b8f520776e70..55ecbd5b5f21d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1138,4 +1138,39 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row(CalendarInterval.fromString( "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } }