From a7c13dc6f14c0abe839d7405b1b112d1f6ba1435 Mon Sep 17 00:00:00 2001 From: Benoit Moriceau Date: Mon, 18 Sep 2023 16:31:57 -0700 Subject: [PATCH] Improve the memory usage of the async destination (#30479) Co-authored-by: benmoriceau --- .../AsyncStreamConsumer.java | 52 +++++++------- .../buffers/BufferManager.java | 10 ++- .../state/GlobalAsyncStateManager.java | 70 +++++++++++-------- .../AsyncStreamConsumerTest.java | 13 ++-- .../destination-snowflake/Dockerfile | 2 +- .../destination-snowflake/build.gradle | 2 +- .../destination-snowflake/metadata.yaml | 2 +- docs/integrations/destinations/snowflake.md | 3 +- 8 files changed, 87 insertions(+), 67 deletions(-) diff --git a/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/AsyncStreamConsumer.java b/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/AsyncStreamConsumer.java index 3b9c86579758..caac2e43f253 100644 --- a/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/AsyncStreamConsumer.java +++ b/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/AsyncStreamConsumer.java @@ -18,7 +18,6 @@ import io.airbyte.protocol.models.v0.AirbyteMessage.Type; import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog; import io.airbyte.protocol.models.v0.StreamDescriptor; -import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import lombok.extern.slf4j.Slf4j; @@ -110,16 +109,14 @@ public void accept(final String messageString, final Integer sizeInBytes) throws * to try to use a thread pool to partially deserialize to get record type and stream name, we can * do it without touching buffer manager. */ - deserializeAirbyteMessage(messageString) - .ifPresent(message -> { - if (Type.RECORD.equals(message.getType())) { - if (Strings.isNullOrEmpty(message.getRecord().getNamespace())) { - message.getRecord().setNamespace(defaultNamespace); - } - validateRecord(message); - } - bufferEnqueue.addRecord(message, sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES); - }); + final var message = deserializeAirbyteMessage(messageString); + if (Type.RECORD.equals(message.getType())) { + if (Strings.isNullOrEmpty(message.getRecord().getNamespace())) { + message.getRecord().setNamespace(defaultNamespace); + } + validateRecord(message); + } + bufferEnqueue.addRecord(message, sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES); } /** @@ -134,24 +131,27 @@ public void accept(final String messageString, final Integer sizeInBytes) throws * @return PartialAirbyteMessage if the message is valid, empty otherwise */ @VisibleForTesting - public static Optional deserializeAirbyteMessage(final String messageString) { + public static PartialAirbyteMessage deserializeAirbyteMessage(final String messageString) { // TODO: (ryankfu) plumb in the serialized AirbyteStateMessage to match AirbyteRecordMessage code // parity. https://github.com/airbytehq/airbyte/issues/27530 for additional context - final Optional messageOptional = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class) - .map(partial -> { - if (Type.RECORD.equals(partial.getType()) && partial.getRecord().getData() != null) { - return partial.withSerialized(partial.getRecord().getData().toString()); - } else if (Type.STATE.equals(partial.getType())) { - return partial.withSerialized(messageString); - } else { - return null; - } - }); - - if (messageOptional.isPresent()) { - return messageOptional; + final var partial = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class) + .orElseThrow(() -> new RuntimeException("Unable to deserialize PartialAirbyteMessage.")); + + final var msgType = partial.getType(); + if (Type.RECORD.equals(msgType) && partial.getRecord().getData() != null) { + // store serialized json + partial.withSerialized(partial.getRecord().getData().toString()); + // The connector doesn't need to be able to access to the record value. We can serialize it here and + // drop the json + // object. Having this data stored as a string is slightly more optimal for the memory usage. + partial.getRecord().setData(null); + } else if (Type.STATE.equals(msgType)) { + partial.withSerialized(messageString); + } else { + throw new RuntimeException(String.format("Unsupported message type: %s", msgType)); } - throw new RuntimeException("Invalid serialized message"); + + return partial; } @Override diff --git a/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/buffers/BufferManager.java b/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/buffers/BufferManager.java index 3216135fdf73..b9aa5d75e26d 100644 --- a/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/buffers/BufferManager.java +++ b/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/buffers/BufferManager.java @@ -74,7 +74,9 @@ public void close() throws Exception { } private void printQueueInfo() { - final var queueInfo = new StringBuilder().append("QUEUE INFO").append(System.lineSeparator()); + final var queueInfo = new StringBuilder().append("START OF QUEUE INFO").append(System.lineSeparator()) + .append("This represents an estimation of the size of the elements contain in the in memory buffer.") + .append(System.lineSeparator()); queueInfo .append(String.format(" Global Mem Manager -- max: %s, allocated: %s (%s MB), %% used: %s", @@ -91,6 +93,12 @@ private void printQueueInfo() { entry.getKey().getName(), queue.size(), AirbyteFileUtils.byteCountToDisplaySize(queue.getCurrentMemoryUsage()))) .append(System.lineSeparator()); } + + queueInfo.append(stateManager.getMemoryUsageMessage()) + .append(System.lineSeparator()); + + queueInfo.append("END OF QUEUE INFO"); + log.info(queueInfo.toString()); } diff --git a/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/state/GlobalAsyncStateManager.java b/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/state/GlobalAsyncStateManager.java index 420f892a66ef..2ddc643763cc 100644 --- a/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/state/GlobalAsyncStateManager.java +++ b/airbyte-integrations/bases/base-java/src/main/java/io/airbyte/integrations/destination_async/state/GlobalAsyncStateManager.java @@ -18,6 +18,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -68,9 +69,8 @@ public class GlobalAsyncStateManager { private final AtomicLong memoryUsed; boolean preState = true; + private final ConcurrentMap> descToStateIdQ = new ConcurrentHashMap<>(); private final ConcurrentMap stateIdToCounter = new ConcurrentHashMap<>(); - private final ConcurrentMap> streamToStateIdQ = new ConcurrentHashMap<>(); - private final ConcurrentMap> stateIdToState = new ConcurrentHashMap<>(); // empty in the STREAM case. @@ -144,30 +144,39 @@ public void decrement(final long stateId, final long count) { public List flushStates() { final List output = new ArrayList<>(); Long bytesFlushed = 0L; - for (final Map.Entry> entry : streamToStateIdQ.entrySet()) { - // Remove all states with 0 counters. - // Per-stream synchronized is required to make sure the state (at the head of the queue) - // logic is applied to is the state actually removed. - synchronized (this) { + synchronized (this) { + for (final Map.Entry> entry : descToStateIdQ.entrySet()) { + // Remove all states with 0 counters. + // Per-stream synchronized is required to make sure the state (at the head of the queue) + // logic is applied to is the state actually removed. + final LinkedList stateIdQueue = entry.getValue(); while (true) { - final Long oldestState = stateIdQueue.peek(); - if (oldestState == null) { + final Long oldestStateId = stateIdQueue.peek(); + // no state to flush for this stream + if (oldestStateId == null) { break; } // technically possible this map hasn't been updated yet. - final boolean noCorrespondingStateMsg = stateIdToState.get(oldestState) == null; - if (noCorrespondingStateMsg) { + final var oldestStateCounter = stateIdToCounter.get(oldestStateId); + Objects.requireNonNull(oldestStateCounter, "Invariant Violation: No record counter found for state message."); + + final var oldestState = stateIdToState.get(oldestStateId); + // no state to flush for this stream + if (oldestState == null) { break; } - final boolean noPrevRecs = !stateIdToCounter.containsKey(oldestState); - final boolean allRecsEmitted = stateIdToCounter.get(oldestState).get() == 0; - if (noPrevRecs || allRecsEmitted) { - var polled = entry.getValue().poll(); // poll to remove. no need to read as the earlier peek is still valid. - output.add(stateIdToState.get(oldestState).getLeft()); - bytesFlushed += stateIdToState.get(oldestState).getRight(); + final var allRecordsCommitted = oldestStateCounter.get() == 0; + if (allRecordsCommitted) { + output.add(oldestState.getLeft()); + bytesFlushed += oldestState.getRight(); + + // cleanup + entry.getValue().poll(); + stateIdToState.remove(oldestStateId); + stateIdToCounter.remove(oldestStateId); } else { break; } @@ -183,10 +192,10 @@ private Long getStateIdAndIncrement(final StreamDescriptor streamDescriptor, fin final StreamDescriptor resolvedDescriptor = stateType == AirbyteStateMessage.AirbyteStateType.STREAM ? streamDescriptor : SENTINEL_GLOBAL_DESC; // As concurrent collections do not guarantee data consistency when iterating, use `get` instead of // `containsKey`. - if (streamToStateIdQ.get(resolvedDescriptor) == null) { + if (descToStateIdQ.get(resolvedDescriptor) == null) { registerNewStreamDescriptor(resolvedDescriptor); } - final Long stateId = streamToStateIdQ.get(resolvedDescriptor).peekLast(); + final Long stateId = descToStateIdQ.get(resolvedDescriptor).peekLast(); final var update = stateIdToCounter.get(stateId).addAndGet(increment); log.trace("State id: {}, count: {}", stateId, update); return stateId; @@ -231,12 +240,12 @@ private void convertToGlobalIfNeeded(final PartialAirbyteMessage message) { // upon conversion, all previous tracking data structures need to be cleared as we move // into the non-STREAM world for correctness. - aliasIds.addAll(streamToStateIdQ.values().stream().flatMap(Collection::stream).toList()); - streamToStateIdQ.clear(); + aliasIds.addAll(descToStateIdQ.values().stream().flatMap(Collection::stream).toList()); + descToStateIdQ.clear(); retroactiveGlobalStateId = StateIdProvider.getNextId(); - streamToStateIdQ.put(SENTINEL_GLOBAL_DESC, new LinkedList<>()); - streamToStateIdQ.get(SENTINEL_GLOBAL_DESC).add(retroactiveGlobalStateId); + descToStateIdQ.put(SENTINEL_GLOBAL_DESC, new LinkedList<>()); + descToStateIdQ.get(SENTINEL_GLOBAL_DESC).add(retroactiveGlobalStateId); final long combinedCounter = stateIdToCounter.values() .stream() @@ -291,9 +300,12 @@ private void allocateMemoryToState(final long sizeInBytes) { throw new RuntimeException(e); } } + LOGGER.debug(getMemoryUsageMessage()); } - memoryUsed.addAndGet(sizeInBytes); - LOGGER.debug("State Manager memory usage: Allocated: {}, Used: {}, % Used {}", + } + + public String getMemoryUsageMessage() { + return String.format("State Manager memory usage: Allocated: %s, Used: %s, percentage Used %f", FileUtils.byteCountToDisplaySize(memoryAllocated.get()), FileUtils.byteCountToDisplaySize(memoryUsed.get()), (double) memoryUsed.get() / memoryAllocated.get()); @@ -312,14 +324,14 @@ private long getStateAfterAlias(final long stateId) { } private void registerNewStreamDescriptor(final StreamDescriptor resolvedDescriptor) { - streamToStateIdQ.put(resolvedDescriptor, new LinkedList<>()); + descToStateIdQ.put(resolvedDescriptor, new LinkedList<>()); registerNewStateId(resolvedDescriptor); } private void registerNewStateId(final StreamDescriptor resolvedDescriptor) { final long stateId = StateIdProvider.getNextId(); - streamToStateIdQ.get(resolvedDescriptor).add(stateId); stateIdToCounter.put(stateId, new AtomicLong(0)); + descToStateIdQ.get(resolvedDescriptor).add(stateId); } /** @@ -327,10 +339,10 @@ private void registerNewStateId(final StreamDescriptor resolvedDescriptor) { */ private static class StateIdProvider { - private static long pk = 0; + private static final AtomicLong pk = new AtomicLong(0); public static long getNextId() { - return pk++; + return pk.incrementAndGet(); } } diff --git a/airbyte-integrations/bases/base-java/src/test/java/io/airbyte/integrations/destination_async/AsyncStreamConsumerTest.java b/airbyte-integrations/bases/base-java/src/test/java/io/airbyte/integrations/destination_async/AsyncStreamConsumerTest.java index 06e399d28f36..bb1bf7420499 100644 --- a/airbyte-integrations/bases/base-java/src/test/java/io/airbyte/integrations/destination_async/AsyncStreamConsumerTest.java +++ b/airbyte-integrations/bases/base-java/src/test/java/io/airbyte/integrations/destination_async/AsyncStreamConsumerTest.java @@ -39,7 +39,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -236,8 +235,8 @@ void deserializeAirbyteMessageWithAirbyteRecord() { .withData(PAYLOAD)); final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage); final String airbyteRecordString = Jsons.serialize(PAYLOAD); - final Optional partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage); - assertEquals(airbyteRecordString, partial.get().getSerialized()); + final PartialAirbyteMessage partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage); + assertEquals(airbyteRecordString, partial.getSerialized()); } @Test @@ -250,8 +249,8 @@ void deserializeAirbyteMessageWithEmptyAirbyteRecord() { .withNamespace(SCHEMA_NAME) .withData(Jsons.jsonNode(emptyMap))); final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage); - final Optional partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage); - assertEquals(emptyMap.toString(), partial.get().getSerialized()); + final PartialAirbyteMessage partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage); + assertEquals(emptyMap.toString(), partial.getSerialized()); } @Test @@ -266,8 +265,8 @@ void deserializeAirbyteMessageWithNoStateOrRecord() { @Test void deserializeAirbyteMessageWithAirbyteState() { final String serializedAirbyteMessage = Jsons.serialize(STATE_MESSAGE1); - final Optional partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage); - assertEquals(serializedAirbyteMessage, partial.get().getSerialized()); + final PartialAirbyteMessage partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage); + assertEquals(serializedAirbyteMessage, partial.getSerialized()); } @Test diff --git a/airbyte-integrations/connectors/destination-snowflake/Dockerfile b/airbyte-integrations/connectors/destination-snowflake/Dockerfile index 4d12b9d1b757..e57766de0b3f 100644 --- a/airbyte-integrations/connectors/destination-snowflake/Dockerfile +++ b/airbyte-integrations/connectors/destination-snowflake/Dockerfile @@ -29,5 +29,5 @@ RUN tar xf ${APPLICATION}.tar --strip-components=1 ENV ENABLE_SENTRY true -LABEL io.airbyte.version=3.1.7 +LABEL io.airbyte.version=3.1.8 LABEL io.airbyte.name=airbyte/destination-snowflake diff --git a/airbyte-integrations/connectors/destination-snowflake/build.gradle b/airbyte-integrations/connectors/destination-snowflake/build.gradle index d37a6b84432b..bf0ee019906d 100644 --- a/airbyte-integrations/connectors/destination-snowflake/build.gradle +++ b/airbyte-integrations/connectors/destination-snowflake/build.gradle @@ -16,7 +16,7 @@ application { // '-Dcom.sun.management.jmxremote=true', // '-Dcom.sun.management.jmxremote.port=6000', // '-Dcom.sun.management.jmxremote.rmi.port=6000', -// '-Dcom.sun.management.jmxremote.local.only=false', +// '-Dcom.sun.management.jmxremote.local.only=false' // '-Dcom.sun.management.jmxremote.authenticate=false', // '-Dcom.sun.management.jmxremote.ssl=false', // '-agentpath:/usr/local/YourKit-JavaProfiler-2021.3/bin/linux-x86-64/libyjpagent.so=port=10001,listen=all' diff --git a/airbyte-integrations/connectors/destination-snowflake/metadata.yaml b/airbyte-integrations/connectors/destination-snowflake/metadata.yaml index 8d6c07817ea4..9c1ee11cc9d4 100644 --- a/airbyte-integrations/connectors/destination-snowflake/metadata.yaml +++ b/airbyte-integrations/connectors/destination-snowflake/metadata.yaml @@ -2,7 +2,7 @@ data: connectorSubtype: database connectorType: destination definitionId: 424892c4-daac-4491-b35d-c6688ba547ba - dockerImageTag: 3.1.7 + dockerImageTag: 3.1.8 dockerRepository: airbyte/destination-snowflake githubIssueLabel: destination-snowflake icon: snowflake.svg diff --git a/docs/integrations/destinations/snowflake.md b/docs/integrations/destinations/snowflake.md index af1b4e6a9283..195fb11133a6 100644 --- a/docs/integrations/destinations/snowflake.md +++ b/docs/integrations/destinations/snowflake.md @@ -270,7 +270,8 @@ Otherwise, make sure to grant the role the required permissions in the desired n ## Changelog | Version | Date | Pull Request | Subject | -| :-------------- | :--------- | :--------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------- | +|:----------------|:-----------|:-----------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 3.1.8 | 2023-09-18 | [\#30479](https://github.com/airbytehq/airbyte/pull/30479) | Fix async memory management | | 3.1.7 | 2023-09-15 | [\#30491](https://github.com/airbytehq/airbyte/pull/30491) | Improve error message display | | 3.1.6 | 2023-09-14 | [\#30439](https://github.com/airbytehq/airbyte/pull/30439) | Fix a transient error | | 3.1.5 | 2023-09-13 | [\#30416](https://github.com/airbytehq/airbyte/pull/30416) | Support `${` in stream name/namespace, and in column names |