Skip to content

Commit

Permalink
gireesh/snowflake_refreshes
Browse files Browse the repository at this point in the history
  • Loading branch information
gisripa authored and stephane-airbyte committed Jun 14, 2024
1 parent c7720b1 commit e2cf210
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
airbyteJavaConnector {
cdkVersionRequired = '0.37.1'
features = ['db-destinations', 's3-destinations', 'typing-deduping']
useLocalCdk = false
useLocalCdk = true
}

java {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ constructor(
hasUnprocessedRecords = true,
maxProcessedTimestamp = Optional.empty()
),
initialTempRawTableStatus =
InitialRawTableStatus(
rawTableExists = false,
hasUnprocessedRecords = true,
maxProcessedTimestamp = Optional.empty()
),
isSchemaMismatch = true,
isFinalTableEmpty = true,
destinationState =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ class SnowflakeStagingClient(private val database: JdbcDatabase) {
stageName: String,
stagingPath: String,
stagedFiles: List<String>,
streamId: StreamId
streamId: StreamId,
suffix: String = ""
) {
try {
val queryId = UUID.randomUUID()
val query = getCopyQuery(stageName, stagingPath, stagedFiles, streamId)
val query = getCopyQuery(stageName, stagingPath, stagedFiles, streamId, suffix)
log.info { "query $queryId, $query" }
// queryJsons is intentionally used here to get the error message in case of failure
// instead of execute
Expand Down Expand Up @@ -252,12 +253,13 @@ class SnowflakeStagingClient(private val database: JdbcDatabase) {
stageName: String,
stagingPath: String,
stagedFiles: List<String>,
streamId: StreamId
streamId: StreamId,
suffix: String
): String {
return String.format(
COPY_QUERY_1S1T + generateFilesList(stagedFiles) + ";",
streamId.rawNamespace,
streamId.rawName,
streamId.rawName + suffix,
stageName,
stagingPath
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.airbyte.integrations.destination.snowflake.operation

import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.integrations.base.JavaBaseConstants
import io.airbyte.cdk.integrations.destination.StandardNameTransformer
import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer
Expand All @@ -15,7 +16,6 @@ import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduperUtil
import io.airbyte.integrations.destination.snowflake.SnowflakeSQLNameTransformer
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeDestinationHandler
import io.airbyte.integrations.destination.snowflake.typing_deduping.SnowflakeSqlGenerator
import io.airbyte.protocol.models.v0.DestinationSyncMode
import io.github.oshai.kotlinlogging.KotlinLogging
import java.time.Instant
import java.time.ZoneOffset
Expand All @@ -35,19 +35,77 @@ class SnowflakeStorageOperation(
private val connectionId = UUID.randomUUID()
private val syncDateTime = Instant.now()

override fun prepareStage(streamId: StreamId, destinationSyncMode: DestinationSyncMode) {
override fun prepareStage(streamId: StreamId, suffix: String, replace: Boolean) {
// create raw table
destinationHandler.execute(Sql.of(createTableQuery(streamId)))
if (destinationSyncMode == DestinationSyncMode.OVERWRITE) {
destinationHandler.execute(Sql.of(truncateTableQuery(streamId)))
destinationHandler.execute(Sql.of(createTableQuery(streamId, suffix)))
if (replace) {
destinationHandler.execute(Sql.of(truncateTableQuery(streamId, suffix)))
}
// create stage
staging.createStageIfNotExists(getStageName(streamId))
}

internal fun createTableQuery(streamId: StreamId): String {
override fun overwriteStage(streamId: StreamId, suffix: String) {
if (suffix.isBlank()) {
throw IllegalArgumentException("Cannot overwrite raw table with empty suffix")
}
// Something weird happening with SWAP WITH in truncateRefresh tests,
// so using DROP AND ALTER RENAME instead
destinationHandler.execute(
Sql.of("DROP TABLE IF EXISTS \"${streamId.rawNamespace}\".\"${streamId.rawName}\"")
)
val swapQuery =
"""
| ALTER TABLE "${streamId.rawNamespace}"."${streamId.rawName+suffix}" RENAME TO "${streamId.rawNamespace}"."${streamId.rawName}";
""".trimMargin()
destinationHandler.execute(Sql.of(swapQuery))
}

override fun transferFromTempStage(streamId: StreamId, suffix: String) {
if (suffix.isBlank()) {
throw IllegalArgumentException(
"Cannot transfer records from temp raw table with empty suffix"
)
}
destinationHandler.execute(
Sql.of(
"""
INSERT INTO "${streamId.rawNamespace}"."${streamId.rawName}"
SELECT * FROM "${streamId.rawNamespace}"."${streamId.rawName + suffix}"
""".trimIndent()
)
)
destinationHandler.execute(
Sql.of(
"""
DROP TABLE "${streamId.rawNamespace}"."${streamId.rawName + suffix}"
""".trimIndent()
)
)
}

override fun getStageGeneration(streamId: StreamId, suffix: String): Long? {
val results =
destinationHandler.query(
"""
SELECT "${JavaBaseConstants.COLUMN_NAME_AB_GENERATION_ID}" FROM "${streamId.rawNamespace}"."${streamId.rawName + suffix}" LIMIT 1
""".trimIndent()
)
if (results.isEmpty()) return null
var generationIdNode: JsonNode? =
results.first().get(JavaBaseConstants.COLUMN_NAME_AB_GENERATION_ID)
if (generationIdNode == null) {
// This is the dance where QUOTED_IDENTIFIERS_IGNORE_CASE will return uppercase column
// as result, so check for fallback.
generationIdNode =
results.first().get(JavaBaseConstants.COLUMN_NAME_AB_GENERATION_ID.uppercase())
}
return generationIdNode?.asLong()
}

internal fun createTableQuery(streamId: StreamId, suffix: String): String {
return """
|CREATE TABLE IF NOT EXISTS "${streamId.rawNamespace}"."${streamId.rawName}"(
|CREATE TABLE IF NOT EXISTS "${streamId.rawNamespace}"."${streamId.rawName + suffix}"(
| "${JavaBaseConstants.COLUMN_NAME_AB_RAW_ID}" VARCHAR PRIMARY KEY,
| "${JavaBaseConstants.COLUMN_NAME_AB_EXTRACTED_AT}" TIMESTAMP WITH TIME ZONE DEFAULT current_timestamp(),
| "${JavaBaseConstants.COLUMN_NAME_AB_LOADED_AT}" TIMESTAMP WITH TIME ZONE DEFAULT NULL,
Expand All @@ -58,19 +116,24 @@ class SnowflakeStorageOperation(
""".trimMargin()
}

internal fun truncateTableQuery(streamId: StreamId): String {
return "TRUNCATE TABLE \"${streamId.rawNamespace}\".\"${streamId.rawName}\";\n"
internal fun truncateTableQuery(streamId: StreamId, suffix: String): String {
return "TRUNCATE TABLE \"${streamId.rawNamespace}\".\"${streamId.rawName + suffix}\";\n"
}

override fun writeToStage(streamConfig: StreamConfig, data: SerializableBuffer) {
override fun writeToStage(
streamConfig: StreamConfig,
suffix: String,
data: SerializableBuffer
) {
val stageName = getStageName(streamConfig.id)
val stagingPath = getStagingPath()
val stagedFileName = staging.uploadRecordsToStage(data, stageName, stagingPath)
staging.copyIntoTableFromStage(
stageName,
stagingPath,
listOf(stagedFileName),
streamConfig.id
streamConfig.id,
suffix
)
}
override fun cleanupStage(streamId: StreamId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.airbyte.cdk.integrations.destination.jdbc.ColumnDefinition
import io.airbyte.cdk.integrations.destination.jdbc.TableDefinition
import io.airbyte.cdk.integrations.destination.jdbc.typing_deduping.JdbcDestinationHandler
import io.airbyte.commons.json.Jsons.emptyObject
import io.airbyte.integrations.base.destination.operation.AbstractStreamOperation
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteType
import io.airbyte.integrations.base.destination.typing_deduping.Array
Expand Down Expand Up @@ -85,22 +86,32 @@ class SnowflakeDestinationHandler(
@Throws(Exception::class)
private fun getInitialRawTableState(
id: StreamId,
suffix: String,
): InitialRawTableStatus {
// Short-circuit for overwrite, table will be truncated anyway
val rawTableName = id.rawName + suffix
val tableExists =
database.executeMetadataQuery { databaseMetaData: DatabaseMetaData ->
LOGGER.info("Retrieving table from Db metadata: {} {}", id.rawNamespace, id.rawName)
LOGGER.info(
"Retrieving table from Db metadata: {} {}",
id.rawNamespace,
rawTableName
)
try {
val rs =
databaseMetaData.getTables(databaseName, id.rawNamespace, id.rawName, null)
databaseMetaData.getTables(
databaseName,
id.rawNamespace,
rawTableName,
null
)
// When QUOTED_IDENTIFIERS_IGNORE_CASE is set to true, the raw table is
// interpreted as uppercase
// in db metadata calls. check for both
val rsUppercase =
databaseMetaData.getTables(
databaseName,
id.rawNamespace.uppercase(),
id.rawName.uppercase(),
rawTableName.uppercase(),
null
)
rs.next() || rsUppercase.next()
Expand Down Expand Up @@ -130,7 +141,7 @@ class SnowflakeDestinationHandler(
StringSubstitutor(
java.util.Map.of(
"raw_table",
id.rawTableId(SnowflakeSqlGenerator.QUOTE)
id.rawTableId(SnowflakeSqlGenerator.QUOTE, suffix)
)
)
.replace(
Expand Down Expand Up @@ -186,7 +197,7 @@ class SnowflakeDestinationHandler(
StringSubstitutor(
java.util.Map.of(
"raw_table",
id.rawTableId(SnowflakeSqlGenerator.QUOTE)
id.rawTableId(SnowflakeSqlGenerator.QUOTE, suffix)
)
)
.replace(
Expand Down Expand Up @@ -286,7 +297,7 @@ class SnowflakeDestinationHandler(
"VARIANT" == existingTable.columns[abMetaColumnName]!!.type
}

fun isAirbyteGenerationIdColumnMatch(existingTable: TableDefinition): Boolean {
private fun isAirbyteGenerationIdColumnMatch(existingTable: TableDefinition): Boolean {
val abGenerationIdColumnName: String =
JavaBaseConstants.COLUMN_NAME_AB_GENERATION_ID.uppercase(Locale.getDefault())
return existingTable.columns.containsKey(abGenerationIdColumnName) &&
Expand Down Expand Up @@ -388,7 +399,12 @@ class SnowflakeDestinationHandler(
!existingSchemaMatchesStreamConfig(streamConfig, existingTable!!)
isFinalTableEmpty = hasRowCount && tableRowCounts[namespace]!![name] == 0
}
val initialRawTableState = getInitialRawTableState(streamConfig.id)
val initialRawTableState = getInitialRawTableState(streamConfig.id, "")
val tempRawTableState =
getInitialRawTableState(
streamConfig.id,
AbstractStreamOperation.TMP_TABLE_SUFFIX
)
val destinationState =
destinationStates.getOrDefault(
streamConfig.id.asPair(),
Expand All @@ -398,6 +414,7 @@ class SnowflakeDestinationHandler(
streamConfig,
isFinalTablePresent,
initialRawTableState,
tempRawTableState,
isSchemaMismatch,
isFinalTableEmpty,
destinationState
Expand Down Expand Up @@ -466,6 +483,10 @@ class SnowflakeDestinationHandler(
}
}

fun query(sql: String): List<JsonNode> {
return database.queryJsons(sql)
}

companion object {
private val LOGGER: Logger =
LoggerFactory.getLogger(SnowflakeDestinationHandler::class.java)
Expand Down
Loading

0 comments on commit e2cf210

Please sign in to comment.