Skip to content

Commit

Permalink
bulk-cdk: source output performance improvements (#44865)
Browse files Browse the repository at this point in the history
  • Loading branch information
postamar authored Aug 28, 2024
1 parent 643c800 commit abd9da9
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* Copyright (c) 2024 Airbyte, Inc., all rights reserved. */
package io.airbyte.cdk.output

import com.fasterxml.jackson.databind.SequenceWriter
import io.airbyte.cdk.util.Jsons
import io.airbyte.protocol.models.v0.AirbyteAnalyticsTraceMessage
import io.airbyte.protocol.models.v0.AirbyteCatalog
Expand Down Expand Up @@ -99,6 +100,7 @@ private class StdoutOutputConsumer : OutputConsumer {
override val emittedAt: Instant = Instant.now()

private val buffer = ByteArrayOutputStream()
private val sequenceWriter: SequenceWriter = Jsons.writer().writeValues(buffer)

override fun accept(airbyteMessage: AirbyteMessage) {
// This method effectively println's its JSON-serialized argument.
Expand All @@ -108,12 +110,12 @@ private class StdoutOutputConsumer : OutputConsumer {
// Other Airbyte message types are not buffered, instead they trigger an immediate flush.
// Such messages should not linger indefinitely in a buffer.
val isRecord: Boolean = airbyteMessage.type == AirbyteMessage.Type.RECORD
val json: ByteArray = Jsons.writeValueAsBytes(airbyteMessage)
synchronized(this) {
if (buffer.size() > 0) {
buffer.write('\n'.code)
}
buffer.writeBytes(json)
sequenceWriter.write(airbyteMessage)
sequenceWriter.flush()
if (!isRecord || buffer.size() >= BUFFER_MAX_SIZE) {
withLockFlush()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* Copyright (c) 2024 Airbyte, Inc., all rights reserved. */
package io.airbyte.cdk.output

import io.airbyte.cdk.util.Jsons
import io.airbyte.protocol.models.v0.AirbyteCatalog
import io.airbyte.protocol.models.v0.AirbyteConnectionStatus
import io.airbyte.protocol.models.v0.AirbyteLogMessage
Expand Down Expand Up @@ -35,7 +36,10 @@ class BufferingOutputConsumer(
private val traces = mutableListOf<AirbyteTraceMessage>()
private val messages = mutableListOf<AirbyteMessage>()

override fun accept(m: AirbyteMessage) {
override fun accept(input: AirbyteMessage) {
// Deep copy the input, which may be reused and mutated later on.
val m: AirbyteMessage =
Jsons.readValue(Jsons.writeValueAsBytes(input), AirbyteMessage::class.java)
synchronized(this) {
messages.add(m)
when (m.type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ data class TestCase(
)
for (actualState in actual!!) {
Assertions.assertTrue(
actualState in expected,
actualState.toString() in expected.map { it.toString() },
"$actualState should be in $expected",
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode
import io.airbyte.cdk.command.OpaqueStateValue
import io.airbyte.cdk.output.OutputConsumer
import io.airbyte.cdk.util.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.AirbyteRecordMessage
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
Expand Down Expand Up @@ -38,18 +39,25 @@ sealed class JdbcPartitionReader<P : JdbcPartition<*>>(
}

fun out(record: ObjectNode) {
val recordMessageData: ObjectNode = Jsons.objectNode()
for (fieldName in streamFieldNames) {
recordMessageData.set<JsonNode>(fieldName, record[fieldName] ?: Jsons.nullNode())
outData.set<JsonNode>(fieldName, record[fieldName] ?: Jsons.nullNode())
}
outputConsumer.accept(
AirbyteRecordMessage()
.withStream(stream.name)
.withNamespace(stream.namespace)
.withData(recordMessageData),
)
outputConsumer.accept(msg)
}

private val outData: ObjectNode = Jsons.objectNode()

private val msg =
AirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
.withRecord(
AirbyteRecordMessage()
.withEmittedAt(outputConsumer.emittedAt.toEpochMilli())
.withStream(stream.name)
.withNamespace(stream.namespace)
.withData(outData)
)

val streamFieldNames: List<String> = stream.fields.map { it.id }

override fun releaseResources() {
Expand All @@ -69,7 +77,11 @@ class JdbcNonResumablePartitionReader<P : JdbcPartition<*>>(
selectQuerier
.executeQuery(
q = partition.nonResumableQuery,
parameters = SelectQuerier.Parameters(streamState.fetchSize),
parameters =
SelectQuerier.Parameters(
reuseResultObject = true,
fetchSize = streamState.fetchSize
),
)
.use { result: SelectQuerier.Result ->
for (record in result) {
Expand Down Expand Up @@ -109,7 +121,8 @@ class JdbcResumablePartitionReader<P : JdbcSplittablePartition<*>>(
selectQuerier
.executeQuery(
q = partition.resumableQuery(limit),
parameters = SelectQuerier.Parameters(fetchSize),
parameters =
SelectQuerier.Parameters(reuseResultObject = true, fetchSize = fetchSize),
)
.use { result: SelectQuerier.Result ->
for (record in result) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ interface SelectQuerier {
): Result

data class Parameters(
/** When set, the [ObjectNode] in the [Result] is reused; take care with this! */
val reuseResultObject: Boolean = false,
/** JDBC fetchSize value. */
val fetchSize: Int? = null,
)

Expand All @@ -46,6 +49,7 @@ class JdbcSelectQuerier(
var conn: Connection? = null
var stmt: PreparedStatement? = null
var rs: ResultSet? = null
val reusable: ObjectNode? = Jsons.objectNode().takeIf { parameters.reuseResultObject }

init {
log.info { "Querying ${q.sql}" }
Expand Down Expand Up @@ -94,7 +98,7 @@ class JdbcSelectQuerier(
// necessary.
if (!hasNext()) throw NoSuchElementException()
// Read the current row in the ResultSet
val record: ObjectNode = Jsons.objectNode()
val record: ObjectNode = reusable ?: Jsons.objectNode()
var colIdx = 1
for (column in q.columns) {
log.debug { "Getting value #$colIdx for $column." }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class JdbcPartitionReaderTest {
)
),
),
SelectQuerier.Parameters(fetchSize = 2),
SelectQuerier.Parameters(reuseResultObject = true, fetchSize = 2),
"""{"id":1,"ts":"2024-08-01","msg":"hello"}""",
"""{"id":2,"ts":"2024-08-02","msg":"how"}""",
"""{"id":3,"ts":"2024-08-03","msg":"are"}""",
Expand Down Expand Up @@ -126,7 +126,7 @@ class JdbcPartitionReaderTest {
OrderBy(ts),
Limit(4),
),
SelectQuerier.Parameters(fetchSize = 2),
SelectQuerier.Parameters(reuseResultObject = true, fetchSize = 2),
"""{"id":1,"ts":"2024-08-01","msg":"hello"}""",
"""{"id":2,"ts":"2024-08-02","msg":"how"}""",
"""{"id":3,"ts":"2024-08-03","msg":"are"}""",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* Copyright (c) 2024 Airbyte, Inc., all rights reserved. */
package io.airbyte.cdk.read

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.ObjectNode
import io.airbyte.cdk.discover.Field
import io.airbyte.cdk.h2.H2TestFixture
Expand Down Expand Up @@ -77,7 +78,7 @@ class JdbcSelectQuerierTest {

private fun runTest(
q: SelectQuery,
vararg expected: String,
vararg expectedJson: String,
) {
val configPojo: H2SourceConfigurationJsonObject =
H2SourceConfigurationJsonObject().apply {
Expand All @@ -86,7 +87,21 @@ class JdbcSelectQuerierTest {
}
val config: H2SourceConfiguration = H2SourceConfigurationFactory().make(configPojo)
val querier: SelectQuerier = JdbcSelectQuerier(JdbcConnectionFactory(config))
// Vanilla query
val expected: List<JsonNode> = expectedJson.map(Jsons::readTree)
val actual: List<ObjectNode> = querier.executeQuery(q).use { it.asSequence().toList() }
Assertions.assertIterableEquals(expected.toList().map(Jsons::readTree), actual)
Assertions.assertIterableEquals(expected, actual)
// Query with reuseResultObject = true
querier.executeQuery(q, SelectQuerier.Parameters(reuseResultObject = true)).use {
var i = 0
var previous: ObjectNode? = null
for (record in it) {
if (i > 0) {
Assertions.assertTrue(previous === record)
}
Assertions.assertEquals(expected[i++], record)
previous = record
}
}
}
}

0 comments on commit abd9da9

Please sign in to comment.