Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling CSE encryption for COPY command #377

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,11 @@ The following describes how each connection can be authenticated:
To use this capability, you should configure your Hadoop S3 FileSystem to use encryption by
setting the appropriate configuration properties (which will vary depending on whether you
are using `s3a`, `s3n`, EMRFS, etc.).

Note that the `MANIFEST` file (a list of all files written) will not be encrypted.

To use CSE with S3 during COPY, generate your own encryption key, and pass it to the writer using
```.option("extracopyoptions", s"encrypted master_symmetric_key '$encodedSymmetricKey'")```


### Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ private[redshift] object AWSCredentialsUtils {
/**
* Generates a credentials string for use in Redshift COPY and UNLOAD statements.
* Favors a configured `aws_iam_role` if available in the parameters.
* http://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html
*/
def getRedshiftCredentialsString(
params: MergedParameters,
Expand All @@ -36,15 +37,14 @@ private[redshift] object AWSCredentialsUtils {
def awsCredsToString(credentials: AWSCredentials): String = {
credentials match {
case creds: AWSSessionCredentials =>
s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
s"aws_secret_access_key=${creds.getAWSSecretKey};token=${creds.getSessionToken}"
s"access_key_id '${creds.getAWSAccessKeyId}' secret_access_key '${creds.getAWSSecretKey}' " +
s"session_token '${creds.getSessionToken}'"
case creds =>
s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
s"aws_secret_access_key=${creds.getAWSSecretKey}"
s"access_key_id '${creds.getAWSAccessKeyId}' secret_access_key '${creds.getAWSSecretKey}'"
}
}
if (params.iamRole.isDefined) {
s"aws_iam_role=${params.iamRole.get}"
s"iam_role '${params.iamRole.get}'"
} else if (params.temporaryAWSCredentials.isDefined) {
awsCredsToString(params.temporaryAWSCredentials.get.getCredentials)
} else if (params.forwardSparkS3Credentials) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ private[redshift] case class RedshiftRelation(
// the credentials passed via `credsString`.
val fixedUrl = Utils.fixS3Url(Utils.removeCredentialsFromURI(new URI(tempDir)).toString)

s"UNLOAD ('$query') TO '$fixedUrl' WITH CREDENTIALS '$credsString' ESCAPE MANIFEST"
s"UNLOAD ('$query') TO '$fixedUrl' $credsString ESCAPE MANIFEST"
}

private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
Expand Down
32 changes: 26 additions & 6 deletions src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ import java.sql.{Connection, Date, SQLException, Timestamp}
import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.s3.AmazonS3Client
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.TaskContext
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.util.control.NonFatal

import com.databricks.spark.redshift.Parameters.MergedParameters

import org.apache.hadoop.conf.Configuration
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -63,6 +62,9 @@ private[redshift] class RedshiftWriter(

private val log = LoggerFactory.getLogger(getClass)

// http://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-cluster-configuration-object-encryption.html
private val EMRFS_CLIENT_SIDE_ENCRYPTION_KEY :String = "fs.s3.cse.enabled"

/**
* Generate CREATE TABLE statement for Redshift
*/
Expand Down Expand Up @@ -98,7 +100,7 @@ private[redshift] class RedshiftWriter(
case "AVRO" => "AVRO 'auto'"
case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'"
}
s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
s"COPY ${params.table.get} FROM '$fixedUrl' $credsString FORMAT AS " +
s"${format} manifest ${params.extraCopyOptions}"
}

Expand Down Expand Up @@ -295,19 +297,36 @@ private[redshift] class RedshiftWriter(
}).save(tempDir)

if (nonEmptyPartitions.value.isEmpty) {
log.info("Did not write any records. Not creating a manifest file.")
None
} else {
// See https://docs.aws.amazon.com/redshift/latest/dg/loading-data-files-using-manifest.html
// for a description of the manifest file format. The URLs in this manifest must be absolute
// and complete.

// The manifest file must be stored in plain text, even if the contents of the files being loaded into
// Redshift are encrypted. To make this work, we create a separate FileSystem without any encryption options
// set

// Clone existing configuration set by the user
val plainTextHadoopConfig = new Configuration(sqlContext.sparkContext.hadoopConfiguration)

// Make sure that we don't get a cached value of the file system, with client side encryption configuration
// set to true. FileSystem class will cache implementations of file system based on scheme and authority
// of the path of the file, so even if you pass a completely different configuration to FileSystem, you may
// still get a cached value of a FileSystem with different properties
plainTextHadoopConfig.set("fs.s3.impl.disable.cache", "true")

// Turn off any CSE if it's set
plainTextHadoopConfig.set(EMRFS_CLIENT_SIDE_ENCRYPTION_KEY, "false")

// The partition filenames are of the form part-r-XXXXX-UUID.fileExtension.
val fs = FileSystem.get(URI.create(tempDir), sqlContext.sparkContext.hadoopConfiguration)
val fs = FileSystem.get(URI.create(tempDir), plainTextHadoopConfig)
val partitionIdRegex = "^part-(?:r-)?(\\d+)[^\\d+].*$".r
val filesToLoad: Seq[String] = {
val nonEmptyPartitionIds = nonEmptyPartitions.value.toSet
fs.listStatus(new Path(tempDir)).map(_.getPath.getName).collect {
case file @ partitionIdRegex(id) if nonEmptyPartitionIds.contains(id.toInt) => file
case file@partitionIdRegex(id) if nonEmptyPartitionIds.contains(id.toInt) => file
}
}
// It's possible that tempDir contains AWS access keys. We shouldn't save those credentials to
Expand All @@ -325,6 +344,7 @@ private[redshift] class RedshiftWriter(
} finally {
fsDataOut.close()
}

Some(manifestPath)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AWSCredentialsUtilsSuite extends FunSuite {
val params =
Parameters.mergeParameters(baseParams ++ Map("forward_spark_s3_credentials" -> "true"))
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) ===
"aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY/WITH/SLASHES")
"access_key_id 'ACCESSKEYID' secret_access_key 'SECRET/KEY/WITH/SLASHES'")
}

test("credentialsString with STS temporary keys") {
Expand All @@ -51,15 +51,15 @@ class AWSCredentialsUtilsSuite extends FunSuite {
"temporary_aws_secret_access_key" -> "SECRET/KEY",
"temporary_aws_session_token" -> "SESSION/Token"))
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
"aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY;token=SESSION/Token")
"access_key_id 'ACCESSKEYID' secret_access_key 'SECRET/KEY' session_token 'SESSION/Token'")
}

test("Configured IAM roles should take precedence") {
val creds = new BasicSessionCredentials("ACCESSKEYID", "SECRET/KEY", "SESSION/Token")
val iamRole = "arn:aws:iam::123456789000:role/redshift_iam_role"
val params = Parameters.mergeParameters(baseParams ++ Map("aws_iam_role" -> iamRole))
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
s"aws_iam_role=$iamRole")
s"iam_role '$iamRole'")
}

test("AWSCredentials.load() STS temporary keys should take precedence") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class RedshiftSourceSuite
"\"testtimestamp\" " +
"FROM \"PUBLIC\".\"test_table\" '\\) " +
"TO '.*' " +
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"access_key_id 'test1' secret_access_key 'test2' " +
"ESCAPE").r
val mockRedshift = new MockRedshift(
defaultParams("url"),
Expand Down Expand Up @@ -230,7 +230,7 @@ class RedshiftSourceSuite
val expectedQuery = (
"UNLOAD \\('SELECT \"testbyte\", \"testbool\" FROM \"PUBLIC\".\"test_table\" '\\) " +
"TO '.*' " +
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"access_key_id 'test1' secret_access_key 'test2' " +
"ESCAPE").r
val mockRedshift =
new MockRedshift(defaultParams("url"), Map("test_table" -> TestUtils.testSchema))
Expand Down Expand Up @@ -270,7 +270,7 @@ class RedshiftSourceSuite
"AND \"testfloat\" >= 1.0 " +
"AND \"testint\" <= 43'\\) " +
"TO '.*' " +
"WITH CREDENTIALS 'aws_access_key_id=test1;aws_secret_access_key=test2' " +
"access_key_id 'test1' secret_access_key 'test2' " +
"ESCAPE").r
// scalastyle:on
val mockRedshift = new MockRedshift(
Expand Down