diff --git a/airbyte-integrations/connectors/destination-snowflake/build.gradle b/airbyte-integrations/connectors/destination-snowflake/build.gradle index 61f54b0647cd..061b1ac96fde 100644 --- a/airbyte-integrations/connectors/destination-snowflake/build.gradle +++ b/airbyte-integrations/connectors/destination-snowflake/build.gradle @@ -5,7 +5,7 @@ plugins { airbyteJavaConnector { cdkVersionRequired = '0.35.13' features = ['db-destinations', 's3-destinations', 'typing-deduping'] - useLocalCdk = false + useLocalCdk = true } java { @@ -50,4 +50,6 @@ integrationTestJava { dependencies { implementation 'net.snowflake:snowflake-jdbc:3.14.1' implementation 'org.apache.commons:commons-text:1.10.0' + + testImplementation "io.mockk:mockk:1.13.11" } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDatabase.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDatabaseUtils.kt similarity index 79% rename from airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDatabase.kt rename to airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDatabaseUtils.kt index ed0e517b7b68..cb8ce1f2ffd2 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDatabase.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDatabaseUtils.kt @@ -8,7 +8,9 @@ import com.zaxxer.hikari.HikariDataSource import io.airbyte.cdk.db.jdbc.DefaultJdbcDatabase import io.airbyte.cdk.db.jdbc.JdbcDatabase import io.airbyte.cdk.db.jdbc.JdbcUtils +import io.airbyte.commons.exceptions.ConfigErrorException import io.airbyte.commons.json.Jsons.deserialize +import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType import java.io.IOException import java.io.PrintWriter import java.net.URI @@ -23,12 +25,13 @@ import java.util.* import java.util.concurrent.TimeUnit import java.util.stream.Collectors import javax.sql.DataSource +import net.snowflake.client.jdbc.SnowflakeSQLException import org.slf4j.Logger import org.slf4j.LoggerFactory /** SnowflakeDatabase contains helpers to create connections to and run queries on Snowflake. */ -object SnowflakeDatabase { - private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeDatabase::class.java) +object SnowflakeDatabaseUtils { + private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeDatabaseUtils::class.java) private const val PAUSE_BETWEEN_TOKEN_REFRESH_MIN = 7 // snowflake access token TTL is 10min and can't be modified @@ -42,14 +45,20 @@ object SnowflakeDatabase { .version(HttpClient.Version.HTTP_2) .connectTimeout(Duration.ofSeconds(10)) .build() - const val PRIVATE_KEY_FILE_NAME: String = "rsa_key.p8" - const val PRIVATE_KEY_FIELD_NAME: String = "private_key" - const val PRIVATE_KEY_PASSWORD: String = "private_key_password" + private const val PRIVATE_KEY_FILE_NAME: String = "rsa_key.p8" + private const val PRIVATE_KEY_FIELD_NAME: String = "private_key" + private const val PRIVATE_KEY_PASSWORD: String = "private_key_password" private const val CONNECTION_STRING_IDENTIFIER_KEY = "application" private const val CONNECTION_STRING_IDENTIFIER_VAL = "Airbyte_Connector" + // This is an unfortunately fragile way to capture the errors, but Snowflake doesn't + // provide a more specific permission exception error code + private const val NO_PRIVILEGES_ERROR_MESSAGE = "but current role has no privileges on it" + private const val IP_NOT_IN_WHITE_LIST_ERR_MSG = "not allowed to access Snowflake" + @JvmStatic fun createDataSource(config: JsonNode, airbyteEnvironment: String?): HikariDataSource { + val dataSource = HikariDataSource() val jdbcUrl = @@ -243,4 +252,45 @@ object SnowflakeDatabase { } } } + + fun checkForKnownConfigExceptions(e: Exception?): Optional { + if (e is SnowflakeSQLException && e.message!!.contains(NO_PRIVILEGES_ERROR_MESSAGE)) { + return Optional.of( + ConfigErrorException( + "Encountered Error with Snowflake Configuration: Current role does not have permissions on the target schema please verify your privileges", + e + ) + ) + } + if (e is SnowflakeSQLException && e.message!!.contains(IP_NOT_IN_WHITE_LIST_ERR_MSG)) { + return Optional.of( + ConfigErrorException( + """ + Snowflake has blocked access from Airbyte IP address. Please make sure that your Snowflake user account's + network policy allows access from all Airbyte IP addresses. See this page for the list of Airbyte IPs: + https://docs.airbyte.com/cloud/getting-started-with-airbyte-cloud#allowlist-ip-addresses and this page + for documentation on Snowflake network policies: https://docs.snowflake.com/en/user-guide/network-policies + + """.trimIndent(), + e + ) + ) + } + return Optional.empty() + } + + fun toSqlTypeName(airbyteProtocolType: AirbyteProtocolType): String { + return when (airbyteProtocolType) { + AirbyteProtocolType.STRING -> "TEXT" + AirbyteProtocolType.NUMBER -> "FLOAT" + AirbyteProtocolType.INTEGER -> "NUMBER" + AirbyteProtocolType.BOOLEAN -> "BOOLEAN" + AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_TZ" + AirbyteProtocolType.TIMESTAMP_WITHOUT_TIMEZONE -> "TIMESTAMP_NTZ" + AirbyteProtocolType.TIME_WITH_TIMEZONE -> "TEXT" + AirbyteProtocolType.TIME_WITHOUT_TIMEZONE -> "TIME" + AirbyteProtocolType.DATE -> "DATE" + AirbyteProtocolType.UNKNOWN -> "VARIANT" + } + } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestination.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestination.kt index a539c804dcfd..2f0ddd391748 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestination.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestination.kt @@ -17,19 +17,23 @@ import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer import io.airbyte.cdk.integrations.base.TypingAndDedupingFlag.getRawNamespaceOverride import io.airbyte.cdk.integrations.base.adaptive.AdaptiveDestinationRunner import io.airbyte.cdk.integrations.destination.NamingConventionTransformer +import io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer +import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager import io.airbyte.cdk.integrations.destination.jdbc.JdbcCheckOperations -import io.airbyte.cdk.integrations.destination.staging.StagingConsumerFactory.Companion.builder +import io.airbyte.cdk.integrations.destination.operation.SyncOperation +import io.airbyte.cdk.integrations.destination.s3.FileUploadFormat +import io.airbyte.cdk.integrations.destination.staging.operation.StagingStreamOperations +import io.airbyte.integrations.base.destination.operation.DefaultFlush +import io.airbyte.integrations.base.destination.operation.DefaultSyncOperation import io.airbyte.integrations.base.destination.typing_deduping.CatalogParser -import io.airbyte.integrations.base.destination.typing_deduping.DefaultTyperDeduper -import io.airbyte.integrations.base.destination.typing_deduping.NoOpTyperDeduperWithV1V2Migrations +import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialStatus import io.airbyte.integrations.base.destination.typing_deduping.ParsedCatalog -import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduper import io.airbyte.integrations.base.destination.typing_deduping.migrators.Migration import io.airbyte.integrations.destination.snowflake.migrations.SnowflakeState +import io.airbyte.integrations.destination.snowflake.operation.SnowflakeStagingClient +import io.airbyte.integrations.destination.snowflake.operation.SnowflakeStorageOperation import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeDestinationHandler import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeSqlGenerator -import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeV1V2Migrator -import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeV2TableMigrator import io.airbyte.protocol.models.v0.AirbyteConnectionStatus import io.airbyte.protocol.models.v0.AirbyteMessage import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog @@ -93,11 +97,11 @@ constructor( } private fun getDataSource(config: JsonNode): DataSource { - return SnowflakeDatabase.createDataSource(config, airbyteEnvironment) + return SnowflakeDatabaseUtils.createDataSource(config, airbyteEnvironment) } private fun getDatabase(dataSource: DataSource): JdbcDatabase { - return SnowflakeDatabase.getDatabase(dataSource) + return SnowflakeDatabaseUtils.getDatabase(dataSource) } override fun getSerializedMessageConsumer( @@ -115,13 +119,10 @@ constructor( } val retentionPeriodDays = - SnowflakeSqlOperations.getRetentionPeriodDays( - config[SnowflakeSqlOperations.RETENTION_PERIOD_DAYS_CONFIG_KEY], + getRetentionPeriodDays( + config[RETENTION_PERIOD_DAYS], ) - val sqlGenerator = SnowflakeSqlGenerator(retentionPeriodDays) - val parsedCatalog: ParsedCatalog - val typerDeduper: TyperDeduper val database = getDatabase(getDataSource(config)) val databaseName = config[JdbcUtils.DATABASE_KEY].asText() val rawTableSchemaName: String @@ -135,63 +136,51 @@ constructor( } val snowflakeDestinationHandler = SnowflakeDestinationHandler(databaseName, database, rawTableSchemaName) - parsedCatalog = catalogParser.parseCatalog(catalog) - val migrator = SnowflakeV1V2Migrator(this.nameTransformer, database, databaseName) - val v2TableMigrator = - SnowflakeV2TableMigrator( - database, - databaseName, - sqlGenerator, - snowflakeDestinationHandler - ) + val parsedCatalog: ParsedCatalog = catalogParser.parseCatalog(catalog) val disableTypeDedupe = config.has(DISABLE_TYPE_DEDUPE) && config[DISABLE_TYPE_DEDUPE].asBoolean(false) val migrations = listOf>() - typerDeduper = - if (disableTypeDedupe) { - NoOpTyperDeduperWithV1V2Migrations( - sqlGenerator, - snowflakeDestinationHandler, - parsedCatalog, - migrator, - v2TableMigrator, - migrations - ) - } else { - DefaultTyperDeduper( - sqlGenerator, - snowflakeDestinationHandler, - parsedCatalog, - migrator, - v2TableMigrator, - migrations, - ) - } - return builder( - outputRecordCollector, - database, - SnowflakeInternalStagingSqlOperations(nameTransformer), - nameTransformer, - config, - catalog, - true, - typerDeduper, + val snowflakeStagingClient = SnowflakeStagingClient(database) + + val snowflakeStorageOperation = + SnowflakeStorageOperation( + sqlGenerator = sqlGenerator, + destinationHandler = snowflakeDestinationHandler, + retentionPeriodDays, + snowflakeStagingClient + ) + + val syncOperation: SyncOperation = + DefaultSyncOperation( parsedCatalog, + snowflakeDestinationHandler, defaultNamespace, - JavaBaseConstants.DestinationColumns.V2_WITHOUT_META, + { initialStatus: DestinationInitialStatus, disableTD -> + StagingStreamOperations( + snowflakeStorageOperation, + initialStatus, + FileUploadFormat.CSV, + JavaBaseConstants.DestinationColumns.V2_WITHOUT_META, + disableTD + ) + }, + migrations, + disableTypeDedupe ) - .setBufferMemoryLimit(Optional.of(snowflakeBufferMemoryLimit)) - .setOptimalBatchSizeBytes( - // The per stream size limit is following recommendations from: - // https://docs.snowflake.com/en/user-guide/data-load-considerations-prepare.html#general-file-sizing-recommendations - // "To optimize the number of parallel operations for a load, - // we recommend aiming to produce data files roughly 100-250 MB (or larger) in size - // compressed." - (200 * 1024 * 1024).toLong(), - ) - .build() - .createAsync() + + return AsyncStreamConsumer( + outputRecordCollector = outputRecordCollector, + onStart = {}, + onClose = { _, streamSyncSummaries -> + syncOperation.finalizeStreams(streamSyncSummaries) + SCHEDULED_EXECUTOR_SERVICE.shutdownNow() + }, + onFlush = DefaultFlush(optimalFlushBatchSize, syncOperation), + catalog = catalog, + bufferManager = BufferManager(snowflakeBufferMemoryLimit), + defaultNamespace = Optional.of(defaultNamespace), + ) } override val isV2Destination: Boolean @@ -209,7 +198,7 @@ constructor( companion object { private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeDestination::class.java) const val RAW_SCHEMA_OVERRIDE: String = "raw_data_schema" - + const val RETENTION_PERIOD_DAYS: String = "retention_period_days" const val DISABLE_TYPE_DEDUPE: String = "disable_type_dedupe" @JvmField val SCHEDULED_EXECUTOR_SERVICE: ScheduledExecutorService = @@ -241,23 +230,43 @@ constructor( } } + fun getRetentionPeriodDays(node: JsonNode?): Int { + val retentionPeriodDays = + if (node == null || node.isNull) { + 1 + } else { + node.asInt() + } + return retentionPeriodDays + } + private val snowflakeBufferMemoryLimit: Long get() = (Runtime.getRuntime().maxMemory() * 0.5).toLong() + + // The per stream size limit is following recommendations from: + // https://docs.snowflake.com/en/user-guide/data-load-considerations-prepare.html#general-file-sizing-recommendations + // "To optimize the number of parallel operations for a load, + // we recommend aiming to produce data files roughly 100-250 MB (or larger) in size + // compressed." + private val optimalFlushBatchSize: Long + get() = (200 * 1024 * 1024).toLong() } } fun main(args: Array) { IntegrationRunner.addOrphanedThreadFilter { t: Thread -> - for (stackTraceElement in IntegrationRunner.getThreadCreationInfo(t).stack) { - val stackClassName = stackTraceElement.className - val stackMethodName = stackTraceElement.methodName - if ( - SFStatement::class.java.canonicalName == stackClassName && - "close" == stackMethodName || - SFSession::class.java.canonicalName == stackClassName && - "callHeartBeatWithQueryTimeout" == stackMethodName - ) { - return@addOrphanedThreadFilter false + if (IntegrationRunner.getThreadCreationInfo(t) != null) { + for (stackTraceElement in IntegrationRunner.getThreadCreationInfo(t)!!.stack) { + val stackClassName = stackTraceElement.className + val stackMethodName = stackTraceElement.methodName + if ( + SFStatement::class.java.canonicalName == stackClassName && + "close" == stackMethodName || + SFSession::class.java.canonicalName == stackClassName && + "callHeartBeatWithQueryTimeout" == stackMethodName + ) { + return@addOrphanedThreadFilter false + } } } true @@ -277,5 +286,4 @@ fun main(args: Array) { ) } .run(args) - SnowflakeDestination.SCHEDULED_EXECUTOR_SERVICE.shutdownNow() } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeSqlOperations.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeSqlOperations.kt index fdbe3ecd6ddc..352ca5f7dd2b 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeSqlOperations.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeSqlOperations.kt @@ -15,7 +15,6 @@ import io.airbyte.commons.exceptions.ConfigErrorException import java.sql.SQLException import java.util.* import java.util.function.Consumer -import net.snowflake.client.jdbc.SnowflakeSQLException import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -143,29 +142,7 @@ open class SnowflakeSqlOperations : JdbcSqlOperations(), SqlOperations { } override fun checkForKnownConfigExceptions(e: Exception?): Optional { - if (e is SnowflakeSQLException && e.message!!.contains(NO_PRIVILEGES_ERROR_MESSAGE)) { - return Optional.of( - ConfigErrorException( - "Encountered Error with Snowflake Configuration: Current role does not have permissions on the target schema please verify your privileges", - e - ) - ) - } - if (e is SnowflakeSQLException && e.message!!.contains(IP_NOT_IN_WHITE_LIST_ERR_MSG)) { - return Optional.of( - ConfigErrorException( - """ - Snowflake has blocked access from Airbyte IP address. Please make sure that your Snowflake user account's - network policy allows access from all Airbyte IP addresses. See this page for the list of Airbyte IPs: - https://docs.airbyte.com/cloud/getting-started-with-airbyte-cloud#allowlist-ip-addresses and this page - for documentation on Snowflake network policies: https://docs.snowflake.com/en/user-guide/network-policies - - """.trimIndent(), - e - ) - ) - } - return Optional.empty() + return SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e) } companion object { @@ -174,11 +151,6 @@ open class SnowflakeSqlOperations : JdbcSqlOperations(), SqlOperations { private val LOGGER: Logger = LoggerFactory.getLogger(SnowflakeSqlOperations::class.java) private const val MAX_FILES_IN_LOADING_QUERY_LIMIT = 1000 - // This is an unfortunately fragile way to capture this, but Snowflake doesn't - // provide a more specific permission exception error code - private const val NO_PRIVILEGES_ERROR_MESSAGE = "but current role has no privileges on it" - private const val IP_NOT_IN_WHITE_LIST_ERR_MSG = "not allowed to access Snowflake" - private val retentionPeriodDaysFromConfigSingleton: Int /** * Sort of hacky. The problem is that SnowflakeSqlOperations is constructed in the diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStagingClient.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStagingClient.kt new file mode 100644 index 000000000000..5992e4aa8f23 --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStagingClient.kt @@ -0,0 +1,248 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ +package io.airbyte.integrations.destination.snowflake.operation + +import io.airbyte.cdk.db.jdbc.JdbcDatabase +import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer +import io.airbyte.commons.string.Strings.join +import io.airbyte.integrations.base.destination.typing_deduping.StreamId +import io.airbyte.integrations.destination.snowflake.SnowflakeDatabaseUtils +import io.github.oshai.kotlinlogging.KotlinLogging +import java.io.IOException +import java.sql.SQLException +import java.util.* + +private val log = KotlinLogging.logger {} + +/** Client wrapper providing Snowflake Stage related operations. */ +class SnowflakeStagingClient(private val database: JdbcDatabase) { + + // Most of the code here is preserved from + // https://github.com/airbytehq/airbyte/blob/503b819b846663b0dff4c90322d0219a93e61d14/airbyte-integrations/connectors/destination-snowflake/src/main/java/io/airbyte/integrations/destination/snowflake/SnowflakeInternalStagingSqlOperations.java + @Throws(IOException::class) + fun uploadRecordsToStage( + recordsData: SerializableBuffer, + stageName: String, + stagingPath: String + ): String { + val exceptionsThrown: MutableList = ArrayList() + var succeeded = false + while (exceptionsThrown.size < UPLOAD_RETRY_LIMIT && !succeeded) { + try { + uploadRecordsToBucket(stageName, stagingPath, recordsData) + succeeded = true + } catch (e: Exception) { + log.error(e) { "Failed to upload records into stage $stagingPath" } + exceptionsThrown.add(e) + } + if (!succeeded) { + log.info { + "Retrying to upload records into stage $stagingPath (${exceptionsThrown.size}/$UPLOAD_RETRY_LIMIT})" + } + } + } + if (!succeeded) { + throw RuntimeException( + String.format( + "Exceptions thrown while uploading records into stage: %s", + join(exceptionsThrown, "\n") + ) + ) + } + log.info { + "Successfully loaded records to stage $stagingPath with ${exceptionsThrown.size} re-attempt(s)" + } + return recordsData.filename + } + + @Throws(Exception::class) + private fun uploadRecordsToBucket( + stageName: String, + stagingPath: String, + recordsData: SerializableBuffer + ) { + val query = getPutQuery(stageName, stagingPath, recordsData.file!!.absolutePath) + log.info { "Executing query: $query" } + database.execute(query) + if (!checkStageObjectExists(stageName, stagingPath, recordsData.filename)) { + log.error { + "Failed to upload data into stage, object @${ + (stagingPath + "/" + recordsData.filename).replace( + "/+".toRegex(), + "/", + ) + } not found" + } + throw RuntimeException("Upload failed") + } + } + + internal fun getPutQuery(stageName: String, stagingPath: String, filePath: String): String { + return String.format( + PUT_FILE_QUERY, + filePath, + stageName, + stagingPath, + Runtime.getRuntime().availableProcessors() + ) + } + + @Throws(SQLException::class) + private fun checkStageObjectExists( + stageName: String, + stagingPath: String, + filename: String + ): Boolean { + val query = getListQuery(stageName, stagingPath, filename) + log.debug { "Executing query: $query" } + val result: Boolean + database.unsafeQuery(query).use { stream -> result = stream.findAny().isPresent } + return result + } + + /** + * Creates a SQL query to list file which is staged + * + * @param stageName name of staging folder + * @param stagingPath path to the files within the staging folder + * @param filename name of the file within staging area + * @return SQL query string + */ + internal fun getListQuery(stageName: String, stagingPath: String, filename: String): String { + return String.format(LIST_STAGE_QUERY, stageName, stagingPath, filename) + .replace("/+".toRegex(), "/") + } + + @Throws(Exception::class) + fun createStageIfNotExists(stageName: String) { + val query = getCreateStageQuery(stageName) + log.debug { "Executing query: $query" } + try { + database.execute(query) + } catch (e: Exception) { + throw SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e).orElseThrow { e } + } + } + + /** + * Creates a SQL query to create a staging folder. This query will create a staging folder if + * one previously did not exist + * + * @param stageName name of the staging folder + * @return SQL query string + */ + internal fun getCreateStageQuery(stageName: String): String { + return String.format(CREATE_STAGE_QUERY, stageName) + } + + @Throws(SQLException::class) + fun copyIntoTableFromStage( + stageName: String, + stagingPath: String, + stagedFiles: List, + streamId: StreamId + ) { + try { + val query = getCopyQuery(stageName, stagingPath, stagedFiles, streamId) + log.info { "Executing query: $query" } + database.execute(query) + } catch (e: SQLException) { + throw SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e).orElseThrow { e } + } + } + + /** + * Creates a SQL query to bulk copy data into fully qualified destination table See + * https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html for more context + * + * @param stageName name of staging folder + * @param stagingPath path of staging folder to data files + * @param stagedFiles collection of the staging files + * @param streamId + * @return SQL query string + */ + internal fun getCopyQuery( + stageName: String, + stagingPath: String, + stagedFiles: List, + streamId: StreamId + ): String { + return String.format( + COPY_QUERY_1S1T + generateFilesList(stagedFiles) + ";", + streamId.rawNamespace, + streamId.rawName, + stageName, + stagingPath + ) + } + + // TODO: Do we need this sketchy logic when all we use is just 1 file. + private fun generateFilesList(files: List): String { + if (0 < files.size && files.size < MAX_FILES_IN_LOADING_QUERY_LIMIT) { + // see + // https://docs.snowflake.com/en/user-guide/data-load-considerations-load.html#lists-of-files + val filesString = + files.joinToString { filename: String -> + "'${ + filename.substring( + filename.lastIndexOf("/") + 1, + ) + }'" + } + return " files = ($filesString)" + } else { + return "" + } + } + + @Throws(Exception::class) + fun dropStageIfExists(stageName: String) { + try { + val query = getDropQuery(stageName) + log.debug { "Executing query: $query" } + database.execute(query) + } catch (e: SQLException) { + throw SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e).orElseThrow { e } + } + } + + /** + * Creates a SQL query to drop staging area and all associated files within the staged area + * https://docs.snowflake.com/en/sql-reference/sql/drop-stage + * @param stageName name of staging folder + * @return SQL query string + */ + internal fun getDropQuery(stageName: String?): String { + return String.format(DROP_STAGE_QUERY, stageName) + } + + companion object { + private const val UPLOAD_RETRY_LIMIT: Int = 3 + private const val MAX_FILES_IN_LOADING_QUERY_LIMIT = 1000 + private const val CREATE_STAGE_QUERY = + "CREATE STAGE IF NOT EXISTS %s encryption = (type = 'SNOWFLAKE_SSE') copy_options = (on_error='skip_file');" + private const val PUT_FILE_QUERY = "PUT file://%s @%s/%s PARALLEL = %d;" + private const val LIST_STAGE_QUERY = "LIST @%s/%s/%s;" + + // the 1s1t copy query explicitly quotes the raw table+schema name. + // we set error_on_column_count_mismatch because (at time of writing), we haven't yet added + // the airbyte_meta column to the raw table. + // See also https://github.com/airbytehq/airbyte/issues/36410 for improved error handling. + // TODO remove error_on_column_count_mismatch once snowflake has airbyte_meta in raw data. + private val COPY_QUERY_1S1T = + """ + |COPY INTO "%s"."%s" FROM '@%s/%s' + |file_format = ( + | type = csv + | compression = auto + | field_delimiter = ',' + | skip_header = 0 + | FIELD_OPTIONALLY_ENCLOSED_BY = '"' + | NULL_IF=('') + | error_on_column_count_mismatch=false + |) + """.trimMargin() + private const val DROP_STAGE_QUERY = "DROP STAGE IF EXISTS %s;" + } +} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStorageOperation.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStorageOperation.kt new file mode 100644 index 000000000000..9bcfdadd7169 --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStorageOperation.kt @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.integrations.destination.snowflake.operation + +import io.airbyte.cdk.integrations.base.JavaBaseConstants +import io.airbyte.cdk.integrations.destination.StandardNameTransformer +import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer +import io.airbyte.integrations.base.destination.operation.StorageOperation +import io.airbyte.integrations.base.destination.typing_deduping.Sql +import io.airbyte.integrations.base.destination.typing_deduping.StreamConfig +import io.airbyte.integrations.base.destination.typing_deduping.StreamId +import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduperUtil +import io.airbyte.integrations.destination.snowflake.SnowflakeSQLNameTransformer +import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeDestinationHandler +import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeSqlGenerator +import io.airbyte.protocol.models.v0.DestinationSyncMode +import io.github.oshai.kotlinlogging.KotlinLogging +import java.time.Instant +import java.time.ZoneOffset +import java.time.ZonedDateTime +import java.util.* + +private val log = KotlinLogging.logger {} + +class SnowflakeStorageOperation( + private val sqlGenerator: SnowflakeSqlGenerator, + private val destinationHandler: SnowflakeDestinationHandler, + private val retentionPeriodDays: Int, + private val staging: SnowflakeStagingClient, + private val nameTransformer: StandardNameTransformer = SnowflakeSQLNameTransformer(), +) : StorageOperation { + + private val connectionId = UUID.randomUUID() + private val syncDateTime = Instant.now() + + override fun prepareStage(streamId: StreamId, destinationSyncMode: DestinationSyncMode) { + // create raw table + destinationHandler.execute(Sql.of(createTableQuery(streamId))) + if (destinationSyncMode == DestinationSyncMode.OVERWRITE) { + destinationHandler.execute(Sql.of(truncateTableQuery(streamId))) + } + // create stage + staging.createStageIfNotExists(getStageName(streamId)) + } + + internal fun createTableQuery(streamId: StreamId): String { + return """ + |CREATE TABLE IF NOT EXISTS "${streamId.rawNamespace}"."${streamId.rawName}"( + | "${JavaBaseConstants.COLUMN_NAME_AB_RAW_ID}" VARCHAR PRIMARY KEY, + | "${JavaBaseConstants.COLUMN_NAME_AB_EXTRACTED_AT}" TIMESTAMP WITH TIME ZONE DEFAULT current_timestamp(), + | "${JavaBaseConstants.COLUMN_NAME_AB_LOADED_AT}" TIMESTAMP WITH TIME ZONE DEFAULT NULL, + | "${JavaBaseConstants.COLUMN_NAME_DATA}" VARIANT + |) data_retention_time_in_days = $retentionPeriodDays; + """.trimMargin() + } + + internal fun truncateTableQuery(streamId: StreamId): String { + return "TRUNCATE TABLE \"${streamId.rawNamespace}\".\"${streamId.rawName}\";\n" + } + + override fun writeToStage(streamId: StreamId, data: SerializableBuffer) { + val stageName = getStageName(streamId) + val stagingPath = getStagingPath() + val stagedFileName = staging.uploadRecordsToStage(data, stageName, stagingPath) + staging.copyIntoTableFromStage(stageName, stagingPath, listOf(stagedFileName), streamId) + } + override fun cleanupStage(streamId: StreamId) { + val stageName = getStageName(streamId) + log.info { "Cleaning up stage $stageName" } + staging.dropStageIfExists(stageName) + } + + internal fun getStageName(streamId: StreamId): String { + return """ + "${nameTransformer.convertStreamName(streamId.rawNamespace)}"."${ nameTransformer.convertStreamName(streamId.rawName)}" + """.trimIndent() + } + + internal fun getStagingPath(): String { + // see https://docs.snowflake.com/en/user-guide/data-load-considerations-stage.html + val zonedDateTime = ZonedDateTime.ofInstant(syncDateTime, ZoneOffset.UTC) + return nameTransformer.applyDefaultCase( + String.format( + "%s/%02d/%02d/%02d/%s/", + zonedDateTime.year, + zonedDateTime.monthValue, + zonedDateTime.dayOfMonth, + zonedDateTime.hour, + connectionId + ) + ) + } + + override fun createFinalTable(streamConfig: StreamConfig, suffix: String, replace: Boolean) { + destinationHandler.execute(sqlGenerator.createTable(streamConfig, suffix, replace)) + } + + override fun overwriteFinalTable(streamConfig: StreamConfig, tmpTableSuffix: String) { + if (tmpTableSuffix.isNotBlank()) { + log.info { + "Overwriting table ${streamConfig.id.finalTableId(SnowflakeSqlGenerator.QUOTE)} with ${ + streamConfig.id.finalTableId( + SnowflakeSqlGenerator.QUOTE, + tmpTableSuffix, + ) + }" + } + destinationHandler.execute( + sqlGenerator.overwriteFinalTable(streamConfig.id, tmpTableSuffix) + ) + } + } + + override fun softResetFinalTable(streamConfig: StreamConfig) { + TyperDeduperUtil.executeSoftReset(sqlGenerator, destinationHandler, streamConfig) + } + + override fun typeAndDedupe( + streamConfig: StreamConfig, + maxProcessedTimestamp: Optional, + finalTableSuffix: String + ) { + TyperDeduperUtil.executeTypeAndDedupe( + sqlGenerator = sqlGenerator, + destinationHandler = destinationHandler, + streamConfig, + maxProcessedTimestamp, + finalTableSuffix + ) + } +} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeDestinationHandler.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeDestinationHandler.kt index 1d602d18cac4..0ce1ca6eccdb 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeDestinationHandler.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeDestinationHandler.kt @@ -23,6 +23,7 @@ import io.airbyte.integrations.base.destination.typing_deduping.StreamId import io.airbyte.integrations.base.destination.typing_deduping.Struct import io.airbyte.integrations.base.destination.typing_deduping.Union import io.airbyte.integrations.base.destination.typing_deduping.UnsupportedOneOf +import io.airbyte.integrations.destination.snowflake.SnowflakeDatabaseUtils import io.airbyte.integrations.destination.snowflake.migrations.SnowflakeState import io.airbyte.protocol.models.v0.DestinationSyncMode import java.sql.Connection @@ -241,7 +242,9 @@ class SnowflakeDestinationHandler( } else { e.message } - throw RuntimeException(trimmedMessage, e) + throw SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e).orElseThrow { + RuntimeException(trimmedMessage, e) + } } LOGGER.info( @@ -430,22 +433,34 @@ class SnowflakeDestinationHandler( } private fun toJdbcTypeName(airbyteProtocolType: AirbyteProtocolType): String { - return when (airbyteProtocolType) { - AirbyteProtocolType.STRING -> "TEXT" - AirbyteProtocolType.NUMBER -> "FLOAT" - AirbyteProtocolType.INTEGER -> "NUMBER" - AirbyteProtocolType.BOOLEAN -> "BOOLEAN" - AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_TZ" - AirbyteProtocolType.TIMESTAMP_WITHOUT_TIMEZONE -> "TIMESTAMP_NTZ" - AirbyteProtocolType.TIME_WITH_TIMEZONE -> "TEXT" - AirbyteProtocolType.TIME_WITHOUT_TIMEZONE -> "TIME" - AirbyteProtocolType.DATE -> "DATE" - AirbyteProtocolType.UNKNOWN -> "VARIANT" - } + return SnowflakeDatabaseUtils.toSqlTypeName(airbyteProtocolType) } override fun createNamespaces(schemas: Set) { - // do nothing? + schemas.forEach { + try { + // 1s1t is assuming a lowercase airbyte_internal schema name, so we need to quote it + // we quote for final schemas names too (earlier existed in + // SqlGenerator#createSchema). + if (!isSchemaExists(it)) { + database.execute(String.format("CREATE SCHEMA IF NOT EXISTS \"%s\";", it)) + } + } catch (e: Exception) { + throw SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e).orElseThrow { e } + } + } + } + + private fun isSchemaExists(schema: String): Boolean { + try { + database.unsafeQuery(SHOW_SCHEMAS).use { results -> + return results + .map { schemas: JsonNode -> schemas[NAME].asText() } + .anyMatch { anObject: String? -> schema == anObject } + } + } catch (e: Exception) { + throw SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e).orElseThrow { e } + } } companion object { @@ -453,6 +468,8 @@ class SnowflakeDestinationHandler( LoggerFactory.getLogger(SnowflakeDestinationHandler::class.java) const val EXCEPTION_COMMON_PREFIX: String = "JavaScript execution error: Uncaught Execution of multiple statements failed on statement" + const val SHOW_SCHEMAS: String = "show schemas;" + const val NAME: String = "name" @Throws(SQLException::class) fun findExistingTables( diff --git a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGenerator.kt b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGenerator.kt index 2cb5f8c85810..736cc61f94cb 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGenerator.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/main/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGenerator.kt @@ -23,6 +23,7 @@ import io.airbyte.integrations.base.destination.typing_deduping.Struct import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduperUtil.SOFT_RESET_SUFFIX import io.airbyte.integrations.base.destination.typing_deduping.Union import io.airbyte.integrations.base.destination.typing_deduping.UnsupportedOneOf +import io.airbyte.integrations.destination.snowflake.SnowflakeDatabaseUtils import io.airbyte.protocol.models.v0.DestinationSyncMode import java.time.Instant import java.util.* @@ -76,20 +77,8 @@ class SnowflakeSqlGenerator(private val retentionPeriodDays: Int) : SqlGenerator throw IllegalArgumentException("Unsupported AirbyteType: $type") } - fun toDialectType(airbyteProtocolType: AirbyteProtocolType): String { - // TODO verify these types against normalization - return when (airbyteProtocolType) { - AirbyteProtocolType.STRING -> "TEXT" - AirbyteProtocolType.NUMBER -> "FLOAT" - AirbyteProtocolType.INTEGER -> "NUMBER" - AirbyteProtocolType.BOOLEAN -> "BOOLEAN" - AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE -> "TIMESTAMP_TZ" - AirbyteProtocolType.TIMESTAMP_WITHOUT_TIMEZONE -> "TIMESTAMP_NTZ" - AirbyteProtocolType.TIME_WITH_TIMEZONE -> "TEXT" - AirbyteProtocolType.TIME_WITHOUT_TIMEZONE -> "TIME" - AirbyteProtocolType.DATE -> "DATE" - AirbyteProtocolType.UNKNOWN -> "VARIANT" - } + private fun toDialectType(airbyteProtocolType: AirbyteProtocolType): String { + return SnowflakeDatabaseUtils.toSqlTypeName(airbyteProtocolType) } override fun createSchema(schema: String): Sql { diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestinationIntegrationTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestinationIntegrationTest.kt index 1ee4a151fcec..6d824e78098b 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestinationIntegrationTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeDestinationIntegrationTest.kt @@ -50,9 +50,9 @@ internal class SnowflakeDestinationIntegrationTest { val config = config val schema = config["schema"].asText() val dataSource: DataSource = - SnowflakeDatabase.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) + SnowflakeDatabaseUtils.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) try { - val database = SnowflakeDatabase.getDatabase(dataSource) + val database = SnowflakeDatabaseUtils.getDatabase(dataSource) Assertions.assertDoesNotThrow { syncWithNamingResolver(database, schema) } Assertions.assertThrows(SQLException::class.java) { syncWithoutNamingResolver(database, schema) diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeInsertDestinationAcceptanceTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeInsertDestinationAcceptanceTest.kt index 50f2220abbe8..62a4a158091d 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeInsertDestinationAcceptanceTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/SnowflakeInsertDestinationAcceptanceTest.kt @@ -52,8 +52,8 @@ open class SnowflakeInsertDestinationAcceptanceTest : DestinationAcceptanceTest( } private var config: JsonNode = Jsons.clone(staticConfig) private var dataSource: DataSource = - SnowflakeDatabase.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) - private var database: JdbcDatabase = SnowflakeDatabase.getDatabase(dataSource) + SnowflakeDatabaseUtils.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) + private var database: JdbcDatabase = SnowflakeDatabaseUtils.getDatabase(dataSource) @BeforeEach fun setup() { @@ -187,8 +187,9 @@ open class SnowflakeInsertDestinationAcceptanceTest : DestinationAcceptanceTest( this.config = Jsons.clone(staticConfig) (config as ObjectNode?)!!.put("schema", schemaName) - dataSource = SnowflakeDatabase.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) - database = SnowflakeDatabase.getDatabase(dataSource) + dataSource = + SnowflakeDatabaseUtils.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) + database = SnowflakeDatabaseUtils.getDatabase(dataSource) database.execute(createSchemaQuery) } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/AbstractSnowflakeTypingDedupingTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/AbstractSnowflakeTypingDedupingTest.kt index 15e915181ffb..ee3be7581a5e 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/AbstractSnowflakeTypingDedupingTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/AbstractSnowflakeTypingDedupingTest.kt @@ -45,8 +45,9 @@ abstract class AbstractSnowflakeTypingDedupingTest : BaseTypingDedupingTest() { val config = deserialize(readFile(Path.of(configPath))) (config as ObjectNode).put("schema", "typing_deduping_default_schema$uniqueSuffix") databaseName = config.get(JdbcUtils.DATABASE_KEY).asText() - dataSource = SnowflakeDatabase.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) - database = SnowflakeDatabase.getDatabase(dataSource) + dataSource = + SnowflakeDatabaseUtils.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS) + database = SnowflakeDatabaseUtils.getDatabase(dataSource) cleanAirbyteInternalTable(database) return config } @@ -117,111 +118,6 @@ abstract class AbstractSnowflakeTypingDedupingTest : BaseTypingDedupingTest() { */ get() = JavaBaseConstants.DEFAULT_AIRBYTE_INTERNAL_NAMESPACE - /** - * Run a sync using 3.0.0 (which is the highest version that still creates v2 final tables with - * lowercased+quoted names). Then run a sync using our current version. - */ - @Test - @Throws(Exception::class) - open fun testFinalTableUppercasingMigration_append() { - try { - val catalog = - ConfiguredAirbyteCatalog() - .withStreams( - java.util.List.of( - ConfiguredAirbyteStream() - .withSyncMode(SyncMode.FULL_REFRESH) - .withDestinationSyncMode(DestinationSyncMode.APPEND) - .withStream( - AirbyteStream() - .withNamespace(streamNamespace) - .withName(streamName) - .withJsonSchema(SCHEMA) - ) - ) - ) - - // First sync - val messages1 = readMessages("dat/sync1_messages.jsonl") - runSync(catalog, messages1, "airbyte/destination-snowflake:3.0.0") - - // We no longer have the code to dump a lowercased table, so just move on directly to - // the new sync - - // Second sync - val messages2 = readMessages("dat/sync2_messages.jsonl") - - runSync(catalog, messages2) - - val expectedRawRecords2 = readRecords("dat/sync2_expectedrecords_raw_mixed_tzs.jsonl") - val expectedFinalRecords2 = - readRecords("dat/sync2_expectedrecords_fullrefresh_append_final.jsonl") - verifySyncResult( - expectedRawRecords2, - expectedFinalRecords2, - disableFinalTableComparison() - ) - } finally { - // manually drop the lowercased schema, since we no longer have the code to do it - // automatically - // (the raw table is still in lowercase "airbyte_internal"."whatever", so the - // auto-cleanup code - // handles it fine) - database!!.execute("DROP SCHEMA IF EXISTS \"$streamNamespace\" CASCADE") - } - } - - @Test - @Throws(Exception::class) - fun testFinalTableUppercasingMigration_overwrite() { - try { - val catalog = - ConfiguredAirbyteCatalog() - .withStreams( - java.util.List.of( - ConfiguredAirbyteStream() - .withSyncMode(SyncMode.FULL_REFRESH) - .withDestinationSyncMode(DestinationSyncMode.OVERWRITE) - .withStream( - AirbyteStream() - .withNamespace(streamNamespace) - .withName(streamName) - .withJsonSchema(SCHEMA) - ) - ) - ) - - // First sync - val messages1 = readMessages("dat/sync1_messages.jsonl") - runSync(catalog, messages1, "airbyte/destination-snowflake:3.0.0") - - // We no longer have the code to dump a lowercased table, so just move on directly to - // the new sync - - // Second sync - val messages2 = readMessages("dat/sync2_messages.jsonl") - - runSync(catalog, messages2) - - val expectedRawRecords2 = - readRecords("dat/sync2_expectedrecords_fullrefresh_overwrite_raw.jsonl") - val expectedFinalRecords2 = - readRecords("dat/sync2_expectedrecords_fullrefresh_overwrite_final.jsonl") - verifySyncResult( - expectedRawRecords2, - expectedFinalRecords2, - disableFinalTableComparison() - ) - } finally { - // manually drop the lowercased schema, since we no longer have the code to do it - // automatically - // (the raw table is still in lowercase "airbyte_internal"."whatever", so the - // auto-cleanup code - // handles it fine) - database!!.execute("DROP SCHEMA IF EXISTS \"$streamNamespace\" CASCADE") - } - } - @Test @Throws(Exception::class) open fun testRemovingPKNonNullIndexes() { diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeInternalStagingCaseInsensitiveTypingDedupingTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeInternalStagingCaseInsensitiveTypingDedupingTest.kt index 95d19095d71b..a08a28843458 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeInternalStagingCaseInsensitiveTypingDedupingTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeInternalStagingCaseInsensitiveTypingDedupingTest.kt @@ -7,8 +7,6 @@ import com.fasterxml.jackson.databind.JsonNode import com.fasterxml.jackson.databind.node.ObjectNode import io.airbyte.commons.json.Jsons.emptyObject import java.util.* -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test class SnowflakeInternalStagingCaseInsensitiveTypingDedupingTest : AbstractSnowflakeTypingDedupingTest() { @@ -35,13 +33,4 @@ class SnowflakeInternalStagingCaseInsensitiveTypingDedupingTest : } .toList() } - - @Disabled( - "This test assumes the ability to create case-sensitive tables, which is by definition not available with QUOTED_IDENTIFIERS_IGNORE_CASE=TRUE" - ) - @Test - @Throws(Exception::class) - override fun testFinalTableUppercasingMigration_append() { - super.testFinalTableUppercasingMigration_append() - } } diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGeneratorIntegrationTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGeneratorIntegrationTest.kt index 92ddc945d3d9..697c4c7b568a 100644 --- a/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGeneratorIntegrationTest.kt +++ b/airbyte-integrations/connectors/destination-snowflake/src/test-integration/kotlin/io/airbyte/integrations/destination/snowflake/typing_deduping/SnowflakeSqlGeneratorIntegrationTest.kt @@ -17,7 +17,7 @@ import io.airbyte.integrations.base.destination.typing_deduping.Sql import io.airbyte.integrations.base.destination.typing_deduping.StreamId import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduperUtil.executeTypeAndDedupe import io.airbyte.integrations.destination.snowflake.OssCloudEnvVarConsts -import io.airbyte.integrations.destination.snowflake.SnowflakeDatabase +import io.airbyte.integrations.destination.snowflake.SnowflakeDatabaseUtils import io.airbyte.integrations.destination.snowflake.SnowflakeSourceOperations import io.airbyte.integrations.destination.snowflake.SnowflakeTestUtils import io.airbyte.integrations.destination.snowflake.SnowflakeTestUtils.dumpFinalTable @@ -1844,15 +1844,15 @@ class SnowflakeSqlGeneratorIntegrationTest : BaseSqlGeneratorIntegrationTest } - Assertions.assertEquals(AsyncStreamConsumer::class.java, consumer.javaClass) - } - companion object { @JvmStatic private fun urlsDataProvider(): Stream { diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStagingClientTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStagingClientTest.kt new file mode 100644 index 000000000000..b0d51d7e3341 --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStagingClientTest.kt @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.integrations.destination.snowflake.operation + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.jdbc.JdbcDatabase +import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer +import io.airbyte.commons.json.Jsons +import io.airbyte.integrations.base.destination.typing_deduping.StreamId +import io.mockk.confirmVerified +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.mockk.verifySequence +import java.io.File +import java.lang.RuntimeException +import java.sql.SQLException +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import org.mockito.kotlin.times + +class SnowflakeStagingClientTest { + + @Nested + inner class SuccessTest { + private val database = + mockk(relaxed = true, relaxUnitFun = true) { + // checkIfStage exists should be the only call to mock. it checks if any object + // exists. + every { unsafeQuery(any()) } returns listOf(Jsons.emptyObject()).stream() + } + private val stagingClient = SnowflakeStagingClient(database) + + @Test + fun verifyUploadRecordsToStage() { + val mockFileName = "mock-file-name" + val mockFileAbsolutePath = "/tmp/$mockFileName" + val recordBuffer = + mockk() { + every { filename } returns mockFileName + every { file } returns + mockk() { every { absolutePath } returns mockFileAbsolutePath } + } + val stageName = "dummy" + val stagingPath = "2024/uuid-random" + + val putQuery = stagingClient.getPutQuery(stageName, stagingPath, mockFileAbsolutePath) + val listQuery = stagingClient.getListQuery(stageName, stagingPath, mockFileName) + val stagedFile = + stagingClient.uploadRecordsToStage(recordBuffer, stageName, stagingPath) + assertEquals(stagedFile, mockFileName) + verify { + database.execute(putQuery) + database.unsafeQuery(listQuery) + } + confirmVerified(database) + } + + @Test + fun verifyCreateStageIfNotExists() { + val stageName = "dummy" + stagingClient.createStageIfNotExists(stageName) + verify { database.execute(stagingClient.getCreateStageQuery(stageName)) } + confirmVerified(database) + } + + @Test + fun verifyCopyIntoTableFromStage() { + val stageName = "dummy" + val stagingPath = "2024/uuid-random" + val stagedFiles = listOf("mock-file-name") + stagingClient.copyIntoTableFromStage(stageName, stagingPath, stagedFiles, streamId) + verify { + database.execute( + stagingClient.getCopyQuery(stageName, stagingPath, stagedFiles, streamId) + ) + } + confirmVerified(database) + } + + @Test + fun verifyDropStageIfExists() { + val stageName = "dummy" + stagingClient.dropStageIfExists(stageName) + verify { database.execute(stagingClient.getDropQuery(stageName)) } + confirmVerified(database) + } + } + + @Nested + inner class FailureTest { + + @Test + fun verifyUploadToStageRetried() { + val database = + mockk(relaxed = true, relaxUnitFun = true) { + // throw exception first on execute and success + every { execute(any(String::class)) } throws + Exception("First query failed") andThen + Unit + // return empty stream on first checkStage and then some data + every { unsafeQuery(any()) } returns + listOf().stream() andThen + listOf(Jsons.emptyObject()).stream() + } + val stagingClient = SnowflakeStagingClient(database) + + val mockFileName = "mock-file-name" + val mockFileAbsolutePath = "/tmp/$mockFileName" + val recordBuffer = + mockk() { + every { filename } returns mockFileName + every { file } returns + mockk() { every { absolutePath } returns mockFileAbsolutePath } + } + val stageName = "dummy" + val stagingPath = "2024/uuid-random" + + val putQuery = stagingClient.getPutQuery(stageName, stagingPath, mockFileAbsolutePath) + val listQuery = stagingClient.getListQuery(stageName, stagingPath, mockFileName) + val stagedFile = + stagingClient.uploadRecordsToStage(recordBuffer, stageName, stagingPath) + assertEquals(stagedFile, mockFileName) + verifySequence { + database.execute(putQuery) + database.execute(putQuery) + database.unsafeQuery(listQuery) + database.execute(putQuery) + database.unsafeQuery(listQuery) + } + confirmVerified(database) + } + + @Test + fun verifyUploadToStageExhaustedRetries() { + val database = + mockk(relaxed = true, relaxUnitFun = true) { + // throw exception first on execute and success + every { execute(any(String::class)) } throws + SQLException("Query can't be executed") + } + val stagingClient = SnowflakeStagingClient(database) + + val mockFileName = "mock-file-name" + val mockFileAbsolutePath = "/tmp/$mockFileName" + val recordBuffer = + mockk() { + every { filename } returns mockFileName + every { file } returns + mockk() { every { absolutePath } returns mockFileAbsolutePath } + } + val stageName = "dummy" + val stagingPath = "2024/uuid-random" + + val putQuery = stagingClient.getPutQuery(stageName, stagingPath, mockFileAbsolutePath) + assertThrows(RuntimeException::class.java) { + stagingClient.uploadRecordsToStage(recordBuffer, stageName, stagingPath) + } + verifySequence { + database.execute(putQuery) + database.execute(putQuery) + database.execute(putQuery) + } + confirmVerified(database) + } + } + + companion object { + val streamId = + StreamId( + "final_namespace", + "final_name", + "raw_namespace", + "raw_name", + "original_namespace", + "original_name", + ) + } +} diff --git a/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStorageOperationTest.kt b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStorageOperationTest.kt new file mode 100644 index 000000000000..84f19355af4b --- /dev/null +++ b/airbyte-integrations/connectors/destination-snowflake/src/test/kotlin/io/airbyte/integrations/destination/snowflake/operation/SnowflakeStorageOperationTest.kt @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.integrations.destination.snowflake.operation + +import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer +import io.airbyte.integrations.base.destination.typing_deduping.Sql +import io.airbyte.integrations.base.destination.typing_deduping.StreamId +import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeDestinationHandler +import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeSqlGenerator +import io.airbyte.protocol.models.v0.DestinationSyncMode +import io.mockk.confirmVerified +import io.mockk.every +import io.mockk.mockk +import io.mockk.verifySequence +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test + +class SnowflakeStorageOperationTest { + + @Nested + inner class SuccessTest { + private val sqlGenerator = mockk(relaxed = true) + private val destinationHandler = mockk(relaxed = true) + private val stagingClient = mockk(relaxed = true) + private val storageOperation: SnowflakeStorageOperation = + SnowflakeStorageOperation(sqlGenerator, destinationHandler, 1, stagingClient) + @Test + fun verifyPrepareStageCreatesTableAndStage() { + storageOperation.prepareStage(streamId, DestinationSyncMode.APPEND) + verifySequence { + destinationHandler.execute(Sql.of(storageOperation.createTableQuery(streamId))) + stagingClient.createStageIfNotExists(storageOperation.getStageName(streamId)) + } + confirmVerified(destinationHandler) + confirmVerified(stagingClient) + } + + @Test + fun verifyPrepareStageOverwriteTruncatesTable() { + storageOperation.prepareStage(streamId, DestinationSyncMode.OVERWRITE) + verifySequence { + destinationHandler.execute(Sql.of(storageOperation.createTableQuery(streamId))) + destinationHandler.execute(Sql.of(storageOperation.truncateTableQuery(streamId))) + stagingClient.createStageIfNotExists(storageOperation.getStageName(streamId)) + } + confirmVerified(destinationHandler) + confirmVerified(stagingClient) + } + + @Test + fun verifyWriteToStage() { + val mockTmpFileName = "random-tmp-file-name" + val data = mockk() { every { filename } returns mockTmpFileName } + val stageName = storageOperation.getStageName(streamId) + // stagingPath has UUID which isn't injected atm. + val stagingClient = + mockk(relaxed = true) { + every { uploadRecordsToStage(any(), any(), any()) } returns mockTmpFileName + } + val storageOperation = + SnowflakeStorageOperation(sqlGenerator, destinationHandler, 1, stagingClient) + storageOperation.writeToStage(streamId, data) + + verifySequence { + stagingClient.uploadRecordsToStage( + data, + stageName, + any(), + ) + stagingClient.copyIntoTableFromStage( + stageName, + any(), + listOf(mockTmpFileName), + streamId, + ) + } + confirmVerified(stagingClient) + } + + @Test + fun verifyCleanUpStage() { + storageOperation.cleanupStage(streamId) + verifySequence { + stagingClient.dropStageIfExists(storageOperation.getStageName(streamId)) + } + confirmVerified(stagingClient) + } + } + + companion object { + val streamId = + StreamId( + "final_namespace", + "final_name", + "raw_namespace", + "raw_name", + "original_namespace", + "original_name", + ) + } +}