Skip to content

Commit

Permalink
snowflake-new-intfs
Browse files Browse the repository at this point in the history
  • Loading branch information
gisripa committed May 28, 2024
1 parent 3e36434 commit f59cd09
Show file tree
Hide file tree
Showing 16 changed files with 857 additions and 286 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
airbyteJavaConnector {
cdkVersionRequired = '0.35.13'
features = ['db-destinations', 's3-destinations', 'typing-deduping']
useLocalCdk = false
useLocalCdk = true
}

java {
Expand Down Expand Up @@ -50,4 +50,6 @@ integrationTestJava {
dependencies {
implementation 'net.snowflake:snowflake-jdbc:3.14.1'
implementation 'org.apache.commons:commons-text:1.10.0'

testImplementation "io.mockk:mockk:1.13.11"
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import com.zaxxer.hikari.HikariDataSource
import io.airbyte.cdk.db.jdbc.DefaultJdbcDatabase
import io.airbyte.cdk.db.jdbc.JdbcDatabase
import io.airbyte.cdk.db.jdbc.JdbcUtils
import io.airbyte.commons.exceptions.ConfigErrorException
import io.airbyte.commons.json.Jsons.deserialize
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType
import java.io.IOException
import java.io.PrintWriter
import java.net.URI
Expand All @@ -23,12 +25,13 @@ import java.util.*
import java.util.concurrent.TimeUnit
import java.util.stream.Collectors
import javax.sql.DataSource
import net.snowflake.client.jdbc.SnowflakeSQLException
import org.slf4j.Logger
import org.slf4j.LoggerFactory

/** SnowflakeDatabase contains helpers to create connections to and run queries on Snowflake. */
object SnowflakeDatabase {
private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeDatabase::class.java)
object SnowflakeDatabaseUtils {
private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeDatabaseUtils::class.java)
private const val PAUSE_BETWEEN_TOKEN_REFRESH_MIN =
7 // snowflake access token TTL is 10min and can't be modified

Expand All @@ -42,14 +45,20 @@ object SnowflakeDatabase {
.version(HttpClient.Version.HTTP_2)
.connectTimeout(Duration.ofSeconds(10))
.build()
const val PRIVATE_KEY_FILE_NAME: String = "rsa_key.p8"
const val PRIVATE_KEY_FIELD_NAME: String = "private_key"
const val PRIVATE_KEY_PASSWORD: String = "private_key_password"
private const val PRIVATE_KEY_FILE_NAME: String = "rsa_key.p8"
private const val PRIVATE_KEY_FIELD_NAME: String = "private_key"
private const val PRIVATE_KEY_PASSWORD: String = "private_key_password"
private const val CONNECTION_STRING_IDENTIFIER_KEY = "application"
private const val CONNECTION_STRING_IDENTIFIER_VAL = "Airbyte_Connector"

// This is an unfortunately fragile way to capture the errors, but Snowflake doesn't
// provide a more specific permission exception error code
private const val NO_PRIVILEGES_ERROR_MESSAGE = "but current role has no privileges on it"
private const val IP_NOT_IN_WHITE_LIST_ERR_MSG = "not allowed to access Snowflake"

@JvmStatic
fun createDataSource(config: JsonNode, airbyteEnvironment: String?): HikariDataSource {

val dataSource = HikariDataSource()

val jdbcUrl =
Expand Down Expand Up @@ -243,4 +252,45 @@ object SnowflakeDatabase {
}
}
}

fun checkForKnownConfigExceptions(e: Exception?): Optional<ConfigErrorException> {
if (e is SnowflakeSQLException && e.message!!.contains(NO_PRIVILEGES_ERROR_MESSAGE)) {
return Optional.of(
ConfigErrorException(
"Encountered Error with Snowflake Configuration: Current role does not have permissions on the target schema please verify your privileges",
e
)
)
}
if (e is SnowflakeSQLException && e.message!!.contains(IP_NOT_IN_WHITE_LIST_ERR_MSG)) {
return Optional.of(
ConfigErrorException(
"""
Snowflake has blocked access from Airbyte IP address. Please make sure that your Snowflake user account's
network policy allows access from all Airbyte IP addresses. See this page for the list of Airbyte IPs:
https://docs.airbyte.com/cloud/getting-started-with-airbyte-cloud#allowlist-ip-addresses and this page
for documentation on Snowflake network policies: https://docs.snowflake.com/en/user-guide/network-policies
""".trimIndent(),
e
)
)
}
return Optional.empty()
}

fun toSqlTypeName(airbyteProtocolType: AirbyteProtocolType): String {
return when (airbyteProtocolType) {
AirbyteProtocolType.STRING -> "TEXT"
AirbyteProtocolType.NUMBER -> "FLOAT"
AirbyteProtocolType.INTEGER -> "NUMBER"
AirbyteProtocolType.BOOLEAN -> "BOOLEAN"
AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_TZ"
AirbyteProtocolType.TIMESTAMP_WITHOUT_TIMEZONE -> "TIMESTAMP_NTZ"
AirbyteProtocolType.TIME_WITH_TIMEZONE -> "TEXT"
AirbyteProtocolType.TIME_WITHOUT_TIMEZONE -> "TIME"
AirbyteProtocolType.DATE -> "DATE"
AirbyteProtocolType.UNKNOWN -> "VARIANT"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer
import io.airbyte.cdk.integrations.base.TypingAndDedupingFlag.getRawNamespaceOverride
import io.airbyte.cdk.integrations.base.adaptive.AdaptiveDestinationRunner
import io.airbyte.cdk.integrations.destination.NamingConventionTransformer
import io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.jdbc.JdbcCheckOperations
import io.airbyte.cdk.integrations.destination.staging.StagingConsumerFactory.Companion.builder
import io.airbyte.cdk.integrations.destination.operation.SyncOperation
import io.airbyte.cdk.integrations.destination.s3.FileUploadFormat
import io.airbyte.cdk.integrations.destination.staging.operation.StagingStreamOperations
import io.airbyte.integrations.base.destination.operation.DefaultFlush
import io.airbyte.integrations.base.destination.operation.DefaultSyncOperation
import io.airbyte.integrations.base.destination.typing_deduping.CatalogParser
import io.airbyte.integrations.base.destination.typing_deduping.DefaultTyperDeduper
import io.airbyte.integrations.base.destination.typing_deduping.NoOpTyperDeduperWithV1V2Migrations
import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialStatus
import io.airbyte.integrations.base.destination.typing_deduping.ParsedCatalog
import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduper
import io.airbyte.integrations.base.destination.typing_deduping.migrators.Migration
import io.airbyte.integrations.destination.snowflake.migrations.SnowflakeState
import io.airbyte.integrations.destination.snowflake.operation.SnowflakeStagingClient
import io.airbyte.integrations.destination.snowflake.operation.SnowflakeStorageOperation
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeDestinationHandler
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeSqlGenerator
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeV1V2Migrator
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeV2TableMigrator
import io.airbyte.protocol.models.v0.AirbyteConnectionStatus
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog
Expand Down Expand Up @@ -93,11 +97,11 @@ constructor(
}

private fun getDataSource(config: JsonNode): DataSource {
return SnowflakeDatabase.createDataSource(config, airbyteEnvironment)
return SnowflakeDatabaseUtils.createDataSource(config, airbyteEnvironment)
}

private fun getDatabase(dataSource: DataSource): JdbcDatabase {
return SnowflakeDatabase.getDatabase(dataSource)
return SnowflakeDatabaseUtils.getDatabase(dataSource)
}

override fun getSerializedMessageConsumer(
Expand All @@ -115,13 +119,10 @@ constructor(
}

val retentionPeriodDays =
SnowflakeSqlOperations.getRetentionPeriodDays(
config[SnowflakeSqlOperations.RETENTION_PERIOD_DAYS_CONFIG_KEY],
getRetentionPeriodDays(
config[RETENTION_PERIOD_DAYS],
)

val sqlGenerator = SnowflakeSqlGenerator(retentionPeriodDays)
val parsedCatalog: ParsedCatalog
val typerDeduper: TyperDeduper
val database = getDatabase(getDataSource(config))
val databaseName = config[JdbcUtils.DATABASE_KEY].asText()
val rawTableSchemaName: String
Expand All @@ -135,63 +136,51 @@ constructor(
}
val snowflakeDestinationHandler =
SnowflakeDestinationHandler(databaseName, database, rawTableSchemaName)
parsedCatalog = catalogParser.parseCatalog(catalog)
val migrator = SnowflakeV1V2Migrator(this.nameTransformer, database, databaseName)
val v2TableMigrator =
SnowflakeV2TableMigrator(
database,
databaseName,
sqlGenerator,
snowflakeDestinationHandler
)
val parsedCatalog: ParsedCatalog = catalogParser.parseCatalog(catalog)
val disableTypeDedupe =
config.has(DISABLE_TYPE_DEDUPE) && config[DISABLE_TYPE_DEDUPE].asBoolean(false)
val migrations = listOf<Migration<SnowflakeState>>()
typerDeduper =
if (disableTypeDedupe) {
NoOpTyperDeduperWithV1V2Migrations(
sqlGenerator,
snowflakeDestinationHandler,
parsedCatalog,
migrator,
v2TableMigrator,
migrations
)
} else {
DefaultTyperDeduper(
sqlGenerator,
snowflakeDestinationHandler,
parsedCatalog,
migrator,
v2TableMigrator,
migrations,
)
}

return builder(
outputRecordCollector,
database,
SnowflakeInternalStagingSqlOperations(nameTransformer),
nameTransformer,
config,
catalog,
true,
typerDeduper,
val snowflakeStagingClient = SnowflakeStagingClient(database)

val snowflakeStorageOperation =
SnowflakeStorageOperation(
sqlGenerator = sqlGenerator,
destinationHandler = snowflakeDestinationHandler,
retentionPeriodDays,
snowflakeStagingClient
)

val syncOperation: SyncOperation =
DefaultSyncOperation(
parsedCatalog,
snowflakeDestinationHandler,
defaultNamespace,
JavaBaseConstants.DestinationColumns.V2_WITHOUT_META,
{ initialStatus: DestinationInitialStatus<SnowflakeState>, disableTD ->
StagingStreamOperations(
snowflakeStorageOperation,
initialStatus,
FileUploadFormat.CSV,
JavaBaseConstants.DestinationColumns.V2_WITHOUT_META,
disableTD
)
},
migrations,
disableTypeDedupe
)
.setBufferMemoryLimit(Optional.of(snowflakeBufferMemoryLimit))
.setOptimalBatchSizeBytes(
// The per stream size limit is following recommendations from:
// https://docs.snowflake.com/en/user-guide/data-load-considerations-prepare.html#general-file-sizing-recommendations
// "To optimize the number of parallel operations for a load,
// we recommend aiming to produce data files roughly 100-250 MB (or larger) in size
// compressed."
(200 * 1024 * 1024).toLong(),
)
.build()
.createAsync()

return AsyncStreamConsumer(
outputRecordCollector = outputRecordCollector,
onStart = {},
onClose = { _, streamSyncSummaries ->
syncOperation.finalizeStreams(streamSyncSummaries)
SCHEDULED_EXECUTOR_SERVICE.shutdownNow()
},
onFlush = DefaultFlush(optimalFlushBatchSize, syncOperation),
catalog = catalog,
bufferManager = BufferManager(snowflakeBufferMemoryLimit),
defaultNamespace = Optional.of(defaultNamespace),
)
}

override val isV2Destination: Boolean
Expand All @@ -209,7 +198,7 @@ constructor(
companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeDestination::class.java)
const val RAW_SCHEMA_OVERRIDE: String = "raw_data_schema"

const val RETENTION_PERIOD_DAYS: String = "retention_period_days"
const val DISABLE_TYPE_DEDUPE: String = "disable_type_dedupe"
@JvmField
val SCHEDULED_EXECUTOR_SERVICE: ScheduledExecutorService =
Expand Down Expand Up @@ -241,23 +230,43 @@ constructor(
}
}

fun getRetentionPeriodDays(node: JsonNode?): Int {
val retentionPeriodDays =
if (node == null || node.isNull) {
1
} else {
node.asInt()
}
return retentionPeriodDays
}

private val snowflakeBufferMemoryLimit: Long
get() = (Runtime.getRuntime().maxMemory() * 0.5).toLong()

// The per stream size limit is following recommendations from:
// https://docs.snowflake.com/en/user-guide/data-load-considerations-prepare.html#general-file-sizing-recommendations
// "To optimize the number of parallel operations for a load,
// we recommend aiming to produce data files roughly 100-250 MB (or larger) in size
// compressed."
private val optimalFlushBatchSize: Long
get() = (200 * 1024 * 1024).toLong()
}
}

fun main(args: Array<String>) {
IntegrationRunner.addOrphanedThreadFilter { t: Thread ->
for (stackTraceElement in IntegrationRunner.getThreadCreationInfo(t).stack) {
val stackClassName = stackTraceElement.className
val stackMethodName = stackTraceElement.methodName
if (
SFStatement::class.java.canonicalName == stackClassName &&
"close" == stackMethodName ||
SFSession::class.java.canonicalName == stackClassName &&
"callHeartBeatWithQueryTimeout" == stackMethodName
) {
return@addOrphanedThreadFilter false
if (IntegrationRunner.getThreadCreationInfo(t) != null) {
for (stackTraceElement in IntegrationRunner.getThreadCreationInfo(t)!!.stack) {
val stackClassName = stackTraceElement.className
val stackMethodName = stackTraceElement.methodName
if (
SFStatement::class.java.canonicalName == stackClassName &&
"close" == stackMethodName ||
SFSession::class.java.canonicalName == stackClassName &&
"callHeartBeatWithQueryTimeout" == stackMethodName
) {
return@addOrphanedThreadFilter false
}
}
}
true
Expand All @@ -277,5 +286,4 @@ fun main(args: Array<String>) {
)
}
.run(args)
SnowflakeDestination.SCHEDULED_EXECUTOR_SERVICE.shutdownNow()
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import io.airbyte.commons.exceptions.ConfigErrorException
import java.sql.SQLException
import java.util.*
import java.util.function.Consumer
import net.snowflake.client.jdbc.SnowflakeSQLException
import org.slf4j.Logger
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -143,29 +142,7 @@ open class SnowflakeSqlOperations : JdbcSqlOperations(), SqlOperations {
}

override fun checkForKnownConfigExceptions(e: Exception?): Optional<ConfigErrorException> {
if (e is SnowflakeSQLException && e.message!!.contains(NO_PRIVILEGES_ERROR_MESSAGE)) {
return Optional.of(
ConfigErrorException(
"Encountered Error with Snowflake Configuration: Current role does not have permissions on the target schema please verify your privileges",
e
)
)
}
if (e is SnowflakeSQLException && e.message!!.contains(IP_NOT_IN_WHITE_LIST_ERR_MSG)) {
return Optional.of(
ConfigErrorException(
"""
Snowflake has blocked access from Airbyte IP address. Please make sure that your Snowflake user account's
network policy allows access from all Airbyte IP addresses. See this page for the list of Airbyte IPs:
https://docs.airbyte.com/cloud/getting-started-with-airbyte-cloud#allowlist-ip-addresses and this page
for documentation on Snowflake network policies: https://docs.snowflake.com/en/user-guide/network-policies
""".trimIndent(),
e
)
)
}
return Optional.empty()
return SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e)
}

companion object {
Expand All @@ -174,11 +151,6 @@ open class SnowflakeSqlOperations : JdbcSqlOperations(), SqlOperations {
private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeSqlOperations::class.java)
private const val MAX_FILES_IN_LOADING_QUERY_LIMIT = 1000

// This is an unfortunately fragile way to capture this, but Snowflake doesn't
// provide a more specific permission exception error code
private const val NO_PRIVILEGES_ERROR_MESSAGE = "but current role has no privileges on it"
private const val IP_NOT_IN_WHITE_LIST_ERR_MSG = "not allowed to access Snowflake"

private val retentionPeriodDaysFromConfigSingleton: Int
/**
* Sort of hacky. The problem is that SnowflakeSqlOperations is constructed in the
Expand Down
Loading

0 comments on commit f59cd09

Please sign in to comment.