Skip to content

Commit

Permalink
CAT: fix incremental by running tests per stream (#36814)
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-yermilov-gl authored Jun 3, 2024
1 parent e64d3b5 commit 1f325ec
Show file tree
Hide file tree
Showing 2 changed files with 326 additions and 379 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@

import json
from pathlib import Path
from typing import Any, Dict, List, Mapping, MutableMapping, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union

import pytest
from airbyte_protocol.models import AirbyteMessage, AirbyteStateMessage, AirbyteStateType, ConfiguredAirbyteCatalog, SyncMode, Type
from airbyte_protocol.models import (
AirbyteMessage,
AirbyteStateMessage,
AirbyteStateStats,
AirbyteStateType,
ConfiguredAirbyteCatalog,
SyncMode,
Type,
)
from connector_acceptance_test import BaseTest
from connector_acceptance_test.config import Config, EmptyStreamConfiguration, IncrementalConfig
from connector_acceptance_test.utils import ConnectorRunner, SecretDict, filter_output, incremental_only_catalog
from connector_acceptance_test.utils.timeouts import TWENTY_MINUTES
from deepdiff import DeepDiff

MIN_BATCHES_TO_TEST: int = 5


@pytest.fixture(name="future_state_configuration")
def future_state_configuration_fixture(inputs, base_path, test_strictness_level) -> Tuple[Path, List[EmptyStreamConfiguration]]:
Expand Down Expand Up @@ -170,70 +180,54 @@ async def test_read_sequential_slices(
pytest.skip("Skipping new incremental test based on acceptance-test-config.yml")
return

output_1 = await docker_runner.call_read(connector_config, configured_catalog_for_incremental)
records_1 = filter_output(output_1, type_=Type.RECORD)
states_1 = filter_output(output_1, type_=Type.STATE)

# We sometimes have duplicate identical state messages in a stream which we can filter out to speed things up
unique_state_messages = [message for index, message in enumerate(states_1) if message not in states_1[:index]]
for stream in configured_catalog_for_incremental.streams:
configured_catalog_for_incremental_per_stream = ConfiguredAirbyteCatalog(streams=[stream])

# Important!
output_1 = await docker_runner.call_read(connector_config, configured_catalog_for_incremental_per_stream)

# There is only a small subset of assertions we can make
# in the absense of enforcing that all connectors return 3 or more state messages
# during the first read.
records_1 = filter_output(output_1, type_=Type.RECORD)
# If the output of a full read is empty, there is no reason to iterate over its state.
# So, reading from any checkpoint of an empty stream will also produce nothing.
if len(records_1) == 0:
continue

# To learn more: https://github.com/airbytehq/airbyte/issues/29926
if len(unique_state_messages) < 3:
pytest.skip("Skipping test because there are not enough state messages to test with")
return
states_1 = filter_output(output_1, type_=Type.STATE)

assert records_1, "First Read should produce at least one record"
# To learn more: https://github.com/airbytehq/airbyte/issues/29926
if len(states_1) == 0:
continue

# For legacy state format, the final state message contains the final state of all streams. For per-stream state format,
# the complete final state of streams must be assembled by going through all prior state messages received
is_per_stream = is_per_stream_state(states_1[-1])
states_with_expected_record_count = self._state_messages_selector(states_1)
if not states_with_expected_record_count:
pytest.fail(
"Unable to test because there is no suitable state checkpoint, likely due to a zero record count in the states."
)

# To avoid spamming APIs we only test a fraction of batches (10%) and enforce a minimum of 10 tested
min_batches_to_test = 5
sample_rate = len(unique_state_messages) // min_batches_to_test
mutating_stream_name_to_per_stream_state = dict()

mutating_stream_name_to_per_stream_state = dict()
for idx, state_message in enumerate(unique_state_messages):
assert state_message.type == Type.STATE
for idx, state_message_data in enumerate(states_with_expected_record_count):
state_message, expected_records_count = state_message_data
assert state_message.type == Type.STATE

# if first state message, skip
# this is because we cannot assert if the first state message will result in new records
# as in this case it is possible for a connector to return an empty state message when it first starts.
# e.g. if the connector decides it wants to let the caller know that it has started with an empty state.
if idx == 0:
continue

# if last state message, skip
# this is because we cannot assert if the last state message will result in new records
# as in this case it is possible for a connector to return a previous state message.
# e.g. if the connector is using pagination and the last page is only partially full
if idx == len(unique_state_messages) - 1:
continue

# if batching required, and not a sample, skip
if len(unique_state_messages) >= min_batches_to_test and idx % sample_rate != 0:
continue
state_input, mutating_stream_name_to_per_stream_state = self.get_next_state_input(
state_message, mutating_stream_name_to_per_stream_state
)

state_input, mutating_stream_name_to_per_stream_state = self.get_next_state_input(
state_message, mutating_stream_name_to_per_stream_state, is_per_stream
)
output_N = await docker_runner.call_read_with_state(
connector_config, configured_catalog_for_incremental_per_stream, state=state_input
)
records_N = filter_output(output_N, type_=Type.RECORD)

output_N = await docker_runner.call_read_with_state(connector_config, configured_catalog_for_incremental, state=state_input)
records_N = filter_output(output_N, type_=Type.RECORD)
assert (
records_N
), f"Read {idx + 2} of {len(unique_state_messages)} should produce at least one record.\n\n state: {state_input} \n\n records_{idx + 2}: {records_N}"
assert (
# We assume that the output may be empty when we read the latest state, or it must produce some data if we are in the middle of our progression
len(records_N)
>= expected_records_count
), f"Read {idx + 1} of {len(states_with_expected_record_count)} should produce at least one record.\n\n state: {state_input} \n\n records_{idx + 1}: {records_N}"

diff = naive_diff_records(records_1, records_N)
assert (
diff
), f"Records for subsequent reads with new state should be different.\n\n records_1: {records_1} \n\n state: {state_input} \n\n records_{idx + 2}: {records_N} \n\n diff: {diff}"
diff = naive_diff_records(records_1, records_N)
assert (
diff
), f"Records for subsequent reads with new state should be different.\n\n records_1: {records_1} \n\n state: {state_input} \n\n records_{idx + 1}: {records_N} \n\n diff: {diff}"

async def test_state_with_abnormally_large_values(
self, connector_config, configured_catalog, future_state, docker_runner: ConnectorRunner
Expand All @@ -249,25 +243,116 @@ async def test_state_with_abnormally_large_values(
assert states, "The sync should produce at least one STATE message"

def get_next_state_input(
self,
state_message: AirbyteStateMessage,
stream_name_to_per_stream_state: MutableMapping,
is_per_stream,
self, state_message: AirbyteStateMessage, stream_name_to_per_stream_state: MutableMapping
) -> Tuple[Union[List[MutableMapping], MutableMapping], MutableMapping]:
if is_per_stream:
# Including all the latest state values from previous batches, update the combined stream state
# with the current batch's stream state and then use it in the following read() request
current_state = state_message.state
if current_state and current_state.type == AirbyteStateType.STREAM:
per_stream = current_state.stream
if per_stream.stream_state:
stream_name_to_per_stream_state[per_stream.stream_descriptor.name] = (
per_stream.stream_state.dict() if per_stream.stream_state else {}
)
state_input = [
{"type": "STREAM", "stream": {"stream_descriptor": {"name": stream_name}, "stream_state": stream_state}}
for stream_name, stream_state in stream_name_to_per_stream_state.items()
]
return state_input, stream_name_to_per_stream_state
else:
return state_message.state.data, state_message.state.data
# Including all the latest state values from previous batches, update the combined stream state
# with the current batch's stream state and then use it in the following read() request
current_state = state_message.state
if current_state and current_state.type == AirbyteStateType.STREAM:
per_stream = current_state.stream
if per_stream.stream_state:
stream_name_to_per_stream_state[per_stream.stream_descriptor.name] = (
per_stream.stream_state.dict() if per_stream.stream_state else {}
)
state_input = [
{"type": "STREAM", "stream": {"stream_descriptor": {"name": stream_name}, "stream_state": stream_state}}
for stream_name, stream_state in stream_name_to_per_stream_state.items()
]
return state_input, stream_name_to_per_stream_state

@staticmethod
def _get_state(airbyte_message: AirbyteMessage) -> AirbyteStateMessage:
if not airbyte_message.state.stream:
return airbyte_message.state
return airbyte_message.state.stream.stream_state

@staticmethod
def _get_record_count(airbyte_message: AirbyteMessage) -> float:
return airbyte_message.state.sourceStats.recordCount

def _get_unique_state_messages_with_record_count(self, states: List[AirbyteMessage]) -> List[Tuple[AirbyteMessage, float]]:
"""
Validates a list of state messages to ensure that consecutive messages with the same stream state are represented by only the first message, while subsequent duplicates are ignored.
"""
if len(states) <= 1:
return [(state, 0.0) for state in states if self._get_record_count(state)]

current_idx = 0
unique_state_messages = []

# Iterate through the list of state messages
while current_idx < len(states) - 1:
next_idx = current_idx + 1
# Check if consecutive messages have the same stream state
while self._get_state(states[current_idx]) == self._get_state(states[next_idx]) and next_idx < len(states) - 1:
next_idx += 1

states[current_idx].state.sourceStats = AirbyteStateStats(
recordCount=sum(map(self._get_record_count, states[current_idx:next_idx]))
)
# Append the first message with a unique stream state to the result list
unique_state_messages.append(states[current_idx])
# If the last message has a different stream state than the previous one, append it to the result list
if next_idx == len(states) - 1 and self._get_state(states[current_idx]) != self._get_state(states[next_idx]):
unique_state_messages.append(states[next_idx])
current_idx = next_idx

# Drop all states with a record count of 0.0
unique_non_zero_state_messages = list(filter(self._get_record_count, unique_state_messages))

total_record_count = sum(map(self._get_record_count, unique_non_zero_state_messages))

# Calculates the expected record count per state based on the total record count and distribution across states.
# The expected record count is the number of records we expect to receive when applying a specific state checkpoint.
unique_non_zero_state_messages_with_record_count = zip(
unique_non_zero_state_messages,
[
total_record_count - sum(map(self._get_record_count, unique_non_zero_state_messages[: idx + 1]))
for idx in range(len(unique_non_zero_state_messages))
],
)

return list(unique_non_zero_state_messages_with_record_count)

def _states_with_expected_record_count_batch_selector(
self, unique_state_messages_with_record_count: List[Tuple[AirbyteMessage, float]]
) -> List[Tuple[AirbyteMessage, float]]:
# Important!

# There is only a small subset of assertions we can make
# in the absense of enforcing that all connectors return 3 or more state messages
# during the first read.
if len(unique_state_messages_with_record_count) < 3:
return unique_state_messages_with_record_count[-1:]

# To avoid spamming APIs we only test a fraction of batches (4 or 5 states by default)
sample_rate = (len(unique_state_messages_with_record_count) // MIN_BATCHES_TO_TEST) or 1

states_with_expected_record_count_batch = []

for idx, state_message_data in enumerate(unique_state_messages_with_record_count):
# if first state message, skip
# this is because we cannot assert if the first state message will result in new records
# as in this case it is possible for a connector to return an empty state message when it first starts.
# e.g. if the connector decides it wants to let the caller know that it has started with an empty state.
if idx == 0:
continue

# if batching required, and not a sample, skip
if idx % sample_rate != 0:
continue

# if last state message, skip
# this is because we cannot assert if the last state message will result in new records
# as in this case it is possible for a connector to return a previous state message.
# e.g. if the connector is using pagination and the last page is only partially full
if idx == len(unique_state_messages_with_record_count) - 1:
continue

states_with_expected_record_count_batch.append(state_message_data)

return states_with_expected_record_count_batch

def _state_messages_selector(self, state_messages: List[AirbyteMessage]) -> List[Tuple[AirbyteMessage, float]]:
unique_state_messages_with_record_count = self._get_unique_state_messages_with_record_count(state_messages)
return self._states_with_expected_record_count_batch_selector(unique_state_messages_with_record_count)
Loading

0 comments on commit 1f325ec

Please sign in to comment.