diff --git a/project/SparkRedshiftBuild.scala b/project/SparkRedshiftBuild.scala index 1a5301f9..e5b827e6 100644 --- a/project/SparkRedshiftBuild.scala +++ b/project/SparkRedshiftBuild.scala @@ -47,7 +47,8 @@ object SparkRedshiftBuild extends Build { organization := "com.databricks", scalaVersion := "2.11.7", crossScalaVersions := Seq("2.10.5", "2.11.7"), - sparkVersion := "2.0.0", + sparkVersion := "2.1.1", + //sparkVersion := "2.0.0", testSparkVersion := sys.props.get("spark.testVersion").getOrElse(sparkVersion.value), testSparkAvroVersion := sys.props.get("sparkAvro.testVersion").getOrElse("3.0.0"), testHadoopVersion := sys.props.get("hadoop.testVersion").getOrElse("2.2.0"), diff --git a/project/plugins.sbt b/project/plugins.sbt index 3ee88f7d..a4f9d989 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,5 @@ +import sbtassembly.ShadeRule + addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.5") @@ -17,3 +19,9 @@ addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0") addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") libraryDependencies += "org.apache.maven" % "maven-artifact" % "3.3.9" + +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.5") + +//assemblyShadeRules in assembly := Seq( +// ShadeRule.rename("com.databricks.spark.redshift.**" -> "edmunds.@1").inProject +//) diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala index 875f5b75..f188524c 100644 --- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala +++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala @@ -38,7 +38,8 @@ private[redshift] object Parameters { "diststyle" -> "EVEN", "usestagingtable" -> "true", "preactions" -> ";", - "postactions" -> ";" + "postactions" -> ";", + "create_table_if_not_exist" -> "true" ) val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP") @@ -47,8 +48,9 @@ private[redshift] object Parameters { * Merge user parameters with the defaults, preferring user parameters if specified */ def mergeParameters(userParameters: Map[String, String]): MergedParameters = { - if (!userParameters.contains("tempdir")) { - throw new IllegalArgumentException("'tempdir' is required for all Redshift loads and saves") + if (!userParameters.contains("tempdir") && !userParameters.contains("permdir")) { + throw new IllegalArgumentException("'tempdir' or 'permdir' is required " + + "for all Redshift loads and saves") } if (userParameters.contains("tempformat") && !VALID_TEMP_FORMATS.contains(userParameters("tempformat").toUpperCase)) { @@ -121,8 +123,15 @@ private[redshift] object Parameters { /** * Creates a per-query subdirectory in the [[rootTempDir]], with a random UUID. + * modified by jchoi to support permDir */ - def createPerQueryTempDir(): String = Utils.makeTempPath(rootTempDir) + def createPerQueryTempDir(): String = { + if (!permDir.isEmpty){ + return permDir + } + + Utils.makeTempPath(rootTempDir) + } /** * The Redshift table to be used as the target when loading or writing data. @@ -285,5 +294,18 @@ private[redshift] object Parameters { new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken)) } } + + + /** + * option for create table if not exists + */ + def createTableIfNotExist: Boolean = parameters("create_table_if_not_exist").toBoolean + + /** + * + * @return + */ + def permDir: String = parameters.getOrElse("permdir", "") + } } diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala index 8383231d..a0c9514f 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala @@ -124,10 +124,12 @@ private[redshift] class RedshiftWriter( creds: AWSCredentialsProvider, manifestUrl: Option[String]): Unit = { - // If the table doesn't exist, we need to create it first, using JDBC to infer column types - val createStatement = createTableSql(data, params) - log.info(createStatement) - jdbcWrapper.executeInterruptibly(conn.prepareStatement(createStatement)) + if (params.createTableIfNotExist) { + // If the table doesn't exist, we need to create it first, using JDBC to infer column types + val createStatement = createTableSql(data, params) + log.info(createStatement) + jdbcWrapper.executeInterruptibly(conn.prepareStatement(createStatement)) + } val preActions = commentActions(params.description, data.schema) ++ params.preActions diff --git a/src/main/scala/com/databricks/spark/redshift/TableName.scala b/src/main/scala/com/databricks/spark/redshift/TableName.scala index d4a3d12e..bc11d986 100644 --- a/src/main/scala/com/databricks/spark/redshift/TableName.scala +++ b/src/main/scala/com/databricks/spark/redshift/TableName.scala @@ -25,7 +25,10 @@ private[redshift] case class TableName(unescapedSchemaName: String, unescapedTab private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"' def escapedSchemaName: String = quote(unescapedSchemaName) def escapedTableName: String = quote(unescapedTableName) - override def toString: String = s"$escapedSchemaName.$escapedTableName" + + // escapedSchemaName.length == 2 if empty due to "" + override def toString: String = if (escapedSchemaName.length == 2) escapedTableName + else s"$escapedSchemaName.$escapedTableName" } private[redshift] object TableName { @@ -39,7 +42,7 @@ private[redshift] object TableName { def unescapeQuotes(s: String) = s.replace("\"\"", "\"") def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s)) splitByDots(str) match { - case Seq(tableName) => TableName("PUBLIC", unescape(tableName)) + case Seq(tableName) => TableName("", unescape(tableName)) // jchoi "PUBLIC" case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName)) case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'") } diff --git a/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala b/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala index 24c935f3..68939adc 100644 --- a/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/TableNameSuite.scala @@ -21,8 +21,8 @@ import org.scalatest.FunSuite class TableNameSuite extends FunSuite { test("TableName.parseFromEscaped") { assert(TableName.parseFromEscaped("foo.bar") === TableName("foo", "bar")) - assert(TableName.parseFromEscaped("foo") === TableName("PUBLIC", "foo")) - assert(TableName.parseFromEscaped("\"foo\"") === TableName("PUBLIC", "foo")) + assert(TableName.parseFromEscaped("foo") === TableName("", "foo")) // PUBLIC jchoi + assert(TableName.parseFromEscaped("\"foo\"") === TableName("", "foo")) // PUBLIC jchoi assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("\"foo\"", "bar")) // Dots (.) can also appear inside of valid identifiers. assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("foo.bar", "baz")) diff --git a/version.sbt b/version.sbt index 4a2422e0..11809522 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "3.0.0-SNAPSHOT" \ No newline at end of file +version in ThisBuild := "3.0.0-edmunds-SNAPSHOT" \ No newline at end of file