diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 7a33d9f16..001740a35 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -3,7 +3,7 @@ # import logging -from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union +from typing import Any, Generic, Iterator, List, Mapping, Optional, Tuple, Union, Callable from airbyte_cdk.models import ( AirbyteCatalog, @@ -27,18 +27,24 @@ ) from airbyte_cdk.sources.declarative.models.declarative_component_schema import ( DatetimeBasedCursor as DatetimeBasedCursorModel, + DeclarativeStream as DeclarativeStreamModel, ) from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import ( ModelToComponentFactory, + ComponentDefinition, ) from airbyte_cdk.sources.declarative.requesters import HttpRequester -from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever +from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever, Retriever +from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( + DeclarativePartitionFactory, + StreamSlicerPartitionGenerator, +) from airbyte_cdk.sources.declarative.transformations.add_fields import AddFields from airbyte_cdk.sources.declarative.types import ConnectionDefinition from airbyte_cdk.sources.source import TState +from airbyte_cdk.sources.types import Config, StreamState from airbyte_cdk.sources.streams import Stream from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream -from airbyte_cdk.sources.streams.concurrent.adapters import CursorPartitionGenerator from airbyte_cdk.sources.streams.concurrent.availability_strategy import ( AlwaysAvailableAvailabilityStrategy, ) @@ -213,31 +219,18 @@ def _group_streams( ) ) - # This is an optimization so that we don't invoke any cursor or state management flows within the - # low-code framework because state management is handled through the ConcurrentCursor. - if ( - declarative_stream - and declarative_stream.retriever - and isinstance(declarative_stream.retriever, SimpleRetriever) - ): - # Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is - # called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor - # for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and - # ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized - # with state. - if declarative_stream.retriever.cursor: - declarative_stream.retriever.cursor.set_initial_state( - stream_state=stream_state - ) - declarative_stream.retriever.cursor = None - - partition_generator = CursorPartitionGenerator( - stream=declarative_stream, - message_repository=self.message_repository, # type: ignore # message_repository is always instantiated with a value by factory - cursor=cursor, - connector_state_converter=connector_state_converter, - cursor_field=[cursor.cursor_field.cursor_field_key], - slice_boundary_fields=cursor.slice_boundary_fields, + partition_generator = StreamSlicerPartitionGenerator( + DeclarativePartitionFactory( + declarative_stream.name, + declarative_stream.get_json_schema(), + self._retriever_factory( + name_to_stream_mapping[declarative_stream.name], + config, + stream_state, + ), + self.message_repository, + ), + cursor, ) concurrent_streams.append( @@ -350,3 +343,34 @@ def _remove_concurrent_streams_from_catalog( if stream.stream.name not in concurrent_stream_names ] ) + + def _retriever_factory( + self, stream_config: ComponentDefinition, source_config: Config, stream_state: StreamState + ) -> Callable[[], Retriever]: + def _factory_method() -> Retriever: + declarative_stream: DeclarativeStream = self._constructor.create_component( + DeclarativeStreamModel, + stream_config, + source_config, + emit_connector_builder_messages=self._emit_connector_builder_messages, + ) + + # This is an optimization so that we don't invoke any cursor or state management flows within the + # low-code framework because state management is handled through the ConcurrentCursor. + if ( + declarative_stream + and declarative_stream.retriever + and isinstance(declarative_stream.retriever, SimpleRetriever) + ): + # Also a temporary hack. In the legacy Stream implementation, as part of the read, set_initial_state() is + # called to instantiate incoming state on the cursor. Although we no longer rely on the legacy low-code cursor + # for concurrent checkpointing, low-code components like StopConditionPaginationStrategyDecorator and + # ClientSideIncrementalRecordFilterDecorator still rely on a DatetimeBasedCursor that is properly initialized + # with state. + if declarative_stream.retriever.cursor: + declarative_stream.retriever.cursor.set_initial_state(stream_state=stream_state) + declarative_stream.retriever.cursor = None + + return declarative_stream.retriever + + return _factory_method diff --git a/airbyte_cdk/sources/declarative/manifest_declarative_source.py b/airbyte_cdk/sources/declarative/manifest_declarative_source.py index 05a80321d..82f4dff3a 100644 --- a/airbyte_cdk/sources/declarative/manifest_declarative_source.py +++ b/airbyte_cdk/sources/declarative/manifest_declarative_source.py @@ -8,7 +8,7 @@ import re from copy import deepcopy from importlib import metadata -from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple import yaml from airbyte_cdk.models import ( @@ -94,7 +94,7 @@ def resolved_manifest(self) -> Mapping[str, Any]: return self._source_config @property - def message_repository(self) -> Union[None, MessageRepository]: + def message_repository(self) -> MessageRepository: return self._message_repository @property diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py new file mode 100644 index 000000000..1c2ad06cf --- /dev/null +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. + +from typing import Iterable, Optional, Mapping, Any, Callable + +from airbyte_cdk.sources.declarative.retrievers import Retriever +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition +from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator +from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer +from airbyte_cdk.sources.types import StreamSlice +from airbyte_cdk.utils.slice_hasher import SliceHasher + + +class DeclarativePartitionFactory: + def __init__( + self, + stream_name: str, + json_schema: Mapping[str, Any], + retriever_factory: Callable[[], Retriever], + message_repository: MessageRepository, + ) -> None: + """ + The DeclarativePartitionFactory takes a retriever_factory and not a retriever directly. The reason is that our components are not + thread safe and classes like `DefaultPaginator` may not work because multiple threads can access and modify a shared field across each other. + In order to avoid these problems, we will create one retriever per thread which should make the processing thread-safe. + """ + self._stream_name = stream_name + self._json_schema = json_schema + self._retriever_factory = retriever_factory + self._message_repository = message_repository + + def create(self, stream_slice: StreamSlice) -> Partition: + return DeclarativePartition( + self._stream_name, + self._json_schema, + self._retriever_factory(), + self._message_repository, + stream_slice, + ) + + +class DeclarativePartition(Partition): + def __init__( + self, + stream_name: str, + json_schema: Mapping[str, Any], + retriever: Retriever, + message_repository: MessageRepository, + stream_slice: StreamSlice, + ): + self._stream_name = stream_name + self._json_schema = json_schema + self._retriever = retriever + self._message_repository = message_repository + self._stream_slice = stream_slice + self._hash = SliceHasher.hash(self._stream_name, self._stream_slice) + + def read(self) -> Iterable[Record]: + for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice): + if isinstance(stream_data, Mapping): + yield Record(stream_data, self) + else: + self._message_repository.emit_message(stream_data) + + def to_slice(self) -> Optional[Mapping[str, Any]]: + return self._stream_slice + + def stream_name(self) -> str: + return self._stream_name + + def __hash__(self) -> int: + return self._hash + + +class StreamSlicerPartitionGenerator(PartitionGenerator): + def __init__( + self, partition_factory: DeclarativePartitionFactory, stream_slicer: StreamSlicer + ) -> None: + self._partition_factory = partition_factory + self._stream_slicer = stream_slicer + + def generate(self) -> Iterable[Partition]: + for stream_slice in self._stream_slicer.stream_slices(): + yield self._partition_factory.create(stream_slice) diff --git a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py index af9c438f8..db15496ff 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/stream_slicer.py @@ -2,18 +2,17 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from abc import abstractmethod -from dataclasses import dataclass -from typing import Iterable +from abc import ABC from airbyte_cdk.sources.declarative.requesters.request_options.request_options_provider import ( RequestOptionsProvider, ) -from airbyte_cdk.sources.types import StreamSlice +from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import ( + StreamSlicer as ConcurrentStreamSlicer, +) -@dataclass -class StreamSlicer(RequestOptionsProvider): +class StreamSlicer(ConcurrentStreamSlicer, RequestOptionsProvider, ABC): """ Slices the stream into a subset of records. Slices enable state checkpointing and data retrieval parallelization. @@ -23,10 +22,4 @@ class StreamSlicer(RequestOptionsProvider): See the stream slicing section of the docs for more information. """ - @abstractmethod - def stream_slices(self) -> Iterable[StreamSlice]: - """ - Defines stream slices - - :return: List of stream slices - """ + pass diff --git a/airbyte_cdk/sources/streams/concurrent/adapters.py b/airbyte_cdk/sources/streams/concurrent/adapters.py index 1df713037..679f2d865 100644 --- a/airbyte_cdk/sources/streams/concurrent/adapters.py +++ b/airbyte_cdk/sources/streams/concurrent/adapters.py @@ -38,15 +38,13 @@ from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.partition_generator import PartitionGenerator from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( - DateTimeStreamStateConverter, -) from airbyte_cdk.sources.streams.core import StreamData -from airbyte_cdk.sources.types import StreamSlice from airbyte_cdk.sources.utils.schema_helpers import InternalConfig from airbyte_cdk.sources.utils.slice_logger import SliceLogger from deprecated.classic import deprecated +from airbyte_cdk.utils.slice_hasher import SliceHasher + """ This module contains adapters to help enabling concurrency on Stream objects without needing to migrate to AbstractStream """ @@ -270,6 +268,7 @@ def __init__( self._sync_mode = sync_mode self._cursor_field = cursor_field self._state = state + self._hash = SliceHasher.hash(self._stream.name, self._slice) def read(self) -> Iterable[Record]: """ @@ -309,12 +308,7 @@ def to_slice(self) -> Optional[Mapping[str, Any]]: return self._slice def __hash__(self) -> int: - if self._slice: - # Convert the slice to a string so that it can be hashed - s = json.dumps(self._slice, sort_keys=True, cls=SliceEncoder) - return hash((self._stream.name, s)) - else: - return hash(self._stream.name) + return self._hash def stream_name(self) -> str: return self._stream.name @@ -363,83 +357,6 @@ def generate(self) -> Iterable[Partition]: ) -class CursorPartitionGenerator(PartitionGenerator): - """ - This class generates partitions using the concurrent cursor and iterates through state slices to generate partitions. - - It is used when synchronizing a stream in incremental or full-refresh mode where state information is maintained - across partitions. Each partition represents a subset of the stream's data and is determined by the cursor's state. - """ - - _START_BOUNDARY = 0 - _END_BOUNDARY = 1 - - def __init__( - self, - stream: Stream, - message_repository: MessageRepository, - cursor: Cursor, - connector_state_converter: DateTimeStreamStateConverter, - cursor_field: Optional[List[str]], - slice_boundary_fields: Optional[Tuple[str, str]], - ): - """ - Initialize the CursorPartitionGenerator with a stream, sync mode, and cursor. - - :param stream: The stream to delegate to for partition generation. - :param message_repository: The message repository to use to emit non-record messages. - :param sync_mode: The synchronization mode. - :param cursor: A Cursor object that maintains the state and the cursor field. - """ - self._stream = stream - self.message_repository = message_repository - self._sync_mode = SyncMode.full_refresh - self._cursor = cursor - self._cursor_field = cursor_field - self._state = self._cursor.state - self._slice_boundary_fields = slice_boundary_fields - self._connector_state_converter = connector_state_converter - - def generate(self) -> Iterable[Partition]: - """ - Generate partitions based on the slices in the cursor's state. - - This method iterates through the list of slices found in the cursor's state, and for each slice, it generates - a `StreamPartition` object. - - :return: An iterable of StreamPartition objects. - """ - - start_boundary = ( - self._slice_boundary_fields[self._START_BOUNDARY] - if self._slice_boundary_fields - else "start" - ) - end_boundary = ( - self._slice_boundary_fields[self._END_BOUNDARY] - if self._slice_boundary_fields - else "end" - ) - - for slice_start, slice_end in self._cursor.generate_slices(): - stream_slice = StreamSlice( - partition={}, - cursor_slice={ - start_boundary: self._connector_state_converter.output_format(slice_start), - end_boundary: self._connector_state_converter.output_format(slice_end), - }, - ) - - yield StreamPartition( - self._stream, - copy.deepcopy(stream_slice), - self.message_repository, - self._sync_mode, - self._cursor_field, - self._state, - ) - - @deprecated( "Availability strategy has been soft deprecated. Do not use. Class is subject to removal", category=ExperimentalClassWarning, diff --git a/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte_cdk/sources/streams/concurrent/cursor.py index 15e9b59a4..1cc7e8965 100644 --- a/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -11,9 +11,11 @@ from airbyte_cdk.sources.streams import NO_CURSOR_STATE_KEY from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.concurrent.partitions.stream_slicer import StreamSlicer from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ( AbstractStreamStateConverter, ) +from airbyte_cdk.sources.types import StreamSlice def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any: @@ -61,7 +63,7 @@ def extract_value(self, record: Record) -> CursorValueType: return cursor_value # type: ignore # we assume that the value the path points at is a comparable -class Cursor(ABC): +class Cursor(StreamSlicer, ABC): @property @abstractmethod def state(self) -> MutableMapping[str, Any]: ... @@ -88,12 +90,12 @@ def ensure_at_least_one_state_emitted(self) -> None: """ raise NotImplementedError() - def generate_slices(self) -> Iterable[Tuple[Any, Any]]: + def stream_slices(self) -> Iterable[StreamSlice]: """ Default placeholder implementation of generate_slices. Subclasses can override this method to provide actual behavior. """ - yield from () + yield StreamSlice(partition={}, cursor_slice={}) class FinalStateCursor(Cursor): @@ -184,8 +186,15 @@ def cursor_field(self) -> CursorField: return self._cursor_field @property - def slice_boundary_fields(self) -> Optional[Tuple[str, str]]: - return self._slice_boundary_fields + def _slice_boundary_fields_wrapper(self) -> Tuple[str, str]: + return ( + self._slice_boundary_fields + if self._slice_boundary_fields + else ( + self._connector_state_converter.START_KEY, + self._connector_state_converter.END_KEY, + ) + ) def _get_concurrent_state( self, state: MutableMapping[str, Any] @@ -299,7 +308,7 @@ def ensure_at_least_one_state_emitted(self) -> None: """ self._emit_state_message() - def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]: + def stream_slices(self) -> Iterable[StreamSlice]: """ Generating slices based on a few parameters: * lookback_window: Buffer to remove from END_KEY of the highest slice @@ -368,7 +377,7 @@ def _calculate_lower_boundary_of_last_slice( def _split_per_slice_range( self, lower: CursorValueType, upper: CursorValueType, upper_is_end: bool - ) -> Iterable[Tuple[CursorValueType, CursorValueType]]: + ) -> Iterable[StreamSlice]: if lower >= upper: return @@ -377,10 +386,22 @@ def _split_per_slice_range( lower = max(lower, self._start) if self._start else lower if not self._slice_range or self._evaluate_upper_safely(lower, self._slice_range) >= upper: - if self._cursor_granularity and not upper_is_end: - yield lower, upper - self._cursor_granularity - else: - yield lower, upper + start_value, end_value = ( + (lower, upper - self._cursor_granularity) + if self._cursor_granularity and not upper_is_end + else (lower, upper) + ) + yield StreamSlice( + partition={}, + cursor_slice={ + self._slice_boundary_fields_wrapper[ + self._START_BOUNDARY + ]: self._connector_state_converter.output_format(start_value), + self._slice_boundary_fields_wrapper[ + self._END_BOUNDARY + ]: self._connector_state_converter.output_format(end_value), + }, + ) else: stop_processing = False current_lower_boundary = lower @@ -389,12 +410,24 @@ def _split_per_slice_range( self._evaluate_upper_safely(current_lower_boundary, self._slice_range), upper ) has_reached_upper_boundary = current_upper_boundary >= upper - if self._cursor_granularity and ( - not upper_is_end or not has_reached_upper_boundary - ): - yield current_lower_boundary, current_upper_boundary - self._cursor_granularity - else: - yield current_lower_boundary, current_upper_boundary + + start_value, end_value = ( + (current_lower_boundary, current_upper_boundary - self._cursor_granularity) + if self._cursor_granularity + and (not upper_is_end or not has_reached_upper_boundary) + else (current_lower_boundary, current_upper_boundary) + ) + yield StreamSlice( + partition={}, + cursor_slice={ + self._slice_boundary_fields_wrapper[ + self._START_BOUNDARY + ]: self._connector_state_converter.output_format(start_value), + self._slice_boundary_fields_wrapper[ + self._END_BOUNDARY + ]: self._connector_state_converter.output_format(end_value), + }, + ) current_lower_boundary = current_upper_boundary if current_upper_boundary >= upper: stop_processing = True diff --git a/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py b/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py new file mode 100644 index 000000000..98ac04ed7 --- /dev/null +++ b/airbyte_cdk/sources/streams/concurrent/partitions/stream_slicer.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. + +from abc import ABC, abstractmethod +from typing import Iterable + +from airbyte_cdk.sources.types import StreamSlice + + +class StreamSlicer(ABC): + """ + Slices the stream into chunks that can be fetched independently. Slices enable state checkpointing and data retrieval parallelization. + """ + + @abstractmethod + def stream_slices(self) -> Iterable[StreamSlice]: + """ + Defines stream slices + + :return: An iterable of stream slices + """ + pass diff --git a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py index 1b4779761..987915317 100644 --- a/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py +++ b/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py @@ -124,6 +124,13 @@ def increment(self, value: Any) -> Any: """ ... + @abstractmethod + def output_format(self, value: Any) -> Any: + """ + Convert the cursor value type to a JSON valid type. + """ + ... + def merge_intervals( self, intervals: List[MutableMapping[str, Any]] ) -> List[MutableMapping[str, Any]]: diff --git a/airbyte_cdk/utils/slice_hasher.py b/airbyte_cdk/utils/slice_hasher.py new file mode 100644 index 000000000..d86147da0 --- /dev/null +++ b/airbyte_cdk/utils/slice_hasher.py @@ -0,0 +1,30 @@ +import hashlib +import json +from typing import Any, Mapping, Optional, Final + + +class SliceEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if hasattr(obj, "__json_serializable__"): + return obj.__json_serializable__() + + # Let the base class default method raise the TypeError + return super().default(obj) + + +class SliceHasher: + _ENCODING: Final = "utf-8" + + @classmethod + def hash(cls, stream_name: str, stream_slice: Optional[Mapping[str, Any]] = None) -> int: + if stream_slice: + try: + s = json.dumps(stream_slice, sort_keys=True, cls=SliceEncoder) + hash_input = f"{stream_name}:{s}".encode(cls._ENCODING) + except TypeError as e: + raise ValueError(f"Failed to serialize stream slice: {e}") + else: + hash_input = stream_name.encode(cls._ENCODING) + + # Use last 8 bytes as 64-bit integer for better distribution + return int.from_bytes(hashlib.sha256(hash_input).digest()[-8:], "big") diff --git a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index c8d0781ab..3ec1f3765 100644 --- a/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -3073,11 +3073,11 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields( assert concurrent_cursor._lookback_window == expected_lookback_window assert ( - concurrent_cursor.slice_boundary_fields[ConcurrentCursor._START_BOUNDARY] + concurrent_cursor._slice_boundary_fields[ConcurrentCursor._START_BOUNDARY] == expected_start_boundary ) assert ( - concurrent_cursor.slice_boundary_fields[ConcurrentCursor._END_BOUNDARY] + concurrent_cursor._slice_boundary_fields[ConcurrentCursor._END_BOUNDARY] == expected_end_boundary ) @@ -3096,14 +3096,14 @@ def test_create_concurrent_cursor_from_datetime_based_cursor_all_fields( [ pytest.param( {"partition_field_start": None}, - "slice_boundary_fields", + "_slice_boundary_fields", ("start_time", "custom_end"), None, id="test_no_partition_field_start", ), pytest.param( {"partition_field_end": None}, - "slice_boundary_fields", + "_slice_boundary_fields", ("custom_start", "end_time"), None, id="test_no_partition_field_end", diff --git a/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py new file mode 100644 index 000000000..be601de0d --- /dev/null +++ b/unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py @@ -0,0 +1,73 @@ +from typing import List +from unittest import TestCase +from unittest.mock import Mock + +from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type +from airbyte_cdk.sources.declarative.retrievers import Retriever +from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import ( + DeclarativePartitionFactory, +) +from airbyte_cdk.sources.message import MessageRepository +from airbyte_cdk.sources.streams.core import StreamData +from airbyte_cdk.sources.types import StreamSlice + +_STREAM_NAME = "a_stream_name" +_JSON_SCHEMA = {"type": "object", "properties": {}} +_A_STREAM_SLICE = StreamSlice( + partition={"partition_key": "partition_value"}, cursor_slice={"cursor_key": "cursor_value"} +) +_ANOTHER_STREAM_SLICE = StreamSlice( + partition={"partition_key": "another_partition_value"}, + cursor_slice={"cursor_key": "cursor_value"}, +) +_AIRBYTE_LOG_MESSAGE = AirbyteMessage( + type=Type.LOG, log=AirbyteLogMessage(level=Level.DEBUG, message="a log message") +) +_A_RECORD = {"record_field": "record_value"} + + +class StreamSlicerPartitionGeneratorTest(TestCase): + def setUp(self) -> None: + self._retriever_factory = Mock() + self._message_repository = Mock(spec=MessageRepository) + self._partition_factory = DeclarativePartitionFactory( + _STREAM_NAME, + _JSON_SCHEMA, + self._retriever_factory, + self._message_repository, + ) + + def test_given_multiple_slices_when_read_then_read_from_different_retrievers(self) -> None: + first_retriever = self._mock_retriever([]) + second_retriever = self._mock_retriever([]) + self._retriever_factory.side_effect = [first_retriever, second_retriever] + + list(self._partition_factory.create(_A_STREAM_SLICE).read()) + list(self._partition_factory.create(_ANOTHER_STREAM_SLICE).read()) + + first_retriever.read_records.assert_called_once() + second_retriever.read_records.assert_called_once() + + def test_given_a_mapping_when_read_then_yield_record(self) -> None: + retriever = self._mock_retriever([_A_RECORD]) + self._retriever_factory.return_value = retriever + partition = self._partition_factory.create(_A_STREAM_SLICE) + + records = list(partition.read()) + + assert len(records) == 1 + assert records[0].partition == partition + assert records[0].data == _A_RECORD + + def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> None: + retriever = self._mock_retriever([_AIRBYTE_LOG_MESSAGE]) + self._retriever_factory.return_value = retriever + + list(self._partition_factory.create(_A_STREAM_SLICE).read()) + + self._message_repository.emit_message.assert_called_once_with(_AIRBYTE_LOG_MESSAGE) + + def _mock_retriever(self, read_return_value: List[StreamData]) -> Mock: + retriever = Mock(spec=Retriever) + retriever.read_records.return_value = iter(read_return_value) + return retriever diff --git a/unit_tests/sources/streams/concurrent/test_adapters.py b/unit_tests/sources/streams/concurrent/test_adapters.py index 93e8fd212..c7c168c0b 100644 --- a/unit_tests/sources/streams/concurrent/test_adapters.py +++ b/unit_tests/sources/streams/concurrent/test_adapters.py @@ -1,7 +1,6 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -import datetime import logging import unittest from unittest.mock import Mock @@ -12,7 +11,6 @@ from airbyte_cdk.sources.message import InMemoryMessageRepository from airbyte_cdk.sources.streams.concurrent.adapters import ( AvailabilityStrategyFacade, - CursorPartitionGenerator, StreamFacade, StreamPartition, StreamPartitionGenerator, @@ -25,11 +23,7 @@ from airbyte_cdk.sources.streams.concurrent.cursor import Cursor from airbyte_cdk.sources.streams.concurrent.exceptions import ExceptionWithDisplayMessage from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import ( - CustomFormatConcurrentStreamStateConverter, -) from airbyte_cdk.sources.streams.core import Stream -from airbyte_cdk.sources.types import StreamSlice from airbyte_cdk.sources.utils.slice_logger import SliceLogger from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer @@ -149,6 +143,7 @@ def test_stream_partition(transformer, expected_records): def test_stream_partition_raising_exception(exception_type, expected_display_message): stream = Mock() stream.get_error_display_message.return_value = expected_display_message + stream.name = _STREAM_NAME message_repository = InMemoryMessageRepository() _slice = None @@ -175,10 +170,10 @@ def test_stream_partition_raising_exception(exception_type, expected_display_mes [ pytest.param( {"partition": 1, "k": "v"}, - hash(("stream", '{"k": "v", "partition": 1}')), + 1088629586613270006, id="test_hash_with_slice", ), - pytest.param(None, hash("stream"), id="test_hash_no_slice"), + pytest.param(None, 5149571505982114308, id="test_hash_no_slice"), ], ) def test_stream_partition_hash(_slice, expected_hash): @@ -442,43 +437,3 @@ def test_get_error_display_message(exception, expected_display_message): display_message = facade.get_error_display_message(exception) assert display_message == expected_display_message - - -def test_cursor_partition_generator(): - stream = Mock() - cursor = Mock() - message_repository = Mock() - connector_state_converter = CustomFormatConcurrentStreamStateConverter( - datetime_format="%Y-%m-%dT%H:%M:%S" - ) - cursor_field = Mock() - slice_boundary_fields = ("start", "end") - - expected_slices = [ - StreamSlice( - partition={}, - cursor_slice={"start": "2024-01-01T00:00:00", "end": "2024-01-02T00:00:00"}, - ) - ] - cursor.generate_slices.return_value = [ - (datetime.datetime(year=2024, month=1, day=1), datetime.datetime(year=2024, month=1, day=2)) - ] - - partition_generator = CursorPartitionGenerator( - stream, - message_repository, - cursor, - connector_state_converter, - cursor_field, - slice_boundary_fields, - ) - - partitions = list(partition_generator.generate()) - generated_slices = [partition.to_slice() for partition in partitions] - - assert all( - isinstance(partition, StreamPartition) for partition in partitions - ), "Not all partitions are instances of StreamPartition" - assert ( - generated_slices == expected_slices - ), f"Expected {expected_slices}, but got {generated_slices}" diff --git a/unit_tests/sources/streams/concurrent/test_cursor.py b/unit_tests/sources/streams/concurrent/test_cursor.py index 883f2418f..e7a1f42d9 100644 --- a/unit_tests/sources/streams/concurrent/test_cursor.py +++ b/unit_tests/sources/streams/concurrent/test_cursor.py @@ -10,6 +10,7 @@ import freezegun import pytest + from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor @@ -28,6 +29,7 @@ EpochValueConcurrentStreamStateConverter, IsoMillisConcurrentStreamStateConverter, ) +from airbyte_cdk.sources.types import StreamSlice from isodate import parse_duration _A_STREAM_NAME = "a stream name" @@ -227,10 +229,16 @@ def test_given_no_state_when_generate_slices_then_create_slice_from_start_to_end _NO_LOOKBACK_WINDOW, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(10, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 10, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -260,10 +268,16 @@ def test_given_one_slice_when_generate_slices_then_create_slice_from_slice_upper _NO_LOOKBACK_WINDOW, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 20, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -291,10 +305,16 @@ def test_given_start_after_slices_when_generate_slices_then_generate_from_start( _NO_LOOKBACK_WINDOW, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 30, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -328,10 +348,16 @@ def test_given_state_with_gap_and_start_after_slices_when_generate_slices_then_g _NO_LOOKBACK_WINDOW, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 30, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -361,12 +387,30 @@ def test_given_small_slice_range_when_generate_slices_then_create_many_slices(se small_slice_range, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(30, timezone.utc)), - (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(40, timezone.utc)), - (datetime.fromtimestamp(40, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 20, + _SLICE_BOUNDARY_FIELDS[1]: 30, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 30, + _SLICE_BOUNDARY_FIELDS[1]: 40, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 40, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -402,10 +446,16 @@ def test_given_difference_between_slices_match_slice_range_when_generate_slices_ small_slice_range, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(40, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 30, + _SLICE_BOUNDARY_FIELDS[1]: 40, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -439,12 +489,30 @@ def test_given_small_slice_range_with_granularity_when_generate_slices_then_crea granularity, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(29, timezone.utc)), - (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(39, timezone.utc)), - (datetime.fromtimestamp(40, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 20, + _SLICE_BOUNDARY_FIELDS[1]: 29, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 30, + _SLICE_BOUNDARY_FIELDS[1]: 39, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 40, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -482,13 +550,16 @@ def test_given_difference_between_slices_match_slice_range_and_cursor_granularit granularity, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - ( - datetime.fromtimestamp(31, timezone.utc), - datetime.fromtimestamp(40, timezone.utc), - ), # FIXME there should probably be the granularity at the beginning too + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 31, + _SLICE_BOUNDARY_FIELDS[1]: 40, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -525,12 +596,30 @@ def test_given_non_continuous_state_when_generate_slices_then_create_slices_betw _NO_LOOKBACK_WINDOW, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(10, timezone.utc), datetime.fromtimestamp(20, timezone.utc)), - (datetime.fromtimestamp(25, timezone.utc), datetime.fromtimestamp(30, timezone.utc)), - (datetime.fromtimestamp(40, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 10, + _SLICE_BOUNDARY_FIELDS[1]: 20, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 25, + _SLICE_BOUNDARY_FIELDS[1]: 30, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 40, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -565,11 +654,23 @@ def test_given_lookback_window_when_generate_slices_then_apply_lookback_on_most_ lookback_window, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(30, timezone.utc)), - (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 20, + _SLICE_BOUNDARY_FIELDS[1]: 30, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 30, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) @@ -599,11 +700,23 @@ def test_given_start_is_before_first_slice_lower_boundary_when_generate_slices_t _NO_LOOKBACK_WINDOW, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(0, timezone.utc), datetime.fromtimestamp(10, timezone.utc)), - (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 0, + _SLICE_BOUNDARY_FIELDS[1]: 10, + }, + ), + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 20, + _SLICE_BOUNDARY_FIELDS[1]: 50, + }, + ), ] def test_slices_with_records_when_close_then_most_recent_cursor_value_from_most_recent_slice( @@ -714,10 +827,16 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t slice_range=a_very_big_slice_range, ) - slices = list(cursor.generate_slices()) + slices = list(cursor.stream_slices()) assert slices == [ - (datetime.fromtimestamp(0, timezone.utc), datetime.fromtimestamp(10, timezone.utc)) + StreamSlice( + partition={}, + cursor_slice={ + _SLICE_BOUNDARY_FIELDS[0]: 0, + _SLICE_BOUNDARY_FIELDS[1]: 10, + }, + ), ] @@ -733,30 +852,30 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "P5D", {}, [ - ( - datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 10, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 1, 11, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 20, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 1, 21, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 30, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 1, 31, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 9, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 2, 10, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 19, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 2, 20, 0, 0, tzinfo=timezone.utc), - datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc), - ), + { + "start": "2024-01-01T00:00:00.000Z", + "end": "2024-01-10T23:59:59.000Z", + }, + { + "start": "2024-01-11T00:00:00.000Z", + "end": "2024-01-20T23:59:59.000Z", + }, + { + "start": "2024-01-21T00:00:00.000Z", + "end": "2024-01-30T23:59:59.000Z", + }, + { + "start": "2024-01-31T00:00:00.000Z", + "end": "2024-02-09T23:59:59.000Z", + }, + { + "start": "2024-02-10T00:00:00.000Z", + "end": "2024-02-19T23:59:59.000Z", + }, + { + "start": "2024-02-20T00:00:00.000Z", + "end": "2024-03-01T00:00:00.000Z", + }, ], id="test_datetime_based_cursor_all_fields", ), @@ -776,18 +895,18 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "state_type": "date-range", }, [ - ( - datetime(2024, 2, 5, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 14, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 2, 15, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 24, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 2, 25, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc), - ), + { + "start": "2024-02-05T00:00:00.000Z", + "end": "2024-02-14T23:59:59.000Z", + }, + { + "start": "2024-02-15T00:00:00.000Z", + "end": "2024-02-24T23:59:59.000Z", + }, + { + "start": "2024-02-25T00:00:00.000Z", + "end": "2024-03-01T00:00:00.000Z", + }, ], id="test_datetime_based_cursor_with_state", ), @@ -807,22 +926,22 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "state_type": "date-range", }, [ - ( - datetime(2024, 1, 20, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 8, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 2, 9, 0, 0, tzinfo=timezone.utc), - datetime(2024, 2, 28, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 2, 29, 0, 0, tzinfo=timezone.utc), - datetime(2024, 3, 19, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 3, 20, 0, 0, tzinfo=timezone.utc), - datetime(2024, 4, 1, 0, 0, 0, tzinfo=timezone.utc), - ), + { + "start": "2024-01-20T00:00:00.000Z", + "end": "2024-02-08T23:59:59.000Z", + }, + { + "start": "2024-02-09T00:00:00.000Z", + "end": "2024-02-28T23:59:59.000Z", + }, + { + "start": "2024-02-29T00:00:00.000Z", + "end": "2024-03-19T23:59:59.000Z", + }, + { + "start": "2024-03-20T00:00:00.000Z", + "end": "2024-04-01T00:00:00.000Z", + }, ], id="test_datetime_based_cursor_with_state_and_end_date", ), @@ -834,14 +953,14 @@ def test_given_overflowing_slice_gap_when_generate_slices_then_cap_upper_bound_t "P5D", {}, [ - ( - datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 1, 31, 23, 59, 59, tzinfo=timezone.utc), - ), - ( - datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), - datetime(2024, 3, 1, 0, 0, 0, tzinfo=timezone.utc), - ), + { + "start": "2024-01-01T00:00:00.000Z", + "end": "2024-01-31T23:59:59.000Z", + }, + { + "start": "2024-02-01T00:00:00.000Z", + "end": "2024-03-01T00:00:00.000Z", + }, ], id="test_datetime_based_cursor_using_large_step_duration", ), @@ -927,7 +1046,7 @@ def test_generate_slices_concurrent_cursor_from_datetime_based_cursor( cursor_granularity=cursor_granularity, ) - actual_slices = list(cursor.generate_slices()) + actual_slices = list(cursor.stream_slices()) assert actual_slices == expected_slices