Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add 'create_table_if_not_exist' params #352

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion project/SparkRedshiftBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ object SparkRedshiftBuild extends Build {
organization := "com.databricks",
scalaVersion := "2.11.7",
crossScalaVersions := Seq("2.10.5", "2.11.7"),
sparkVersion := "2.0.0",
sparkVersion := "2.1.1",
//sparkVersion := "2.0.0",
testSparkVersion := sys.props.get("spark.testVersion").getOrElse(sparkVersion.value),
testSparkAvroVersion := sys.props.get("sparkAvro.testVersion").getOrElse("3.0.0"),
testHadoopVersion := sys.props.get("hadoop.testVersion").getOrElse("2.2.0"),
Expand Down
8 changes: 8 additions & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sbtassembly.ShadeRule

addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")

addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.5")
Expand All @@ -17,3 +19,9 @@ addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.0")
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")

libraryDependencies += "org.apache.maven" % "maven-artifact" % "3.3.9"

addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.5")

//assemblyShadeRules in assembly := Seq(
// ShadeRule.rename("com.databricks.spark.redshift.**" -> "edmunds.@1").inProject
//)
30 changes: 26 additions & 4 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ private[redshift] object Parameters {
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
"preactions" -> ";",
"postactions" -> ";"
"postactions" -> ";",
"create_table_if_not_exist" -> "true"
)

val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP")
Expand All @@ -47,8 +48,9 @@ private[redshift] object Parameters {
* Merge user parameters with the defaults, preferring user parameters if specified
*/
def mergeParameters(userParameters: Map[String, String]): MergedParameters = {
if (!userParameters.contains("tempdir")) {
throw new IllegalArgumentException("'tempdir' is required for all Redshift loads and saves")
if (!userParameters.contains("tempdir") && !userParameters.contains("permdir")) {
throw new IllegalArgumentException("'tempdir' or 'permdir' is required "
+ "for all Redshift loads and saves")
}
if (userParameters.contains("tempformat") &&
!VALID_TEMP_FORMATS.contains(userParameters("tempformat").toUpperCase)) {
Expand Down Expand Up @@ -121,8 +123,15 @@ private[redshift] object Parameters {

/**
* Creates a per-query subdirectory in the [[rootTempDir]], with a random UUID.
* modified by jchoi to support permDir
*/
def createPerQueryTempDir(): String = Utils.makeTempPath(rootTempDir)
def createPerQueryTempDir(): String = {
if (!permDir.isEmpty){
return permDir
}

Utils.makeTempPath(rootTempDir)
}

/**
* The Redshift table to be used as the target when loading or writing data.
Expand Down Expand Up @@ -285,5 +294,18 @@ private[redshift] object Parameters {
new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken))
}
}


/**
* option for create table if not exists
*/
def createTableIfNotExist: Boolean = parameters("create_table_if_not_exist").toBoolean

/**
*
* @return
*/
def permDir: String = parameters.getOrElse("permdir", "")

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ private[redshift] class RedshiftWriter(
creds: AWSCredentialsProvider,
manifestUrl: Option[String]): Unit = {

// If the table doesn't exist, we need to create it first, using JDBC to infer column types
val createStatement = createTableSql(data, params)
log.info(createStatement)
jdbcWrapper.executeInterruptibly(conn.prepareStatement(createStatement))
if (params.createTableIfNotExist) {
// If the table doesn't exist, we need to create it first, using JDBC to infer column types
val createStatement = createTableSql(data, params)
log.info(createStatement)
jdbcWrapper.executeInterruptibly(conn.prepareStatement(createStatement))
}

val preActions = commentActions(params.description, data.schema) ++ params.preActions

Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/com/databricks/spark/redshift/TableName.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ private[redshift] case class TableName(unescapedSchemaName: String, unescapedTab
private def quote(str: String) = '"' + str.replace("\"", "\"\"") + '"'
def escapedSchemaName: String = quote(unescapedSchemaName)
def escapedTableName: String = quote(unescapedTableName)
override def toString: String = s"$escapedSchemaName.$escapedTableName"

// escapedSchemaName.length == 2 if empty due to ""
override def toString: String = if (escapedSchemaName.length == 2) escapedTableName
else s"$escapedSchemaName.$escapedTableName"
}

private[redshift] object TableName {
Expand All @@ -39,7 +42,7 @@ private[redshift] object TableName {
def unescapeQuotes(s: String) = s.replace("\"\"", "\"")
def unescape(s: String) = unescapeQuotes(dropOuterQuotes(s))
splitByDots(str) match {
case Seq(tableName) => TableName("PUBLIC", unescape(tableName))
case Seq(tableName) => TableName("", unescape(tableName)) // jchoi "PUBLIC"
case Seq(schemaName, tableName) => TableName(unescape(schemaName), unescape(tableName))
case other => throw new IllegalArgumentException(s"Could not parse table name from '$str'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import org.scalatest.FunSuite
class TableNameSuite extends FunSuite {
test("TableName.parseFromEscaped") {
assert(TableName.parseFromEscaped("foo.bar") === TableName("foo", "bar"))
assert(TableName.parseFromEscaped("foo") === TableName("PUBLIC", "foo"))
assert(TableName.parseFromEscaped("\"foo\"") === TableName("PUBLIC", "foo"))
assert(TableName.parseFromEscaped("foo") === TableName("", "foo")) // PUBLIC jchoi
assert(TableName.parseFromEscaped("\"foo\"") === TableName("", "foo")) // PUBLIC jchoi
assert(TableName.parseFromEscaped("\"\"\"foo\"\"\".bar") === TableName("\"foo\"", "bar"))
// Dots (.) can also appear inside of valid identifiers.
assert(TableName.parseFromEscaped("\"foo.bar\".baz") === TableName("foo.bar", "baz"))
Expand Down
2 changes: 1 addition & 1 deletion version.sbt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version in ThisBuild := "3.0.0-SNAPSHOT"
version in ThisBuild := "3.0.0-edmunds-SNAPSHOT"