Skip to content

Commit

Permalink
Improve the memory usage of the async destination (#30479)
Browse files Browse the repository at this point in the history
Co-authored-by: benmoriceau <[email protected]>
  • Loading branch information
benmoriceau and benmoriceau authored Sep 18, 2023
1 parent f226503 commit a7c13dc
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -110,16 +109,14 @@ public void accept(final String messageString, final Integer sizeInBytes) throws
* to try to use a thread pool to partially deserialize to get record type and stream name, we can
* do it without touching buffer manager.
*/
deserializeAirbyteMessage(messageString)
.ifPresent(message -> {
if (Type.RECORD.equals(message.getType())) {
if (Strings.isNullOrEmpty(message.getRecord().getNamespace())) {
message.getRecord().setNamespace(defaultNamespace);
}
validateRecord(message);
}
bufferEnqueue.addRecord(message, sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES);
});
final var message = deserializeAirbyteMessage(messageString);
if (Type.RECORD.equals(message.getType())) {
if (Strings.isNullOrEmpty(message.getRecord().getNamespace())) {
message.getRecord().setNamespace(defaultNamespace);
}
validateRecord(message);
}
bufferEnqueue.addRecord(message, sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES);
}

/**
Expand All @@ -134,24 +131,27 @@ public void accept(final String messageString, final Integer sizeInBytes) throws
* @return PartialAirbyteMessage if the message is valid, empty otherwise
*/
@VisibleForTesting
public static Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String messageString) {
public static PartialAirbyteMessage deserializeAirbyteMessage(final String messageString) {
// TODO: (ryankfu) plumb in the serialized AirbyteStateMessage to match AirbyteRecordMessage code
// parity. https://github.com/airbytehq/airbyte/issues/27530 for additional context
final Optional<PartialAirbyteMessage> messageOptional = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class)
.map(partial -> {
if (Type.RECORD.equals(partial.getType()) && partial.getRecord().getData() != null) {
return partial.withSerialized(partial.getRecord().getData().toString());
} else if (Type.STATE.equals(partial.getType())) {
return partial.withSerialized(messageString);
} else {
return null;
}
});

if (messageOptional.isPresent()) {
return messageOptional;
final var partial = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class)
.orElseThrow(() -> new RuntimeException("Unable to deserialize PartialAirbyteMessage."));

final var msgType = partial.getType();
if (Type.RECORD.equals(msgType) && partial.getRecord().getData() != null) {
// store serialized json
partial.withSerialized(partial.getRecord().getData().toString());
// The connector doesn't need to be able to access to the record value. We can serialize it here and
// drop the json
// object. Having this data stored as a string is slightly more optimal for the memory usage.
partial.getRecord().setData(null);
} else if (Type.STATE.equals(msgType)) {
partial.withSerialized(messageString);
} else {
throw new RuntimeException(String.format("Unsupported message type: %s", msgType));
}
throw new RuntimeException("Invalid serialized message");

return partial;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ public void close() throws Exception {
}

private void printQueueInfo() {
final var queueInfo = new StringBuilder().append("QUEUE INFO").append(System.lineSeparator());
final var queueInfo = new StringBuilder().append("START OF QUEUE INFO").append(System.lineSeparator())
.append("This represents an estimation of the size of the elements contain in the in memory buffer.")
.append(System.lineSeparator());

queueInfo
.append(String.format(" Global Mem Manager -- max: %s, allocated: %s (%s MB), %% used: %s",
Expand All @@ -91,6 +93,12 @@ private void printQueueInfo() {
entry.getKey().getName(), queue.size(), AirbyteFileUtils.byteCountToDisplaySize(queue.getCurrentMemoryUsage())))
.append(System.lineSeparator());
}

queueInfo.append(stateManager.getMemoryUsageMessage())
.append(System.lineSeparator());

queueInfo.append("END OF QUEUE INFO");

log.info(queueInfo.toString());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
Expand Down Expand Up @@ -68,9 +69,8 @@ public class GlobalAsyncStateManager {
private final AtomicLong memoryUsed;

boolean preState = true;
private final ConcurrentMap<StreamDescriptor, LinkedList<Long>> descToStateIdQ = new ConcurrentHashMap<>();
private final ConcurrentMap<Long, AtomicLong> stateIdToCounter = new ConcurrentHashMap<>();
private final ConcurrentMap<StreamDescriptor, LinkedList<Long>> streamToStateIdQ = new ConcurrentHashMap<>();

private final ConcurrentMap<Long, ImmutablePair<PartialAirbyteMessage, Long>> stateIdToState = new ConcurrentHashMap<>();
// empty in the STREAM case.

Expand Down Expand Up @@ -144,30 +144,39 @@ public void decrement(final long stateId, final long count) {
public List<PartialAirbyteMessage> flushStates() {
final List<PartialAirbyteMessage> output = new ArrayList<>();
Long bytesFlushed = 0L;
for (final Map.Entry<StreamDescriptor, LinkedList<Long>> entry : streamToStateIdQ.entrySet()) {
// Remove all states with 0 counters.
// Per-stream synchronized is required to make sure the state (at the head of the queue)
// logic is applied to is the state actually removed.
synchronized (this) {
synchronized (this) {
for (final Map.Entry<StreamDescriptor, LinkedList<Long>> entry : descToStateIdQ.entrySet()) {
// Remove all states with 0 counters.
// Per-stream synchronized is required to make sure the state (at the head of the queue)
// logic is applied to is the state actually removed.

final LinkedList<Long> stateIdQueue = entry.getValue();
while (true) {
final Long oldestState = stateIdQueue.peek();
if (oldestState == null) {
final Long oldestStateId = stateIdQueue.peek();
// no state to flush for this stream
if (oldestStateId == null) {
break;
}

// technically possible this map hasn't been updated yet.
final boolean noCorrespondingStateMsg = stateIdToState.get(oldestState) == null;
if (noCorrespondingStateMsg) {
final var oldestStateCounter = stateIdToCounter.get(oldestStateId);
Objects.requireNonNull(oldestStateCounter, "Invariant Violation: No record counter found for state message.");

final var oldestState = stateIdToState.get(oldestStateId);
// no state to flush for this stream
if (oldestState == null) {
break;
}

final boolean noPrevRecs = !stateIdToCounter.containsKey(oldestState);
final boolean allRecsEmitted = stateIdToCounter.get(oldestState).get() == 0;
if (noPrevRecs || allRecsEmitted) {
var polled = entry.getValue().poll(); // poll to remove. no need to read as the earlier peek is still valid.
output.add(stateIdToState.get(oldestState).getLeft());
bytesFlushed += stateIdToState.get(oldestState).getRight();
final var allRecordsCommitted = oldestStateCounter.get() == 0;
if (allRecordsCommitted) {
output.add(oldestState.getLeft());
bytesFlushed += oldestState.getRight();

// cleanup
entry.getValue().poll();
stateIdToState.remove(oldestStateId);
stateIdToCounter.remove(oldestStateId);
} else {
break;
}
Expand All @@ -183,10 +192,10 @@ private Long getStateIdAndIncrement(final StreamDescriptor streamDescriptor, fin
final StreamDescriptor resolvedDescriptor = stateType == AirbyteStateMessage.AirbyteStateType.STREAM ? streamDescriptor : SENTINEL_GLOBAL_DESC;
// As concurrent collections do not guarantee data consistency when iterating, use `get` instead of
// `containsKey`.
if (streamToStateIdQ.get(resolvedDescriptor) == null) {
if (descToStateIdQ.get(resolvedDescriptor) == null) {
registerNewStreamDescriptor(resolvedDescriptor);
}
final Long stateId = streamToStateIdQ.get(resolvedDescriptor).peekLast();
final Long stateId = descToStateIdQ.get(resolvedDescriptor).peekLast();
final var update = stateIdToCounter.get(stateId).addAndGet(increment);
log.trace("State id: {}, count: {}", stateId, update);
return stateId;
Expand Down Expand Up @@ -231,12 +240,12 @@ private void convertToGlobalIfNeeded(final PartialAirbyteMessage message) {
// upon conversion, all previous tracking data structures need to be cleared as we move
// into the non-STREAM world for correctness.

aliasIds.addAll(streamToStateIdQ.values().stream().flatMap(Collection::stream).toList());
streamToStateIdQ.clear();
aliasIds.addAll(descToStateIdQ.values().stream().flatMap(Collection::stream).toList());
descToStateIdQ.clear();
retroactiveGlobalStateId = StateIdProvider.getNextId();

streamToStateIdQ.put(SENTINEL_GLOBAL_DESC, new LinkedList<>());
streamToStateIdQ.get(SENTINEL_GLOBAL_DESC).add(retroactiveGlobalStateId);
descToStateIdQ.put(SENTINEL_GLOBAL_DESC, new LinkedList<>());
descToStateIdQ.get(SENTINEL_GLOBAL_DESC).add(retroactiveGlobalStateId);

final long combinedCounter = stateIdToCounter.values()
.stream()
Expand Down Expand Up @@ -291,9 +300,12 @@ private void allocateMemoryToState(final long sizeInBytes) {
throw new RuntimeException(e);
}
}
LOGGER.debug(getMemoryUsageMessage());
}
memoryUsed.addAndGet(sizeInBytes);
LOGGER.debug("State Manager memory usage: Allocated: {}, Used: {}, % Used {}",
}

public String getMemoryUsageMessage() {
return String.format("State Manager memory usage: Allocated: %s, Used: %s, percentage Used %f",
FileUtils.byteCountToDisplaySize(memoryAllocated.get()),
FileUtils.byteCountToDisplaySize(memoryUsed.get()),
(double) memoryUsed.get() / memoryAllocated.get());
Expand All @@ -312,25 +324,25 @@ private long getStateAfterAlias(final long stateId) {
}

private void registerNewStreamDescriptor(final StreamDescriptor resolvedDescriptor) {
streamToStateIdQ.put(resolvedDescriptor, new LinkedList<>());
descToStateIdQ.put(resolvedDescriptor, new LinkedList<>());
registerNewStateId(resolvedDescriptor);
}

private void registerNewStateId(final StreamDescriptor resolvedDescriptor) {
final long stateId = StateIdProvider.getNextId();
streamToStateIdQ.get(resolvedDescriptor).add(stateId);
stateIdToCounter.put(stateId, new AtomicLong(0));
descToStateIdQ.get(resolvedDescriptor).add(stateId);
}

/**
* Simplify internal tracking by providing a global always increasing counter for state ids.
*/
private static class StateIdProvider {

private static long pk = 0;
private static final AtomicLong pk = new AtomicLong(0);

public static long getNextId() {
return pk++;
return pk.incrementAndGet();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
Expand Down Expand Up @@ -236,8 +235,8 @@ void deserializeAirbyteMessageWithAirbyteRecord() {
.withData(PAYLOAD));
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
final String airbyteRecordString = Jsons.serialize(PAYLOAD);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(airbyteRecordString, partial.get().getSerialized());
final PartialAirbyteMessage partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(airbyteRecordString, partial.getSerialized());
}

@Test
Expand All @@ -250,8 +249,8 @@ void deserializeAirbyteMessageWithEmptyAirbyteRecord() {
.withNamespace(SCHEMA_NAME)
.withData(Jsons.jsonNode(emptyMap)));
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(emptyMap.toString(), partial.get().getSerialized());
final PartialAirbyteMessage partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(emptyMap.toString(), partial.getSerialized());
}

@Test
Expand All @@ -266,8 +265,8 @@ void deserializeAirbyteMessageWithNoStateOrRecord() {
@Test
void deserializeAirbyteMessageWithAirbyteState() {
final String serializedAirbyteMessage = Jsons.serialize(STATE_MESSAGE1);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(serializedAirbyteMessage, partial.get().getSerialized());
final PartialAirbyteMessage partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(serializedAirbyteMessage, partial.getSerialized());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ RUN tar xf ${APPLICATION}.tar --strip-components=1

ENV ENABLE_SENTRY true

LABEL io.airbyte.version=3.1.7
LABEL io.airbyte.version=3.1.8
LABEL io.airbyte.name=airbyte/destination-snowflake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ application {
// '-Dcom.sun.management.jmxremote=true',
// '-Dcom.sun.management.jmxremote.port=6000',
// '-Dcom.sun.management.jmxremote.rmi.port=6000',
// '-Dcom.sun.management.jmxremote.local.only=false',
// '-Dcom.sun.management.jmxremote.local.only=false'
// '-Dcom.sun.management.jmxremote.authenticate=false',
// '-Dcom.sun.management.jmxremote.ssl=false',
// '-agentpath:/usr/local/YourKit-JavaProfiler-2021.3/bin/linux-x86-64/libyjpagent.so=port=10001,listen=all'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 3.1.7
dockerImageTag: 3.1.8
dockerRepository: airbyte/destination-snowflake
githubIssueLabel: destination-snowflake
icon: snowflake.svg
Expand Down
3 changes: 2 additions & 1 deletion docs/integrations/destinations/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ Otherwise, make sure to grant the role the required permissions in the desired n
## Changelog

| Version | Date | Pull Request | Subject |
| :-------------- | :--------- | :--------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|:----------------|:-----------|:-----------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 3.1.8 | 2023-09-18 | [\#30479](https://github.com/airbytehq/airbyte/pull/30479) | Fix async memory management |
| 3.1.7 | 2023-09-15 | [\#30491](https://github.com/airbytehq/airbyte/pull/30491) | Improve error message display |
| 3.1.6 | 2023-09-14 | [\#30439](https://github.com/airbytehq/airbyte/pull/30439) | Fix a transient error |
| 3.1.5 | 2023-09-13 | [\#30416](https://github.com/airbytehq/airbyte/pull/30416) | Support `${` in stream name/namespace, and in column names |
Expand Down

0 comments on commit a7c13dc

Please sign in to comment.